Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Lab 07-01 — The MoE Forward (Routing + Grouped Experts) [CPU-OK]

A mixture-of-experts layer makes a strange promise: a model with 8× the parameters at ~1× the per-token compute, because each token visits only its top-k of E expert MLPs. The catch is operational — "each token visits different experts" is a scatter/gather nightmare for hardware that loves big uniform matmuls. This lab has you implement both sides of the resolution: the naive per-token loop (obviously correct, hopelessly slow — your oracle) and the grouped formulation (permute tokens by expert → one big matmul per expert → scatter back with combine weights), and prove they're equal. The grouped version is, step for step, what vLLM's fused_moe_kernel does in one GPU pass — you're writing the readable edition of one of the hottest kernels in modern serving.

Contents


Why this lab exists

MoE is where the frontier lives (Mixtral, DeepSeek-V3, Qwen-MoE, most rumored frontier models), and its serving stack confuses newcomers because the math (a weighted sum of small MLPs) and the implementation (sorts, histograms, alignment buffers, grouped GEMMs) look unrelated. They aren't — the implementation is the math, reorganized so the GPU sees few large uniform operations instead of many tiny ragged ones. Building both versions and asserting equality is how you internalize the correspondence; after this lab, moe_align_block_size (a real kernel whose name suggests nothing) reads as "my argsort, made tile-friendly."

The grouped-equals-reference test is also this course's master invariant again (optimizations must not change output) in its most insidious habitat: the scatter-back. Combine-weight bugs and duplicate-token-row bugs (np.add.at vs out[toks] += — see the notes) produce outputs that are plausibly wrong, the worst kind. The oracle test is the only honest defense.

Background: the permute trick

Per token: router logits x @ W_gateᵀ → take top-k experts → softmax the selected k logits into combine weights → output is the weighted sum of those experts' MLP outputs. Done literally, that's T × k tiny matmuls — death by launch overhead and zero data reuse (lab-03 quantifies why tiny matmuls waste a GPU).

The grouped reformulation observes that the same set of (token, expert) pairs can be processed expert-major instead of token-major:

  1. Flatten the (T, k) assignment matrix into T·k (token, expert, weight) triples.
  2. Permute: sort triples by expert (your argsort; real kernels build the equivalent grouping with a histogram + prefix sum — moe_align_block_size).
  3. Grouped GEMM: for each expert, one matmul over its contiguous block of tokens — E medium matmuls instead of T·k tiny ones, each big enough to tile well (lab-03).
  4. Un-permute + combine: scatter results back to token order, multiplying by the combine weights, summing the k contributions per token.

No arithmetic changed — only its order. The speedup comes entirely from shaping the work to what hardware rewards: contiguity and uniformity.

Files

  • starter.pyroute, expert_mlp, moe_forward_reference, moe_forward_grouped. Your work.
  • solution.py — reference.
  • test_lab.py — grouped == reference, combine weights sum to 1, assignment bookkeeping.

Run

LAB_IMPL=starter pytest phase-07-gemm-and-moe-kernels/labs/lab-01-moe-routing -q
pytest phase-07-gemm-and-moe-kernels/labs/lab-01-moe-routing -q   # reference

What to implement

Per 02-mini-build.md: route (logits → top-k ids + softmax of the selected logits), expert_mlp (relu(x @ W1) @ W2), the reference loop, and the grouped version. Two precision points: softmax over the selected k logits only (not all E — selecting then normalizing is the standard formulation; normalizing then selecting gives different weights), and the scatter-back must handle a token appearing twice in an expert's block when top-k assigns it duplicate experts — np.add.at accumulates correctly where fancy-indexed += silently drops duplicates. That numpy footgun is the lab's hidden boss; the bookkeeping test exists for it.

What the tests prove

TestWhat it pins
grouped ≈ referenceThe permute/group/scatter pipeline is an identity on the math — the kernel's entire correctness claim
combine weights sum to 1The router emits a proper convex combination — drop this and outputs scale with k
assignments = T × k, each expert sees exactly its tokensThe bookkeeping conservation law: nothing dropped, nothing duplicated in the permute — the histogram you'd actually print when debugging a real routing issue (lab-04 builds the diagnostics on top)

Hitchhiker's notes

  • Why softmax-after-top-k? It renormalizes mass over the experts actually used, so the output is a proper weighted average regardless of how confident the router was. Mixtral and most modern MoEs do exactly this; some (DeepSeek-V3) use sigmoid gates with normalization — same pipeline, different gate function. The structure (select → normalize → combine) is the stable part.
  • The real kernel fuses steps 2–4 into one launch: fused_moe_kernel (upstream/vllm/model_executor/layers/fused_moe/fused_moe.py:295) — a Triton kernel whose grid covers (expert blocks × tile positions), reading the alignment metadata that moe_align_block_size produced. Your four functions are its four phases; the fusion exists so intermediate permuted tensors never round-trip through HBM (the recurring lesson from Phase 2 lab-06 and lab-03 here: materializing intermediates forfeits the bandwidth win).
  • SwiGLU, not ReLU, in real models: (silu(x@W1) * (x@W3)) @ W2 — three weight matrices per expert, one extra elementwise multiply. Changes the per-expert FLOPs, changes nothing about routing/grouping. The lab uses ReLU to keep the oracle short.
  • Where the time really goes: lab-02's profile shows experts (grouped GEMM) at ~41%, permute at ~7%, router at ~1%. Routing is decision-cheap, consequence-expensive — the gate is a tiny matmul whose output determines whether the expensive part runs balanced (lab-04's entire subject).

Going further

  • Replace argsort with the histogram + prefix-sum (counting sort) the real kernel uses: np.bincountnp.cumsum → stable placement. Same permutation, O(T·k) instead of O(T·k log T·k) — and now you've written moe_align_block_size's algorithm.
  • Implement SwiGLU experts and re-run the equality test (it should pass untouched — routing is orthogonal to expert internals; prove it).
  • Pad each expert's token block to a multiple of 16 (the GEMM tile constraint — lab-03) with zero rows, and verify the output is still exact. You've discovered why moe_align_block_size has "block size" in its name, and where MoE's small padding-waste overhead comes from.

References

  • upstream/vllm/model_executor/layers/fused_moe/fused_moe.py:295 — the fused kernel; read it next to your grouped function, phase by phase.
  • Shazeer et al., Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer (2017) — the modern MoE formulation: https://arxiv.org/abs/1701.06538
  • Jiang et al., Mixtral of Experts (2024) — the architecture this lab's shapes mimic (8 experts, top-2): https://arxiv.org/abs/2401.04088
  • Lab-03 — why grouped beats tiny matmuls (tiling/reuse); lab-04 — what the routing histogram costs; lab-02 — the profile showing where the milliseconds go.