Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Phase 02 — Deep Dive: PagedAttention in the real vLLM

All paths are relative to upstream/ at the pinned commit v0.22.1 @ 0decac0 (UPSTREAM_PIN.md). Open each file as we go. Line numbers are valid at the pin; the named symbol lets you re-find anything if you're on a different version.

The V1 KV-cache stack lives in vllm/v1/core/:

vllm/v1/core/
  kv_cache_utils.py        KVCacheBlock, FreeKVCacheBlockQueue, hashing  (the primitives)
  block_pool.py            BlockPool                                     (the allocator)
  kv_cache_manager.py      KVCacheManager, KVCacheBlocks                 (per-request tables)
  kv_cache_coordinator.py  coordinates groups (hybrid models)           (one level up)
  single_type_kv_cache_manager.py                                       (per-group logic)

We'll go bottom-up: the block, the free list, the pool, then the manager the scheduler calls.


Contents


1. KVCacheBlock — metadata for one physical block

vllm/v1/core/kv_cache_utils.py:116:

@dataclass
class KVCacheBlock:
    """KV-cache block metadata."""
    block_id: int
    ref_cnt: int = 0
    _block_hash: BlockHashWithGroupId | None = None
    # Used to construct a doubly linked list for free blocks.
    prev_free_block: "KVCacheBlock | None" = None
    next_free_block: "KVCacheBlock | None" = None
    is_null: bool = False

Crucial things to notice:

  • A KVCacheBlock is metadata only. The actual K/V tensors live in a big GPU buffer; this object just says "block #block_id, used by ref_cnt requests, hashing to _block_hash." Your mini_vllm.block_pool.KVCacheBlock is the same shape minus the GPU tensors.
  • ref_cnt is the heart of sharing (I1). The block_hash setter (line 139) asserts the block has no hash yet — enforcing I3/I2: a block's hash is set once when it fills, and the block id is stable.
  • prev_free_block/next_free_block are the linked-list pointers. The comment (line 128) warns: "These two attributes should only be manipulated by FreeKVCacheBlockQueue." That's an invariant about ownership — exactly the kind of thing a maintainer must respect.

reset_hash() (line 146) clears the hash on eviction. We'll see it called from _maybe_evict_cached_block.


2. FreeKVCacheBlockQueue — the free list, and why it's hand-rolled

vllm/v1/core/kv_cache_utils.py:164. Read its docstring in full — it's a masterclass. The key sentences:

"We implement this class instead of using Python builtin deque to support removing a block in the middle of the queue in O(1) time. … this class does not allocate any Python objects when manipulating the linked list."

Two design decisions, both about performance on the hot path (this runs for every allocation and free, every step):

  1. O(1) middle removal. On a prefix-cache hit, a block that was a free eviction candidate gets revived — pulled out of wherever it sits in the free list. A deque only does O(1) at the ends; the middle is O(n). So they wrote a doubly linked list.
  2. Zero allocation. They reuse the prev/next fields on the blocks themselves rather than allocating node wrappers. No GC pressure in the scheduler loop.

The eviction order is the other half (docstring lines 173–180):

"1. The least recently used block is at the front (LRU). 2. If two blocks have the same last accessed time … the one with more hash tokens (the tail of a block chain) is at the front."

