Lab 10-01 — Tensor Parallelism Math [CPU-OK]
A 70B model's weights don't fit on your GPU. Tensor parallelism's answer is almost insolent in its simplicity: a matrix multiply distributes over slicing — cut the weight matrix into N pieces, give each GPU one piece, and the partial results reassemble into exactly the unsharded answer. This lab makes you prove it, in numpy, with the two sharding patterns that production TP is built from — column-parallel (slice outputs, reassemble by concatenation = all-gather) and row-parallel (slice inputs, reassemble by summation = all-reduce) — and then the composition trick that makes a whole transformer block cost only one all-reduce: pair them column→row, and the intermediate never needs reassembling at all.
Contents
- Why this lab exists
- Background: two shardings and the pairing trick
- Files
- Run
- What to implement
- What the tests prove
- Hitchhiker's notes
- Going further
- References
Why this lab exists
Distributed inference has a reputation for being infrastructure wizardry — Ray
clusters, NCCL, process groups — and that reputation obscures the fact that the core
is linear algebra a laptop verifies in milliseconds. Separating the two layers is the
point of this lab: the math (which sharding produces which partial result, and what
collective reassembles it) is exact, provable, and small; the infrastructure (Phase
10's deep-dive: process groups, communicators, weight loaders that shard at load time)
exists to execute that math. Engineers who learn the infrastructure first treat
ColumnParallelLinear as an incantation; engineers who learn the math first read it
as "my column_parallel, with NCCL where my np.concatenate is."
The one-all-reduce composition is the part that earns the word design. Naively
sharding two consecutive matmuls costs a collective after each. The Megatron insight —
which every serving stack inherited — is that the column shard's output is already
partitioned exactly the way the row shard's input wants it: the activation flows from
shard to shard without ever being whole. Communication is designed out, not optimized
out. You'll assert it: num_all_reduces == 1.
Background: two shardings and the pairing trick
For y = x @ W.T (W: (out, in)):
- Column-parallel (shard W's output rows): rank r computes
x @ W_r.T, a slice of y's columns. Reassembly = concatenation (all-gather). Every rank needs all ofx, which it has (the previous all-reduce ended with everyone holding the full activation). - Row-parallel (shard W's input columns): rank r holds
W[:, r·c:(r+1)·c]and only the matching slice ofx, computing a full-shaped but partialy_r. Reassembly = elementwise sum (all-reduce).
The MLP composition: W1 column-parallel → each rank holds a slice of the hidden
activation → apply the nonlinearity per-shard (elementwise, so it commutes with
slicing — this is why the trick works for ReLU/SiLU but would break for anything
mixing hidden dims) → W2 row-parallel consumes exactly that slice → one all-reduce
at the end. Attention follows the same pattern with heads as the natural column
boundary: QKV projections column-parallel (each rank owns whole heads), out-proj
row-parallel. Two blocks per layer, one all-reduce each — lab-03 prices them.
Files
starter.py—column_parallel,row_parallel,mlp_tp. Your work.solution.py— reference.test_lab.py— exact reconstruction for several rank counts, the one-all-reduce property, and the divisibility constraint.
Run
LAB_IMPL=starter pytest phase-10-distributed-inference/labs/lab-01-tp-sharding-math -q
pytest phase-10-distributed-inference/labs/lab-01-tp-sharding-math -q # reference
What to implement
Per 02-mini-build.md. The loop over ranks is the
simulation — each iteration is one GPU's life; the concatenate and the running sum
are the collectives. Keep that mapping conscious: when you later read real TP code,
every line will be one of your loop bodies with the loop distributed across processes.
What the tests prove
| Test | What it pins |
|---|---|
column/row reconstruct x @ W.T exactly | Sharding is algebra, not approximation — to machine precision, for num_ranks ∈ {1, 2, 4, 8} (and rank-count invariance is itself the deployment-critical property: TP=4 and TP=8 must serve identical models) |
mlp_tp == dense MLP with num_all_reduces == 1 | The Megatron pairing: the hidden activation never reassembles. The counter in the return value is the design, made falsifiable |
| divisibility asserted | hidden % num_ranks == 0 — why TP sizes are powers of two and why some models can't run at TP=6: head counts and hidden dims must divide. A real constraint users hit (GQA's 8 KV heads cap practical TP at 8 without head replication) |
Hitchhiker's notes
- Floating point note: the row-parallel sum reorders additions vs the dense matmul, so on real hardware TP=2 and TP=1 differ in the last ulp — the recurring last-ulp story (Phases 3/4/6/9), now with rank count as the trigger. Your float64 numpy hides it; fp16 GPUs don't. "Different outputs at different TP sizes" bug reports are usually this, not a bug.
- Map to upstream:
ColumnParallelLinear/RowParallelLinearinupstream/vllm/model_executor/layers/linear.py— find the singleall_reducein the row class's forward (linear.py:1392), and noticegather_output=Falseon the column class: the default is the paired pattern, all-gather elided. Model code composes these two classes and TP falls out — that's why adding a new model (Phase 14) barely thinks about TP. - Weights are sharded at load time, not runtime — each rank reads only its slice
from the checkpoint (the weight loader's
shard_idmachinery). The lab'sW[r*chunk:(r+1)*chunk]is, in production, a file-read pattern: TP=8 startup reads each tensor once across 8 processes. Loading is part of the sharding design, not an afterthought. - Embedding and LM head shard on the vocabulary dimension (vocab-parallel) — same two patterns, different axis, with a gather at the logits. Every weight matrix in the model has a natural slicing axis; TP is the discipline of choosing axes so the collectives stay rare.
Going further
- Add
attention_tp(x, Wqkv, Wo, num_heads, num_ranks): heads as the column boundary, out-proj row-parallel, assert one all-reduce and head-count divisibility. You've now sharded both halves of a real layer. - Implement
gather_output=True(the elided all-gather) and count collectives for the unpaired composition — two matmuls sharded naively. The diff againstmlp_tp's 1 is the Megatron paper's contribution, measured by your counter. - Simulate a wrong sharding: shard
W1by rows instead of columns, watch the nonlinearity break the reconstruction (ReLU of a partial sum ≠ partial of a ReLU). The elementwise-commutes-with-slicing condition, demonstrated by violating it.
References
- Shoeybi et al., Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism (2019) — the column→row pairing, Figure 3: https://arxiv.org/abs/1909.08053
upstream/vllm/model_executor/layers/linear.py— the two classes and the one all-reduce (:1392).upstream/vllm/model_executor/layers/vocab_parallel_embedding.py— the same idea on the vocab axis.- Lab-03 — what the one all-reduce costs; lab-02 — the memory split, observed live.