Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Lab 04-04 — Flash-Decoding: Split the Keys, Merge the Partials [CPU-OK]

Here's a problem your lab-01 kernel can't solve. One request, one decode query, a 128k-token context — and a GPU with 100+ streaming multiprocessors. The online-softmax loop is sequential over blocks: one SM grinds through 8,000 blocks while 99+ SMs watch. Decode latency for long contexts becomes a single-core problem on a massively parallel machine.

The fix — known as flash-decoding (Dao et al.), paged_attention_v2 in vLLM's CUDA, split-k in FlashInfer — is the subject of this lab: partition the keys, attend each partition independently and in parallel, then merge the partial results exactly. The reason it works is a small piece of algebra worth owning forever: softmax-attention state compresses to a triple (max, denominator, unnormalized-accumulator), and two such triples combine associatively. You'll implement the triple, the merge, and prove equality with dense attention for any partition count, any merge order, any tree shape.

Contents


Why this lab exists

This is the lab where "online softmax" stops being a trick you memorized and becomes a monoid you can wield. Lab-01's recurrence processes blocks left to right — it looks inherently sequential. The deep fact is that it isn't: the per-block update is just the binary merge applied repeatedly, and because the merge is associative and order-insensitive, you may evaluate it in any tree shape — including "all leaves in parallel, one combine at the end." Sequential streaming (FlashAttention), parallel split-KV (flash-decoding), and hierarchical reduction (multi-stage kernels) are the same algorithm under different parenthesizations.

Practically, this is also the difference between usable and unusable long-context decode. Batch-1, long-context inference — the agentic workload, increasingly the workload — has no batch parallelism to hide behind; parallelism must come from within the single query's attention. When vLLM picks paged_attention_v2 over v1, or FlashInfer chooses a split-k plan, the decision is "is this context long enough that splitting beats the merge overhead?" After this lab you'll know exactly what's being weighed.

Background: attention state is three numbers

For a query q and any set of keys/values, define:

m     = max_i  s_i                    (s_i = k_i·q / √d)
denom = Σ_i exp(s_i − m)
acc   = Σ_i exp(s_i − m) · v_i        ← UNNORMALIZED (a vector)

(m, denom, acc) is a summary of attention over that key set: the final output is acc / denom, but crucially you don't divide until the very end. Two summaries over disjoint key sets merge by rescaling both to the shared max:

m* = max(m₁, m₂)
denom* = denom₁·e^{m₁−m*} + denom₂·e^{m₂−m*}
acc*   = acc₁·e^{m₁−m*}  + acc₂·e^{m₂−m*}

Check the properties: commutative (symmetry of the formulas), associative (both sides reduce to "rescale everything to the global max and add"), and lab-01's per-block update is exactly this merge where one side is a single block's summary. The exp(m−m*) correction factors are the price of never having seen the global max in advance — and they're also the numerical-stability mechanism: no exponential is ever taken of a positive number, so nothing overflows even when one partition holds a monster logit (test_extreme_scores_do_not_overflow feeds it a score of ~200, which would be inf under naive softmax).

Files

  • starter.pyattend_partial (key range → summary), merge_partials (summaries → output), partitioned_attention (split, attend, merge). Your work.
  • solution.py — reference (the whole thing is ~25 lines; the understanding is the deliverable).
  • test_lab.py — identity at 1 partition, equality at any count, empty-chunk handling, order-invariance, hierarchical merging, and the overflow stress.

Run

LAB_IMPL=starter pytest phase-04-attention-backends/labs/lab-04-flash-decoding-partitions -q
pytest phase-04-attention-backends/labs/lab-04-flash-decoding-partitions -q   # reference

What to implement

Follow the math above literally. The one design rule that matters: attend_partial must not normalize. The moment you divide by the local denominator, the summary is no longer mergeable — you've thrown away the weights needed to re-weight against other partitions. (Returning normalized outputs and "averaging" them is the classic wrong implementation; it passes the 1-partition test and fails every other one, which is exactly why the 1-partition test isn't sufficient and the suite has six.)

