Transformer Networks

Attention is all you need.

Vaswani, Shazeer, Parmar, Uszkoreit, Jones, Gomez, Kaiser, Polosukhin

2019

Since the release of BERT and GPT-2, I created a set of slides that are more clear than these notes with better examples. See them here.

Since 2020, almost all models are now decoder-only (often termed "GPT-2 style"). The exception are encoder-only models like BERT which are still used for, well, encoding. If you're interested in encoding, I'd read Matryoshka Representation Learning.

Introduction

Attention is all you need Vaswani, A. et al. (2017). Attention is All You Need. NeurIPS. introduces the Transformer Network. This network is a shift from recurrent networks; economy inspires design. It does not use stateful or recurrent functions, and instead it is parallelized across all symbols in an input sequence. However, it is difficult at first to see how this works with sequences of different length. The primary goal of this post is demonstrating how the Transformer Network fits together.

Overview

This work focuses on the task of natural language translation (e.g. translating English to German or vice versa.) This notebook focuses on the unique modules the authors present, and how the system fits together. The Transformer Network (TN) is composed of attention modules, linear mappings, regularization features and uses an Encoder-Decoder structure. Since this publication, Transformer Networks Encoders have been used to great success in a wide range of applications Devlin, J. et al. (2018). BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding..

I modify and present implementations from the tensor2tensor library and The Annotated Transformer. For a complete view and implementation of this system, please visit these sources. Diagrams are recreations, and all blocked quotes are from the original paper.

Encoder-Decoder Structure

The transformer uses an encoder-decoder Bahdanau, D. et al. (2014). Neural Machine Translation by Jointly Learning to Align and Translate. structure: an input sequence of symbols, x = { x1, x2, ..., xn }, is encoded into a sequence of continuous variables, z = { z1, z2, ..., zn }. This is then decoded into a sequence of symbols, y = { y1, y2, ..., yn }. In some cases, z is a single continuous variable rather than a sequence. This generation of symbols occurs one at a time — it is auto-regressive, consuming previously generated symbols as additional input when generating the next. Encoder-decoders often use recurrent architectures.

According to Cho et al., the encoding function e can be any non-linear function, but it is often implemented as an RNN.

ht = e(ht-1, xt)

Encoding the input into a hidden state.

The input sentence x is encoded into the vector z. Depending on the implementation, we consider the final hidden states as the encoding, or some operation on all the hidden states.

z = ∑t ht

Summarize the hidden states. Attention could be used here.

Next, we decode z into the output predictions y. This typically uses a recurrent function (RNN).

ht = d(ht-1, yt-1, z)

Process the outputs, given the previous generated symbol along with the summarized vector z.

yt = g(ht, yt-1, z)

Decode into output symbols.

  1. It is sequential and cannot be easily parallelized.
  2. Often z is input into each instance of the decoding function. Because from z there is O(n) distance to each input symbol, it becomes difficult to learn long range dependencies.
  3. The path between an output symbol and its corresponding source symbol depends on the length of x.

TN's stateless auto-regressive strategy decodes encoded (but not summarized) source words and the current output words, outputting probability distributions for new symbols. This allows the model to be parallelized.

Scaled Dot-Product Attention

The authors describe attention as follows:

An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key.

As noted by the authors, attention maps a query to a combination of given outputs, as determined by the query's corresponding compatibility with the input keys. As the autological "Scaled Dot-Product Attention" method implies, the authors use dot product for their compatibility function. One could use any metric, learned or otherwise, for example cosine distance or a feedforward neural network layer.

For their formulation of attention to work, there are a few requirements for the inputs. There must be a mapping between the keys and values, and the compatibility function must be valid for the queries and the keys. In the paper, there is a 1:1 mapping between the keys and values (by index), and the dot-product compatibility function requires that the queries and the keys have the same dimensionality.

Attention intuition
Attention intuition.
  1. Each key Ki maps to a value Vi.
  2. Each query Qj will operate on all the keys with a compatibility function (dot product). The closer the vectors are in high-dimensional space, the more compatible. These scores will be transformed into a probability distribution by a softmax.
  3. Then, each query will be mapped to a linear combination of the values as determined by the probability distribution.

As shown in the example above, the query q1 is most similar to k1, thus it is predominantly mapped to the corresponding value v1.

