Skip to content

Cheatsheet: Why pretraining is a memory engineering problem (parallelism and Flash Attention)

Pretraining memory must hold:
parameters + gradients + optimizer states (Adam: 2 moments per param)
+ activations from the forward pass
A single 80 GB GPU cannot hold this for any frontier model.
The field's answer: distribute memory across GPUs (parallelism),
and rearrange memory inside one GPU (Flash Attention).
WhatDetail
ParametersThe model weights themselves
GradientsOne per parameter, computed during backward pass
Optimizer states (Adam)First and second moments per parameter; ~2x parameter memory in same precision, more in mixed precision (moments typically stored in FP32 even when weights are FP16)
ActivationsPer-layer values from forward pass (needed for backward pass). Grow with model size, batch size, and O(n²) in context length
Hardware referenceNVIDIA H100 has 80 GB. “Tens of gigabytes per GPU” is the order of magnitude.
Idea: split the batch across GPUs; every GPU has a full model copy.
gradients averaged across GPUs before each weight update.
Reduces: activation memory (each GPU sees only part of the batch)
Does NOT reduce: parameter / gradient / optimizer-state memory
(each GPU still has the whole thing)
Communication cost: grows with GPU count; non-linear scaling above some N
ZeRO = Zero Redundancy Optimization. Layered on top of data parallelism.
Removes the duplicated copies of high-memory quantities across GPUs.
LevelPartitionsMemory savingsCommunication cost
ZeRO-1Optimizer statesBig (states are ~2x parameter memory)Modest
ZeRO-2+ GradientsBiggerLarger
ZeRO-3+ ParametersBiggestLargest (per-step parameter fetches)
Idea: split the model itself across GPUs (orthogonal to data parallelism).
Used when the model is too big for one GPU even briefly.
VariantWhat it splits
Tensor parallelismLarge matrix multiplications (cut across GPUs, slices combined)
Pipeline parallelismThe layer stack (GPU 1 holds early layers, GPU 2 holds middle layers, etc.)
Expert parallelismMoE experts (specific to mixture-of-experts architectures)

Frontier runs combine multiple flavors: typically ZeRO-3 + tensor parallelism + pipeline parallelism, each addressing a different bottleneck.

Two GPU memories:
HBM = ~tens of GB, ~few TB/s ("GPU memory" on a spec sheet)
SRAM = ~tens of MB, ~tens of TB/s, on-chip next to compute (~10x faster)
Standard attention: read/write Q, K, V, QK^T, softmax(QK^T) to/from HBM
many times. Bottleneck is HBM bandwidth, not compute.
Flash Attention: tile the computation; each tile fits in SRAM;
do matmul + partial softmax + multiply by V end-to-end
in SRAM; write back to HBM only once.
Mathematically exact. Same output as standard attention.
Speedup is entirely from data movement.
PropertyDetail
OriginTri Dao and collaborators at Stanford, 2022
Other nameTiling
Mathematical trickSoftmax can be computed block by block with one scaling factor per block
ResultMaterially faster, especially at long context lengths
Why it mattersLong context windows (tens to hundreds of thousands of tokens) became practical largely because of this technique
PhenomenonWhat it tells you
”Trained on a very large GPU cluster”Data parallelism + ZeRO + model parallelism stack making something tractable that would otherwise be impossible. The cluster size is a real moat.
”Long context windows”Largely a Flash Attention story. Not a smarter algorithm; the same algorithm that fits at long sequence lengths because of the SRAM-vs-HBM trick.
”Only a handful of labs train frontier models”Partly the hardware investment, partly the operational complexity of running these techniques together at scale.
PitfallReality
Data parallelism gives linear speedupNo. Communication cost grows with GPU count; there is a regime where adding GPUs hurts.
ZeRO is separate from data parallelismNo. ZeRO is data parallelism with the redundant copies partitioned across GPUs.
Model and data parallelism are alternativesNo. They address different bottlenecks and are usually combined in frontier runs.
Flash Attention is faster because it skips workNo. Same output as standard attention. Speedup is entirely from fewer HBM read/write operations.
Knowing the techniques means you can train a frontier modelNo. Operational layers (hardware acquisition, data pipeline, electricity, debugging multi-week runs) are enormous on top.
  • Activation: the value at a layer during the forward pass; needed by the backward pass to compute gradients. Memory cost grows with model size, batch size, and the square of context length.
  • Adam (optimizer): standard training optimizer. Tracks first and second moments per parameter, roughly 2x parameter memory.
  • Data parallelism (DP): distribute the training batch across GPUs; each GPU has a full model copy.
  • ZeRO (Zero Redundancy Optimizer): data parallelism with the duplicated copies of optimizer states / gradients / parameters partitioned across GPUs. Three increasing levels.
  • Model parallelism: distribute the model itself across GPUs. Variants: tensor, pipeline, expert.
  • Tensor parallelism: cut large matrix multiplications across GPUs.
  • Pipeline parallelism: distribute layers across GPUs; batch flows through like an assembly line.
  • HBM (high-bandwidth memory): the GPU memory on a spec sheet. Tens of GB, a few TB/s.
  • SRAM: on-chip GPU memory next to compute. Tens of MB, tens of TB/s, roughly 10x faster than HBM.
  • Flash Attention: Stanford-2022 technique to compute attention by tiling into SRAM-sized blocks. Mathematically exact, materially faster.

Pretraining at scale is a memory engineering problem.
Parallelism distributes memory across many GPUs.
Flash Attention rearranges memory inside one GPU.