Lab 09-04 — Per-Request RNG & Batch Invariance [CPU-OK]
Here's a bug report you will eventually receive: "I set seed=7, temperature 1.0, and
I get different outputs every time. Your API is broken." The API isn't broken — but it
would be, in exactly this way, if the sampler used one shared random generator for the
whole batch. Whoever shares the batch with you consumes numbers from the shared stream
and shifts yours; your "seeded" request reproduces only when the entire fleet's traffic
reproduces. This lab builds the fix — per-request generator state — and proves the
contract that production samplers must honor: a seeded request's token stream is
identical whether it runs alone or interleaved with five neighbors. The test suite
includes the broken shared-RNG sampler as a control, so you see the failure, not just
read about it.
Contents
- Why this lab exists
- Background: randomness as private state
- Files
- Run
- What the tests prove
- Hitchhiker's notes
- Going further
- References
Why this lab exists
Determinism under batching is one of those properties that's trivial to state,
genuinely subtle to deliver, and commercially important: seeded sampling is how users
build reproducible evals, how you bisect a generation bug ("same seed, same output —
now change one thing"), and how A/B tests hold the noise still. And it's violated by
the most natural implementation — rng = np.random.default_rng(...) at sampler scope,
draw per request in batch order — which works perfectly in every single-request test
and fails the moment two users share a step. Bugs that pass single-user tests and fail
under concurrency are the defining bug class of serving systems; this lab is a clean
specimen you can build, break, and internalize in twenty minutes.
It also completes the phase's batching story: lab-01 gave you the per-request pipeline (each request its own temperature/top-k/penalties), lab-02 showed requests sharing compute and KV, and this lab adds the last isolation boundary — randomness. The full production picture is vLLM's sampling metadata: per-request parameters, per-request generators, batched execution. Shared work, private state.
Background: randomness as private state
Three requirements, each pinned by a test:
- Reproducibility: same seed + same logits stream → same tokens, across process restarts and sampler instances. (The generator must be created from the seed, deterministically.)
- Continuity: a request's draws across its decode steps come from one continuing
stream — create the generator once per request, not once per step. Re-seeding every
step is the sneaky variant bug: step 1 is correct, and every step draws the same
"random" number (
test_request_stream_is_stateful_not_resetconstructs a uniform distribution where this is visible — on peaked distributions it hides, which is what makes it sneaky). - Isolation (batch invariance): request A's stream must be untouched by neighbors' draws. This is what per-request state buys; the shared-RNG control test shows the alternative failing.
Plus the greedy rule from Phase 0 lab-03, restated with a reason: temperature == 0
must touch no RNG at all — not "use a default seed," no draw — so greedy requests are
reproducible without any seed bookkeeping, and so they don't perturb anyone else's
stream either (a greedy request that consumed RNG would break a seeded neighbor's
invariance — isolation cuts both ways).
Files
starter.py—PerRequestSampler: a generator dict keyed by request id, a greedy fast path, one draw per call. ~15 lines. Your work.solution.py— reference.test_lab.py— reproducibility, divergence across seeds, the invariance contract, the shared-RNG failure (control), greedy's RNG-free path, and stream continuity.
Run
LAB_IMPL=starter pytest phase-09-sampling-and-decoding-algorithms/labs/lab-04-seeded-rng-batch-invariance -q
pytest phase-09-sampling-and-decoding-algorithms/labs/lab-04-seeded-rng-batch-invariance -q # reference
What the tests prove
| Test | What it pins |
|---|---|
test_same_seed_reproduces_across_instances | Requirement 1: the stream is a pure function of the seed |
test_different_seeds_diverge | Seeds actually matter (a sampler that ignores its seed passes test 1 vacuously — paired tests close the loophole) |
test_batch_invariance | The contract: A's stream with 0, 1, and 5 interleaved neighbors — identical. The neighbors even sample before A each step, the worst case for a shared stream |
test_shared_rng_breaks_batch_invariance | The control: one global generator, same scenario — neighbors shift A's tokens. The bug, demonstrated rather than asserted |
test_greedy_ignores_seed_and_rng_state | Temperature 0 → argmax, no RNG touched, any seed |
test_request_stream_is_stateful_not_reset | Requirement 2: two draws match a reference generator's first two draws — not the first draw twice |
The test-design pattern is worth keeping: every isolation claim ships with a broken control. "X holds" plus "here is the natural implementation where X fails" teaches reviewers what the protective code protects against — and stops the next refactorer from "simplifying" the generator dict away.
Hitchhiker's notes
- Where this lives upstream: vLLM keeps per-request
torch.Generatorobjects in its sampling state (seeded requests get their own; see the generator plumbing inupstream/vllm/v1/worker/gpu_input_batch.pyand the sampler). The batched GPU sampler does the draws vectorized, but seeded rows use their private generator state — the exact structure of your dict, tensor-shaped. - What batch invariance does not promise: bitwise-identical logits. Different batch compositions change kernel tiling and reduction order (the recurring last-ulp story — Phases 3/4/6), so two near-tied tokens can flip even with perfect RNG isolation. True end-to-end batch-invariant inference requires batch-invariant kernels as well — a real, recent line of engineering work (deterministic-inference modes); RNG isolation is the necessary first floor, not the whole building. Know which layer a nondeterminism report belongs to before debugging it.
- Cleanup is part of the contract: request ids recycle; a production sampler must drop a request's generator when it finishes (your dict grows forever — fine for a lab, a leak in a server). Per-request anything implies a lifecycle hook — tie it to Phase 1's reaping path mentally.
- Why not one generator seeded per (request, step)? It "fixes" continuity bugs by construction but costs a generator init per token and — worse — makes the stream depend on step numbering, which shifts under speculative decoding (Phase 8: a cycle emits several tokens). Stream-per-request is the design that survives feature composition; most alternatives quietly don't.
Going further
- Add
finish(request_id)and a test that a recycled id with a new seed starts a fresh stream (the leak-plus-collision bug, both halves). - Vectorize:
sample_batch(ids, logits_matrix, temps, seeds)doing one softmax over the batch but per-row draws from per-row generators — the actual shape of the GPU sampler. Verify batch invariance still holds (it must: that's the point of the structure). - Compose with Phase 8: simulate a speculative cycle (k draft draws + residual draws from lab 08-03) using the request's generator, and check that a request's output is invariant to whether speculation was enabled given the same accepted tokens. (It isn't, in general — spec decode consumes RNG differently. Production systems accept this; knowing why is the exercise.)
References
upstream/vllm/v1/worker/gpu_input_batch.py— per-request generator state in the input batch (searchgenerator).upstream/vllm/v1/sample/sampler.py— where seeded rows meet the batched sampler.- vLLM docs, Sampling Parameters — the
seedfield's contract: https://docs.vllm.ai/en/latest/api/inference_params.html - Thinking Machines, Defeating Nondeterminism in LLM Inference (2025) — the kernel layer of this problem (batch-invariant kernels), for the full picture: https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/
- Phase 0 lab-03 — the greedy fast path's origin; lab-01 — the per-request pipeline this lab adds state isolation to.