A: Q × K × V → O

A = softmax(QKT / √d) V

Q ∈ ℜq×d, K ∈ ℜn×d, V ∈ ℜn×v, O ∈ ℜq×v

The authors note that the variance of a dot product scales with the size of the input vectors. Increased variance will result in increased magnitude, "pushing the softmax function into regions where it has extremely small gradients." This motivates the scaling of the dot-product based on the dimensionality of the input vectors.

Scaled dot product attention
Scaled dot product attention.

Below is an implementation for scaled dot product attention. Each line corresponds to a box in the figure above.

def attention(query, key, value, mask=None):
    "Compute 'Scaled Dot Product Attention'"
    # Compatibility function (dot product) between the query and keys.
    scores = torch.matmul(query, key.transpose(-2, -1))
    # Scale the scores depending on the size of the inputs.
    scores = scores / math.sqrt(query.size(-1))
    # Optional mask.
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    # Compute probability distribution across the final dimension.
    p_attn = F.softmax(scores, dim = -1)
    # Output linear combinations of values.
    return torch.matmul(p_attn, value), p_attn

Self Attention

With a single query, self attention will have no effect. This is because the attention mechanism will be a linear combination of the values, and it can only reproduce itself so it serves as an identity function.

def SelfAttention(X):
    Q, K, V = X, X, X
    return attention(Q, K, V)

>>> out, alpha = SelfAttention(torch.FloatTensor([[0.1,0.1,0.8]]))
>>> print(out)
tensor([[0.1000, 0.1000, 0.8000]])
>>> print(alpha)
tensor([[1.]])

When there are multiple queries, the vectors that are most compatible will become more similar because they are mapped to combinations consisting mostly of the already-compatible vectors.

>>> X = torch.FloatTensor(
    [
        [0,0,1],
        [0,0,2],
        [1,0,0]
    ]
)
>>> out, alpha = SelfAttention(X)
>>> print(alpha)
tensor(
    [
        [0.2992, 0.5329, 0.1679],
        [0.2228, 0.7070, 0.0702],
        [0.2645, 0.2645, 0.4711]
    ]
)

Note that, especially with values greater than 1, a vector can have a greater dot product with other vectors rather than itself. Thus, the first vector is mapped to a construction consisting mostly of itself and the second vector follows the same trend but more extreme. Lastly, the third vector, less compatible than the others, becomes pseudo-normalized.

Multi-Head Attention

The transformer uses "Multi-Head Attention" as its primary module for representational power. It is built up using scaled dot product attention. But, rather than attend raw queries a single time, this method attends h linear projections of the input. For each of the h heads, the inputs (K, Q, V) are linearly projected with a learned mapping.

Q ∈ ℜq×m, K ∈ ℜn×m, V ∈ ℜn×m

WjQ, WjK, WjV ∈ ℜm×d,   WO ∈ ℜ(h·v)×m

headi = Attention(QWiQ, KWiK, VWiV)

out = Concat(head0, ..., headh) WO

The compatibility function and the projections are linear. Does including a non-linearity affect the performance of this method? How well would the transformer perform using a feed forward layer?
class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        """Take in model size and number of heads."""
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        if mask is not None:
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)

        # 1) Do all the linear projections in batch
        query, key, value = \
            [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
            for l, x in zip(self.linears, (query, key, value))]

        # 2) Apply attention on all the projected vectors in batch
        x, self.attn = attention(query, key, value, mask=mask,
                                dropout=self.dropout)

        # 3) "Concat" using a view and apply a final linear
        x = x.transpose(1, 2).contiguous() \
            .view(nbatches, -1, self.h * self.d_k)
        return self.linears[-1](x)

Thus, the multi-headed attention is a function from ℜq×d → ℜq×v. Like the scaled-dot-product attention, it is able to concurrently operate on all the queries in parallel regardless of the size of the sentence. Lastly, this module is able to support h different heads, and still output a fixed-size vector for each query by concatenation followed by a linear mapping of the output.

Multi-Headed Attention
Multi-Headed Attention.

Other Features

Position-wise Feed-Forward Networks

This is two linear transforms with a nonlinear (ReLU) operation. The denotation of position-wise remarks on the fact that it is not a convolution, nor does it have any directly spatial functionality.

