Li, Huang, Yang et al. — UIUC, Cohere, Princeton — 2024

SnapKV LLM Knows What You Are Looking for Before Generation

Attention patterns during prefill predict which KV pairs matter during generation. Compress the KV cache per-head with a voting mechanism. 3.6x faster decoding, 8.2x memory savings, 380K context on a single A100.

Prerequisites: Transformer attention (Q, K, V) + KV cache basics + softmax & top-k
10
Chapters
3+
Simulations

Chapter 0: The Problem

You are serving a chatbot that handles long documents. A user pastes a 50-page contract and asks: "What is the termination clause?" The model reads every token, computes attention, and generates an answer. Behind the scenes, it builds a KV cache — the key and value tensors at every layer for every input token — so that during autoregressive generation, each new token can attend to the full context without re-computing everything.

Here is the problem: the KV cache grows linearly with context length and linearly with the number of layers and heads. For a model like LLaMA-2 with 32 layers, 32 heads, and a head dimension of 128, each token in the KV cache costs:

2 × 32 × 32 × 128 × 2 bytes = 524,288 bytes ≈ 0.5 MB per token

That is half a megabyte per token. A 128K-token context? That is 64 GB just for the KV cache — nearly the entire memory of an A100-80GB GPU. You have no room left for model weights, activations, or batch size greater than 1.

But the pain does not stop at memory. During decoding, every new token must attend to every KV pair. That attention computation costs O(n) per token per head, where n is the context length. At 128K tokens, each generation step touches 128K key-value pairs across all heads and layers. Decoding latency grows linearly with context length, and throughput tanks.

The double squeeze: Long-context KV caches eat memory (preventing batching and longer contexts) and slow down decoding (each step must attend to the full cache). You need a way to keep the cache small without losing the information the model actually needs.

The naive solution is truncation: keep only the most recent N tokens. But this throws away the beginning of the document, exactly where titles, definitions, and key facts tend to live. Another approach: keep everything but use sparse attention patterns at decode time. But this requires architectural changes and re-training.

What if the model itself could tell you which KV pairs it actually needs — before it generates a single token?

KV Cache Memory Explosion

Drag the slider to set the context length. Watch how the KV cache size grows. The red line marks 80 GB (A100 capacity). The teal bar is SnapKV's compressed cache at a fixed budget of 1024 KV pairs.

Context length (K tokens) 64
For a 7B-parameter LLM with 32 layers and 32 KV heads (head_dim=128) using FP16, approximately how much memory does the KV cache consume for a 128K-token context?

Chapter 1: The Key Observation

SnapKV is built on a single, elegant observation: the attention patterns that emerge during prefill are nearly identical to the attention patterns used during generation. In other words, the model "knows what it is looking for" before it generates a single token.

Here is how the authors discovered this. They took samples from Ultrachat — a multi-turn instruction dataset with 1.4 million dialogues — and filtered for sequences with response length greater than 512 and prompt length greater than 3,000 tokens. Then they asked two questions:

Observation 1: The pattern is predictable

Split the input sequence into windows of 128 tokens each. For each window, compute the average attention weights over the prefix tokens. Identify the "important" prefix positions — those with high average attention — and compare them to the positions that actually receive high attention during generation.

The result: the last window of the input sequence (the final 128 tokens before generation begins) identifies important positions with a hit rate above 85% on most layers. The last window's attention pattern is an excellent predictor of what the model will attend to during generation.

Why the last window? The last tokens of the prompt have already attended to the entire context. Their attention weights are a summary of "what matters" in the full input. Think of it like a student who has read an entire essay — the thoughts in their head at the end reflect which parts were most important.

Observation 2: The pattern is stable

During generation, the authors split the generated tokens into windows of 128 and checked whether the important positions identified by the last input window remained important throughout generation. They do. The overlap rate between the prefill pattern and each generation window stays above 80% across layers, even 512 tokens into generation.

This is surprising. You might expect that as the model generates more tokens, its attention drifts to new parts of the context. But in practice, the core "information sources" the model identified during prefill remain the dominant attention targets throughout the response.

