Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Lab 14-02 — Trace Weight Loading in the Real llama.py [GPU-OPT]

Lab-03 had you implement the translation; this lab has you watch the production version run — and read it as a peer. You'll trace five tensors from a real Llama checkpoint through load_weights: the safetensors name on disk, the mapping row that claims it, the vLLM parameter and shard it lands in, and (under TP) which rows of that shard each rank takes. The deliverable is the filled-in mapping table below — five rows of checkpoint surgery, verified against a live load.

No GPU? Don't panic. Loading happens on CPU before anything touches CUDA — you can trace most of this with device="cpu"-ish settings or just the captured table below plus the source. The reading is the lab.

Contents


Why this lab exists

load_weights is where most model-integration PRs live or die, and it's also the single most readable "real" function in the model zoo once you have lab-03's vocabulary — a loop, a mapping table, and weight_loader callbacks. Tracing five tensors end-to-end converts the function from code-you-scroll-past into code-you-could-have-written, and it arms you for the two production moments that need this knowledge: a new checkpoint that won't load (which mapping row is missing?), and a loaded model that generates garbage (which shard landed wrong?).

Requirements

uv pip install -e ".[vllm]"
huggingface-cli download meta-llama/Meta-Llama-3-8B-Instruct  # or any Llama-family model

Steps

  1. List the checkpoint's names: safetensors files are zip-like; enumerate without loading:
from safetensors import safe_open
import glob
names = []
for f in sorted(glob.glob("<model_dir>/*.safetensors")):
    with safe_open(f, framework="np") as sf:
        names += list(sf.keys())
print(len(names))           # ~291 for an 8B
print([n for n in names if ".layers.0." in n])   # one layer's worth
  1. Open upstream/vllm/model_executor/models/llama.py, find stacked_params_mapping and the load_weights loop. For each of the five tensor names in the table below, walk the loop by hand: which mapping row matches? what name does it become? which shard_id rides along?

  2. Verify live: load the model with LLM(model=..., enforce_eager=True) and afterwards inspect a parameter's shape: model.model.layers[0].self_attn.qkv_proj.weight.shape(6144, 4096) — and reconcile: 32 q-heads × 128 + 2 × (8 kv-heads × 128) = 4096 + 2048 = 6144. Lab-03's qkv_slices, on a real tensor.

Captured mapping (Llama-3-8B, vLLM 0.22.1)

checkpoint tensor (HF)vLLM parametershardrows in fused
model.layers.0.self_attn.q_proj.weight...self_attn.qkv_proj.weightq0:4096
model.layers.0.self_attn.k_proj.weight...self_attn.qkv_proj.weightk4096:5120
model.layers.0.self_attn.v_proj.weight...self_attn.qkv_proj.weightv5120:6144
model.layers.0.mlp.gate_proj.weight...mlp.gate_up_proj.weight00:14336
model.layers.0.input_layernorm.weight(itself)unfused
# live verification:
qkv_proj.weight.shape   = (6144, 4096)    # 4096 q + 1024 k + 1024 v  (GQA: 8 kv heads)
gate_up_proj.weight.shape = (28672, 4096) # 14336 gate + 14336 up
# under tensor_parallel_size=2: (3072, 4096) per rank — heads split, slices halve

Reading the trace

  • The k/v rows are 4× narrower than q's — GQA's 8 KV heads vs 32 query heads, lab-03's slice asymmetry on a real 8B. If you ever see (12288, 4096) here instead, you're looking at an MHA model — the fused shape is an architecture fingerprint.
  • gate_up_proj at 28,672 rows — the MLP's two halves stacked; down_proj stays unfused (it has no sibling to stack with). The mapping table's five rows cover ~80% of a Llama checkpoint's tensors; everything else passes through.
  • Under TP=2, every fused shape halves along rows — Phase 10 lab-01's column-parallel sharding composed with lab-03's stacking: each rank loads its heads' rows of each shard directly from disk. Two slicings, one read, no redistribution — the loading-is-part-of-sharding point from Phase 10, visible in a tensor shape.
  • enforce_eager=True keeps the trace clean (no capture pass cluttering logs — Phase 5 lab-04's test-suite setting, used for exactly its intended purpose).

Hitchhiker's notes

  • --load-format dummy skips real weights (random init) — the tool for testing mapping and shapes without downloading 16 GB, and how CI exercises loaders cheaply. Pair with a tiny --max-model-len and loader bugs surface in seconds.
  • Watch for the unloaded-parameter check: upstream tracks which params got weights and errors on leftovers — the missed-tensor guard from lab-03's going-further, in production. When adding a model, that error message is your todo list.
  • Sharded checkpoints (multiple .safetensors files) interleave layers across files arbitrarily — the loader is order-independent by design (each tensor knows its name; the mapping doesn't care about file layout). Resist any urge to assume file order means anything.
  • Quantized variants add scale/zero tensors with their own names (...qweight, ...scales) routed by the quant method's loader (Phase 6) — same loop, more rows. Tracing one AWQ tensor through is the natural sequel to this lab.

Reflect

  • From the shapes alone — (6144, 4096) qkv, (28672, 4096) gate_up — reconstruct the model card: hidden size, head count, KV heads, MLP expansion. (4096 hidden; 32 heads × 128; 8 KV heads; 3.5× ffn ratio.) Checkpoint forensics is a real skill; you just did it.
  • A teammate's new model PR loads but outputs garbage; loading reported no errors. Using labs 01–03: what are your first three checks? (Mapping rows for every fused family — a missed one means init-valued weights; slice boundaries vs the config's head counts; and the q/k/v order in the fused buffer vs what the attention layer expects.)
  • Why does vLLM fuse at load time rather than shipping a conversion script? (Checkpoints stay interchange-format; the fusion choice is the runtime's, can change between versions, and composes with TP/quantization decided at startup — lab-03's interface-vs-implementation point, operationalized.)

References

  • upstream/vllm/model_executor/models/llama.pyload_weights and stacked_params_mapping: the function under trace.
  • upstream/vllm/model_executor/layers/linear.py — the weight_loader callbacks that place each shard (lab-03's load_stacked, with TP).
  • upstream/vllm/model_executor/model_loader/ — the loader framework (formats, dummy loading, sharded files).
  • Lab-03 — the implementation this trace recognizes; lab-01 — the contract the loaded model serves through.