Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Phase 04 — The Hitchhiker's Guide to Attention Backends

Phase 03 · Course home · Phase 05

Contents


Don't Panic

Attention is one mathematical operation. But there are a dozen hyper-tuned GPU kernels that compute it (FlashAttention, FlashInfer, Triton, FlashMLA, TRTLLM-GEN…), each best for some combination of hardware, model, and batch shape. vLLM hides them all behind one interface, picks the right one at startup, and feeds it the metadata it needs (the block tables from Phase 2). This phase is that interface and that choice — usually the single hottest kernel in decode, so it's where a lot of real performance wins and bugs live.

model's Attention layer  (one API)
        │  q, k, v
        ▼
   AttentionImpl  (the chosen backend: FlashAttention / FlashInfer / Triton / MLA / ...)
        │  + AttentionMetadata (block tables, seq lens, slot mapping)
        ▼
   the CUDA kernel  ── gathers paged KV via the block table, computes softmax(QKᵀ)V

Step 1: Why attention needs a special kernel (recap Phase 2)

A token attends to all earlier tokens, whose K/V live in scattered physical blocks (PagedAttention). So the kernel can't just multiply two contiguous matrices — it must, per token, look up physical_block = block_table[logical_block] and gather K/V from all over memory. It also must write this step's new K/V to the right slot (slot_mapping). Two pieces of metadata the scheduler/runner build and hand the kernel:

  • block table — where to read prior KV (logical → physical block).
  • slot mapping — where to write this step's new K/V.

Plus per-request sequence lengths so variable-length (varlen) batches pack together.


Step 2: Why so many kernels?

The math is fixed; the fast way to do it depends on context:

  • FlashAttention — the classic: never materializes the full N×N attention matrix; streams K/V in tiles using online softmax (running max + rescale), so memory is O(N) not O(N²). Great general default.
  • FlashInfer — a library specialized for serving: paged KV, prefill+decode wrappers, fast for many small/decode requests; often wins at high concurrency.
  • Triton — kernels written in Triton (Python-ish DSL); portable, the fallback when a hand-tuned CUDA kernel isn't available for your case.
  • FlashMLA — for MLA (Multi-head Latent Attention), DeepSeek's design that compresses KV into a low-rank latent — different KV layout, so it needs its own kernel.
  • TRTLLM-GEN — NVIDIA TensorRT-LLM generated kernels, tuned for specific GPUs/precisions.

Different head dims, dtypes (fp16/bf16/fp8), features (sliding window, soft-cap, ALiBi), and hardware all shift which kernel is fastest or even available.


Step 3: The backend abstraction

vLLM factors attention into four roles (vllm/v1/attention/backend.py):

RoleJob
Attention layerwhat the model calls (q,k,v -> out); backend-agnostic
AttentionBackendnames the impl + metadata classes for a kernel family
AttentionImplthe actual forward that runs the kernel
AttentionMetadataBuilderturns SchedulerOutput into the kernel's metadata (block tables, seq lens, slot mapping) each step

A selector (get_attn_backend, selector.py:52) picks the backend at startup from platform + dtype + head_dim + model features, overridable with VLLM_ATTENTION_BACKEND=FLASH_ATTN|FLASHINFER| TRITON_ATTN|.... The model never changes — only which AttentionImpl is plugged in.


Step 4: Online softmax (the FlashAttention trick), in one picture

You can't hold a 1×N attention row in fast SRAM for long N. So FlashAttention streams K/V in tiles and keeps a running result, rescaling as it goes:

for each tile of (K,V):
    s = q·Kᵀ_tile                  # scores for this tile
    m_new = max(m_old, max(s))     # running max (for numerical stability)
    correction = exp(m_old - m_new)
    acc = acc*correction + exp(s - m_new) · V_tile   # rescale old, add new
    denom = denom*correction + sum(exp(s - m_new))
out = acc / denom

You'll implement exactly this in lab-01 (numpy, CPU) over a paged KV cache, and prove it equals plain dense attention. That single lab demystifies FlashAttention and PagedAttention's kernel side at once.


The invariants to memorize

  1. Attention is one op; the backend is which kernel computes it. Model code is backend-agnostic.
  2. The kernel needs block table (read map), slot mapping (write map), seq lens (varlen).
  3. Online softmax makes attention O(N) memory and is why "Flash" kernels exist.
  4. Backend is chosen at startup (selector) and overridable via VLLM_ATTENTION_BACKEND.
  5. MLA models need MLA-specific backends (different KV layout).

What you'll do

  • Read: 01-deep-dive.md — the Attention layer, the backend base classes, the selector, and FlashAttentionImpl/its metadata builder, all line-anchored.
  • Build: 02-mini-build.md — paged attention with online softmax in numpy.
  • Labs (see labs/README.md; recommended order 01 → 03 → 04 → 02):
    • lab-01-paged-attention-gather [CPU-OK] — implement online-softmax attention over a paged KV cache; prove it equals dense attention.
    • lab-02-backend-selection [GPU-OPT] — read the selector, build the (GPU, dtype, model) → backend matrix, verify with env overrides (captured output).
    • lab-03-causal-prefill-attention [CPU-OK] — the prefill kernel shape: M queries, causal loop bounds, start_pos offsets; prove chunked prefill == one-shot at the attention layer.
    • lab-04-flash-decoding-partitions [CPU-OK] — split-KV decode: attention state as a mergeable (max, denom, acc) triple; equality with dense for any partition count/order.
  • Test yourself: EXERCISES.md, INTERVIEW.md, CHEATSHEET.md.

Phase 03 · Course home · Phase 05