Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Phase 04 — Mini-Build: paged attention with online softmax

You'll implement the heart of a "Flash"-style attention kernel in numpy — online softmax over a paged KV cache — and prove it equals plain dense attention. This single build demystifies both FlashAttention (the streaming softmax) and PagedAttention's kernel side (the block-table gather) at once.

Contents


The task (lab-01)

Given:

  • a query vector q (one decode step, one head): shape (d,),
  • a paged KV cache k_cache, v_cache: shape (num_blocks, block_size, d),
  • a block_table: list[int] mapping logical→physical block,
  • a seq_len (valid tokens),

compute attention(q) = softmax(q·Kᵀ / √d) · V, where K/V are gathered through the block table (token t lives at block_table[t // block_size], offset t % block_size), using the online softmax recurrence (running max + rescale) so you never build the full score vector.

Implement two functions and show they match:

  • dense_attention(q, K, V) — the reference (build all scores, softmax, weighted sum).
  • paged_online_attention(q, k_cache, v_cache, block_table, seq_len) — block-table gather + online softmax, processed block by block.

The online-softmax recurrence (from the guide)

m, denom, acc = -inf, 0, zeros(d)
for each block (gathered via block_table, up to seq_len):
    s = (q · Kblockᵀ) / sqrt(d)            # scores for this block's tokens
    m_new = max(m, s.max())
    corr = exp(m - m_new)
    acc = acc*corr + (exp(s - m_new) @ Vblock)
    denom = denom*corr + exp(s - m_new).sum()
    m = m_new
return acc / denom

Definition of done

pytest phase-04-attention-backends/labs -q

The test asserts paged_online_attention ≈ dense_attention within tolerance, for non-block-aligned seq_len (so you handle the partial last block), and that scattering the logical blocks to arbitrary physical ids doesn't change the result (that's the whole point of paging).

Map to the real engine

your numpyreal vLLM
block_table gatherthe block table fed to FlashAttentionImpl (flash_attn.py:592)
online softmaxthe FlashAttention/FlashInfer kernels
seq_len partial blockvarlen handling in the metadata builder (flash_attn.py:276)
dense referencewhat a naive (pre-Flash) kernel did, O(N²) memory