Attention patterns during prefill predict which KV pairs matter during generation. Compress the KV cache per-head with a voting mechanism. 3.6x faster decoding, 8.2x memory savings, 380K context on a single A100.
You are serving a chatbot that handles long documents. A user pastes a 50-page contract and asks: "What is the termination clause?" The model reads every token, computes attention, and generates an answer. Behind the scenes, it builds a KV cache — the key and value tensors at every layer for every input token — so that during autoregressive generation, each new token can attend to the full context without re-computing everything.
Here is the problem: the KV cache grows linearly with context length and linearly with the number of layers and heads. For a model like LLaMA-2 with 32 layers, 32 heads, and a head dimension of 128, each token in the KV cache costs:
That is half a megabyte per token. A 128K-token context? That is 64 GB just for the KV cache — nearly the entire memory of an A100-80GB GPU. You have no room left for model weights, activations, or batch size greater than 1.
But the pain does not stop at memory. During decoding, every new token must attend to every KV pair. That attention computation costs O(n) per token per head, where n is the context length. At 128K tokens, each generation step touches 128K key-value pairs across all heads and layers. Decoding latency grows linearly with context length, and throughput tanks.
The naive solution is truncation: keep only the most recent N tokens. But this throws away the beginning of the document, exactly where titles, definitions, and key facts tend to live. Another approach: keep everything but use sparse attention patterns at decode time. But this requires architectural changes and re-training.
What if the model itself could tell you which KV pairs it actually needs — before it generates a single token?
Drag the slider to set the context length. Watch how the KV cache size grows. The red line marks 80 GB (A100 capacity). The teal bar is SnapKV's compressed cache at a fixed budget of 1024 KV pairs.
SnapKV is built on a single, elegant observation: the attention patterns that emerge during prefill are nearly identical to the attention patterns used during generation. In other words, the model "knows what it is looking for" before it generates a single token.
Here is how the authors discovered this. They took samples from Ultrachat — a multi-turn instruction dataset with 1.4 million dialogues — and filtered for sequences with response length greater than 512 and prompt length greater than 3,000 tokens. Then they asked two questions:
Split the input sequence into windows of 128 tokens each. For each window, compute the average attention weights over the prefix tokens. Identify the "important" prefix positions — those with high average attention — and compare them to the positions that actually receive high attention during generation.
The result: the last window of the input sequence (the final 128 tokens before generation begins) identifies important positions with a hit rate above 85% on most layers. The last window's attention pattern is an excellent predictor of what the model will attend to during generation.
During generation, the authors split the generated tokens into windows of 128 and checked whether the important positions identified by the last input window remained important throughout generation. They do. The overlap rate between the prefill pattern and each generation window stays above 80% across layers, even 512 tokens into generation.
This is surprising. You might expect that as the model generates more tokens, its attention drifts to new parts of the context. But in practice, the core "information sources" the model identified during prefill remain the dominant attention targets throughout the response.
This is the insight that makes SnapKV possible: you do not need to run any generation tokens to know which KV pairs matter. The prefill attention pattern — specifically the pattern in the observation window — tells you everything.
The observation window is the last Lobs tokens of the prompt. It is the window through which SnapKV "observes" the attention pattern. The rest of the prompt, before the observation window, is called the prefix (length Lprefix).
Only the prefix gets compressed. The observation window's KV pairs are always kept in full — they are the most recent context and the model relies on them heavily during generation (recency bias is real).
Concretely, if your prompt has 100,000 tokens and Lobs = 32:
The authors experiment with window sizes of 16, 32, and 64. Their default is 16 or 32 tokens. A larger window means more queries "voting" on which prefix positions are important (better signal), but it also means slightly more KV pairs kept unconditionally. In practice, 16-32 is a sweet spot: enough signal to identify important positions, small enough to be negligible in the overall budget.
You could, in principle, average attention weights across all query positions. But the key finding is that you do not need to. The last window alone achieves hit rates above 85%. Using a small window keeps the observation cost negligible — you are just reading attention weights that you already computed during prefill.
Importantly, the observation window is not a fixed set of special tokens or learned prompts. It is simply the last Lobs tokens of whatever prompt the user provides. It works across different tasks, different instructions, and different document types because the observation leverages the model's own attention mechanism rather than any task-specific heuristic.
The observation window gives us the attention weights from the last Lobs query positions over all prefix positions. But we have N attention heads, and each head might focus on different parts of the prefix. How do we decide which prefix positions to keep per head?
For a single head on a single layer, the observation window gives us a matrix of shape (Lobs, Lprefix) — each of the Lobs query tokens has a softmax-normalized attention distribution over the Lprefix prefix positions. We sum across the query dimension:
This gives a single vote score C[j] for each prefix position j. Positions that receive high attention from many observation window tokens get a high total vote. Positions that receive high attention from only one or two tokens still get a moderate vote. Positions that nobody attends to get a low vote.
Given the vote scores, select the top k positions with the highest votes:
Here p is the compression rate — the fraction of prefix positions to keep. If p = 0.01 and Lprefix = 100,000, you keep k = 1,000 positions per head. The indices I tell you exactly which KV pairs to retain.
Use the selected indices to gather the corresponding key and value tensors from the prefix, then concatenate with the observation window's KV pairs:
# Per-head compressed keys and values k_prefix_compressed = key_states[..., :−window_size, :].gather(dim=2, index=indices) v_prefix_compressed = value_states[..., :−window_size, :].gather(dim=2, index=indices) # Observation window keys and values (always kept) k_obs = key_states[..., −window_size:, :] v_obs = value_states[..., −window_size:, :] # Final compressed cache key_states = torch.cat([k_prefix_compressed, k_obs], dim=2) value_states = torch.cat([v_prefix_compressed, v_obs], dim=2)
After this, the KV cache on this layer has only k + Lobs entries per head, instead of the original Lprompt. Generation proceeds as usual, but every attention computation is now over a much smaller cache.
How do you know the voting mechanism picked the right positions? The paper defines a hit rate metric. During generation, at each step, check whether the positions that receive high attention (above a threshold θ) were among the ones selected by the voting mechanism:
Mthreshold is a binary mask of positions with attention weight above θ during generation. Mvote is the binary mask of positions selected by voting. A hit rate of 0.85 means the voting mechanism captured 85% of the positions the model actually attends to during generation.
Each column is a prefix position. The heatmap shows attention weights from the observation window (rows). The bottom bar shows the summed votes. Top-k positions are highlighted. Click "Resample" for a new random pattern.
The voting mechanism with pure top-k selection has a subtle problem. Imagine the model is trying to retrieve a phone number: "+1 (555) 867-5309". The attention pattern might strongly focus on the digits "867" and "5309" but give low attention to the parentheses, dashes, and country code. Top-k selection would keep "867" and "5309" but drop "+1 (555)" — and the model would generate an incomplete phone number.
The issue is that nearby tokens form semantic clusters. Information is rarely contained in a single token — it spans a contiguous region of tokens. If you select individual high-attention tokens without their neighbors, you lose contextual integrity.
Before applying top-k to the vote scores, SnapKV applies a 1D average pooling operation along the sequence dimension:
# votes: shape (num_heads, L_prefix) pool_vote = pooling(votes, kernel_size=kernel_size, padding=kernel_size // 2, stride=1) # pool_vote has the same shape as votes # Each position's score is now the average of its neighborhood indices = pool_vote.topk(max_capacity_prompt − window_size, dim=−1).indices
With kernel_size = 5, each position's score becomes the average of itself and its two neighbors on each side. This means a position with a moderate score that is surrounded by high-score positions gets a boost. Conversely, an isolated high-attention position that is surrounded by unimportant tokens gets slightly smoothed down.
After pooling, top-k selection tends to pick contiguous clusters of positions rather than scattered individual tokens. The phone number "+1 (555) 867-5309" gets selected as a unit because the pooling kernel lifts the scores of the low-attention tokens ("+1", "(", ")") that sit between the high-attention digits.
| Kernel Size | Effect | Use Case |
|---|---|---|
| 1 | No pooling — pure top-k | When individual tokens carry independent facts |
| 3 | Light smoothing — small clusters | General-purpose default |
| 5 | Moderate clustering — captures phrases | QA, summarization (author's default) |
| 7 | Aggressive clustering — captures sentences | Very long documents where context spans are wide |
The authors use kernel_size = 5 or 7 as their default, with observation window size of 16. In ablation studies (Section 5.2 of the paper), they show that pooling significantly improves retrieval accuracy on the LongEval-Lines benchmark — a task where the model must retrieve a specific value from a key-value pair buried in a long noisy context. Without pooling, SnapKV retrieves the value token but misses surrounding formatting. With pooling, it captures the full key-value pair.
Now we can assemble the complete SnapKV algorithm. It runs once, at the end of prefill, on every layer independently. There is no training, no learned parameters, no architectural change. It modifies only the KV cache.
| Parameter | Symbol | Typical Value |
|---|---|---|
| Query states | Q ∈ RB×N×L×d | From the prefill forward pass |
| Key states | K ∈ RB×N×L×d | Full prefill KV cache |
| Value states | V ∈ RB×N×L×d | Full prefill KV cache |
| Observation window size | Lobs | 16 or 32 |
| Max KV cache capacity | k | 1024, 2048, or 4096 |
| Pooling kernel size | kernel_size | 5 or 7 |
def snap_kv(query_states, key_states, value_states, window_size, max_capacity, kernel_size): bsz, num_heads, q_len, head_dim = query_states.shape assert key_states.shape[-2] == query_states.shape[-2] # Short-circuit: no compression needed if q_len <= max_capacity: return key_states, value_states # Step 2: Compute attention from observation window to prefix attn_weights = compute_attn( query_states[..., −window_size:, :], key_states[..., :−window_size, :], attention_mask) # Step 3: Vote — sum across query dimension votes = attn_weights[..., −window_size:, :−window_size].sum(dim=-2) # Step 4: Pool votes for clustering pool_vote = pooling(votes, kernel_size=kernel_size, padding=kernel_size // 2, stride=1) # Step 5: Select top-k prefix positions per head indices = pool_vote.topk( max_capacity − window_size, dim=-1).indices # Sort to maintain position order (for RoPE correctness) indices = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim) # Step 6: Gather compressed prefix + observation window k_compress = key_states[..., :−window_size, :].gather(dim=2, index=indices) v_compress = value_states[..., :−window_size, :].gather(dim=2, index=indices) k_obs = key_states[..., −window_size:, :] v_obs = value_states[..., −window_size:, :] return torch.cat([k_compress, k_obs], dim=2), \ torch.cat([v_compress, v_obs], dim=2)
Let us do the concrete arithmetic for a real deployment scenario. This is where SnapKV's value becomes visceral.
| Parameter | Value |
|---|---|
| Layers | 32 |
| KV heads | 32 (no GQA in LLaMA-2 7B) |
| Head dimension | 128 |
| Precision | FP16 (2 bytes per element) |
| Context length | 128,000 tokens |
That is nearly the full memory of an A100-80GB. You cannot even batch two requests.
From 62.5 GB to 0.5 GB. That is a 125x compression ratio. With k = 4096, you still get a 31x compression ratio.
| Context Length | Full Cache | SnapKV (k=1024) | Compression |
|---|---|---|---|
| 16K | 7.8 GB | 0.5 GB | 16x |
| 32K | 15.6 GB | 0.5 GB | 32x |
| 64K | 31.3 GB | 0.5 GB | 64x |
| 128K | 62.5 GB | 0.5 GB | 125x |
| 380K | 185.5 GB | 0.5 GB | 380x |
Memory savings are not the only benefit. Decoding speed also improves because each attention computation now operates over k entries instead of L. The paper reports:
The constant-latency property is particularly important for production serving. With a full cache, decoding latency scales with prompt length, making latency SLAs impossible to guarantee. With SnapKV, decoding latency depends only on k (a hyperparameter you control), not on the user's prompt length.
The paper evaluates SnapKV across multiple models, benchmarks, and settings. Here are the key results.
The classic long-context stress test: hide a specific sentence (the "needle") at a random position in a very long document (the "haystack"). The model must retrieve it. The authors extend the haystack to 380K tokens on a single A100-80GB GPU with SnapKV (k = 1024, observation window = 16, kernel = 5).
Result: LWM-Text-Chat-1M with SnapKV retrieves the needle correctly up to 140K tokens with only a slight accuracy drop. The original implementation OOMs at 33K tokens. That is a 380x compression ratio with minimal accuracy loss.
LongBench tests long-context understanding across single-doc QA, multi-doc QA, summarization, few-shot learning, synthetic tasks, and code completion. The authors test on four models:
| Model | Full KV | SnapKV-1024 | SnapKV-2048 | SnapKV-4096 |
|---|---|---|---|---|
| LWM-Text-Chat-1M | Baseline | ≈ equal | ≈ equal | ≈ equal |
| Mistral-7B-Instruct-v0.2 | Baseline | Slight drop | ≈ equal | ≈ equal |
| LongChat-7B-v1.5-32K | Baseline | ≈ equal | ≈ equal | ≈ equal |
| Mixtral-8x7B-Instruct-v0.1 | Baseline | Slight drop | ≈ equal | ≈ equal |
The striking finding: with k = 1024 (keeping only 1024 KV pairs per head), SnapKV achieves comparable or even better performance than the full KV cache on most datasets. Some models actually improve with SnapKV-1024, likely because removing noisy, irrelevant KV pairs acts as a form of attention denoising.
The authors also evaluate on Cohere's Command-R, a production model designed for RAG with 128K context. On the Needle-in-a-Haystack test with a 128K sequence and 32x compression:
| Model | Score (Full) | Score (SnapKV) | Difference |
|---|---|---|---|
| Command-R | 9.866 | 9.819 | −0.5% |
On RAG tasks with 20K-40K token contexts and a 5-10x compression ratio, SnapKV retains 98.8% of Command-R's RAG Citation F1 score (only −1.2%) and 97.9% of the end-to-end RAG F1 score (only −2.1%).
Drag the slider to change the KV budget (k). Watch how compression ratio and accuracy change for a 16K-token input. The sweet spot is around k=1024-2048.
SnapKV is not the first attempt to compress the KV cache. Let us compare it with the key alternatives and understand what SnapKV does differently.
H2O (Zhang et al., 2023) maintains a fixed-size KV cache during generation by tracking cumulative attention scores and greedily evicting low-scoring KV pairs at each step. Key differences from SnapKV:
On LongBench, SnapKV with k=1024 outperforms H2O with k=4096 on 11 out of 16 benchmarks. SnapKV gets better compression with better quality because it compresses the right thing (the prompt cache) using a better signal (the prefill attention pattern).
StreamingLLM (Xiao et al., 2023) keeps only the first few "attention sink" tokens plus a sliding window of recent tokens. Everything in the middle is discarded.
SnapKV selects positions based on what the model actually attends to, so it can keep important information regardless of where it appears in the context.
ScissorHands (Liu et al., 2023) identifies "pivotal tokens" that exhibit consistent attention weight patterns across previous generation windows. Like H2O, it operates during generation and focuses on tokens appended during decoding. It misses the prompt cache entirely.
FastGen (Ge et al., 2023) implements a dual-phase algorithm: profile attention patterns during prompt encoding, then evict during generation based on those profiles. It is conceptually closer to SnapKV but more complex (four compression policies) and still primarily targets the generation-phase cache.
| Method | Compresses Prompt KV? | When? | Signal | Fine-tuning? |
|---|---|---|---|---|
| SnapKV | Yes | Once, end of prefill | Observation window attention + pooling | No |
| H2O | No — generation only | Every decode step | Cumulative attention | No |
| StreamingLLM | Yes (crude: drop middle) | Continuous | Position (first + last) | No |
| ScissorHands | No — generation only | Every decode step | Consistent attention | No |
| FastGen | Partial | Profile + evict | Attention profiling | No |
| TOVA | No — generation only | Every decode step | Current-step attention | No |
Because SnapKV operates only on the prompt cache and only at prefill time, it is orthogonal to generation-phase methods like H2O. You could combine SnapKV (compress the prompt cache) with H2O (manage the growing generation cache) for maximum efficiency. The paper also shows that SnapKV works with parallel decoding strategies like Medusa for additional speedups.
SnapKV requires minimal code changes. The authors report that their HuggingFace implementation modifies only a few lines of code in the attention module — specifically, adding the snap_kv function call after the prefill forward pass computes the KV cache. For vLLM, the integration is similarly lightweight: intercept the KV cache after prefill, run the voting/pooling/gather, and replace the cache before generation begins.
The key implementation detail: SnapKV operates on the already-computed attention weights and KV states from prefill. It does not require a separate forward pass or any additional model computation. The observation window attention is extracted from the prefill computation that was going to happen anyway.