Skip to content

Why deep nets won't train (and how to fix it): activations, gradients, and BatchNorm

The MLP from last lesson works and generates better names, but if you train it naively you hit a frustrating wall: the loss starts absurdly high, then barely moves, and the network learns slowly or not at all. The architecture is fine. The problem is everything around the math, how the weights start and how the signals flow, and getting that wrong quietly sabotages an otherwise correct network. This lesson opens the network up while it trains, finds the two things going wrong, and fixes them.

The contract holds: nothing inside is a mystery, including the parts that usually get hand-waved as “training tricks.” They are not magic; they are answers to specific, diagnosable problems.

Watch the loss on the very first step of naive training and it is enormous, far higher than it has any right to be. Here is why: at initialization the weights are random, and random weights in the output layer produce logits (the 27 raw output numbers) that are large and spread out. Softmax turns large, spread-out logits into a confident probability distribution, so the untrained network starts out sure of nonsense. Being confidently wrong is the most expensive thing a model can do under negative log likelihood, so the loss spikes.

What should the starting loss be? A model that knows nothing should hedge: assign every one of the 27 characters an equal 1/27 probability. The loss for that uniform guess is -log(1/27) = log(27) = 3.30 (using natural log, the makemore convention). A healthy network should begin training right around 3.30, not at 20 or more.

The fix is to start the output weights small, so the initial logits are near zero. Equal logits give a near-uniform softmax, the loss starts at its natural 3.30 baseline, and every training step from there does real work instead of spending its first hundreds of steps just squashing the over-confident logits back down. That wasted early phase is the telltale “hockey stick” in the loss curve, and a good initialization removes it.

The deeper and more dangerous problem is in the hidden layer’s tanh. Recall from the autograd lesson that tanh’s local derivative is 1 - tanh(x)^2. When a neuron’s input is large (strongly positive or negative), tanh output sits out near +1 or -1, in its flat tails, and there 1 - tanh^2 is nearly zero.

Put numbers on it. The local derivative 1 - tanh(x)^2 collapses fast as the input grows:

input x = 0: tanh = 0.00, local derivative = 1.00 (fully responsive)
input x = 1: tanh = 0.76, local derivative = 0.42
input x = 2: tanh = 0.96, local derivative = 0.07
input x = 3: tanh = 0.995, local derivative = 0.01 (effectively numb)

By a pre-activation of 3, the neuron passes back only about one percent of the gradient it would at zero. That near-zero local derivative is fatal, because backpropagation multiplies the incoming gradient by it. A neuron pinned in the tails passes almost no gradient to the weights feeding it, so those weights barely update: the neuron has gone numb, or “saturated.” If a neuron is saturated for every training example, it is effectively dead, it learns nothing and contributes nothing. A whole layer can quietly die this way, and the network limps along on whatever neurons survived.

You can see this. Plot a histogram of a layer’s tanh outputs during training: if they pile up against -1 and +1, the layer is saturated and in trouble; if they spread across the responsive middle range, it is healthy. That histogram is one of the most useful diagnostics in all of neural network training, because it turns an invisible failure into a picture.

The cause of saturation is the same as the cause of the bad starting loss: pre-activations that are too large. So the cure is to control their size.

How large the pre-activations get depends on two things: the scale of the weights and how many inputs feed each neuron (the fan-in). Each neuron sums many weighted inputs, and summing many numbers tends to produce a larger number, so the more inputs a neuron has, the bigger its pre-activation grows for a given weight scale, and the more likely it is to saturate tanh.

The fix is to shrink the weights in proportion: scale each layer’s initial weights by roughly 1 / sqrt(number of inputs). That counteracts the growth from summing, so the pre-activations come out with a healthy, roughly unit-sized spread, large enough to be expressive, small enough to stay in tanh’s responsive middle. Done layer by layer, this keeps the signal at a steady scale all the way through the network instead of exploding into saturation or shrinking toward nothing. This recipe is called Kaiming (or He) initialization, and good initialization alone is often enough to make a deep network train well.

