Practice: Why transformers need stability to learn: LayerNorm, pre-norm, and RMSNorm
Self-check
Section titled “Self-check”Answer in your head (or on paper) before opening the collapsible.
1. What does LayerNorm actually do to a vector?
Show answer
Per token, LayerNorm subtracts the mean of the vector’s components and divides by their standard deviation, then applies two learnable parameters: gamma (rescale) and beta (shift). The result is an activation vector whose components sit in a controlled range, regardless of how extreme the inputs to the layer were. The keyword the literature uses for the underlying problem is internal covariate shift.
2. Why do transformers use LayerNorm instead of BatchNorm?
Show answer
Two reasons. First, in the lecturer’s words: “probably because empirically it works better.” Second, BatchNorm depends on batch composition: the statistics it computes during training (mean and standard deviation of each component across the batch) differ from the statistics it would compute at inference time (when the batch may be a single sequence). LayerNorm avoids that train-vs-inference difference because it operates per-token without ever looking at other vectors in the batch.
CV intuition substitution: BatchNorm normalizes “one component across many vectors”; LayerNorm normalizes “one vector across many components.” Same operation, different axis.
3. What is the structural difference between post-norm and pre-norm?
Show answer
Post-norm (the original 2017 transformer): output = LayerNorm(x + SubLayer(x)). Run the sub-layer, add the residual, then normalize the sum.
Pre-norm (modern transformers): output = x + SubLayer(LayerNorm(x)). Normalize the input first, then run the sub-layer on the normalized version, then add the residual.
Same components, different placement. Sub-layer = either attention or feed-forward network.
4. How does RMSNorm differ from LayerNorm, and what is the practical reason most modern LLMs use it?
Show answer
RMSNorm skips the mean subtraction and the learnable shift. Just divide the vector by the root mean square of its components, then apply learnable rescale (gamma) only. No beta.
The lecturer’s framing: “the convergence properties are basically comparable, but here you have fewer parameters to learn, so it’s basically quicker.”
Same per-token operation as LayerNorm, fewer arithmetic steps, fewer learnable parameters.
5. What does “Pre-RMSNorm” in a model card mean?
Show answer
Two architectural choices combined. Pre-norm placement: the normalization sits before the sub-layer (x + SubLayer(Norm(x))). RMSNorm computation: the normalization is RMSNorm rather than LayerNorm (no mean subtraction, no learnable shift). Most modern open-weight LLMs use this combination as their default.
6. Why are the “Add & Norm” boxes in the original transformer diagram more important than they look?
Show answer
They improve convergence and shorten training time at every model size. The lecturer calls the residual-plus-normalization combination “a little trick” the original authors used; in practice the trick is what makes the deep stack of sub-layers trainable in the first place. The mechanism is small. Removing it makes training noticeably worse.
Try it yourself: LayerNorm by hand
Section titled “Try it yourself: LayerNorm by hand”This exercise puts the mechanism into practice. About 12 minutes.
Side effects: none. Pen and paper, or a text editor.
Setup: you have a 4-component activation vector x = (8, -2, 4, 6). Apply LayerNorm step by step.
Step 1: Compute the mean.
Show answer
mean(x) = (8 + (-2) + 4 + 6) / 4 = 16 / 4 = 4
Step 2: Compute the standard deviation. Use the population formula: std(x) = sqrt( sum((x_i - mean)^2) / d ).
Show answer
Squared deviations: (8-4)^2 = 16, (-2-4)^2 = 36, (4-4)^2 = 0, (6-4)^2 = 4.
Sum: 16 + 36 + 0 + 4 = 56.
Variance: 56 / 4 = 14.
Standard deviation: sqrt(14) ≈ 3.742.
Step 3: Compute the normalized vector (x - mean(x)) / std(x).
Show answer
(8 - 4) / 3.742 ≈ 1.069
(-2 - 4) / 3.742 ≈ -1.604
(4 - 4) / 3.742 = 0
(6 - 4) / 3.742 ≈ 0.535
So normalized ≈ (1.069, -1.604, 0, 0.535). Notice the magnitudes are now all in roughly the -2 to 2 range, regardless of how big or small the original components were.
Step 4: Apply learnable rescale and shift. Suppose the model has learned gamma = (1, 1, 1, 1) and beta = (0, 0, 0, 0) (the identity defaults). What is the output?
Show answer
With gamma = (1, 1, 1, 1) and beta = (0, 0, 0, 0), the rescale and shift are identity operations. The output is just the normalized vector: (1.069, -1.604, 0, 0.535).
In a trained model, gamma and beta would have non-trivial learned values that allow the model to recover any range it needs while keeping the normalization step in place.
Step 5: RMSNorm comparison. Compute RMSNorm on the same x = (8, -2, 4, 6). Show how the steps differ from LayerNorm, and what the output is with gamma = (1, 1, 1, 1).
Show answer
RMSNorm skips the mean subtraction. Compute RMS directly:
rms(x) = sqrt( (8^2 + (-2)^2 + 4^2 + 6^2) / 4 ) = sqrt( (64 + 4 + 16 + 36) / 4 ) = sqrt(120 / 4) = sqrt(30) ≈ 5.477
Normalized: x / rms(x) = (8/5.477, -2/5.477, 4/5.477, 6/5.477) ≈ (1.461, -0.365, 0.730, 1.096).
With gamma = (1, 1, 1, 1), the output is just the normalized vector.
Notice two things: (a) RMSNorm did not subtract the mean, so the output is not centered around zero; (b) there was no beta to add at the end. Fewer arithmetic steps, fewer learned parameters.
Sanity check: the goal of this exercise is to feel the operation in your hands. Once you have computed both LayerNorm and RMSNorm on the same vector, the difference is concrete: LayerNorm centers and rescales; RMSNorm just rescales. Both bring extreme values into a controlled range; RMSNorm just does it with less arithmetic.
Flashcards
Section titled “Flashcards”Twelve cards. Click any card to reveal the answer. Use the Print flashcards button to lay out the full set as one card per page, ready to print or save as a PDF for offline review.
Q. What problem does LayerNorm solve?
Activations inside a deep network drift (one component at 50, another at 0.001, etc.); the next layer struggles to learn from inputs that vary wildly. The keyword in the literature is internal covariate shift. LayerNorm rescales each token’s activation vector into a controlled range so the next layer sees something usable.
Q. What does LayerNorm do, in five steps?
(1) Take a single token’s activation vector. (2) Compute the mean of its components. (3) Subtract the mean from each component. (4) Divide by the standard deviation. (5) Apply learnable rescale (gamma) and shift (beta), per component. Output: an activation vector with components in a controlled range, regardless of input magnitude.
Q. Why use LayerNorm instead of BatchNorm in transformers?
Empirically better (the lecturer’s framing: “probably because empirically it works better”), and structurally cleaner: BatchNorm depends on batch composition, so train-time statistics differ from inference-time statistics. LayerNorm avoids that gap by operating per-token without looking at other vectors in the batch.
Q. CV intuition: what is the difference between LayerNorm and BatchNorm?
BatchNorm normalizes “one component across many vectors” (across the batch dimension). LayerNorm normalizes “one vector across many components” (across the feature dimension). Same operation, different axis.
Q. What is post-norm in the original 2017 transformer?
output = LayerNorm(x + SubLayer(x)). Run the sub-layer (attention or FFN) on the input, add the residual, then normalize the sum. The LayerNorm sits AFTER the addition.
Q. What is pre-norm in modern transformers?
output = x + SubLayer(LayerNorm(x)). Normalize the input first, then run the sub-layer on the normalized version, then add the residual. The LayerNorm sits BEFORE the sub-layer.
Q. Why did the field move from post-norm to pre-norm?
The lecture is brief on this and does not unpack it. The widely-cited explanation in the literature is that pre-norm keeps the residual stream’s magnitude better controlled as networks get deeper. What the lecture does say clearly: most modern transformer architectures use pre-norm.
Q. What is RMSNorm?
A simplification of LayerNorm. Skip the mean subtraction; just divide the vector by the root mean square of its components. Skip the learnable shift; learn rescale (gamma) only. Same per-token operation, fewer arithmetic steps, fewer learnable parameters.
Q. Why use RMSNorm instead of LayerNorm?
In the lecturer’s words: “the convergence properties are basically comparable, but here you have fewer parameters to learn, so it’s basically quicker.” Empirically comparable quality, faster computation, smaller parameter count.
Q. What does 'Pre-RMSNorm' in a model card mean?
Pre-norm placement (normalize before the sub-layer) plus RMSNorm computation (no mean subtraction, no learnable shift). The combination most modern open-weight LLMs use as their default.
Q. What is the most common mistake in distinguishing LayerNorm from BatchNorm?
Conflating them because both are normalizations. Same general idea (rescale activations into a usable range), different axis. LayerNorm operates per-token across the feature dimension; BatchNorm operates per-feature across the batch dimension. Transformers use LayerNorm because it doesn’t depend on batch composition.
Q. What is the one-sentence takeaway?
LayerNorm rescales the activations. Pre-norm moves where it sits. RMSNorm changes what it computes.