Lab 04-04 — Flash-Decoding: Split the Keys, Merge the Partials [CPU-OK]
Here's a problem your lab-01 kernel can't solve. One request, one decode query, a 128k-token context — and a GPU with 100+ streaming multiprocessors. The online-softmax loop is sequential over blocks: one SM grinds through 8,000 blocks while 99+ SMs watch. Decode latency for long contexts becomes a single-core problem on a massively parallel machine.
The fix — known as flash-decoding (Dao et al.), paged_attention_v2 in vLLM's CUDA,
split-k in FlashInfer — is the subject of this lab: partition the keys, attend each
partition independently and in parallel, then merge the partial results exactly. The
reason it works is a small piece of algebra worth owning forever: softmax-attention state
compresses to a triple (max, denominator, unnormalized-accumulator), and two such
triples combine associatively. You'll implement the triple, the merge, and prove
equality with dense attention for any partition count, any merge order, any tree shape.
Contents
- Why this lab exists
- Background: attention state is three numbers
- Files
- Run
- What to implement
- What the tests prove
- Hitchhiker's notes
- Going further
- References
Why this lab exists
This is the lab where "online softmax" stops being a trick you memorized and becomes a monoid you can wield. Lab-01's recurrence processes blocks left to right — it looks inherently sequential. The deep fact is that it isn't: the per-block update is just the binary merge applied repeatedly, and because the merge is associative and order-insensitive, you may evaluate it in any tree shape — including "all leaves in parallel, one combine at the end." Sequential streaming (FlashAttention), parallel split-KV (flash-decoding), and hierarchical reduction (multi-stage kernels) are the same algorithm under different parenthesizations.
Practically, this is also the difference between usable and unusable long-context decode.
Batch-1, long-context inference — the agentic workload, increasingly the workload — has
no batch parallelism to hide behind; parallelism must come from within the single
query's attention. When vLLM picks paged_attention_v2 over v1, or FlashInfer chooses a
split-k plan, the decision is "is this context long enough that splitting beats the merge
overhead?" After this lab you'll know exactly what's being weighed.
Background: attention state is three numbers
For a query q and any set of keys/values, define:
m = max_i s_i (s_i = k_i·q / √d)
denom = Σ_i exp(s_i − m)
acc = Σ_i exp(s_i − m) · v_i ← UNNORMALIZED (a vector)
(m, denom, acc) is a summary of attention over that key set: the final output is
acc / denom, but crucially you don't divide until the very end. Two summaries over
disjoint key sets merge by rescaling both to the shared max:
m* = max(m₁, m₂)
denom* = denom₁·e^{m₁−m*} + denom₂·e^{m₂−m*}
acc* = acc₁·e^{m₁−m*} + acc₂·e^{m₂−m*}
Check the properties: commutative (symmetry of the formulas), associative (both sides
reduce to "rescale everything to the global max and add"), and lab-01's per-block update
is exactly this merge where one side is a single block's summary. The exp(m−m*)
correction factors are the price of never having seen the global max in advance — and
they're also the numerical-stability mechanism: no exponential is ever taken of a
positive number, so nothing overflows even when one partition holds a monster logit
(test_extreme_scores_do_not_overflow feeds it a score of ~200, which would be inf
under naive softmax).
Files
starter.py—attend_partial(key range → summary),merge_partials(summaries → output),partitioned_attention(split, attend, merge). Your work.solution.py— reference (the whole thing is ~25 lines; the understanding is the deliverable).test_lab.py— identity at 1 partition, equality at any count, empty-chunk handling, order-invariance, hierarchical merging, and the overflow stress.
Run
LAB_IMPL=starter pytest phase-04-attention-backends/labs/lab-04-flash-decoding-partitions -q
pytest phase-04-attention-backends/labs/lab-04-flash-decoding-partitions -q # reference
What to implement
Follow the math above literally. The one design rule that matters: attend_partial
must not normalize. The moment you divide by the local denominator, the summary is no
longer mergeable — you've thrown away the weights needed to re-weight against other
partitions. (Returning normalized outputs and "averaging" them is the classic wrong
implementation; it passes the 1-partition test and fails every other one, which is
exactly why the 1-partition test isn't sufficient and the suite has six.)
What the tests prove
| Test | What it pins |
|---|---|
test_one_partition_is_just_attention | The degenerate case: summary → output round-trips |
test_any_partition_count_matches_dense | 2, 3, 7, 32, 100 partitions — all 1e-12-equal to dense. Partitioning is exact, not approximate; any tolerance bigger than rounding would hide real bugs |
test_more_partitions_than_keys | array_split hands you empty chunks; skip, don't crash. The GPU analogue: grid sized for max length, sequences shorter than the partition count |
test_merge_is_order_invariant | Reversed and shuffled partial lists give identical output — mandatory, because on hardware thread blocks finish in nondeterministic order |
test_merge_is_hierarchical | Merging merges = attending over the union: associativity, demonstrated. This is the license for tree reductions and multi-stage kernels |
test_extreme_scores_do_not_overflow | A ~200 logit in one partition: finite output, still 1e-12-equal. The running max isn't bookkeeping — it's the firewall |
Hitchhiker's notes
- Where this lives upstream:
upstream/csrc/attention/paged_attention_v2.cu— searchmax_logitsandexp_sums: those are yourmanddenom, written per partition to scratch buffers, merged by a second reduction kernel. The v1/v2 choice (v2 when partitioning pays) is made by the backend per launch. FlashInfer generalizes the same state into plan-based split-k; FlashAttention'sflash_attn_with_kvcacheexposes it asnum_splits. - The merge is also how cascade/shared-prefix attention works (FlashInfer's signature feature): attend over the shared system-prompt KV once for the whole batch (one summary, reused), attend per-request suffixes separately, merge each request's pair. Same triple, same combine — prefix caching meeting kernel design. That's three course threads (Phase 2 sharing, Phase 3 caching, this lab) converging on one formula.
- Why does sequential streaming still exist if parallel split is exact? Overhead: each partition writes its summary to global memory and a second kernel reads them back. For short contexts the round-trip costs more than it saves; for prefill the parallelism already comes from query rows (lab-03). Split-KV wins specifically at long-context decode — engineering is choosing the parenthesization that matches the hardware's idle dimension.
- This trick is older and bigger than attention: it's a parallel reduction over a non-trivial monoid, the same pattern as parallel max/sum/scan. The general skill — "can I summarize partial state so summaries combine associatively?" — is how you parallelize anything with a running normalizer. You'll meet it again in distributed softmax (Phase 10's context parallelism splits attention across GPUs with exactly this merge).
Going further
- Implement
merge_two(a, b) -> summary(summary × summary → summary, not output) and rebuildmerge_partialsas a fold; then as a balanced tree withfunctools.reduce-style pairing. Verify all shapes agree — you've now written the reduction the way the GPU executes it. - Combine with lab-01: make each partition gather through the block table (partition
= a contiguous range of logical blocks). That composition — paged + split-KV — is
precisely
paged_attention_v2. - Simulate the cascade pattern: 8 "requests" sharing a 512-token prefix with unique 64-token suffixes. Compute the prefix summary once + 8 suffix summaries, merge per request; compare against 8 dense computations. Measure the key-reads saved (should be ~7×512 rows) — FlashInfer's headline, reproduced in numpy.
References
- Dao et al., Flash-Decoding for Long-Context Inference (2023) — the technique, with the parallelism diagrams: https://pytorch.org/blog/flash-decoding/
- Milakov & Gimelshein, Online normalizer calculation for softmax (2018) — the merge formula's original home: https://arxiv.org/abs/1805.02867
- Ye et al., FlashInfer: Efficient and Customizable Attention Engine for LLM Serving (2024) — split-k plans and cascade/shared-prefix attention: https://arxiv.org/abs/2501.01005
upstream/csrc/attention/paged_attention_v2.cu—max_logits/exp_sums/ the reduce kernel: your lab, in CUDA.- Phase 10 — the same merge, stretched across GPUs (context parallelism).