Skip to content

Practice: Why pretraining is a memory engineering problem (parallelism and Flash Attention)

Answer in your head (or on paper) before opening the collapsible.

1. What four kinds of data have to fit in memory during one training step?

Show answer

(1) Parameters (the model weights). (2) Gradients (the direction each weight should be nudged). (3) Optimizer states, specifically Adam’s first and second moments per parameter, so roughly 2x the parameter memory. (4) Activations from the forward pass (the values at each layer that the backward pass needs). Activation memory depends on model size, batch size, and the square of context length because attention is O(n²) in sequence length.

2. Describe what data parallelism contributes and what it costs.

Show answer

Contributes: distributes the batch across GPUs. Each GPU has its own full copy of the model and works on a slice of the batch. Reduces the activation memory per GPU because each GPU only sees part of the batch.

Costs: (1) every GPU still needs the full model copy plus its own gradients and optimizer states, so data parallelism alone does not help for models that are too large for one GPU; (2) gradients have to be averaged across GPUs before each weight update, which is a communication cost that grows with the number of GPUs. There is a regime where adding more GPUs hurts more than it helps because communication starts to dominate.

3. Distinguish ZeRO-1, ZeRO-2, and ZeRO-3 by which quantity each one partitions.

Show answer
  • ZeRO-1 partitions the optimizer states across GPUs. Each GPU holds Adam’s first and second moments for only its slice of the parameters. Substantial memory savings because optimizer states are roughly 2x the parameter memory.
  • ZeRO-2 also partitions gradients. Each GPU only holds the gradient for the slice of parameters it is responsible for.
  • ZeRO-3 also partitions parameters. Each GPU only stores its slice of the model weights at any given time; GPUs fetch the parts they need from each other on demand.

The trade-off going up: more memory savings, more communication. ZeRO is not separate from data parallelism; it is data parallelism with the redundancy removed.

4. Explain what model parallelism is and name its three common variants.

Show answer

Model parallelism splits the model itself across GPUs (as opposed to data parallelism, which splits the batch). Used when the model is too large for any single GPU to hold even briefly.

Three variants the lecturer flags:

  • Tensor parallelism cuts large matrix multiplications across GPUs. Each GPU computes a slice of the multiplication; slices are combined across the network. Used in most frontier-class training.
  • Pipeline parallelism splits the layer stack across GPUs. GPU 1 holds layers 1-3, GPU 2 holds layers 4-6, and so on; a batch flows through layers like an assembly line. Useful when the model is too deep for tensor parallelism alone.
  • Expert parallelism is specific to mixture-of-experts (MoE) architectures. Different experts live on different GPUs; tokens get routed to the right expert’s GPU.

In practice, frontier runs combine multiple flavors (e.g., ZeRO-3 plus tensor parallelism plus pipeline parallelism).

5. Describe how Flash Attention uses the GPU memory hierarchy.

Show answer

A GPU has two relevant kinds of memory: HBM (the GPU memory you see on a spec sheet, tens of GB, a few TB/s of bandwidth) and SRAM (on-chip memory next to the compute units, tens of MB, tens of TB/s, roughly 10x faster).

Standard attention reads and writes the Q, K, V matrices and the giant intermediate QK^T matrix to and from HBM many times. Even though GPU compute is fast, the data-movement between HBM and the compute units is the bottleneck. The lecturer’s framing: “GPU is very very fast but you spend a lot of time just loading matrices from the memory.”

Flash Attention tiles the computation: cut the matrices into blocks small enough to fit in SRAM, compute each tile end-to-end in fast SRAM (the matrix multiply, the partial softmax, the multiply by the value tile, all without leaving SRAM), and write back to HBM only once. The mathematical trick that makes this work: softmax can be computed block by block as long as each block tracks its own scaling factor that combines correctly with the others.

The result is mathematically exact, not an approximation, and materially faster on real hardware. Long context windows became practical in part because of this technique.

6. If Flash Attention is mathematically equivalent to standard attention (same output), where does its speedup come from?

