Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Phase 05 — Deep Dive: CUDA Graphs & torch.compile in real vLLM

Paths relative to upstream/ at v0.22.1 @ 0decac0 (UPSTREAM_PIN.md). The compilation subsystem:

vllm/compilation/
  cuda_graph.py        CUDAGraphWrapper, CUDAGraphEntry   (capture/replay — read this first)
  decorators.py        @support_torch_compile             (how a model opts in)
  backends.py          the VllmBackend for torch.compile  (trace -> split -> compile)
  piecewise_backend.py piecewise compiled regions
  passes/pass_manager.py  + passes/fusion/  custom graph rewrites
vllm/config/compilation.py   CompilationMode, CUDAGraphMode, CompilationConfig

We read capture/replay (the core), the two config enums (the vocabulary), and the decorator (the seam). The Inductor internals are deep — return after you're comfortable here.


Contents


1. The two enums that name everything

CompilationMode — the "level" (vllm/config/compilation.py:37)

class CompilationMode(enum.IntEnum):
    NONE = 0                  # pure eager, model runs as-is (what enforce_eager gives you)
    STOCK_TORCH_COMPILE = 1   # the standard torch.compile pipeline
    DYNAMO_TRACE_ONCE = 2     # single Dynamo trace, avoid recompilation
    VLLM_COMPILE = 3          # vLLM's Inductor backend: caching, piecewise, shape
                              # specialization, custom passes   <- V1 default

This answers "how hard does the compiler work?" Level 3 (VLLM_COMPILE) is where vLLM's value is — its own backend with caching and piecewise splitting. Levels 0–2 are mostly for debugging/comparison. mini_vllm doesn't compile (no GPU), but the idea of "a level dial from eager to fully-optimized" is the thing to carry.

CUDAGraphMode — the capture strategy (vllm/config/compilation.py:53)

class CUDAGraphMode(enum.Enum):
    NONE = 0
    PIECEWISE = 1
    FULL = 2
    FULL_DECODE_ONLY = (FULL, NONE)        # full graph for decode, nothing for mixed
    FULL_AND_PIECEWISE = (FULL, PIECEWISE) # full for decode, piecewise for mixed (v1 default)

Notice the clever encoding: the last two are tuples (decode_mode, mixed_mode). A batch is either pure-decode (uniform shapes — safe for a FULL graph) or mixed prefill+decode (variable attention metadata — needs PIECEWISE). The helper methods make this explicit (compilation.py:65):

def decode_mode(self) -> "CUDAGraphMode":
    return CUDAGraphMode(self.value[0]) if self.separate_routine() else self
def mixed_mode(self) -> "CUDAGraphMode":
    return CUDAGraphMode(self.value[1]) if self.separate_routine() else self
def has_mode(self, mode) -> bool: ...        # is `mode` one of my routines?
def requires_piecewise_compilation(self) -> bool:
    return self.has_mode(CUDAGraphMode.PIECEWISE)

So FULL_AND_PIECEWISE.decode_mode() == FULL and .mixed_mode() == PIECEWISE. You will reimplement these exact methods in lab-03 — they're small and they encode the whole "which graph for which batch" decision. The comment at line 595–620 of the config spells out the tradeoffs (PIECEWISE only keeps non-attention out of the graph; FULL_AND_PIECEWISE is generally fastest).


2. CUDAGraphWrapper — capture and replay (the heart)

vllm/compilation/cuda_graph.py:145. Read its docstring (lines 146–168) — it states the dispatch protocol precisely. The key data structure (line 207):

# the entries for different batch descriptors that we need to capture cudagraphs for.
self.concrete_cudagraph_entries: dict[BatchDescriptor, CUDAGraphEntry] = {}

A dict of graphs keyed by batch shape. This is Constraint 1 (per-shape) made concrete. Your mini_vllm.GraphRunner.graphs: dict[shape, GraphEntry] is the same structure.

A CUDAGraphEntry (line 128) is what we cache per shape:

@dataclass
class CUDAGraphEntry:
    batch_descriptor: BatchDescriptor
    cudagraph: torch.cuda.CUDAGraph | None = None
    output: Any | None = None
    # for cudagraph debugging, track the input addresses during capture,
    # and check if they are the same during replay
    input_addresses: list[int] | None = None

That input_addresses field is Constraint 2 (static buffers) made checkable: capture records the input tensor addresses; replay asserts they're unchanged. Your simulation models this with the static_input buffer you must np.copyto into.

The dispatch: __call__ (line 233)

Walk it in three branches:

(a) No graph / mode mismatch → run eagerly (lines 234–254):

forward_context = get_forward_context()
batch_descriptor = forward_context.batch_descriptor
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode

if (cudagraph_runtime_mode == CUDAGraphMode.NONE
        or cudagraph_runtime_mode != self.runtime_mode):
    # profile run, warmup, no-cudagraph, OR a different wrapper's turn
    return self.runnable(*args, **kwargs)

The wrapper "blindly trusts" the mode + shape key set by the model runner in the forward_context. If the runtime says NONE (profiling/warmup) or this isn't this wrapper's mode, just run the real function. (This is how FULL and PIECEWISE wrappers can be nested and each only fires for its own mode.) Your GraphRunner doesn't need modes, but the trust-the-context pattern is why the wrapper stays decoupled from the compiler.

(b) Shape not seen → CAPTURE (lines 257–344):

if batch_descriptor not in self.concrete_cudagraph_entries:
    self.concrete_cudagraph_entries[batch_descriptor] = CUDAGraphEntry(batch_descriptor=...)
