Lab 02-01 — Build the Paged Block Allocator [CPU-OK]
This is the lab of the phase, and arguably of the course. You are going to implement, from a skeleton, the data structure that made vLLM famous: the paged KV-cache block allocator — the free queue, the block pool, the reference counts, and the prefix-cache index. When the tests go green, the thing that serves trillions of tokens a day in production deployments around the world will exist, in miniature, written by your hands.
Contents
- Why this lab exists
- Background: what problem this structure solves
- The cast of characters
- Files
- How to run
- What to implement (in
starter.py) - The invariants you're proving
- The one data-structure decision to savor
- What the tests prove
- Hitchhiker's notes
- Success, and what to do with it
- References
Why this lab exists
Here is the surprise at the heart of vLLM: its breakthrough wasn't a kernel, a model trick, or a CUDA wizardry. It was an operating-systems idea from 1962 — paged virtual memory — applied to the KV cache. The PagedAttention paper's headline numbers (2–4× throughput over the prior state of the art) come almost entirely from the metadata structure you're about to build: a few hundred lines of bookkeeping that decide which 16-token "page" of GPU memory belongs to whom.
That's also why this lab is CPU-only with zero loss of fidelity. The GPU tensors are, as the
module docstring puts it, "just an array indexed by block_id." The hard part — the part
maintainers actually edit, review, and break — is the metadata: ref counts, free lists,
eviction, the prefix-cache index. You'll write all of it. And because mini_vllm's version
is a faithful-but-small port of the real one (same class names, same invariants, line
references throughout), finishing this lab means you can open
upstream/vllm/v1/core/block_pool.py and read it like something you wrote.
Background: what problem this structure solves
Every token a transformer processes leaves a residue: its attention keys and values, needed by every future token of the same sequence. For a 7B model that's ~0.5 MB per token. The pre-vLLM engines stored each request's KV in one contiguous tensor sized for the maximum possible length — and since you can't know in advance how long a generation will run, they reserved worst case and used average case. Result: 60–80% of "used" KV memory held nothing (measured in the PagedAttention paper, §2; you'll reproduce the number yourself in lab-02).
The fix is the OS playbook, almost verbatim:
| OS virtual memory | vLLM | In this lab |
|---|---|---|
| physical page frame | KV block (block_size tokens of K/V) | KVCacheBlock |
| free frame list | free queue | FreeKVCacheBlockQueue |
| page table (per process) | block table (per request) | Phase 2's KVCacheManager (next file over) |
| shared pages + refcounts | prefix sharing + ref_cnt | touch / free_blocks |
| page cache | prefix-cache index | cached_block_hash_to_block |
A request takes blocks one at a time, from anywhere, as it grows. Nothing is reserved.
External fragmentation: impossible (all blocks the same size). Internal fragmentation: at
most block_size − 1 tokens, in the last block only. Sharing: free, via refcounts. That's
the whole revolution.
The cast of characters
You implement three things, mirroring (with line references) the real engine:
KVCacheBlock— one block's metadata:block_id(its fixed address in the GPU tensor),ref_cnt(how many requests use it),block_hash(set only when full and cached), and two linked-list pointers it does not manage itself. (upstream:kv_cache_utils.py:116)FreeKVCacheBlockQueue— a doubly linked list with head/tail sentinels holding everyref_cnt == 0block in eviction order. Supportspopleft(allocate),append(free), and the crucialremove(block)— O(1) extraction from the middle. (upstream:kv_cache_utils.py:164, where the docstring explains exactly why a deque can't do this job)BlockPool— the owner:get_new_blocks(allocate + maybe-evict),touch(adopt a cached block, reviving it from the free queue if needed),free_blocks(decref, return to queue atref_cnt == 0),cache_full_blocks/get_cached_block(the prefix-cache index), plushash_block_tokens— the parent-chained content hash. (upstream:block_pool.py:130)
Files
starter.py— the skeleton. Method bodies raiseNotImplementedError. Fill them in.solution.py— a complete reference. Don't open it until you're green or truly stuck — this lab's struggle is its value.test_lab.py— every invariant from the deep-dive §1–3, executable.
How to run
# Grade YOUR implementation:
LAB_IMPL=starter pytest phase-02-paged-attention/labs/lab-01-block-allocator -q
# The reference (default — keeps the suite green out of the box):
pytest phase-02-paged-attention/labs/lab-01-block-allocator -q
What to implement (in starter.py)
Recommended order — each layer is testable before the next:
FreeKVCacheBlockQueue:popleft,remove,append,get_all_free_blocks. The sentinels (_head,_tail) are pre-wired so you never branch on "am I first/last?" — notice how much conditional logic two dummy nodes delete. Keepnum_free_blocksexact; the pool's OOM answer depends on it.hash_block_tokens: hash(parent_hash, tokens_tuple). One line — but read the docstring until you can say why the parent is in there (see Hitchhiker's notes).BlockPool:get_new_blocks(pop,_maybe_evict, assertref_cnt == 0, set to 1),_maybe_evict(drop the hash↔block mapping if this block was a cached eviction candidate),touch,free_blocks,cache_full_blocks,get_cached_block,get_num_free_blocks. Mind block 0: it's reserved as the null block at construction, exactly like upstream.
The invariants you're proving
These four lines are the closest thing the KV subsystem has to a constitution. Real scheduler bugs — upstream, in production — are violations of one of these:
- I1. A block is in the free queue ⟺
ref_cnt == 0(and it's not the null block). Both directions. A block in the queue with refs is a use-after-free wearing a disguise: someone will allocate it and overwrite KV another request is still reading — silent corruption, tokens from someone else's conversation. - I2. Block ids are stable: once given to a request, a block is never renumbered or
deduplicated out from under it. The GPU kernel reads physical addresses computed from
block_id; metadata cleverness must never move data. - I3. Only full blocks get hashed and cached. A partial block's contents are still changing; caching it would serve half-written KV to a prefix match.
- I4. Cached ≠ unusable. A cached block with
ref_cnt == 0sits in the free queue as an eviction candidate — it can be reclaimed (evicted) byget_new_blocksor revived (re-referenced) bytouch. This dual citizenship is the whole trick of zero-cost prefix caching: the cache rides for free in memory that's already free.
The one data-structure decision to savor
Why is the free "queue" a hand-rolled doubly linked list instead of
collections.deque? Because of I4. When a prefix-cache hit revives a block, that block is
sitting somewhere in the middle of the free queue, and it must leave now, in O(1) —
not via an O(n) scan of a deque. The eviction end (popleft) and the return end (append)
are deque-friendly; it's the revival path that forces real pointers. The upstream class
exists for precisely this reason and says so in its docstring.
Generalize the lesson: the access pattern dictates the structure. "Queue with O(1) middle removal" doesn't have a stdlib name, so vLLM built one. When you find a hand-rolled structure in a mature codebase, your first question should be "which operation forced this?" — the answer is usually a design document in disguise.
What the tests prove
| Test group | Invariant |
|---|---|
| free-queue mechanics | popleft/append/remove keep order and counts exact; sentinels never leak |
| allocate/free round-trips | I1 in both directions |
| no-dedup on identical content | I2 — two requests writing the same tokens get different blocks |
| partial blocks never cached | I3 |
revive-from-middle via touch | I4 + the O(1) removal that motivates the linked list |
| eviction drops the cache entry | _maybe_evict keeps the index consistent with reality |
Hitchhiker's notes
- The chained hash is the prefix property.
hash(block) = hash(parent_hash, tokens)means a block matches only if its entire ancestry matches. Without the chain, the block containing tokens[c, d]would collide between "ab|cd" and "xy|cd" — and a request would inherit KV computed under a different prefix. Attention is causal: KV at position i encodes everything before i. The chain is causality, hashed. (Upstream goes further and also folds in extras like LoRA id and multimodal hashes — same idea, more ancestry. And since v0.9, the hash uses SHA-256 by default rather than Python'shash, because across a fleet, a 64-bit hash collision means serving someone else's KV: at scale, "unlikely" is a frequency.) - The null block (id 0) is not a hack. Reserving a permanent placeholder block means
"no block here yet" can be represented inside the block-table tensor without sentinels
like −1 leaking into kernels. Upstream does exactly this. Watch that your
free_blocksandtouchnever count it. - Eviction is lazy and that's the elegance. Nothing proactively cleans the cache. A
cached-but-free block just sits in the queue; if demand arrives first,
get_new_blocksevicts it in passing (_maybe_evict); if a prefix hit arrives first,touchrevives it. The cache is exactly as big as whatever memory happens to be idle — no knob to tune, no background thread to race with. - Order of the free queue = eviction policy.
poplefttakes the front, so whatever orderingappendmaintains is your eviction policy. Append-on-free gives LRU-ish. Phase 3'sKVCacheManager.freeexploits this by returning a request's blocks in reverse order, so deep suffix blocks die before shared prefix blocks. Policy, encoded as list order — no priority queue in sight. (Upstream v0.22 keepsmaybe_evictand the queue discipline inBlockPool; older versions had a pluggableEvictorclass — the simplification is itself an instructive PR to read.)
Success, and what to do with it
LAB_IMPL=starter pytest phase-02-paged-attention/labs/lab-01-block-allocator -q
........ [100%]
Then do the two diffs that cement the knowledge:
diffyourstarter.pyagainstsolution.py— note every place you did it differently and decide which you prefer (sometimes yours is better; say why).- Open
upstream/vllm/v1/core/block_pool.pynext to your file and readget_new_blocks,touch,free_blocksfor real. List what production adds (multi-group KV for hybrid models, eviction events for observability,BlockHashtypes) and notice that nothing structural differs. You now read this file as its author.
References
mini_vllm/block_pool.py— the faithful port you're rebuilding, with upstream line refs.upstream/vllm/v1/core/block_pool.py:130—BlockPoolin production.upstream/vllm/v1/core/kv_cache_utils.py:116,164,541—KVCacheBlock,FreeKVCacheBlockQueue(read its docstring!),hash_block_tokens.- Kwon et al., Efficient Memory Management for Large Language Model Serving with PagedAttention (SOSP 2023) — the paper; §4 is this lab: https://arxiv.org/abs/2309.06180
- vLLM blog, vLLM: Easy, Fast, and Cheap LLM Serving with PagedAttention (June 2023) — the original announcement, with the fragmentation figures: https://blog.vllm.ai/2023/06/20/vllm.html
- vLLM docs, Automatic Prefix Caching (design) — the hash-chain design you implemented: https://docs.vllm.ai/en/latest/design/prefix_caching.html
- Denning, Virtual Memory (ACM Computing Surveys, 1970) — the 50-year-old playbook vLLM ran: https://dl.acm.org/doi/10.1145/356571.356573