Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Lab 00-01 — The KV-Cache Speedup [CPU-OK]

This is the experiment that motivates the entire course — and arguably the entire field of LLM inference engineering. You will implement autoregressive generation twice: once the naive way (recompute attention's keys and values for the whole sequence, every step) and once with a KV cache (compute each token's K/V exactly once, ever). Same model, same output, and a work difference that grows with the square of the sequence length. By the end you'll have measured, with an exact integer counter you control, why every serving engine on earth is built around a cache — and why the rest of this course is about managing that cache.

Contents


Why this lab exists

Ask a newcomer why LLM inference is expensive and they'll say "big matrices." True, but it misses the structural problem: generation is autoregressive. The model emits one token, appends it, and runs again on the longer sequence — N tokens of output means N forward passes. If each pass reprocesses everything before it, total work is 1+2+3+…+N ≈ N²/2 token-computations for N tokens of value. Quadratic. A 10k-token answer would cost ~50 million token-computations to produce 10 thousand tokens.

The KV cache is the observation that almost all of that recomputation is byte-identical every time — and the field's entire architecture flows downstream of caching it. Once you store K/V, generation becomes O(N)… and a new problem is born: that cache is state, it lives in scarce GPU memory, it grows every step, and somebody has to manage it. That somebody is vLLM, and managing-the-cache-well is Phases 1–19. This lab is where you earn the premise.

We measure with a counter, not a stopwatch, on purpose. Wall-clock on a laptop is noisy and proves nothing about asymptotics; an exact work count (one unit per token processed) gives you the formula, and formulas transfer to any hardware. You'll meet this counting-over-clocking style again in Phase 3's labs.

Background: what K and V are, and why they're recomputable

In attention, each token's hidden state is projected three ways: a query (what am I looking for?), a key (what can I be found by?), and a value (what do I contribute when found?). When the model processes token t, its query is dotted against the keys of all previous tokens, and the resulting weights blend their values:

attn(t) = softmax(q_t · [k_0 … k_t]ᵀ / √d) · [v_0 … v_t]

The crucial property is causality: k_i and v_i depend only on tokens 0..i. Token 5's key is the same whether the sequence currently has 6 tokens or 6,000. So once computed, (k_i, v_i) is valid forever — it's a pure function of the prefix, which makes it perfectly cacheable. The query is the only part that's fresh each step (it belongs to the new token), which is why we cache K and V but never Q.

That's the whole trick. "KV cache" sounds like infrastructure; it's actually a one-line theorem about causal attention plus the decision to spend memory on it.

Files

  • starter.py — implement generate_no_cache and generate_with_cache. The work meter (compute_kv / KVWork) and the deterministic next_token are provided. Your work.
  • solution.py — reference.
  • test_lab.py — pins identical outputs, the exact quadratic and linear work formulas, and the growing ratio.

Run

LAB_IMPL=starter pytest phase-00-foundations/labs/lab-01-kv-cache-speedup -q
pytest phase-00-foundations/labs/lab-01-kv-cache-speedup -q   # reference (default)

What to implement

Both functions generate n_new tokens from a prompt of length P and return (full_token_sequence, total_kv_work):

  • generate_no_cache — each decode step first calls compute_kv(tok, pos) for every token currently in the sequence (the model "re-reads" everything), then appends next_token(tokens). Step i (0-indexed) costs P + i units.
  • generate_with_cache — prefill once (compute_kv per prompt token, P units), then each decode step computes K/V for only the newly appended token (1 unit).

next_token is deterministic — a hash of the context — so both implementations must produce the same token sequence. That's not a convenience; it's the point (see the first test).

What you should see — and why every number is what it is

For P = 5, n_new = 10:

no cache : work = 5+6+7+8+9+10+11+12+13+14 = 95      (sum of P..P+n_new-1 → O(N²))
cached   : work = 5 + 10                   = 15      (P prefill + 1/step   → O(N))
  • Why 95? Step 0 reprocesses the 5 prompt tokens; step 1 reprocesses 6 (prompt + the token just generated); … step 9 reprocesses 14. The arithmetic series is the quadratic, made concrete enough to check by hand — which is exactly what the test does.
  • Why 15? Each of the 15 tokens that ever exists has its K/V computed exactly once. The cached cost is the number of tokens. It cannot be beaten by any scheme that actually computes the KV (it can be beaten by schemes that reuse KV across requests — that's prefix caching, Phase 2/3).
  • At n_new = 1000: the ratio is >100× and still climbing linearly (~N/2). On real hardware this asymptotic gap is the difference between "chatbots are economically possible" and not.
  • Notice the two-phase shape that fell out for free: a big batch of K/V work up front (the prefill — all P prompt tokens at once, parallelizable, compute-hungry), then a drip of single-token steps (the decode — serial, one unit each). You didn't design that; caching created it. Prefill-vs-decode is the most consequential workload split in inference (lab-04 quantifies it; Phase 1 traces it; Phase 3 schedules around it), and it is born right here, in your 20 lines.

What the tests prove

TestWhat it pins
test_both_produce_identical_tokensCaching is an optimization, not a behavior change — the cached run's outputs are bit-identical. This is the course's master invariant: every optimization from here on (chunked prefill, prefix caching, preemption, paging) is proven safe by exactly this kind of equality test
test_no_cache_is_quadraticwork == sum(P .. P+n_new−1) — the formula, not "roughly slower"
test_cached_is_linearwork == P + n_new — every token computed once, ever
test_work_ratio_grows_with_lengthThe gap grows with N (>100× at n=1000): this is an asymptotic class difference, not a constant factor someone could optimize away

Hitchhiker's notes

  • The cache is a time–space trade, and the space is the plot of this course. You just converted O(N²) compute into O(N) memory: every token now permanently occupies bytes (about 128 KiB/token for Llama-3-8B — lab-02 computes this). One number to foreshadow: a 24 GiB GPU holds weights plus only a few dozen full-length sequences of cache. Scarcity is immediate, and scarcity is why Phases 2–3 exist.
  • Real transformers hide the no-cache cost inside one matmul. HuggingFace generate(use_cache=False) doesn't loop per token like your simulation; it reprocesses the whole sequence in a single (big) forward pass per step. The work is still quadratic in total — your counter models the FLOPs faithfully even though the loop structure differs.
  • Where the cache actually lives upstream: vllm.attention.layer.Attention writes each step's new K/V into the paged cache (via slot_mapping — Phase 2 lab-06), and the kernel reads all prior K/V (via block_table). What you modeled as a counter is, in production, tensors + an allocator + a scheduler. Same theorem underneath.
  • Why does the cached version call next_token(tokens) with the full list, then? Because the model function still needs the whole context semantically — the cache changes what is recomputed, not what the model "knows." In a real model, "the cache was consulted" and "the context was read" are the same act: attention over cached K/V. Don't confuse caching KV with truncating context.

Going further

  • Plot work_no_cache / work_cached for n in 1..2000 — confirm the ~N/2 line. Then plot cached work alone: a flat 1/step. That flat line is why decode latency is stable and why per-token pricing is linear. Economics from asymptotics.
  • Model prompt length: sweep P from 10 to 10,000 at fixed n_new=100. Notice prefill dominates total cached work for long prompts — the TTFT story (Phase 1) in miniature.
  • Add a kv_bytes counter alongside the work counter (one cache entry per compute_kv) and watch memory grow linearly while compute stays flat — you've now built both axes of lab-02 and the motivating tension of Phase 2 with ~5 extra lines.

References

  • Vaswani et al., Attention Is All You Need (2017) — where K/Q/V come from: https://arxiv.org/abs/1706.03762
  • kipply, Transformer Inference Arithmetic — the canonical blog walkthrough of KV-cache math and why decode is bandwidth-bound: https://kipp.ly/transformer-inference-arithmetic/
  • Pope et al., Efficiently Scaling Transformer Inference (2022) — §3 formalizes the prefill/decode split your counter just exposed: https://arxiv.org/abs/2211.05102
  • upstream/vllm/attention/layer.py — the production home of the cache write.
  • Phase 0 guide §"the KV cache" (00-guide.md) — the intuition this lab makes quantitative.