Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Lab 14-03 — Checkpoint Surgery: HF Names → vLLM Params, Shards → Fused [CPU-OK]

A HuggingFace checkpoint and a vLLM model disagree about what a layer is. The checkpoint stores q_proj, k_proj, v_proj as three tensors; vLLM runs one fused qkv_proj (one big GEMM beats three small ones — Phase 7 lab-03's tiling economics, applied to layer design). Same for gate_proj+up_projgate_up_proj. Loading weights is therefore translation: rename every checkpoint tensor to its vLLM parameter, and copy shard tensors into the right slice of the fused buffer. This lab has you build the translation table (llama.py's stacked_params_mapping, in spirit), the GQA-aware slice arithmetic, and the shape guard that turns wrong-checkpoint disasters into loud load-time errors — then prove the fusion legal with the test that matters: the fused matmul's output slices equal the three separate projections, exactly.

Contents


Why this lab exists

When a newly-added model loads and generates fluent nonsense, the bug is almost never in the forward pass — it's here, in the mapping: a shard landed in the wrong slice, a name pattern missed a tensor (silently left at init values), or an MHA checkpoint met a GQA config. These failures are maddening precisely because nothing crashes: the shapes coincidentally fit, the matmuls run, the output is garbage. The two defenses you'll build are the professional's toolkit: exact slice arithmetic (derived, not pattern-matched) and assert-on-shape at load time (test_shape_mismatch_is_loud — the wrong-checkpoint case caught at the door, not at the demo).

This lab is also lab-02's prerequisite done right: lab-02 has you read load_weights in llama.py; this lab has you implement its core first, so the reading is recognition. The pairing (build small, then read big) is the course's method; this is its purest instance — the production function is your three functions plus a loop over the checkpoint.

Background: why fused, and where the slices fall

Why fuse at all: three matmuls over the same input x with weights Wq, Wk, Wv equal one matmul with the row-stacked weight — x @ [Wq; Wk; Wv]ᵀ — and the single GEMM launches once, tiles better (Phase 7 lab-03: bigger M×N per weight-read), and reads x from memory once instead of three times. The legality is two lines of block matrix algebra, and test_fused_matmul_equals_separate_projections states it as an executable fact. (Column-stacking composes with tensor parallelism too: QKVParallelLinear is Phase 10 lab-01's column-parallel class with this stacking built in — the shard boundaries respect head boundaries on every rank.)

Where the slices fall — the GQA wrinkle: with nh query heads, nkv KV heads, head_dim hd, the fused weight has (nh + 2·nkv)·hd rows: q owns the first nh·hd, k the next nkv·hd, v the last nkv·hd. Under GQA (Phase 0 lab-02's 4× KV saving) nkv < nh, so k and v slices are narrower than q's — the asymmetry test_qkv_slices_account_for_gqa pins, and exactly the place hand-written loaders go wrong when their author last looked at an MHA model.

The name mapping: a substring rewrite (q_proj → qkv_proj) plus a shard tag telling the loader which slice. Tensors outside the table (norms, embeddings, down_proj — anything unfused) map to themselves. Upstream's stacked_params_mapping is literally this list of triples; your STACKED_PARAMS copies its shape.

Files

  • starter.pymap_weight_name, qkv_slices, load_stacked (+ the STACKED_PARAMS table, provided). Your work.
  • solution.py — reference.
  • test_lab.py — the mapping, pass-throughs, GQA slice arithmetic, the fusion- legality proof, and the loud-mismatch guard.

Run

LAB_IMPL=starter pytest phase-14-model-architectures/labs/lab-03-weight-mapping -q
pytest phase-14-model-architectures/labs/lab-03-weight-mapping -q   # reference

What the tests prove

TestWhat it pins
test_name_mappingThe rewrite preserves the layer path and swaps only the projection name — model.layers.3.self_attn.q_proj.weight keeps its layers.3 identity
test_unstacked_names_pass_throughNorms, embeddings, down_proj, lm_head: shard_id None, name unchanged. A mapping that's too greedy (matching up_proj inside gate_up_proj-like names) fails here
test_qkv_slices_account_for_gqaq gets 128 rows, k and v get 32 each — and the three slices tile the fused rows with no gaps and no overlap (assert the boundary equalities; off-by-ones here are the garbage-output bug)
test_fused_matmul_equals_separate_projectionsThe legality theorem: slice the fused output and recover each projection to 1e-12. Fusion is layout, not math — the course's paged-attention identity (Phase 2 lab-06), weight edition
test_shape_mismatch_is_loudAn MHA-width k shard against a GQA config: caught by the assert at load, with shapes in the message. The alternative is a demo that hallucinates

Hitchhiker's notes

  • Read load_weights right after this (upstream/vllm/model_executor/models/llama.py, search stacked_params_mapping): the production loop is your three functions plus reality — iterating safetensors shards, skipping rotary-embedding buffers, handling TP (each rank loads only its rows of each slice: Phase 10 lab-01's sharding composed with this lab's stacking — two slicings, one tensor), and weight_loader callbacks per parameter that encapsulate the slice placement. Your load_stacked is the weight_loader of QKVParallelLinear, minus distribution.
  • Quantized checkpoints stack the stakes: AWQ/GPTQ tensors come with scales and zero-points per group (Phase 6 lab-03) that must be sliced consistently with their weights — a mapping bug now corrupts numerics in a way that's only statistically visible. Same machinery, smaller margin for error; the loud-assert habit pays double.
  • The mapping table is per-architecture API: when a new HF model renames a tensor (mlp.experts.0.w1 vs block_sparse_moe...), vLLM's loader needs a new mapping entry — the single most common cause of "KeyError loading model X with vLLM version Y" issues. You can now read those tracebacks as "the translation table is missing a row" and often fix them yourself; that's a real first upstream PR shape.
  • Why not store fused in the checkpoint? The checkpoint serves every runtime (HF transformers, llama.cpp, MLX...), each with its own fusion choices. Unfused is the interchange format; fusion is a runtime optimization — the same interface-vs-implementation split as Phase 11's unmerged LoRA, and the reason load_weights exists at all.

Going further

  • Add down_proj and embedding handling plus a full load_checkpoint(params, ckpt) driver: iterate a dict of fake checkpoint tensors, translate, place, and assert every parameter got touched exactly once (the missed-tensor bug class, made checkable — upstream tracks loaded_params for the same reason).
  • Compose with TP: given tp_rank, tp_size, make load_stacked place only the rank's rows of each shard (q rows shard by head; k/v by KV head — and note what happens when nkv < tp_size: KV-head replication, the real constraint from Phase 10 lab-01's divisibility test).
  • Write the MoE mapping rows (experts.N.w1 → experts.w13_weight with expert-index shards) by reading mixtral.py's table — the same idea with two stacking axes.

References

  • upstream/vllm/model_executor/models/llama.pyload_weights + stacked_params_mapping: this lab, productionized (lab-02 reads it with you).
  • upstream/vllm/model_executor/layers/linear.pyQKVParallelLinear.weight_loader: your load_stacked with TP.
  • vLLM docs, Adding a New Model — where the mapping table fits in the integrator's checklist: https://docs.vllm.ai/en/latest/contributing/model/
  • Phase 7 lab-03 — why fused GEMMs win; Phase 10 lab-01 — the sharding this stacking composes with; Phase 6 lab-03 — the quantized version of the stakes.