Beam search is a method for decoding a sequence given an auto-regressive function that outputs a probability distribution over the next possible symbols. Ideally, a search algorithm would traverse the all paths and select the most probable sequence. However, this is prohibatively expensive.

This search algorithm is often used translation. Beam search is most often used at test time, not during training. For a full implementation see OpenNMT. I provide a basic implementation below for reference.

Beam search works iteratively. The details depend on the decoding function: a hidden markov model with a memory of 1 consumes the previous symbol; recurrent neural networks are stateful ; and transformer networks consume the entire prefix .

Walkthrough

Given a function which takes a prefix of a sequence and outputs a probability distribution of output symbols for the next item in the sequence, beam-search is an approximate algorithm which searches for the path that results in the most probable sequence. The path with the highest probability to start with, may not end up being the most likely sequence. Log probability is used so that we can sum together the probabilities and avoid floating point errors.

In this example, we will compute symbols until we reach the maximum length of 4, and maintain 2 beams (or hypotheses). There are three output symbols (A, B, C). The log probabilities from the start symbol are -0.39, -0.60, and -0.45.

The selected options are those with the highest log probabilities. Now, we will generate the next steps probabilities given these two prefixes (S-A and S-C). Here the search will now continue in different branches. The outputs that are highlighted green indicate that they are the paths with the current highest log probabilities.

Note that in this time step that one branch will fade completely, as the other branch contains all the options with lowest probabilities.

Finally, beam search will select the path with total lowest log probability.

Below is a demonstration of how the algorithm searches through the graph, pruning all but the number of parameterized number of beams at each time step.