What the tests prove

TestWhat it pins
test_one_partition_is_just_attentionThe degenerate case: summary → output round-trips
test_any_partition_count_matches_dense2, 3, 7, 32, 100 partitions — all 1e-12-equal to dense. Partitioning is exact, not approximate; any tolerance bigger than rounding would hide real bugs
test_more_partitions_than_keysarray_split hands you empty chunks; skip, don't crash. The GPU analogue: grid sized for max length, sequences shorter than the partition count
test_merge_is_order_invariantReversed and shuffled partial lists give identical output — mandatory, because on hardware thread blocks finish in nondeterministic order
test_merge_is_hierarchicalMerging merges = attending over the union: associativity, demonstrated. This is the license for tree reductions and multi-stage kernels
test_extreme_scores_do_not_overflowA ~200 logit in one partition: finite output, still 1e-12-equal. The running max isn't bookkeeping — it's the firewall

Hitchhiker's notes

  • Where this lives upstream: upstream/csrc/attention/paged_attention_v2.cu — search max_logits and exp_sums: those are your m and denom, written per partition to scratch buffers, merged by a second reduction kernel. The v1/v2 choice (v2 when partitioning pays) is made by the backend per launch. FlashInfer generalizes the same state into plan-based split-k; FlashAttention's flash_attn_with_kvcache exposes it as num_splits.
  • The merge is also how cascade/shared-prefix attention works (FlashInfer's signature feature): attend over the shared system-prompt KV once for the whole batch (one summary, reused), attend per-request suffixes separately, merge each request's pair. Same triple, same combine — prefix caching meeting kernel design. That's three course threads (Phase 2 sharing, Phase 3 caching, this lab) converging on one formula.
  • Why does sequential streaming still exist if parallel split is exact? Overhead: each partition writes its summary to global memory and a second kernel reads them back. For short contexts the round-trip costs more than it saves; for prefill the parallelism already comes from query rows (lab-03). Split-KV wins specifically at long-context decode — engineering is choosing the parenthesization that matches the hardware's idle dimension.
  • This trick is older and bigger than attention: it's a parallel reduction over a non-trivial monoid, the same pattern as parallel max/sum/scan. The general skill — "can I summarize partial state so summaries combine associatively?" — is how you parallelize anything with a running normalizer. You'll meet it again in distributed softmax (Phase 10's context parallelism splits attention across GPUs with exactly this merge).

Going further

  • Implement merge_two(a, b) -> summary (summary × summary → summary, not output) and rebuild merge_partials as a fold; then as a balanced tree with functools.reduce-style pairing. Verify all shapes agree — you've now written the reduction the way the GPU executes it.
  • Combine with lab-01: make each partition gather through the block table (partition = a contiguous range of logical blocks). That composition — paged + split-KV — is precisely paged_attention_v2.
  • Simulate the cascade pattern: 8 "requests" sharing a 512-token prefix with unique 64-token suffixes. Compute the prefix summary once + 8 suffix summaries, merge per request; compare against 8 dense computations. Measure the key-reads saved (should be ~7×512 rows) — FlashInfer's headline, reproduced in numpy.

References

  • Dao et al., Flash-Decoding for Long-Context Inference (2023) — the technique, with the parallelism diagrams: https://pytorch.org/blog/flash-decoding/
  • Milakov & Gimelshein, Online normalizer calculation for softmax (2018) — the merge formula's original home: https://arxiv.org/abs/1805.02867
  • Ye et al., FlashInfer: Efficient and Customizable Attention Engine for LLM Serving (2024) — split-k plans and cascade/shared-prefix attention: https://arxiv.org/abs/2501.01005
  • upstream/csrc/attention/paged_attention_v2.cumax_logits / exp_sums / the reduce kernel: your lab, in CUDA.
  • Phase 10 — the same merge, stretched across GPUs (context parallelism).