Lab 02-04 — A Block-Table-Indexed Attention in Triton [GPU-REQ]
The payoff lab. For three labs you've been managing metadata — block ids, ref counts, free queues — on the promise that some kernel, somewhere, turns those tables into actual attention. This is that kernel. You'll write a small Triton program that does what the real paged-attention kernel does: gather K/V from scattered physical blocks through a block table, inside the GPU, and produce attention output bit-for-bit (well, half-precision- for-half-precision) equal to the dense reference. The metadata finally meets the math.
No GPU? Don't panic. Do lab-06 first — it's this lab's exact algorithm in pure numpy, CPU-only, fully tested. Then read this walkthrough and the captured output; the indirection is the lesson, Triton is the dialect. You can rent an A10 for about a dollar when you want the real thing (see SETUP.md).
Contents
- Why this lab exists
- Background: what the kernel must do
- Requirements
- The task
- Steps
- Compare to the real kernel
- Captured output (real run, A10 24GB, triton 3.x)
- Hitchhiker's notes
- Reflect
- References
Why this lab exists
There's a moment of disbelief everyone has with PagedAttention: "wait — the KV for one sequence is scattered across random physical blocks, and the attention kernel just… deals with it?" Yes. And the dealing is two lines of address arithmetic. This lab exists so you stop believing that and start knowing it — because you wrote the two lines.
The career payoff is concrete: attention backends are where vLLM meets the hardware, and
"can read/modify a paged attention kernel" is the dividing line between engineers who
configure vLLM and engineers who fix it. Phase 4 (FlashAttention/FlashInfer backends),
Phase 7 (kernels), and a large fraction of real upstream PRs assume exactly the literacy
this lab builds. Triton is the right first dialect: Python-syntax, explicit about memory,
and what vLLM itself uses for many fallback kernels (upstream/vllm/attention/ops/).
Background: what the kernel must do
One decode step of attention, for one request:
out = softmax(q · Kᵀ / √d) · V
where q is this step's single query vector, and K/V are all previous tokens' keys and
values. In a dense engine, K and V are contiguous [seq_len, heads, dim] tensors —
token t is at row t. Under paging, token t lives in physical block
block_table[t // block_size] at offset t % block_size:
physical_row(t) = block_table[t // block_size] * block_size + (t % block_size)
That one formula is the entire difference between dense and paged attention. Everything else — the dot products, the softmax, the weighted sum — is unchanged. The kernel receives one extra input (the block table, an int array) and performs one extra indexed load per block. The cost is one address computation; the benefit was labs 01–03.
The second idea you'll need is online softmax (the FlashAttention trick): for long
sequences you can't materialize the full score row in fast memory, so you stream K/V block
by block, keeping a running maximum m, running denominator l, and rescaling the
accumulator as m updates. Numerically exact, O(1) extra memory. Phase 4 dives deep;
here you implement the minimal version.
Requirements
uv pip install -e ".[torch,triton]" # needs a CUDA GPU (T4/A10/L4 all fine)
The task
Implement single-query (one decode step) attention over a paged KV cache:
- KV cache:
kv[num_blocks, block_size, num_heads, head_dim]— physical blocks, fp16. block_table[num_logical_blocks]— logical → physical mapping for one sequence.seq_len— how many tokens are valid (the tail block is partly empty — mask it!).
For query q[num_heads, head_dim], produce softmax(q·Kᵀ/√d)·V where K/V are gathered
through the block table.
Steps
- Torch reference first (in
starter.py): a slow, obviously-correct paged version — python loop over logical blocks, gather via the formula, regular softmax. Verify it matches a dense baseline on the same data to ~1e-3 (fp16). Never port to a kernel language something you haven't proven in a slow language. This reference is also your debugger: when the Triton version disagrees, binary-search by comparing per-block partial sums. - Port to Triton: one program per (head); loop over logical blocks; each iteration
tl.loads the physical block id from the table, then loads that block's K tile, updates the online-softmax state (m,l, accumulator), same for V; mask the tail block withoffs < seq_len. Keep block_size = the tile size and the kernel stays readable (~40 lines). - Correctness gate: max |Δ| vs the torch reference within
1e-2(fp16 accumulation noise; use fp32 accumulators inside the kernel — Triton's default fortl.dot— and you'll land near1e-3).
Compare to the real kernel
Now open the production versions and find your two lines:
upstream/csrc/attention/paged_attention_v1.cu— searchblock_table. Same indirection, plus: vectorized 16-byte loads, warp-level reductions, head-dim tiling, av2variant that partitions long sequences across thread blocks and reduces partial results (needed when one sequence's KV no longer fits one SM's shared memory).upstream/vllm/v1/attention/backends/flash_attn.py— where the metadata you've been building all phase is marshaled into the kernel's arguments. Findblock_table(read path: where all prior KV lives) andslot_mapping(write path: where this step's new K/V get scattered). Two tensors, two directions — the scheduler's decisions, compiled.
The honest takeaway: production kernels are 95% performance engineering wrapped around the 5% of logic you just wrote. You now own the 5% that defines correctness; Phase 4 teaches the 95%.
Captured output (real run, A10 24GB, triton 3.x)
$ python lab.py
dense baseline : output[0,:4] = [ 0.0123 -0.0455 0.0991 0.0237]
paged torch ref : output[0,:4] = [ 0.0123 -0.0455 0.0991 0.0237] max|Δ| = 0.0e+00
paged triton : output[0,:4] = [ 0.0124 -0.0454 0.0990 0.0238] max|Δ| = 7.6e-03 ✓
seq_len=130 block_size=16 -> 9 logical blocks, physical ids = [12, 3, 47, 1, 88, 5, 9, 22, 0]
PASS: triton paged attention matches dense within 1e-2
Read the last data line closely — it's the whole phase in one line. The sequence's 130
tokens live in physical blocks [12, 3, 47, 1, 88, 5, 9, 22, 0]: out of order, scattered
anywhere in the pool (block 0 here is just whatever the allocator handed out — in
mini_vllm it'd be reserved as the null block; the simulation hands out arbitrary ids).
The 9th block holds only 130 − 8·16 = 2 valid tokens — your tail mask earned its keep.
And max|Δ| = 7.6e-03 is fp16 rounding, not error: the paged result is the dense result,
because gathering through a table is mathematically the identity. The block table changed
where bytes live, never what they mean. That sentence is PagedAttention.
Hitchhiker's notes
block_tablereads;slot_mappingwrites. Per step, the runner first scatters the new K/V into their assigned slots (slot_mapping, one entry per scheduled token), then the kernel gathers everything throughblock_table. Mixing these up is the most common conceptual error in this phase — they're different tensors with different shapes built from the same allocator state.- Masking bugs read as "almost right." Forget the tail mask and you attend over
garbage in the unfilled slots — outputs are subtly wrong, worse on short sequences,
and pass eyeball tests. This is why the correctness gate is a max-abs-diff against a
reference, never "looks plausible." (And why the gate uses varied
seq_lens that don't divide evenly byblock_size.) - Why fp32 accumulators? Summing many fp16 products loses bits; flash-style kernels
accumulate in fp32 and round once at the end. The
7.6e-03above would be 10× worse with fp16 accumulation — try it, it's a one-line change and an excellent numerics lesson. - Decode vs prefill kernels differ. You wrote the decode shape (1 query × N keys). Prefill is M queries × N keys with causal masking — same indirection, different tiling, which is why real backends ship separate paths (and why chunked prefill needs kernels that handle "M queries starting at offset k" — Phase 4).
Reflect
- Why must the kernel receive the block table at all — could the runner instead copy each sequence's KV into a contiguous scratch buffer and call a dense kernel? (It could — and it would burn memory bandwidth proportional to the whole context per step, exactly the resource decode is starved for. The indirection moves the scatter/gather into the compute, paying address arithmetic — which is free next to memory traffic — instead of copies.)
- The block table for a 128k-token sequence at block_size 16 has 8192 entries. Where does
it live, and does reading it hurt? (Global memory; one extra int load per 16 tokens —
amortized to noise. But the CPU-side construction of batched block tables every step is
real overhead, which is why upstream builds them incrementally — peek at
block_table.pyin the worker.) - What breaks if two requests share a block (ref_cnt = 2, prefix caching) and one of them
writes to it? (Corruption of the other's prefix — which is why shared blocks are
read-only by construction: writes only ever target a request's own tail block via
slot_mapping. Copy-on-write for the partial-block case is exactly how upstream handles
the edge — find
copyin the kv-cache manager when you're curious.)
References
upstream/csrc/attention/paged_attention_v1.cu— the production CUDA kernel.upstream/vllm/attention/ops/— Triton kernels in-tree; closest cousins to yours.upstream/vllm/v1/attention/backends/flash_attn.py— metadata → kernel arguments.- Kwon et al., PagedAttention (SOSP 2023), §4.3 — the kernel design: https://arxiv.org/abs/2309.06180
- Dao et al., FlashAttention (2022) — the online-softmax streaming you implemented: https://arxiv.org/abs/2205.14135
- Milakov & Gimelshein, Online normalizer calculation for softmax (2018) — the original online-softmax trick, 3 pages, very readable: https://arxiv.org/abs/1805.02867
- Triton tutorials — Fused Attention is this lab with prefill shapes: https://triton-lang.org/main/getting-started/tutorials/