Byte-Encoding Representation

Neural machine translation of rare words with subword units

This is a subword tokenization of a vocabulary. An example of how this would work and be useful at test time from Sennrich et al. is as follows:

V = { "low", "lowest", "newer", "wider" } BCE("lower") > "low", "er"

This is much more valuable than a symbol for an unknown word (UNK). Using the two subword segments, it would not be surpirising if a model would be able to intepret "lower" correctly.

To build this representation, an iterative algorithm can be used to link together the most common segments, starting with character pairs. Below is the pseudo code provided by the original authors with a few changes. The author provides an optimized implementation that batches and avoids recomputation. It starts bottom up, getting the bigram counts for every pair of symbols. These symbols will startout as characters but in later steps be common strings.

import re import collections def get_stats(vocab): """For the given merge step, get the bigram counts, for every pair of symbols (strings). For each word and its freq in the corpus, add the number of instances for each of its subword parts. """ pairs = collections.defaultdict(int) for word, freq in vocab.items(): symbols = word.split() for i in range(len(symbols)-1): pairs[symbols[i],symbols[i+1]] += freq return pairs def merge_vocab(pair, v_in): """Merge the vocab, adding in the new pair. """ v_out = {} bigram_pattern = re.escape(' '.join(pair)) p = re.compile(bigram_pattern) for word in v_in: w_out = p.sub(''.join(pair), word) v_out[w_out] = v_in[word] return v_out

Next, given the bi-counts of a vocabulary, merge the vocabulary to remove repetitions of this bigram.

def byte_pair_encoding(vocab, num_merges=5): """For the given number of merges, find the most common pairs of symbols. """ for _ in range(num_merges): pairs = get_stats(vocab) best = max(pairs, key=pairs.get) if pairs[best] < 2: print('no pair has frequency > 1. Stopping\n') break vocab = merge_vocab(best, vocab) print(vocab) >>> vocab = { 'l o w ' : 5, 'f a r t h e s t ' : 5, 'n e w e r ': 5, 'w i d e r ': 5 } 5 >>> byte_pair_encoding(vocab) {'l o w ': 5, 'f a r t h e s t ': 5, 'n e w er ': 5, 'w i d er ': 5} {'l o w ': 5, 'f a r t h e s t ': 5, 'n e w er': 5, 'w i d er': 5} {'lo w ': 5, 'f a r t h e s t ': 5, 'n e w er': 5, 'w i d er': 5} {'low ': 5, 'f a r t h e s t ': 5, 'n e w er': 5, 'w i d er': 5} {'low': 5, 'f a r t h e s t ': 5, 'n e w er': 5, 'w i d er': 5}