Lab 14-03 — Checkpoint Surgery: HF Names → vLLM Params, Shards → Fused [CPU-OK]
A HuggingFace checkpoint and a vLLM model disagree about what a layer is. The
checkpoint stores q_proj, k_proj, v_proj as three tensors; vLLM runs one fused
qkv_proj (one big GEMM beats three small ones — Phase 7 lab-03's tiling economics,
applied to layer design). Same for gate_proj+up_proj → gate_up_proj. Loading
weights is therefore translation: rename every checkpoint tensor to its vLLM
parameter, and copy shard tensors into the right slice of the fused buffer. This
lab has you build the translation table (llama.py's stacked_params_mapping, in
spirit), the GQA-aware slice arithmetic, and the shape guard that turns
wrong-checkpoint disasters into loud load-time errors — then prove the fusion legal
with the test that matters: the fused matmul's output slices equal the three separate
projections, exactly.
Contents
- Why this lab exists
- Background: why fused, and where the slices fall
- Files
- Run
- What the tests prove
- Hitchhiker's notes
- Going further
- References
Why this lab exists
When a newly-added model loads and generates fluent nonsense, the bug is almost never
in the forward pass — it's here, in the mapping: a shard landed in the wrong slice,
a name pattern missed a tensor (silently left at init values), or an MHA checkpoint
met a GQA config. These failures are maddening precisely because nothing crashes:
the shapes coincidentally fit, the matmuls run, the output is garbage. The two
defenses you'll build are the professional's toolkit: exact slice arithmetic
(derived, not pattern-matched) and assert-on-shape at load time
(test_shape_mismatch_is_loud — the wrong-checkpoint case caught at the door, not at
the demo).
This lab is also lab-02's prerequisite done right: lab-02 has you read
load_weights in llama.py; this lab has you implement its core first, so the
reading is recognition. The pairing (build small, then read big) is the course's
method; this is its purest instance — the production function is your three
functions plus a loop over the checkpoint.
Background: why fused, and where the slices fall
Why fuse at all: three matmuls over the same input x with weights Wq, Wk, Wv
equal one matmul with the row-stacked weight — x @ [Wq; Wk; Wv]ᵀ — and the single
GEMM launches once, tiles better (Phase 7 lab-03: bigger M×N per weight-read), and
reads x from memory once instead of three times. The legality is two lines of block
matrix algebra, and test_fused_matmul_equals_separate_projections states it as an
executable fact. (Column-stacking composes with tensor parallelism too:
QKVParallelLinear is Phase 10 lab-01's column-parallel class with this stacking
built in — the shard boundaries respect head boundaries on every rank.)
Where the slices fall — the GQA wrinkle: with nh query heads, nkv KV heads,
head_dim hd, the fused weight has (nh + 2·nkv)·hd rows: q owns the first nh·hd,
k the next nkv·hd, v the last nkv·hd. Under GQA (Phase 0 lab-02's 4× KV saving)
nkv < nh, so k and v slices are narrower than q's — the asymmetry
test_qkv_slices_account_for_gqa pins, and exactly the place hand-written loaders
go wrong when their author last looked at an MHA model.
The name mapping: a substring rewrite (q_proj → qkv_proj) plus a shard tag
telling the loader which slice. Tensors outside the table (norms, embeddings,
down_proj — anything unfused) map to themselves. Upstream's
stacked_params_mapping is literally this list of triples; your STACKED_PARAMS
copies its shape.
Files
starter.py—map_weight_name,qkv_slices,load_stacked(+ theSTACKED_PARAMStable, provided). Your work.solution.py— reference.test_lab.py— the mapping, pass-throughs, GQA slice arithmetic, the fusion- legality proof, and the loud-mismatch guard.
Run
LAB_IMPL=starter pytest phase-14-model-architectures/labs/lab-03-weight-mapping -q
pytest phase-14-model-architectures/labs/lab-03-weight-mapping -q # reference
What the tests prove
| Test | What it pins |
|---|---|
test_name_mapping | The rewrite preserves the layer path and swaps only the projection name — model.layers.3.self_attn.q_proj.weight keeps its layers.3 identity |
test_unstacked_names_pass_through | Norms, embeddings, down_proj, lm_head: shard_id None, name unchanged. A mapping that's too greedy (matching up_proj inside gate_up_proj-like names) fails here |
test_qkv_slices_account_for_gqa | q gets 128 rows, k and v get 32 each — and the three slices tile the fused rows with no gaps and no overlap (assert the boundary equalities; off-by-ones here are the garbage-output bug) |
test_fused_matmul_equals_separate_projections | The legality theorem: slice the fused output and recover each projection to 1e-12. Fusion is layout, not math — the course's paged-attention identity (Phase 2 lab-06), weight edition |
test_shape_mismatch_is_loud | An MHA-width k shard against a GQA config: caught by the assert at load, with shapes in the message. The alternative is a demo that hallucinates |
Hitchhiker's notes
- Read
load_weightsright after this (upstream/vllm/model_executor/models/llama.py, searchstacked_params_mapping): the production loop is your three functions plus reality — iterating safetensors shards, skipping rotary-embedding buffers, handling TP (each rank loads only its rows of each slice: Phase 10 lab-01's sharding composed with this lab's stacking — two slicings, one tensor), andweight_loadercallbacks per parameter that encapsulate the slice placement. Yourload_stackedis theweight_loaderofQKVParallelLinear, minus distribution. - Quantized checkpoints stack the stakes: AWQ/GPTQ tensors come with scales and zero-points per group (Phase 6 lab-03) that must be sliced consistently with their weights — a mapping bug now corrupts numerics in a way that's only statistically visible. Same machinery, smaller margin for error; the loud-assert habit pays double.
- The mapping table is per-architecture API: when a new HF model renames a tensor
(
mlp.experts.0.w1vsblock_sparse_moe...), vLLM's loader needs a new mapping entry — the single most common cause of "KeyError loading model X with vLLM version Y" issues. You can now read those tracebacks as "the translation table is missing a row" and often fix them yourself; that's a real first upstream PR shape. - Why not store fused in the checkpoint? The checkpoint serves every runtime
(HF transformers, llama.cpp, MLX...), each with its own fusion choices. Unfused is
the interchange format; fusion is a runtime optimization — the same
interface-vs-implementation split as Phase 11's unmerged LoRA, and the reason
load_weightsexists at all.
Going further
- Add
down_projand embedding handling plus a fullload_checkpoint(params, ckpt)driver: iterate a dict of fake checkpoint tensors, translate, place, and assert every parameter got touched exactly once (the missed-tensor bug class, made checkable — upstream tracksloaded_paramsfor the same reason). - Compose with TP: given
tp_rank, tp_size, makeload_stackedplace only the rank's rows of each shard (q rows shard by head; k/v by KV head — and note what happens whennkv < tp_size: KV-head replication, the real constraint from Phase 10 lab-01's divisibility test). - Write the MoE mapping rows (
experts.N.w1 → experts.w13_weightwith expert-index shards) by readingmixtral.py's table — the same idea with two stacking axes.
References
upstream/vllm/model_executor/models/llama.py—load_weights+stacked_params_mapping: this lab, productionized (lab-02 reads it with you).upstream/vllm/model_executor/layers/linear.py—QKVParallelLinear.weight_loader: yourload_stackedwith TP.- vLLM docs, Adding a New Model — where the mapping table fits in the integrator's checklist: https://docs.vllm.ai/en/latest/contributing/model/
- Phase 7 lab-03 — why fused GEMMs win; Phase 10 lab-01 — the sharding this stacking composes with; Phase 6 lab-03 — the quantized version of the stakes.