Notes on "Attention Is All You Need" (2017)

$$ \newcommand{\bzero}{\mathbf{0}} \newcommand{\ba}{\mathbf{a}}\newcommand{\bA}{\mathbf{A}} \newcommand{\bb}{\mathbf{b}}\newcommand{\bB}{\mathbf{B}} \newcommand{\bc}{\mathbf{c}}\newcommand{\bC}{\mathbf{C}} \newcommand{\bd}{\mathbf{d}}\newcommand{\bD}{\mathbf{D}} \newcommand{\be}{\mathbf{e}}\newcommand{\bE}{\mathbf{E}} \newcommand{\bff}{\mathbf{f}}\newcommand{\bF}{\mathbf{F}} \newcommand{\bg}{\mathbf{g}}\newcommand{\bG}{\mathbf{G}} \newcommand{\bh}{\mathbf{h}}\newcommand{\bH}{\mathbf{H}} \newcommand{\bi}{\mathbf{i}}\newcommand{\bI}{\mathbf{I}} \newcommand{\bj}{\mathbf{j}}\newcommand{\bJ}{\mathbf{J}} \newcommand{\bk}{\mathbf{k}}\newcommand{\bK}{\mathbf{K}} \newcommand{\bl}{\mathbf{l}}\newcommand{\bL}{\mathbf{L}} \newcommand{\bm}{\mathbf{m}}\newcommand{\bM}{\mathbf{M}} \newcommand{\bn}{\mathbf{n}}\newcommand{\bN}{\mathbf{N}} \newcommand{\bo}{\mathbf{o}}\newcommand{\bO}{\mathbf{O}} \newcommand{\bp}{\mathbf{p}}\newcommand{\bP}{\mathbf{P}} \newcommand{\bq}{\mathbf{q}}\newcommand{\bQ}{\mathbf{Q}} \newcommand{\br}{\mathbf{r}}\newcommand{\bR}{\mathbf{R}} \newcommand{\bs}{\mathbf{s}}\newcommand{\bS}{\mathbf{S}} \newcommand{\bt}{\mathbf{t}}\newcommand{\bT}{\mathbf{T}} \newcommand{\bu}{\mathbf{u}}\newcommand{\bU}{\mathbf{U}} \newcommand{\bv}{\mathbf{v}}\newcommand{\bV}{\mathbf{V}} \newcommand{\bw}{\mathbf{w}}\newcommand{\bW}{\mathbf{W}} \newcommand{\bx}{\mathbf{x}}\newcommand{\bX}{\mathbf{X}} \newcommand{\by}{\mathbf{y}}\newcommand{\bY}{\mathbf{Y}} \newcommand{\bz}{\mathbf{z}}\newcommand{\bZ}{\mathbf{Z}} \newcommand{\balpha}{\boldsymbol{\alpha}}\newcommand{\bAlpha}{\boldsymbol{\Alpha}} \newcommand{\bbeta}{\boldsymbol{\beta}}\newcommand{\bBeta}{\boldsymbol{\Beta}} \newcommand{\bgamma}{\boldsymbol{\gamma}}\newcommand{\bGamma}{\boldsymbol{\Gamma}} \newcommand{\bdelta}{\boldsymbol{\delta}}\newcommand{\bDelta}{\boldsymbol{\Delta}} \newcommand{\bepsilon}{\boldsymbol{\epsilon}}\newcommand{\bEpsilon}{\boldsymbol{\Epsilon}} \newcommand{\bzeta}{\boldsymbol{\zeta}}\newcommand{\bZeta}{\boldsymbol{\Zeta}} \newcommand{\beeta}{\boldsymbol{\eta}}\newcommand{\bEta}{\boldsymbol{\Eta}} % \beta already taken \newcommand{\btheta}{\boldsymbol{\theta}}\newcommand{\bTheta}{\boldsymbol{\Theta}} \newcommand{\biota}{\boldsymbol{\iota}}\newcommand{\bIota}{\boldsymbol{\Iota}} \newcommand{\bkappa}{\boldsymbol{\kappa}}\newcommand{\bKappa}{\boldsymbol{\Kappa}} \newcommand{\blambda}{\boldsymbol{\lambda}}\newcommand{\bLambda}{\boldsymbol{\Lambda}} \newcommand{\bmu}{\boldsymbol{\mu}}\newcommand{\bMu}{\boldsymbol{\Mu}} \newcommand{\bnu}{\boldsymbol{\nu}}\newcommand{\bNu}{\boldsymbol{\Nu}} \newcommand{\bxi}{\boldsymbol{\xi}}\newcommand{\bXi}{\boldsymbol{\Xi}} \newcommand{\bomikron}{\boldsymbol{\omikron}}\newcommand{\bOmikron}{\boldsymbol{\Omikron}} \newcommand{\bpi}{\boldsymbol{\pi}}\newcommand{\bPi}{\boldsymbol{\Pi}} \newcommand{\brho}{\boldsymbol{\rho}}\newcommand{\bRho}{\boldsymbol{\Rho}} \newcommand{\bsigma}{\boldsymbol{\sigma}}\newcommand{\bSigma}{\boldsymbol{\Sigma}} \newcommand{\btau}{\boldsymbol{\tau}}\newcommand{\bTau}{\boldsymbol{\Tau}} \newcommand{\bypsilon}{\boldsymbol{\ypsilon}}\newcommand{\bYpsilon}{\boldsymbol{\Ypsilon}} \newcommand{\bphi}{\boldsymbol{\phi}}\newcommand{\bPhi}{\boldsymbol{\Phi}} \newcommand{\bchi}{\boldsymbol{\chi}}\newcommand{\bChi}{\boldsymbol{\Chi}} \newcommand{\bpsi}{\boldsymbol{\psi}}\newcommand{\bPsi}{\boldsymbol{\Psi}} \newcommand{\bomega}{\boldsymbol{\omega}}\newcommand{\bOmega}{\boldsymbol{\Omega}} \newcommand{\nA}{\mathbb{A}} \newcommand{\nB}{\mathbb{B}} \newcommand{\nC}{\mathbb{C}} \newcommand{\nD}{\mathbb{D}} \newcommand{\nE}{\mathbb{E}} \newcommand{\nF}{\mathbb{F}} \newcommand{\nG}{\mathbb{G}} \newcommand{\nH}{\mathbb{H}} \newcommand{\nI}{\mathbb{I}} \newcommand{\nJ}{\mathbb{J}} \newcommand{\nK}{\mathbb{K}} \newcommand{\nL}{\mathbb{L}} \newcommand{\nM}{\mathbb{M}} \newcommand{\nN}{\mathbb{N}} \newcommand{\nO}{\mathbb{O}} \newcommand{\nP}{\mathbb{P}} \newcommand{\nQ}{\mathbb{Q}} \newcommand{\nR}{\mathbb{R}} \newcommand{\nS}{\mathbb{S}} \newcommand{\nT}{\mathbb{T}} \newcommand{\nU}{\mathbb{U}} \newcommand{\nV}{\mathbb{V}} \newcommand{\nW}{\mathbb{W}} \newcommand{\nX}{\mathbb{X}} \newcommand{\nY}{\mathbb{Y}} \newcommand{\nZ}{\mathbb{Z}} \newcommand{\cA}{\mathcal{A}} \newcommand{\cB}{\mathcal{B}} \newcommand{\cC}{\mathcal{C}} \newcommand{\cD}{\mathcal{D}} \newcommand{\cE}{\mathcal{E}} \newcommand{\cF}{\mathcal{F}} \newcommand{\cG}{\mathcal{G}} \newcommand{\cH}{\mathcal{H}} \newcommand{\cI}{\mathcal{I}} \newcommand{\cJ}{\mathcal{J}} \newcommand{\cK}{\mathcal{K}} \newcommand{\cL}{\mathcal{L}} \newcommand{\cM}{\mathcal{M}} \newcommand{\cN}{\mathcal{N}} \newcommand{\cO}{\mathcal{O}} \newcommand{\cP}{\mathcal{P}} \newcommand{\cQ}{\mathcal{Q}} \newcommand{\cR}{\mathcal{R}} \newcommand{\cS}{\mathcal{S}} \newcommand{\cT}{\mathcal{T}} \newcommand{\cU}{\mathcal{U}} \newcommand{\cV}{\mathcal{V}} \newcommand{\cW}{\mathcal{W}} \newcommand{\cX}{\mathcal{X}} \newcommand{\cY}{\mathcal{Y}} \newcommand{\cZ}{\mathcal{Z}} \DeclareMathOperator*{\argmax}{argmax~} \DeclareMathOperator*{\argmin}{argmin~} \DeclareMathOperator*{\Tr}{Tr} \DeclareMathOperator*{\Bias}{Bias} \DeclareMathOperator*{\Var}{Var} \newcommand{\Perp}{\perp\!\!\! \perp} \let\dsad\d \renewcommand\d{\mathrm{d}} \newcommand{\R}{\mathbb{R}} \newcommand{\N}{\mathbb{N}} \newcommand{\E}{\mathbb{E}} \newcommand{\Eb}[1]{\mathbb{E} \left[ #1 \right]} \newcommand{\F}{\mathcal{F}} \newcommand{\X}{\mathcal{X}} \newcommand{\vocab}[1]{\textbf{\color{blue} #1}} \newcommand{\norm}[1]{\left\lVert#1\right\rVert} \newcommand{\inner}[2]{ \left \langle #1, #2 \right \rangle } \newcommand{\ibra}[1]{\llbracket #1 \rrbracket} \newenvironment{nalign}{ \begin{equation} \begin{aligned} }{ \end{aligned} \end{equation} \ignorespacesafterend } \newcommand{\icol}[1]{ \left(\begin{smallmatrix}#1\end{smallmatrix}\right)% } \newcommand{\cc}[1]{\overline{#1}} \newcommand{\tr}[1]{\text{tr} \left( #1 \right)} \newcommand{\aff}{\text{aff} \,} \newcommand{\ri}{\text{ri} \, } \newcommand{\rb}{\text{rb} \, } \newcommand{\cl}{\text{cl} \, } \newcommand{\conv}{\text{conv} \, } \newcommand{\irow}[1]{ \begin{smallmatrix}(#1)\end{smallmatrix}% } \newcommand\abs[1]{\lvert #1 \rvert} $$

In this post I overview and code up the Transformer.

Motivation

It has long been known that Recurrent Neural Networks (RNNs) suffer greatly when it comes to large-scale training in comparison to Multi-Layer Perceptrons (MLPs), precisely because of their sequential nature. Additionally, the recurrent nature of RNNs make them harder to retain long-term information.

Question: Is the bottleneck with RNNs fundamentally computational or also conceptual? Is the advance only based on the breakthrough in computation (i.e. is the progress based solely on the fit of hardware and type of computation introduced here?)

I think yes. While the computationally RNNs are harder to train, their recurrent nature makes them suffer from vanishing gradients and in general be less adept to long-term dependencies. In contrast, Transformers don’t suffer from this precisely because they implement pairwise comparison between tokens.

The main goal here is to work with time-series data using attention mechanism without sequential computation.

One interesting point in the paper is the emphasis on the long-term dependencies. Classically, sequential networks such as RNNs have bias [citation needed!] towards recent inputs and states. This is exactly the “fundamental” sreason for their shortcomings. It’s emphasized that usages of attention of that time were mostly in the context of RNNs. The upside of this was the ability to have input-position-independent computations. The proposed approach, however, tries to get away with a purely attention-based approach.

Interlude: Layer normalization

Architecture

Transformer consists of two main building blocks: the encoder and decoder. From the computational perspective, the encoder can be seen as, in RNN lingo, a one-shot computation that produces hidden representations of the data. The decoder is closer a RNN because of the sequential nature that the outputs are produced. However, nothing all representations are available all at once in the decoder, so problems of information-retention and filtering are non-existent here (at least in direct comparison to RNNs). Additionally, the Transformer is a set-based model. This means that regardless of the input positions (and if the positional encoding step is skipped), the representation that the Transformer yields will be identical. We look into the components in greater detail.

Figure 1. Transformer architecture.

Encoder

The encoder consists of one main block: the attention block, which produces the pairwise representations of its inputs. When we stack multiple encoding layers, it leads to higher-order attention (think pairwise-pairwise connections), and so on. In a way, I guess, it forms a hierarchical connectivity, irrespective to the positions of the inputs (though positions can baked into the inputs themselves, explained later). The encoder consists of two main blocks: multi-head attention and feed forward network

Multi-head attention constructs a soft dictionary and provides queries for input elements to attend to, and finally queries the values from the dictionary. At this step, each element in the input sequence has indexed a content vector from other positions in the input that is the most helpful/important to it.

Then a feed forward network processes the attended representations of inputs. The key here that this step is performed for each position independently of others with the same weights. This is crucial to keep the model process sets as input.

This attend-then-transform process is repeated multiple times to get deeper representations.

Attention

Single-head Attention

Single-head attention constructs a dictionary on the fly and performs soft indexing of it. Each input element (or token) in the sequence is transformed to a (key, query, value) triple. Intuitively, the key of each element describes the identifier of the element, the value describes the actual content of the element, and query describes the key whose value is the most “interesting” for that particular element. Unlike a dictionary, the queries and keys might not match identically because attention constructs the triples independently of other elements. Because of that, the dictionary indexing is soft: a dot-product is computed between each query and each key determine the degree of similarity between the identifiers. Finally, this similarity is used in the weighted sum of the contents of each token. Let’s now make it more concrete.

The encoder block consists of three weight matrices $W^K, W^Q \in \R^{d_{\text{model}} \times d_k}$ and $W^V \in \R^{d_{\text{model}} \times d_v}$; also note that the projection matrices are independent of the length of the sequence. These matrices project the input embeddings to matrices $K, Q, V$, corresponding to keys, queries and values of all the tokens defined as follows.

Definition 1 (Key, value and query computation) The keys, values and queries for input-embedded sequence (in representation $R \in \R^{n \times d_{\text{model}}}$ is computed as

\[\begin{equation} \begin{split} K &:= R W^K \\ Q &:= RW^Q \\ V &:= RW^V. \end{split} \end{equation}\]

Keys and values can be used to define the dictionary. Queries can be used to, well, query the dictionary. Since we allow for soft-indexing, multiple elements can be matched. This defined a $n \times n$ similarity matrix defined as follows

Definition 2 (Key-query similarity matrix) Given queries $Q$ and keys $K$, the similarity matrix $S \in \R^{n \times n} $ is computed as

\(\begin{equation} \begin{split} S := \text{softmax} \left (K Q^\intercal / \sqrt{d_k} \right ), \end{split} \end{equation}\) where the softmax is taken row-wise, which leads the individual rows to sum up to 1.

After we have the similarity matches, we can take the (weighted) contents of corresponding elements. This defined the attention function.

Definition 3 (Attention) Given $Q, K, V$, the attention at each position is computed as a weighted sum of all values weighted by the similarity of the key at that position with the queries of all positions

\[\begin{equation} \begin{split} \text{attention}(K, Q, V) := S V = \text{softmax}(K Q^\intercal / \sqrt{d_k}) V. \end{split} \end{equation}\]

Multi-head attention

Multi-head attention is just computation of attention with different weights, and a way to pool them into a single representation. The pooling is done by concatenating all the representations into a single matrix, and then integrating them through a linear projeection.

Remark. Note that in the paper multi-head attention is implemented as linearly transforming the $Q, K, V$ matrices. Here we implement it differently (based on (missing reference)) by defining multiple weight matrices that transform the input representation.

This idea can be implemented as follows.

Definition 4 (Multi-head attention) Define $W^O \in \R^{(h \cdot d_{v}) \times d_{\text{model}}}$, and assume $h$ attention heads are used. Then multi-head attention is defined on a triple of lists $( { K_i }, { Q_i }, { V_i})$ (for $i$ from 1 to $h$), with each $K_i = R W_i^K$ (and analgously for $Q, V$), as

\[\begin{equation} \begin{split} \text{multi-head-attention}( \{ K_i \}, \{ Q_i \}, \{ V_i\}) := \underbrace{\text{concatenate}(\text{attention}(K_1, Q_1, V_1), \dots, \text{attention}(K_h, Q_h, V_h))}_{\in \R^{n \times (h \cdot d_v)}} W^O \in \R^{n \times d_{\text{model}}}. \end{split} \end{equation}\]

Intuitively, it allows the attention heads to probe for different kind of information. The job of the final projection $W^O$ is to integrate most useful information from separate attention heads.

Question: Why does the linear pooling of individual attention heads make sense? It somehow doen’t seem powerful at all. Why not just skip this step and immediately use a fully-connected layer with non-linearity? Seems like important information might be discarded by this “pooling” operation.

Decoder

Decoder builds on the blocks introduced in the encoder, though one-shot computation is no longer possible here. Instead, the outputs are produced sequentially one by one. Inside the decoder there are two attention mechanisms: one for the output representations and one for integration of the outputs produced so far. There is some difference in their inner workings.

Output attention

The output attention (denoted as “Masked Multi-Head Attention” in the figure) takes in the keys, queries and values of the so-far-produced outputs’ representations, and produces, contrary to before, not any kind of representations of the outputs, but queries for downstream usage. These queries can then be used in the second attention mechanism that takes the dictionary (defined by keys and values) from the encoder, and the queries from the decoder to produce representations of the integrated values. This can be defined as follows.

Definition 5 (Masked multi-head attention)

The intuition here is that

Positional-encoding

Mechanical example

To understand how the network works mechanically, consider the conceptual path an input takes in a $h$ attention-heads setting. Say our input is a timeseries $X = (\bx_1, \dots, \bx_n)$ consisting of datapoints $\bx_i$ having n-dimensional feature space, $\bx_i \in \R^f$; overall this means that $X \in \R^{n \times f}$. We can trace the path of the input as follows:

  • Input $X$ is passed to the input embedding layer, this yields a representation of the input $R \in \R^{n \times d_{\text{model}}}$.
  • The positional ecodings are computed, $\text{PE} \in \R^{n \times d_{\text{model}}}$ and added to the representation $R$ to get the encoded representation $R^{\text{pos}} = R + \text{PE} \in \R^{n \times d_{\text{model}}}$.
  • Encoding phase (repeat $n_e$ times, where at proceeding layers the input is taken to be the output of the previous layer, i.e. the embedding step is skipped).
    • Input processing step:
      • The positional representation is passed through the attention heads. Head $i=1,\dots,h$ produces $n$ vectors in a matrix $Z_i \in \R^{n \times d_v}$ by first computing the key, values and queries, computing pair-wise dot-products as scores for attention, and then computing the actual attention as a sum of values weighted by the scores. For each head the following process is followed:
        • Key is computed as $K_i = R^{\text{pos}} W_i^Q \in \R^{n \times d_{k}}$
        • Query is computed as $Q_i = R^{\text{pos}} W_i^K \in \R^{n \times d_k}$
        • Value is computed as $V_i = R^{\text{pos}} W_i^V \in \R^{n \times d_v}$
        • Attention score matrix is computed as $S_i = Q_i K_i^\intercal \in \R^{n \times n}$
        • Attention is computed as $Z_i = \text{softmax}\left(S_i / \sqrt{d_k} \right )V_i \in \R^{n \times d_v}$.
      • Attention variables $Z_i$ are concatenated into $C := \text{concatenate}(Z_1, \dots, Z_h) \in \R^{n \times (h \cdot d_v)}$, and projected onto smaller-dimensional subspace, resulting in the final “pooled” attention variable $Z = C W^O \in \R^{n \times d_{\text{model}}}$.
    • Add & Norm step:
      • Take the position input representation $R^{\text{pos}}$, add it to the attention variable $Z$ and apply layer normalization, yielding $\text{add-norm}_{1} := \text{layer-norm}(R^{\text{pos}} + Z) \in \R^{n \times d_{\text{model}}}$.
    • Feed Forward step:
      • Transform the representation through a non-linearity as $F := \text{MLP}(\text{add-norm}_1) \in \R^{n \times d_{\text{model}}}$.
    • Add & Norm step:
      • $\text{add-norm}_2 := \text{layer-norm}(F + \text{add-norm}_1) \in \R^{n \times d_{\text{model}}}$.

    Now the process of encoding can be repeated in the same way (just with different parameters) multiple times. The most important detail here is that after the encoding phase we end up with a representation vector which can be used in the decoding step, implemented as follows. We now assume that all the input points have been processed, yielding $n$ representation vectors of the input sequence.

    The output-producing (decoding) step is sequential (at least in testing phase; in the training phase we do teacher forcing). The first step in this process looks as follows.

  • Decoding step (again, repeat the process multiple times if needed). Assume we have already decoded $a$ points ($a \in {0, \dots, m}$). Then we proceed as follows:
    • A sequence of all decoded outputs are passed through the embedding layer (if $a = 0$, then the start-of-sequence token, $\bx^{\text{out}}_1 := \text{SOS} \in \R^{f}$ is inputted to the model (one can also reuse the end-of-sequence symbol for this purpose also, see), which serves as a start-of-sequence token).
    • The token goes through input embedding, and a representation matrix $R := (\br_1, \dots, \br_{a+1}) \in \R^{a \times d_{\text{model}}}$ is computed.
    • This representation $\br_1$ then goes through positional encoding layer, yielding $\br_1^{\text{pos}} \in \R^{d_{\text{model}}}$.
    • The positional representation goes through the masked multi-head attention, which works identically to self-attention, except that future values in the attention matrix are set to $-\infty$, since we don’t want the model to attend to them, but this is just an implementational detail; in a perfect world we wouldn’t need to worry about processing sequences in parallel. This yields an attended representation $\bz_1 \in \R^d_{\text{model}}$, which at the first step only attends to itself.
    • The representation goes through add & layer normalize block, this yields $\text{add-norm}1 := \text{layer-norm}(\bz_1 + \br_1^{\text{pos} }) \in \R^{d{\text{model}}}$.
    • $\text{add-norm}_1$ goes through encoder-decoder attention. Firstly, the keys and values are computed from the encoder representaton: $K^{\text{enc}} \in \R^{} V^{\text{enc}} $

On complexity and comparison to RNNs

Training details

Assuming we are concerned with a sequence-to-sequence task, additionally assume that input sequence is represented as $(\bx_1, \dots, \bx_n)$, and the corresponding target sequence is represented as $(\by_1, \dots, \by_m)$. Then his is the type of data used in the training:

  • The sequence that the encoder model sees is the original sequence with an $\text{EOS}$ token appended to its end, i.e. $(\bx_1, \dots, \bx_n, \text{EOS})$.
  • The target sequence (the sequence that the model should match in the decoder) is a sequence i.e. $(\by_1, \dots, \by_m, \text{EOS})$.
  • The sequence that the decoder sees in training (using teacher forcing), is identical to the target sequence, except that the start-of-sequence token is prepended to it, i.e. $(\text{SOS}, \by_1, \dots, \by_m)$.

Results

Implementing the model

There are some tricky details at the code level. One of them is variable-size inputs. Since we want to vectorize as much as possible, we need to make sure that all the sequences are padded and the extra padded data is ignored. This applies both to input and output sequences:

  • In the input sequences, the sequences are padded from the padded, i.e. additional “dummy” inputs are added until the sequence is of required length.
  • In the output sequences, each

$f(x)$.




Enjoy Reading This Article?

Here are some more articles you might like to read next:

  • Notes on "Domain Adaptation by Using Causal Inference to Predict Invariant Conditional Distributions" (2018)
  • Notes on Conditional Gradient Method
  • Notes on Fenchel Conjugate and Duality
  • Comparing Non-Smooth Optimization through Subgradients on Regularized Logistic Regression Problem
  • Notes on Optimizing Non-Differentiable (but Convex) Functions: Subgradients, and Projected and Proximal Methods