Show answer

Because it does not skip computation. Flash Attention performs all the same matrix multiplications and softmax operations that standard attention does. The output is bit-for-bit equivalent. The speedup is entirely from data movement: standard attention reads and writes intermediate matrices to slow HBM many times; Flash Attention does almost all the work in fast SRAM and only touches HBM once at the start and once at the end. The phrase “memory-bound, not compute-bound” describes the bottleneck this technique addresses.

Try it yourself: a memory budget walk-through

Section titled “Try it yourself: a memory budget walk-through”

This exercise puts the memory-engineering picture into a concrete back-of-envelope estimate. About 12 minutes. Pen and paper, or a calculator.

A team has access to eight GPUs, each with 80 GB of memory. They want to train a model with 20 billion parameters, which they will store in 16-bit precision (so each parameter takes 2 bytes).

a) How much memory do the parameters alone consume?

Show answer

20 billion parameters × 2 bytes/parameter = 40 GB

The parameters by themselves take 40 GB, half of one GPU’s 80 GB.

b) Adam’s optimizer states (first moment + second moment) take roughly 2x the parameter memory in terms of operation count, but each moment is typically stored in higher precision (FP32, so 4 bytes per number). Estimate the optimizer-state memory cost.

Show answer

Two moments per parameter at 4 bytes each = 20 billion × 2 × 4 = 160 GB.

The optimizer states alone are 160 GB, twice the entire memory of one GPU. (This is one reason ZeRO-1’s partitioning of optimizer states is the largest single memory win.)

c) Gradients are typically stored in the same precision as parameters (FP16, 2 bytes). Estimate the gradient memory cost.

Show answer

20 billion × 2 = 40 GB. Same as the parameters.

a) Sum the parameter, optimizer-state, and gradient memory. Will plain data parallelism (every GPU has a full copy of all of this) fit on a single 80 GB GPU?

Show answer

40 + 160 + 40 = 240 GB of memory just for parameters + gradients + optimizer states (no activations yet).

Plain data parallelism would not fit on a single 80 GB GPU. The 240 GB exceeds the per-GPU capacity by a factor of three. And we have not even counted activations, which depend on batch size and context length and can add tens of GB more.

b) Apply ZeRO-3 across the eight GPUs (each GPU holds a slice of the parameters, gradients, and optimizer states). Roughly how much per-GPU memory do those three categories consume?

Show answer

240 GB / 8 GPUs = 30 GB per GPU for the three partitioned categories. Now there is room (the remaining 50 GB per GPU) for activations and intermediate computation. This is the kind of arithmetic that makes ZeRO-3 enabling for runs of this scale.

c) What additional cost does ZeRO-3 incur compared to plain data parallelism?

Show answer

Communication. ZeRO-3 partitions parameters across GPUs, so during each forward pass and backward pass each GPU has to fetch the parameter slices it does not currently hold from the other GPUs that do. That fetch traffic happens on the GPU interconnect at every layer of every step. The trade-off across ZeRO levels is exactly this: more memory savings, more communication.

Part three: which technique for which problem?

Section titled “Part three: which technique for which problem?”

For each scenario, identify which technique (or combination) is the most appropriate.

a) A team’s model fits on one 80 GB GPU but they want to use eight GPUs to train faster on a fixed dataset.

Show answer

Plain data parallelism is enough. The model fits on one device, so there is no need to partition the model itself. Splitting the batch across eight GPUs gives close to 8x throughput minus the cost of averaging gradients. ZeRO would be overkill for this case.

b) A team’s model is 100 billion parameters and does not fit on one GPU even with parameters alone, regardless of batch size.

Show answer

ZeRO-3 at minimum, almost certainly combined with tensor parallelism and possibly pipeline parallelism. ZeRO-3 partitions parameters across GPUs so that no single GPU holds the full model. Tensor parallelism cuts the matrix multiplications across GPUs so that each forward step operates on slices. Pipeline parallelism distributes layers across GPUs so the depth fits. Frontier-class training routinely combines all three.

