Skip to content

Why deep nets won't train (and how to fix it): activations, gradients, and BatchNorm

This is lesson 3 of Phase 2 (Building a language model) in the Build Neural Networks from Scratch track, which follows the arc of Andrej Karpathy’s Neural Networks: Zero to Hero series. The previous lesson built a deeper MLP that generates better names. This lesson is about making that network actually train well, which turns out to be surprisingly fragile.

You will open the network up during training and find the two things that go wrong. First, a confidently wrong start: random initial weights make over-confident logits, so the loss spikes far above the -log(1/27) = 3.30 uniform baseline, wasting the early training just squashing it back down. Second, saturated neurons: when pre-activations are large, tanh outputs sit in the flat tails where the local derivative 1 - tanh^2 is near zero (about 0.07 at an input of 2), so almost no gradient flows and the neurons go numb. The lesson diagnoses both with activation histograms, fixes them with scaled initialization (1/sqrt(fan-in), Kaiming), and introduces batch normalization, the automatic fix. These are the techniques that make every deep network, transformers included, trainable at all.

This is lesson 3 of Phase 2, Building a language model. The previous lesson built the deeper MLP; this lesson makes it train smoothly instead of stalling. It leans directly on tanh’s local derivative from lesson 1 (saturation is that derivative going to zero) and applies to any network with hidden layers, not just the language model. The next lesson takes away the autograd engine entirely and has you backpropagate through this network by hand, so the gradient flow you diagnose here becomes something you can compute yourself.

Prerequisite (within this track): lesson 4, Giving the model memory: the MLP language model. This lesson is about training that network well, so you need to know its shape (embeddings, a tanh hidden layer, a softmax output, negative-log-likelihood loss). It also relies on tanh’s local derivative 1 - tanh^2 from lesson 1, the autograd engine, since saturation is exactly that derivative going to zero. If “backprop multiplies the incoming gradient by each operation’s local derivative” reads as a fact you own, you are ready. No coding is required to follow along, though running Karpathy’s makemore repo (MIT-licensed) and watching the histograms is the best way to make it concrete.

  • Explain why a naive deep network’s loss starts far above the uniform baseline, and compute that baseline as -log(1/V)
  • Describe tanh saturation and why a near-zero local derivative starves backpropagation, killing learning in a neuron
  • Use an activation histogram as a diagnostic for saturation, and read it
  • Explain how scaling weights by 1/sqrt(fan-in) keeps activations healthy through depth
  • Describe what batch normalization does, its two main caveats, and why normalization layers appear in every modern architecture
  • Read time: about 12 minutes
  • Practice time: about 20 minutes (diagnosing a bad initialization and a saturated neuron by hand, optionally confirmed in the makemore repo, plus flashcards)
  • Difficulty: standard