Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Lab 10-01 — Tensor Parallelism Math [CPU-OK]

A 70B model's weights don't fit on your GPU. Tensor parallelism's answer is almost insolent in its simplicity: a matrix multiply distributes over slicing — cut the weight matrix into N pieces, give each GPU one piece, and the partial results reassemble into exactly the unsharded answer. This lab makes you prove it, in numpy, with the two sharding patterns that production TP is built from — column-parallel (slice outputs, reassemble by concatenation = all-gather) and row-parallel (slice inputs, reassemble by summation = all-reduce) — and then the composition trick that makes a whole transformer block cost only one all-reduce: pair them column→row, and the intermediate never needs reassembling at all.

Contents


Why this lab exists

Distributed inference has a reputation for being infrastructure wizardry — Ray clusters, NCCL, process groups — and that reputation obscures the fact that the core is linear algebra a laptop verifies in milliseconds. Separating the two layers is the point of this lab: the math (which sharding produces which partial result, and what collective reassembles it) is exact, provable, and small; the infrastructure (Phase 10's deep-dive: process groups, communicators, weight loaders that shard at load time) exists to execute that math. Engineers who learn the infrastructure first treat ColumnParallelLinear as an incantation; engineers who learn the math first read it as "my column_parallel, with NCCL where my np.concatenate is."

The one-all-reduce composition is the part that earns the word design. Naively sharding two consecutive matmuls costs a collective after each. The Megatron insight — which every serving stack inherited — is that the column shard's output is already partitioned exactly the way the row shard's input wants it: the activation flows from shard to shard without ever being whole. Communication is designed out, not optimized out. You'll assert it: num_all_reduces == 1.

Background: two shardings and the pairing trick

For y = x @ W.T (W: (out, in)):

  • Column-parallel (shard W's output rows): rank r computes x @ W_r.T, a slice of y's columns. Reassembly = concatenation (all-gather). Every rank needs all of x, which it has (the previous all-reduce ended with everyone holding the full activation).
  • Row-parallel (shard W's input columns): rank r holds W[:, r·c:(r+1)·c] and only the matching slice of x, computing a full-shaped but partial y_r. Reassembly = elementwise sum (all-reduce).

The MLP composition: W1 column-parallel → each rank holds a slice of the hidden activation → apply the nonlinearity per-shard (elementwise, so it commutes with slicing — this is why the trick works for ReLU/SiLU but would break for anything mixing hidden dims) → W2 row-parallel consumes exactly that slice → one all-reduce at the end. Attention follows the same pattern with heads as the natural column boundary: QKV projections column-parallel (each rank owns whole heads), out-proj row-parallel. Two blocks per layer, one all-reduce each — lab-03 prices them.

Files

  • starter.pycolumn_parallel, row_parallel, mlp_tp. Your work.
  • solution.py — reference.
  • test_lab.py — exact reconstruction for several rank counts, the one-all-reduce property, and the divisibility constraint.

Run

LAB_IMPL=starter pytest phase-10-distributed-inference/labs/lab-01-tp-sharding-math -q
pytest phase-10-distributed-inference/labs/lab-01-tp-sharding-math -q   # reference

What to implement

Per 02-mini-build.md. The loop over ranks is the simulation — each iteration is one GPU's life; the concatenate and the running sum are the collectives. Keep that mapping conscious: when you later read real TP code, every line will be one of your loop bodies with the loop distributed across processes.

What the tests prove

TestWhat it pins
column/row reconstruct x @ W.T exactlySharding is algebra, not approximation — to machine precision, for num_ranks ∈ {1, 2, 4, 8} (and rank-count invariance is itself the deployment-critical property: TP=4 and TP=8 must serve identical models)
mlp_tp == dense MLP with num_all_reduces == 1The Megatron pairing: the hidden activation never reassembles. The counter in the return value is the design, made falsifiable
divisibility assertedhidden % num_ranks == 0 — why TP sizes are powers of two and why some models can't run at TP=6: head counts and hidden dims must divide. A real constraint users hit (GQA's 8 KV heads cap practical TP at 8 without head replication)

Hitchhiker's notes

  • Floating point note: the row-parallel sum reorders additions vs the dense matmul, so on real hardware TP=2 and TP=1 differ in the last ulp — the recurring last-ulp story (Phases 3/4/6/9), now with rank count as the trigger. Your float64 numpy hides it; fp16 GPUs don't. "Different outputs at different TP sizes" bug reports are usually this, not a bug.
  • Map to upstream: ColumnParallelLinear / RowParallelLinear in upstream/vllm/model_executor/layers/linear.py — find the single all_reduce in the row class's forward (linear.py:1392), and notice gather_output=False on the column class: the default is the paired pattern, all-gather elided. Model code composes these two classes and TP falls out — that's why adding a new model (Phase 14) barely thinks about TP.
  • Weights are sharded at load time, not runtime — each rank reads only its slice from the checkpoint (the weight loader's shard_id machinery). The lab's W[r*chunk:(r+1)*chunk] is, in production, a file-read pattern: TP=8 startup reads each tensor once across 8 processes. Loading is part of the sharding design, not an afterthought.
  • Embedding and LM head shard on the vocabulary dimension (vocab-parallel) — same two patterns, different axis, with a gather at the logits. Every weight matrix in the model has a natural slicing axis; TP is the discipline of choosing axes so the collectives stay rare.

Going further

  • Add attention_tp(x, Wqkv, Wo, num_heads, num_ranks): heads as the column boundary, out-proj row-parallel, assert one all-reduce and head-count divisibility. You've now sharded both halves of a real layer.
  • Implement gather_output=True (the elided all-gather) and count collectives for the unpaired composition — two matmuls sharded naively. The diff against mlp_tp's 1 is the Megatron paper's contribution, measured by your counter.
  • Simulate a wrong sharding: shard W1 by rows instead of columns, watch the nonlinearity break the reconstruction (ReLU of a partial sum ≠ partial of a ReLU). The elementwise-commutes-with-slicing condition, demonstrated by violating it.

References

  • Shoeybi et al., Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism (2019) — the column→row pairing, Figure 3: https://arxiv.org/abs/1909.08053
  • upstream/vllm/model_executor/layers/linear.py — the two classes and the one all-reduce (:1392).
  • upstream/vllm/model_executor/layers/vocab_parallel_embedding.py — the same idea on the vocab axis.
  • Lab-03 — what the one all-reduce costs; lab-02 — the memory split, observed live.