Transformer Networks

Attention is all you need.


Since the release of BERT and GPT-2, I created a set of slides that are more clear than these notes with better examples. See them here.


Attention is all you need introduces the Transformer Network. This network is a shift from recurrent networks; economy inspires design. It does not use stateful or recurrent functions, and instead it is parallelized across all symbols in an input sequence. However, it is difficult at first to see how this works with sequences of different length. The primary goal of this post is demonstrating how the Transformer Network fits together.


This work focuses on the task of natural language translation (e.g. translating english to german or vice versa.) This notebook focuses on the unique modules the authors present, and how the system fits together. The Transformer Network (TN) is composed of attention modules, linear mappings, regularization features and uses an Encoder-Decoder structure. Since this publication, Transformer Networks Encoders have been used to great success in a wide range of applications .

I modify and present implementations from the tensor2tensor library and The Annotated Transformer For a complete view and implementation of this system, please visit these sources. Diagrams are recreations, and all blocked quotes are from the original paper.

Encoder-Decoder Structure

The transformer uses a encoder-decoder structure: an input sequence of symbols, x = \{ x_1, x_2, ..., x_n \}, is encoded into a sequence of continuous variables, \mathbf{z} = \{ z_1, z_2, ..., z_n \}. This is then decoded into a sequence of symbols, y = \{ y_1, y_2, ..., y_n \}. In some cases, \mathbf{z} is a single continuous variable rather than a sequence . This generation of these symbols occurs one at a time - it is , consuming previously generated symbols as additional input when generating the next. Encoder-decoders often use recurrent architectures.

According to Cho et. al , the encoding function e can be any non-linear function, but it is often implemented as an RNN.

h_{t} = \;
e(h_{t-1}, x_t)
Encoding the input into a hidden state.

The input sentence x is encoded into the vector \mathbf{z}. Depending on the implemention, we consider the final hidden states as the encoding, or some operation on all the hidden states.

\mathbf{z} = \;
\sum_t h_t
Summarize the hidden states. Attention could be used here.

Next, we decode \mathbf{z} into the output predictions y. This typically uses a recurrent function (RNN).

h_{t} \; =
d(h_{t-1}, y_{t-1}, \mathbf{z})
Process the outputs, given the previous generated symbol along with the summarized vector \mathbf{z}.
y_{t} \; =
g(h_{t}, y_{t-1}, \mathbf{z}).
Decode into output symbols.
  1. Its sequential and cannot be easily parallelized.
  2. Often \mathbf{z} is input into each instance of the decoding function. Because from \mathbf{z} there is O(n) distance to each input symbol, it becomes difficult to learn long range dependencies.
  3. The path between an output symbol and its corresponding source symbol depends on the length of x.

TN's stateless auto-regressive strategy decodes encoded (but not summarized) source words and the current output words ouputting probability distributions for new symbols. This allows the model to be parallelized.

Scaled Dot-Product Attention

The authors describe attention as follows:

An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key.

As noted by the authors, attention maps a query to a combination of given outputs, as determined by the query's corresponding compatibility with the input keys. As the autological "Scaled Dot-Product Attention" method implies, the authors use dot product for their compatibility function. One could use any metric, learned or otherwise, for example cosine distance or a feedforward neural network layer.

For their formulation of attention to work, there are a few requirements for the inputs. There must be mapping between the keys and values, and the compatibility function must be valid (be defined for) the queries and the keys. In the paper, there is 1:1 mapping between the keys and values (by index), and the dot-product compatibility function requires that the queries and the keys have the same dimensionality.

Attention intuition.
  1. Each key K_i maps to a value V_i.
  2. Each query Q_j will operate on all the keys with a compatibility function (dot product). As shown in (b), the closer the vectors are in high-dimensional space, the more compatible. These scores will be transformed into a probability distribution by a softmax.
  3. Then, each query will be mapped to a linear combination of the values as determined by the probability distribution (c).

As shown in the example above, the query q_1 is most similar to k_1, thus it is predominantly mapped to the corresponding value v_1. Note: these values are examples, not accurate.

The scaled dot product attention is straight forward.

A: Q \times K \times V \to O \;
The attention operates on queries Q, K, V.
A = \text{SOFTMAX}(\frac{QK^{\intercal}}{\sqrt{d}}) V \;
Attention definition.
Q\in \mathbb{R}^{q \times d}
K \in \mathbb{R}^{n \times d}
V \in \mathbb{R}^{n \times v} \;
O \in \mathbb{R}^{q \times v}
Attended output.

The authors note that the variance of a dot product scales with the size of the input vectors. Increased variance will result in increased magnitude, "pushing the softmax function into regions where it has extremely small gradients." This motivates the scaling of the dot-product based on the dimensionality of the input vectors.

Scaled dot product attention.

Below is an implementation for scaled dot product attention. Each line corresponds to a box in the figure above.

