Lab 00-01 — The KV-Cache Speedup [CPU-OK]
This is the experiment that motivates the entire course — and arguably the entire field of LLM inference engineering. You will implement autoregressive generation twice: once the naive way (recompute attention's keys and values for the whole sequence, every step) and once with a KV cache (compute each token's K/V exactly once, ever). Same model, same output, and a work difference that grows with the square of the sequence length. By the end you'll have measured, with an exact integer counter you control, why every serving engine on earth is built around a cache — and why the rest of this course is about managing that cache.
Contents
- Why this lab exists
- Background: what K and V are, and why they're recomputable
- Files
- Run
- What to implement
- What you should see — and why every number is what it is
- What the tests prove
- Hitchhiker's notes
- Going further
- References
Why this lab exists
Ask a newcomer why LLM inference is expensive and they'll say "big matrices." True, but it misses the structural problem: generation is autoregressive. The model emits one token, appends it, and runs again on the longer sequence — N tokens of output means N forward passes. If each pass reprocesses everything before it, total work is 1+2+3+…+N ≈ N²/2 token-computations for N tokens of value. Quadratic. A 10k-token answer would cost ~50 million token-computations to produce 10 thousand tokens.
The KV cache is the observation that almost all of that recomputation is byte-identical every time — and the field's entire architecture flows downstream of caching it. Once you store K/V, generation becomes O(N)… and a new problem is born: that cache is state, it lives in scarce GPU memory, it grows every step, and somebody has to manage it. That somebody is vLLM, and managing-the-cache-well is Phases 1–19. This lab is where you earn the premise.
We measure with a counter, not a stopwatch, on purpose. Wall-clock on a laptop is noisy
and proves nothing about asymptotics; an exact work count (one unit per token processed)
gives you the formula, and formulas transfer to any hardware. You'll meet this
counting-over-clocking style again in Phase 3's labs.
Background: what K and V are, and why they're recomputable
In attention, each token's hidden state is projected three ways: a query (what am I looking for?), a key (what can I be found by?), and a value (what do I contribute when found?). When the model processes token t, its query is dotted against the keys of all previous tokens, and the resulting weights blend their values:
attn(t) = softmax(q_t · [k_0 … k_t]ᵀ / √d) · [v_0 … v_t]
The crucial property is causality: k_i and v_i depend only on tokens 0..i. Token
5's key is the same whether the sequence currently has 6 tokens or 6,000. So once computed,
(k_i, v_i) is valid forever — it's a pure function of the prefix, which makes it
perfectly cacheable. The query is the only part that's fresh each step (it belongs to the
new token), which is why we cache K and V but never Q.
That's the whole trick. "KV cache" sounds like infrastructure; it's actually a one-line theorem about causal attention plus the decision to spend memory on it.
Files
starter.py— implementgenerate_no_cacheandgenerate_with_cache. The work meter (compute_kv/KVWork) and the deterministicnext_tokenare provided. Your work.solution.py— reference.test_lab.py— pins identical outputs, the exact quadratic and linear work formulas, and the growing ratio.
Run
LAB_IMPL=starter pytest phase-00-foundations/labs/lab-01-kv-cache-speedup -q
pytest phase-00-foundations/labs/lab-01-kv-cache-speedup -q # reference (default)
What to implement
Both functions generate n_new tokens from a prompt of length P and return
(full_token_sequence, total_kv_work):
generate_no_cache— each decode step first callscompute_kv(tok, pos)for every token currently in the sequence (the model "re-reads" everything), then appendsnext_token(tokens). Stepi(0-indexed) costsP + iunits.generate_with_cache— prefill once (compute_kvper prompt token,Punits), then each decode step computes K/V for only the newly appended token (1 unit).
next_token is deterministic — a hash of the context — so both implementations must
produce the same token sequence. That's not a convenience; it's the point (see the first
test).
What you should see — and why every number is what it is
For P = 5, n_new = 10:
no cache : work = 5+6+7+8+9+10+11+12+13+14 = 95 (sum of P..P+n_new-1 → O(N²))
cached : work = 5 + 10 = 15 (P prefill + 1/step → O(N))
- Why 95? Step 0 reprocesses the 5 prompt tokens; step 1 reprocesses 6 (prompt + the token just generated); … step 9 reprocesses 14. The arithmetic series is the quadratic, made concrete enough to check by hand — which is exactly what the test does.
- Why 15? Each of the 15 tokens that ever exists has its K/V computed exactly once. The cached cost is the number of tokens. It cannot be beaten by any scheme that actually computes the KV (it can be beaten by schemes that reuse KV across requests — that's prefix caching, Phase 2/3).
- At
n_new = 1000: the ratio is >100× and still climbing linearly (~N/2). On real hardware this asymptotic gap is the difference between "chatbots are economically possible" and not. - Notice the two-phase shape that fell out for free: a big batch of K/V work up front (the prefill — all P prompt tokens at once, parallelizable, compute-hungry), then a drip of single-token steps (the decode — serial, one unit each). You didn't design that; caching created it. Prefill-vs-decode is the most consequential workload split in inference (lab-04 quantifies it; Phase 1 traces it; Phase 3 schedules around it), and it is born right here, in your 20 lines.
What the tests prove
| Test | What it pins |
|---|---|
test_both_produce_identical_tokens | Caching is an optimization, not a behavior change — the cached run's outputs are bit-identical. This is the course's master invariant: every optimization from here on (chunked prefill, prefix caching, preemption, paging) is proven safe by exactly this kind of equality test |
test_no_cache_is_quadratic | work == sum(P .. P+n_new−1) — the formula, not "roughly slower" |
test_cached_is_linear | work == P + n_new — every token computed once, ever |
test_work_ratio_grows_with_length | The gap grows with N (>100× at n=1000): this is an asymptotic class difference, not a constant factor someone could optimize away |
Hitchhiker's notes
- The cache is a time–space trade, and the space is the plot of this course. You just converted O(N²) compute into O(N) memory: every token now permanently occupies bytes (about 128 KiB/token for Llama-3-8B — lab-02 computes this). One number to foreshadow: a 24 GiB GPU holds weights plus only a few dozen full-length sequences of cache. Scarcity is immediate, and scarcity is why Phases 2–3 exist.
- Real transformers hide the no-cache cost inside one matmul. HuggingFace
generate(use_cache=False)doesn't loop per token like your simulation; it reprocesses the whole sequence in a single (big) forward pass per step. The work is still quadratic in total — your counter models the FLOPs faithfully even though the loop structure differs. - Where the cache actually lives upstream:
vllm.attention.layer.Attentionwrites each step's new K/V into the paged cache (viaslot_mapping— Phase 2 lab-06), and the kernel reads all prior K/V (viablock_table). What you modeled as a counter is, in production, tensors + an allocator + a scheduler. Same theorem underneath. - Why does the cached version call
next_token(tokens)with the full list, then? Because the model function still needs the whole context semantically — the cache changes what is recomputed, not what the model "knows." In a real model, "the cache was consulted" and "the context was read" are the same act: attention over cached K/V. Don't confuse caching KV with truncating context.
Going further
- Plot
work_no_cache / work_cachedfor n in 1..2000 — confirm the ~N/2 line. Then plot cached work alone: a flat 1/step. That flat line is why decode latency is stable and why per-token pricing is linear. Economics from asymptotics. - Model prompt length: sweep P from 10 to 10,000 at fixed n_new=100. Notice prefill dominates total cached work for long prompts — the TTFT story (Phase 1) in miniature.
- Add a
kv_bytescounter alongside the work counter (one cache entry percompute_kv) and watch memory grow linearly while compute stays flat — you've now built both axes of lab-02 and the motivating tension of Phase 2 with ~5 extra lines.
References
- Vaswani et al., Attention Is All You Need (2017) — where K/Q/V come from: https://arxiv.org/abs/1706.03762
- kipply, Transformer Inference Arithmetic — the canonical blog walkthrough of KV-cache math and why decode is bandwidth-bound: https://kipp.ly/transformer-inference-arithmetic/
- Pope et al., Efficiently Scaling Transformer Inference (2022) — §3 formalizes the prefill/decode split your counter just exposed: https://arxiv.org/abs/2211.05102
upstream/vllm/attention/layer.py— the production home of the cache write.- Phase 0 guide §"the KV cache" (00-guide.md) — the intuition this lab makes quantitative.