Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Phase 07 Labs — GEMM & MoE Kernels

Four labs below the attention line: the matmuls that are most of every step's milliseconds, and the mixture-of-experts machinery that reorganizes them. The arc: build the MoE forward and prove the grouped formulation exact (lab-01), learn the tiling arithmetic that makes any GEMM fast — and why decode shapes defeat it (lab-03), measure the balance tax that routing levies on parallel experts (lab-04), then profile a real MoE model and check all three models against silicon (lab-02).

Recommended order: 01 → 03 → 04 → 02. (Directory numbers predate labs 03–04.) CPU labs follow the standard contract — starter.py (your work), solution.py (reference), test_lab.py (the spec); default runs the solution, LAB_IMPL=starter grades yours.

# Whole phase (GPU tests auto-skip without CUDA):
pytest phase-07-gemm-and-moe-kernels/labs -m "not gpu"

# Grade yourself on one lab:
LAB_IMPL=starter pytest phase-07-gemm-and-moe-kernels/labs/lab-01-moe-routing -q

Contents


Labs

lab-01-moe-routing [CPU-OK]

Implement the MoE forward twice — the naive per-token oracle and the grouped formulation (permute by expert → one matmul per expert → scatter back with combine weights) — and prove them equal. The grouped version is the readable edition of fused_moe_kernel; the scatter-back hides the lab's boss fight (np.add.at vs the duplicate-dropping +=). Skills: select-normalize-combine routing; the permute trick; conservation bookkeeping as debugging.

lab-02-profile-fused-moe [GPU-OPT]

Capture decode steps of a real MoE model under torch.profiler and read the kernel table: experts ~41%, permute ~7%, router ~1%. Predict the breakdown first; the gaps are your misconceptions. Annotated capture included. Skills: warm-up discipline; kernel-table 80/20; decision-cheap/consequence-expensive structure of routing.

lab-03-tiled-gemm [CPU-OK]

The idea that fills the gap between three nested loops and CUTLASS: tiling. Implement a tiled matmul (exact, ragged edges included) and the memory-traffic model — reuse equals the harmonic mean of tile dimensions, square tiles win, and a decode-shaped matmul (M=1) caps at reuse 2 no matter what, re-deriving decode's bandwidth wall from the kernel side. Skills: traffic counting; intensity as an algorithm property; the three-level tiling hierarchy.

lab-04-expert-load-balance [CPU-OK]

The MoE serving problem: with experts sharded across devices, a step lasts as long as the busiest device. Build the diagnostics (loads, imbalance factor, EP step time, capacity-overflow drops) and prove a hot expert inflates step time >2.5× at identical total work. Skills: straggler arithmetic; placement vs routing; why inference never drops tokens; what EPLB optimizes.

What you can do after this phase

Estimate any GEMM's achievable performance from shape + tile + hardware on a napkin; explain why grouped MoE kernels exist and verify one against an oracle; diagnose an underperforming MoE deployment with a routing histogram before touching a profiler, and with one afterward; and read fused_moe.py upstream as four phases you've personally implemented. Phase 8 (speculative decoding) spends the idle FLOPs you now know how to find; Phase 10 stretches the all-to-all across nodes.