def attention(query, key, value, mask=None): "Compute 'Scaled Dot Product Attention'" # Compatiblity function (dot product) between the query and keys. scores = torch.matmul(query, key.transpose(-2, -1)) # Scale the scores depending on the size of the inputs. scores = scores / math.sqrt(query.size(-1)) # Optional mask. This is used to zero out values that should not be used by this function. if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) # Compute probability distribution across the final dimension. p_attn = F.softmax(scores, dim = -1) # Output linear combinations of values, as determined by the distribution. return torch.matmul(p_attn, value), p_attn

Self Attention

With a single query, self attention will have no effect. This is because the attention mechanism will be a linear combination of the values, and it can only reproduce itself so it serves as an identity function.

def SelfAttention(X): Q, K, V = X, X, X return attention(Q, K, V) >>> out, alpha = SelfAttention(torch.FloatTensor([[0.1,0.1,0.8]])) >>> print(out) tensor([[0.1000, 0.1000, 0.8000]]) >>> print(alpha) tensor([[1.]])

When there are multiple queries, the vectors that are most *compatible* will become more similar because they are mapped to combinations consisting mostly of the already-compatible vectors. The remaining vectors will also be normalized.

>>> X = torch.FloatTensor( [ [0,0,1], [0,0,2], [1,0,0] ] ) >>> out, alpha = SelfAttention(X) >>> print(alpha) tensor( [ [0.2992, 0.5329, 0.1679], [0.2228, 0.7070, 0.0702], [0.2645, 0.2645, 0.4711] ] )

Note that, especially with values greater than 1, a vector can have a greater dot product with other vectors rather than itself. So, similarity is aptly not the correct word to describe this interaction (at least when using a dot product). Thus, the first vector is mapped to a construction consisting mostly of itself and the second vector follows the same trend but more extreme. Lastly, the third vector, less compatible than the others - becomes pseduo-normalized.

Multi-Head Attention

The transformer uses "Multi-Head Attention" as its primary module for representational power. It is built up using scaled dot product attention. But, rather than attend raw queries a single time, this method attends *h* linear projections of the input. For each of the *h* heads, the inputs (K,Q,V) are linearily projected with a learned mapping.

Q\in \mathbb{R}^{q \times m}
K \in \mathbb{R}^{n \times m}
V \in \mathbb{R}^{n \times m} \;
W_j^Q, W_j^K, W_j^V \in \mathbb{R}^{m \times d}
W^O \in \mathbb{R}^{(h*v)\times m}
\text{head}_i = \texttt{Attention}(QW_i^Q, KW_i^K, VW_i^V) \;
Each head attends a projection of the input vectors.
\text{out} = \texttt{Concat}(\text{head}_0, ..., \text{head}_h) W^O \;
The output is the concatonation of the heads which is then projected into the correct dimensionality.
The compatiblity function and the projections are linear. Does including a non-linearity effect the performance of this method? How well would the transformer perform using a feed forward layer?
class MultiHeadedAttention(nn.Module): def __init__(self, h, d_model, dropout=0.1): """Take in model size and number of heads.""" super(MultiHeadedAttention, self).__init__() assert d_model % h == 0 # We assume d_v always equals d_k self.d_k = d_model // h self.h = h self.linears = clones(nn.Linear(d_model, d_model), 4) self.attn = None self.dropout = nn.Dropout(p=dropout) def forward(self, query, key, value, mask=None): """Outputs weighted sums of the value as determined by interactions between the query and the key.""" if mask is not None: # Same mask applied to all h heads. mask = mask.unsqueeze(1) nbatches = query.size(0) # 1) Do all the linear projections in batch from d_model => h x d_k query, key, value = \ [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) for l, x in zip(self.linears, (query, key, value))] # 2) Apply attention on all the projected vectors in batch. x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout) # 3) "Concat" using a view and apply a final linear. x = x.transpose(1, 2).contiguous() \ .view(nbatches, -1, self.h * self.d_k) return self.linears[-1](x)

Thus, the multi-headed attention is a function from R^{q \times d} \to R^{q \times v}. Like the scaled-dot-product attenion, it is able to concurrently operate on all the queries in parallel regardless of the size of the sentence. Lastly, this module is able to support *h* different heads, and still output a fixed-size vector for each query by concatenation followed by a linear mapping of the output.

Multi-Headed Attention.

Other Features

Position-wise Feed-Forward Networks

This two linear transforms with a nonlinear (RELU) operation. The denotation of position-wise remarks on the fact that it is not a convolution, nor does it have any directly spatial functionality.

\text{FFN}(x) = \;
\max(0, xW_1 + b_1)W_2 + b_2
class PositionwiseFeedForward(nn.Module): """Implements FFN equation.""" def __init__(self, d_model=512, d_ff=2048, dropout=0.1): super(PositionwiseFeedForward, self).__init__() self.w_1 = nn.Linear(d_model, d_ff) self.w_2 = nn.Linear(d_ff, d_model) self.dropout = nn.Dropout(dropout) def forward(self, x): return self.w_2(self.dropout(F.relu(self.w_1(x))))

