Phase 07 Labs — GEMM & MoE Kernels
Four labs below the attention line: the matmuls that are most of every step's milliseconds, and the mixture-of-experts machinery that reorganizes them. The arc: build the MoE forward and prove the grouped formulation exact (lab-01), learn the tiling arithmetic that makes any GEMM fast — and why decode shapes defeat it (lab-03), measure the balance tax that routing levies on parallel experts (lab-04), then profile a real MoE model and check all three models against silicon (lab-02).
Recommended order: 01 → 03 → 04 → 02. (Directory numbers predate labs 03–04.) CPU
labs follow the standard contract — starter.py (your work), solution.py (reference),
test_lab.py (the spec); default runs the solution, LAB_IMPL=starter grades yours.
# Whole phase (GPU tests auto-skip without CUDA):
pytest phase-07-gemm-and-moe-kernels/labs -m "not gpu"
# Grade yourself on one lab:
LAB_IMPL=starter pytest phase-07-gemm-and-moe-kernels/labs/lab-01-moe-routing -q
Contents
- lab-01-moe-routing
[CPU-OK] - lab-02-profile-fused-moe
[GPU-OPT] - lab-03-tiled-gemm
[CPU-OK] - lab-04-expert-load-balance
[CPU-OK] - What you can do after this phase
Labs
lab-01-moe-routing [CPU-OK]
Implement the MoE forward twice — the naive per-token oracle and the grouped
formulation (permute by expert → one matmul per expert → scatter back with combine
weights) — and prove them equal. The grouped version is the readable edition of
fused_moe_kernel; the scatter-back hides the lab's boss fight (np.add.at vs the
duplicate-dropping +=). Skills: select-normalize-combine routing; the permute trick;
conservation bookkeeping as debugging.
lab-02-profile-fused-moe [GPU-OPT]
Capture decode steps of a real MoE model under torch.profiler and read the kernel
table: experts ~41%, permute ~7%, router ~1%. Predict the breakdown first; the gaps are
your misconceptions. Annotated capture included. Skills: warm-up discipline;
kernel-table 80/20; decision-cheap/consequence-expensive structure of routing.
lab-03-tiled-gemm [CPU-OK]
The idea that fills the gap between three nested loops and CUTLASS: tiling. Implement a tiled matmul (exact, ragged edges included) and the memory-traffic model — reuse equals the harmonic mean of tile dimensions, square tiles win, and a decode-shaped matmul (M=1) caps at reuse 2 no matter what, re-deriving decode's bandwidth wall from the kernel side. Skills: traffic counting; intensity as an algorithm property; the three-level tiling hierarchy.
lab-04-expert-load-balance [CPU-OK]
The MoE serving problem: with experts sharded across devices, a step lasts as long as the busiest device. Build the diagnostics (loads, imbalance factor, EP step time, capacity-overflow drops) and prove a hot expert inflates step time >2.5× at identical total work. Skills: straggler arithmetic; placement vs routing; why inference never drops tokens; what EPLB optimizes.
What you can do after this phase
Estimate any GEMM's achievable performance from shape + tile + hardware on a napkin;
explain why grouped MoE kernels exist and verify one against an oracle; diagnose an
underperforming MoE deployment with a routing histogram before touching a profiler, and
with one afterward; and read fused_moe.py upstream as four phases you've personally
implemented. Phase 8 (speculative decoding) spends the idle FLOPs you now know how to
find; Phase 10 stretches the all-to-all across nodes.