Phase 02 — The Hitchhiker's Guide to PagedAttention ⭐
← Phase 01 · Course home · Phase 03 →
This is a flagship phase — written in full. Use it as the template for the depth every other phase aims at.
Contents
- Don't Panic
- Step 1: Why is memory the problem at all?
- Step 2: The old way, and why it bled memory
- Step 3: The fix — pages (blocks)
- Step 4: The bonus that falls out — sharing
- Step 5: The data structures you're about to meet
- The four invariants (memorize these)
- What you'll do in this phase
Don't Panic
Here is the entire idea, in one breath:
The KV cache is the model's memory of the conversation so far. Naively, you'd give each request one big contiguous slab of GPU memory to hold it. PagedAttention instead chops the KV cache into fixed-size blocks (like a operating system chops memory into pages) and lets each request's blocks live anywhere in GPU memory, tracked by a little block table. That one change — contiguous slab → scattered pages — is why vLLM serves several times more requests per GPU than the systems that came before it.
If you have ever learned how an OS gives processes "virtual memory" backed by scattered physical pages, you already understand PagedAttention. It is literally that idea, applied to the KV cache. The vLLM paper's title even says so: "Efficient Memory Management for Large Language Model Serving with PagedAttention."
Take a breath. By the end of this phase you will have written a working paged block
allocator yourself (mini_vllm/block_pool.py) and read the real one
(upstream/vllm/v1/core/block_pool.py) line by line.
Step 1: Why is memory the problem at all?
Recall from Phase 0: during generation, the model caches a Key and Value vector for every token it has seen, in every layer. This is the KV cache. It is enormous and it grows as the conversation gets longer.
A rough size for one sequence:
kv_bytes_per_token = 2 (K and V) × num_layers × num_kv_heads × head_dim × dtype_bytes
For Llama-3-8B (32 layers, 8 KV heads, head_dim 128, fp16) that's about:
2 × 32 × 8 × 128 × 2 ≈ 131 KB per token
A 2,000-token conversation is ~256 MB of KV — for one user. On a 24 GB GPU, after the ~16 GB of weights, you have ~8 GB for KV — maybe ~30 such conversations. Memory, not compute, is what caps how many users you can serve. So how you manage that memory is the whole ballgame.
Step 2: The old way, and why it bled memory
Pre-vLLM systems reserved a contiguous chunk of KV memory per request, sized for the maximum possible length (e.g. 2048 tokens), up front.
Request A (will generate 30 tokens, reserved 2048):
[####..............................................................] <- 2018 slots WASTED
^30 used
Request B (reserved 2048):
[#########.........................................................] <- ~2000 WASTED
Two diseases:
- Internal fragmentation — you reserve for the worst case (2048) but use 30. The other ~2018 slots sit idle, reserved, unusable by anyone else.
- External fragmentation — as requests of different sizes come and go, free memory breaks into chunks too small to fit the next contiguous request, even though the total free memory is plenty.
Studies in the vLLM paper found these wasted 60–80% of KV memory. That directly means 60–80% fewer concurrent users than the hardware could support.
Step 3: The fix — pages (blocks)
PagedAttention says: stop reserving contiguous slabs. Instead:
- Carve all KV memory into many small, equal blocks. A block holds the KV of
block_sizetokens (commonly 16). - Maintain a global pool of free blocks.
- Give each request blocks on demand, one at a time, as it generates — and the blocks can be anywhere in physical memory.
- Keep a per-request block table: a little array mapping the request's logical block index (0, 1, 2, …) to the physical block id it actually got.
Physical KV memory (one big array of blocks, ids 0..N):
┌────┬────┬────┬────┬────┬────┬────┬────┬────┬────┐
│ b0 │ b1 │ b2 │ b3 │ b4 │ b5 │ b6 │ b7 │ b8 │ b9 │ ...
└────┴────┴────┴────┴────┴────┴────┴────┴────┴────┘
Request A's block table: [ 4, 1, 7 ] (logical 0→phys 4, 1→1, 2→7)
Request B's block table: [ 2, 9 ]
A's tokens live in blocks 4,1,7 — NOT contiguous, and that's totally fine.
Only A's *last* block may be partly empty (≤ block_size−1 wasted). No giant reservations.
Now waste is at most block_size − 1 tokens per request (the tail of the last block) —
seconds of generation, not thousands of reserved-but-idle slots. Fragmentation: gone.
The mental shift: a request's KV no longer needs to be contiguous in memory; it only needs to be contiguous in the block table. The attention kernel is handed the block table and gathers KV from the scattered physical blocks. That's the "Paged" in PagedAttention.
Step 4: The bonus that falls out — sharing
Once KV is in blocks tracked by tables, two requests can point their block tables at the
same physical block. If two requests start with the same prompt (a shared system prompt, or
n=4 samples of one prompt), they can share the physical KV blocks of that prefix — compute
it once, store it once.
System prompt blocks (computed once): b5 b6
Request A table: [ b5, b6, b1 ] ─┐
Request B table: [ b5, b6, b8 ] ─┴─ both point at b5,b6 (shared!), diverge after.
This is prefix caching (the star of Phase 03). To make sharing safe we need two more concepts, both straight from operating systems:
- Reference counting — each block knows how many requests use it (
ref_cnt). A block is truly free only whenref_cnt == 0. - Copy-on-write — if a shared block must change for just one request, copy it first so the other sharer's view is untouched.
Step 5: The data structures you're about to meet
The real vLLM (and your mini_vllm) implement paging with exactly four pieces:
| Piece | Job | Real code | Your code |
|---|---|---|---|
KVCacheBlock | metadata for one physical block (id, ref_cnt, hash) | kv_cache_utils.py:116 | mini_vllm/block_pool.py |
FreeKVCacheBlockQueue | the free list, in eviction order, O(1) middle-removal | kv_cache_utils.py:164 | mini_vllm/block_pool.py |
BlockPool | owns all blocks + the free list + the prefix-cache index | block_pool.py:130 | mini_vllm/block_pool.py |
KVCacheManager | per-request block tables; the API the scheduler calls | kv_cache_manager.py:110 | mini_vllm/kv_cache.py |
A surprising detail you'll appreciate: the free list is a hand-rolled doubly linked list,
not a Python deque. Why? Because on a prefix-cache hit we must yank a specific block out of
the middle of the free list in O(1). A deque can't do that. The real code has a 30-line
docstring justifying this exact decision (kv_cache_utils.py:164). Reading that docstring and
understanding why is a rite of passage — and a great interview answer.
The four invariants (memorize these)
A maintainer holds these in their head at all times. They're tested in
mini_vllm/test_block_pool.py and asserted throughout the real code:
- I1. A block is in the free queue ⟺
block.ref_cnt == 0(and it isn't the null block). - I2. Block tables are append-only: an allocated
block_idnever changes under a request. (This is why the cache doesn't de-duplicate — seeblock_pool.py:48.) - I3. Only a full block (exactly
block_sizetokens) ever gets a hash and enters the prefix cache. - I4. "Cached" ≠ "unusable." A block can be a free eviction candidate (in the free
queue) while still being a prefix-cache hit target.
touch()revives it.
What you'll do in this phase
- Read the real allocator: 01-deep-dive.md walks
block_pool.pyandkv_cache_utils.pyline by line. - Build your own: 02-mini-build.md (you've got
mini_vllm/block_pool.pyas the reference — the lab has you write it from a stub). - Labs (see labs/README.md; recommended order 01 → 02 → 05 → 06 → 03 → 04):
lab-01-block-allocator[CPU-OK]— implement the paged allocator + free queue, pass the tests.lab-02-fragmentation-viz[CPU-OK]— simulate contiguous vs paged allocation; measure the waste.lab-03-real-vllm-blocks[GPU-OPT]— run real vLLM, readnum_gpu_blocksand KV usage, prove no fragmentation.lab-04-triton-paged-attn[GPU-REQ]— port a block-table-indexed attention to a Triton kernel.lab-05-share-and-evict[CPU-OK]— the life of a cached block: sharing (ref_cnt==2), eviction order (tails before shared prefixes), and revival from the middle of the free queue.lab-06-paged-attention-numpy[CPU-OK]— the kernel's data path in pure numpy:slot_mapping, scatter, gather-through-the-table, and proof that paged == dense to 1e-12. (The CPU twin of lab-04.)
- Test yourself: EXERCISES.md, INTERVIEW.md, CHEATSHEET.md.
When you can whiteboard the block table + free queue from memory and explain copy-on-write and the four invariants, you understand the single most important idea in vLLM. Onward.
← Phase 01 · Course home · Phase 03 →