Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Lab 04-03 — Causal Prefill Attention over a Paged Cache [CPU-OK]

Lab-01 gave you the decode kernel shape: one query, N keys. But every token in that cache got there through the other shape — prefill: M queries at once (a prompt, or a chunk of one), each allowed to see only its own past. In this lab you build the prefill shape on top of your lab-01 recurrence, with the two ingredients that make it interesting: the causal mask (query i attends to positions 0..start_pos+i, nothing later) and the start_pos offset that makes chunked prefill possible at the kernel level. The payoff test proves, in attention outputs rather than scheduler bookkeeping, the invariant Phase 3 lab-02 promised you: prefilling in chunks computes exactly what one-shot prefill computes.

Contents


Why this lab exists

Every attention backend in vLLM ships (at least) two code paths, and PRs routinely touch one and break the other. If you've only ever written the decode path, prefill kernels read as a wall of index arithmetic: why does the mask depend on a start offset? why does the kernel receive query_start_loc arrays? what exactly must hold for a chunk computed today to splice seamlessly with a chunk computed three steps ago? This lab makes you derive all three answers, because you need them to make four tests pass.

It also closes a loop the course opened two phases ago. Phase 3 proved "chunking changes when, never what" behaviorally — same tokens out of the engine. But that proof leaned on the kernel doing its part: a query at absolute position 7, computed in a chunk that starts at position 5, must attend over tokens 0–7 exactly as it would have in a one-shot prefill. That's a property of the attention math plus the cache, and here you verify it at the layer where it actually lives — test_chunked_equals_one_shot is Phase 3 lab-02 restated in linear algebra.

Background: one mechanism, two shapes

The contract when a prefill chunk runs (this ordering is upstream's, and Phase 2 lab-06's):

  1. The runner writes the chunk's K/V firstslot_mapping scatters rows for positions start_pos..start_pos+M−1 into the paged cache. So by the time attention runs, the cache holds tokens 0..start_pos+M−1: everything each query may legally see.
  2. The kernel then computes, for each query row i (absolute position start_pos+i), attention over the causal prefix [0, start_pos+i] — gathered through the block table, streamed with online softmax, exactly your lab-01 loop with a per-query length.

Note what the causal mask is not: a -inf matrix you materialize. In a streaming kernel the mask degenerates into a loop bound — you simply stop reading keys at the query's own position. (Real kernels processing key tiles need the mask only for the one diagonal tile where queries and keys overlap; every earlier tile is all-visible, every later tile is skipped entirely. "The mask is mostly a loop bound" is why causal attention costs half of bidirectional, not the same with masking overhead.)

And start_pos is the entire kernel-side story of chunked prefill: a chunk is just a prefill whose queries don't start at zero. No special "resume" state — the cache is the state, which is the same insight (the counter/cache is the resume mechanism) you've now met in the scheduler (Phase 3), in preemption recovery (Phase 3 lab-04), and here in the kernel.

Files

  • starter.pydense_causal_attention (the reference) and paged_causal_prefill_attention (the paged, online-softmax version). Your work.
  • solution.py — reference; note how it reuses the lab-01 recurrence as an inner function — the decode kernel is literally a sub-case.
  • test_lab.py — full prefill, mid-sequence chunk, chunked ≡ one-shot, and the poisoned-future causality test.

Run

LAB_IMPL=starter pytest phase-04-attention-backends/labs/lab-03-causal-prefill-attention -q
pytest phase-04-attention-backends/labs/lab-03-causal-prefill-attention -q   # reference

What to implement

Two functions. The dense reference is a per-query loop: slice the causal prefix, score, softmax, blend. The paged version wraps your lab-01 recurrence: for query i, run the block-streaming loop with seq_len = start_pos + i + 1. That +1 is load-bearing — a token does attend to itself (its K/V are in the cache before its attention runs; see the contract above). Off-by-one it and test_full_prefill_from_position_zero fails on the very first row, where the prefix is exactly one token.

What the tests prove

TestWhat it pins
test_full_prefill_from_position_zeroThe base case (start_pos=0), with a partial last block — 13 tokens in 4 blocks
test_mid_sequence_chunkThe chunked case: queries for positions 5–8 of a 9-token cache attend over exactly the right prefixes despite starting mid-block
test_chunked_equals_one_shotThe phase-bridging invariant: 12 positions as one chunk ≡ as 5 + 7 — every output row identical to 1e-9. Phase 3 lab-02's theorem, at the layer where it's actually enforced
test_causality_future_tokens_are_invisibleA 1e3 "loud future" in the last token's K/V changes only the last query's row. Rows 0–6 provably deaf to it. A non-causal bug here doesn't crash — it leaks the future into every token, the model trains on nothing like it, and outputs degrade mysteriously. This test makes the leak deafening instead

The poison technique is Phase 2 lab-06's trick pointed at a different boundary: there it guarded seq_len masking, here it guards the causal frontier. Same principle — make the forbidden region catastrophic to touch, then prove nothing touched it.

Hitchhiker's notes

  • Why is prefill compute-bound while decode is bandwidth-bound when it's the same math? Count the reuse: in prefill, each gathered K/V block is dotted against many query rows (every query whose prefix covers it); in decode, against exactly one. That's the arithmetic-intensity difference of Phase 0 lab-04, visible in this very loop nest — and it's why real prefill kernels tile over both queries and keys (FlashAttention's 2D blocking) while decode kernels tile only keys (lab-01/lab-04 shapes).
  • query_start_loc and friends: real batches contain many requests' chunks concatenated; upstream passes per-request offsets (query_start_loc, seq_lens) so one kernel launch handles a ragged batch. Your start_pos is the single-request version of that metadata. Find the production form in upstream/vllm/v1/attention/backends/flash_attn.py (FlashAttentionMetadata).
  • The solution's per-query inner loop is honest but quadratic in reads — it re-gathers shared prefix blocks once per query. Real kernels invert the nest (outer loop over key tiles, inner over queries, with the diagonal-tile mask) precisely to read each block once. Try the inversion as an exercise; the recurrence per query row is unchanged, which is the point — the math doesn't care which loop is outside.
  • Sliding-window attention (Mistral et al.) is one more loop-bound tweak: the prefix becomes [max(0, pos−W+1), pos]. If you can place the causal bound, you can place the window bound — and you now know why window support is a per-backend feature flag rather than a model-side trick.

Going further

  • Vectorize the dense reference into a single masked matmul (scores + np.triu(-inf, k=1+start_pos_offset)) and check it against your loop — then notice the materialized (M, N) score matrix is exactly what FlashAttention exists to avoid.
  • Invert the loop nest (key-tiles outer) as sketched above and re-run the suite — same four green tests, different memory behavior. You've reproduced the actual structure of flash_attn's prefill kernel.
  • Implement sliding-window (window parameter, prefix start max(0, pos−W+1)) and write the poison test for the left boundary: a loud token just outside the window must be inaudible.

References

  • Dao et al., FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (2022) — the 2D-tiled prefill kernel this lab is the skeleton of: https://arxiv.org/abs/2205.14135
  • Dao, FlashAttention-2 (2023) — the loop-nest inversion and work partitioning: https://arxiv.org/abs/2307.08691
  • upstream/vllm/v1/attention/backends/flash_attn.pyFlashAttentionMetadata: query_start_loc, seq_lens, and the cascade of shapes one launch handles.
  • Phase 3 lab-02 — the engine-level statement of test_chunked_equals_one_shot.
  • Phase 2 lab-06 — the write path (slot_mapping) that fills the cache this lab reads.