RNNs

This is part 5 of my applied DL notes. They will cover high-level intuition about RNNs and some practical advice on using them for language modeling/machine translation.

Language Modeling

One of the primary applications of RNNs is for language modeling. The goal is to learn the joint probability of a text sequence P(x1,x2,xT)P(x_{1}, x_{2}, \dots x_{T}). This enables us to do text generation with xtP(xtxt1,x1)x_{t} \sim P(x_{t} \mid x_{t - 1}, \dots x_{1}). Here are some practical insights which motivate and set-up using RNNs for this task:

Perplexity

Say a language model is completing a sequence “It is raining … “. A model which completes it with “outside” is better than “asdasdsad”. We can compute the likelihood of “It is raining outside” and see it is higher than “It is raining asdasdsad.” It is hard to evaluate likelihoods, especially since longer sequences will be less likely. We need an average. Hence it is reasonable to use perplexity:

exp(1nt=1nlogP(xtxt1,x1)) \exp\left(-\frac{1}{n} \sum_{t =1}^{n} \log P(x_{t} \mid x_{t - 1}, \dots x_{1})\right)

Here PP is given by a language model and xtx_{t} is the actual token at time tt. So, this is the exponential of the cross entropy entropy loss average over the nn tokens. If our model is perfect, perplexity is 1. In our model always completely wrong, perplexity is \infty. A model which predicts uniform over the tokens in a vocab will have a perplexity of the number of unique tokens. Intuitively, perplexity tells us a better language model is able to spend fewer bits to compress the sequence.

Machine Translation

Machine translation is a related, but different problem. The goal is to map an input sequence in one language to an output sequence in another language. The text processing is much the same, but there are a few differences:

BLEU

For machine translation, the performance metric is usually BLEU. For any nn-gram in the predicted sequence, BLEU evaluates whether this nn-gram appears in the label truth. We denote pnp_{n} as the ratio of the number of matched nn-grams to the number of nn-grams in the predicted sequence. For example, given a predicted sequence ABBCD and a label sequence ABCDEF, we have p1=4/5,p2=3/4,p3=1/3p_{1} = 4/5, p_{2} = 3/4, p_{3} = 1/3 and p4=0p_{4} = 0. Then, BLEU is:

exp(min(0,1lenlabellenpred))n=1kpn1/2n \exp\left(\min\left(0, 1 - \frac{\text{len}_{\text{label}}}{\text{len}_{\text{pred}}}\right)\right)\prod_{n = 1}^{k} p_{n}^{1/2^{n}}

where kk is the longest nn-grams for matching. If the predicted sequence exactly equals the label sequence, we have BLEU is 1. Also, since pn1/2np_{n}^{1/2^{n}} grows with nn for a fixed pnp_{n}, BLEU assigns greater value to longer nn-gram precision. Finally, since predicting shorter sequences will get you higher pnp_{n}, the exp()\exp(\cdot) term penalizes shorter predicted sequences.

RNNs

Assume a nn-gram model where the conditional probability of a word at time step tt only depends on the last n1n - 1 words. For a vocab V\mathcal{V}, a counting model model would need to store Vn\lvert \mathcal{V} \rvert^{n} numbers. Instead, we can use a latent variable model:

P(xtxt1,x1)P(xtht1) P(x_{t} \mid x_{t - 1}, \dots x_{1}) \approx P(x_{t} \mid h_{t - 1})

where ht1h_{t - 1} is a hidden state that summarizes information through time t1t - 1. We can let ht=f(xt,ht1)h_{t} = f(x_{t}, h_{t - 1}). Note we want hth_{t} to be a small, but useful representation of the past time steps. It should not just store the full sequence; that would be infeasible memory/compute-wise. RNNs are NNs with hidden states. Let XtRn×dX_{t} \in \mathbb{R}^{n \times d} be a minibatch of inputs at time tt and HtRn×hH_{t} \in \mathbb{R}^{n \times h} be the hidden state at time tt. We have Ht1Rn×hH_{t - 1} \in \mathbb{R}^{n \times h} stored from the last step. We model:

