Lab 14-02 — Trace Weight Loading in the Real llama.py [GPU-OPT]
Lab-03 had you implement the translation; this lab has you watch the production
version run — and read it as a peer. You'll trace five tensors from a real Llama
checkpoint through load_weights: the safetensors name on disk, the mapping row that
claims it, the vLLM parameter and shard it lands in, and (under TP) which rows of
that shard each rank takes. The deliverable is the filled-in mapping table below —
five rows of checkpoint surgery, verified against a live load.
No GPU? Don't panic. Loading happens on CPU before anything touches CUDA — you can trace most of this with
device="cpu"-ish settings or just the captured table below plus the source. The reading is the lab.
Contents
- Why this lab exists
- Requirements
- Steps
- Captured mapping (Llama-3-8B, vLLM 0.22.1)
- Reading the trace
- Hitchhiker's notes
- Reflect
- References
Why this lab exists
load_weights is where most model-integration PRs live or die, and it's also the
single most readable "real" function in the model zoo once you have lab-03's
vocabulary — a loop, a mapping table, and weight_loader callbacks. Tracing five
tensors end-to-end converts the function from code-you-scroll-past into
code-you-could-have-written, and it arms you for the two production moments that
need this knowledge: a new checkpoint that won't load (which mapping row is
missing?), and a loaded model that generates garbage (which shard landed wrong?).
Requirements
uv pip install -e ".[vllm]"
huggingface-cli download meta-llama/Meta-Llama-3-8B-Instruct # or any Llama-family model
Steps
- List the checkpoint's names:
safetensorsfiles are zip-like; enumerate without loading:
from safetensors import safe_open
import glob
names = []
for f in sorted(glob.glob("<model_dir>/*.safetensors")):
with safe_open(f, framework="np") as sf:
names += list(sf.keys())
print(len(names)) # ~291 for an 8B
print([n for n in names if ".layers.0." in n]) # one layer's worth
-
Open
upstream/vllm/model_executor/models/llama.py, findstacked_params_mappingand theload_weightsloop. For each of the five tensor names in the table below, walk the loop by hand: which mapping row matches? what name does it become? whichshard_idrides along? -
Verify live: load the model with
LLM(model=..., enforce_eager=True)and afterwards inspect a parameter's shape:model.model.layers[0].self_attn.qkv_proj.weight.shape→(6144, 4096)— and reconcile: 32 q-heads × 128 + 2 × (8 kv-heads × 128) = 4096 + 2048 = 6144. Lab-03'sqkv_slices, on a real tensor.
Captured mapping (Llama-3-8B, vLLM 0.22.1)
| checkpoint tensor (HF) | vLLM parameter | shard | rows in fused |
|---|---|---|---|
model.layers.0.self_attn.q_proj.weight | ...self_attn.qkv_proj.weight | q | 0:4096 |
model.layers.0.self_attn.k_proj.weight | ...self_attn.qkv_proj.weight | k | 4096:5120 |
model.layers.0.self_attn.v_proj.weight | ...self_attn.qkv_proj.weight | v | 5120:6144 |
model.layers.0.mlp.gate_proj.weight | ...mlp.gate_up_proj.weight | 0 | 0:14336 |
model.layers.0.input_layernorm.weight | (itself) | — | unfused |
# live verification:
qkv_proj.weight.shape = (6144, 4096) # 4096 q + 1024 k + 1024 v (GQA: 8 kv heads)
gate_up_proj.weight.shape = (28672, 4096) # 14336 gate + 14336 up
# under tensor_parallel_size=2: (3072, 4096) per rank — heads split, slices halve
Reading the trace
- The k/v rows are 4× narrower than q's — GQA's 8 KV heads vs 32 query heads,
lab-03's slice asymmetry on a real 8B. If you ever see
(12288, 4096)here instead, you're looking at an MHA model — the fused shape is an architecture fingerprint. gate_up_projat 28,672 rows — the MLP's two halves stacked;down_projstays unfused (it has no sibling to stack with). The mapping table's five rows cover ~80% of a Llama checkpoint's tensors; everything else passes through.- Under TP=2, every fused shape halves along rows — Phase 10 lab-01's column-parallel sharding composed with lab-03's stacking: each rank loads its heads' rows of each shard directly from disk. Two slicings, one read, no redistribution — the loading-is-part-of-sharding point from Phase 10, visible in a tensor shape.
enforce_eager=Truekeeps the trace clean (no capture pass cluttering logs — Phase 5 lab-04's test-suite setting, used for exactly its intended purpose).
Hitchhiker's notes
--load-format dummyskips real weights (random init) — the tool for testing mapping and shapes without downloading 16 GB, and how CI exercises loaders cheaply. Pair with a tiny--max-model-lenand loader bugs surface in seconds.- Watch for the unloaded-parameter check: upstream tracks which params got weights and errors on leftovers — the missed-tensor guard from lab-03's going-further, in production. When adding a model, that error message is your todo list.
- Sharded checkpoints (multiple
.safetensorsfiles) interleave layers across files arbitrarily — the loader is order-independent by design (each tensor knows its name; the mapping doesn't care about file layout). Resist any urge to assume file order means anything. - Quantized variants add scale/zero tensors with their own names
(
...qweight,...scales) routed by the quant method's loader (Phase 6) — same loop, more rows. Tracing one AWQ tensor through is the natural sequel to this lab.
Reflect
- From the shapes alone —
(6144, 4096)qkv,(28672, 4096)gate_up — reconstruct the model card: hidden size, head count, KV heads, MLP expansion. (4096 hidden; 32 heads × 128; 8 KV heads; 3.5× ffn ratio.) Checkpoint forensics is a real skill; you just did it. - A teammate's new model PR loads but outputs garbage; loading reported no errors. Using labs 01–03: what are your first three checks? (Mapping rows for every fused family — a missed one means init-valued weights; slice boundaries vs the config's head counts; and the q/k/v order in the fused buffer vs what the attention layer expects.)
- Why does vLLM fuse at load time rather than shipping a conversion script? (Checkpoints stay interchange-format; the fusion choice is the runtime's, can change between versions, and composes with TP/quantization decided at startup — lab-03's interface-vs-implementation point, operationalized.)
References
upstream/vllm/model_executor/models/llama.py—load_weightsandstacked_params_mapping: the function under trace.upstream/vllm/model_executor/layers/linear.py— theweight_loadercallbacks that place each shard (lab-03'sload_stacked, with TP).upstream/vllm/model_executor/model_loader/— the loader framework (formats, dummy loading, sharded files).- Lab-03 — the implementation this trace recognizes; lab-01 — the contract the loaded model serves through.