Skip to content

Cheatsheet: Writing fast kernels

  • Kernel = GPU code for one op (matmul, softmax, norm, …).
  • CPU launches kernels onto the device.
  • PyTorch’s stock ops are kernels (often cuBLAS/cuDNN).
  • Write a custom kernel when: stock doesn’t cover your case, or you want to fuse several ops.
Unfused (2 ops): HBM -> compute A -> HBM -> compute B -> HBM
Fused (1 op): HBM -> compute A -> [SRAM/regs] -> compute B -> HBM
  • Eliminates HBM round-trips for intermediates.
  • Arithmetic intensity (FLOPs/byte from HBM) rises.
  • Memory-bound chains become much less memory-bound.
TritonXLA
What you writeKernel code (Python-like, block-level)A graph of standard ops (model)
Who fusesYou (compiler handles warps/registers/SRAM)The compiler, automatically
ControlHigh (custom kernels)Lower (graph-driven)
Used inOpenAI, custom kernels (FlashAttention, MoE)JAX, TF, torch.compile

Both keep data in fast memory across multiple ops.

Problem: naive attention materializes an N x N scores matrix in HBM at long N. Each step (scores, softmax, V multiply) round-trips HBM. Bandwidth saturates; tensor cores idle.

Fix: tile the computation so scores never sit in HBM:

  1. Load a block of Q + a block of K into SRAM.
  2. Compute partial scores there.
  3. Apply tile-wise softmax (track running max + sum across tiles).
  4. Multiply by a block of V; accumulate output.
  5. Stream the next K/V block; repeat.

Result: same math, ~2-4x faster forward, large memory savings, long contexts unlock.

  • Lesson 2: arithmetic intensity = FLOPs / bytes moved.
  • Memory-bound (low intensity) means tensor cores idle on HBM.
  • Fusion is the code-level mechanism that raises intensity (many ops per HBM trip).
  • The hardware lesson’s “stage data, reuse it” becomes real here.

Diagnostic: chain of small ops between matmuls?

Section titled “Diagnostic: chain of small ops between matmuls?”

If your model has, between matmuls: bias add, activation, dropout, layernorm/RMSNorm:

  • The chain is memory-bound (each op is elementwise, low intensity).
  • Fix: fuse them (Triton kernel or torch.compile with XLA backend).
  • Expected gain: large (often 1.5-3x on those segments).
  • Kernel: GPU code for one op.
  • Fusion: combining adjacent ops into one kernel, eliminating intermediate HBM trips.
  • Tile / block: a chunk of a tensor processed together in SRAM.
  • FlashAttention: tiled, fused attention that never materializes N x N scores in HBM.
  • Stanford CS336, Lecture 6 (Kernels, Triton, XLA), by Hashimoto and Liang. cs336.stanford.edu. Independent structural mirror in original prose; see references.