Batch Normalization

This is part 3 of my applied DL notes. It covers a very specific topic: batch normalization.

Problems with Training NNs

As motivation, we list some common issues with training NNs that batchnorm may address:

Batch Normalization Overview

Batchnorm addresses some of these issues and consistently helps accelerate convergence of deep NNs (sometimes with depth > 100). Batchnorm is applied layer-by-layer. In each training iteration, the inputs are standardized based on the estimated mean/std of the current minibatch and then scaled based on a learned scale coefficient/offset. More precisely, for minibatch B\mathcal{B} and input xBx \in \mathcal{B} to a batchnorm operation, we have:

BN(x)=βxμ^Bσ^B+α \text{BN}(x) = \beta \odot \frac{x - \hat{\mu}_{\mathcal{B}}}{\hat{\sigma}_{\mathcal{B}}} + \alpha

where μ^B\hat{\mu}_{\mathcal{B}} is the sample mean and σ^B\hat{\sigma}_{\mathcal{B}} is the sample std of minibatch B\mathcal{B}, calculated per dimension. Standardizing to mean 0 and std 1 is potentially an arbitrary choice and can reduce the expressive power of the NN. Hence we include scale parameters β\beta and α\alpha, which are learned jointly with the rest of the model parameters and are of the same size as the input xx. Note we usually add some small constant ϵ\epsilon to σ^B\hat{\sigma}_{\mathcal{B}} in order to prevent divide by zero errors.

For batchnorm to work we need minibatches of size greater than 1; that is B>1\lvert \mathcal{B} \rvert > 1. Otherwise, each hidden unit would take a value of 0. Batchnorm tends to work effectively for standard minibatch sizes. However, the key point is that the use of batchnorm makes one’s choice of minibatch size even more important.

For batchnorm to work in practice, we just add batchnorm layers to the NN. It is a differentiable operation and the gradient of the loss with respect to the different parameters can be computed with the chain rule as usual. However, unlike other operations, the minibatch size will show up in the gradient formulas. Consider fully-connected layers. Denote the input as xx, the weight matrix as WW, the bias vector as bb and activation as ϕ\phi. The output, hh, is then given by:

h=ϕ(BN(Wx+b)) h = \phi(\text{BN}(Wx + b))

In certain applications, you can also put the BN\text{BN} operation after the activation ϕ\phi. Convolutional layers will be a little different. We can apply BN\text{BN} after convolution and before ϕ\phi. If there are multiple output channels, batchnorm is applied to each channel and each channel has its own scale coefficient/offsets, which are scalars. Assume that B=m\lvert \mathcal{B} \rvert = m and the convolution output has height pp and width qq. Then, per output channel, batchnorm will compute the mean/std over m×p×qm \times p \times q elements.

Another detail to keep in mind is how batchnorm works differently during train and inference times. In train time, the variable values are standardized by minibatch statistics. It would be impossible to use the full dataset statistics because the variable values change every time a minibatch passes through and model parameters are updated. In the inference stage, we have a fixed trained model and can use full dataset statistics to get estimates with less sample noise. Moreover, in the inference stage, we may be making predictions one at a time, so do not even have access to the original minibatches. In practice, people keep track of EWMAs of the intermediate layer sample means/stds and use these statistics for inference. In PyTorch, you typically want to set the momentum parameter to 0.1 or 0.01, which correspond to the weight on the latest minibatch statistic. The closer we set momentum to 0, the more we are relying on historical minibatches.

Why Does It Help?

In some sense, batchnorm does not seem to do anything because it standardizes variable values with μ^B\hat{\mu}_{\mathcal{B}} and σ^B\hat{\sigma}_{\mathcal{B}}, but then scales them back with α\alpha and β\beta. The key idea is that parameterizing NNs with and without batchnorm can represent the same set of functions, but using batchnorm makes optimization with standard algorithms like GD is much easier. Empirically, DL practitioners have found that batchnorm allows for more aggressive learning rates without vanishing/exploding gradients (so faster convergence) and makes training a NN robust to different initializations and/or learning rates. Intuitively, it seems to play the same role as pre-standardizing one’s features in non-deep ML. Moreover, inputs to intermediate layers cannot diverge to crazy magnitudes during training since they are actively re-standardized. Moreover, for particular activation functions ϕ\phi, it may help keep variable values in the non-saturating regime. For linear 1-layer NNs, one can explicitly show that batchnorm implies a better conditioned Hessian.

