Lab 04-01 — Paged Attention with Online Softmax [CPU-OK]
This lab is where the two most important kernel ideas in LLM inference fuse into one
function. From PagedAttention (Phase 2): K/V live in scattered physical blocks, reached
through a block table. From FlashAttention: you never materialize the full score row —
you stream the keys and maintain a running softmax. Put them together and you have, in
~25 lines of numpy, the algorithm at the heart of every decode kernel vLLM ships:
paged_attention_v1.cu, the Triton fallbacks, FlashInfer's decode path. When the tests
pass, you don't "know about" these kernels anymore — you've written their semantics.
Did Phase 2 lab-06 already? Good — that was the gather with ordinary softmax. This lab replaces the softmax with the online recurrence, the part that makes the streaming exact. Different load-bearing idea, same scaffolding, deliberately.
Contents
- Why this lab exists
- Background: the recurrence
- Files
- Run
- What to implement
- What the tests prove
- Hitchhiker's notes
- Going further
- References
Why this lab exists
Naive attention computes all N scores, softmaxes the row, then blends N value rows.
That's three passes over data that, for a long context, doesn't fit in any fast memory —
on a GPU it means writing an O(N) score row to HBM and reading it back, twice, in the
hottest loop of the entire system. FlashAttention's insight is that softmax can be
computed in one streaming pass with O(1) extra state, if you're willing to rescale
history every time you discover a new maximum. That rescaling trick — three running
quantities and a correction factor — is the single most important piece of kernel math in
this field, and the only way to actually own it is to implement it and watch it match the
naive answer to 1e-6 on inputs where a wrong correction factor would diverge wildly.
The phase needs this lab as its foundation: lab-03 runs this recurrence per query row (prefill), lab-04 proves it's a mergeable monoid (flash-decoding), and the deep-dive's tour of real backends assumes you can see this loop inside every one of them.
Background: the recurrence
You hold three things while streaming key blocks: m (max score so far), denom (sum of
exp(score − m) so far), acc (sum of exp(score − m) · v so far — unnormalized). For
each new block with scores s:
m_new = max(m, max(s))
corr = exp(m − m_new) # how much history shrinks under the new max
p = exp(s − m_new) # new block's weights, on the new scale
acc = acc · corr + p @ V_block
denom = denom · corr + sum(p)
m = m_new
Final answer: acc / denom. Why it's exact (not an approximation): every exp(s_i) you
ever wanted appears in the final sums multiplied by exp(−m_final) — the corrections
compose so each term is rescaled from whatever max it was added under to the final max.
It's a telescoping product, and the only thing subtraction-by-max changes is overflow
behavior, never the ratio. The same algebra is why the state merges across partitions
in lab-04 — write it out once by hand for two blocks and the whole phase unlocks.
The paged part you know from Phase 2: token t is at
k_cache[block_table[t // block_size], t % block_size], and the last block of a sequence
is usually partial — read only seq_len − start rows of it.
Files
starter.py—dense_attention(the slow truth) andpaged_online_attention(the streaming, gathered version). Your work.solution.py— reference.test_lab.py— equality with dense for aligned and ragged lengths, and paging invariance.
Run
LAB_IMPL=starter pytest phase-04-attention-backends/labs/lab-01-paged-attention-gather -q
pytest phase-04-attention-backends/labs/lab-01-paged-attention-gather -q # reference
What to implement
Write dense_attention first and convince yourself it's correct — it's your oracle, and
the entire discipline of kernel work is never port what you haven't proven slow. Then
the streaming version per the recurrence above, iterating logical blocks that cover
[0, seq_len). The two classic stumbles, both covered by tests:
- The first-block edge:
mstarts at−inf, socorr = exp(−inf − m_new)must come out as 0, not NaN. Guard it (the solution branches onm != -inf). - The partial last block:
valid = min(block_size, seq_len − start). Read one row too many and you're attending over uninitialized cache — the bug that "almost works" (Phase 2 lab-06 poisoned the padding to make this loud; here the random zeros are quiet but the 1e-6 equality still catches it).
What the tests prove
| Test | What it pins |
|---|---|
test_matches_dense_block_aligned | The recurrence itself: 16 tokens, 4 scattered blocks ([3, 1, 7, 0]), equal to dense within 1e-6. A wrong corr doesn't fail subtly — softmax weights are exponential in the error, so divergence is loud |
test_matches_dense_partial_last_block | 13 tokens = 3 full + 1 single-token block: the valid bound |
test_paging_invariance | Same logical sequence at physical placements [0,1,2] vs [7,3,5] → identical output. The block table is the only coupling between logical and physical — Phase 2's identity theorem, restated where the math happens |
Hitchhiker's notes
- Map your variables to the CUDA kernel: in
paged_attention_v1.cu, yourmisqk_max(computed via warp/block reductions instead ofmax()), yourdenomisexp_sum, youracclives in registers asaccs, and your gather is theblock_table-indexed pointer arithmetic in the main loop. Read the kernel right after finishing — it's ~400 lines of which you now understand the load-bearing 40; the rest is vectorized loads, shared-memory staging, and reduction plumbing (the "95% performance engineering" of Phase 2 lab-04). - Why subtract the max at all, again?
exp(90)overflows float32. Logit ~90 is not exotic — it's a confident model with a sharp head. Unprotected softmax is a NaN factory; subtraction-by-max makes every exponent ≤ 0. The online version just maintains that protection without knowing the max in advance — that's the whole cleverness. - One query here, many heads in reality: real decode runs this once per (sequence, KV-head) with the query being that head's slice — heads are embarrassingly parallel and share nothing (Phase 2 lab-06's per-head loop). GQA means several query heads stream the same K/V blocks — bandwidth amortization inside the kernel, one more reason GQA wins (Phase 0 lab-02).
- Numerics note for the tests' 1e-6: float64 throughout, so the tolerance is generous
— it's calibrated to catch algorithmic error (a missing
corr, an off-by-one), not rounding. In fp16 kernels the same comparison runs at 1e-2 with fp32 accumulators (Phase 2 lab-04's gate); the tolerance always encodes what you're testing for.
Going further
- Hand-trace two blocks with two tokens each on paper, with block 2's max larger than
block 1's — watch
corrshrink the history. Then once with block 2's max smaller — watchcorr = 1and nothing rescale. The recurrence has exactly these two behaviors. - Delete the
corrfactor and run the tests: the aligned test fails with weights skewed toward later blocks. Now you know this failure's signature — useful the day you review a kernel PR that gets it almost right. - Batch it: take a list of
(q, block_table, seq_len)and loop — you've builtpaged_attention_v1's grid (one program per sequence per head). Then go to lab-04 to split within a sequence, and lab-03 to widen to query chunks.
References
- Milakov & Gimelshein, Online normalizer calculation for softmax (2018) — the recurrence, 3 readable pages: https://arxiv.org/abs/1805.02867
- Dao et al., FlashAttention (2022) — the recurrence + tiling + IO analysis: https://arxiv.org/abs/2205.14135
upstream/csrc/attention/paged_attention_v1.cu—qk_max,exp_sum, the gather: your function in CUDA.upstream/vllm/v1/attention/backends/flash_attn.py:592— where the real engine hands block tables and slot mappings to the kernel (find both; Phase 2 lab-06 explains the write side).- 02-mini-build.md — the recurrence derived step by step.