Prefill
The model processes the entire prompt. On each layer, each head's attention reveals which prefix positions are "important." The last window of the prompt captures this pattern.
Snapshot
Extract the attention pattern from the observation window. Use it to identify the top-k important KV positions per head.
Compress
Keep only those top-k KV pairs (plus the observation window itself). Discard the rest. The compressed cache is a tiny fraction of the original.
Generate
Decode using the compressed KV cache. Because the kept positions match what the model was going to attend to anyway, quality is preserved.

This is the insight that makes SnapKV possible: you do not need to run any generation tokens to know which KV pairs matter. The prefill attention pattern — specifically the pattern in the observation window — tells you everything.

Why does the "last window" of the input prompt work as a good predictor of which KV positions will be important during generation?

Chapter 2: The Observation Window

The observation window is the last Lobs tokens of the prompt. It is the window through which SnapKV "observes" the attention pattern. The rest of the prompt, before the observation window, is called the prefix (length Lprefix).

Lprompt = Lprefix + Lobs

Only the prefix gets compressed. The observation window's KV pairs are always kept in full — they are the most recent context and the model relies on them heavily during generation (recency bias is real).

Concretely, if your prompt has 100,000 tokens and Lobs = 32:

How big should Lobs be?

The authors experiment with window sizes of 16, 32, and 64. Their default is 16 or 32 tokens. A larger window means more queries "voting" on which prefix positions are important (better signal), but it also means slightly more KV pairs kept unconditionally. In practice, 16-32 is a sweet spot: enough signal to identify important positions, small enough to be negligible in the overall budget.

The observation window is always kept. The compressed KV cache = selected prefix KV pairs + full observation window KV pairs. This means the effective cache size is k + Lobs, where k is the number of selected prefix positions per head.

Why not use the whole prompt as the observation window?

You could, in principle, average attention weights across all query positions. But the key finding is that you do not need to. The last window alone achieves hit rates above 85%. Using a small window keeps the observation cost negligible — you are just reading attention weights that you already computed during prefill.

Importantly, the observation window is not a fixed set of special tokens or learned prompts. It is simply the last Lobs tokens of whatever prompt the user provides. It works across different tasks, different instructions, and different document types because the observation leverages the model's own attention mechanism rather than any task-specific heuristic.

In SnapKV, what is the observation window?

Chapter 3: The Voting Mechanism

The observation window gives us the attention weights from the last Lobs query positions over all prefix positions. But we have N attention heads, and each head might focus on different parts of the prefix. How do we decide which prefix positions to keep per head?

Step 1: Sum the attention weights

For a single head on a single layer, the observation window gives us a matrix of shape (Lobs, Lprefix) — each of the Lobs query tokens has a softmax-normalized attention distribution over the Lprefix prefix positions. We sum across the query dimension:

C[j] = ∑i=0Lobs Wobs[i, j]  for each prefix position j

This gives a single vote score C[j] for each prefix position j. Positions that receive high attention from many observation window tokens get a high total vote. Positions that receive high attention from only one or two tokens still get a moderate vote. Positions that nobody attends to get a low vote.

Step 2: Select Top-k

Given the vote scores, select the top k positions with the highest votes:

I = Topk(C, k)  where k = ⌊p × Lprefix

Here p is the compression rate — the fraction of prefix positions to keep. If p = 0.01 and Lprefix = 100,000, you keep k = 1,000 positions per head. The indices I tell you exactly which KV pairs to retain.

Per-head selection is critical. Each attention head can select a different set of positions. Head 5 on layer 12 might focus on named entities. Head 17 on layer 20 might focus on numerical values. By allowing each head to keep its own set of important positions, SnapKV preserves the specialized retrieval patterns that different heads have learned.

Step 3: Gather the compressed KV pairs

Use the selected indices to gather the corresponding key and value tensors from the prefix, then concatenate with the observation window's KV pairs:

# Per-head compressed keys and values
k_prefix_compressed = key_states[..., :−window_size, :].gather(dim=2, index=indices)
v_prefix_compressed = value_states[..., :−window_size, :].gather(dim=2, index=indices)

# Observation window keys and values (always kept)
k_obs = key_states[..., −window_size:, :]
v_obs = value_states[..., −window_size:, :]

