Lab 07-03 — Tiled GEMM and the Memory-Traffic Model [CPU-OK]
A matrix multiply is three nested loops a first-year student can write. CUTLASS — the template library behind most of vLLM's GEMMs — is tens of thousands of lines. This lab is about the single idea that fills that gap: tiling. Not because the loops are wrong, but because the memory traffic is: a naive GEMM re-reads its operands from slow memory incessantly, while a tiled one stages blocks in fast memory and reuses every loaded element many times. You'll implement the tiling (and prove it changes nothing numerically — it's pure loop reordering), then build the traffic model that explains why tile shape is the most important number in any GEMM kernel — and derive, as a bonus, exactly why a decode-shaped matmul (M=1) can't be saved by any tile size at all.
Contents
- Why this lab exists
- Background: the reuse arithmetic
- Files
- Run
- What the tests prove
- Hitchhiker's notes
- Going further
- References
Why this lab exists
Phase 0 lab-04 gave you the roofline: below the ridge you're bandwidth-bound, above it compute-bound. What it didn't say is that arithmetic intensity is not a property of the problem — it's a property of the algorithm. A 1024³ GEMM has enough FLOPs per byte in principle (operands total ~6 MB, work totals 2 G-FLOPs: thousands of FLOPs per byte), but the naive loop order achieves an intensity of ~1 anyway, because it keeps re-loading what it just evicted. Tiling is the act of claiming the intensity the math always had. Every fast kernel you'll ever read — CUTLASS GEMMs, FlashAttention (which is exactly this lab applied to attention — Phase 4), the fused MoE kernel (lab-02's profile) — is this one idea wearing different clothes, and the traffic model you build here is how you estimate any of them on a napkin.
The numerics half matters too: tiling reorders the accumulation, and you'll prove with tests that for exact arithmetic it's an identity (ragged edges included — the place naive implementations corrupt silently). In floating point, reordering shifts the last ulp — the legitimate cross-kernel divergence you've now met in three phases (3, 4, 6).
Background: the reuse arithmetic
Count slow-memory loads. Naive: each output element C[i,j] streams a K-row of A and a
K-column of B → M·N·2K loads — every element of A is loaded N times, every element of
B loaded M times. Tiled with (tile_m × tile_n) output tiles: each tile streams its
A-rows and B-columns once (staged in fast memory while the tile's tile_m·tile_n·K
FLOPs consume them):
tiled loads = M·K · ceil(N/tile_n) + K·N · ceil(M/tile_m)
reuse = naive/tiled = 2 / (1/tile_m + 1/tile_n) ← the HARMONIC MEAN of the tile dims
That harmonic mean is the lab's punchline. It says: reuse is governed by the smaller
tile dimension (256×16 tiles reuse like ~30, not like their area suggests); square tiles
maximize reuse per unit of fast memory (a t×t tile gives reuse t while staging
O(t·K) operands); and — the inference-shaped consequence — when M=1 (a single decode
token), tile_m is pinned at 1 and reuse caps at 2, no matter how clever the kernel.
The weights must stream once per step. That's Phase 0 lab-04's "decode is
bandwidth-bound" re-derived from the kernel's side, and it's why decode optimization is
about shrinking bytes (Phase 6) and sharing the stream across a batch, never about
better GEMM tiling.
Files
starter.py—tiled_gemm,naive_traffic,tiled_traffic,reuse_factor. Your work.solution.py— reference.test_lab.py— equality (divisible, ragged, tile=1), the traffic formulas, the bigger-tiles-less-traffic direction, the harmonic mean, and the decode-shape cap.
Run
LAB_IMPL=starter pytest phase-07-gemm-and-moe-kernels/labs/lab-03-tiled-gemm -q
pytest phase-07-gemm-and-moe-kernels/labs/lab-03-tiled-gemm -q # reference
What the tests prove
| Test | What it pins |
|---|---|
test_tiled_equals_matmul_divisible / _ragged | Tiling is loop reordering, not approximation — including 37×23×19, where every edge tile is partial. Ragged edges are where real kernel bugs live (predication/masking in CUTLASS); your min() bounds are their readable form |
test_tile_size_one_is_the_naive_algorithm | The degenerate case anchors the model: tiles of 1 = no reuse = the naive loop |
test_traffic_formulas | The load counts, exactly — 1024³ with 128² tiles moves 16 MB-equivalents instead of 2 GB-equivalents |
test_bigger_tiles_mean_less_traffic | The direction that justifies burning shared memory on bigger tiles |
test_reuse_factor_is_the_harmonic_tile_size | Square 128 → reuse 128; skewed 256×16 → ~30. Shape, not area |
test_decode_shape_has_no_reuse_to_harvest | M=1 → reuse ≤ 2. The GEMM-side proof of decode's bandwidth wall |
Hitchhiker's notes
- Why not just make tiles enormous? Fast memory is finite: a GPU SM has ~100–230 KB
of shared memory, and a
t×tfp16 tile's staging (A-panel + B-panel + accumulator) must fit — which lands real kernels at tiles like 128×128 or 128×256, exactly where your model's curve flattens against the hardware budget. Tile choice is a constrained optimization, and CUTLASS exposes it as template parameters because the optimum moves with dtype, shape, and architecture. - The hierarchy repeats: HBM → shared memory is your model's level, but the same arithmetic recurs for shared memory → registers (warp tiles), and L2 ↔ HBM (threadblock swizzling for L2 reuse). Production GEMMs tile at three levels with the same formula at each. Learn it once, apply it fractally.
- Tensor cores change the FLOP rate, not the traffic math. They make the compute side faster, which raises the ridge (Phase 0 lab-04) and makes good tiling more necessary, not less — a tensor-core GEMM that under-tiles just starves faster. This is why Hopper added TMA (bulk async copies HBM→shared): feeding the tiles became the whole game.
- Grouped GEMM (the MoE kernel, lab-01/02) is this lab plus one indirection: many
small GEMMs (one per expert) whose tiles are scheduled from a single kernel launch so
the tile machinery amortizes across experts.
moe_align_block_sizeexists precisely to organize tokens into tile-shaped groups — your lab-01argsort, upgraded to be tile-aware.
Going further
- Add a
fast_memory_bytes(tile_m, tile_n, tile_k, dtype_bytes)function and find the best square tile under a 100 KB budget for K=4096 — then compare against the tile shapes in a CUTLASS config or a Triton autotune list. You'll land within a factor of 2 of what the pros chose, from a 5-line model. - Time it for real: your
tiled_gemmvsA @ Bin numpy is unfair (BLAS is tiled and vectorized), buttiled_gemmwith tile 64 vs tile 1 against each other shows the traffic effect even through Python overhead. Measure, then explain the ratio. - Extend the traffic model with the K-dimension split (
tile_k, split-K reduction — needed when M and N are both small but K is huge). Notice the merge-partials shape from Phase 4 lab-04 reappearing: split-K GEMM is the same monoid trick, applied to plain sums.
References
upstream/csrc/quantization/cutlass_w8a8/andupstream/cmake/external_projects/— where CUTLASS enters vLLM; the deep-dive maps the entry points.- NVIDIA, CUTLASS: Efficient GEMM in CUDA (docs) — the three-level tiling hierarchy: https://github.com/NVIDIA/cutlass/blob/main/media/docs/efficient_gemm.md
- Triton tutorial, Matrix Multiplication — your lab in Triton, with autotuned tiles: https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
- Williams et al., Roofline (2009) — intensity as an algorithm property: https://dl.acm.org/doi/10.1145/1498765.1498785
- Phase 0 lab-04 — the ridge this lab's reuse factor is racing toward; Phase 4 lab-01 — FlashAttention as tiling applied to attention.