Skip to content

Lesson: Writing fast kernels, Triton and XLA

Lesson 5 showed the chip: tensor cores, the memory hierarchy, and the physical reason arithmetic intensity decides whether the hardware is busy. This lesson is what you do about it. The single largest lever, by a wide margin, is fusing operations so data stays in fast memory across many computations instead of round-tripping through HBM. The two ways you get there in practice are Triton (you write the kernel in Python) and XLA (a compiler fuses your graph for you). The example that made fusion famous is FlashAttention, and it is the cleanest case study of the idea.

A kernel is code that runs on the GPU. The CPU launches kernels onto the device, each one performing one operation (a matmul, a softmax, a normalization). PyTorch’s standard ops are kernels under the hood, written by NVIDIA in libraries like cuBLAS and cuDNN, and they are extremely good at the cases they cover. Writing a custom kernel matters when (a) you have an op or pattern those libraries don’t cover well, or (b) you want to fuse several ops into one.

Recall the arithmetic-intensity picture from lesson 2 and the memory hierarchy from lesson 5. When you call op A and then op B sequentially in Python, each is a separate kernel launch. The shape of each launch is:

1. Read inputs from HBM (slow)
2. Compute (fast, if the data is large enough)
3. Write outputs back to HBM (slow)

If A’s output feeds straight into B, that HBM round-trip in the middle is wasted: A wrote the data out, B will immediately read it back. A fused kernel does A and B together in one launch: it reads the inputs from HBM once, keeps the intermediate in registers or SRAM, and writes only the final output back. Two HBM trips become one. The arithmetic intensity (FLOPs per byte moved from HBM) rises, and a chain of memory-bound operations becomes far less memory-bound, sometimes compute-bound.

This is why fusion dwarfs almost every other optimization for memory-bound code, and why a chain of small elementwise ops (a bias add, an activation, a normalization) is exactly the right thing to fuse. The math is identical; the data movement is much less.

Triton is an OpenAI-developed language and compiler that lets you write GPU kernels in Python-like code, with the compiler handling the parts that make raw CUDA painful: scheduling threads onto warps, allocating registers, staging tiles into shared memory. You think in blocks of data (a tile of the matrix, a chunk of the sequence); Triton turns that into the warp-level code.

The practical result is that you can write a custom fused kernel in a few dozen lines and get performance close to a hand-tuned CUDA expert, where before Triton the same kernel would have been hundreds of lines of CUDA and out of reach for most teams. Modern fast attention implementations, custom fused norms, gated FFN variants, mixture-of-experts dispatch, are often written in Triton.

XLA (accelerated linear algebra) takes the other route. Instead of you writing kernels, you write your model as a graph of standard operations (in JAX, TensorFlow, or via torch-compile), and the compiler analyzes that graph, decides which adjacent operations should be fused into single kernels, tiles them for the target hardware, and emits the fused kernel code. You give up some control; you get automatic fusion across the whole model without writing a line of GPU code yourself.

Two tools, one purpose: keep data in fast memory across multiple operations.

The reason fusion gets talked about constantly is the dramatic case study of attention. Standard attention computes:

scores = Q @ K^T # an N x N matrix
attn = softmax(scores)
output = attn @ V

At long sequence length N, the N-by-N scores matrix is huge. The naive implementation materializes that matrix in HBM, writes it, reads it back for the softmax, writes the softmax, reads it back to multiply by V. Each of those HBM trips is enormous, and the whole pipeline is memory-bound: tensor cores idle while the bandwidth is saturated moving the scores matrix around.

FlashAttention rewrites this so the N-by-N matrix is never materialized in HBM at all. It tiles the computation: load a block of Q and a block of K into SRAM, compute the partial scores, apply a numerically careful tile-wise softmax (which requires tracking running max and sum statistics across tiles), multiply by a block of V, accumulate the output, then load the next K/V block and repeat. The intermediate scores live in SRAM and never see HBM; only the inputs and the final output cross the bandwidth boundary. Same math, identical results, but the data movement drops by a large factor, and on long contexts the speed-up is on the order of two-to-four times for forward passes (with comparable memory savings) and unlocks much longer sequences.

FlashAttention is the canonical example because it shows the principle perfectly: the gain is not from doing less math but from arranging the math so it fits in fast memory.

How this connects to the rest of the track

Section titled “How this connects to the rest of the track”

Two threads tie together. First, FlashAttention and the grouped-query attention of lesson 4 are complementary optimizations: FlashAttention tiles the computation so the full attention matrix never has to be written to slow memory, while GQA shrinks the KV cache; they are orthogonal and combine for long-context inference. Second, fusion is precisely what raises arithmetic intensity from lesson 2 and 5: instead of 2 operations per byte transferred, you get many operations per byte transferred, which is the literal definition of moving from memory-bound toward compute-bound. The hardware lesson said “stage data into fast memory and reuse it”; this lesson is how you actually do that in code.

Two lessons stick. First, when you read a paper or release that reports a large speed-up over a baseline, the answer is almost always a fused kernel, not new math. Knowing that demystifies most of the headline performance improvements in the field, and lets you read them critically: the gain is from data movement, not from a better algorithm. Second, fusion is no longer an exotic skill. Triton brings custom kernels within reach for any practitioner who has Python, and XLA gives you a lot of the benefit for free if you use a graph compiler. Modern from-scratch model work assumes you will reach for one or the other when stock kernels leave performance on the table, and it expects you to recognize when that is happening. The systems half of an LLM build is, in large part, fusion-aware code.

  • A kernel is GPU code for one operation. PyTorch’s standard ops are kernels (often from cuBLAS/cuDNN). Custom kernels matter when stock ones don’t cover your case well, or when you want to fuse several operations into one.
  • Fusion is the single biggest lever for memory-bound code. Two sequential ops mean two HBM round-trips; fused they keep the intermediate in SRAM/registers and round-trip once. Arithmetic intensity rises; memory-bound chains become much less memory-bound.
  • Triton writes kernels in Python. OpenAI’s language/compiler that lets you express block-level GPU code while the compiler handles warp scheduling, register allocation, and SRAM tiling. Modern fast attention, fused norms, and MoE dispatch are often written in Triton.
  • XLA fuses a graph for you. A compiler (used by JAX, TF, and torch-compile) that analyzes a graph of standard ops and emits fused kernels automatically. You trade some control for free fusion across the model.
  • FlashAttention is the canonical example. It tiles attention so the N-by-N scores matrix never sits in HBM; intermediates live in SRAM. Same math, large speed-ups (2-4x forward) and large memory savings; unlocks long contexts.
  • Fusion is the code-level mechanism that raises arithmetic intensity and moves operations from memory-bound toward compute-bound, the principle lessons 2 and 5 named in the abstract.

Most “X is faster” claims in the systems half of LLM work are really “we fused these ops so the data stays in fast memory.” Triton lets you write that fusion; XLA lets a compiler do it; FlashAttention is the famous case. The math doesn’t change; the data movement does, and that is most of the speed.