c) A team has a 7-billion-parameter model that fits comfortably on one GPU, but they are training on context lengths of 32K tokens and finding that attention is the bottleneck.

Show answer

Flash Attention. This is exactly the case it is designed for: the model fits, but attention’s O(n²) memory cost in sequence length combined with the HBM data-movement bottleneck makes long-context training expensive. Flash Attention removes the data-movement bottleneck while producing the same output as standard attention. Most training frameworks now use it by default for this reason.

Sanity check: the four techniques in this lesson are not alternatives; they address different bottlenecks. Memory-per-GPU constraints want data parallelism + ZeRO. Model-too-big-for-one-GPU constraints want model parallelism. Long-context attention constraints want Flash Attention. Frontier runs combine the relevant techniques rather than picking one.

Twelve cards. Click any card to reveal the answer.

Q. What four kinds of data must fit in memory during one training step?
A.

Parameters, gradients, optimizer states (Adam’s first and second moments per parameter, roughly 2x parameters), and activations from the forward pass. Activation memory grows with model size, batch size, and the square of context length.

Q. Why is single-GPU memory the constraint for frontier training?
A.

Frontier-scale models cannot fit even just their parameters plus gradients plus optimizer states on a single 80 GB GPU, let alone the activations from a forward pass. The Stanford lecturer’s framing: there is “a lot of things to save” against memory that “is not unlimited.”

Q. What does data parallelism do?
A.

Distributes the batch across GPUs. Each GPU has its own full copy of the model and works on a slice of the batch. Reduces activation memory per GPU. Requires gradient averaging across GPUs before each weight update, which is a communication cost that grows with GPU count.

Q. What does ZeRO-1 partition?
A.

Optimizer states. Each GPU holds Adam’s first and second moments for only its slice of the parameters. Substantial memory savings because optimizer states are roughly 2x the parameter memory.

Q. What does ZeRO-2 add to ZeRO-1?
A.

Partitions gradients in addition to optimizer states. Each GPU only holds the gradient for the slice of parameters it is responsible for.

Q. What does ZeRO-3 add to ZeRO-2?
A.

Partitions parameters themselves. Each GPU only stores its slice of the model weights; GPUs fetch the parts they need from each other on demand. Most memory savings, most communication cost.

Q. What is model parallelism?
A.

Splitting the model itself across GPUs (as opposed to data parallelism, which splits the batch). Used when a single GPU cannot hold the model. Three common variants: tensor parallelism (cut matrix multiplications), pipeline parallelism (cut the layer stack), expert parallelism (specific to MoE).

Q. What does tensor parallelism do?
A.

Cuts large matrix multiplications inside the model across GPUs. Each GPU computes one slice of the multiplication; slices are combined across the network. Used in most frontier-class training.

Q. What does pipeline parallelism do?
A.

Splits the layer stack across GPUs. GPU 1 holds layers 1-3, GPU 2 holds layers 4-6, and so on. A batch flows through the layers like an assembly line.

Q. What is HBM vs SRAM in a GPU?
A.

HBM is the “GPU memory” on a spec sheet: tens of GB, a few TB/s bandwidth, slow relative to compute. SRAM is on-chip memory next to compute units: tens of MB, tens of TB/s bandwidth, roughly 10x faster than HBM.

Q. What does Flash Attention do?
A.

Tiles the attention computation into blocks small enough to fit in SRAM, computes each tile end-to-end in fast SRAM, writes back to HBM only once. Removes the data-movement bottleneck of standard attention. Mathematically exact, not an approximation. Developed by Tri Dao and collaborators at Stanford in 2022.

Q. What is the one-sentence takeaway?
A.

Pretraining at scale is a memory engineering problem. Parallelism distributes memory across many GPUs. Flash Attention rearranges memory inside one GPU.

Pretraining at scale is a memory engineering problem.
Parallelism distributes memory across many GPUs.
Flash Attention rearranges memory inside one GPU.