Lab 07-01 — The MoE Forward (Routing + Grouped Experts) [CPU-OK]
A mixture-of-experts layer makes a strange promise: a model with 8× the parameters at
~1× the per-token compute, because each token visits only its top-k of E expert MLPs.
The catch is operational — "each token visits different experts" is a scatter/gather
nightmare for hardware that loves big uniform matmuls. This lab has you implement both
sides of the resolution: the naive per-token loop (obviously correct, hopelessly
slow — your oracle) and the grouped formulation (permute tokens by expert → one big
matmul per expert → scatter back with combine weights), and prove they're equal. The
grouped version is, step for step, what vLLM's fused_moe_kernel does in one GPU pass —
you're writing the readable edition of one of the hottest kernels in modern serving.
Contents
- Why this lab exists
- Background: the permute trick
- Files
- Run
- What to implement
- What the tests prove
- Hitchhiker's notes
- Going further
- References
Why this lab exists
MoE is where the frontier lives (Mixtral, DeepSeek-V3, Qwen-MoE, most rumored frontier
models), and its serving stack confuses newcomers because the math (a weighted sum of
small MLPs) and the implementation (sorts, histograms, alignment buffers, grouped
GEMMs) look unrelated. They aren't — the implementation is the math, reorganized so the
GPU sees few large uniform operations instead of many tiny ragged ones. Building both
versions and asserting equality is how you internalize the correspondence; after this
lab, moe_align_block_size (a real kernel whose name suggests nothing) reads as "my
argsort, made tile-friendly."
The grouped-equals-reference test is also this course's master invariant again
(optimizations must not change output) in its most insidious habitat: the scatter-back.
Combine-weight bugs and duplicate-token-row bugs (np.add.at vs out[toks] += — see
the notes) produce outputs that are plausibly wrong, the worst kind. The oracle test
is the only honest defense.
Background: the permute trick
Per token: router logits x @ W_gateᵀ → take top-k experts → softmax the selected k
logits into combine weights → output is the weighted sum of those experts' MLP outputs.
Done literally, that's T × k tiny matmuls — death by launch overhead and zero data
reuse (lab-03 quantifies why tiny matmuls waste a GPU).
The grouped reformulation observes that the same set of (token, expert) pairs can be processed expert-major instead of token-major:
- Flatten the
(T, k)assignment matrix intoT·k(token, expert, weight) triples. - Permute: sort triples by expert (your
argsort; real kernels build the equivalent grouping with a histogram + prefix sum —moe_align_block_size). - Grouped GEMM: for each expert, one matmul over its contiguous block of tokens —
E medium matmuls instead of
T·ktiny ones, each big enough to tile well (lab-03). - Un-permute + combine: scatter results back to token order, multiplying by the combine weights, summing the k contributions per token.
No arithmetic changed — only its order. The speedup comes entirely from shaping the work to what hardware rewards: contiguity and uniformity.
Files
starter.py—route,expert_mlp,moe_forward_reference,moe_forward_grouped. Your work.solution.py— reference.test_lab.py— grouped == reference, combine weights sum to 1, assignment bookkeeping.
Run
LAB_IMPL=starter pytest phase-07-gemm-and-moe-kernels/labs/lab-01-moe-routing -q
pytest phase-07-gemm-and-moe-kernels/labs/lab-01-moe-routing -q # reference
What to implement
Per 02-mini-build.md: route (logits → top-k ids + softmax of
the selected logits), expert_mlp (relu(x @ W1) @ W2), the reference loop, and the
grouped version. Two precision points: softmax over the selected k logits only (not
all E — selecting then normalizing is the standard formulation; normalizing then
selecting gives different weights), and the scatter-back must handle a token appearing
twice in an expert's block when top-k assigns it duplicate experts — np.add.at
accumulates correctly where fancy-indexed += silently drops duplicates. That numpy
footgun is the lab's hidden boss; the bookkeeping test exists for it.
What the tests prove
| Test | What it pins |
|---|---|
| grouped ≈ reference | The permute/group/scatter pipeline is an identity on the math — the kernel's entire correctness claim |
| combine weights sum to 1 | The router emits a proper convex combination — drop this and outputs scale with k |
assignments = T × k, each expert sees exactly its tokens | The bookkeeping conservation law: nothing dropped, nothing duplicated in the permute — the histogram you'd actually print when debugging a real routing issue (lab-04 builds the diagnostics on top) |
Hitchhiker's notes
- Why softmax-after-top-k? It renormalizes mass over the experts actually used, so the output is a proper weighted average regardless of how confident the router was. Mixtral and most modern MoEs do exactly this; some (DeepSeek-V3) use sigmoid gates with normalization — same pipeline, different gate function. The structure (select → normalize → combine) is the stable part.
- The real kernel fuses steps 2–4 into one launch:
fused_moe_kernel(upstream/vllm/model_executor/layers/fused_moe/fused_moe.py:295) — a Triton kernel whose grid covers (expert blocks × tile positions), reading the alignment metadata thatmoe_align_block_sizeproduced. Your four functions are its four phases; the fusion exists so intermediate permuted tensors never round-trip through HBM (the recurring lesson from Phase 2 lab-06 and lab-03 here: materializing intermediates forfeits the bandwidth win). - SwiGLU, not ReLU, in real models:
(silu(x@W1) * (x@W3)) @ W2— three weight matrices per expert, one extra elementwise multiply. Changes the per-expert FLOPs, changes nothing about routing/grouping. The lab uses ReLU to keep the oracle short. - Where the time really goes: lab-02's profile shows experts (grouped GEMM) at ~41%, permute at ~7%, router at ~1%. Routing is decision-cheap, consequence-expensive — the gate is a tiny matmul whose output determines whether the expensive part runs balanced (lab-04's entire subject).
Going further
- Replace
argsortwith the histogram + prefix-sum (counting sort) the real kernel uses:np.bincount→np.cumsum→ stable placement. Same permutation, O(T·k) instead of O(T·k log T·k) — and now you've writtenmoe_align_block_size's algorithm. - Implement SwiGLU experts and re-run the equality test (it should pass untouched — routing is orthogonal to expert internals; prove it).
- Pad each expert's token block to a multiple of 16 (the GEMM tile constraint — lab-03)
with zero rows, and verify the output is still exact. You've discovered why
moe_align_block_sizehas "block size" in its name, and where MoE's small padding-waste overhead comes from.
References
upstream/vllm/model_executor/layers/fused_moe/fused_moe.py:295— the fused kernel; read it next to your grouped function, phase by phase.- Shazeer et al., Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer (2017) — the modern MoE formulation: https://arxiv.org/abs/1701.06538
- Jiang et al., Mixtral of Experts (2024) — the architecture this lab's shapes mimic (8 experts, top-2): https://arxiv.org/abs/2401.04088
- Lab-03 — why grouped beats tiny matmuls (tiling/reuse); lab-04 — what the routing histogram costs; lab-02 — the profile showing where the milliseconds go.