Phase 01 — Deep Dive: tracing a request through real vLLM
Paths relative to
upstream/atv0.22.1 @ 0decac0. We follow one request fromLLM.generateto tokens out, naming every file. Keepmini_vllm/engine.pyopen alongside — it's the same control flow, miniature.
Contents
- 1. The offline entry point:
LLM.generate - 2. The heartbeat:
EngineCore.step - 3. Down to the metal: Executor → Worker → ModelRunner
- 4. The async path (serving)
- 5. The output path
- The whole journey, named
- Reading checklist
1. The offline entry point: LLM.generate
vllm/entrypoints/llm.py: class LLM (:66), def generate (:422). generate validates
inputs, builds requests, adds them to the engine, and runs the engine to completion, collecting
RequestOutputs. Under the hood it drives an LLMEngine.
vllm/v1/engine/llm_engine.py: class LLMEngine (:47) with add_request (:209) and step
(:287). This is the synchronous wrapper: add_request tokenizes + enqueues; step pumps the
core once and returns finished RequestOutputs. mini_vllm.LLMEngine.{add_request,step,generate}
mirror these one-to-one.
2. The heartbeat: EngineCore.step
vllm/v1/engine/core.py:428 (you read this in Phase 00 — revisit with the architecture in mind):
def step(self):
if not self.scheduler.has_requests():
return {}, False
scheduler_output = self.scheduler.schedule() # 1. who runs (Ph 3)
future = self.model_executor.execute_model(scheduler_output, ...) # 2. run model (Ph 4–14)
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
model_output = future.result()
if model_output is None:
model_output = self.model_executor.sample_tokens(grammar_output)# 3. sample (Ph 9)
engine_core_outputs = self.scheduler.update_from_output( # 4. advance (Ph 3)
scheduler_output, model_output)
return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
add_request is at core.py:337: it wraps the incoming EngineCoreRequest into a Request and
hands it to self.scheduler.add_request. Note EngineCore also subclasses into EngineCoreProc
(:835) — the version that runs in its own process and receives requests over a queue. That's
the process split from the guide.
3. Down to the metal: Executor → Worker → ModelRunner
self.model_executor is an Executor (vllm/v1/executor/abstract.py defines the interface). For
single-GPU it's a UniProcExecutor; for multi-GPU a MultiProcExecutor (multiproc_executor.py,
Phase 10). execute_model(scheduler_output) forwards to the worker(s).
vllm/v1/worker/gpu_worker.py — class Worker: owns the device, the model, and the KV cache for
one GPU. Its execute_model calls into the model runner.
vllm/v1/worker/gpu_model_runner.py — GPUModelRunner.execute_model is where SchedulerOutput
becomes reality: it gathers the scheduled tokens into input tensors, builds attention metadata
(block tables + sequence lengths from Phase 2/3), runs the (possibly CUDA-graphed, Phase 5)
forward pass, and runs the sampler. Search it for execute_model and _prepare_inputs. This is
the single busiest file in the engine — you'll return to it in Phases 4, 5, 9, 13.
4. The async path (serving)
vllm/v1/engine/async_llm.py: class AsyncLLM. The OpenAI server (Phase 16) calls
AsyncLLM.generate, an async generator that yields RequestOutput deltas as they're produced.
Internally it talks to the EngineCoreProc over IPC and runs the output processing/detokenization
on the server side, off the core's hot path. Same core, async shell.
5. The output path
vllm/v1/engine/output_processor.py + detokenizer.py: turn the core's sampled token ids back
into text, handle stop strings, and assemble RequestOutputs (streaming deltas for the server).
mini_vllm folds this into engine.generate (decode at the end) — simpler, same idea.
The whole journey, named
LLM.generate (llm.py:422)
└─ LLMEngine.add_request (llm_engine.py:209) -> EngineCore.add_request (core.py:337)
└─ loop LLMEngine.step (llm_engine.py:287) -> EngineCore.step (core.py:428):
scheduler.schedule() (sched/scheduler.py:329) Phase 3
executor.execute_model() (executor/ -> worker/gpu_model_runner.py) Phase 4-14
executor.sample_tokens() (sample/sampler.py) Phase 9
scheduler.update_from_output() (sched/scheduler.py:1283) Phase 3
└─ output_processor/detokenizer -> RequestOutput
Reading checklist
-
LLM.generate→ which engine method adds requests, which pumps the loop? -
EngineCore.step→ recite the four stages and the file each lives in. - Executor vs Worker vs ModelRunner → who owns the GPU, who builds tensors?
-
Why does
EngineCoreProcexist (the process split)? - Where does detokenization happen, and why off the core's hot path for serving?
Now build it: 02-mini-build.md, then the labs.