Lab 03-01 — Implement the Scheduler Step [CPU-OK]
The scheduler is the brain of the engine — the component that decides, every single step,
who computes and how much. In this lab you implement its core: the two-phase loop
(serve the RUNNING, then admit the WAITING) under a global token budget, a sequence-slot
cap, and a per-request chunk limit. It is maybe 30 lines of code. Those 30 lines are the
difference between a GPU that hums at 90% utilization and one that stutters between
overload and idleness — and they are, shape for shape, the same 30 lines at the heart of
upstream/vllm/v1/core/sched/scheduler.py.
Contents
- Why this lab exists
- Background: the three scarce resources
- Why running-first is not arbitrary
- Files
- Run
- What
schedule_stepmust do - What the tests prove — a guided tour
- How this maps to the real engine
- Hitchhiker's notes
- Going further
- References
Why this lab exists
In Phase 1 lab-04 you observed the scheduler's decisions as a trace of per-step batch dicts. Reverse the arrow: now you are the one producing those dicts. Everything you watched — chunking to the budget, deferred admission, mixed prefill+decode batches — must now fall out of code you write. This is the course's central loop made flesh: observe a mechanism, then build it, and the understanding compounds.
It's also the file you will touch most as a contributor. Scheduling policy is where vLLM
evolves fastest — priority scheduling, fairness, SLA-aware admission, disaggregated
prefill (Phase 15) are all edits to this loop. The deep-dive walks you through the real
Scheduler.schedule; this lab makes sure that when you read it, you're recognizing,
not learning.
Background: the three scarce resources
Every scheduling decision is a negotiation between three independent scarcities, and the loop checks all three — know which line enforces which:
max_num_batched_tokens(the token budget, default 2048–8192 upstream) — caps total tokens computed per step. This is a latency control: step wall-clock time grows with scheduled tokens, so the budget is, almost literally, your inter-token latency dial (lab-05 measures this). The budget is global per step — one pool shared by everyone scheduled.max_num_seqs(the slot cap) — caps how many requests can be RUNNING at once. This bounds per-step fixed overheads and runner state (and, on real hardware, things like CUDA-graph batch-size buckets — Phase 5). It is checked only at admission: an already-running request never re-competes for its slot.- KV memory (via
kv.allocate(...)) — the hard wall from Phase 2. Unlike the other two, this one can refuse mid-flight (a decode needs one more block and the pool is empty); handling that refusal is preemption, deliberately deferred to lab-04. In this lab, allocation failure during admission simply stops admitting.
Three resources, three different enforcement points. Most scheduler bugs are one resource checked at the wrong point — e.g. counting seqs in the budget loop, or letting an admission overdraw the budget "just this once."
Why running-first is not arbitrary
The loop's order — RUNNING phase, then WAITING phase — encodes a policy with a name: decode-first. The requests already running have users watching tokens stream; a stalled decode is a frozen cursor in somebody's chat window. The waiting requests haven't received anything yet; making them wait one more step costs queueing delay but breaks no stream. So the scheduler protects in-flight experience first and spends whatever budget remains on admissions.
The inverse policy (admit-first) would maximize... nothing useful: it trades visible jitter for marginally earlier admissions. But note the deeper principle, because it generalizes: the loop's iteration order IS the priority policy. Upstream's priority scheduling and preemption-victim selection are both, at bottom, careful answers to "in what order do we iterate, and from which end do we take?"
Files
starter.py—clamp(...)andschedule_step(...)stubbed, with the full recipe in comments. Ships a tiny self-containedReqandFakeKV(a slot-counting memory model) so the lab isolates pure scheduling logic. Your work.solution.py— reference (mirrorsmini_vllm/scheduler.py, minus preemption).test_lab.py— budget cap, slot cap, chunking, running-first ordering, and memory-stops-admission.
Run
LAB_IMPL=starter pytest phase-03-continuous-batching-scheduler/labs/lab-01-scheduler-step -q
pytest phase-03-continuous-batching-scheduler/labs/lab-01-scheduler-step -q # reference
What schedule_step must do
budget = max_num_batched_tokens.- RUNNING phase: for each running req in order:
n = clamp(req.num_tokens − req.num_computed, budget, threshold); skip ifn == 0;kv.allocate(req, n)(assume it succeeds for running reqs here — preemption is lab-04); commitscheduled[rid] = n,budget -= n. - WAITING phase: while there are waiters AND
budget > 0ANDlen(running) < max_num_seqs: take the front waiter (FCFS — order is policy!), clamp the same way, try to allocate; on failurebreak(if the front request can't fit, don't go shopping deeper in the queue — see the head-of-line note below); on success, move it waiting → running and commit. - Return
{rid: n}.
And clamp(num_new, budget, threshold) is the whole chunking mechanism in one line:
cap by the per-request threshold (if 0 < threshold < num_new), then by the remaining
budget, floored at 0. Notice what isn't here: no "prefill mode," no "decode mode." A
decode is just a request whose num_tokens − num_computed == 1. The two-counters model
from Phase 1 means one code path schedules both — that unification is the deep design,
and it's why this loop stays 30 lines while doing what took Orca a paper to describe.
What the tests prove — a guided tour
test_clamp_chunks_and_budgets— the clamp's three regimes (budget-bound, threshold-bound, neither). Get this right first; everything else composes it.test_budget_caps_total_tokens— three 8-token prompts under a 10-token budget schedule exactly 10 tokens: 8 + 2 (the second request's prefill is chunked mid-prompt)- 0 (the third isn't admitted). One assertion, three behaviors.
test_max_num_seqs_caps_running— ten tiny requests, slots for four: exactly four admitted, despite infinite budget and memory. Each scarcity binds independently.test_chunked_prefill_caps_per_request— a 100-token prompt withthreshold=16schedules 16, not 100, even with budget to burn. The threshold protects other requests' latency from this request's prompt (lab-05 quantifies exactly how much).test_running_scheduled_before_waiting_admitted— the decode-first policy: the running decode gets its 1 token first; the eager 20-token waiter gets what's left, chunked. Order of phases = priority.test_admission_stops_when_memory_exhausted— A fills the pool; B stays WAITING. No crash, no partial admission: capacity exhaustion is a normal scheduling outcome, not an error path. (The engine-level consequence — B admitted later when A finishes — is Phase 1 lab-04's trace; the violent version is lab-04's preemption.)
How this maps to the real engine
Open upstream/vllm/v1/core/sched/scheduler.py:329 after you're green. The skeleton is
yours; production adds, in roughly descending order of weight: preemption inside the
RUNNING phase (the while True allocate-or-preempt dance — lab-04); prefix-cache
consultation at admission (get_computed_blocks — lab-06 / Phase 2 lab-05); structured-
output and LoRA gating; speculative-decoding token accounting; and the encoder budget for
multimodal inputs. Every one of those is a guard or a discount on num_new_tokens
inside the same two phases. Once you see the file that way — your 30 lines plus accessory
clauses — it stops being 700 intimidating lines.
Also worth noting upstream: _clamp_new_tokens's real twin is the interaction between
long_prefill_token_threshold and chunked_prefill_enabled in the scheduler config —
chunked prefill is default-on in V1, which tells you how settled this once-controversial
idea now is.
Hitchhiker's notes
- Head-of-line blocking is a choice. When the front waiter doesn't fit, we
breakrather than trying the next (smaller) one. Skipping ahead would raise utilization and starve large requests — a big prompt could wait forever behind a stream of small ones slipping past it. FCFS-with-blocking is the fairness-conservative default; if you relax it, you must add an aging mechanism. (Upstream has exactly this debate in its issue tracker — worth a read.) - Why is the budget in tokens, not requests? Because step time scales with tokens through the model, not with request count — a 1-request 2048-token prefill costs about the same as 2048 one-token decodes through the GEMMs (attention differs; Phase 18 refines). Budgeting the actual scarce quantity is what makes the latency dial linear.
num_computed > 0for a waiter is not an error — it's a preempted request being re-admitted (lab-04) or a prefix-cache hit (lab-06). Your clamp already handles it:num_tokens − num_computedjust comes out smaller. Design observation: by making "partial progress" a first-class state, recovery and caching share the admission path with fresh requests. No special cases.- The FakeKV is a teaching instrument: one slot per token, no blocks, no hashes — so
this lab's failures are always scheduling failures. When you wire the real
KVCacheManagerin (mini-build), block granularity adds a ceil() but changes no logic.
Going further
- Add priority classes: each
Reqgetspriority: int; iterate waiting in priority order with FCFS tiebreak. Then write the test proving a late high-priority request overtakes the queue without stalling running decodes. You've just implemented the core of upstream'spriorityscheduling policy. - Add the fully-cached edge case: if
num_tokens − num_computed == 0for a waiter (prefix cache covered everything it can), schedule 1 token anyway. Why must it be ≥ 1? (A request that schedules 0 tokens never produces logits, never samples, never finishes — an admission that can't make progress.mini_vllm/scheduler.pyhas this exact branch; lab-06 will show you the 1-token prefills it produces in a trace.) - Make the budget elastic: allow one oversized decode batch when
waitingis empty. Measure (with Phase 1 lab-04's probe) what it does to step-time variance. Most "clever" scheduler ideas die in exactly this experiment — cheap to run here, expensive to learn in production.
References
mini_vllm/scheduler.py— the full version (with preemption + prefix caching) your solution grows into.upstream/vllm/v1/core/sched/scheduler.py:329—Scheduler.schedule, the production loop; read it immediately after finishing.- Yu et al., Orca (OSDI 2022) — iteration-level scheduling, this loop's ancestor: https://www.usenix.org/conference/osdi22/presentation/yu
- Agrawal et al., Sarathi-Serve (OSDI 2024) — why the chunk threshold exists; the prefill/decode interference math: https://arxiv.org/abs/2403.02310
- vLLM docs, Optimization and Tuning —
max_num_batched_tokens/max_num_seqsguidance straight from the maintainers: https://docs.vllm.ai/en/latest/configuration/optimization.html