Phase 09 — Deep Dive: the batched sampler
Paths relative to
upstream/atv0.22.1 @ 0decac0.vllm/v1/sample/sampler.py the batched Sampler (orchestrates the pipeline) vllm/v1/sample/ops/topk_topp_sampler.py vectorized top-k/top-p vllm/v1/sample/ops/penalties.py repetition/frequency/presence penalties vllm/v1/sample/ops/bad_words.py banned sequences vllm/v1/sample/logits_processor/ the pluggable pre-sampling hook (builtin/interface/state) vllm/v1/sample/metadata.py SamplingMetadata: per-request params packed into tensors vllm/sampling_params.py SamplingParams (the user-facing knobs)
Contents
- 1.
SamplingParams— the knobs - 2.
Sampler.forward— the pipeline - 3. Vectorized top-k/top-p
- 4. Penalties
- 5. Logits processors — the hook everything uses
- 6. Parallel sampling
- Reading checklist
1. SamplingParams — the knobs
vllm/sampling_params.py:168 — class SamplingParams. Fields: temperature (:205, default
1.0), top_p (:209), top_k, min_p, penalties, n (parallel samples), seed, logit_bias,
max_tokens, stop conditions. mini_vllm/sampler.py:SamplingParams is a faithful subset
(temperature/top_k/top_p/seed/max_tokens/ignore_eos).
2. Sampler.forward — the pipeline
vllm/v1/sample/sampler.py:20 class Sampler, :67 def forward. Read it; the order is the
guide's pipeline:
- logits processors / penalties edit the logits (repetition penalty, bad-words, logit bias, and the structured-output grammar mask — Phase 12).
apply_temperature(:223) divides by per-request temperature.- top-k / top-p truncation (
ops/topk_topp_sampler.py). sample(:238) draws the token (argmax for greedy rows, multinomial for the rest).
The crucial detail: every step operates on the whole batch at once, with per-request params
read from SamplingMetadata (metadata.py) — tensors aligned to the batch. Greedy and random
requests coexist in one call; greedy rows are handled as a temperature→argmax path. There is no
Python per-request loop on the hot path — that's the systems win.
3. Vectorized top-k/top-p
vllm/v1/sample/ops/topk_topp_sampler.py — applies top-k and top-p across the batch with masked
sorts/cumsums, each row using its own k/p. (There's also a Triton variant,
topk_topp_triton.py, for speed.) Your mini_vllm/sampler.py _apply_top_k/_apply_top_p do
the single-row version; the real challenge is doing it for 256 different (k,p) at once without
branching.
4. Penalties
vllm/v1/sample/ops/penalties.py — given the tokens generated so far (and prompt), subtract
repetition/frequency/presence penalties from the corresponding logits. Needs per-request output
token histories, threaded through SamplingMetadata.
5. Logits processors — the hook everything uses
vllm/v1/sample/logits_processor/:
interface.py— theLogitsProcessorcontract (transform logits in place given state).builtin.py— the built-in processors (min-p, logit bias, etc.).state.py— per-request state management across steps.
This is the seam structured output (Phase 12) plugs into: a grammar produces a per-step bitmask of
allowed tokens, applied as a logits processor that sets illegal tokens to -inf before sampling.
Penalties, bias, and grammar masks all compose at this one well-defined point.
6. Parallel sampling
vllm/v1/engine/parallel_sampling.py — manages n>1: it expands one request into N child
sequences that share the prompt's KV (prefix caching, Phase 2/3) and diverge after the first
sampled token. Beam search has its own handling (it changes the active set each step, unlike plain
sampling).
Reading checklist
-
Sampler.forward— recite the pipeline order. - Where do per-request params live, and why packed into tensors (not a Python loop)?
-
topk_topp_sampler.py— how is heterogeneous-batch top-p done branch-free? -
The
LogitsProcessorinterface — how does Phase 12's grammar mask reuse it? -
parallel_sampling.py— how doesn>1reuse prefix caching?
Now build it: 02-mini-build.md, then the labs.