Cheatsheet: Why pretraining is a memory engineering problem (parallelism and Flash Attention)
The one idea that matters
Section titled “The one idea that matters”Pretraining memory must hold: parameters + gradients + optimizer states (Adam: 2 moments per param) + activations from the forward pass
A single 80 GB GPU cannot hold this for any frontier model.The field's answer: distribute memory across GPUs (parallelism), and rearrange memory inside one GPU (Flash Attention).What has to fit during training
Section titled “What has to fit during training”| What | Detail |
|---|---|
| Parameters | The model weights themselves |
| Gradients | One per parameter, computed during backward pass |
| Optimizer states (Adam) | First and second moments per parameter; ~2x parameter memory in same precision, more in mixed precision (moments typically stored in FP32 even when weights are FP16) |
| Activations | Per-layer values from forward pass (needed for backward pass). Grow with model size, batch size, and O(n²) in context length |
| Hardware reference | NVIDIA H100 has 80 GB. “Tens of gigabytes per GPU” is the order of magnitude. |
Data parallelism
Section titled “Data parallelism”Idea: split the batch across GPUs; every GPU has a full model copy. gradients averaged across GPUs before each weight update.
Reduces: activation memory (each GPU sees only part of the batch)Does NOT reduce: parameter / gradient / optimizer-state memory (each GPU still has the whole thing)Communication cost: grows with GPU count; non-linear scaling above some NZeRO levels
Section titled “ZeRO levels”ZeRO = Zero Redundancy Optimization. Layered on top of data parallelism.Removes the duplicated copies of high-memory quantities across GPUs.| Level | Partitions | Memory savings | Communication cost |
|---|---|---|---|
| ZeRO-1 | Optimizer states | Big (states are ~2x parameter memory) | Modest |
| ZeRO-2 | + Gradients | Bigger | Larger |
| ZeRO-3 | + Parameters | Biggest | Largest (per-step parameter fetches) |
Model parallelism
Section titled “Model parallelism”Idea: split the model itself across GPUs (orthogonal to data parallelism).Used when the model is too big for one GPU even briefly.| Variant | What it splits |
|---|---|
| Tensor parallelism | Large matrix multiplications (cut across GPUs, slices combined) |
| Pipeline parallelism | The layer stack (GPU 1 holds early layers, GPU 2 holds middle layers, etc.) |
| Expert parallelism | MoE experts (specific to mixture-of-experts architectures) |
Frontier runs combine multiple flavors: typically ZeRO-3 + tensor parallelism + pipeline parallelism, each addressing a different bottleneck.
Flash Attention
Section titled “Flash Attention”Two GPU memories: HBM = ~tens of GB, ~few TB/s ("GPU memory" on a spec sheet) SRAM = ~tens of MB, ~tens of TB/s, on-chip next to compute (~10x faster)
Standard attention: read/write Q, K, V, QK^T, softmax(QK^T) to/from HBM many times. Bottleneck is HBM bandwidth, not compute.
Flash Attention: tile the computation; each tile fits in SRAM; do matmul + partial softmax + multiply by V end-to-end in SRAM; write back to HBM only once.
Mathematically exact. Same output as standard attention.Speedup is entirely from data movement.| Property | Detail |
|---|---|
| Origin | Tri Dao and collaborators at Stanford, 2022 |
| Other name | Tiling |
| Mathematical trick | Softmax can be computed block by block with one scaling factor per block |
| Result | Materially faster, especially at long context lengths |
| Why it matters | Long context windows (tens to hundreds of thousands of tokens) became practical largely because of this technique |
Why this matters when you use AI
Section titled “Why this matters when you use AI”| Phenomenon | What it tells you |
|---|---|
| ”Trained on a very large GPU cluster” | Data parallelism + ZeRO + model parallelism stack making something tractable that would otherwise be impossible. The cluster size is a real moat. |
| ”Long context windows” | Largely a Flash Attention story. Not a smarter algorithm; the same algorithm that fits at long sequence lengths because of the SRAM-vs-HBM trick. |
| ”Only a handful of labs train frontier models” | Partly the hardware investment, partly the operational complexity of running these techniques together at scale. |
Pitfalls to dodge
Section titled “Pitfalls to dodge”| Pitfall | Reality |
|---|---|
| Data parallelism gives linear speedup | No. Communication cost grows with GPU count; there is a regime where adding GPUs hurts. |
| ZeRO is separate from data parallelism | No. ZeRO is data parallelism with the redundant copies partitioned across GPUs. |
| Model and data parallelism are alternatives | No. They address different bottlenecks and are usually combined in frontier runs. |
| Flash Attention is faster because it skips work | No. Same output as standard attention. Speedup is entirely from fewer HBM read/write operations. |
| Knowing the techniques means you can train a frontier model | No. Operational layers (hardware acquisition, data pipeline, electricity, debugging multi-week runs) are enormous on top. |
Glossary
Section titled “Glossary”- Activation: the value at a layer during the forward pass; needed by the backward pass to compute gradients. Memory cost grows with model size, batch size, and the square of context length.
- Adam (optimizer): standard training optimizer. Tracks first and second moments per parameter, roughly 2x parameter memory.
- Data parallelism (DP): distribute the training batch across GPUs; each GPU has a full model copy.
- ZeRO (Zero Redundancy Optimizer): data parallelism with the duplicated copies of optimizer states / gradients / parameters partitioned across GPUs. Three increasing levels.
- Model parallelism: distribute the model itself across GPUs. Variants: tensor, pipeline, expert.
- Tensor parallelism: cut large matrix multiplications across GPUs.
- Pipeline parallelism: distribute layers across GPUs; batch flows through like an assembly line.
- HBM (high-bandwidth memory): the GPU memory on a spec sheet. Tens of GB, a few TB/s.
- SRAM: on-chip GPU memory next to compute. Tens of MB, tens of TB/s, roughly 10x faster than HBM.
- Flash Attention: Stanford-2022 technique to compute attention by tiling into SRAM-sized blocks. Mathematically exact, materially faster.
Pretraining at scale is a memory engineering problem.
Parallelism distributes memory across many GPUs.
Flash Attention rearranges memory inside one GPU.