FFN(x) = max(0, xW1 + b1)W2 + b2

class PositionwiseFeedForward(nn.Module):
    """Implements FFN equation."""
    def __init__(self, d_model=512, d_ff=2048, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(F.relu(self.w_1(x))))

The remaining features used by the network are residual layers, layer normalization, and positional encoding. The structure and features of the model all work to make short paths between inputs and outputs, while also being highly regularized.

The positional encoding is used to represent the position of the queries in their embeddings. This is important because the attention mechanisms have no notion of order among the queries, and order determines the semantics of a sentence. The authors use a positional encoding that uses:

PE(p,2i) = sin(p / 100002i/dmodel)

PE(p,2i+1) = cos(p / 100002i/dmodel)

where p is the position and i is the dimension.

That is, each dimension of the positional encoding corresponds to a sinusoid. The wavelengths form a geometric progression from 2π to 10000 · 2π. We chose this function because we hypothesized it would allow the model to easily learn to attend by relative positions, since for any fixed offset k, PEpos+k can be represented as a linear function of PEpos.
Positional encoding
Each dimension corresponds to its location; each line in the vertical slice of the graph would be added to the corresponding dimension in the word embeddings.
Positional encoding with dropout
The authors use dropout to reduce the strength of the signal.

Architecture

Each instance of the transformer will output a probability for the next symbol. The encoder and decoder stacks are repeated N times each. In the paper the default was N = 6. The input and output of each stack is of the same dimensionality. In addition to attention modules, they use a few techniques to regularize their network: layer normalization, residual connections, and dropout.

Transformer architecture
Architecture.

Encoder

The encoder consists of a stack of identical modules.

Encoder details
Transformer Network Encoder Details.

First, an input embedding for each word is retrieved. TN uses Byte-Encoding Representation with a shared embedding matrix — this itself improves performance. It is a subword tokenization of your vocabulary. Next a positional encoding is added pointwise to each dimension of the input vector. The identical encoder modules will operate on this representation.

The two sublayers are Multi-Head Attention (self-attending) and a feed forward layer. This process manipulates the inputs and captures their interactions, outputting a sequence of the same dimensionality.

The residual connections maintain a direct path to the inputs, and the normalization stabilizes the embeddings. This encoder architecture mirrors Highway Networks Srivastava, R. K. et al. (2015). Highway Networks. because additive connections allow for a clear path through the architecture, supporting many layers.

Decoder

The decoder resembles the encoder. All symbols already generated (beginning with a start symbol) are embedded and combined with the positional encoding.

Decoder details
Transformer Network Decoder Details.

Next, masked self-attention is computed. A mask is applied so that only the right-most output can see previous outputs, preventing any contamination. After this, multi-headed attention is applied, where the output sequences are the queries, and the encoded symbols are the keys and the values. This maps the dimensionality of the vectors to be the same as those outputted by the encoder.

After the first "execution" of the decoder, the inputs to the module are derived from the encoded symbols rather than the previous output symbols. Note that while the encoder and decoder modules are repeated, they do not share weights. They are separate instances. Finally, after decoding the encoded inputs, a linear map is applied to the vectors and a softmax generates an output probability distribution.

Decoding

The linear layer takes an input of k inputs ℜdmodel and has a weight shape of ℜdmodel × vocab, outputting ℜq × vocab. During training, the decoding is set so that all subsequent positions are masked out during attention, so that a symbol could never see "into the future".

Simplified architecture
Simplified Transformer Network Architecture.

When decoding an output sequence, the network is run repeatedly. During the first run q = 1. For each run afterwards, q increases, and the right-most dimension is selected as the generated symbol.

A greedy approach looks something like this:

def greedy_decode(model, src, src_mask, max_len, start_symbol):
    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data)
    for i in range(max_len-1):
        out = model.decode(
            memory,
            src_mask,
            Variable(ys),
            Variable(subsequent_mask(ys.size(1)).type_as(src.data)))
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim = 1)
        next_word = next_word.data[0]
        ys = torch.cat([ys,
            torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
    return ys

Using beam search (as the authors did), a path is selected by maintaining k beams — i.e. the best-so-far k options.

Charles Lovering © 2026