Ht=ϕ(XtWxh+Ht1Whh+bh) H_{t} = \phi(X_{t}W_{xh} + H_{t - 1}W_{hh} + b_{h})

So, the hidden state today depends on its dependence on to the hidden state yesterday, parameterized by WhhW_{hh}, and its dependence on the input today, parameterized by WxhW_{xh}. Finally, we generate the output OtRqO_{t} \in \mathbb{R}^{q} as:

Ot=HtWhq+bq O_{t} = H_{t}W_{hq} + b_{q}

Note that the weights Wxh,Whh,Whq,bhW_{xh}, W_{hh}, W_{hq}, b_{h} and bqb_{q} are constant across time steps. So, parameterization does not increase with tt. Note we can simplify computation of HtH_{t} here by concatenating the matrices into [Xt,Ht1]Rn×(d+h)[X_{t}, H_{t - 1}] \in \mathbb{R}^{n \times (d + h)} and feeding them into a FC “hidden state” layer paramterized by [Wxh,Whh]R(d+h)×h[W_{xh}, W_{hh}] \in \mathbb{R}^{(d + h) \times h}. Then, we take HtH_{t} and feed it into a FC “output” layer parameterized by WhqW_{hq} to get OtO_{t}.

RNNs for Character-Level Language Models

As a concrete example, consider the character RNN below where we want to predict the next token using current and past tokens. Consider a minibatch of size 1 with the sequence given by “machine.” In step 3, O3O_{3} depends on “mac” and the cross-entropy loss will depend on the softmax(O3)\text{softmax}(O_{3}) and the true char “h.” In practice, each token is a vector and batch size is greater than 1. So, the input is XtRn×dX_{t} \in \mathbb{R}^{n \times d}.

Some implementation advice:

Backprop Through Time

Backprop through time is just backprop for RNNs. There is nothing conceptually different, but it is worth studying the computational graph to point out potential issues we run into with long sequences.

Think about a sequence with T=1000T = 1000. The first token can potentially influence the last token. Computing and storing the gradients of can then take too long and requires to much memory. Moreover, it will typically be numerically unstable. To be more mathematically precise, consider an RNN with the identity activation and no bias parameters. For time step tt, let a single example input be xtRdx_{t} \in \mathbb{R}^d, label yty_{t}, hidden state htRhh_{t} \in \mathbb{R}^{h} and output otRqo_{t} \in \mathbb{R}^{q}. Then we have:

ht=Whxxt+Whhht1 h_{t} = W_{hx}x_{t} + W_{hh}h_{t - 1}

ot=Wqhht o_{t} = W_{qh}h_{t}

