Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Lab 04-01 — Paged Attention with Online Softmax [CPU-OK]

This lab is where the two most important kernel ideas in LLM inference fuse into one function. From PagedAttention (Phase 2): K/V live in scattered physical blocks, reached through a block table. From FlashAttention: you never materialize the full score row — you stream the keys and maintain a running softmax. Put them together and you have, in ~25 lines of numpy, the algorithm at the heart of every decode kernel vLLM ships: paged_attention_v1.cu, the Triton fallbacks, FlashInfer's decode path. When the tests pass, you don't "know about" these kernels anymore — you've written their semantics.

Did Phase 2 lab-06 already? Good — that was the gather with ordinary softmax. This lab replaces the softmax with the online recurrence, the part that makes the streaming exact. Different load-bearing idea, same scaffolding, deliberately.

Contents


Why this lab exists

Naive attention computes all N scores, softmaxes the row, then blends N value rows. That's three passes over data that, for a long context, doesn't fit in any fast memory — on a GPU it means writing an O(N) score row to HBM and reading it back, twice, in the hottest loop of the entire system. FlashAttention's insight is that softmax can be computed in one streaming pass with O(1) extra state, if you're willing to rescale history every time you discover a new maximum. That rescaling trick — three running quantities and a correction factor — is the single most important piece of kernel math in this field, and the only way to actually own it is to implement it and watch it match the naive answer to 1e-6 on inputs where a wrong correction factor would diverge wildly.

The phase needs this lab as its foundation: lab-03 runs this recurrence per query row (prefill), lab-04 proves it's a mergeable monoid (flash-decoding), and the deep-dive's tour of real backends assumes you can see this loop inside every one of them.

Background: the recurrence

You hold three things while streaming key blocks: m (max score so far), denom (sum of exp(score − m) so far), acc (sum of exp(score − m) · v so far — unnormalized). For each new block with scores s:

m_new  = max(m, max(s))
corr   = exp(m − m_new)            # how much history shrinks under the new max
p      = exp(s − m_new)            # new block's weights, on the new scale
acc    = acc · corr + p @ V_block
denom  = denom · corr + sum(p)
m      = m_new

Final answer: acc / denom. Why it's exact (not an approximation): every exp(s_i) you ever wanted appears in the final sums multiplied by exp(−m_final) — the corrections compose so each term is rescaled from whatever max it was added under to the final max. It's a telescoping product, and the only thing subtraction-by-max changes is overflow behavior, never the ratio. The same algebra is why the state merges across partitions in lab-04 — write it out once by hand for two blocks and the whole phase unlocks.

The paged part you know from Phase 2: token t is at k_cache[block_table[t // block_size], t % block_size], and the last block of a sequence is usually partial — read only seq_len − start rows of it.

Files

  • starter.pydense_attention (the slow truth) and paged_online_attention (the streaming, gathered version). Your work.
  • solution.py — reference.
  • test_lab.py — equality with dense for aligned and ragged lengths, and paging invariance.

Run

LAB_IMPL=starter pytest phase-04-attention-backends/labs/lab-01-paged-attention-gather -q
pytest phase-04-attention-backends/labs/lab-01-paged-attention-gather -q   # reference

What to implement

Write dense_attention first and convince yourself it's correct — it's your oracle, and the entire discipline of kernel work is never port what you haven't proven slow. Then the streaming version per the recurrence above, iterating logical blocks that cover [0, seq_len). The two classic stumbles, both covered by tests:

  • The first-block edge: m starts at −inf, so corr = exp(−inf − m_new) must come out as 0, not NaN. Guard it (the solution branches on m != -inf).
  • The partial last block: valid = min(block_size, seq_len − start). Read one row too many and you're attending over uninitialized cache — the bug that "almost works" (Phase 2 lab-06 poisoned the padding to make this loud; here the random zeros are quiet but the 1e-6 equality still catches it).

What the tests prove

TestWhat it pins
test_matches_dense_block_alignedThe recurrence itself: 16 tokens, 4 scattered blocks ([3, 1, 7, 0]), equal to dense within 1e-6. A wrong corr doesn't fail subtly — softmax weights are exponential in the error, so divergence is loud
test_matches_dense_partial_last_block13 tokens = 3 full + 1 single-token block: the valid bound
test_paging_invarianceSame logical sequence at physical placements [0,1,2] vs [7,3,5] → identical output. The block table is the only coupling between logical and physical — Phase 2's identity theorem, restated where the math happens

Hitchhiker's notes

  • Map your variables to the CUDA kernel: in paged_attention_v1.cu, your m is qk_max (computed via warp/block reductions instead of max()), your denom is exp_sum, your acc lives in registers as accs, and your gather is the block_table-indexed pointer arithmetic in the main loop. Read the kernel right after finishing — it's ~400 lines of which you now understand the load-bearing 40; the rest is vectorized loads, shared-memory staging, and reduction plumbing (the "95% performance engineering" of Phase 2 lab-04).
  • Why subtract the max at all, again? exp(90) overflows float32. Logit ~90 is not exotic — it's a confident model with a sharp head. Unprotected softmax is a NaN factory; subtraction-by-max makes every exponent ≤ 0. The online version just maintains that protection without knowing the max in advance — that's the whole cleverness.
  • One query here, many heads in reality: real decode runs this once per (sequence, KV-head) with the query being that head's slice — heads are embarrassingly parallel and share nothing (Phase 2 lab-06's per-head loop). GQA means several query heads stream the same K/V blocks — bandwidth amortization inside the kernel, one more reason GQA wins (Phase 0 lab-02).
  • Numerics note for the tests' 1e-6: float64 throughout, so the tolerance is generous — it's calibrated to catch algorithmic error (a missing corr, an off-by-one), not rounding. In fp16 kernels the same comparison runs at 1e-2 with fp32 accumulators (Phase 2 lab-04's gate); the tolerance always encodes what you're testing for.

Going further

  • Hand-trace two blocks with two tokens each on paper, with block 2's max larger than block 1's — watch corr shrink the history. Then once with block 2's max smaller — watch corr = 1 and nothing rescale. The recurrence has exactly these two behaviors.
  • Delete the corr factor and run the tests: the aligned test fails with weights skewed toward later blocks. Now you know this failure's signature — useful the day you review a kernel PR that gets it almost right.
  • Batch it: take a list of (q, block_table, seq_len) and loop — you've built paged_attention_v1's grid (one program per sequence per head). Then go to lab-04 to split within a sequence, and lab-03 to widen to query chunks.

References

  • Milakov & Gimelshein, Online normalizer calculation for softmax (2018) — the recurrence, 3 readable pages: https://arxiv.org/abs/1805.02867
  • Dao et al., FlashAttention (2022) — the recurrence + tiling + IO analysis: https://arxiv.org/abs/2205.14135
  • upstream/csrc/attention/paged_attention_v1.cuqk_max, exp_sum, the gather: your function in CUDA.
  • upstream/vllm/v1/attention/backends/flash_attn.py:592 — where the real engine hands block tables and slot mappings to the kernel (find both; Phase 2 lab-06 explains the write side).
  • 02-mini-build.md — the recurrence derived step by step.