Phase 05 — Deep Dive: CUDA Graphs & torch.compile in real vLLM
Paths relative to
upstream/atv0.22.1 @ 0decac0(UPSTREAM_PIN.md). The compilation subsystem:vllm/compilation/ cuda_graph.py CUDAGraphWrapper, CUDAGraphEntry (capture/replay — read this first) decorators.py @support_torch_compile (how a model opts in) backends.py the VllmBackend for torch.compile (trace -> split -> compile) piecewise_backend.py piecewise compiled regions passes/pass_manager.py + passes/fusion/ custom graph rewrites vllm/config/compilation.py CompilationMode, CUDAGraphMode, CompilationConfigWe read capture/replay (the core), the two config enums (the vocabulary), and the decorator (the seam). The Inductor internals are deep — return after you're comfortable here.
Contents
- 1. The two enums that name everything
- 2.
CUDAGraphWrapper— capture and replay (the heart) - 3.
@support_torch_compile— the seam between a model and the compiler - 4. The backend + passes (skim now, return later)
- 5. Where it's wired into the engine
- Reading checklist
1. The two enums that name everything
CompilationMode — the "level" (vllm/config/compilation.py:37)
class CompilationMode(enum.IntEnum):
NONE = 0 # pure eager, model runs as-is (what enforce_eager gives you)
STOCK_TORCH_COMPILE = 1 # the standard torch.compile pipeline
DYNAMO_TRACE_ONCE = 2 # single Dynamo trace, avoid recompilation
VLLM_COMPILE = 3 # vLLM's Inductor backend: caching, piecewise, shape
# specialization, custom passes <- V1 default
This answers "how hard does the compiler work?" Level 3 (VLLM_COMPILE) is where vLLM's value
is — its own backend with caching and piecewise splitting. Levels 0–2 are mostly for
debugging/comparison. mini_vllm doesn't compile (no GPU), but the idea of "a level dial from
eager to fully-optimized" is the thing to carry.
CUDAGraphMode — the capture strategy (vllm/config/compilation.py:53)
class CUDAGraphMode(enum.Enum):
NONE = 0
PIECEWISE = 1
FULL = 2
FULL_DECODE_ONLY = (FULL, NONE) # full graph for decode, nothing for mixed
FULL_AND_PIECEWISE = (FULL, PIECEWISE) # full for decode, piecewise for mixed (v1 default)
Notice the clever encoding: the last two are tuples (decode_mode, mixed_mode). A batch is
either pure-decode (uniform shapes — safe for a FULL graph) or mixed prefill+decode (variable
attention metadata — needs PIECEWISE). The helper methods make this explicit
(compilation.py:65):
def decode_mode(self) -> "CUDAGraphMode":
return CUDAGraphMode(self.value[0]) if self.separate_routine() else self
def mixed_mode(self) -> "CUDAGraphMode":
return CUDAGraphMode(self.value[1]) if self.separate_routine() else self
def has_mode(self, mode) -> bool: ... # is `mode` one of my routines?
def requires_piecewise_compilation(self) -> bool:
return self.has_mode(CUDAGraphMode.PIECEWISE)
So FULL_AND_PIECEWISE.decode_mode() == FULL and .mixed_mode() == PIECEWISE. You will
reimplement these exact methods in lab-03 — they're small and they encode the whole
"which graph for which batch" decision. The comment at line 595–620 of the config spells out
the tradeoffs (PIECEWISE only keeps non-attention out of the graph; FULL_AND_PIECEWISE is
generally fastest).
2. CUDAGraphWrapper — capture and replay (the heart)
vllm/compilation/cuda_graph.py:145. Read its docstring (lines 146–168) — it states the
dispatch protocol precisely. The key data structure (line 207):
# the entries for different batch descriptors that we need to capture cudagraphs for.
self.concrete_cudagraph_entries: dict[BatchDescriptor, CUDAGraphEntry] = {}
A dict of graphs keyed by batch shape. This is Constraint 1 (per-shape) made concrete. Your
mini_vllm.GraphRunner.graphs: dict[shape, GraphEntry] is the same structure.
A CUDAGraphEntry (line 128) is what we cache per shape:
@dataclass
class CUDAGraphEntry:
batch_descriptor: BatchDescriptor
cudagraph: torch.cuda.CUDAGraph | None = None
output: Any | None = None
# for cudagraph debugging, track the input addresses during capture,
# and check if they are the same during replay
input_addresses: list[int] | None = None
That input_addresses field is Constraint 2 (static buffers) made checkable: capture records
the input tensor addresses; replay asserts they're unchanged. Your simulation models this with
the static_input buffer you must np.copyto into.
The dispatch: __call__ (line 233)
Walk it in three branches:
(a) No graph / mode mismatch → run eagerly (lines 234–254):
forward_context = get_forward_context()
batch_descriptor = forward_context.batch_descriptor
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
if (cudagraph_runtime_mode == CUDAGraphMode.NONE
or cudagraph_runtime_mode != self.runtime_mode):
# profile run, warmup, no-cudagraph, OR a different wrapper's turn
return self.runnable(*args, **kwargs)
The wrapper "blindly trusts" the mode + shape key set by the model runner in the
forward_context. If the runtime says NONE (profiling/warmup) or this isn't this wrapper's
mode, just run the real function. (This is how FULL and PIECEWISE wrappers can be nested and each
only fires for its own mode.) Your GraphRunner doesn't need modes, but the trust-the-context
pattern is why the wrapper stays decoupled from the compiler.
(b) Shape not seen → CAPTURE (lines 257–344):
if batch_descriptor not in self.concrete_cudagraph_entries:
self.concrete_cudagraph_entries[batch_descriptor] = CUDAGraphEntry(batch_descriptor=...)
entry = self.concrete_cudagraph_entries[batch_descriptor]
if entry.cudagraph is None:
validate_cudagraph_capturing_enabled()
input_addresses = [x.data_ptr() for x in args if isinstance(x, torch.Tensor)]
entry.input_addresses = input_addresses
cudagraph = torch.cuda.CUDAGraph()
...
with torch.cuda.graph(cudagraph, pool=self.graph_pool, stream=current_stream()):
output = self.runnable(*args, **kwargs) # the kernels are RECORDED, not just run
if self.cudagraph_options.weak_ref_output:
output = weak_ref_tensors(output)
entry.output = weak_ref_tensors(output)
entry.cudagraph = cudagraph
compilation_counter.num_cudagraph_captured += 1
return output # return the REAL output on capture step
The with torch.cuda.graph(...) context is where CUDA records every kernel issued by
self.runnable(...) into cudagraph. The weak_ref_tensors dance (lines 325–336) is the
"mind-exploding" memory management the comment warns about: the output lives in the graph's
private memory pool, so vLLM holds only weak references to avoid leaking it while still letting
PyTorch manage the pool. Your simulation skips this (numpy has no pools) but captures the
structure: first sight of a shape → run once, record, cache.
(c) Shape seen → REPLAY (lines 346–361):
if self.is_debugging_mode:
new_input_addresses = [x.data_ptr() for x in args if isinstance(x, torch.Tensor)]
assert new_input_addresses == entry.input_addresses, (
"Input addresses for cudagraphs are different during replay...")
...
entry.cudagraph.replay()
return entry.output
This is the entire win in two lines: entry.cudagraph.replay() issues one launch and the
GPU runs the whole recorded sequence; return the cached output tensor. Note the debug assertion —
it enforces Constraint 2 (inputs must be at the same addresses; the model runner guarantees this
by writing new inputs into persistent buffers before calling). Your GraphRunner.__call__ replay
branch is the direct analog: np.copyto(entry.static_input, x) then "replay" as a single
LaunchCounter.bump(1).
The whole class in one sentence: a per-shape dict where the first call captures and every later call with that shape replays — exactly your
mini_vllm.GraphRunner.
3. @support_torch_compile — the seam between a model and the compiler
vllm/compilation/decorators.py:118. Models opt in by decorating the class:
@support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0})
class MyModel(nn.Module):
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
What it does (read the docstring, 126–176): it wraps the class so that, when compilation is
enabled, the forward is run through torch.compile/the vLLM backend, and it marks which tensor
dimensions are dynamic (the batch/sequence dim) so the compiler specializes on shape
correctly. dynamic_arg_dims says "dimension 0 of x varies" — that's the batch dimension the
CUDA-graph capture sizes range over. If you don't pass it, vLLM infers it from the type
annotations (line 153): torch.Tensor args get dim 0 marked dynamic.
The important takeaway: adding compile support to a model is one decorator, and the dynamic dims you declare are what let the same compiled artifact serve many batch sizes (and what the CUDA-graph layer keys its captured graphs on). When you add a model in Phase 14, this decorator is part of the recipe.
4. The backend + passes (skim now, return later)
vllm/compilation/backends.py—VllmBackend, thetorch.compilebackend Dynamo calls with the traced FX graph. It splits the graph atsplitting_ops(attention) for piecewise compilation, compiles each piece with Inductor, caches the results, and arranges the pieces for piecewise CUDA-graph capture. This is the level-3VLLM_COMPILEmachinery.vllm/compilation/piecewise_backend.py— manages a single piecewise compiled region.vllm/compilation/passes/pass_manager.py+passes/fusion/— the custom graph passes: rewrites vLLM applies to the traced graph that stock Inductor wouldn't, e.g. fusingadd + RMSNorm, fusing quantization into the preceding op, sequence-parallel rewrites. Each pass is an FX-graph-in, FX-graph-out transform. Reading one small fusion pass is a great way to see "graph-level transformation" concretely.
Your mini_vllm.PiecewiseGraphRunner models the split idea (break at uncapturable ops, capture
the rest) without the Inductor compilation — which is the part that matters for the mental model.
5. Where it's wired into the engine
The model runner (vllm/v1/worker/gpu_model_runner.py) is what:
- decides the
cudagraph_runtime_modefor the current batch (FULL for pure decode, PIECEWISE for mixed, NONE during profiling/warmup) and thebatch_descriptor(the shape key), - sets them on the
forward_context(which theCUDAGraphWrapperreads), - writes the step's inputs into the persistent buffers the captured graph reads from (Constraint 2), padding the batch up to a captured size (Constraint 1),
- runs a warmup at startup that captures graphs for every size in
cudagraph_capture_sizes.
Search gpu_model_runner.py for cudagraph and capture to see the warmup/capture loop and the
input-buffer copies. That's the production embodiment of everything above.
Reading checklist
One sentence each in your notebook:
-
CompilationMode— what does level 3 (VLLM_COMPILE) add over stocktorch.compile? -
CUDAGraphMode— why areFULL_AND_PIECEWISE/FULL_DECODE_ONLYencoded as tuples? -
concrete_cudagraph_entries— what is the key, and which constraint does that enforce? -
CUDAGraphEntry.input_addresses— which constraint, and when is it checked? -
__call__— name the three branches (eager / capture / replay) and their triggers. -
entry.cudagraph.replay()— why is this "the entire win"? -
@support_torch_compiledynamic_arg_dims— why does the compiler need to know the dynamic dimension?
Now build it: 02-mini-build.md, then the labs.