Phase 07 — Mini-Build: the MoE forward in numpy
You'll implement the full MoE forward — router → top-k → permute → grouped experts → un-permute → weighted combine — and prove it equals a simple reference. This makes the fused kernel's job concrete: it's this, fused into one GPU pass.
Contents
- The task (lab-01)
- Why permute/un-permute (the key insight)
- Definition of done
- Map to the real engine
The task (lab-01)
Implement, in numpy:
route(x, W_gate, k)→(topk_ids (T,k), topk_weights (T,k)):logits = x @ W_gate.T; pick the top-k experts per token; softmax the k selected logits for the combine weights.moe_forward_reference(x, experts, topk_ids, topk_weights)→ the naive version: for each token, for each of its k experts, run that expert's MLP and weight-sum. (Correct, slow — the oracle.)moe_forward_grouped(x, experts, topk_ids, topk_weights)→ the "fused" idea: permute tokens by expert (argsort), run each expert once on its contiguous block (grouped GEMM), un-permute, then combine. Must equal the reference.
An "expert" here is a tiny MLP: relu(x @ W1) @ W2.
Why permute/un-permute (the key insight)
Scattered per-token expert calls are tiny and launch-bound. Sorting tokens by expert turns the
work into a handful of big matmuls (one per expert), which the GPU loves. Your argsort-based
permute is the CPU mirror of moe_align_block_size / moe_permute_unpermute.
Definition of done
pytest phase-07-gemm-and-moe-kernels/labs -q
Tests pin: grouped == reference output; the permutation round-trips (un-permute ∘ permute = identity); each expert is invoked on exactly its assigned tokens; top-k weights sum to 1 per token.
Map to the real engine
| your numpy | real vLLM |
|---|---|
route top-k | routing in FusedMoE/fused_moe.py |
permute by argsort | moe_align_block_size / moe_permute_unpermute.py |
| grouped expert matmuls | fused_moe_kernel (fused_moe.py:295) |
| weighted combine | the combine in fused_experts_impl (:1664) |
| (experts on different GPUs) | expert parallelism (all2all_utils.py) |