entry = self.concrete_cudagraph_entries[batch_descriptor]

if entry.cudagraph is None:
    validate_cudagraph_capturing_enabled()
    input_addresses = [x.data_ptr() for x in args if isinstance(x, torch.Tensor)]
    entry.input_addresses = input_addresses
    cudagraph = torch.cuda.CUDAGraph()
    ...
    with torch.cuda.graph(cudagraph, pool=self.graph_pool, stream=current_stream()):
        output = self.runnable(*args, **kwargs)      # the kernels are RECORDED, not just run
        if self.cudagraph_options.weak_ref_output:
            output = weak_ref_tensors(output)
    entry.output = weak_ref_tensors(output)
    entry.cudagraph = cudagraph
    compilation_counter.num_cudagraph_captured += 1
    return output                                    # return the REAL output on capture step

The with torch.cuda.graph(...) context is where CUDA records every kernel issued by self.runnable(...) into cudagraph. The weak_ref_tensors dance (lines 325–336) is the "mind-exploding" memory management the comment warns about: the output lives in the graph's private memory pool, so vLLM holds only weak references to avoid leaking it while still letting PyTorch manage the pool. Your simulation skips this (numpy has no pools) but captures the structure: first sight of a shape → run once, record, cache.

(c) Shape seen → REPLAY (lines 346–361):

if self.is_debugging_mode:
    new_input_addresses = [x.data_ptr() for x in args if isinstance(x, torch.Tensor)]
    assert new_input_addresses == entry.input_addresses, (
        "Input addresses for cudagraphs are different during replay...")
...
entry.cudagraph.replay()
return entry.output

This is the entire win in two lines: entry.cudagraph.replay() issues one launch and the GPU runs the whole recorded sequence; return the cached output tensor. Note the debug assertion — it enforces Constraint 2 (inputs must be at the same addresses; the model runner guarantees this by writing new inputs into persistent buffers before calling). Your GraphRunner.__call__ replay branch is the direct analog: np.copyto(entry.static_input, x) then "replay" as a single LaunchCounter.bump(1).

The whole class in one sentence: a per-shape dict where the first call captures and every later call with that shape replays — exactly your mini_vllm.GraphRunner.


3. @support_torch_compile — the seam between a model and the compiler

vllm/compilation/decorators.py:118. Models opt in by decorating the class:

@support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0})
class MyModel(nn.Module):
    def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...

What it does (read the docstring, 126–176): it wraps the class so that, when compilation is enabled, the forward is run through torch.compile/the vLLM backend, and it marks which tensor dimensions are dynamic (the batch/sequence dim) so the compiler specializes on shape correctly. dynamic_arg_dims says "dimension 0 of x varies" — that's the batch dimension the CUDA-graph capture sizes range over. If you don't pass it, vLLM infers it from the type annotations (line 153): torch.Tensor args get dim 0 marked dynamic.

The important takeaway: adding compile support to a model is one decorator, and the dynamic dims you declare are what let the same compiled artifact serve many batch sizes (and what the CUDA-graph layer keys its captured graphs on). When you add a model in Phase 14, this decorator is part of the recipe.


4. The backend + passes (skim now, return later)

  • vllm/compilation/backends.pyVllmBackend, the torch.compile backend Dynamo calls with the traced FX graph. It splits the graph at splitting_ops (attention) for piecewise compilation, compiles each piece with Inductor, caches the results, and arranges the pieces for piecewise CUDA-graph capture. This is the level-3 VLLM_COMPILE machinery.
  • vllm/compilation/piecewise_backend.py — manages a single piecewise compiled region.
  • vllm/compilation/passes/pass_manager.py + passes/fusion/ — the custom graph passes: rewrites vLLM applies to the traced graph that stock Inductor wouldn't, e.g. fusing add + RMSNorm, fusing quantization into the preceding op, sequence-parallel rewrites. Each pass is an FX-graph-in, FX-graph-out transform. Reading one small fusion pass is a great way to see "graph-level transformation" concretely.

Your mini_vllm.PiecewiseGraphRunner models the split idea (break at uncapturable ops, capture the rest) without the Inductor compilation — which is the part that matters for the mental model.


5. Where it's wired into the engine

The model runner (vllm/v1/worker/gpu_model_runner.py) is what:

  • decides the cudagraph_runtime_mode for the current batch (FULL for pure decode, PIECEWISE for mixed, NONE during profiling/warmup) and the batch_descriptor (the shape key),
  • sets them on the forward_context (which the CUDAGraphWrapper reads),
  • writes the step's inputs into the persistent buffers the captured graph reads from (Constraint 2), padding the batch up to a captured size (Constraint 1),
  • runs a warmup at startup that captures graphs for every size in cudagraph_capture_sizes.

Search gpu_model_runner.py for cudagraph and capture to see the warmup/capture loop and the input-buffer copies. That's the production embodiment of everything above.


Reading checklist

One sentence each in your notebook:

  • CompilationMode — what does level 3 (VLLM_COMPILE) add over stock torch.compile?
  • CUDAGraphMode — why are FULL_AND_PIECEWISE/FULL_DECODE_ONLY encoded as tuples?
  • concrete_cudagraph_entries — what is the key, and which constraint does that enforce?
  • CUDAGraphEntry.input_addresses — which constraint, and when is it checked?
  • __call__ — name the three branches (eager / capture / replay) and their triggers.
  • entry.cudagraph.replay() — why is this "the entire win"?
  • @support_torch_compile dynamic_arg_dims — why does the compiler need to know the dynamic dimension?

Now build it: 02-mini-build.md, then the labs.