Skip to content

Summary: Writing fast kernels

This lesson turns lesson 5’s hardware picture into code. A kernel is GPU code for one operation; PyTorch’s standard ops are kernels (cuBLAS/cuDNN). The single biggest performance lever is fusion: combining adjacent ops into one kernel so the intermediate stays in SRAM/registers and HBM is round-tripped once instead of many times. Two practical paths: Triton lets you write custom kernels in Python (block-level code; compiler handles warps and tiling), and XLA is a compiler that takes a graph of standard ops and emits fused kernels automatically. The famous worked example is FlashAttention, which tiles attention so the N-by-N scores matrix never lives in HBM, yielding 2-4x speed-ups and large memory savings on long contexts, with identical math. Fusion is the code-level mechanism that raises arithmetic intensity from lesson 2. This is the scan version; the lesson explains why most modern speed-ups are data-movement wins.

  • Kernel = GPU code for one op. CPU launches kernels onto the device. Standard ops ship as good kernels; custom ones are for cases stock misses or for fusion.
  • Fusion is the lever. Two sequential ops mean two HBM round-trips (memory-bound). Fused: intermediate in SRAM/registers, one HBM trip. FLOPs-per-byte rises; memory-bound chains become far less so.
  • Triton lets you write kernels in Python-like block code; compiler does warp scheduling, register allocation, SRAM tiling. Fast attention, fused norms, MoE dispatch are commonly written in it.
  • XLA is a graph compiler (used by JAX, TF, torch.compile) that fuses standard-op graphs into kernels automatically. Trade control for free fusion.
  • FlashAttention = canonical fusion example. Tiles attention so the N-by-N scores never sit in HBM; partial scores, tile-wise softmax (running max/sum), and V multiplication all in SRAM. Same math, ~2-4x speed-up, big memory savings, long-context-enabling.
  • Fusion connects to lesson 2. It is the code-level mechanism that raises arithmetic intensity, moving operations from memory-bound toward compute-bound.

Two takeaways. First, you can read systems papers and release notes critically: when someone reports a large speed-up at the same math, it is almost always a fused kernel, not a new algorithm. The headline systems wins of the past few years, FlashAttention, kernel fusion across normalization-and-activation chains, mixed precision served well, are data-movement wins, not arithmetic ones. Second, fusion is no longer exotic. Triton brings custom kernels within reach for any practitioner who has Python and knows lesson 5’s memory hierarchy. XLA gives much of the same benefit for free if you use a graph compiler. Modern from-scratch model work assumes you reach for one of these when stock kernels leave performance on the table, and expects you to recognize when that is happening. The next lesson scales out from one fast device to many, with parallelism.

Most “X is faster” claims in LLM systems work are really “we fused these ops so data stays in fast memory.” Triton lets you write the fusion; XLA lets a compiler do it; FlashAttention is the famous case. The math doesn’t change; the data movement does.