Skip to content

Practice: Multi-head attention: many lenses on the same sentence

Seven short questions. Try to answer each one in your head (or on paper) before opening the collapsible. Active retrieval is where the learning sticks; rereading is comfortable but does much less.

1. Why is one attention head structurally limited?

Show answer

It produces only one weighted blend of context per token. A real sentence has many kinds of structure happening at once (syntactic, coreference, positional, semantic), and a single softmax-weighted sum can only encode one weighting. Information about the other structures gets compressed away.

2. Define h, d_model, and d_k. What’s the relationship between them?

Show answer

h is the number of heads in the layer. d_model is the model’s main embedding dimension (the vector size of each token going in and out of attention). d_k is the per-head dimension. The relationship is d_k = d_model / h. So with d_model = 768 and h = 12, d_k = 64.

3. Walk through what happens to a single token’s representation in one multi-head attention layer, in terms of vector shape.

Show answer

Input: a d_model-dim vector (the token’s embedding). For each of h heads: project to a d_k-dim Q, K, V via that head’s W_Q, W_K, W_V. Each head runs the attention formula on its own d_k-dim Q/K/V and produces a d_k-dim output. The h head outputs concatenate back into a d_model-dim vector (since h × d_k = d_model). One final linear projection through W_O produces the layer’s d_model-dim output. Shape in equals shape out.

4. What does W_O do, and why is it necessary?

Show answer

W_O is the output projection that mixes the concatenated head outputs into the layer’s final output. Without it, the layer would just be a stack of independent heads with no interaction; W_O lets the model combine information across heads (giving more weight to some, less to others) before passing the result to the next layer or block.

5. Are individual heads typically interpretable as having human-readable roles (“the gender head,” “the syntax head”)?

Show answer

No, in general. Some heads do attend to recognizable patterns (positional, syntactic head of phrase, coreference), but most do not have a clean role. The literature on head interpretability is mixed; many heads can also be pruned without much performance loss. Treat individual heads as a structural mechanism, not as named lenses.

6. A model card says “768 hidden dim, 12 heads.” What does this tell you about the multi-head attention layers?

Show answer

d_model = 768, h = 12, so d_k = 64. Each layer has 12 parallel attention computations on 64-dim Q, K, V vectors, concatenated back to 768-dim, then projected through W_O (which is a 768 by 768 matrix). The layer’s input and output shape are both 768.

7. A model card lists “12 layers, 12 heads per layer.” How many attention computations happen in one forward pass through the model, and why?

Show answer

144. Heads run in parallel inside a single attention layer; layers stack vertically (the output of one full layer becomes the input to the next). So 12 layers times 12 heads per layer is 144 attention computations per forward pass. The two numbers compose multiplicatively, not additively.

Try it yourself: dimension arithmetic on a different config

Section titled “Try it yourself: dimension arithmetic on a different config”

This is the mechanism in motion. Different numbers, same split-run-concatenate pattern. About 10 minutes with a pen.

Side effects: none. Paper arithmetic. No API calls, no costs.

Setup: consider a transformer layer with d_model = 1024 and h = 16 heads.

Steps:

  1. Compute d_k for this configuration.

  2. Walk through the dimension flow for a single token: write down the shape after each step (input, per-head Q / K / V, per-head attention output, concatenated output, final output after W_O).

  3. Compute the parameter count of one head’s W_Q matrix. Hint: W_Q has shape d_model × d_k.

  4. Compute the parameter count of W_O. Hint: W_O has shape d_model × d_model.

  5. Sanity check. If you doubled h to 32 while keeping d_model = 1024 constant, what would d_k be? Would the total parameter count of the head projections change?

Expected outcomes:

  • Step 1: d_k = d_model / h = 1024 / 16 = 64.
  • Step 2: input is 1024. Per head: Q, K, V each at 64. Per head output at 64. Concatenated output: 16 × 64 = 1024. After W_O: still 1024.
  • Step 3: one head’s W_Q has shape 1024 × 64, so 65,536 parameters. There are 16 heads, so total W_Q parameters across heads: 16 × 65,536 = 1,048,576. Same count for W_K and W_V.
  • Step 4: W_O is 1024 × 1024 = 1,048,576 parameters.
  • Step 5: doubling h to 32 would give d_k = 1024 / 32 = 32. Each head’s W_Q would be 1024 × 32 = 32,768 parameters; with 32 heads that totals 32 × 32,768 = 1,048,576. The total parameter count of head projections is the same; only the slicing changes.

