Practice: Writing fast kernels
Self-check
Section titled “Self-check”Seven short questions. Answer each before opening the collapsible.
1. What is a kernel, and why do you ever need a custom one?
Show answer
A kernel is code that runs on the GPU for one operation. The CPU launches kernels onto the device. PyTorch’s standard ops are kernels under the hood (often from cuBLAS/cuDNN). You need a custom kernel when stock ones do not cover your operation well, or when you want to fuse several operations into one to save HBM round-trips.
2. Why is fusion the single biggest lever for memory-bound code?
Show answer
Each kernel launch reads inputs from HBM, computes, and writes outputs back to HBM. Two sequential ops mean two HBM round-trips. A fused kernel does both in one launch, keeping the intermediate in registers/SRAM and round-tripping HBM once. Arithmetic intensity (FLOPs per byte moved) rises, and a memory-bound chain becomes far less memory-bound.
3. What does Triton give you?
Show answer
A language and compiler from OpenAI for writing GPU kernels in Python-like code, thinking in blocks/tiles, with the compiler handling warp scheduling, register allocation, and SRAM tiling. Custom fused kernels become a few dozen lines instead of hundreds of CUDA lines. Modern fast attention, fused norms, and MoE dispatch are commonly written in it.
4. What does XLA give you, and how does it differ from Triton?
Show answer
XLA (accelerated linear algebra) is a compiler that takes a graph of standard operations and emits fused kernels automatically: you do not write GPU code, you write the model in JAX, TensorFlow, or torch.compile and the compiler does the fusion. Triton is hand-written kernels in Python; XLA is a compiler doing fusion across a graph. You trade control for automatic coverage.
5. Why does naive attention become memory-bound at long sequences?
Show answer
It materializes an N-by-N scores matrix in HBM. At long N that matrix is huge, and the pipeline writes scores, reads them back for the softmax, writes the softmax, reads it back to multiply by V. Each HBM trip is large, and the bandwidth saturates while the tensor cores idle. The math is fine; the data movement is the problem.
6. What does FlashAttention change, without changing the math?
Show answer
It tiles the attention computation so the N-by-N scores matrix is never materialized in HBM. Blocks of Q, K, and V are loaded into SRAM; partial scores, a tile-wise softmax (tracking running max and sum), and the V multiplication all happen in SRAM; only inputs and the final output cross HBM. Same results, large speed-up (~2-4x forward) and large memory savings; long contexts become feasible.
7. How does fusion connect to the arithmetic-intensity story from lessons 2 and 5?
Show answer
Fusion is the code-level mechanism that raises arithmetic intensity. Lessons 2 and 5 named it abstractly (FLOPs per byte moved; data must be reused in fast memory). A fused kernel does many ops per HBM byte transferred, so the operation moves from memory-bound toward compute-bound and the tensor cores stay fed.
Try it yourself: explain the speed-up
Section titled “Try it yourself: explain the speed-up”About 10 minutes, no code. Practice the diagnostic muscle.
Part A: a paper claims 2x throughput on long-context attention without changing the model. What is almost certainly responsible, and what should you look for in the paper to confirm?
What you’ll get
A fused kernel, almost certainly some variant of FlashAttention. Look for: tiling the attention computation so the N-by-N scores matrix never materializes in HBM, a custom kernel (often Triton), block sizes (Q/K/V tiles), and a numerically careful tile-wise softmax with running max/sum. If the paper’s gain is on long sequences specifically and the model itself is unchanged, the story is almost always data-movement, not new math.
Part B (reasoning). A team’s training loop has a sequence of small ops between each matmul: a bias add, an activation, a normalization. Their GPU reports ~30% of peak FLOPs. What single change would you propose, and what gain do you expect?
What you should notice
Fuse the small ops with the surrounding matmul (or at least with each other) into one kernel, so the intermediates stay in SRAM and HBM is traversed once per chain instead of three times. The expected gain on memory-bound chains is large (often 1.5-3x on those segments), because the elementwise ops are entirely bandwidth-bound and fusion eliminates most of the round-trips. Either write a Triton kernel that does the whole chunk, or run the model under torch.compile / an XLA path that does it automatically.
Part C (reasoning). Why is it accurate to say that most “X is faster” speed-ups in modern LLM systems work are about data movement, not better algorithms?
What you should notice
Because the operations are mostly memory-bound at the scales we care about: tensor cores can compute faster than HBM can deliver data, so reducing data movement (fusion, tiling, lower-precision storage) almost always helps more than changing the math. Algorithmic changes (sparser attention, sub-quadratic alternatives) exist, but the headline systems wins of the past few years (FlashAttention, kernel fusion, mixed precision served well) are data-movement wins. The math hasn’t changed; the way it touches memory has.
Flashcards
Section titled “Flashcards”Nine cards. Click any card to reveal the answer. Use the Print flashcards button to lay the set out one card per page for offline review.
Q. What is a kernel?
GPU code for one operation; CPU launches kernels onto the device. PyTorch’s standard ops are kernels (often cuBLAS/cuDNN). Custom kernels are for operations stock ones miss, or for fusing several ops into one.
Q. Why is fusion the biggest lever for memory-bound code?
Each unfused op round-trips HBM. Fusing two ops keeps the intermediate in SRAM/registers, round-tripping once. Arithmetic intensity rises; memory-bound chains become much less memory-bound. The math is identical; data movement halves.
Q. What is Triton?
OpenAI’s language/compiler for writing GPU kernels in Python-like code with block-level thinking. The compiler handles warp scheduling, register allocation, SRAM tiling. Custom fused kernels become tens of lines instead of hundreds of CUDA.
Q. What is XLA, and how is it different from Triton?
A graph compiler (used by JAX, TF, torch.compile) that fuses standard ops into kernels automatically. You don’t write GPU code; you write the model and the compiler does fusion. Triton: hand-written kernels; XLA: compiler-fused graphs.
Q. Why is naive attention memory-bound at long N?
It materializes the N-by-N scores matrix in HBM (then softmax, then multiply by V), with multiple HBM round-trips. At long sequences the matrix is huge and bandwidth saturates while tensor cores idle.
Q. What does FlashAttention change?
Tiles attention so the N-by-N scores never sit in HBM; blocks of Q/K/V are loaded into SRAM, with a tile-wise softmax (running max + sum) and accumulated output. Same math, ~2-4x faster forward, much less memory.
Q. How does fusion connect to arithmetic intensity?
Fusion is the code-level lever that raises FLOPs-per-byte: more ops per HBM trip. Operations move from memory-bound (cores idle on HBM) toward compute-bound (cores fed from SRAM). The lesson-2 number made concrete.
Q. Why are most modern systems speed-ups about data movement?
At scale, ops are mostly memory-bound: tensor cores can compute faster than HBM can deliver. Reducing data movement (fusion, tiling, lower-precision storage) helps more than algorithmic changes. Headline wins: FlashAttention, kernel fusion, mixed precision.
Q. When would you write a Triton kernel?
When stock kernels don’t cover your case (a fused chain of small ops, an unusual attention variant, MoE routing), and the chain is memory-bound enough that a custom fused kernel would meaningfully raise arithmetic intensity.