Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Lab 05-01 — Build the Capture/Replay Simulator [CPU-OK]

Here's the absurdity CUDA graphs exist to fix: a decode step for a small model can spend more time on the CPU — Python dispatch, kernel argument marshaling, cudaLaunchKernel calls, one per operation, hundreds per step — than the GPU spends computing. The GPU finishes each tiny kernel and idles, waiting for the next launch to arrive. CUDA graphs fix it by recording the whole kernel sequence once and replaying it as a single launch. In this lab you build that mechanism on CPU — capture, shape-keyed dispatch, static buffers, replay — and in doing so you'll discover that both of its infamous constraints aren't incidental limitations but the direct price of the win.

Contents


Why this lab exists

CUDA graphs have a reputation as deep GPU arcana, and the reputation is wrong: the mechanism is pure systems — a cache of recorded work, keyed by shape, replayed from fixed memory — and it simulates perfectly on a laptop. What's genuinely hard about graphs in production is not the replay; it's the discipline the constraints impose on everything else: every batch must arrive at a captured shape (lab-05's padding ladder), every input must be written into the same buffers (the input_addresses checks upstream), and anything dynamic — like attention over varying sequence lengths — must either be made shape-stable or cut out of the graph (lab-03's piecewise modes). You can't reason about any of that machinery until the core capture/replay contract is in your fingers. That's this lab.

The simulator you build mirrors mini_vllm/cudagraph.py, which itself mirrors the real CUDAGraphWrapper (upstream/vllm/compilation/cuda_graph.py) — same per-shape dict, same static-buffer copy, same single-launch accounting. The launch counter stands in for wall-clock CPU overhead, for the usual course reason: a counter gives you formulas (lab-02 derives them), a stopwatch gives you noise.

Background: one win, two constraints

  • The WIN — eager execution pays one launch per op, every call. A captured graph pays the full cost once (capture), then one launch per replay regardless of how many kernels are inside. For a 300-kernel decode step replayed thousands of times per second, that's the difference lab-04 measures at ~2.5× end-to-end.
  • CONSTRAINT 1 (fixed shape) — the recording bakes in every tensor size, grid dimension, and memory extent. A different batch size is a different recording. Hence: graphs are stored in a dict keyed by shape (upstream: concrete_cudagraph_entries keyed by BatchDescriptor), and unseen shapes must capture anew.
  • CONSTRAINT 2 (static buffers) — the recording bakes in addresses. Replay reads the same input memory it was captured from, so new inputs must be copied into the captured buffer before replay (upstream asserts this: the input_addresses consistency check). Forget the copy and the graph happily recomputes last step's batch — the classic graph bug, and test_static_buffer_reflects_new_input exists to make you commit it once, here, where it's cheap.

Both constraints are the same fact stated twice: a graph is a recording, not a program. Recordings don't take arguments.

Files

  • starter.pyLaunchCounter, run_eager, and GraphRunner stubbed. Your work.
  • solution.py — reference (mirrors mini_vllm/cudagraph.py).
  • test_lab.py — the win, both constraints, correctness, and the 100-call accounting.

Run

LAB_IMPL=starter pytest phase-05-cuda-graphs-and-torch-compile/labs/lab-01-graph-replay-simulator -q
pytest phase-05-cuda-graphs-and-torch-compile/labs/lab-01-graph-replay-simulator -q   # reference

What to implement

  1. LaunchCounter — class-level n, reset(), bump(k=1). (Global on purpose: launch overhead is a process-wide resource, which is also why one slow Python step stalls every request in the batch.)
  2. run_eager(ops, x) — bump once per op, every call.
  3. GraphRunner(ops).__call__(x):
    • Capture (shape unseen): copy x into static_input, run ops (bump each), cache a GraphEntry, return the output.
    • Replay (shape seen): np.copyto(entry.static_input, x)into the existing buffer, never rebind the reference — recompute from the buffer, bump(1) total, return.

What the tests prove

TestWhat it pins
test_eager_pays_one_launch_per_opThe baseline cost model
test_capture_then_replay_is_one_launchThe WIN: capture = len(ops), replay = exactly 1
test_replay_output_matches_eagerReplay is an optimization, not a behavior change — the course's master invariant, graph edition
test_static_buffer_reflects_new_inputConstraint 2: capture with value 1, replay with value 5, get 50 — the copy-into-buffer is live
test_new_shape_triggers_recaptureConstraint 1: shape (8,) after shape (4,) pays full capture; both entries coexist in the dict
test_graphs_win_when_overhead_dominates100 calls: 300 eager launches vs 102 graph launches — the amortization lab-02 turns into formulas

Hitchhiker's notes

  • np.copyto(buf, x) vs buf = x is the whole lab. Rebinding the Python name does nothing to the captured memory; the real API has the same trap (you must static_tensor.copy_(new) in PyTorch graph idiom, never reassign). If you remember one line from this phase, make it this one.
  • Find your three lines upstream: capture (cuda_graph.py:313, inside torch.cuda.graph(...)), replay (:360, entry.cudagraph.replay()), the per-shape dict (:207). The production wrapper adds warmup runs before capture (CUDA needs the allocator and autotuners settled), a memory pool shared across graphs, and debug-mode address assertions — engineering around exactly the two constraints you implemented.
  • What can't be captured at all? Anything whose control flow depends on data: CPU-side branching, dynamic shapes inside the sequence, unsupported ops (some collectives, host syncs). vLLM's answer is to compile the model into a shape-stable form first (torch.compile, with attention marked as a splitting op) — graphs are the last stage of the compilation pipeline, not a standalone trick. That pipeline is the deep-dive's subject; lab-03 handles the mode routing it produces.
  • Replay still runs the ops here (numpy has no real recording) — the simulation's one honest cheat. The accounting (one launch) models the real benefit; the real replay also skips Python entirely, which is why the measured win (lab-04) can exceed what launch-counting alone predicts.

Going further

  • Add an input_addresses assertion to your replay path (store id(entry.static_input) at capture; assert it unchanged at replay) — you've reproduced upstream's debug check, and you'll appreciate why it exists the first time you "optimize" the copy away.
  • Give GraphRunner a memory budget: each entry costs prod(shape) bytes; evict LRU when over budget. Now you have the graph-pool problem, and a feel for why upstream shares one memory pool across all captured sizes instead.
  • Wire your GraphRunner around mini_vllm's ToyModel.forward for fixed batch sizes and count launches across a full generate() — the engine-level integration upstream does in the model runner.

References

  • mini_vllm/cudagraph.py — the annotated simulator this lab rebuilds, with upstream line references throughout.
  • upstream/vllm/compilation/cuda_graph.pyCUDAGraphWrapper: capture, replay, BatchDescriptor dict, address checks.
  • NVIDIA, Getting Started with CUDA Graphs — the original motivation and API: https://developer.nvidia.com/blog/cuda-graphs/
  • PyTorch docs, CUDA Graphs (torch.cuda.CUDAGraph) — the idiom vLLM builds on, including the static-buffer pattern: https://pytorch.org/docs/stable/notes/cuda.html#cuda-graphs
  • Phase 0 lab-04 — why small-batch decode is launch-overhead territory in the first place.