The takeaway is that head count is a partition of the same total compute and parameter budget, not extra capacity for free. More heads at the same d_model means smaller per-head dimension; the trade is between many narrow lenses or fewer wide ones.

Try it yourself: look up the head count of a published model

Section titled “Try it yourself: look up the head count of a published model”

This is a 5-minute exercise. Open the Hugging Face Hub. Pick any transformer-based model that interests you. A small one is easiest; the 100M to 1B parameter range works well. Open its model card or config.json.

Find these two fields:

  • hidden_size (or d_model, or n_embd, depending on the architecture)
  • num_attention_heads (or n_heads, or attention_heads)

Verify that d_k = hidden_size / num_attention_heads is a clean integer.

Bonus: check whether the model also lists num_key_value_heads. If that number is smaller than num_attention_heads, the model uses grouped-query or multi-query attention to share K/V across heads (an inference-cost optimization mentioned in the lesson body).

You’ll find that model designers tend to cluster around a few common configurations: 8, 12, 16, or 32 heads, with hidden_size chosen so that d_k lands at 64 or 128. Once you’ve seen a few, the pattern becomes obvious.

Twelve cards. Click any card to reveal the answer. Use the Print flashcards button to lay out the full set as one card per page, ready to print or save as a PDF for offline review.

Q. Why does single-head attention have a structural limit?
A.

It produces only one weighted blend of context per token. A sentence has multiple kinds of structure happening at once, and a single softmax-weighted sum can only encode one weighting. Information about the other structures gets compressed away.

Q. What is the relationship between `h`, `d_model`, and `d_k`?
A.

d_k = d_model / h. With d_model = 768 and h = 12, each head operates on d_k = 64-dim Q, K, V vectors. The total per-token compute is roughly comparable to one big attention, just split across heads.

Q. What is the split-run-concatenate pattern in multi-head attention?
A.

(1) Project the input to h smaller Q, K, V triples (each at d_k = d_model / h). (2) Run h independent attention computations in parallel. (3) Concatenate the h outputs back into a single d_model-dim vector. (4) Project once more through W_O. The output shape equals the input shape.

Q. What does `W_O` do?
A.

It mixes the concatenated head outputs into the layer’s final output. Without W_O, the layer would be a stack of independent heads with no interaction; W_O lets the model combine information across heads before the output goes to the next layer or block.

Q. Do individual heads have human-interpretable roles?
A.

In general, no. Some heads attend to recognizable patterns (positional, syntactic, coreference), but most do not have clean roles. The literature on head interpretability is mixed; many heads can also be pruned without significant performance loss. Treat heads as a structural mechanism, not as named lenses.

Q. What does '768 hidden dim, 12 heads' mean in a model card?
A.

d_model = 768, h = 12, so d_k = 64. Each attention layer runs 12 parallel attention computations on 64-dim Q, K, V, concatenates back to 768, then projects through a 768-by-768 W_O. Shape in equals shape out.

Q. What is the difference between heads and layers?
A.

Heads run in parallel inside one attention layer. Layers stack vertically; the output of one full transformer layer becomes the input to the next. A 12-layer model with 12 heads per layer has 144 attention computations per forward pass, not 12.

Q. Does adding more heads at the same `d_model` always improve the model?
A.

No. d_k shrinks as h grows (since d_k = d_model / h); past a point, each head has too few dimensions to learn anything useful. Practical models settle in the 8 to 32 head range per layer.

Q. What are MQA and GQA, and what do they optimize?
A.

Multi-query attention (MQA) and grouped-query attention (GQA) are recent optimizations that share key and value projections across heads (or groups of heads) to cut inference cost. If a model card lists num_key_value_heads smaller than num_attention_heads, the model is using one of these.

Q. Does multi-head attention only apply to self-attention?
A.

No. The multi-head trick works for self-attention (Q, K, V from the same sequence) and cross-attention (Q from one sequence, K and V from another) equally well. Multi-head is orthogonal to the self-versus-cross distinction.

Q. Is multi-head attention the same as multi-layer or mixture of experts?
A.

No. Three different ideas. Multi-head: h parallel attention computations within one layer. Multi-layer: stacked attention layers, output of one feeds the next. Mixture of experts (MoE): different feed-forward networks for different tokens within a layer.

Q. What is the one-sentence takeaway from this lesson?
A.

One head asks one question. Many heads ask many, all at once.