# Final compressed cache
key_states = torch.cat([k_prefix_compressed, k_obs], dim=2)
value_states = torch.cat([v_prefix_compressed, v_obs], dim=2)

After this, the KV cache on this layer has only k + Lobs entries per head, instead of the original Lprompt. Generation proceeds as usual, but every attention computation is now over a much smaller cache.

Measuring effectiveness: Hit Rate

How do you know the voting mechanism picked the right positions? The paper defines a hit rate metric. During generation, at each step, check whether the positions that receive high attention (above a threshold θ) were among the ones selected by the voting mechanism:

H = ∑(Mthreshold ∧ Mvote) / ∑ Mthreshold

Mthreshold is a binary mask of positions with attention weight above θ during generation. Mvote is the binary mask of positions selected by voting. A hit rate of 0.85 means the voting mechanism captured 85% of the positions the model actually attends to during generation.

Attention Voting Visualization

Each column is a prefix position. The heatmap shows attention weights from the observation window (rows). The bottom bar shows the summed votes. Top-k positions are highlighted. Click "Resample" for a new random pattern.

Keep top-k 8
In SnapKV's voting mechanism, votes are accumulated per head. Why is per-head selection important rather than using a single global set of positions for all heads?

Chapter 4: The Pooling Kernel

The voting mechanism with pure top-k selection has a subtle problem. Imagine the model is trying to retrieve a phone number: "+1 (555) 867-5309". The attention pattern might strongly focus on the digits "867" and "5309" but give low attention to the parentheses, dashes, and country code. Top-k selection would keep "867" and "5309" but drop "+1 (555)" — and the model would generate an incomplete phone number.

The issue is that nearby tokens form semantic clusters. Information is rarely contained in a single token — it spans a contiguous region of tokens. If you select individual high-attention tokens without their neighbors, you lose contextual integrity.

Induction heads and copying: LLMs retrieve information by copying features from high-attention positions. Induction heads (discovered by Olsson et al., 2022) copy tokens surrounding the attended position, not just the attended token itself. If you remove the neighbors, the induction head copies incomplete information.

The fix: 1D average pooling

Before applying top-k to the vote scores, SnapKV applies a 1D average pooling operation along the sequence dimension:

# votes: shape (num_heads, L_prefix)
pool_vote = pooling(votes,
                   kernel_size=kernel_size,
                   padding=kernel_size // 2,
                   stride=1)
# pool_vote has the same shape as votes
# Each position's score is now the average of its neighborhood
indices = pool_vote.topk(max_capacity_prompt − window_size, dim=−1).indices

With kernel_size = 5, each position's score becomes the average of itself and its two neighbors on each side. This means a position with a moderate score that is surrounded by high-score positions gets a boost. Conversely, an isolated high-attention position that is surrounded by unimportant tokens gets slightly smoothed down.

The effect: cluster selection

After pooling, top-k selection tends to pick contiguous clusters of positions rather than scattered individual tokens. The phone number "+1 (555) 867-5309" gets selected as a unit because the pooling kernel lifts the scores of the low-attention tokens ("+1", "(", ")") that sit between the high-attention digits.