Going back to our example y^=xw1w2wl\hat{y} = x w_{1} w_{2} \dots w_{l}, we can see more precisely how batchnorm solves the stated issues and makes learning easier. Specifically, assume that xx is N(0,1)\mathcal{N}(0, 1). Then it is clear that hl1h_{l - 1} will also be Gaussian, but not with mean 0 and std 1. However, BN(hl1)\text{BN}(h_{l - 1}) is back to being N(0,1)\mathcal{N}(0, 1) and will remain so for almost any update to lower layers. Hence we can learn the simple linear function y^=wlBN(hl1)\hat{y} = w_{l} \text{BN}(h_{l - 1}). Since the parameters from the preceding layers do not have any effect anymore, learning here is a simple task. There are corner cases where lower layers do have an effect. For example, if one changes wil1w_{i \leq l - 1} to 0, the output is degenerate. Also, if one flips the sign of wil1w_{i \leq l - 1}, then the relationship between BN(hl1)\text{BN}(h_{l - 1}) and yy also flips sign. In this example, we have stabilized learning at the expense of making the earlier layers useless. But this is because we have a linear model. With non-linear activations, the preceding layers remain useful.

In addition to improving training, batchnorm improves generalization. One theory is that batchnorm has a regularization effect. Empirically, practitioners have found that dropout tends to be less important when using batchnorm. First, the use of minibatches adds a source of noise to the training process (much like dropout) which for unknown reasons causes less overfitting. More precisely, since the samples in a given minibatch are randomly selected, the quantities μ^B\hat{\mu}_{\mathcal{B}} and σ^B\hat{\sigma}_{\mathcal{B}} introduce random fluctuations to the training process of hidden layers. More theoretically, the authors in LWSP18 analyze batchnorm in a Bayesian framework. They connect its effect to imposing particular priors/penalities on parameters which encourage the NN to not over-rely on a particular neuron and reduce correlations among different neurons. In this set-up, the strength of regularization is inversely proportional to B\lvert \mathcal{B} \rvert, indicating that making the minibatch too large may not be optimal. This result is consistent with the finding that batchnorm works best for minibatches of sizes 50-100.

The theoretical backing for batchnorm’s optimization improvement is controversial. In STIM18, the authors argue it has nothing to do with internal covariate shift. They look at histograms of the inputs into a various layers and find that these histograms are fairly stable over the course of training iterations for both standard and batch normalized VGG NNs. The authors then create a “noisy” batchnorm NN where they add non-stationary Gaussian noise to the outputs of the batchnorm layer. While the histograms of this “noisy” batchnorm are noticeably less stable, it still converges faster than a standard NN.

Instead, the authors argue that batchnorm has a smoothing effect on the optimization landscape which makes first-order methods more quicker and more robust to hyperparameter choices. They measure the variation of the value of the loss, L(xηL(x))\mathcal{L}(x - \eta \nabla \mathcal{L}(x)), and change of the loss gradient, L(x)L(xηL(x))\|\nabla \mathcal{L}(x) - \nabla \mathcal{L}(x - \eta \nabla \mathcal{L}(x))\| for η[0.05,0.4]\eta \in [0.05, 0.4] over the course of training. They find that the batchnorm network has smaller variability of loss and smaller change of the loss gradient. So, steps taken during GD are unlikely to drive the loss very high and the gradient at a point xx stays relevant over longer distances supporting larger η\eta. More precisely, their theoretical analysis shows batchnorm improves the Lipschitzness of both the loss and gradients (i.e. β\beta-smoothness). Together these imply that the gradients are more predictive allowing for larger learning rates and faster convergence.

The theory of batchnorm is again brought into question by CKLKT21. The authors argue that for a wide range of practical learning rates and many NN architectures, full-batch GD very quickly enters the edge of stability:

They find this phenomenon holds for both standard and batch normalized VGG NNs. Hence they state that the experiments in STIM18 cannot be interpreted as showing that batchnorm improves the local smoothness of the loss along the optimization trajectory. Some of the empirical analysis here is based on re-doing the experiments with a more appropriate measure for “effective smoothness” and showing that this new measure does not actually show better behavior in the batchnorm setting.