The remaining features used by the network is residual layers, layer normalization and positional encoding. The structure and features of the model all work to make short paths between inputs and outputs, while also being highly regularized. Layer normalization and residual layers are topics on-to-themselves.

The positional encoding is used to represent the position of the queries in their embeddings. This is important because the attention mechanisms have no notion of order among the queries, and order determines the semantics of a sentence. The authors use a positional encoding that uses:

\text{PE}_{(p,2i)} = \sin(p / 10000^{2i/d_{\text{model}}}) \;
\text{PE}_{(p,2i+1)} = \cos(p / 10000^{2i/d_{\text{model}}})
where p is the position and i is the dimension.

As the authors describe:

That is, each dimension of the positional encoding corresponds to a sinusoid. The wavelengths form a geometric progression from 2\pi to 10000 \cdot 2\pi. We chose this function because we hypothesized it would allow the model to easily learn to attend by relative positions, since for any fixed offset k, PE_{pos+k} can be represented as a linear function of PE_{pos}.
Each dimension corresponds to its location; each line in the vertical slice of the graph would be added to the corresponding dimension in the word embeddings.
The authors use dropout to reduce the strength of the signal.


Each instance of the transformer will output a probability for the next symbol. As you can see, the encoder and decoder stacks are repeated N times each. In the paper the default was N = 6. The input and ouput of each stack is the of the same dimensionality. In addition to attention modules, they use a few techniques to regularize their network: layer normalization, residual connections, and dropout.



The encoder consists of a stack of identical modules.
Transformer Network Encoder Details.

First, an input embedding for each word is retrieved. TN uses Byte-Encoding Representation with a shared embedding matrix -- this itself improves performance. It is a subword tokenization of your vocabulary. Next a positional encoding is added pointwise to each dimension of the input vector. The identical encoder modules will operate on this representation.

The two sublayers are Multi-Head Attention (self-attending) and a feed forward layer. This process manipulates the inputs and captures their interactions, outputting a sequence of the same dimensionality. As the authors note, this is used to make using the residual connections more natural. (It is possible to use residual connections with varying dimensions, but it is less clean.)

The residual connections maintain a direct path to the inputs, and the normalization stabilizes the embeddings. This encoder architecture mirrors Highway Networks because additive connections allow for a clear path through the architecture, supporting many layers.


The decoder resembles the encoder. All symbols already generated (beginning with a start symbol) are embedded and combined with the positional encoding.

Transformer Network Decoder Details.

Next, masked self-attention is computed. A mask is applied so that only the right-most output can see previous outputs, preventing any contamination. After this, multi-headed attenion is applied, where the output sequences are the queries, and the encoded symbols are the keys and the queries. This maps the dimensionality of the vectors to be the same as those outputted by the encoder. The previous outputs served to mark what information has already been regressed; the linear maps intrinsic to the multi-head attention could make the compatibility *negative* in some sense, ignoring already generated content. After a feed forward layer, this process is repeated N times.

After the first "execution" of the decoder, the inputs to the module are derived from the encoded symbols rather than the previous output symbols. Note that while the encoder and decoder modules are repeated, they do not share weights. They are separate instances. Thus, I expect the first and second decoder layers learn different things. Finally, after decoding the encoded inputs, a linear map is applied to the vectors and a softmax generates an output probability distribution.


The linear layer takes an input of k inputs \mathbb{R}^{d_{model}} and has a weight shape of \mathbb{R}^{d_{model} \times vocab}, outputing \mathbb{R}^{q \times vocab}. During training, the decoding is set so that all subsequent positions are masked out during attention, so that a symbol could never see "into the future". So, the final linear layer will output a probability distribution for each query (each symbol generated so far) starting with the start symbol. When decoding the next symbol will always be the right-most dimension.

Simplified Transformer Network Architecture.

When decoding an output sequence, the network is run repeatedly. During the first run q = 1. For each run afterwards, q increases, and the right-most dimension is selected as the generated symbol.

A greedy approach looks something like this:

def greedy_decode(model, src, src_mask, max_len, start_symbol): memory = model.encode(src, src_mask) ys = torch.ones(1, 1).fill_(start_symbol).type_as( # generate a word up to the max length. # the system could represent stop symbols to stop early. for i in range(max_len-1): out = model.decode( memory, src_mask, Variable(ys), Variable(subsequent_mask(ys.size(1)).type_as( # select the final outputs' result. prob = model.generator(out[:, -1]) _, next_word = torch.max(prob, dim = 1) next_word =[0] # concat the most likely word to the result. ys =[ys, torch.ones(1, 1).type_as(], dim=1) return ys

Using beam search (as the authors did do), a path is selected by maintaining k beams - i.e. the best-so-far k options.