Lab 01-04 — Watch the Batch: Continuous Batching Made Visible [CPU-OK]
One request's lifecycle (lab-01) is a nice story. But inference engines earn their living when many requests share the machine — and the way they share it is the single biggest throughput idea of the last few years: continuous batching. In this lab you'll instrument the engine to photograph the batch composition of every step — who got scheduled, for how many tokens — and you'll directly observe the thing the famous benchmark posts only describe: prefill chunks of one request riding in the same step as decodes of another.
Contents
- Why this lab exists
- Background: static vs continuous batching
- Files
- Run
- What to implement
- What you should see — the full trace, explained
- What the tests prove
- Hitchhiker's notes
- Going further
- References
Why this lab exists
This lab is Phase 3 knocking on the door early — on purpose. The scheduler is easier to implement (Phase 3 lab-01) after you've seen its decisions laid out step by step. More practically: per-step batch composition is the engine's most important hidden variable. The wall-clock time of a step is roughly proportional to the tokens scheduled in it, so the sequence of dicts you're about to record is, up to a constant, the latency profile of the server. Spiky dicts = spiky inter-token latency. When Phase 3 lab-05 measures chunked prefill's effect on decode latency, it will use exactly the probe you build here.
You'll also learn the instrumentation pattern itself: wrapping a component's method to observe a system without changing its behavior. That's how vLLM's own stat loggers attach to the engine, and how you'll debug schedulers for the rest of your career — schedulers rarely crash; they just quietly make bad batches. You can't grep for a bad batch. You have to look at it.
Background: static vs continuous batching
The old way (pre-Orca, ~2022): collect N requests, run them as a unit until all N finish, then take the next N. Two disasters hide in that sentence. First, requests finish at different times, and finished slots sit idle while the longest request drags on (the "convoy"). Second, a request arriving one millisecond after the batch launched waits an entire batch lifetime to even start. GPU utilization graphs of static-batched servers look like a comb: bursts of work, then teeth-gaps of idle.
The Orca insight (OSDI 2022), which vLLM adopted and the whole industry copied: rebuild the batch every step. A request can join the batch at any step boundary (its prefill just becomes part of that step) and leave at any step boundary (its slot is free next step, not at end-of-batch). The batch isn't a unit of work anymore — it's whatever the scheduler composed for this one tick. Anyscale's benchmark of this idea measured up to 23× throughput over static batching. That entire revolution is visible in the data structure you're about to record: consecutive dicts whose key sets grow and shrink while the engine never stops.
Files
starter.py— implementtrace_batches(engine + probe). Your work.solution.py— reference.test_lab.py— pins step-1 composition, token conservation, budget cap, deferral under a tight budget, and the existence of mixed prefill+decode steps.
Run
LAB_IMPL=starter pytest phase-01-architecture-and-request-lifecycle/labs/lab-04-watch-the-batch -q
pytest phase-01-architecture-and-request-lifecycle/labs/lab-04-watch-the-batch -q # reference
What to implement
def trace_batches(prompts, max_tokens=4, **engine_kwargs)
-> tuple[list[str], list[dict[str, int]]]
Add all prompts (greedy, ignore_eos=True), then run eng.step() to completion — but
first, wrap eng.scheduler.schedule with a closure that calls the original, appends a
copy of out.num_scheduled_tokens to your trace, and returns out unchanged. The probe
must be invisible: same engine behavior with or without it. (Copy the dict! The scheduler
gives you its own object; aliasing it is the kind of bug that produces a trace where every
step mysteriously looks like the last one.)
What you should see — the full trace, explained
Two prompts — A = "hello world" (11 tokens) and B = "goodbye" (7 tokens),
max_tokens=4 — with a tight budget of max_num_batched_tokens=8:
step 1: {A: 8} # A's prefill, CHUNKED to the budget. B is NOT admitted: budget spent.
step 2: {A: 3, B: 5} # A finishes prefill (3 left) + samples token 1.
# B finally admitted with the leftover budget: 8-3=5 of its 7.
step 3: {A: 1, B: 2} # ← THE MONEY STEP: A is decoding (1 token) while B is still
# prefilling (its last 2) — prefill and decode IN THE SAME BATCH.
step 4: {A: 1, B: 1} # both decoding.
step 5: {A: 1, B: 1} # ...
step 6: {B: 1} # A hit max_tokens and left; B has the machine to itself.
Read it like a maintainer:
- Step 1 is
{A: 8}, not{A: 11}— the budget (8) caps the step, so the scheduler takes the first 8 tokens of A's prompt and stops. Nothing special-cased:n = min(remaining, budget). And B isn't admitted at all, because admission requires leftover budget. B is spending this step in the WAITING queue — this is the queueing delay that lab-01 promised you'd see under load. - Step 2 is where continuous batching starts paying — A's last chunk and B's first chunk share a step. A static-batch engine cannot produce this step; it doesn't have a concept for "half of A and half of B."
- Step 3 is the signature —
min=1, max>1in the same dict: a decode and a prefill chunk co-scheduled. The testtest_mixed_batches_exist_under_loadhunts for exactly this shape. On a GPU this mixing is also an efficiency trick: decode alone underuses compute (bandwidth-bound), prefill alone starves latency; mixed batches fill the compute bubbles with prefill work (Sarathi's "piggybacking" — Phase 3). - Step 6's shrinking key set — A finished and was reaped mid-flight; B never noticed. Its slot is reusable immediately. That, and nothing more, is "continuous."
- Conservation check — sum A's numbers: 8+3+1+1+1 = 14 = 11 + 4 − 1. Each request's
scheduled tokens total
prompt + max_tokens − 1. Why −1? The final sampled token is appended and the request immediately finishes — its KV is never computed, because no further token will ever attend to it. The engine doesn't do work the future won't read. When a counter is off by one in a scheduler, this is the kind of identity you use to find it; that's why there's a test pinning it.
Rerun with the default roomy budget (2048) and the drama disappears: step 1 is
{A: 11, B: 7}, everything after is decodes. Scheduling is only interesting under
scarcity — keep that in mind when building benchmarks, or you'll "validate" a scheduler
on workloads that never exercise it.
What the tests prove
| Test | Invariant |
|---|---|
test_ample_budget_prefills_everyone_in_step_one | With budget to spare, admission is immediate — queueing is a scarcity phenomenon, not a constant tax |
test_token_conservation_per_request | Σ scheduled = prompt + max_tokens − 1, the off-by-one identity above |
test_budget_is_never_exceeded | Σ over the batch ≤ max_num_batched_tokens, every single step — the engine's load-bearing promise to the GPU's latency |
test_tight_budget_chunks_and_defers | The exact step-1/step-2 composition above: chunking + deferred admission |
test_mixed_batches_exist_under_load | A prefill chunk and a decode co-exist in one step |
Hitchhiker's notes
- The probe pattern beats print-debugging schedulers. You get structured data you can
assert on, diff between runs, and plot. The real engine's equivalent surface is
SchedulerOutput(upstreamvllm/v1/core/sched/output.py) — when debugging real vLLM, loggingnum_scheduled_tokensper step gives you this exact trace. - Why does B wait a whole step when the budget is spent? Could the scheduler give A 7 and B 1 instead of A 8? It could — but FCFS says finish admitting A's work first; fairness policies are a deep rabbit hole (priority scheduling lands in Phase 3's exercises). The shape to remember: policy decides who, budget decides how much, and they're separable concerns in the code.
- Step time ∝ scheduled tokens is a good first-order model but not exact on real hardware: a decode-only step pays memory-bandwidth costs that token-count alone doesn't capture, and tiny steps pay fixed launch overheads (which CUDA graphs attack — Phase 5). Phase 18 refines the model; the trace you built stays the right raw material.
- Request IDs are global.
mini_vllmnumbers requests with a module-level counter, so don't hardcodereq-0in your own experiments — use the idstrace_batchesreturns. The tests are written that way for exactly this reason.
Going further
- Plot the trace: steps on x, stacked bars of scheduled tokens per request. You've recreated the iconic continuous-batching diagram from the Orca paper and the Anyscale post — except yours is measured, not illustrated.
- Sweep
max_num_batched_tokensfrom 4 to 64 over the same prompts and plot total steps vs budget. You'll see a hyperbola flatten: past "everything fits," more budget buys nothing. Congratulations, you've found a saturation knee — Phase 18 is full of these. - Add 8 requests with staggered arrival (add two, step twice, add two more …). Watch key sets churn. This is what a production batch actually looks like: a rolling membership, no two steps alike.
References
- Yu et al., Orca: A Distributed Serving System for Transformer-Based Generative Models (OSDI 2022) — iteration-level scheduling, the idea this lab photographs: https://www.usenix.org/conference/osdi22/presentation/yu
- Anyscale, How continuous batching enables 23x throughput in LLM inference (2023) — the benchmark post that made this mainstream: https://www.anyscale.com/blog/continuous-batching-llm-inference
- Agrawal et al., Sarathi-Serve: Taming Throughput-Latency Tradeoff in LLM Inference (OSDI 2024) — why mixed prefill+decode batches are not just legal but desirable: https://arxiv.org/abs/2403.02310
upstream/vllm/v1/core/sched/output.py—SchedulerOutput, the real engine's version of the dicts you recorded.upstream/vllm/v1/core/sched/scheduler.py:329— the loop that composed every step you traced; you implement its core in Phase 3 lab-01.