Lab 03-02 — Chunked Prefill: Same Output, Different Timing [CPU-OK]
This lab proves, on a running engine, the most important safety property in vLLM:
Chunked prefill changes WHEN tokens are computed, never WHAT tokens are produced.
If that sentence were false, no scheduling optimization in this codebase would be safe to
ship — every knob that re-times work would be a knob that corrupts output. You'll verify it
the strong way (identical token ids, chunked vs unchunked, on the real mini_vllm engine),
and you'll learn to predict the timing side: exactly how many steps a prefill takes under
any threshold/budget combination.
Contents
- Why this lab exists
- Background: why chunk a prefill at all
- Why the output cannot change — the actual argument
- Files
- Run
- The formula to implement
- What the tests prove
- Hitchhiker's notes
- Going further
- References
Why this lab exists
"Re-timing is output-invariant" is the kind of claim engineers nod along to and never
check. But your career will repeatedly hand you moments where the nod isn't enough: a
customer reports different outputs between two deployments that differ only in scheduler
config; a reviewer asks whether your scheduler PR can change generations; an incident
review wants to know whether enabling chunked prefill mid-fleet is provably safe or just
probably safe. This lab gives you the proof technique: drive the same deterministic
workload through both schedules and diff the token ids. It's mini_vllm's own regression
test (test_engine.py::test_chunked_prefill_matches_unchunked_output), reproduced by your
hand so you know why it must hold, not just that it does.
The second skill is the timing model. "How many steps does a 4000-token prompt take at threshold 512?" is a real capacity question (it sets TTFT for that request and the interference window for everyone else — lab-05). The answer is a one-line formula, and you should never need to run the engine to produce it.
Background: why chunk a prefill at all
Without chunking, a 4096-token prompt arrives and the scheduler faces an ugly choice: schedule the whole prefill in one step — a step that takes hundreds of times longer than a decode step, during which every other user's token stream visibly freezes — or make the new request wait indefinitely. Early engines picked the freeze; users called it "jitter" and "stalls."
Chunked prefill (Sarathi's contribution, default-on in vLLM V1) dissolves the choice:
split the prompt into budget-sized chunks across several steps, and let decodes ride along
in each step's leftover budget. The long prompt pays slightly more total latency (more
steps, plus re-reading its growing KV each chunk); everyone else's inter-token latency
stays smooth. The two-counters model from Phase 1 makes the implementation almost
embarrassingly small: a prefill is just a request whose num_computed_tokens is far
behind, so capping its per-step advance — clamp from lab-01 — is chunking. No prefill
state machine, no resume logic; the counter is the resume logic.
Why the output cannot change — the actual argument
Spell it out once, carefully, because this is the argument you'll reuse for every scheduling feature:
- The model's logits at position k depend only on tokens
0..k(causality) and their KV values — not on which step computed that KV. KV is a pure function of the tokens. - The engine samples for a request only when
num_computed_tokens + n == num_tokens(Scheduler.needs_sample— Phase 1 lab-03's guard). Mid-prefill chunks emit nothing. - Therefore the first sample happens at the same logical state (all prompt KV computed,
position = prompt length) whether the prompt was computed in 1 chunk or 10. Same state
- same sampling → same token. Induction extends this to every later token.
The invariant has exactly two load-bearing dependencies: causality (KV doesn't depend on schedule) and the sampling guard (no logits read mid-prefill). Notice what that implies for review: a PR can only break output-invariance by touching one of those two things. That's a checklist of length two for an entire class of changes — and on real GPUs, a softer third dependency appears (batch-shape-dependent floating-point reduction order), which is why the real engine's version of this test compares with tolerance while ours can demand exact equality. See the Hitchhiker's notes.
Files
starter.py— implementnum_prefill_steps(prompt_len, threshold, budget). Your work.solution.py— reference.test_lab.py— checks your formula on the boundary cases AND runs the engine both ways asserting identical output token ids.
Run
LAB_IMPL=starter pytest phase-03-continuous-batching-scheduler/labs/lab-02-chunked-prefill -q
pytest phase-03-continuous-batching-scheduler/labs/lab-02-chunked-prefill -q # reference
The formula to implement
A single request (it owns the whole budget) with a prompt_len-token prompt. The per-step
chunk is threshold if threshold > 0 else budget, but never more than budget:
chunk = min(threshold or budget, budget). The prefill then takes
ceil(prompt_len / chunk)
steps. Watch the boundaries the tests probe: threshold = 0 means disabled (not "chunk
of zero"); a threshold larger than the budget is moot (the budget binds); a prompt that
divides evenly takes exactly prompt_len / chunk, no +1. Off-by-ones here are off-by-ones
in someone's TTFT model later.
What the tests prove
- The formula tests pin the chunk-size selection logic and the ceiling division —
including
threshold=0(unchunked: 1 step), threshold > budget (budget wins), and exact-division boundaries. - The engine test generates from the same prompt with
long_prefill_token_threshold=0and with a small threshold, and asserts identical output token ids — not similar: identical. It can demand exactness becausemini_vllmis deterministic end-to-end (greedy sampling, deterministic toy model), which turns the safety property into a hard equality a CI can enforce forever. This is the test you write first when building any scheduling feature: pin the semantics, then optimize the timing freely. (Compare to the trace shape you saw in Phase 1 lab-04: chunking visibly rearranged the steps. Same engine, same tokens — the timing is the only degree of freedom.)
Hitchhiker's notes
- On real GPUs, "identical" softens to "equivalent." Chunking changes batch shapes; different GEMM/attention tile sizes can change floating-point reduction order; logits wiggle in the last ulp; and a greedy argmax between two near-tied tokens can flip. The semantic invariant (same distribution, same correctness) holds; bitwise equality does not. This is why upstream correctness tests for chunked prefill compare with tolerance or check logprob closeness, and it's the first thing to say in the incident review when two configs differ by one token at position 947: not all divergence is a bug — divergence beyond rounding is.
- The threshold is a latency/throughput dial, not free money. Small chunks: smoother decode latency for others, but the long prompt's prefill stretches across more steps (worse TTFT for it), and each chunk re-reads the prompt's accumulated KV (attention cost ~quadratic-ish in total across chunks vs the one-shot). Sarathi-Serve's whole paper is about choosing this number; lab-05 lets you feel it.
- Where would chunking change output? It wouldn't — but a bug that sampled mid-prefill
would (a request emitting a token from logits computed over half its prompt). Find the
guard in
mini_vllm/scheduler.py::needs_sampleand its upstream twin (thelogits_indicesselection in the model runner). If a future refactor moved sampling before the catch-up check, this lab's engine test is the tripwire that catches it. - Chunked prefill and prefix caching compose. A cache-hit request enters admission
with
num_computed_tokensalready nonzero; the chunk math applies to the remainder. No interaction code exists because both features speak the same language: the counter. (Lab-06 shows the composed behavior in a trace.)
Going further
- Extend
num_prefill_stepsto two concurrent prompts sharing the budget fairly — suddenly you need to model the RUNNING phase's in-flight chunks competing with admissions, and the closed form gets genuinely interesting. Check your model against Phase 1 lab-04's probe. - Compute TTFT in steps as a function of threshold for a 4096-token prompt at budget 512, then sketch the other requests' worst-case stall at each threshold (lab-05 measures it). Plot both curves; their crossing is the tuning decision Sarathi formalizes.
- Read upstream's
long_prefill_token_thresholdhandling and the scheduler config's chunked-prefill defaults, and write down which of your formula's branches each config combination exercises.
References
mini_vllm/test_engine.py::test_chunked_prefill_matches_unchunked_output— the course's own regression test you just rebuilt.mini_vllm/scheduler.py—_clamp_new_tokens(the chunker) andneeds_sample(the guard).upstream/vllm/v1/core/sched/scheduler.py— the production clamp; searchlong_prefill_token_threshold.- Agrawal et al., SARATHI: Efficient LLM Inference by Piggybacking Decodes with Chunked Prefills (2023) — the original chunking paper: https://arxiv.org/abs/2308.16369
- Agrawal et al., Sarathi-Serve (OSDI 2024) — the production-grade follow-up with the threshold-tuning math: https://arxiv.org/abs/2403.02310
- vLLM docs, Chunked Prefill — the feature's official knobs and defaults: https://docs.vllm.ai/en/latest/configuration/optimization.html