So popleft() evicts LRU-first, and within a freed request, tail blocks go first (we'll see KVCacheManager.free frees in reverse so the longest shared prefix survives longest).

The sentinel trick (lines 196–214): a fake head and tail node so push/pop never special-case "is this the first/last?". Read popleft (216), remove (286), append (306), popleft_n (253), append_n (329). Your mini_vllm.block_pool.FreeKVCacheBlockQueue implements the same four operations with the same sentinel trick — compare them side by side.

Interview gold: "Why does vLLM use a custom linked list instead of collections.deque for free blocks?" → O(1) removal from the middle for prefix-cache revival, and zero per-operation allocation on the scheduler hot path. If you can also say where the middle removal happens (touch), you're answering at staff level.


3. BlockPool — owns every block, the free list, and the cache index

vllm/v1/core/block_pool.py:130. The constructor (__init__, line 149):

self.blocks: list[KVCacheBlock] = [KVCacheBlock(idx) for idx in range(num_gpu_blocks)]
self.free_block_queue = FreeKVCacheBlockQueue(self.blocks)
self.cached_block_hash_to_block: BlockHashToBlockMap = BlockHashToBlockMap()
# To represent a placeholder block with block_id=0.
self.null_block = self.free_block_queue.popleft()
self.null_block.is_null = True
  • One KVCacheBlock per physical block, all initially free.
  • A null block (id 0) is reserved as a placeholder (used for skipped positions, e.g. outside a sliding window). mini_vllm reserves block 0 the same way (BlockPool.__init__).
  • cached_block_hash_to_block is the prefix-cache index: block_hash → block. (Upstream uses a BlockHashToBlockMap that can hold multiple blocks per hash; mini_vllm simplifies to one block per hash — read the BlockHashToBlockMap docstring at line 34 to see why the real one is more complex: it must keep block ids stable, I2, so it doesn't dedup.)

Allocation: get_new_blocks (line 333)

def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]:
    if num_blocks > self.get_num_free_blocks():
        raise ValueError(f"Cannot get {num_blocks} free blocks from the pool")
    ret: list[KVCacheBlock] = self.free_block_queue.popleft_n(num_blocks)
    if self.enable_caching:
        for block in ret:
            self._maybe_evict_cached_block(block)   # <- was it a cached eviction candidate?
            assert block.ref_cnt == 0
            block.ref_cnt += 1
    else:
        for block in ret:
            assert block.ref_cnt == 0
            block.ref_cnt += 1
    return ret

Pop n blocks off the front of the free queue (LRU). If caching is on, each popped block might still be sitting in the prefix cache as an eviction candidate (I4) — so _maybe_evict_cached_block removes its hash entry before we reuse it. Then ref it (ref_cnt = 1). mini_vllm.BlockPool.get_new_blocks mirrors this exactly (including _maybe_evict).

Eviction: _maybe_evict_cached_block (line 365)

block_hash = block.block_hash
if block_hash is None:
    return False            # block was never cached, nothing to evict
if self.cached_block_hash_to_block.pop(block_hash, block.block_id) is None:
    return False
block.reset_hash()          # <- I3: it no longer holds cacheable content

This is the OS analogy made literal: reusing a physical page means invalidating whatever was mapped there. The hash is cleared so no future request thinks this block holds their prefix.

Sharing: touch (line 402) — the O(1) middle removal in action

def touch(self, blocks: Sequence[KVCacheBlock]) -> None:
    for block in blocks:
        # ref_cnt=0 means this block is in the free list (eviction candidate), so remove it.
        if block.ref_cnt == 0 and not block.is_null:
            self.free_block_queue.remove(block)   # <- O(1) middle removal! (the whole reason
        block.ref_cnt += 1                        #     for the custom linked list)

When a new request hits a prefix-cached block that happened to be free, touch revives it: pull it out of the middle of the free list and bump its ref count. This single line is why FreeKVCacheBlockQueue exists. mini_vllm.BlockPool.touch is identical in spirit.

Freeing: free_blocks (line 419)

for block in blocks_list:
    block.ref_cnt -= 1
self.free_block_queue.append_n(
    [block for block in blocks_list if block.ref_cnt == 0 and not block.is_null]
)

Decrement refs; any block that hit 0 goes back on the free queue (and stays in the cache as an eviction candidate — I4). The caller is expected to pass blocks in eviction-priority order (docstring line 419: "first block will be evicted first").

Caching full blocks: cache_full_blocks (line 211)

The big method that registers newly-full blocks into the prefix cache. The important loop (line 267):

for i, blk in enumerate(new_full_blocks):
    if blk.is_null or (block_mask is not None and not block_mask[i]):
        continue
    assert blk.block_hash is None         # I3 again
    block_hash = new_block_hashes[i]
    block_hash_with_group_id = make_block_hash_with_group_id(block_hash, kv_cache_group_id)
    blk.block_hash = block_hash_with_group_id
    self.cached_block_hash_to_block.insert(block_hash_with_group_id, blk)

Only full, non-null, non-masked blocks get a hash and enter the index. The rest of the method (lines 285–331) emits optional KV-cache events (for observability / external KV stores) — skip that on first read.


4. The hash that makes it a prefix cache: hash_block_tokens

vllm/v1/core/kv_cache_utils.py:541:

def hash_block_tokens(hash_function, parent_block_hash, curr_block_token_ids, extra_keys=None):
    if not parent_block_hash:
        parent_block_hash = NONE_HASH
    curr_block_token_ids_tuple = tuple(curr_block_token_ids)
    return BlockHash(
        hash_function((parent_block_hash, curr_block_token_ids_tuple, extra_keys))
    )

The block's hash includes its parent's hash. That chaining is the entire reason this is a prefix cache and not just a block cache: block [c, d] hashes differently depending on what came before it, so a hit on block k guarantees blocks 0..k were all identical. extra_keys folds in things that must not collide across contexts — LoRA id, multimodal content, a cache_salt — see generate_block_hash_extra_keys (line 503). Your mini_vllm.block_pool.hash_block_tokens keeps the parent chaining (the essential part) and drops extra_keys; the test test_prefix_hash_is_chained pins the property.


5. KVCacheManager — the per-request API the scheduler uses

vllm/v1/core/kv_cache_manager.py:110. This is the only KV class the scheduler talks to; it hides the pool/coordinator behind a clean interface. Two methods matter most.

get_computed_blocks (line 194) — prefix-cache lookup

max_cache_hit_length = request.num_tokens - 1   # must recompute last token to get logits
computed_blocks, num_new_computed_tokens = self.coordinator.find_longest_cache_hit(
    request.block_hashes, max_cache_hit_length
)

Note the num_tokens - 1: even if the entire prompt is cached, the last token must be recomputed to produce logits. mini_vllm.KVCacheManager.get_computed_blocks reproduces this exact max_hit_tokens = num_tokens - 1 rule and walks block hashes from the front, stopping at the first miss (a prefix must be contiguous from the start).

allocate_slots (line 236) — the workhorse

Read the giant ASCII docstring (lines 273–305): it diagrams how a request's tokens split into comp | new_comp | ext_comp | new | lookahead. The control flow (simplified):

num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate(...)
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
    return None                                  # <- OOM! caller must preempt and retry
...
new_blocks = self.coordinator.allocate_new_blocks(...)
...
self.coordinator.cache_blocks(request, num_tokens_to_cache)   # cache newly-full blocks
return self.create_kv_cache_blocks(new_blocks)

The single most important line for Phase 03 is return None: when there aren't enough free blocks, allocate_slots returns None, and the scheduler responds by preempting a running request and retrying. That handshake between the KV manager (memory truth) and the scheduler (policy) is the seam where memory management meets scheduling. mini_vllm.KVCacheManager.allocate_slots returns None on OOM for exactly this reason, and mini_vllm.Scheduler.schedule preempts on None.

free (line 429) — reverse order on purpose

"""We free the blocks in reverse order so that the tail blocks are evicted first when
caching is enabled."""
self.coordinator.free(request.request_id)

Freeing tail-first means the head blocks (the shared prefix) stay in the free queue longest, so they survive for the next request that shares that prefix. mini_vllm.KVCacheManager.free does reversed(blocks) for the same reason — see the comment there.


6. Where the blocks actually get used: the attention kernel

We've managed metadata; where do the K/V tensors and block tables meet a GPU kernel? Two places to glance at (full treatment in Phase 04):

  • The classic CUDA kernels: csrc/attention/paged_attention_v1.cu and ..._v2.cu. These take a block table and gather KV from scattered physical blocks. Search the .cu for block_table to see the indirection: physical_block = block_table[seq][logical_block].
  • The V1 backends that build the metadata: vllm/v1/attention/backends/flash_attn.py turns the scheduler's block ids + sequence lengths into the slot_mapping (where to write new K/V) and block tables (where to read old K/V) the kernel needs.

You don't need to read CUDA to pass this phase — but knowing that "the block table is literally passed into the attention kernel, which dereferences it per token" closes the loop on why the metadata we manage here is shaped the way it is.


Reading checklist

Tick these off in your lab notebook (write one sentence each):

  • KVCacheBlock — what does ref_cnt gate? what does the block_hash setter assert?
  • FreeKVCacheBlockQueue — why a linked list not a deque? where is the middle-removal used?
  • BlockPool.get_new_blocks — why call _maybe_evict_cached_block before reusing?
  • BlockPool.touch — trace the O(1) revival of a cached free block.
  • hash_block_tokens — why include the parent hash?
  • KVCacheManager.allocate_slots — what does returning None trigger, and where?

Now go build it: 02-mini-build.md, then the labs.