Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Lab 09-01 — Sampling Ops & Logits Processors [CPU-OK]

Phase 0 lab-03 built the four classic knobs. This lab builds the production pipeline: penalties that read generation history, the confidence-relative min_p cutoff, and — the architecturally important part — the logits-processor hook, a pluggable (logits, ctx) → logits stage that turns the sampler from a fixed function into an extension point. That hook is how structured output injects its grammar mask (Phase 12), how logit_bias and bad-words lists work, and how every "force the model to/never let the model…" feature you'll ever build gets in. The ordering of the stages is the lab's quiet theorem: each transform assumes the one before it, and reorderings produce different samplers, not equivalent ones.

Contents


Why this lab exists

Two reasons, one per half of the lab. The ops half: min_p and repetition penalties are the knobs real traffic actually exercises (every chat frontend ships a repetition penalty; min_p has become the open-model community's favorite truncation), and both have semantics subtle enough that implementing them is the only reliable way to stop mis-explaining them. min_p's cutoff scales with the model's confidence — strict when one token dominates, permissive when the distribution is flat — which is the adaptivity top-k lacks and top-p only approximates. The repetition penalty's divide-positive, multiply-negative asymmetry is the detail everyone forgets: a naive uniform division would boost already-negative logits.

The hook half is about architecture. vLLM cannot ship every conceivable logits intervention, so it ships an interface; the pipeline you build — an ordered list of (logits, ctx) → logits callables, run before the standard knobs — is that interface in miniature. After this lab, Phase 12's grammar-constrained decoding is "a logits processor that masks non-grammatical tokens," full stop; the mystery is relocated to building the mask fast, which is where it belongs.

Background: the pipeline and why its order is fixed

custom processors → repetition penalty → temperature → top-k → top-p → min-p → softmax → draw

Walk the order backwards and each placement explains itself:

  • Truncations (top-k/p/min-p) come after temperature because they're defined over the distribution you'll actually sample from — truncating pre-temperature evaluates the nucleus on the wrong distribution (and yes, top-k commutes with monotone temperature but top-p does not: temperature changes the probabilities the cumulative sum is built from).
  • Penalties come before temperature: they're corrections to the model's raw scores, not to the sampling distribution; applying them post-temperature would make the penalty's strength depend on T — two knobs tangled into one.
  • Custom processors run first: a hard constraint (grammar mask, banned token) must shape everything downstream — a token masked to −∞ before truncation can never sneak back, no matter what k/p/min-p do. Mask after truncation and you can end up with an empty candidate set (every surviving token banned) — the all-states-are-−∞ crash class that constrained-decoding implementations know well.
  • Order within truncations (k → p → min-p) matches vLLM's; they don't commute either, and matching the engine's order is what makes your sampler's outputs comparable to its.

Files

  • starter.py — the five ops, the Pipeline, and sample (the full ordered assembly). Your work.
  • solution.py — reference.
  • test_lab.py — each op's exact semantics plus the ban-token processor pattern.

Run

LAB_IMPL=starter pytest phase-09-sampling-and-decoding-algorithms/labs/lab-01-sampling-ops -q
pytest phase-09-sampling-and-decoding-algorithms/labs/lab-01-sampling-ops -q   # reference

What to implement

The ops from Phase 0 lab-03 (temperature, top-k, top-p) plus the three new pieces: apply_min_p (threshold = min_p × max_prob, computed on the current distribution), apply_repetition_penalty (divide positive logits by the penalty, multiply negative ones — both directions push down; apply once per distinct token, not per occurrence), and Pipeline/sample (the assembly in the order above; greedy short-circuits after penalties — penalties do apply to greedy, a detail people miss: a repetition penalty that only worked at temperature > 0 would be a different feature).

What the tests prove

TestWhat it pins
top-k = 1 ⇒ argmax onlyThe truncation-to-greedy limit
top-p keeps exactly the nucleusThe inclusive-boundary semantics (Phase 0 lab-03's footgun, still armed)
min-p cutoff scales with max probThe confidence-relative behavior that distinguishes it from a fixed floor
repetition penalty lowers a repeated tokenBoth signs handled — the divide/multiply asymmetry
ban-token processor ⇒ token unsamplableThe hook works, and −∞ survives the whole downstream pipeline — the Phase 12 grammar-mask pattern in one assert

Hitchhiker's notes

  • The ctx dict is the processor's window into the request — here just {"generated": [...]}, upstream a richer per-request state (prompt tokens, output tokens, FSM state for grammars). The discipline that keeps the hook safe: processors read ctx and return logits; a processor that mutates shared state breaks the batched execution model (rows are processed in arbitrary order — Phase 9 lab-04's isolation lesson, one layer up).
  • Penalties are why samplers need history. Temperature/top-k are pure functions of the logits row; penalties read generated — meaning the production sampler carries per-request token-id state to the GPU (upstream: the penalty path in vllm/v1/sample/, with prompt-vs-output token distinction: presence_penalty, frequency_penalty, repetition_penalty — three related-but-different formulas; read them once and save yourself a support ticket).
  • Each stage is cheap; the sort in top-p is the expensive one (O(V log V) per row, V = 128k+). Vectorized GPU implementations care a lot — there are sort-free top-p approximations and threshold-precomputation tricks upstream. When sampling shows up in a profile (it does, at high batch), this is the line.
  • Processor order is API, the Phase 1 lab-05 lesson recurring: two processors (say, a grammar mask and a logit bias) don't commute either. vLLM applies user-supplied processors in list order — document yours.

Going further

  • Implement presence_penalty and frequency_penalty (additive, occurrence-counting — distinct from the multiplicative repetition penalty) and write the test that distinguishes all three on a token generated twice.
  • Build a MinTokensProcessor that masks EOS while len(generated) < min_tokens — you've now implemented the min_tokens feature from Phase 1 lab-05's going-further, as a processor, which is exactly how the engine structures it.
  • Property-test the pipeline: for random logits and any knob combo, assert the output distribution (a) sums to 1, (b) supports only unmasked tokens, (c) is unchanged when all knobs are neutral. Three invariants that catch most pipeline-assembly bugs.

References

  • upstream/vllm/v1/sample/sampler.py — the batched pipeline; find your stage order.
  • upstream/vllm/v1/sample/logits_processor/ — the production hook interface.
  • Nguyen et al., Min-p Sampling (2024) — the case for confidence-relative truncation: https://arxiv.org/abs/2407.01082
  • Keskar et al., CTRL (2019) — where the repetition penalty's divide/multiply form comes from (§4.1): https://arxiv.org/abs/1909.05858
  • Phase 0 lab-03 — the four base knobs; Phase 12 — the grammar mask that rides this lab's hook.