Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Lab 02-06 — Paged Attention in Pure Numpy [CPU-OK]

The whole phase, you've been told the kernel "just follows the block table." Here you make that sentence true with your own hands — no GPU, no Triton, no excuses. You'll implement the complete data path of one decode step: build the slot_mapping (the write map), scatter new K/V into a shuffled physical cache (write_kv), then gather it all back through the block table and compute attention (paged_attention) — and prove, to 1e-12, that the result is identical to attention over a contiguous cache.

Do this lab before lab-04 (Triton) — it's the same algorithm; lab-04 just adds the GPU dialect and the online-softmax streaming. If you have no GPU, this lab is your kernel lab, with nothing lost but the silicon.

Contents


Why this lab exists

There's a gap in most people's understanding of PagedAttention, right between "the allocator hands out block ids" (labs 01–05) and "the CUDA kernel is fast" (lab-04, Phase 4). The gap is the data path: how, concretely, does a token's K vector end up at a physical address, and how does attention find it again? Numpy is the perfect language for closing that gap — fancy indexing makes the scatter and the gather each a single line, so the indirection stands alone with zero kernel noise around it.

The deeper point this lab proves is the load-bearing theorem of the whole phase:

Gather-through-a-table is mathematically the identity. Paging changes where bytes live, never what they mean. Attention over a paged cache is not an approximation of dense attention — it is dense attention, composed with a permutation.

The tests don't just claim this; they check it to 1e-12 (same dtype, same operation order ⇒ the only differences would be real bugs, not float noise), and then check it twice — same logical content under two different physical layouts must produce bit-identical answers. When you later benchmark real paged kernels and someone asks "but does paging hurt accuracy?", you'll have the right reflex: it can't; only masking or indexing bugs can.

Background: two maps, two directions

Each engine step, the model runner (upstream: gpu_model_runner.py) turns the scheduler's block tables into two tensors for the kernels — and they answer opposite questions:

  • slot_mappingwrite map, one entry per scheduled token this step: "put this new token's K/V at this flat cache row." For a decode step that's a single entry per request (start = current length, num_tokens = 1); for a prefill chunk it's the chunk's whole range. The formula is the phase's one formula: slot(t) = block_table[t // block_size] * block_size + t % block_size.
  • block_tableread map, one entry per logical block of the whole sequence: "all prior KV for this request lives in these physical blocks, in this logical order." The attention kernel gathers through it every step.

Write one token; read them all. That asymmetry is the decode workload in a nutshell, and it's why decode is memory-bandwidth-bound: the gather touches seq_len × heads × dim × 2 values to produce one token.

Files

  • starter.py — three functions with the recipes in their docstrings. Your work.
  • solution.py — reference (the gather really is one line).
  • test_lab.py — formula checks, round-trip, dense-equivalence, the poison-masked tail, and the two-layouts identity.

Run

LAB_IMPL=starter pytest phase-02-paged-attention/labs/lab-06-paged-attention-numpy -q
pytest phase-02-paged-attention/labs/lab-06-paged-attention-numpy -q   # reference (default)

What to implement

  1. build_slot_mapping(block_table, block_size, start, num_tokens) — the formula, over a token range. The start parameter is not decoration: a decode step writes one token at start = seq_len, a chunked prefill writes a range starting mid-sequence — getting ranges right here is exactly what makes chunked prefill (Phase 3) compose with paging.
  2. write_kv(...) — scatter new_k/new_v rows to slot_mapping rows. Numpy fancy indexing (cache[slots] = new) — one line each, and a quiet preview of what reshape_and_cache does in CUDA upstream.
  3. paged_attention(q, k_cache, v_cache, block_table, seq_len, block_size) — gather seq_len rows through the table, then per head: softmax(K·q/√d)·V. Subtract the max before exp (the standard stability trick — and the seed of the online softmax you'll meet in lab-04).

What the tests prove — including the poison trick

TestWhat it pins
test_slot_mapping_formulaThe formula at the edges: block boundaries, mid-block offsets, and the single-token decode case
test_write_then_gather_round_tripsWrite map and read map agree — the two tensors are consistent views of one layout
test_paged_matches_dense_exactlyThe identity theorem, atol=1e-12, under a shuffled, non-identity block table
test_partial_tail_block_is_maskedThe bug that ships: seq_len=35 fills 2 blocks + 3 slots; the other 13 slots of the tail block are poisoned with 1e6 before the call. If your gather uses len(block_table) * block_size rows instead of seq_len, the poison detonates and the diff is enormous — by design. Real kernels' masking bugs are subtle precisely because real garbage memory is small numbers; in tests, make garbage loud.
test_indirection_is_the_identitySame logical tokens, two different physical placements → identical output. Physical layout is unobservable from the math

That poison-the-padding trick is worth stealing for every masked computation you ever test: don't hope the unmasked path is never read — make reading it catastrophic.

Hitchhiker's notes

  • Your gather is a memcpy the GPU never does. k_cache[slots] materializes a contiguous copy of K — fine in numpy, ruinous on a GPU (it would double memory traffic for the engine's hottest loop). The real kernel follows the indirection inside the compute, loading each block tile straight from its physical address into registers/SRAM. Same semantics, zero copies — that difference is the entire reason kernel-level paging support (lab-04) has to exist at all, rather than a gather-then-dense-kernel two-step.
  • Why per-head loops? Clarity. Attention is independent per head; vectorizing over heads (einsum) is a one-liner you should try after green, and it changes nothing semantically. The real kernel parallelizes over (sequence, head) pairs — your loop nest, mapped to the GPU grid.
  • 1e-12, not 1e-2. Lab-04 tolerates 1e-2 because fp16 + a different operation order (online softmax) genuinely changes rounding. Here, same dtype (float64) and same order mean the comparison can be essentially exact. Calibrating tolerance to the reason for divergence — instead of slapping 1e-3 on everything — is a numerics habit that catches real bugs other suites wave through.
  • GQA fits in one index. Llama-style models have fewer KV heads than query heads; the cache shape grows a num_kv_heads dimension and several query heads share a KV head. The block table doesn't change at all — paging is orthogonal to head layout. (Try it: KV_HEADS = 2, map query head h to KV head h // 2. Ten lines.)

Going further

  • Batch it: extend paged_attention to take a batch of queries with a ragged set of block tables and seq_lens — now you've implemented the actual decode-batch kernel interface (compare with paged_attention_v1.cu's argument list: it's your signature, plus strides).
  • Chunked-prefill write path: simulate prefilling a 40-token prompt in chunks of 16 using build_slot_mapping(start=16, ...) etc., then attend. You've just verified the Phase 3 invariant (chunking changes when, never what) at the memory level.
  • Measure the gather tax in numpy: time k_cache[slots] vs a contiguous slice of the same size for seq_len = 64k. The scatter-gather costs real bandwidth even on CPU — now reread lab-04's note on why GPUs fold it into the kernel.

References

  • upstream/vllm/v1/worker/gpu_model_runner.py — search slot_mapping: where both maps are built from scheduler output, every step.
  • upstream/csrc/cache_kernels.cureshape_and_cache: your write_kv, in CUDA.
  • upstream/csrc/attention/paged_attention_v1.cu — your paged_attention, with the performance engineering attached.
  • Kwon et al., PagedAttention (SOSP 2023), §4.3 — kernel-side gather design: https://arxiv.org/abs/2309.06180
  • Milakov & Gimelshein, Online normalizer calculation for softmax (2018) — what your max-subtraction becomes when the row streams in blocks: https://arxiv.org/abs/1805.02867