Hand-tuning the initialization of every layer is doable but fragile, especially as networks get deep. Batch normalization sidesteps the whole problem by enforcing well-behaved activations directly, during training, rather than hoping a careful initial scale survives.

The idea is simple. Before a layer’s pre-activations go into tanh, insert a step that, for the current minibatch, normalizes them to have zero mean and unit variance, literally subtract the batch’s mean and divide by its standard deviation. Now, by construction, the pre-activations are centered and modestly sized, exactly the range where tanh is responsive, no matter how the weights happened to be initialized. To make sure the network can still represent whatever distribution it actually needs (normalizing everything to a fixed shape would be too rigid), batch normalization then applies a learned gain and bias, two trainable parameters that let the network scale and shift the normalized values back to whatever spread works best.

Two caveats are worth naming, because they trip people up. First, batch normalization couples the examples in a batch: a given example’s output now depends on the other examples it happened to be batched with, since they all share the batch’s mean and variance. That coupling acts as a mild regularizer (a feature) but is also a notorious source of subtle bugs (a cost). Second, at inference time you often have just one example and no batch to compute statistics from, so batch normalization keeps a running average of the mean and variance seen during training and uses those at inference instead. With these in place, the network trains smoothly almost regardless of initialization.

None of this is specific to makemore. No deep network, including every large language model, trains without solving exactly these two problems. Sane initialization is standard practice in every framework. And normalization layers are everywhere: every transformer has them at its core (a close cousin called layer normalization), placed precisely to keep activations and gradients well-behaved as signals pass through dozens or hundreds of layers.

This is the real answer to a question you may have wondered about: why is training large models considered hard, when the loop is just forward, loss, backward, update? Because in a deep network those signals can explode or vanish as they propagate, and a network whose gradients have vanished learns nothing no matter how long you train it. Initialization and normalization are what keep the gradients alive across depth. When you hear that a training run “diverged” or “was unstable,” or that a new architecture “trains more stably,” this is the territory being discussed. The difference between a deep network that learns and one that sits dead is often nothing more than the techniques in this lesson.

Reading a high starting loss as a model problem. A loss that starts far above the uniform baseline (3.30 for 27 characters) is almost always an initialization problem, over-confident logits, not a sign the architecture is wrong. Check the starting loss first.

Forgetting that saturation kills the gradient, not the output. A saturated tanh still produces an output (near +1 or -1); what it loses is the gradient, because 1 - tanh^2 is near zero there. The neuron looks like it is doing something while it has actually stopped learning.

Thinking batch normalization replaces good initialization. It makes training far more forgiving, but the underlying goal is the same, well-scaled activations, and understanding initialization is what lets you diagnose problems when normalization alone is not enough.

Ignoring the batch-coupling of batch normalization. Because each example’s output depends on its batchmates, batch normalization behaves differently at training time (real batch statistics) and inference time (running averages). Mishandling that difference is one of the most common real-world bugs with it.

  • A naive deep network fails in two diagnosable ways. It starts confidently wrong (random weights make over-confident logits, so the loss spikes far above the -log(1/27) = 3.30 uniform baseline), and its tanh neurons saturate (large pre-activations push outputs to the flat tails where 1 - tanh^2 is near zero, so almost no gradient flows and neurons go numb or die).
  • Good initialization fixes both at the source. Start output weights small so the loss begins at its natural baseline, and scale each layer’s weights by about 1 / sqrt(number of inputs) so pre-activations keep a healthy spread and signals neither explode nor vanish through depth.
  • Batch normalization makes it automatic. Normalize each layer’s pre-activations to zero mean and unit variance across the batch, then rescale with a learned gain and bias. It keeps activations well-behaved during training regardless of initialization, at the cost of coupling examples in a batch and needing running statistics at inference. These exact techniques (sane init plus normalization layers like layer normalization) are what make every deep network, transformers included, trainable at all.

You can now not only build a deep network but make it actually train, by seeing what the activations and gradients are doing and keeping them healthy. There is one piece of the engine you have so far taken on trust: the backward pass itself. The next lesson removes the autograd engine’s safety net and has you backpropagate through the whole network by hand, so that gradient flow stops being something a library does for you and becomes something you could compute, and debug, yourself.