Phase 03 — Deep Dive: the real vLLM Scheduler
Paths relative to
upstream/atv0.22.1 @ 0decac0(UPSTREAM_PIN.md). The scheduler isvllm/v1/core/sched/scheduler.py(~2,300 lines). We read the parts that matter; the rest is connectors, encoders, spec-decode glue, and stats — return to those after Phases 8, 13, 15.Supporting files:
vllm/v1/core/sched/ scheduler.py Scheduler.schedule() / update_from_output() (the brain) output.py SchedulerOutput, NewRequestData, CachedRequestData (the wire format) request_queue.py FCFS vs PRIORITY queues (ordering policy) interface.py SchedulerInterface (the contract) vllm/v1/request.py Request, RequestStatus (the unit of work)
Contents
- 1. The unit of work:
Requestand its states - 2.
schedule()— the whole algorithm - 3. The output:
SchedulerOutput - 4. The other half:
update_from_output - 5. Putting Phases 02 + 03 together
- Reading checklist
1. The unit of work: Request and its states
vllm/v1/request.py:315, RequestStatus:
class RequestStatus(enum.IntEnum):
WAITING = enum.auto()
WAITING_FOR_STRUCTURED_OUTPUT_GRAMMAR = enum.auto()
WAITING_FOR_REMOTE_KVS = enum.auto()
WAITING_FOR_STREAMING_REQ = enum.auto()
RUNNING = enum.auto()
PREEMPTED = enum.auto()
# Note: anything after PREEMPTED will be considered as a finished status.
FINISHED_STOPPED = enum.auto()
FINISHED_LENGTH_CAPPED = enum.auto()
FINISHED_ABORTED = enum.auto()
...
Two things to internalize:
- The extra
WAITING_FOR_*states exist because a request can be not ready for reasons beyond "queued": waiting on a grammar to compile (Phase 12), on remote KV to arrive (Phase 15), etc. Yourmini_vllm.RequestStatuskeeps justWAITING/RUNNING/PREEMPTED/FINISHED_*— the essential skeleton. - The ordering trick:
is_finishedis simplystatus > PREEMPTED(line 337). Enum order is the logic.mini_vllmcopies this (is_finished = status >= FINISHED_STOPPED).
The master variables on Request: num_computed_tokens vs num_tokens (and
num_tokens_with_spec for speculative decoding). Everything in schedule() manipulates these.
2. schedule() — the whole algorithm
vllm/v1/core/sched/scheduler.py:329. The defining comment (lines 330–339) — read it; it's the
mental model from the guide, verbatim from the maintainers.
Setup (lines 341–362)
scheduled_new_reqs, scheduled_resumed_reqs = [], []
scheduled_running_reqs, preempted_reqs = [], []
req_to_new_blocks: dict[str, KVCacheBlocks] = {}
num_scheduled_tokens: dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens # <- the per-step token budget
...
self.kv_cache_manager.new_step_starts()
token_budget is max_num_scheduled_tokens (derived from max_num_batched_tokens). This is
the global cap that makes chunked prefill work. mini_vllm: token_budget = self.max_num_batched_tokens.
Phase A — schedule RUNNING requests (lines 364–533)
req_index = 0
while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index]
...
num_new_tokens = (
request.num_tokens_with_spec
+ request.num_output_placeholders
- request.num_computed_tokens
)
if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens:
num_new_tokens = self.scheduler_config.long_prefill_token_threshold # chunk long prefills
num_new_tokens = min(num_new_tokens, token_budget) # respect the budget
num_new_tokens = min(num_new_tokens, self.max_model_len - 1 - request.num_computed_tokens)
num_new_tokens = how far this request is behind, clamped by (a) the long-prefill chunk
threshold and (b) the remaining token budget and (c) the model length. This four-line clamp is
exactly your mini_vllm.Scheduler._clamp_new_tokens (minus spec/placeholder terms). Note
num_tokens_with_spec includes draft tokens — that's how speculative decoding (Phase 8) rides
the same scheduler with no special case, just as the top comment promised.
The preemption loop (lines 442–491) — the heart
with record_function_or_nullcontext("schedule: allocate_slots"):
while True:
new_blocks = self.kv_cache_manager.allocate_slots(
request, num_new_tokens, num_lookahead_tokens=self.num_lookahead_tokens,
)
if new_blocks is not None:
break # got memory; schedule it
# The request cannot be scheduled. Preempt the lowest-priority request.
if self.policy == SchedulingPolicy.PRIORITY:
preempted_req = max(self.running, key=lambda r: (r.priority, r.arrival_time))
self.running.remove(preempted_req)
...
else:
preempted_req = self.running.pop() # FCFS: preempt the most-recent
self._preempt_request(preempted_req, scheduled_timestamp)
preempted_reqs.append(preempted_req)
if preempted_req == request:
break # nothing left to preempt; give up this req
if new_blocks is None:
break
This is the None → preempt → retry handshake with the KV manager (Phase 02 §5). Under
FCFS it preempts self.running.pop() — the most recently admitted, i.e. lowest priority by
arrival. Under PRIORITY it preempts the worst (priority, arrival_time). mini_vllm implements
the FCFS branch (self.running.pop() + _preempt) — the PRIORITY branch is a great extension
exercise.
_preempt_request (line 929) frees the KV and resets the request to be recomputed. Compare
mini_vllm.Scheduler._preempt: frees KV, num_computed_tokens = 0, status PREEMPTED, back to
the front of waiting.
Commit the scheduled running request (lines 493–533)
scheduled_running_reqs.append(request)
req_to_new_blocks[request_id] = new_blocks
num_scheduled_tokens[request_id] = num_new_tokens
token_budget -= num_new_tokens # <- budget bookkeeping
req_index += 1
# ... spec-decode + encoder bookkeeping ...
Phase B — admit WAITING requests (lines 544–...)
if not preempted_reqs and self._pause_state == PauseState.UNPAUSED:
while (self.waiting or self.skipped_waiting) and token_budget > 0:
if len(self.running) == self.max_num_running_reqs:
break
...
request = request_queue.peek_request()
...
# Get already-cached tokens.
if request.num_computed_tokens == 0:
new_computed_blocks, num_new_local_computed_tokens = (
self.kv_cache_manager.get_computed_blocks(request) # <- prefix caching!
)
...
Three gates before admitting anyone (mirrored in mini_vllm):
if not preempted_reqs— don't admit new work in a step where we had to preempt (memory pressure). (mini_vllm:and not out.preempted_req_ids.)token_budget > 0— budget left.len(self.running) == self.max_num_running_reqs: break— the seq-slot cap (max_num_seqs).
Then get_computed_blocks(request) is the prefix-cache head start (Phase 02 §5, guide §4):
the request adopts the cached prefix and only prefills the remainder. The LoRA constraint just
below (lines 573–584) caps distinct adapters per step (max_loras, Phase 11) — another feature
riding the scheduler.
3. The output: SchedulerOutput
vllm/v1/core/sched/output.py:181. What the scheduler hands the executor:
@dataclass
class SchedulerOutput:
scheduled_new_reqs: list[NewRequestData] # first-time-scheduled (full payload)
scheduled_cached_reqs: CachedRequestData # already-running (just deltas)
num_scheduled_tokens: dict[str, int] # req_id -> tokens this step
total_num_scheduled_tokens: int
scheduled_spec_decode_tokens: dict[str, list[int]]
scheduled_encoder_inputs: dict[str, list[int]]
num_common_prefix_blocks: list[int]
finished_req_ids: set[str]
...
The split between NewRequestData (line 31 — full prompt, block_ids, sampling params) and
CachedRequestData (line 112 — just new tokens + new block ids) is a real optimization: for a
request already running, you don't resend the prompt every step, only the delta. mini_vllm
simplifies this to one num_scheduled_tokens dict + the request objects, but the idea — send
new requests in full, running requests as deltas — is worth knowing.
4. The other half: update_from_output
vllm/v1/core/sched/scheduler.py:1283. After the model runs and the sampler produces tokens,
the scheduler ingests the results: append sampled tokens, advance num_computed_tokens, detect
finished requests, free their KV, handle spec-decode acceptance/rejection, emit stats. Your
mini_vllm.Scheduler.update_from_output is the skeleton: num_computed_tokens += n; if a token
was sampled, append it and check stop conditions; reap finished requests (free KV, drop from
running).
The condition for "did this request emit a token this step" in mini_vllm is
needs_sample = (num_computed_tokens + num_scheduled == num_tokens) — only fully-caught-up
(prefill-complete) requests sample. The real engine encodes the same thing through the model
runner's logits-indices selection; the principle is identical (you only sample at the last
position of a request that has no more prompt to ingest).
5. Putting Phases 02 + 03 together
The clean separation you should now see:
Scheduler (policy: who runs, how many tokens) ──calls──► KVCacheManager (truth: is there memory?)
▲ │
└────────────── None ◄── allocate_slots ◄───────────────┘ (OOM signal)
│
└─ responds: preempt a running request, free its KV, retry
The scheduler never touches blocks directly; the KV manager never decides policy. That clean seam is why each file stays readable despite the engine's complexity — and it's a design lesson worth stealing for your own systems.
Reading checklist
Write one sentence each in your notebook:
-
The top comment of
schedule()— restate the "no prefill/decode phase" idea in your words. -
The 4-line
num_new_tokensclamp — what are the three caps and why each? -
The
while Truepreemption loop — what doesallocate_slotsreturningNonetrigger? - FCFS vs PRIORITY preemption victim selection — who gets preempted in each?
- The three gates before admitting WAITING requests.
-
get_computed_blocksin Phase B — how does prefix caching give a free head start? -
NewRequestDatavsCachedRequestData— why send deltas for running requests?
Now build it: 02-mini-build.md, then the labs.