Lesson: Why transformers need stability to learn: LayerNorm, pre-norm, and RMSNorm
The “Add & Norm” boxes in the original transformer diagram are easy to skim past. They are doing real work.
Each one does two specific things between every sub-layer (attention, feed-forward) and the next: it adds the residual back in, and it normalizes the result so the next sub-layer sees activations in a useful range. The Stanford lecturer calls it “a little trick” the original authors used; in practice it improves convergence and shortens training time enough to matter at every model size.
The mechanism the original paper used is called LayerNorm. The 2017 transformer wrapped it around each sub-layer as LayerNorm(x + SubLayer(x)). Modern transformers do two things differently. First, they put the LayerNorm in a different place: x + SubLayer(LayerNorm(x)) instead of around the residual. Second, they often use a simplification of LayerNorm called RMSNorm. Both changes show up explicitly in modern model cards (you’ll see “Pre-LayerNorm” or “Pre-RMSNorm” listed as architecture features), and both exist for a specific structural payoff.
This is the second of three lessons in this lecture on the architectural changes that stuck after 2017. The position embeddings lesson covered the first one; the next lesson covers attention efficiency. Most of the original transformer is intact in modern LLMs; normalization is one of the small handful of pieces that genuinely moved.
What normalization actually does
Section titled “What normalization actually does”Activations inside a deep network drift. The vector flowing between two sub-layers might have one component at 50, another at 0.001, and a third near -200. The next forward pass through the same layer might push everything small, or everything large. The model is trying to learn weights for the next layer, but those weights have to handle inputs that vary wildly across batches and across layers. That instability slows training down, and at deep enough networks it stops training entirely.
The fix is to bring the values of the components into a controlled range before they reach the next sub-layer. Why exactly normalization helps is a question worth slowing down on, because the framing has shifted in the literature.
The original BatchNorm paper (Ioffe and Szegedy, 2015) framed the benefit in terms of internal covariate shift: the distribution of activations shifts as the network trains, and that shift was supposed to be what makes the next layer’s job harder. That framing has not held up empirically. Santurkar, Tsipras, Ilyas and Madry (2018, “How Does Batch Normalization Help Optimization?”) showed that BatchNorm does not actually reduce internal covariate shift, and that you can artificially inject covariate shift after the normalization without hurting optimization. The currently accepted explanation is that normalization smooths the loss landscape (reduces the Lipschitz constants of the loss and its gradient), which lets gradient descent take larger stable steps. The empirical observation (deep networks train faster and more reliably with normalization) is robust; the mechanism is more subtle than the original paper claimed.
LayerNorm itself, introduced by Ba, Kiros and Hinton (2016) for recurrent networks, was motivated differently from the start. The LayerNorm paper’s framing is per-example hidden-state stabilization: normalize each example’s activations using statistics computed only from that example, which avoids the train-test discrepancies and small-batch failures that BatchNorm runs into. That motivation has translated cleanly to transformers, where each token’s activation vector gets normalized independently of the rest of the batch.
LayerNorm does this per-token. For a single token’s activation vector x with d components:
mean(x) = (x_1 + x_2 + ... + x_d) / dstd(x) = sqrt( ((x_1 - mean)^2 + ... + (x_d - mean)^2) / d )
normalized = (x - mean(x)) / std(x)output = gamma * normalized + betaSubtract the mean, divide by the standard deviation, then apply two learnable parameters: gamma (a per-component rescaling factor) and beta (a per-component shift). The model learns gamma and beta during training, so it can recover any range it needs while keeping the normalization step in place.
The result: every token’s activation vector is scaled into a usable range before the next sub-layer sees it. Convergence improves, training stabilizes, and deep networks become trainable.
LayerNorm versus BatchNorm
Section titled “LayerNorm versus BatchNorm”A reader coming from computer vision will recognize the pattern but expect a different name. In CV, BatchNorm is the dominant normalization scheme: it normalizes each component across the batch dimension, so each component of each vector is normalized against the same component of the other vectors in the batch.
Transformers use LayerNorm, not BatchNorm. Two reasons.
First, the lecturer’s framing: “probably because empirically it works better.” The empirical preference is consistent enough that it is the default for transformer-style models, even if the field has not converged on one clean theoretical reason.
Second, BatchNorm has a structural disadvantage for sequence models: it depends on the batch composition. The statistics you compute at training time (mean and standard deviation of each component across the batch) are different from the statistics you compute at inference time, when the batch may be a single sequence. LayerNorm sidesteps that train-vs-inference difference entirely, because it operates per-token without ever looking at other vectors in the batch.
If you have CV intuition, the easiest substitution is: BatchNorm normalizes “one component across many vectors”; LayerNorm normalizes “one vector across many components.” Same operation, different axis.
The first shift: post-norm to pre-norm
Section titled “The first shift: post-norm to pre-norm”The original 2017 transformer placed LayerNorm in a specific spot. For each sub-layer (attention or feed-forward), the formula was:
output = LayerNorm(x + SubLayer(x))Take the input x, run it through the sub-layer, add the original x back as a residual, then normalize the sum. This is post-norm: the LayerNorm sits AFTER the addition.
Modern transformers put the LayerNorm somewhere else:
output = x + SubLayer(LayerNorm(x))Normalize the input first, then run the sub-layer on the normalized version, then add the original x back as a residual. This is pre-norm: the LayerNorm sits BEFORE the sub-layer.
Same components, different placement. The lecture is brief on why the field moved: it just notes “nowadays we use a prenorm version” and leaves the mechanics implicit. The lecturer does not unpack the reason; the widely-cited explanation in the literature is that pre-norm keeps the residual stream’s magnitude better controlled as networks get deeper.
What the lecture does say clearly: most modern transformer architectures use pre-norm.
Pre-norm moves the LayerNorm before the sub-layer. Post-norm wrapped it around the residual addition.
The second shift: LayerNorm to RMSNorm
Section titled “The second shift: LayerNorm to RMSNorm”The other thing modern transformers do differently is use a simpler normalization called RMSNorm (root mean square normalization).
LayerNorm subtracts the mean and divides by the standard deviation, then applies a learnable rescale and shift. RMSNorm skips the mean subtraction. It just divides by the root mean square of the components, and learns a rescale (gamma) but no shift (beta):
rms(x) = sqrt( (x_1^2 + x_2^2 + ... + x_d^2) / d )normalized = x / rms(x)output = gamma * normalizedThat is the entire mechanism. Same per-token operation, fewer arithmetic steps, fewer learnable parameters.
Why use it? In the lecturer’s words: “the convergence properties are basically comparable, but here you have fewer parameters to learn, so it’s basically quicker.”
Most modern open-weight LLMs you read about use RMSNorm, often paired with the pre-norm placement above. When a model card lists “Pre-RMSNorm” as an architectural choice, that is what it means: normalize the input first via RMSNorm, then run the sub-layer, then add the residual.
RMSNorm drops the mean subtraction and the learned shift. Same idea, fewer steps.
Why this matters when you use AI
Section titled “Why this matters when you use AI”Two consequences worth holding onto when you read AI tooling docs or model cards.
- “Pre-LayerNorm” and “RMSNorm” in model cards are real structural choices, not vendor jargon. When you see them, you now know what they mean. Pre-norm is about where the normalization sits (before each sub-layer, not around the residual). RMSNorm is about what the normalization computes (no mean subtraction, no learnable shift). Both are clean improvements over the 2017 default.
- Normalization is one of three places the field genuinely moved on from the original transformer. Position embeddings (previous lesson), normalization (this lesson), and attention efficiency (next lesson). Most of the rest of the architecture is intact in modern LLMs; these are the substitutions that actually changed.
Common pitfalls
Section titled “Common pitfalls”A few mistakes are common enough to be worth naming.
Conflating LayerNorm with BatchNorm. Same general idea (rescale activations to a usable range), different axis. LayerNorm operates per-token across the feature dimension; BatchNorm operates per-feature across the batch dimension. Transformers use LayerNorm because it does not depend on batch composition and avoids train-vs-inference mismatch.
Thinking RMSNorm is fundamentally different from LayerNorm. It is a simplification, not a different idea. Same per-token normalization, fewer steps (skip the mean subtraction), fewer learnable parameters (skip the shift). Convergence is comparable.
Assuming pre-norm is universally better. For modern LLM-scale networks, yes. For shallow networks or specific architectures that were tuned for post-norm, post-norm can still work. Saying “always pre-norm” overgeneralizes from the modern-LLM case.
Assuming the “Add & Norm” boxes are decoration. They are doing real work. Without them, gradients in a deep transformer either explode or vanish, and training does not converge. The mechanism is small; the consequence of removing it is total.
What you should remember
Section titled “What you should remember”- LayerNorm rescales each token’s activation vector into a controlled range. Subtract the mean, divide by the standard deviation, apply learnable rescale (gamma) and shift (beta). Per-token, not per-batch.
- LayerNorm is preferred over BatchNorm in transformers. Empirically better, and it does not depend on batch composition (so train-vs-inference statistics match).
- The post-norm to pre-norm shift moved the LayerNorm from after the residual addition to before the sub-layer. Pre-norm trains more stably at the depths of modern LLMs; post-norm worked for shallow transformers but breaks at scale.
- RMSNorm is the modern simplification of LayerNorm. Skip the mean subtraction; learn gamma only, no beta. Comparable convergence, fewer parameters, faster.
- Most modern open-weight LLMs use Pre-RMSNorm. When you see “Pre-LayerNorm” or “RMSNorm” in a model card, that is what is happening.
If you remember one thing
Section titled “If you remember one thing”LayerNorm rescales the activations.
Pre-norm moves where it sits.
RMSNorm changes what it computes.