If the loss at time tt is l(ot,ytl(o_{t}, y_{t}, the objective function is:

L=1Ttl(yt,ot) L = \frac{1}{T}\sum_{t} l(y_{t}, o_{t})

The computational graph below shows the dependencies. Following the arrows backwards from LL to parameters Whx,WhhW_{hx}, W_{hh} and WqhW_{qh}, we can get the gradients of interest.

First, we have:

Lot=1Tl(ot,yt)otRq \frac{\partial L}{\partial o_{t}} = \frac{1}{T} \frac{\partial l(o_{t}, y_{t})}{\partial o_{t}} \in \mathbb{R}^{q}

Since LL depends on WqhW_{qh} through o1,o2,,oTo_{1}, o_{2}, \dots, o_{T}, the chain rule for total derivatives gives us:

LWqh=tprod(Lot,otWqh)=tLothtTRq×h \frac{\partial L}{\partial W_{qh}} = \sum_{t} \text{prod}\left(\frac{\partial L}{\partial o_{t}}, \frac{\partial o_{t}}{\partial W_{qh}} \right) = \sum_{t} \frac{\partial L}{\partial o_{t}} h_{t}^{T} \in \mathbb{R}^{q \times h}

For the final time step TT, we have LL only depends on hTh_{T} through oTo_{T}:

LhT=prod(Lot,othT)=WqhTLotRh \frac{\partial L}{\partial h_{T}} = \text{prod}\left(\frac{\partial L}{\partial o_{t}}, \frac{\partial o_{t}}{\partial h_{T}} \right) = W_{qh}^{T} \frac{\partial L}{\partial o_{t}} \in \mathbb{R}^{h}

For t<Tt < T, it is more complicated as LL depends on hth_{t} through oto_{t} and ht+1h_{t + 1}:

Lht=prod(Lht+1,ht+1ht)+prod(Lot,otht)=WhhTLht+1+WqhTLot \frac{\partial L}{\partial h_{t}} = \text{prod}\left(\frac{\partial L}{\partial h_{t + 1}}, \frac{\partial h_{t + 1}}{\partial h_{t}}\right) + \text{prod}\left(\frac{\partial L}{\partial o_{t}}, \frac{\partial o_{t}}{\partial h_{t}}\right) = W_{hh}^{T} \frac{\partial L}{\partial h_{t + 1}} + W_{qh}^{T} \frac{\partial L}{\partial o_{t}}

Unrolling the recursion, we get:

Lht=i=tT(WhhT)TiWqhTLoT+tiRh \frac{\partial L}{\partial h_{t}} = \sum_{i = t}^{T} (W_{hh}^{T})^{T - i} W_{qh}^{T} \frac{\partial L}{\partial o_{T + t - i}} \in \mathbb{R}^{h}

Here we see for long sequence models we are going to have very large powers of WhhTW_{hh}^{T}. Eigenvalues smaller than 1 will vanish, while eigenvalues greater than 1 will diverge. We have numerical instability which will cause either vanishing/exploding gradients. Finally, notice that LL depends on WhxW_{hx} and WhhW_{hh} through hidden states h1,hTh_{1}, \dots h_{T}, so we get:

LWhx=tprod(Lht,htWhx)=t=1TLhtxtTRh×d \frac{\partial L}{\partial W_{hx}} = \sum_{t} \text{prod}\left(\frac{\partial L}{\partial h_{t}}, \frac{\partial h_{t}}{\partial W_{hx}} \right) = \sum_{t = 1}^{T} \frac{\partial L}{\partial h_{t}} x_{t}^{T} \in \mathbb{R}^{h \times d}

LWhh=tprod(Lht,htWhh)=t=1TLhtht1TRh×h \frac{\partial L}{\partial W_{hh}} = \sum_{t} \text{prod}\left(\frac{\partial L}{\partial h_{t}}, \frac{\partial h_{t}}{\partial W_{hh}} \right) = \sum_{t = 1}^{T} \frac{\partial L}{\partial h_{t}} h_{t - 1}^{T} \in \mathbb{R}^{h \times h}

Again, any numerical instability from Lht\frac{\partial L}{\partial h_{t}} quantity will show up in both LWhh\frac{\partial L}{\partial W_{hh}} and LWhx\frac{\partial L}{\partial W_{hx}}.

Training an RNN is the same as any other NN. We alternate between forward passes and backprop through time. Any intermediate values are cached, i.e. Lht\frac{\partial L}{\partial h_{t}} is stored in memory to compute LWhh\frac{\partial L}{\partial W_{hh}} and LWhx\frac{\partial L}{\partial W_{hx}}.

Strategies to Handle Numerical Instability

First, we could do the full backprop computation. But then we are giving up on finding robust models which will generalize well. As a result of the instability, small perturbations in initial conditions can lead to vastly different updates and hence final models.

Second, we can truncate the summation in Lht\frac{\partial L}{\partial h_{t}} after some τ\tau number of steps. For example, one could detach the gradient after a given number of time steps or between mini-batches. The model focuses on short-term influences rather than long-term influences. People have found this bias is desirable as it leads to simpler, more stable models.

A more complicated truncation approach is randomized. Specifically, define a sequence ϵt\epsilon_{t} with parameter 0πt10 \leq \pi_{t} \leq 1 where P(ϵt=0)=1πtP(\epsilon_{t} = 0) = 1 - \pi_{t} and P(ϵt=πt1)=πtP(\epsilon_{t} = \pi_{t}^{-1}) = \pi_{t}, so E[ϵt]=1\mathbb{E}[\epsilon_{t}] = 1. We then define:

zt=ϵtWhhTLht+1+WqhTLot z_{t} = \epsilon_{t} W_{hh}^{T} \frac{\partial L}{\partial h_{t + 1}} + W_{qh}^{T} \frac{\partial L}{\partial o_{t}}

Notice that E[zt]=LhT\mathbb{E}[z_{t}] = \frac{\partial L}{\partial h_{T}}, but whenever ϵt=0\epsilon_{t} = 0, we do not unroll the recursion. So, only rarely do we get very long chain rule sequences. By re-weighting long sequences up in a clever manner, one can get such a scheme to provide unbiased gradient estimates.

From bottom up, the picture shows these three strategies for analyzing the few words of a text. In practice, the regular truncation works best. It sufficiently captures the relevant dependencies, it is lower variance than the randomized strategy and has a desirable regularization effect.

Gradient Clipping

In RNNs, a well-known issue is unstable optimization. For a sequence length of TT, the gradient will involve a chain of matrix-products of length O(T)O(T) during backprop. Since this product involves the same matrices over and over again, we can have the vanishing/exploding gradients problem.

One issue you can face is that once in a while your gradients can get too large and your algorithm diverges. A reasonable approach to use gradient clipping:

gmin(1,θg)g g \leftarrow \min\left(1, \frac{\theta}{\|g \|}\right)g

So, now gθ\|g\| \leq \theta and the gradient points in the original direction. It also has the nice side-effect of robustifying optimization to a given minibatch or particular sample. Sometimes people do compute the gradient norm over all the parameters in the sample.

Gated Recurrent Units (GRU)

We know that long products of matrices can lead to vanishing/exploding gradients. To get better control of our gradients, we may want to add the ability to: store early vital info which otherwise would need a large gradient to exert influence, slip irrelevant tokens and forgetting our internal state representation when there is a logical break in the sequence.

To address these concerns, GRUs support gating of the hidden state. Specifically, it introduces learned mechanisms for when a hidden state should be updated and when it should be reset.

Reset and Update Gates

Consider a mini-batch of nn samples. Reset gates, RtRn×hR_{t} \in \mathbb{R}^{n \times h}, will help capture short-term dependencies in sequences. Update gates, ZtRn×hZ_{t} \in \mathbb{R}^{n \times h}, will help capture long-term dependencies in sequences. They are computed as follows:

Rt=σ(XtWxr+Ht1Whr+br) R_{t} = \sigma(X_{t} W_{xr} + H_{t - 1} W_{hr} + b_{r})

Zt=σ(XtWxz+Ht1Whz+bz) Z_{t} = \sigma(X_{t} W_{xz} + H_{t - 1} W_{hz} + b_{z})

The sigmoid functions transform the input values to vectors with entries in (0,1)(0, 1). This will let perform convex combinations, i.e. treat the values as weights.

Candidate Hidden States

The reset gate RtR_{t} controls how much of the previous state we might still want to remember. It gets combined with the regular hidden state updating mechanism to get the following candidate hidden state H~tRn×t\tilde{H}_{t} \in \mathbb{R}^{n \times t}:

H~t=tanh(XtWxh+(RtHt1)Whh+bh) \tilde{H}_{t} = \tanh(X_{t}W_{xh} + (R_{t} \odot H_{t - 1}) W_{hh} + b_{h})

The tanh\tanh ensures the values remain in the interval (1,1)(-1, 1). Notice that if Rt0R_{t} \approx 0, then H~t\tilde{H}_{t} is just an MLP result with the input XtX_{t}. The pre-existing state is reset! If Rt1R_{t} \approx 1, we recover the original RNN set-up. This is a candidate hidden state because we still have no accounted for ZtZ_{t}.

Hidden State

The updated gate ZtZ_{t} will control how much of the new state is just a copy of the old state and to what degree the new candidate state is used. We get:

Ht=ZtHt1+(1Zt)H~t H_{t} = Z_{t} \odot H_{t - 1} + (1 - Z_{t}) \odot \tilde{H}_{t}

When Zt1Z_{t} \approx 1, the new candidate state is irrelevant and we just keep the old state Ht1H_{t - 1}. So, information in XtX_{t} is ignored, effectively skipping time tt in the dependency chain. Whn Zt0Z_{t} \approx 0, the new candidate state is all that matters. This flexibility can help us solve vanishing gradient problems and capture long-run dependencies. If Zt1Z_{t} \approx 1 for a while, the hidden state close to the beginning can easily be retained and passed down the subsequence.

Long Short-Term Memory (LSTM)

LSTMs much like GRUs care about long-term information preservation and short-term input skipping.

Gated Memory Cell

A memory cell records additional information about when to remember and when to ignore inputs in the hidden state. However, it is not passed to the output layer. It exists exclusively for state control. In it, we have the input gate ItRn×hI_{t} \in \mathbb{R}^{n \times h} (decides when to read data into the cell), forget gate FtRn×hF_{t} \in \mathbb{R}^{n \times h} (decides to reset the cell) and output gate OtRn×hO_{t} \in \mathbb{R}^{n \times h} (reads entries out from th cell).

It=σ(XtWxi+Ht1Whi+bi) I_{t} = \sigma(X_{t} W_{xi} + H_{t - 1} W_{hi} + b_{i})

Ft=σ(XtWxf+Ht1Whf+bf) F_{t} = \sigma(X_{t} W_{xf} + H_{t - 1} W_{hf} + b_{f})

Ot=σ(XtWxo+Ht1Who+bo) O_{t} = \sigma(X_{t} W_{xo} + H_{t - 1} W_{ho} + b_{o})

Notice all these values are in (0,1)(0, 1) due to the sigmoids. The candidate memory cell is given by:

C~t=tanh(XtWxc+Ht1Whc+bc) \tilde{C}_{t} = \tanh(X_{t}W_{xc} + H_{t - 1} W_{hc} + b_{c})

Now, the input gate will govern how much we take new data into account via C~t\tilde{C}_{t} and the forget gate address how much of the old memory cell content Ct1C_{t - 1} we retain. We get:

Ct=FtCt1+ItC~t C_{t} = F_{t} \odot C_{t - 1} + I_{t} \odot \tilde{C}_{t}

Hence if Ft1F_{t} \approx 1 and It0I_{t} \approx 0, the past memory cells Ct1C_{t - 1} will be saved over time and passed to the current time step. This captures long-run dependencies and solves the vanishing gradient problem.

Hidden State

THe output gate finally comes into play for the hidden state. We have:

Ht=Ottanh(Ct) H_{t} = O_{t} \odot \tanh(C_{t})

Notice that the values of HtH_{t} are always in (1,1)(-1, 1). When Ot1O_{t} \approx 1, we pass all the info from the memory cell into the predictor. Else if Ot0O_{t} \approx 0, we retain all the info within the memory cell.

Deep RNNs

All the RNNs above have a single unidirectional hidden layer. GRUs and LSTMs specify how the inputs and latent variables interact within the hidden layer in different ways. Such specifications can be fairly arbitrary. But to create even more flexibility, we can stack several hidden layers on top of each other. Intuitively, we may think particular types of information is relevant at different levels of the stack. For example, maybe higher levels record macro-level trends, while the lower-levels keep track of shorter-term dynamics.

In a deep RNN with LL hidden layers, each hidden state is passed both to the next time step of the current layer and the current time step of the next layer. For simplicity, consider RNNs. For each hidden layer l[L]l \in [L], we have:

Ht(l)=ϕl(Ht(l1)Wxh(l)+Ht1(l)Whh(l)+bh(l)) H_{t}^{(l)} = \phi_{l}(H_{t}^{(l - 1)} W_{xh}^{(l)} + H_{t - 1}^{(l)} W_{hh}^{(l)} + b_{h}^{(l)})

Here we have Ht(0)=XtH_{t}^{(0)} = X_{t}, Wxh(l)Rh×hW_{xh}^{(l)} \in \mathbb{R}^{h \times h} and Wjh(l)Rh×hW_{jh}^{(l)} \in \mathbb{R}^{h \times h}. The calculation at the output layer just depends on the last hidden layer:

Ot=Ht(L)Whq+bq O_{t} = H_{t}^{(L)} W_{hq} + b_{q}

By simply replacing the hidden state computation with GRUs or LSTMs, we can get a deep gated RNN. Overall, ensuring proper convergence of deep RNNs requires careful settings for the learning rate, proper initialization and gradient clipping.

Bidirectional RNNs

Most sequence learning problems, we want to model the next output given what we have seen so far. But there are problems where we may see the future and want to infer the past. For example, consider fill-in-the-blank tasks where longer-range context may be useful. In HMMs, one can combine forward and backward recursions to infer P(yjyj)P(y_{j} \mid y_{-j}) where yiy_{i} form a sequence of outputs generated from hidden states hih_{i}. We can similarly learn from the future with bidirectional RNNs.

Instead of only running an RNN in forward mode from X1XTX_{1} \rightarrow X_{T}, we also run an RNN in reverse from XTX1X_{T} \rightarrow X_{1}. Another hidden layer is added to process information in the backward direction more flexibly. Formally, the forward and backward hidden states are given by HtRn×h\overset{\rightarrow}{H_{t}} \in \mathbb{R}^{n \times h} and HtRn×h\overset{\leftarrow}{H_{t}} \in \mathbb{R}^{n \times h}. We have:

Ht=ϕ(XtWxh(f)+Ht1Whh(f)+bh(f)) \overset{\rightarrow}{H_{t}} = \phi(X_{t} W_{xh}^{(f)} + \overset{\rightarrow}{H_{t - 1}}W_{hh}^{(f)} + b_{h}^{(f)})

Ht=ϕ(XtWxh(b)+Ht1Whh(b)+bh(b)) \overset{\leftarrow}{H_{t}} = \phi(X_{t} W_{xh}^{(b)} + \overset{\leftarrow}{H_{t - 1}}W_{hh}^{(b)} + b_{h}^{(b)})

Next, one concatenates Ht\overset{\rightarrow}{H_{t}} and Ht\overset{\leftarrow}{H_{t}} to get the hidden state HtRn×2hH_{t} \in \mathbb{R}^{n \times 2h} which is fed to the output layer. We have:

Ot=HtWhq+bq O_{t} = H_{t}W_{hq} + b_{q}

where WhqR2h×qW_{hq} \in \mathbb{R}^{2h \times q}. In deep bidirectonal RNNs, you can pass the hidden states as inputs the next forward/backward layers. Also, the two directions can have different numbers of hidden units.

If you train a bidirectional RNN and then test it on a next token prediction problem, you will see poor performance. The model only has access to past data, but its parameters were optimized for having both future/past data. Training these models is also very slow because the forward pass requires both forward and backward recursions and backprop depends on these values creating gradients with long dependency chains.

RNN Encoder-Decoder Architecture

For modeling input/output sequences, an encoder-decoder architecture makes sense. The encoder takes the variable-length input and maps it to a fixed length state. The decoder takes this fixed length state and transforms it to a variable-length output.

For machine translation, the encoder and decoders are typically RNNs. Info on the input sequence is encoded in the hidden state of the RNN encoder. A RNN decoder generates the output token by token based on the tokens it has seen/generated plus the hidden state from the RNN encoder. Below there are two special design decisions regarding the start of the RNN decoders. First, there is a <bos> token used as an input. Second, usually the final RNN encoder hidden state is used to initiate the hidden state of the RNN decoder. Here it is taken as an input in all time steps. Notice the RNN decoder can stop making predictions once it generates <eos>. Here the labels are just the original output sequence shifted by one token.

Encoder

More precisely, the encoder transforms the input sequence into a fixed-shape context variable cc:

c=q(h1,h2,,hT) c = q(h_1, h_2, \dots, h_{T})

using an RNN. Taking q(h1,h2,,hT)=hTq(h_1, h_2, \dots, h_{T}) = h_{T} would mean the context variable is just the final hidden state. If one used an LSTM, cc would include both the hidden state as well as memory cell. In one used a GRU, cc would just be the hidden state.

Decoder

We can think of the decoder as modeling P(yty1,yT,c)P(y_{t'} \mid y_{1}, \dots y_{T'}, c) where y1,yTy_{1}, \dots y_{T'} is the output sequence. At each time step tt', the RNN takes cc, the previous hidden state st1s_{t' - 1} and the previous output yt1y_{t' - 1}, transforming them into the new hidden state sts_{t'}. We have:

st=g(c,st1,yt1) s_{t'} = g(c, s_{t' - 1}, y_{t' - 1})

After obtaining sts_{t'}, we can use an output layer plus softmax operation to compute P(yty1,yT,c)P(y_{t'} \mid y_{1}, \dots y_{T'}, c).

Implementation Details

Output Sequence Search Strategies

In machine translation, we are interested in a search problem for the output sequence. Say the maximum output sequence length is TT'. Given an output vocabulary of Y\lvert \mathcal{Y} \rvert, the goal is search from the ideal sequence in a universe of YT\lvert \mathcal{Y} \rvert^{T'} sequences. The actual outputs will remove the portion including and after the <eos> token.

In greedy search, we generate the output sequence by setting the output at each time step tt' with the following rule:

yt=argmaxyYP(yy1,,yt1,c) y_{t'} = \arg \max_{y \in \mathcal{Y}} P(y \mid y_{1}, \dots, y_{t' - 1}, c)

Once <eos> token is generated or the length of the output sequence reaches TT', the sequence is done. The optimal sequence is the sequence with the maximum value for t=1TP(yty1,,yt1,c)\prod_{t' = 1}^{T'} P(y_{t'} \mid y_{1}, \dots, y_{t' - 1}, c). However, greedy search is not guaranteed to return this optimal sequence.

Consider the example below. Both images show the time steps on the horizontal axis and the conditional probs on the y-axis. The blue boxes are possible outputs. In the first case, the output is ABC<eos> with prob 0.50.40.50.6=0.0480.5 \cdot 0.4 \cdot 0.5 \cdot 0.6 = 0.048. This is the choice selected by greedy search. In the second case, the output is ACB<eos>. Because the second output char is different, the conditional probs of the later time steps changes. The prob here is 0.50.30.60.6=0.0540.5 \cdot 0.3 \cdot 0.6 \cdot 0.6 = 0.054, so it is better than the first choice.

Exhaustive search will be computationally infeasible since it requires evaluating the probs of all sequences. Notice that greedy search only has a complexity of O(YT)O(\lvert \mathcal{Y} \rvert T') which is much better than O(YT)O(\lvert \mathcal{Y} \rvert^{T'}).

Beam search can optimize the tradeoff between accuracy and computational cost. There is a hyperparameter, kk, called beam size. At time step 1, we select the kk tokens with the highest conditional probs. At each following time step, we continue by selecting kk candidate output sequences with the highest conditional probs from the kYk \lvert \mathcal{Y} \rvert choices. The image below shows the process for k=2k = 2, Y={A,B,C,D,E}\mathcal{Y} = \{A, B, C, D, E\} and T=3T' = 3. We start by picking AA and CC since for all y1Yy_1 \in \mathcal{Y}, they maximize P(y1c)P(y_1 \mid c). In step 2, for all y2Yy_2 \in \mathcal{Y}, we compute:

P(A,y2c)=P(Ac)P(y2A,c)P(A, y_2 \mid c) = P(A \mid c)P(y_2 \mid A, c) P(C,y2c)=P(Cc)P(y2C,c)P(C, y_2 \mid c) = P(C \mid c)P(y_2 \mid C, c)

Among the ten choices, the top two are ABAB and CECE. The last step is similar.

Over this full process, we have generated six candidate output sequences: AA, CC, ABAB, CECE, ABDABD and CEDCED. After discarding the tails of sequences including and following <eos>, we choose the sequence which maximizes:

1LαlogP(y1,yLc)=1Lαt=1LlogP(yty1,,yt1,c) \frac{1}{L^{\alpha}} \log P(y_1, \dots y_{L} \mid c) = \frac{1}{L^{\alpha}} \sum_{t' = 1}^{L} \log P(y_{t'} \mid y_1, \dots, y_{t' - 1}, c)

where LL is the length of the sequence and α=0.75\alpha = 0.75. The LαL^{\alpha} term helps longer sequences which will have more (negative) terms in the summation.

The compute cost of beam search is O(kYT)O(k \lvert \mathcal{Y} \rvert T'). Greedy search is just beam search with k=1k = 1.