Skip to content

Why pretraining is a memory engineering problem (parallelism and Flash Attention)

By the end of the previous lesson, you knew the target: a 70-billion-parameter model wants 1.4 trillion training tokens. You did not yet know how that run actually happens on real hardware. The model does not fit on one GPU. Activations from a forward pass do not fit either. Attention itself, the operation Phase 2 was built around, is memory-bound rather than compute-bound on real hardware in a way the textbook math does not warn you about. This lesson is about the four engineering tricks that solve those problems: data parallelism distributes the batch across GPUs, the ZeRO optimization stops duplicating the things every GPU is storing, model parallelism splits the model itself across GPUs, and Flash Attention rearranges the attention computation to use the GPU’s memory hierarchy more cleverly. The first three distribute memory across many GPUs; the fourth uses the memory inside one GPU better.

This is lesson 3 of Phase 3, How models are trained at scale. Phase 3 builds toward describing what it takes to train a frontier model and why most organizations cannot. This lesson takes the previous lesson’s compute-and-data target and explains the engineering layer that makes that target reachable in practice. The previous lesson in the phase was Why scale matters: scaling laws and Chinchilla.

Prerequisites: the scaling laws and Chinchilla lesson (Phase 3 lesson 2). You should be comfortable with FLOPs as the cost unit, the Kaplan empirical claims, and the Chinchilla rule. No GPU programming background needed; the lesson works at the conceptual level.

  • Describe what data parallelism contributes to training and the communication cost it incurs
  • Distinguish ZeRO-1, ZeRO-2, and ZeRO-3 by which redundant quantity each one partitions across GPUs
  • Explain what model parallelism is and name its three common variants
  • Describe how Flash Attention uses the GPU memory hierarchy to remove the data-movement bottleneck of standard attention
  • Read time: about 26 minutes (longest lesson in Phase 3 because it covers four engineering concepts at once)
  • Practice time: about 18 minutes (a memory-budgeting exercise plus flashcards)
  • Difficulty: standard