Cheatsheet: Writing fast kernels
Kernel basics
Section titled “Kernel basics”- Kernel = GPU code for one op (matmul, softmax, norm, …).
- CPU launches kernels onto the device.
- PyTorch’s stock ops are kernels (often cuBLAS/cuDNN).
- Write a custom kernel when: stock doesn’t cover your case, or you want to fuse several ops.
Why fuse (the biggest lever)
Section titled “Why fuse (the biggest lever)”Unfused (2 ops): HBM -> compute A -> HBM -> compute B -> HBMFused (1 op): HBM -> compute A -> [SRAM/regs] -> compute B -> HBM- Eliminates HBM round-trips for intermediates.
- Arithmetic intensity (FLOPs/byte from HBM) rises.
- Memory-bound chains become much less memory-bound.
Triton vs XLA
Section titled “Triton vs XLA”| Triton | XLA | |
|---|---|---|
| What you write | Kernel code (Python-like, block-level) | A graph of standard ops (model) |
| Who fuses | You (compiler handles warps/registers/SRAM) | The compiler, automatically |
| Control | High (custom kernels) | Lower (graph-driven) |
| Used in | OpenAI, custom kernels (FlashAttention, MoE) | JAX, TF, torch.compile |
Both keep data in fast memory across multiple ops.
FlashAttention (the canonical example)
Section titled “FlashAttention (the canonical example)”Problem: naive attention materializes an N x N scores matrix in HBM at long N. Each step (scores, softmax, V multiply) round-trips HBM. Bandwidth saturates; tensor cores idle.
Fix: tile the computation so scores never sit in HBM:
- Load a block of Q + a block of K into SRAM.
- Compute partial scores there.
- Apply tile-wise softmax (track running max + sum across tiles).
- Multiply by a block of V; accumulate output.
- Stream the next K/V block; repeat.
Result: same math, ~2-4x faster forward, large memory savings, long contexts unlock.
How fusion connects to lesson 2
Section titled “How fusion connects to lesson 2”- Lesson 2: arithmetic intensity = FLOPs / bytes moved.
- Memory-bound (low intensity) means tensor cores idle on HBM.
- Fusion is the code-level mechanism that raises intensity (many ops per HBM trip).
- The hardware lesson’s “stage data, reuse it” becomes real here.
Diagnostic: chain of small ops between matmuls?
Section titled “Diagnostic: chain of small ops between matmuls?”If your model has, between matmuls: bias add, activation, dropout, layernorm/RMSNorm:
- The chain is memory-bound (each op is elementwise, low intensity).
- Fix: fuse them (Triton kernel or
torch.compilewith XLA backend). - Expected gain: large (often 1.5-3x on those segments).
Words to use precisely
Section titled “Words to use precisely”- Kernel: GPU code for one op.
- Fusion: combining adjacent ops into one kernel, eliminating intermediate HBM trips.
- Tile / block: a chunk of a tensor processed together in SRAM.
- FlashAttention: tiled, fused attention that never materializes N x N scores in HBM.
Source
Section titled “Source”- Stanford CS336, Lecture 6 (Kernels, Triton, XLA), by Hashimoto and Liang.
cs336.stanford.edu. Independent structural mirror in original prose; see references.