Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Phase 09 — Deep Dive: the batched sampler

Paths relative to upstream/ at v0.22.1 @ 0decac0.

vllm/v1/sample/sampler.py            the batched Sampler (orchestrates the pipeline)
vllm/v1/sample/ops/topk_topp_sampler.py   vectorized top-k/top-p
vllm/v1/sample/ops/penalties.py      repetition/frequency/presence penalties
vllm/v1/sample/ops/bad_words.py      banned sequences
vllm/v1/sample/logits_processor/     the pluggable pre-sampling hook (builtin/interface/state)
vllm/v1/sample/metadata.py           SamplingMetadata: per-request params packed into tensors
vllm/sampling_params.py              SamplingParams (the user-facing knobs)

Contents


1. SamplingParams — the knobs

vllm/sampling_params.py:168class SamplingParams. Fields: temperature (:205, default 1.0), top_p (:209), top_k, min_p, penalties, n (parallel samples), seed, logit_bias, max_tokens, stop conditions. mini_vllm/sampler.py:SamplingParams is a faithful subset (temperature/top_k/top_p/seed/max_tokens/ignore_eos).

2. Sampler.forward — the pipeline

vllm/v1/sample/sampler.py:20 class Sampler, :67 def forward. Read it; the order is the guide's pipeline:

  1. logits processors / penalties edit the logits (repetition penalty, bad-words, logit bias, and the structured-output grammar mask — Phase 12).
  2. apply_temperature (:223) divides by per-request temperature.
  3. top-k / top-p truncation (ops/topk_topp_sampler.py).
  4. sample (:238) draws the token (argmax for greedy rows, multinomial for the rest).

The crucial detail: every step operates on the whole batch at once, with per-request params read from SamplingMetadata (metadata.py) — tensors aligned to the batch. Greedy and random requests coexist in one call; greedy rows are handled as a temperature→argmax path. There is no Python per-request loop on the hot path — that's the systems win.

3. Vectorized top-k/top-p

vllm/v1/sample/ops/topk_topp_sampler.py — applies top-k and top-p across the batch with masked sorts/cumsums, each row using its own k/p. (There's also a Triton variant, topk_topp_triton.py, for speed.) Your mini_vllm/sampler.py _apply_top_k/_apply_top_p do the single-row version; the real challenge is doing it for 256 different (k,p) at once without branching.

4. Penalties

vllm/v1/sample/ops/penalties.py — given the tokens generated so far (and prompt), subtract repetition/frequency/presence penalties from the corresponding logits. Needs per-request output token histories, threaded through SamplingMetadata.

5. Logits processors — the hook everything uses

vllm/v1/sample/logits_processor/:

  • interface.py — the LogitsProcessor contract (transform logits in place given state).
  • builtin.py — the built-in processors (min-p, logit bias, etc.).
  • state.py — per-request state management across steps.

This is the seam structured output (Phase 12) plugs into: a grammar produces a per-step bitmask of allowed tokens, applied as a logits processor that sets illegal tokens to -inf before sampling. Penalties, bias, and grammar masks all compose at this one well-defined point.

6. Parallel sampling

vllm/v1/engine/parallel_sampling.py — manages n>1: it expands one request into N child sequences that share the prompt's KV (prefix caching, Phase 2/3) and diverge after the first sampled token. Beam search has its own handling (it changes the active set each step, unlike plain sampling).

Reading checklist

  • Sampler.forward — recite the pipeline order.
  • Where do per-request params live, and why packed into tensors (not a Python loop)?
  • topk_topp_sampler.py — how is heterogeneous-batch top-p done branch-free?
  • The LogitsProcessor interface — how does Phase 12's grammar mask reuse it?
  • parallel_sampling.py — how does n>1 reuse prefix caching?

Now build it: 02-mini-build.md, then the labs.