Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Lab 07-03 — Tiled GEMM and the Memory-Traffic Model [CPU-OK]

A matrix multiply is three nested loops a first-year student can write. CUTLASS — the template library behind most of vLLM's GEMMs — is tens of thousands of lines. This lab is about the single idea that fills that gap: tiling. Not because the loops are wrong, but because the memory traffic is: a naive GEMM re-reads its operands from slow memory incessantly, while a tiled one stages blocks in fast memory and reuses every loaded element many times. You'll implement the tiling (and prove it changes nothing numerically — it's pure loop reordering), then build the traffic model that explains why tile shape is the most important number in any GEMM kernel — and derive, as a bonus, exactly why a decode-shaped matmul (M=1) can't be saved by any tile size at all.

Contents


Why this lab exists

Phase 0 lab-04 gave you the roofline: below the ridge you're bandwidth-bound, above it compute-bound. What it didn't say is that arithmetic intensity is not a property of the problem — it's a property of the algorithm. A 1024³ GEMM has enough FLOPs per byte in principle (operands total ~6 MB, work totals 2 G-FLOPs: thousands of FLOPs per byte), but the naive loop order achieves an intensity of ~1 anyway, because it keeps re-loading what it just evicted. Tiling is the act of claiming the intensity the math always had. Every fast kernel you'll ever read — CUTLASS GEMMs, FlashAttention (which is exactly this lab applied to attention — Phase 4), the fused MoE kernel (lab-02's profile) — is this one idea wearing different clothes, and the traffic model you build here is how you estimate any of them on a napkin.

The numerics half matters too: tiling reorders the accumulation, and you'll prove with tests that for exact arithmetic it's an identity (ragged edges included — the place naive implementations corrupt silently). In floating point, reordering shifts the last ulp — the legitimate cross-kernel divergence you've now met in three phases (3, 4, 6).

Background: the reuse arithmetic

Count slow-memory loads. Naive: each output element C[i,j] streams a K-row of A and a K-column of B → M·N·2K loads — every element of A is loaded N times, every element of B loaded M times. Tiled with (tile_m × tile_n) output tiles: each tile streams its A-rows and B-columns once (staged in fast memory while the tile's tile_m·tile_n·K FLOPs consume them):

tiled loads = M·K · ceil(N/tile_n)  +  K·N · ceil(M/tile_m)

reuse = naive/tiled = 2 / (1/tile_m + 1/tile_n)     ← the HARMONIC MEAN of the tile dims

That harmonic mean is the lab's punchline. It says: reuse is governed by the smaller tile dimension (256×16 tiles reuse like ~30, not like their area suggests); square tiles maximize reuse per unit of fast memory (a t×t tile gives reuse t while staging O(t·K) operands); and — the inference-shaped consequence — when M=1 (a single decode token), tile_m is pinned at 1 and reuse caps at 2, no matter how clever the kernel. The weights must stream once per step. That's Phase 0 lab-04's "decode is bandwidth-bound" re-derived from the kernel's side, and it's why decode optimization is about shrinking bytes (Phase 6) and sharing the stream across a batch, never about better GEMM tiling.

Files

  • starter.pytiled_gemm, naive_traffic, tiled_traffic, reuse_factor. Your work.
  • solution.py — reference.
  • test_lab.py — equality (divisible, ragged, tile=1), the traffic formulas, the bigger-tiles-less-traffic direction, the harmonic mean, and the decode-shape cap.

Run

LAB_IMPL=starter pytest phase-07-gemm-and-moe-kernels/labs/lab-03-tiled-gemm -q
pytest phase-07-gemm-and-moe-kernels/labs/lab-03-tiled-gemm -q   # reference

What the tests prove

TestWhat it pins
test_tiled_equals_matmul_divisible / _raggedTiling is loop reordering, not approximation — including 37×23×19, where every edge tile is partial. Ragged edges are where real kernel bugs live (predication/masking in CUTLASS); your min() bounds are their readable form
test_tile_size_one_is_the_naive_algorithmThe degenerate case anchors the model: tiles of 1 = no reuse = the naive loop
test_traffic_formulasThe load counts, exactly — 1024³ with 128² tiles moves 16 MB-equivalents instead of 2 GB-equivalents
test_bigger_tiles_mean_less_trafficThe direction that justifies burning shared memory on bigger tiles
test_reuse_factor_is_the_harmonic_tile_sizeSquare 128 → reuse 128; skewed 256×16 → ~30. Shape, not area
test_decode_shape_has_no_reuse_to_harvestM=1 → reuse ≤ 2. The GEMM-side proof of decode's bandwidth wall

Hitchhiker's notes

  • Why not just make tiles enormous? Fast memory is finite: a GPU SM has ~100–230 KB of shared memory, and a t×t fp16 tile's staging (A-panel + B-panel + accumulator) must fit — which lands real kernels at tiles like 128×128 or 128×256, exactly where your model's curve flattens against the hardware budget. Tile choice is a constrained optimization, and CUTLASS exposes it as template parameters because the optimum moves with dtype, shape, and architecture.
  • The hierarchy repeats: HBM → shared memory is your model's level, but the same arithmetic recurs for shared memory → registers (warp tiles), and L2 ↔ HBM (threadblock swizzling for L2 reuse). Production GEMMs tile at three levels with the same formula at each. Learn it once, apply it fractally.
  • Tensor cores change the FLOP rate, not the traffic math. They make the compute side faster, which raises the ridge (Phase 0 lab-04) and makes good tiling more necessary, not less — a tensor-core GEMM that under-tiles just starves faster. This is why Hopper added TMA (bulk async copies HBM→shared): feeding the tiles became the whole game.
  • Grouped GEMM (the MoE kernel, lab-01/02) is this lab plus one indirection: many small GEMMs (one per expert) whose tiles are scheduled from a single kernel launch so the tile machinery amortizes across experts. moe_align_block_size exists precisely to organize tokens into tile-shaped groups — your lab-01 argsort, upgraded to be tile-aware.

Going further

  • Add a fast_memory_bytes(tile_m, tile_n, tile_k, dtype_bytes) function and find the best square tile under a 100 KB budget for K=4096 — then compare against the tile shapes in a CUTLASS config or a Triton autotune list. You'll land within a factor of 2 of what the pros chose, from a 5-line model.
  • Time it for real: your tiled_gemm vs A @ B in numpy is unfair (BLAS is tiled and vectorized), but tiled_gemm with tile 64 vs tile 1 against each other shows the traffic effect even through Python overhead. Measure, then explain the ratio.
  • Extend the traffic model with the K-dimension split (tile_k, split-K reduction — needed when M and N are both small but K is huge). Notice the merge-partials shape from Phase 4 lab-04 reappearing: split-K GEMM is the same monoid trick, applied to plain sums.

References