Skip to content

Summary: Multi-head attention: many lenses on the same sentence

Multi-head attention is what is actually inside a real transformer layer. Instead of one Q, K, V projection per layer, the layer runs h independent attention computations in parallel, each through its own learned W_Q, W_K, W_V matrices. Without multi-head, an attention layer can only encode one weighting pattern per token, and real sentences have many simultaneous structures (syntactic, coreference, positional, semantic) that one weighting cannot capture.

With multi-head attention, each head can specialize on a different kind of context. The h outputs concatenate back into the layer’s main dimension d_model and pass through one final projection W_O. Shape in equals shape out; representational capacity goes up.

This summary is the scan-it-in-five-minutes version. The full lesson covers the structural limit of single-head, the split-run-concatenate pattern, the dimension flow on a worked 12-head example, why multiple heads beat one big head, what model cards mean when they say “12 heads,” and the named pitfalls (heads versus layers, heads versus mixture of experts).

  • Single attention has a one-perspective limit per token. It produces one weighted blend of context. A sentence has many kinds of structure happening at once; a single softmax-weighted sum can only encode one weighting at a time.
  • Multi-head attention runs h independent attention computations in parallel. Each head has its own W_Q, W_K, W_V matrices and its own perspective on the input.
  • Split, run, concatenate, project. Project the input to h smaller Q, K, V triples (each at d_k = d_model / h). Run h independent attention computations. Concatenate the h outputs back to d_model. Project once more through W_O. The output shape equals the input shape.
  • Running example: 12 heads at d_model = 768. Each head operates on 64-dim Q, K, V. Twelve 64-dim outputs concatenate to 768. W_O is a 768 by 768 matrix.
  • W_O is the mixing layer. Without it, the heads would be a stack of independent computations with no interaction. W_O lets the model combine information across heads before the output goes to the next layer or block.
  • Heads can specialize, but most are not human-interpretable. Some heads attend to recognizable patterns. Most do not have clean roles. The interpretability literature is mixed; treat heads as a structural mechanism, not as named lenses.
  • Head count is a partition, not free capacity. As h grows at fixed d_model, d_k shrinks and each head has less to work with. Practical models cluster around 8, 12, 16, or 32 heads, with d_model chosen so d_k lands around 64 or 128.
  • Multi-head works for self-attention and cross-attention. Multi-head is orthogonal to the self-versus-cross distinction; both use the same trick.
  • Inference optimizations: MQA and GQA. Multi-query attention and grouped-query attention share keys and values across heads (or groups of heads) to cut inference cost. If a model card lists num_key_value_heads smaller than num_attention_heads, the model is using one of these.
  • Heads are not layers. Heads run in parallel inside one attention layer. Layers stack vertically (the output of one full layer becomes the input to the next). A 12-layer model with 12 heads per layer has 144 attention computations per forward pass.
  • Heads are not mixture of experts. Multi-head varies the attention computation; mixture of experts (MoE) varies the feed-forward network. Different mechanisms, different parts of the layer; do not conflate them.

Before this lesson, “12 heads” was a number you saw in model cards without a clear mental picture. Now it is a specific architecture choice: 12 parallel attention computations per layer, each at d_k = d_model / 12, concatenated and projected. When you read a model card and see num_attention_heads, you can reason about what changes when that number goes up or down. When a model card mentions multi-query or grouped-query attention, you understand what is being shared and why. The next lesson, on the full transformer block, picks up where this one stops: now that we have multi-head attention working, what else does a layer wrap around it (feed-forward, residuals, layer norm) to make a complete transformer block?

One head asks one question.
Many heads ask many, all at once.