Kernel SizeEffectUse Case
1No pooling — pure top-kWhen individual tokens carry independent facts
3Light smoothing — small clustersGeneral-purpose default
5Moderate clustering — captures phrasesQA, summarization (author's default)
7Aggressive clustering — captures sentencesVery long documents where context spans are wide

The authors use kernel_size = 5 or 7 as their default, with observation window size of 16. In ablation studies (Section 5.2 of the paper), they show that pooling significantly improves retrieval accuracy on the LongEval-Lines benchmark — a task where the model must retrieve a specific value from a key-value pair buried in a long noisy context. Without pooling, SnapKV retrieves the value token but misses surrounding formatting. With pooling, it captures the full key-value pair.

Pooling does not change the number of selected positions. It only changes which positions are selected. The same k positions are kept, but they tend to form contiguous clusters rather than scattered points. Same budget, better coverage.
Why does SnapKV apply a 1D average pooling kernel to the vote scores before top-k selection?

Chapter 5: The Full Algorithm

Now we can assemble the complete SnapKV algorithm. It runs once, at the end of prefill, on every layer independently. There is no training, no learned parameters, no architectural change. It modifies only the KV cache.

Inputs

ParameterSymbolTypical Value
Query statesQ ∈ RB×N×L×dFrom the prefill forward pass
Key statesK ∈ RB×N×L×dFull prefill KV cache
Value statesV ∈ RB×N×L×dFull prefill KV cache
Observation window sizeLobs16 or 32
Max KV cache capacityk1024, 2048, or 4096
Pooling kernel sizekernel_size5 or 7

Algorithm step by step

1. Short-circuit check
If the prompt length L ≤ k (max capacity), return the full KV cache unchanged. No compression needed.
2. Compute observation attention
Take the last Lobs queries and compute their attention weights over the prefix keys: Wobs = softmax(Q[..., −Lobs:, :] · K[..., :−Lobs, :]T / √d) ∈ RB×N×Lobs×Lprefix
3. Vote: sum across query dimension
C = ∑i Wobs[:, :, i, :] ∈ RB×N×Lprefix. Each prefix position gets a cumulative vote from all observation window queries.
4. Pool votes
Apply 1D average pooling with the kernel: pool_vote = AvgPool1D(C, kernel_size, padding=kernel_size//2, stride=1). Same shape as C but smoothed.
5. Select top-k
I = topk(pool_vote, k − Lobs). Select the k − Lobs highest-scoring prefix positions per head. Sort the indices to maintain position order.
6. Gather and concatenate
K_compressed = cat(K[prefix][I], K[obs]). V_compressed = cat(V[prefix][I], V[obs]). Total cache size per head: k.

Pseudo-code (PyTorch style)

def snap_kv(query_states, key_states, value_states,
           window_size, max_capacity, kernel_size):
    bsz, num_heads, q_len, head_dim = query_states.shape
    assert key_states.shape[-2] == query_states.shape[-2]

    # Short-circuit: no compression needed
    if q_len <= max_capacity:
        return key_states, value_states

    # Step 2: Compute attention from observation window to prefix
    attn_weights = compute_attn(
        query_states[..., −window_size:, :],
        key_states[..., :−window_size, :],
        attention_mask)

    # Step 3: Vote — sum across query dimension
    votes = attn_weights[..., −window_size:, :−window_size].sum(dim=-2)

    # Step 4: Pool votes for clustering
    pool_vote = pooling(votes,
                        kernel_size=kernel_size,
                        padding=kernel_size // 2,
                        stride=1)

    # Step 5: Select top-k prefix positions per head
    indices = pool_vote.topk(
        max_capacity − window_size, dim=-1).indices

    # Sort to maintain position order (for RoPE correctness)
    indices = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)

    # Step 6: Gather compressed prefix + observation window
    k_compress = key_states[..., :−window_size, :].gather(dim=2, index=indices)
    v_compress = value_states[..., :−window_size, :].gather(dim=2, index=indices)
    k_obs = key_states[..., −window_size:, :]
    v_obs = value_states[..., −window_size:, :]

    return torch.cat([k_compress, k_obs], dim=2), \
           torch.cat([v_compress, v_obs], dim=2)
When does this run? Exactly once per layer, at the end of prefill (the prompt phase). After snap_kv compresses the cache, all subsequent generation tokens attend to the compressed cache only. The cost of the voting + pooling + gather is negligible compared to the prefill forward pass itself.
After SnapKV compression, what is the total number of KV pairs per head in the compressed cache?

Chapter 6: The Memory Math

Let us do the concrete arithmetic for a real deployment scenario. This is where SnapKV's value becomes visceral.

Setup: LLaMA-2 7B with 128K context

ParameterValue
Layers32
KV heads32 (no GQA in LLaMA-2 7B)
Head dimension128
PrecisionFP16 (2 bytes per element)
Context length128,000 tokens

Full KV cache size

Memoryfull = 2 × L × layers × heads × d × bytes = 2 × 128,000 × 32 × 32 × 128 × 2
= 2 × 128,000 × 262,144 = 67,108,864,000 bytes ≈ 62.5 GB

That is nearly the full memory of an A100-80GB. You cannot even batch two requests.

SnapKV compressed cache (k = 1024)

Memorysnap = 2 × k × layers × heads × d × bytes = 2 × 1,024 × 32 × 32 × 128 × 2
= 2 × 1,024 × 262,144 = 536,870,912 bytes ≈ 0.5 GB

From 62.5 GB to 0.5 GB. That is a 125x compression ratio. With k = 4096, you still get a 31x compression ratio.

Context LengthFull CacheSnapKV (k=1024)Compression
16K7.8 GB0.5 GB16x
32K15.6 GB0.5 GB32x
64K31.3 GB0.5 GB64x
128K62.5 GB0.5 GB125x
380K185.5 GB0.5 GB380x
380K on a single GPU: The authors demonstrate that SnapKV enables LWM-Text-Chat-1M (a 7B model fine-tuned for 1M context) to process up to 380K tokens on a single A100-80GB GPU. The original implementation hits OOM at 33K tokens. SnapKV extends this by 11x.

Decoding speedup

Memory savings are not the only benefit. Decoding speed also improves because each attention computation now operates over k entries instead of L. The paper reports:

The constant-latency property is particularly important for production serving. With a full cache, decoding latency scales with prompt length, making latency SLAs impossible to guarantee. With SnapKV, decoding latency depends only on k (a hyperparameter you control), not on the user's prompt length.

If you set k = 2048 for a model with 128K context, what is the approximate compression ratio of the KV cache?

Chapter 7: Experimental Results

The paper evaluates SnapKV across multiple models, benchmarks, and settings. Here are the key results.

Needle-in-a-Haystack

The classic long-context stress test: hide a specific sentence (the "needle") at a random position in a very long document (the "haystack"). The model must retrieve it. The authors extend the haystack to 380K tokens on a single A100-80GB GPU with SnapKV (k = 1024, observation window = 16, kernel = 5).

Result: LWM-Text-Chat-1M with SnapKV retrieves the needle correctly up to 140K tokens with only a slight accuracy drop. The original implementation OOMs at 33K tokens. That is a 380x compression ratio with minimal accuracy loss.

LongBench (16 datasets)

LongBench tests long-context understanding across single-doc QA, multi-doc QA, summarization, few-shot learning, synthetic tasks, and code completion. The authors test on four models:

ModelFull KVSnapKV-1024SnapKV-2048SnapKV-4096
LWM-Text-Chat-1MBaseline≈ equal≈ equal≈ equal
Mistral-7B-Instruct-v0.2BaselineSlight drop≈ equal≈ equal
LongChat-7B-v1.5-32KBaseline≈ equal≈ equal≈ equal
Mixtral-8x7B-Instruct-v0.1BaselineSlight drop≈ equal≈ equal

The striking finding: with k = 1024 (keeping only 1024 KV pairs per head), SnapKV achieves comparable or even better performance than the full KV cache on most datasets. Some models actually improve with SnapKV-1024, likely because removing noisy, irrelevant KV pairs acts as a form of attention denoising.

Average input token length across these 4 models is around 13K. With k = 1024, that is a 92.5% compression rate. With k = 4096, it is 68% compression. Both achieve negligible accuracy drops across 16 benchmarks.

Command-R (35B, 128K context)

The authors also evaluate on Cohere's Command-R, a production model designed for RAG with 128K context. On the Needle-in-a-Haystack test with a 128K sequence and 32x compression:

ModelScore (Full)Score (SnapKV)Difference
Command-R9.8669.819−0.5%

On RAG tasks with 20K-40K token contexts and a 5-10x compression ratio, SnapKV retains 98.8% of Command-R's RAG Citation F1 score (only −1.2%) and 97.9% of the end-to-end RAG F1 score (only −2.1%).

Compression vs. Accuracy Trade-off

Drag the slider to change the KV budget (k). Watch how compression ratio and accuracy change for a 16K-token input. The sweet spot is around k=1024-2048.

KV budget (k) 1024
On LongBench with Mistral-7B-Instruct-v0.2, SnapKV with k=1024 achieves comparable performance to the full KV cache. The average input length is ~13K tokens. What compression rate is that?

Chapter 8: Comparisons

SnapKV is not the first attempt to compress the KV cache. Let us compare it with the key alternatives and understand what SnapKV does differently.

H2O (Heavy-Hitter Oracle)

H2O (Zhang et al., 2023) maintains a fixed-size KV cache during generation by tracking cumulative attention scores and greedily evicting low-scoring KV pairs at each step. Key differences from SnapKV:

On LongBench, SnapKV with k=1024 outperforms H2O with k=4096 on 11 out of 16 benchmarks. SnapKV gets better compression with better quality because it compresses the right thing (the prompt cache) using a better signal (the prefill attention pattern).

The key insight vs. H2O: H2O does not help with prompt processing — it still needs the full KV cache during prefill. SnapKV compresses the prompt cache, which is where most memory goes in long-context serving (prompts are much longer than generated responses in RAG, summarization, and code analysis).

StreamingLLM

StreamingLLM (Xiao et al., 2023) keeps only the first few "attention sink" tokens plus a sliding window of recent tokens. Everything in the middle is discarded.

SnapKV selects positions based on what the model actually attends to, so it can keep important information regardless of where it appears in the context.

ScissorHands

ScissorHands (Liu et al., 2023) identifies "pivotal tokens" that exhibit consistent attention weight patterns across previous generation windows. Like H2O, it operates during generation and focuses on tokens appended during decoding. It misses the prompt cache entirely.

FastGen / Adaptive KV Compression

FastGen (Ge et al., 2023) implements a dual-phase algorithm: profile attention patterns during prompt encoding, then evict during generation based on those profiles. It is conceptually closer to SnapKV but more complex (four compression policies) and still primarily targets the generation-phase cache.

MethodCompresses Prompt KV?When?SignalFine-tuning?
SnapKVYesOnce, end of prefillObservation window attention + poolingNo
H2ONo — generation onlyEvery decode stepCumulative attentionNo
StreamingLLMYes (crude: drop middle)ContinuousPosition (first + last)No
ScissorHandsNo — generation onlyEvery decode stepConsistent attentionNo
FastGenPartialProfile + evictAttention profilingNo
TOVANo — generation onlyEvery decode stepCurrent-step attentionNo

Compatibility: SnapKV + other methods

Because SnapKV operates only on the prompt cache and only at prefill time, it is orthogonal to generation-phase methods like H2O. You could combine SnapKV (compress the prompt cache) with H2O (manage the growing generation cache) for maximum efficiency. The paper also shows that SnapKV works with parallel decoding strategies like Medusa for additional speedups.

What is the fundamental difference between SnapKV and H2O?

Chapter 9: Connections

Integration

SnapKV requires minimal code changes. The authors report that their HuggingFace implementation modifies only a few lines of code in the attention module — specifically, adding the snap_kv function call after the prefill forward pass computes the KV cache. For vLLM, the integration is similarly lightweight: intercept the KV cache after prefill, run the voting/pooling/gather, and replace the cache before generation begins.

The key implementation detail: SnapKV operates on the already-computed attention weights and KV states from prefill. It does not require a separate forward pass or any additional model computation. The observation window attention is extracted from the prefill computation that was going to happen anyway.

What SnapKV does NOT do

Limitations

Related Veanors

CacheBlend
Reuses pre-computed KV caches across RAG queries by selectively re-computing 10-15% of tokens. Complementary to SnapKV — CacheBlend speeds up prefill, SnapKV speeds up generation.
Attention Is All You Need
The original transformer paper. Understanding multi-head attention is prerequisite to understanding why per-head KV selection works.
DeepSeek-V3
Uses Multi-Latent Attention (MLA) — a different approach to KV cache efficiency that compresses keys and values into a lower-dimensional latent space rather than selecting positions.
The one-sentence takeaway: SnapKV exploits the fact that attention patterns during prefill predict generation-time attention. By using a small observation window to vote on important prefix positions, pooling to preserve clusters, and selecting per-head top-k positions, it achieves 100x+ KV cache compression with negligible accuracy loss — no training, no architecture changes, a few lines of code.