Phase 04 — Mini-Build: paged attention with online softmax
You'll implement the heart of a "Flash"-style attention kernel in numpy — online softmax over a paged KV cache — and prove it equals plain dense attention. This single build demystifies both FlashAttention (the streaming softmax) and PagedAttention's kernel side (the block-table gather) at once.
Contents
- The task (lab-01)
- The online-softmax recurrence (from the guide)
- Definition of done
- Map to the real engine
The task (lab-01)
Given:
- a query vector
q(one decode step, one head): shape(d,), - a paged KV cache
k_cache, v_cache: shape(num_blocks, block_size, d), - a
block_table: list[int]mapping logical→physical block, - a
seq_len(valid tokens),
compute attention(q) = softmax(q·Kᵀ / √d) · V, where K/V are gathered through the block
table (token t lives at block_table[t // block_size], offset t % block_size), using the
online softmax recurrence (running max + rescale) so you never build the full score vector.
Implement two functions and show they match:
dense_attention(q, K, V)— the reference (build all scores, softmax, weighted sum).paged_online_attention(q, k_cache, v_cache, block_table, seq_len)— block-table gather + online softmax, processed block by block.
The online-softmax recurrence (from the guide)
m, denom, acc = -inf, 0, zeros(d)
for each block (gathered via block_table, up to seq_len):
s = (q · Kblockᵀ) / sqrt(d) # scores for this block's tokens
m_new = max(m, s.max())
corr = exp(m - m_new)
acc = acc*corr + (exp(s - m_new) @ Vblock)
denom = denom*corr + exp(s - m_new).sum()
m = m_new
return acc / denom
Definition of done
pytest phase-04-attention-backends/labs -q
The test asserts paged_online_attention ≈ dense_attention within tolerance, for non-block-aligned
seq_len (so you handle the partial last block), and that scattering the logical blocks to
arbitrary physical ids doesn't change the result (that's the whole point of paging).
Map to the real engine
| your numpy | real vLLM |
|---|---|
block_table gather | the block table fed to FlashAttentionImpl (flash_attn.py:592) |
| online softmax | the FlashAttention/FlashInfer kernels |
seq_len partial block | varlen handling in the metadata builder (flash_attn.py:276) |
| dense reference | what a naive (pre-Flash) kernel did, O(N²) memory |