Lab 08-03 — Rejection Sampling: Lossless Speculation with Temperature [CPU-OK]
Lab-01's greedy verify had an easy life: at temperature 0 there's exactly one right
token, so "accept iff the draft equals it" is obviously lossless. But production serving
samples — and now the claim that speculative decoding "doesn't change the output"
becomes a real theorem with a real proof obligation: the verified output must be
distributed exactly according to the target model's distribution p, no matter how
wrong the drafter's q is. This lab has you implement the three-line algorithm that
achieves it — accept draft x with probability min(1, p[x]/q[x]), else resample from
the residual normalize(max(p − q, 0)) — and then verify the theorem empirically:
200,000 draws through a deliberately clueless uniform drafter land on the target
distribution to within sampling noise. This is the mathematical heart of every
speculative method in vLLM, from n-gram to EAGLE.
Contents
- Why this lab exists
- Background: why the algorithm works
- Files
- Run
- What the tests prove
- Hitchhiker's notes
- Going further
- References
Why this lab exists
"Speculative decoding is lossless" is repeated everywhere and understood almost nowhere
— most explanations stop at the greedy case, leaving the sampled case as folklore. But
the sampled case is where the engineering risk lives: a subtly wrong residual, a missing
clamp, a normalization slip — and your serving system is quietly sampling from a
distribution that is not the model's, a bug invisible to every output-equality test
(each individual output is plausible!) and detectable only distributionally. The
defense, which you'll build, is the statistical test: histogram many draws, compare to
p. If you ever touch rejection_sampler.py upstream — and spec-decode PRs touch it
constantly — this lab's test design is how you protect the change.
The second deliverable is the acceptance-rate formula Σ min(p, q) — the overlap of
the two distributions (equivalently 1 − total-variation distance). It converts "is the
drafter any good?" from vibes into one number per position, and it's the alpha that
lab-04's economics run on. Drafter evaluation, acceptance metrics in vLLM's logs,
temperature's effect on speedup — all read off this one quantity.
Background: why the algorithm works
The output token's probability decomposes into "accepted draft" + "residual resample":
P(output = x) = q[x]·min(1, p[x]/q[x]) + P(reject)·residual[x]
= min(p[x], q[x]) + (1 − Σ min(p,q)) · max(p[x]−q[x],0) / Σ max(p−q,0)
Since Σ max(p−q, 0) = 1 − Σ min(p, q) (the surplus equals the deficit — both are the
TV distance), the second term simplifies to max(p[x] − q[x], 0), and:
P(output = x) = min(p[x], q[x]) + max(p[x] − q[x], 0) = p[x] ∎
Read the proof's shape: where the draft over-serves (q > p), acceptance is throttled
by exactly the ratio; where it under-serves (q < p), drafts always pass and the
residual makes up precisely the shortfall. The two errors cancel by construction, not by
luck — which is why the result holds for any q, including adversarially bad ones
(your test_disjoint_distributions_never_accept: zero overlap, everything rejected,
output still exactly p — spec decode degrades to baseline speed, never to wrongness.
That graceful-degradation property is what makes it safe to deploy aggressively).
Files
starter.py—accept_prob,residual_distribution,speculative_token,expected_acceptance_rate. Your work.solution.py— reference.test_lab.py— the formula edges, the residual, the empirical theorem (200k draws), the identical-distribution and disjoint-distribution limits, and the overlap formula against simulation.
Run
LAB_IMPL=starter pytest phase-08-speculative-decoding/labs/lab-03-rejection-sampling -q
pytest phase-08-speculative-decoding/labs/lab-03-rejection-sampling -q # reference
What the tests prove
| Test | What it pins |
|---|---|
test_accept_prob_formula | The two regimes: p ≥ q → always accept; p < q → the exact ratio |
test_residual_is_the_renormalized_surplus | The fallback distribution, value by value — get this wrong and the theorem dies silently |
test_output_distribution_is_exactly_the_target | The theorem, empirically: uniform drafter, skewed target, 200k draws, histogram ≈ p within 0.005. This is the test that catches the silent bug class |
test_identical_distributions_always_accept | The q = p limit: overlap 1, acceptance 1 — a perfect drafter is never rejected (and the p == q residual edge case stays well-defined) |
test_acceptance_rate_is_the_overlap | Σ min(p,q) = 0.70 for the lab's pair, confirmed by simulating the accept branch alone |
test_disjoint_distributions_never_accept | The adversarial limit: zero overlap → pure residual → still exactly p. Wrongness is impossible; only speed is at stake |
The statistical tolerance (atol=0.005 at N=200k) is calibrated, not hand-waved: the
binomial standard error at p=0.5 is √(0.25/200000) ≈ 0.0011, so 0.005 is ~4.5σ —
tight enough to catch any real implementation error, loose enough to never flake. When
you write distributional tests (and after this lab, you will), do this arithmetic.
Hitchhiker's notes
- Greedy verify is this algorithm's zero-temperature limit: as temperature → 0,
pandqcollapse toward one-hots;min(1, p[x]/q[x])becomes "1 if the argmaxes match, else 0", and the residual becomes the target's argmax. Lab-01 was a special case all along — upstream'sRejectionSamplerhas the explicit greedy fast path (rejection_sampler.py:87) for exactly this case, because comparing argmaxes is cheaper than the full machinery. - Multi-token drafts chain this per position: verify token 1 against
p₁; if accepted, token 2 againstp₂(computed with token 1 in context — the target's one batched forward scored all positions); first rejection stops the chain and resamples from that position's residual. The i.i.d.-ish per-position acceptance is thealphalab-04 models. Crucially, allk+1target distributions came from one forward pass — that batching is the entire economic basis (lab-04'scost = k·c + 1). - Where the probabilities come from matters:
pandqhere are post-temperature, post-top-p distributions — the verifier must apply the same sampling-parameter pipeline (Phase 0 lab-03) to both models' logits, or the ratio compares apples to oranges. Sampling-parameter mismatches between draft and target paths are a real upstream bug category; now you know what they corrupt. - The same trick generalizes — speculative sampling is importance-sampling-flavored rejection sampling with a guaranteed-exact fallback, and variants (tree drafts with multiple candidates per position, typical acceptance in Medusa/EAGLE-2) bend the acceptance rule while preserving the distributional identity. When reading any new spec-decode paper, find its version of this lemma first; everything else is scheduling.
Going further
- Implement chained multi-token verification (
speculative_sequence(p_list, q_list, k, rng)) and verify the joint distribution of two-token outputs matches sequential target sampling — the full lossless claim, one level up. - Measure acceptance vs temperature: fix logits, sweep T ∈ {0.2, 0.7, 1.0, 1.5} for
both models, plot
Σ min(p,q). Sharp distributions overlap more → spec decode loves low temperature — connect to lab-02's "code accepts at 80%" observation. - Break it on purpose: skip the
min(1, ·)clamp (accept with rawp/q… capped how?) or forget to renormalize the residual, and watch which test catches each. Knowing the failure signatures is half the review skill.
References
- Leviathan et al., Fast Inference from Transformers via Speculative Decoding (2022) — the theorem (their Theorem 3.1 / Appendix A): https://arxiv.org/abs/2211.17192
- Chen et al., Accelerating Large Language Model Decoding with Speculative Sampling (2023) — the same result, DeepMind flavor: https://arxiv.org/abs/2302.01318
upstream/vllm/v1/sample/rejection_sampler.py— the production implementation: find the ratio, the residual, and the greedy fast path (:87).- Lab-04 — what
Σ min(p,q)is worth in milliseconds; lab-01 — the zero-temperature special case you already built.