Writing fast kernels, Triton and XLA
What you’ll learn
Section titled “What you’ll learn”This lesson turns lesson 5’s hardware picture into code, and into the systems instinct that explains most modern performance papers. The source curriculum is Stanford CS336, Lecture 6, by Tatsunori Hashimoto and Percy Liang, with lectures freely available on YouTube and the course at cs336.stanford.edu.
You will learn what a GPU kernel is and when to write a custom one; why fusion is the single biggest lever for memory-bound code (the intermediate stays in SRAM/registers, HBM is round-tripped once); the two practical paths, Triton (hand-written kernels in Python with block-level abstractions) and XLA (a compiler that fuses standard-op graphs automatically); and what FlashAttention changes about attention without changing the math, the canonical fusion case study.
Where this fits
Section titled “Where this fits”This is lesson 6 of 14, the second lesson of Phase 2 (systems and efficiency). It is the code-level companion to lesson 5’s hardware picture and lesson 2’s arithmetic-intensity accounting. After it, parallelism (lesson 7) scales out from one fast device to many; inference (lesson 8) closes the phase. FlashAttention here is also the natural extension of the KV-cache story from lesson 4, and the two combine for long-context inference.
Before you start
Section titled “Before you start”Prerequisites: lesson 5 (the GPU execution model and memory hierarchy this lesson uses to explain why fusion works). Lesson 2’s arithmetic intensity is the abstract version of the same idea. Familiarity with PyTorch is assumed; you do not need prior CUDA experience to read this lesson, but understanding the kernel-launch shape (read HBM, compute, write HBM) is essential.
About the math
Section titled “About the math”None new. The FlashAttention story involves a tile-wise softmax with running max and sum across tiles, which is a numerically careful trick; the lesson explains what it accomplishes rather than deriving it. No formulas to memorize.
By the end, you’ll be able to
Section titled “By the end, you’ll be able to”The single capability this lesson builds: explain what a kernel is and how Triton/XLA make model operations fast, including the FlashAttention fusion idea. Concretely, you will be able to:
- Explain what a GPU kernel is and when to write a custom one
- Explain why fusion is the largest lever for memory-bound code
- Distinguish Triton (hand-written kernels) from XLA (compiler-fused graphs)
- Explain what FlashAttention changes about attention without changing the math
- Connect fusion to the arithmetic-intensity story from lessons 2 and 5
Time and difficulty
Section titled “Time and difficulty”- Read time: about 13 minutes
- Practice time: about 10 minutes (explain a speed-up claim + diagnose a memory-bound chain, plus flashcards)
- Difficulty: deep (Stage C; conceptual, code-aware, reads through lesson 2/5’s accounting and hardware picture)