Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Lab 08-03 — Rejection Sampling: Lossless Speculation with Temperature [CPU-OK]

Lab-01's greedy verify had an easy life: at temperature 0 there's exactly one right token, so "accept iff the draft equals it" is obviously lossless. But production serving samples — and now the claim that speculative decoding "doesn't change the output" becomes a real theorem with a real proof obligation: the verified output must be distributed exactly according to the target model's distribution p, no matter how wrong the drafter's q is. This lab has you implement the three-line algorithm that achieves it — accept draft x with probability min(1, p[x]/q[x]), else resample from the residual normalize(max(p − q, 0)) — and then verify the theorem empirically: 200,000 draws through a deliberately clueless uniform drafter land on the target distribution to within sampling noise. This is the mathematical heart of every speculative method in vLLM, from n-gram to EAGLE.

Contents


Why this lab exists

"Speculative decoding is lossless" is repeated everywhere and understood almost nowhere — most explanations stop at the greedy case, leaving the sampled case as folklore. But the sampled case is where the engineering risk lives: a subtly wrong residual, a missing clamp, a normalization slip — and your serving system is quietly sampling from a distribution that is not the model's, a bug invisible to every output-equality test (each individual output is plausible!) and detectable only distributionally. The defense, which you'll build, is the statistical test: histogram many draws, compare to p. If you ever touch rejection_sampler.py upstream — and spec-decode PRs touch it constantly — this lab's test design is how you protect the change.

The second deliverable is the acceptance-rate formula Σ min(p, q) — the overlap of the two distributions (equivalently 1 − total-variation distance). It converts "is the drafter any good?" from vibes into one number per position, and it's the alpha that lab-04's economics run on. Drafter evaluation, acceptance metrics in vLLM's logs, temperature's effect on speedup — all read off this one quantity.

Background: why the algorithm works

The output token's probability decomposes into "accepted draft" + "residual resample":

P(output = x) = q[x]·min(1, p[x]/q[x]) + P(reject)·residual[x]
              = min(p[x], q[x])        + (1 − Σ min(p,q)) · max(p[x]−q[x],0) / Σ max(p−q,0)

Since Σ max(p−q, 0) = 1 − Σ min(p, q) (the surplus equals the deficit — both are the TV distance), the second term simplifies to max(p[x] − q[x], 0), and:

P(output = x) = min(p[x], q[x]) + max(p[x] − q[x], 0) = p[x]      ∎

Read the proof's shape: where the draft over-serves (q > p), acceptance is throttled by exactly the ratio; where it under-serves (q < p), drafts always pass and the residual makes up precisely the shortfall. The two errors cancel by construction, not by luck — which is why the result holds for any q, including adversarially bad ones (your test_disjoint_distributions_never_accept: zero overlap, everything rejected, output still exactly p — spec decode degrades to baseline speed, never to wrongness. That graceful-degradation property is what makes it safe to deploy aggressively).

Files

  • starter.pyaccept_prob, residual_distribution, speculative_token, expected_acceptance_rate. Your work.
  • solution.py — reference.
  • test_lab.py — the formula edges, the residual, the empirical theorem (200k draws), the identical-distribution and disjoint-distribution limits, and the overlap formula against simulation.

Run

LAB_IMPL=starter pytest phase-08-speculative-decoding/labs/lab-03-rejection-sampling -q
pytest phase-08-speculative-decoding/labs/lab-03-rejection-sampling -q   # reference

What the tests prove

TestWhat it pins
test_accept_prob_formulaThe two regimes: p ≥ q → always accept; p < q → the exact ratio
test_residual_is_the_renormalized_surplusThe fallback distribution, value by value — get this wrong and the theorem dies silently
test_output_distribution_is_exactly_the_targetThe theorem, empirically: uniform drafter, skewed target, 200k draws, histogram ≈ p within 0.005. This is the test that catches the silent bug class
test_identical_distributions_always_acceptThe q = p limit: overlap 1, acceptance 1 — a perfect drafter is never rejected (and the p == q residual edge case stays well-defined)
test_acceptance_rate_is_the_overlapΣ min(p,q) = 0.70 for the lab's pair, confirmed by simulating the accept branch alone
test_disjoint_distributions_never_acceptThe adversarial limit: zero overlap → pure residual → still exactly p. Wrongness is impossible; only speed is at stake

The statistical tolerance (atol=0.005 at N=200k) is calibrated, not hand-waved: the binomial standard error at p=0.5 is √(0.25/200000) ≈ 0.0011, so 0.005 is ~4.5σ — tight enough to catch any real implementation error, loose enough to never flake. When you write distributional tests (and after this lab, you will), do this arithmetic.

Hitchhiker's notes

  • Greedy verify is this algorithm's zero-temperature limit: as temperature → 0, p and q collapse toward one-hots; min(1, p[x]/q[x]) becomes "1 if the argmaxes match, else 0", and the residual becomes the target's argmax. Lab-01 was a special case all along — upstream's RejectionSampler has the explicit greedy fast path (rejection_sampler.py:87) for exactly this case, because comparing argmaxes is cheaper than the full machinery.
  • Multi-token drafts chain this per position: verify token 1 against p₁; if accepted, token 2 against p₂ (computed with token 1 in context — the target's one batched forward scored all positions); first rejection stops the chain and resamples from that position's residual. The i.i.d.-ish per-position acceptance is the alpha lab-04 models. Crucially, all k+1 target distributions came from one forward pass — that batching is the entire economic basis (lab-04's cost = k·c + 1).
  • Where the probabilities come from matters: p and q here are post-temperature, post-top-p distributions — the verifier must apply the same sampling-parameter pipeline (Phase 0 lab-03) to both models' logits, or the ratio compares apples to oranges. Sampling-parameter mismatches between draft and target paths are a real upstream bug category; now you know what they corrupt.
  • The same trick generalizes — speculative sampling is importance-sampling-flavored rejection sampling with a guaranteed-exact fallback, and variants (tree drafts with multiple candidates per position, typical acceptance in Medusa/EAGLE-2) bend the acceptance rule while preserving the distributional identity. When reading any new spec-decode paper, find its version of this lemma first; everything else is scheduling.

Going further

  • Implement chained multi-token verification (speculative_sequence(p_list, q_list, k, rng)) and verify the joint distribution of two-token outputs matches sequential target sampling — the full lossless claim, one level up.
  • Measure acceptance vs temperature: fix logits, sweep T ∈ {0.2, 0.7, 1.0, 1.5} for both models, plot Σ min(p,q). Sharp distributions overlap more → spec decode loves low temperature — connect to lab-02's "code accepts at 80%" observation.
  • Break it on purpose: skip the min(1, ·) clamp (accept with raw p/q… capped how?) or forget to renormalize the residual, and watch which test catches each. Knowing the failure signatures is half the review skill.

References

  • Leviathan et al., Fast Inference from Transformers via Speculative Decoding (2022) — the theorem (their Theorem 3.1 / Appendix A): https://arxiv.org/abs/2211.17192
  • Chen et al., Accelerating Large Language Model Decoding with Speculative Sampling (2023) — the same result, DeepMind flavor: https://arxiv.org/abs/2302.01318
  • upstream/vllm/v1/sample/rejection_sampler.py — the production implementation: find the ratio, the residual, and the greedy fast path (:87).
  • Lab-04 — what Σ min(p,q) is worth in milliseconds; lab-01 — the zero-temperature special case you already built.