Lesson: Why pretraining is a memory engineering problem (parallelism and Flash Attention)
Attention is O(n²) in sequence length, the textbooks say. Modern decoder-only transformers have billions to hundreds of billions of parameters, the textbooks say. Both statements are mathematically true and operationally useless: they tell you almost nothing about whether your training run will actually fit on the hardware you have. On real GPUs, the attention computation is memory-bound rather than compute-bound, the parameters do not fit on one device, and even if they did, the activations from a forward pass would not. Pretraining at frontier scale is, more than anything else, a memory engineering problem.
The previous lesson gave the target: a 70-billion-parameter model wants about 1.4 trillion training tokens for Chinchilla-aligned pretraining. Stating the target is the easy part. This lesson is about how the run actually happens on real hardware. Four engineering tricks, in order: data parallelism, the ZeRO optimization on top of data parallelism, model parallelism, and Flash Attention. The first three distribute memory across many GPUs. The fourth uses the memory hierarchy within one GPU more cleverly. Together they explain how trillions of tokens of training are tractable in practice.
What has to fit in memory during training
Section titled “What has to fit in memory during training”To understand why memory is the bottleneck, walk through one training step.
The model is a stack of transformer blocks (everything you learned in Phase 2). For training, you initialize the weights and prepare to update them. One step has three parts.
Forward pass. You feed a batch of input data through the network. Each layer produces values, called activations, that you have to keep in memory because the backward pass is going to need them. The amount of memory you spend here scales with three things at once: the model size (more parameters means more activations per layer), the batch size (more examples in parallel means proportionally more activations), and the context length (because attention is O(n²) in sequence length, so longer contexts cost more memory in the attention layers specifically).
Backward pass. You compute the loss (how wrong the model was on this batch), then walk backward through the network to find the gradient: the direction in which each weight should be nudged to reduce the loss. The gradients also live in memory while the step runs.
Weight update. A modern optimizer like Adam does not just apply the gradient to the weights. It also tracks two moments: a moving average of the gradient and a moving average of the squared gradient. Both have to be stored, per parameter. So for a model with N parameters, Adam needs roughly 2N additional numbers in memory beyond the weights and gradients themselves. (In practice, those moments are typically stored in FP32 precision even when the weights are FP16, which makes the optimizer-state memory cost more than 2x parameters in mixed precision; the next lesson covers precision in detail.)
Adding up: parameters, gradients, two optimizer moments per parameter, plus all the activations from the forward pass. The total grows fast.
Now look at the hardware. The Stanford lecturer’s reference is the NVIDIA H100, with 80 GB of memory. Other cards have less; some specialized cards have more. The order of magnitude is “tens of gigabytes per GPU.” That is what you have to fit everything into.
For a frontier-scale model, the total of all those quantities (parameters, gradients, optimizer states, activations) exceeds 80 GB by a wide margin. The lecturer’s framing: there is “a lot of things to save” against memory that “is not unlimited.” There is no way to fit one frontier-scale training step on a single GPU.
So we use many GPUs. The question becomes: how do you split the work across them. There are a few different answers to that question, each with different trade-offs.
Data parallelism: split the batch, copy the model
Section titled “Data parallelism: split the batch, copy the model”The first and simplest approach is data parallelism, sometimes abbreviated DP.
The idea: every GPU gets its own complete copy of the model. The training batch is divided across GPUs, and each GPU runs its forward pass and backward pass on its slice independently. This works because forward and backward passes on different batches do not need to talk to each other, mostly.
The memory savings are real. The activation memory (which depends on batch size) is now divided across GPUs. The whole pipeline is doing, in effect, a forward pass on a much smaller batch per device, just done many times in parallel.
Almost real. There is one synchronization point. After each GPU computes its gradient on its batch slice, those gradients have to be averaged across all the GPUs before any of them can apply the weight update. Otherwise each GPU would apply a slightly different update to its copy of the model, and the copies would drift apart. This averaging step is a communication cost: bytes moving between GPUs over the interconnect. Communication is slower than computation. The more GPUs you add, the more communication you incur.
Two things to know about data parallelism:
- It distributes the batch memory, not the model memory. Every GPU still needs its complete copy of the parameters, gradients, and optimizer states. So data parallelism only helps if your model fits on one GPU in the first place.
- The communication cost is real. You do not get linear speedup from N GPUs; you get N GPUs of throughput minus the cost of averaging gradients, which grows with N.
For models that do fit on one GPU (smaller models, or larger models in lower precision), data parallelism alone is often enough. For frontier models, it is not.
ZeRO: stop duplicating what you do not need to
Section titled “ZeRO: stop duplicating what you do not need to”The big inefficiency in plain data parallelism is duplication. Every GPU stores the full set of parameters, the full set of gradients, and the full set of optimizer states. That is the same information stored N times across N GPUs, taking N times the memory it has to.
The ZeRO technique (Zero Redundancy Optimizer) addresses this by partitioning the duplicated quantities across GPUs instead of replicating them. The lecturer walks through three increasing levels.
- ZeRO-1 partitions the optimizer states across GPUs. Each GPU holds the optimizer’s first and second moments for only a slice of the parameters. The savings are substantial because optimizer states (two moments per parameter, so roughly 2x the parameter memory) are bigger than the parameters themselves.
- ZeRO-2 partitions the gradients as well. Each GPU only holds the gradient for the slice of parameters it is responsible for.
- ZeRO-3 partitions the parameters themselves. Each GPU only stores its slice of the model weights at any given time.
As you go from ZeRO-1 to ZeRO-3, the memory savings grow and the communication cost grows. ZeRO-3 in particular requires GPUs to fetch the parts of the model they need from each other on demand during the forward pass, which is bandwidth-hungry. Different runs pick different ZeRO levels depending on how memory-constrained vs how communication-constrained the setup is.
The takeaway: ZeRO is not a separate paradigm from data parallelism; it is data parallelism with the duplication removed. When you read about a training run “using ZeRO-3,” it is data parallelism plus parameter sharding.
Model parallelism: split the model itself
Section titled “Model parallelism: split the model itself”Data parallelism (with or without ZeRO) ultimately requires that the model fit on one GPU, even if briefly. If the model is too big for that, you need to split the model itself across GPUs. That is model parallelism.
The Stanford lecturer flags three variants without going deep, and our coverage will match: knowing they exist matters more than memorizing the details.
- Tensor parallelism. Big matrix multiplications inside the model are cut across GPUs. Each GPU computes one slice of the multiplication; the slices are combined across the network. Used in most frontier-class training.
- Pipeline parallelism. The model’s layers are split across GPUs: GPU 1 holds layers 1-3, GPU 2 holds layers 4-6, and so on. A batch flows through the layers like an assembly line. Useful when the model is so deep that even tensor parallelism cannot fit it on one device’s worth of layers.
- Context parallelism (CP) and sequence parallelism (SP). When the input sequence itself is the bottleneck (as it is for the 1M+ token contexts standard on 2026 frontier models), the sequence dimension gets split across GPUs. Each GPU holds a slice of the tokens; attention is computed in pieces that get combined across the network (Liu et al. 2024, arxiv 2411.01783; also documented in NVIDIA’s Megatron-LM context-parallelism notes). CP and SP are critical at long context where neither tensor nor pipeline parallelism alone fit the activation memory.
- Expert parallelism. Specific to mixture-of-experts (MoE) architectures (Phase 7 covers MoE). Different experts live on different GPUs; tokens get routed to the right expert’s GPU.
In practice, frontier training runs combine multiple flavors. A typical setup might use ZeRO-3 across a group of GPUs for parameter sharding, tensor parallelism inside each group for the matrix multiplications, and pipeline parallelism across groups for layer distribution. That is one of the reasons frontier training is operationally complex even when the theoretical idea (predict the next token at scale) is simple.
Flash Attention: use the GPU memory hierarchy
Section titled “Flash Attention: use the GPU memory hierarchy”So far the lesson has been about memory across multiple GPUs. The last technique is about memory inside a single GPU. It targets attention specifically, the operation Phase 2 was built around.
A GPU has two relevant kinds of memory.
- HBM (high-bandwidth memory) is the “GPU memory” you read about on a spec sheet. On an H100 it is 80 GB. Big, but relatively slow when measured in bandwidth: a few terabytes per second.
- SRAM is on-chip memory, sitting right next to the compute units. Tens of megabytes per GPU. Much smaller, but roughly ten times faster: tens of terabytes per second.
The standard (“vanilla”) way to compute attention does not use this hierarchy well. The attention formula is softmax(QK^T / √d) · V, with three big matrices (queries, keys, values) and one giant intermediate matrix QK^T. The standard implementation reads Q and K from HBM, computes QK^T, writes the result back to HBM, reads the result again to compute the softmax, writes that back to HBM, reads it once more along with V to do the final multiply, and writes the output to HBM. Lots of reading from and writing to the slow memory.
Even though the GPU’s compute is enormously fast, the bottleneck is data movement between HBM and the compute units. The Stanford lecturer’s framing: “GPU is very very fast but you spend a lot of time just loading matrices from the memory.”
Why did the standard implementation do this? Because softmax needs the whole row of QK^T to compute correctly: each entry has to be normalized against all the other entries in its row to make the row sum to one.
So at first glance you appear to have a hard ordering. Compute the whole QK^T. Store it. Then do the softmax over rows. Then multiply by V. Each step waits for the previous one to finish writing to memory.
Flash Attention is the insight that you do not have to. Developed by Tri Dao and collaborators at Stanford in 2022, the technique relies on a mathematical trick about softmax: you can compute the softmax of a big matrix block by block, as long as you keep track of one scaling factor per block. The blocks can be tiled (the technique’s other name) into pieces small enough to fit in SRAM. Each tile gets read from HBM into SRAM once, computed end-to-end (the matrix multiply, the partial softmax, the multiply against the value tile, all of it), and the result is sent back to HBM once.
The result: only one read of Q, K, V from HBM, all the intermediate computation happens in fast SRAM, only one write of the final output to HBM. The data-movement bottleneck is removed. The same mathematical answer (Flash Attention is exact, not an approximation) but materially faster on real hardware.
The reason this matters for pretraining specifically: attention is O(n²) in sequence length, so its memory cost grows quickly with longer contexts. Flash Attention is what makes long-context training tractable in practice. Without it, the long-context era would be prohibitively expensive in HBM bandwidth.
Why this matters when you use AI
Section titled “Why this matters when you use AI”The four techniques in this lesson are deeply invisible at runtime. You will never see “ZeRO-3” mentioned in a chat assistant’s response. But several user-facing facts about modern AI trace directly back here.
- Long context windows are largely a Flash Attention story. The long-context features that frontier assistants advertise (context windows in the high tens of thousands or hundreds of thousands of tokens) became practical in part because Flash Attention fit the attention computation at those sequence lengths on the available hardware. Other ingredients matter too (position encoding choices, KV-cache tricks), but Flash Attention is the central memory-side reason long context shipped.
- “Frontier model” is partly a hardware-cluster story. When you read that a model was trained on a very large GPU cluster, that is the data-parallelism + ZeRO + model-parallelism stack making something tractable that would otherwise be impossible. The size of these training clusters is a real moat: there are only a handful of organizations with the hardware to run a frontier pretraining loop end-to-end.
- Smaller models can be trained much more cheaply. If your model fits on a small number of GPUs, plain data parallelism with no model-parallel tricks is enough. Most fine-tuning runs (Phase 4 territory), most academic experiments, and most enterprise deployments work this way. The complexity of this lesson is about pretraining specifically; downstream stages are much simpler.
Common pitfalls
Section titled “Common pitfalls”A few mistakes worth naming up front, faster than catching them later.
“Data parallelism speeds up training linearly with GPU count.” It does not. Each new GPU you add increases the amount of communication needed to aggregate gradients across the cluster. There is a regime where communication dominates and adding GPUs hurts more than it helps. Frontier training picks GPU counts and topologies carefully to stay below that threshold.
“ZeRO is a separate technique from data parallelism.” It is not. ZeRO is a memory optimization layered on top of data parallelism. Same partitioning of data across GPUs; the addition is that the redundant copies of parameters, gradients, and optimizer states are also partitioned.
“Model parallelism and data parallelism are alternatives.” They are usually combined. A real frontier training setup typically runs ZeRO-3 (parameter sharding) plus tensor parallelism (matrix-multiplication splitting) plus pipeline parallelism (layer distribution) all at once, with each addressing a different bottleneck.
“Flash Attention is faster because it skips computation.” No. Flash Attention is mathematically equivalent to standard attention; it produces the same output. The speedup is entirely from data-movement: fewer reads and writes between HBM and the compute units, by way of tiling the computation into SRAM-sized blocks.
“Once you understand the techniques, you can train a frontier model.” The techniques are only one layer of what frontier training requires. The other layers (acquiring the hardware, building the data pipeline, paying the electricity bill, debugging multi-week runs that fail at hour 700) are operationally enormous.
What you should remember
Section titled “What you should remember”- Memory is the bottleneck during training. You have to fit parameters, gradients, optimizer states (Adam tracks two moments per parameter, so roughly 2x extra), and all the activations from the forward pass. A single H100 with 80 GB cannot hold this for any frontier-scale model.
- Data parallelism splits the batch across GPUs. Each GPU has its own full copy of the model and works on a slice of the batch. Gradients are averaged across GPUs before the weight update. The communication cost grows with the number of GPUs.
- ZeRO removes the duplication in data parallelism. ZeRO-1 partitions optimizer states, ZeRO-2 also partitions gradients, ZeRO-3 also partitions parameters. More memory savings, more communication cost, as you go up the levels.
- Model parallelism splits the model itself across GPUs. Tensor parallelism cuts matrix multiplications, pipeline parallelism cuts the layer stack, expert parallelism cuts MoE experts. Frontier runs combine multiple flavors.
- Flash Attention is a memory-hierarchy trick inside one GPU. Tile the attention computation into SRAM-sized blocks, do each tile end-to-end in fast SRAM, write back to HBM once. Mathematically exact, not an approximation. Several times faster on long sequences in practice.
Parallelism distributes the bytes across GPUs. The next lesson handles the other half of the problem: making each byte smaller. Quantization and mixed precision let the same model fit in much less memory by storing weights in lower-precision formats than the standard 32-bit floating point (and the optimizer-state cost mentioned earlier turns out to be even higher than the baseline math suggested, in practice). Two complementary axes, not a sequence. After that, Phase 3 closes and the curriculum moves to what happens to a pretrained base model after training.
If you remember one thing
Section titled “If you remember one thing”Pretraining at scale is a memory engineering problem.
Parallelism distributes memory across many GPUs.
Flash Attention rearranges memory inside one GPU.