Subword Tokenization

Neural machine translation of rare words with subword units.

Rico Sennrich, Barry Haddow, Alexandra Birch

2019

This is a subword tokenization of a vocabulary. An example of how this would work and be useful at test time from Sennrich et al. Sennrich, R., Haddow, B., & Birch, A. (2016). Neural Machine Translation of Rare Words with Subword Units. ACL. is as follows:

V = {  "low", "lowest", "newer", "wider" }
BCE("lower")
> "low", "er"

This is much more valuable than a symbol for an unknown word (UNK). Using the two subword segments, it would not be surprising if a model would be able to interpret "lower" correctly.

To build this representation, an iterative algorithm can be used to link together the most common segments, starting with character pairs. Below is the pseudo code provided by the original authors with a few changes. The author provides an optimized implementation that batches and avoids recomputation. It starts bottom up, getting the bigram counts for every pair of symbols. These symbols will start out as characters but in later steps be common strings.

import re
import collections

def get_stats(vocab):
    """For the given merge step, get the bigram counts,
    for every pair of symbols (strings).

    For each word and its freq in the corpus, add
    the number of instances for each of its subword parts.
    """
    pairs = collections.defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols)-1):
            pairs[symbols[i],symbols[i+1]] += freq
    return pairs

def merge_vocab(pair, v_in):
    """Merge the vocab, adding in the new pair."""
    v_out = {}
    bigram_pattern = re.escape(' '.join(pair))
    p = re.compile(bigram_pattern)
    for word in v_in:
        w_out = p.sub(''.join(pair), word)
        v_out[w_out] = v_in[word]
    return v_out

Next, given the bi-counts of a vocabulary, merge the vocabulary to remove repetitions of this bigram.

def byte_pair_encoding(vocab, num_merges=5):
    """For the given number of merges,
    find the most common pairs of symbols.
    """
    for _ in range(num_merges):
        pairs = get_stats(vocab)
        best = max(pairs, key=pairs.get)
        if pairs[best] < 2:
            print('no pair has frequency > 1. Stopping\n')
            break
        vocab = merge_vocab(best, vocab)
    print(vocab)
>>> vocab = {}
    'l o w </w>' : 5,
    'f a r t h e s t </w>' : 5,
    'n e w e r </w>': 5,
    'w i d e r </w>': 5 }

>>> byte_pair_encoding(vocab)
{  'l o w </w>': 5, 'f a r t h e s t </w>': 5, 'n e w er </w>': 5, 'w i d er </w>': 5  }
{  'l o w </w>': 5, 'f a r t h e s t </w>': 5, 'n e w er</w>': 5, 'w i d er</w>': 5  }
{  'lo w </w>': 5, 'f a r t h e s t </w>': 5, 'n e w er</w>': 5, 'w i d er</w>': 5  }
{  'low </w>': 5, 'f a r t h e s t </w>': 5, 'n e w er</w>': 5, 'w i d er</w>': 5  }
{  'low</w>': 5, 'f a r t h e s t </w>': 5, 'n e w er</w>': 5, 'w i d er</w>': 5  }
Charles Lovering © 2026