Lab 04-03 — Causal Prefill Attention over a Paged Cache [CPU-OK]
Lab-01 gave you the decode kernel shape: one query, N keys. But every token in that
cache got there through the other shape — prefill: M queries at once (a prompt, or a
chunk of one), each allowed to see only its own past. In this lab you build the prefill
shape on top of your lab-01 recurrence, with the two ingredients that make it interesting:
the causal mask (query i attends to positions 0..start_pos+i, nothing later) and
the start_pos offset that makes chunked prefill possible at the kernel level. The
payoff test proves, in attention outputs rather than scheduler bookkeeping, the invariant
Phase 3 lab-02 promised you: prefilling in chunks computes exactly what one-shot prefill
computes.
Contents
- Why this lab exists
- Background: one mechanism, two shapes
- Files
- Run
- What to implement
- What the tests prove
- Hitchhiker's notes
- Going further
- References
Why this lab exists
Every attention backend in vLLM ships (at least) two code paths, and PRs routinely touch
one and break the other. If you've only ever written the decode path, prefill kernels read
as a wall of index arithmetic: why does the mask depend on a start offset? why does the
kernel receive query_start_loc arrays? what exactly must hold for a chunk computed today
to splice seamlessly with a chunk computed three steps ago? This lab makes you derive all
three answers, because you need them to make four tests pass.
It also closes a loop the course opened two phases ago. Phase 3 proved "chunking changes
when, never what" behaviorally — same tokens out of the engine. But that proof leaned on
the kernel doing its part: a query at absolute position 7, computed in a chunk that starts
at position 5, must attend over tokens 0–7 exactly as it would have in a one-shot
prefill. That's a property of the attention math plus the cache, and here you verify it at
the layer where it actually lives — test_chunked_equals_one_shot is Phase 3 lab-02
restated in linear algebra.
Background: one mechanism, two shapes
The contract when a prefill chunk runs (this ordering is upstream's, and Phase 2 lab-06's):
- The runner writes the chunk's K/V first —
slot_mappingscatters rows for positionsstart_pos..start_pos+M−1into the paged cache. So by the time attention runs, the cache holds tokens0..start_pos+M−1: everything each query may legally see. - The kernel then computes, for each query row i (absolute position
start_pos+i), attention over the causal prefix[0, start_pos+i]— gathered through the block table, streamed with online softmax, exactly your lab-01 loop with a per-query length.
Note what the causal mask is not: a -inf matrix you materialize. In a streaming kernel
the mask degenerates into a loop bound — you simply stop reading keys at the query's
own position. (Real kernels processing key tiles need the mask only for the one diagonal
tile where queries and keys overlap; every earlier tile is all-visible, every later tile is
skipped entirely. "The mask is mostly a loop bound" is why causal attention costs half of
bidirectional, not the same with masking overhead.)
And start_pos is the entire kernel-side story of chunked prefill: a chunk is just a
prefill whose queries don't start at zero. No special "resume" state — the cache is the
state, which is the same insight (the counter/cache is the resume mechanism) you've now
met in the scheduler (Phase 3), in preemption recovery (Phase 3 lab-04), and here in the
kernel.
Files
starter.py—dense_causal_attention(the reference) andpaged_causal_prefill_attention(the paged, online-softmax version). Your work.solution.py— reference; note how it reuses the lab-01 recurrence as an inner function — the decode kernel is literally a sub-case.test_lab.py— full prefill, mid-sequence chunk, chunked ≡ one-shot, and the poisoned-future causality test.
Run
LAB_IMPL=starter pytest phase-04-attention-backends/labs/lab-03-causal-prefill-attention -q
pytest phase-04-attention-backends/labs/lab-03-causal-prefill-attention -q # reference
What to implement
Two functions. The dense reference is a per-query loop: slice the causal prefix, score,
softmax, blend. The paged version wraps your lab-01 recurrence: for query i, run the
block-streaming loop with seq_len = start_pos + i + 1. That +1 is load-bearing — a
token does attend to itself (its K/V are in the cache before its attention runs; see
the contract above). Off-by-one it and test_full_prefill_from_position_zero fails on the
very first row, where the prefix is exactly one token.
What the tests prove
| Test | What it pins |
|---|---|
test_full_prefill_from_position_zero | The base case (start_pos=0), with a partial last block — 13 tokens in 4 blocks |
test_mid_sequence_chunk | The chunked case: queries for positions 5–8 of a 9-token cache attend over exactly the right prefixes despite starting mid-block |
test_chunked_equals_one_shot | The phase-bridging invariant: 12 positions as one chunk ≡ as 5 + 7 — every output row identical to 1e-9. Phase 3 lab-02's theorem, at the layer where it's actually enforced |
test_causality_future_tokens_are_invisible | A 1e3 "loud future" in the last token's K/V changes only the last query's row. Rows 0–6 provably deaf to it. A non-causal bug here doesn't crash — it leaks the future into every token, the model trains on nothing like it, and outputs degrade mysteriously. This test makes the leak deafening instead |
The poison technique is Phase 2 lab-06's trick pointed at a different boundary: there it
guarded seq_len masking, here it guards the causal frontier. Same principle — make the
forbidden region catastrophic to touch, then prove nothing touched it.
Hitchhiker's notes
- Why is prefill compute-bound while decode is bandwidth-bound when it's the same math? Count the reuse: in prefill, each gathered K/V block is dotted against many query rows (every query whose prefix covers it); in decode, against exactly one. That's the arithmetic-intensity difference of Phase 0 lab-04, visible in this very loop nest — and it's why real prefill kernels tile over both queries and keys (FlashAttention's 2D blocking) while decode kernels tile only keys (lab-01/lab-04 shapes).
query_start_locand friends: real batches contain many requests' chunks concatenated; upstream passes per-request offsets (query_start_loc,seq_lens) so one kernel launch handles a ragged batch. Yourstart_posis the single-request version of that metadata. Find the production form inupstream/vllm/v1/attention/backends/flash_attn.py(FlashAttentionMetadata).- The solution's per-query inner loop is honest but quadratic in reads — it re-gathers shared prefix blocks once per query. Real kernels invert the nest (outer loop over key tiles, inner over queries, with the diagonal-tile mask) precisely to read each block once. Try the inversion as an exercise; the recurrence per query row is unchanged, which is the point — the math doesn't care which loop is outside.
- Sliding-window attention (Mistral et al.) is one more loop-bound tweak: the prefix
becomes
[max(0, pos−W+1), pos]. If you can place the causal bound, you can place the window bound — and you now know why window support is a per-backend feature flag rather than a model-side trick.
Going further
- Vectorize the dense reference into a single masked matmul
(
scores + np.triu(-inf, k=1+start_pos_offset)) and check it against your loop — then notice the materialized(M, N)score matrix is exactly what FlashAttention exists to avoid. - Invert the loop nest (key-tiles outer) as sketched above and re-run the suite — same
four green tests, different memory behavior. You've reproduced the actual structure of
flash_attn's prefill kernel. - Implement sliding-window (
windowparameter, prefix startmax(0, pos−W+1)) and write the poison test for the left boundary: a loud token just outside the window must be inaudible.
References
- Dao et al., FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (2022) — the 2D-tiled prefill kernel this lab is the skeleton of: https://arxiv.org/abs/2205.14135
- Dao, FlashAttention-2 (2023) — the loop-nest inversion and work partitioning: https://arxiv.org/abs/2307.08691
upstream/vllm/v1/attention/backends/flash_attn.py—FlashAttentionMetadata:query_start_loc,seq_lens, and the cascade of shapes one launch handles.- Phase 3 lab-02 — the engine-level statement of
test_chunked_equals_one_shot. - Phase 2 lab-06 — the write path (
slot_mapping) that fills the cache this lab reads.