Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Lab 11-01 — Batched Multi-Adapter LoRA [CPU-OK]

A fine-tuned model is a base model plus a small correction — LoRA makes the correction a rank-r factorization (ΔW = B @ A, lab-03 prices it at ~1/400th of the base). The serving problem this lab solves: a single batch arrives carrying requests for different fine-tunes — tenant 1 wants the SQL adapter, tenant 2 the support-bot adapter, tenant 3 the plain base — and the engine must apply each row's own correction without forking the base computation. You'll implement the answer in three layers: the rank-r delta itself (shrink → expand), single-adapter application, and the batched grouped form — one shared base matmul for everyone, plus per-adapter-group deltas — proven exactly equal to the naive per-row loop. That grouped form is the punica/SGMV idea, and it's what makes multi-tenant fine-tune serving a product instead of a hack.

Contents


Why this lab exists

Multi-LoRA is the cleanest case study in the course of a structural insight beating a resource problem. The naive reading of "serve 50 fine-tunes" is 50 model deployments — 50× the weights, 50× the GPUs (lab-03 does the bill). The structural reading: the 50 models share 99.75% of their parameters, so factor the computation the same way the parameters factor — shared base, per-tenant deltas. This lab makes you earn that reading by implementing it and proving equality, because the equality is the entire safety case: a tenant must get bit-for-bit (well, float-for-float) the same output from the shared deployment as from a dedicated one, or the consolidation is a quality regression wearing a cost-savings hat.

It's also the phase's foundation stone: lab-02 runs this exact computation on a GPU, lab-03 prices the structures you're multiplying, lab-04 manages which adapters are allowed into the batch. And the grouping pattern itself — sort work by its parameter-set, run one efficient op per group, scatter back — is Phase 7 lab-01's MoE permute trick with adapters in place of experts. Second appearance; it has a third (Phase 13's modality grouping). Learn the shape, not just the instance.

Background: shrink, expand, group

The delta for one row: Δy = scaling · (x @ Aᵀ) @ Bᵀshrink to the r-dimensional bottleneck (x @ Aᵀ: in→r), then expand back (r→out). Never materialize B @ A (that's an out × in matrix — the whole point is not to build it); the two skinny matmuls cost r·(in+out) multiplies per token vs the base's in·out — the ~128× compute shrink that mirrors lab-03's memory shrink. scaling (= α/r in the standard parametrization) is a training-side constant that rides along.

The batch: rows tagged with adapter_ids (−1 = base only). The grouped application:

  1. One base matmul for the whole batchx @ Wᵀ, every row, regardless of adapter. This is the line that shares the expensive read (the base weights stream from HBM once — Phase 0 lab-04's bandwidth economics, multi-tenant edition).
  2. Per adapter group: gather that adapter's rows, run shrink/expand on the slice, scatter-add back. Segments of rows × one small GEMM each — "Segmented Gather Matrix-Vector multiply" (SGMV), named exactly for this shape.

Base-only rows simply skip step 2 — they cost nothing extra, which is why mixed base+adapter batches (lab-02's demo) are free to compose.

Files

  • starter.pylora_delta, apply_single, apply_batched. Your work.
  • solution.py — reference.
  • test_lab.py — batched ≡ per-row, base-only rows, the rank-r structure, and the shared-base property.

Run

LAB_IMPL=starter pytest phase-11-multi-lora/labs/lab-01-batched-lora-matmul -q
pytest phase-11-multi-lora/labs/lab-01-batched-lora-matmul -q   # reference

What to implement

The three functions per 02-mini-build.md. The one trap: in apply_batched, accumulate with indexed addition onto the base output (out[rows] += …) — and note that here, unlike Phase 7 lab-01's MoE scatter, plain fancy-indexed += is safe, because each row belongs to exactly one adapter (no duplicate indices). If you reflexively reached for np.add.at after Phase 7: good reflex, then notice why it's not needed — knowing when the footgun fires is better than fearing it always.

What the tests prove

TestWhat it pins
batched ≡ per-row loopThe consolidation safety case: grouping is an execution strategy, not a semantics change — the course's master invariant, tenant edition
adapter_id == -1 rows equal pure baseBase traffic rides free in mixed batches; no adapter machinery touches it
the delta is genuinely rank-rIt factors through the r-dim bottleneck — a delta that doesn't is a bug that costs you the entire economics (you'd be applying a full-rank update at full-rank prices)
one shared base matmulThe structural win itself, asserted: the base is read once per batch, not once per tenant

Hitchhiker's notes

  • Map to upstream: add_shrink / add_expand in upstream/vllm/lora/punica_wrapper/punica_base.py (and the CPU reference in punica_cpu.py — genuinely readable, go diff it against your solution) are your two halves of lora_delta; add_lora_linear is your apply_batched. The GPU versions fuse the segment loop into one kernel launch indexed by lab-04's slot ids — grouping logic identical, loop distributed across the grid.
  • Where LoRA hooks into the model: every ColumnParallelLinear / RowParallelLinear (Phase 10 lab-01!) gets a LoRA-aware wrapper that adds the delta after the base matmul. Under tensor parallelism the adapter shards along the same axes as its base layer — A with the input shard, B with the output shard — so TP × LoRA composes with no new collectives. Layer abstractions that compose are what make features multiply instead of interfere; vLLM's linear-layer stack is the load-bearing example.
  • Why group at all, on a GPU? The per-row loop launches a skinny matmul per request; the grouped form launches per adapter — and within a group, the rows share the adapter's A/B read (the tiling/reuse argument of Phase 7 lab-03, at miniature scale). With 64 rows across 4 adapters, that's 4 well-shaped small GEMMs vs 64 degenerate ones. Same arithmetic, ~order-of-magnitude better hardware shape.
  • The delta is dense in the batch dimension but tiny in compute — so multi-LoRA overhead rides almost entirely on decode steps' idle compute (Phase 0 lab-04's story again: bandwidth-bound steps have FLOPs to spare, and the adapter's extra bytes are 32 MiB against the base's 13 GiB). This is why lab-02's capture shows no visible throughput tax — and why the claim "LoRA serving is nearly free" survives measurement.

Going further

  • Implement the fused-into-base alternative for a single-adapter batch ((W + scaling·B@A) materialized, one matmul) and benchmark both in numpy at batch 1 vs 64. Merging wins single-tenant; grouping wins multi-tenant — find the crossover and you've reproduced the deployment decision lab-03's notes describe.
  • Add rank heterogeneity: adapters of rank 8, 16, 64 in one batch (real fleets have this). Your grouped loop handles it naturally; the slot-buffer version (lab-04) pads everyone to max_lora_rank — compute the padding waste and you've found why that config knob is set with gritted teeth.
  • Wire it into mini_vllm: adapter id on the Request, deltas applied to the toy model's logits per row. Multi-tenant mini-serving in ~30 lines — and the scheduler interaction (lab-04's max_schedulable) has a home to land in.

References

  • upstream/vllm/lora/punica_wrapper/punica_cpu.py — the readable reference your solution mirrors; punica_base.pyadd_shrink/add_expand/add_lora_linear.
  • Hu et al., LoRA (2021) — the factorization: https://arxiv.org/abs/2106.09685
  • Chen et al., Punica: Multi-Tenant LoRA Serving (2023) — SGMV, the kernel this lab's grouping becomes: https://arxiv.org/abs/2310.18547
  • Phase 7 lab-01 — the same grouping pattern with experts; lab-03 — the economics; lab-04 — which adapters get into the batch at all.