From the O(n²) bottleneck to the engineering tricks that make modern LLMs fast — MHA, MQA, GQA, and FlashAttention.
You're chatting with an LLM. Short messages get instant replies. But paste in a long document and ask a question about it — suddenly the model takes 10x longer. The response trickles out word by word instead of flowing. Why?
The answer is hidden inside the attention mechanism — the mathematical core of every transformer. Attention lets each output token "look at" every input token to decide what's relevant. When there are 128 tokens, that's 128 × 128 = 16,384 comparisons. When there are 8,192 tokens, it's 8,192 × 8,192 = 67 million comparisons. The cost grows with the square of the sequence length.
This is not a minor engineering inconvenience. It is THE fundamental bottleneck in modern language models. Every technique in this lesson — multi-head attention, KV caching, grouped-query attention, Flash Attention — exists because someone needed to tame this quadratic beast.
Let's make the cost concrete. For each token we generate, the model must compute a dot product between that token's query and every previous token's key. If the sequence has n tokens, that's n dot products per new token. But the model also had to do this for every previous token during the initial processing (the "prefill" phase). Total comparisons: n × n = n².
Doubling the sequence length doesn't double the work — it quadruples it. Let's count by hand.
4 tokens:
8 tokens:
16 tokens:
1,024 tokens:
8,192 tokens (typical LLM context window):
From 4 tokens to 8,192 tokens — a 2,048× increase in sequence length causes a 2,048² = 4,194,304× increase in dot products. That's the quadratic wall.
Each dot product between a query and a key involves dk multiplications and dk additions (one multiply-add per dimension). So the total FLOPs for the attention score computation is approximately:
The factor of 2 accounts for both the Q·KT computation and the (scores)·V multiplication, each contributing roughly n²·dk FLOPs. Let's plug in dk = 512 (a common dimension for 7B-class models):
| Sequence Length (n) | n² | FLOPs (2·n²·512) | Ratio vs n=512 |
|---|---|---|---|
| 512 | 262,144 | 268 million | 1× |
| 1,024 | 1,048,576 | 1.07 billion | 4× |
| 2,048 | 4,194,304 | 4.29 billion | 16× |
| 4,096 | 16,777,216 | 17.2 billion | 64× |
| 8,192 | 67,108,864 | 68.7 billion | 256× |
Going from 512 to 2,048 tokens (4× longer) costs 16× more FLOPs. Going from 512 to 8,192 (16× longer) costs 256× more. This is why a 128K-context model needs special engineering tricks just to be usable.
The simulation below shows token generation at different sequence lengths. Use the slider to change the sequence length. Watch how the time per token and total FLOPs climb quadratically. The bar chart on the right makes the n² scaling viscerally obvious.
Drag the sequence length slider and watch the FLOPs counter and per-token bars respond. Each bar represents the cost of generating one token at that sequence position. The bars grow linearly with position because each new token attends to all previous tokens — and the sum of that linear growth is the quadratic total.
python def attention_flops(n, d_k=512): """Approximate FLOPs for one attention layer.""" qk_flops = n * n * d_k # Q @ K^T sv_flops = n * n * d_k # scores @ V return qk_flops + sv_flops # total ≈ 2·n²·d_k for n in [512, 1024, 2048, 4096, 8192]: flops = attention_flops(n) print(f"n={n:>5}: {flops/1e9:.2f}B FLOPs ({flops/attention_flops(512):.0f}x vs n=512)") # Output: # n= 512: 0.27B FLOPs (1x vs n=512) # n= 1024: 1.07B FLOPs (4x vs n=512) # n= 2048: 4.29B FLOPs (16x vs n=512) # n= 4096: 17.18B FLOPs (64x vs n=512) # n= 8192: 68.72B FLOPs (256x vs n=512)
This lesson builds your understanding of attention from the ground up, then shows you how the field has attacked the quadratic problem from every angle:
By the end, you'll understand why every modern LLM uses grouped-query attention, what Flash Attention actually does at the hardware level, and how to calculate whether a model's KV cache will fit on your GPU. Let's start by building attention from scratch.
Before we can fix attention, we need to understand exactly what it computes. Not the high-level "tokens look at other tokens" hand-wave — the actual matrix operations, from input to output, with every intermediate number visible.
The mechanism is called scaled dot-product attention, and it was introduced in the "Attention Is All You Need" paper (Vaswani et al., 2017). It has three inputs — queries, keys, and values — and produces a weighted combination of the values, where the weights come from comparing queries to keys.
Think of it like a library lookup. You walk in with a question (your query). Each book on the shelf has a label (its key). You compare your question to every label, figure out which books are most relevant (the attention weights), then read those books (the values) in proportion to their relevance. The output is a weighted blend of the relevant books' contents.
The input to attention is a sequence of token embeddings — one vector per token. For a sequence of n tokens with embedding dimension dmodel, the input is a matrix X of shape (n, dmodel).
Attention doesn't work on the raw embeddings directly. Instead, it projects them through three learned weight matrices to create three separate representations:
| Name | Shape | What It Represents |
|---|---|---|
| Q (queries) | (n, dk) | "What am I looking for?" — each token's search query |
| K (keys) | (n, dk) | "What do I contain?" — each token's advertised label |
| V (values) | (n, dv) | "What information do I carry?" — each token's payload |
The projection matrices WQ, WK, and WV are learned during training. This is crucial — the model doesn't just compare raw token embeddings. It learns what to look for (Q), how to be found (K), and what to send when found (V). These can be entirely different functions of the same input.
Once we have Q, K, and V, attention proceeds in four steps:
Let's trace this with real numbers.
Q matrix (4 × 3) — what each token is looking for:
| dim 0 | dim 1 | dim 2 | |
|---|---|---|---|
| Token 0 | 1.0 | 0.0 | 1.0 |
| Token 1 | 0.0 | 1.0 | 0.0 |
| Token 2 | 1.0 | 1.0 | 0.0 |
| Token 3 | 0.0 | 0.0 | 1.0 |
K matrix (4 × 3) — what each token advertises:
| dim 0 | dim 1 | dim 2 | |
|---|---|---|---|
| Token 0 | 1.0 | 0.0 | 0.0 |
| Token 1 | 0.0 | 1.0 | 1.0 |
| Token 2 | 1.0 | 1.0 | 0.0 |
| Token 3 | 0.0 | 0.0 | 1.0 |
V matrix (4 × 3) — the information payload:
| dim 0 | dim 1 | dim 2 | |
|---|---|---|---|
| Token 0 | 0.5 | 1.0 | 0.0 |
| Token 1 | 1.0 | 0.0 | 0.5 |
| Token 2 | 0.0 | 0.5 | 1.0 |
| Token 3 | 0.5 | 0.5 | 0.5 |
Step 1: Compute Q · KT
Each entry (i, j) is the dot product between query i and key j. Let's compute row 0 (token 0's query against all keys):
Token 0's query matches all keys equally! Now row 1 (token 1's query):
Token 1 strongly matches keys 1 and 2, ignores keys 0 and 3. Continuing for all rows, the full score matrix is:
| Q\K | Key 0 | Key 1 | Key 2 | Key 3 |
|---|---|---|---|---|
| Query 0 | 1.0 | 1.0 | 1.0 | 1.0 |
| Query 1 | 0.0 | 1.0 | 1.0 | 0.0 |
| Query 2 | 1.0 | 1.0 | 2.0 | 0.0 |
| Query 3 | 0.0 | 1.0 | 0.0 | 1.0 |
Note row 2: token 2's query [1,1,0] dots strongly with key 2 [1,1,0] giving a score of 2.0. Token 2 attends most to itself — a common pattern in self-attention.
Step 2: Scale by √dk
We divide every score by √3 ≈ 1.732:
| Q\K | Key 0 | Key 1 | Key 2 | Key 3 |
|---|---|---|---|---|
| Query 0 | 0.577 | 0.577 | 0.577 | 0.577 |
| Query 1 | 0.000 | 0.577 | 0.577 | 0.000 |
| Query 2 | 0.577 | 0.577 | 1.155 | 0.000 |
| Query 3 | 0.000 | 0.577 | 0.000 | 0.577 |
Step 3: Softmax (row-wise)
Softmax converts each row to a probability distribution. For row 0, all values are equal (0.577), so softmax gives uniform weights: 0.25 each. For row 2, the 1.155 entry gets boosted:
Row 2: softmax([0.577, 0.577, 1.155, 0.000])
Token 2 puts 41% of its attention on itself (key 2), 23% each on keys 0 and 1, and only 13% on key 3. The full attention weight matrix:
| Q\K | Key 0 | Key 1 | Key 2 | Key 3 |
|---|---|---|---|---|
| Query 0 | 0.250 | 0.250 | 0.250 | 0.250 |
| Query 1 | 0.155 | 0.345 | 0.345 | 0.155 |
| Query 2 | 0.230 | 0.230 | 0.410 | 0.129 |
| Query 3 | 0.173 | 0.327 | 0.173 | 0.327 |
Each row sums to 1.0 — it's a probability distribution over which tokens to attend to. This is the "attention pattern." Token 0 attends uniformly. Token 1 focuses on tokens 1 and 2. Token 2 focuses on itself.
Step 4: Weighted sum of V
Each token's output is the weighted average of all value vectors, using the attention weights. For token 2:
Token 2's output is [0.410, 0.500, 0.590]. It's a blend of all value vectors, but skewed toward V2 (its own value) because it had the highest attention weight on itself. The output "carries information" from the tokens that scored highest.
Without the √dk scaling, the dot products grow with the dimension of the key vectors. Here's why this is a problem.
If the entries of Q and K are roughly standard normal (mean 0, variance 1), then each dot product q·k is the sum of dk products of independent standard normals. Each product has mean 0 and variance 1. By the central limit theorem, the sum has mean 0 and variance dk, so the standard deviation is √dk.
For dk = 64 (a typical head size), that means dot products have standard deviation √64 = 8. For dk = 512, it's √512 ≈ 22.6. Values of magnitude 20+ push softmax into extreme saturation — one entry gets weight ≈ 1.0, everything else gets ≈ 0.0. The output becomes nearly a hard lookup of a single token instead of a soft blend.
Worse: softmax gradients vanish in the saturated regime. The model can't learn to adjust attention patterns if the gradients are dead. Dividing by √dk normalizes the dot products back to unit variance, keeping softmax in its sensitive region where gradients flow.
Click any token (the "query") to see its attention pattern. Arrows from that token to all others show the attention weights — thicker arrows mean higher weight. The heatmap below shows the full attention matrix.
Click a token at the top to select it as the query. Arrows show attention weights — thicker = higher weight. The heatmap shows the full n×n attention matrix. Hover over any cell to see the exact weight.
First, the pure NumPy implementation — every step visible:
python import numpy as np def attention(Q, K, V): """Scaled dot-product attention from scratch.""" d_k = Q.shape[-1] # Step 1: raw scores scores = Q @ K.T # (n, n) # Step 2: scale scores = scores / np.sqrt(d_k) # prevents softmax saturation # Step 3: softmax (row-wise) exp_scores = np.exp(scores - scores.max(axis=-1, keepdims=True)) weights = exp_scores / exp_scores.sum(axis=-1, keepdims=True) # Step 4: weighted sum of values output = weights @ V # (n, d_v) return output, weights # Test with our hand-calculated example Q = np.array([[1,0,1],[0,1,0],[1,1,0],[0,0,1]], dtype=np.float32) K = np.array([[1,0,0],[0,1,1],[1,1,0],[0,0,1]], dtype=np.float32) V = np.array([[.5,1,0],[1,0,.5],[0,.5,1],[.5,.5,.5]], dtype=np.float32) out, wts = attention(Q, K, V) print("Attention weights:\n", wts.round(3)) print("Output:\n", out.round(3))
Now the PyTorch one-liner equivalent:
python import torch import torch.nn.functional as F Q = torch.tensor([[1,0,1],[0,1,0],[1,1,0],[0,0,1]], dtype=torch.float32) K = torch.tensor([[1,0,0],[0,1,1],[1,1,0],[0,0,1]], dtype=torch.float32) V = torch.tensor([[.5,1,0],[1,0,.5],[0,.5,1],[.5,.5,.5]], dtype=torch.float32) # PyTorch's built-in (handles batches, masking, dropout too) out = F.scaled_dot_product_attention( Q.unsqueeze(0), K.unsqueeze(0), V.unsqueeze(0) ).squeeze(0) print(out.round(decimals=3)) # matches our hand calculation
One attention pattern isn't enough. Consider the sentence "The cat sat on the mat because it was tired." The word "it" needs to simultaneously attend to "cat" (to resolve the pronoun) and "sat" (to understand what action "it" performed). A single set of Q, K, V projections can only produce one attention pattern — one way of routing information.
Multi-head attention solves this by running multiple attention operations in parallel, each with its own learned Q, K, V projections. Each head can specialize: one head might focus on local syntax (nearby words), another on long-range semantic connections (pronouns to antecedents), another on positional patterns (sentence boundaries). The outputs are concatenated and projected back to the original dimension.
Here's the key insight that makes multi-head attention efficient: we don't add more parameters for more heads. We split the existing dimension across heads.
If dmodel = 512 and we want H = 8 heads, each head gets:
Each head operates on 64-dimensional queries, keys, and values instead of 512-dimensional ones. Since the total dimension across all heads is still 8 × 64 = 512, the total computation is the same as a single head with dk = 512.
Let's trace the exact tensor shapes through multi-head attention for a concrete example: batch size B = 1, sequence length n = 6, dmodel = 512, H = 8 heads.
nn.MultiheadAttention.
Step 1: Project Q, K, V
Step 2: Reshape into heads
Step 3: Attention per head
Step 4: Concatenate heads
Step 5: Output projection
Each projection matrix maps dmodel to dmodel (we project all heads at once, then reshape):
Total: 4 × 262,144 = 1,048,576 parameters (about 1 million).
Now the crucial comparison. What if we used a single head with dk = 512?
The parameter count is the same because we just split the projection into smaller pieces per head. The compute cost (FLOPs) is also the same — we're doing the same matrix multiplications, just organized differently. Multi-head attention is a free lunch: more expressive attention patterns at zero extra cost.
In practice, different heads learn strikingly different attention patterns. Research on trained transformers (Voita et al., 2019; Clark et al., 2019) has identified several head types:
| Head Type | Pattern | What It Does |
|---|---|---|
| Positional | Attend to previous/next token | Local syntax, bigram patterns |
| Syntactic | Attend to subject-verb pairs | Long-range grammar agreement |
| Semantic | Attend to coreferent tokens | Pronoun resolution, entity tracking |
| Rare | Attend to separator/BOS tokens | "No-op" head — dump unused attention |
| Induction | Attend to token after previous match | In-context learning (copy patterns) |
With only 1 head, the model must compromise — one attention pattern for all these functions. With 8+ heads, each can specialize. This is why multi-head attention is universal in modern transformers.
The simulation below shows attention patterns for different numbers of heads. With 1 head, you see a single blurred pattern trying to capture everything. Increase the head count and watch specialization emerge — each head focuses on a different relationship type.
Adjust the number of heads to see how attention patterns specialize. Each colored block shows one head's attention pattern for a 6-token sequence. With more heads, each individual head sharpens to focus on specific relationships.
python import torch import torch.nn as nn import torch.nn.functional as F class MultiHeadAttention(nn.Module): def __init__(self, d_model, n_heads): super().__init__() assert d_model % n_heads == 0 self.d_model = d_model self.n_heads = n_heads self.d_k = d_model // n_heads # 512 // 8 = 64 # Project all heads at once for efficiency self.W_q = nn.Linear(d_model, d_model) # (512 → 512) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) def forward(self, x): B, N, D = x.shape # (batch, seq_len, d_model) # Step 1: Project Q = self.W_q(x) # (B, N, 512) K = self.W_k(x) V = self.W_v(x) # Step 2: Reshape into heads # (B, N, 512) → (B, N, 8, 64) → (B, 8, N, 64) Q = Q.view(B, N, self.n_heads, self.d_k).transpose(1, 2) K = K.view(B, N, self.n_heads, self.d_k).transpose(1, 2) V = V.view(B, N, self.n_heads, self.d_k).transpose(1, 2) # Step 3: Scaled dot-product attention (per head) scores = Q @ K.transpose(-2, -1) / (self.d_k ** 0.5) weights = F.softmax(scores, dim=-1) # (B, 8, N, N) attn_out = weights @ V # (B, 8, N, 64) # Step 4: Concatenate heads # (B, 8, N, 64) → (B, N, 8, 64) → (B, N, 512) attn_out = attn_out.transpose(1, 2).contiguous().view(B, N, D) # Step 5: Output projection return self.W_o(attn_out) # (B, N, 512) # Verify shapes mha = MultiHeadAttention(d_model=512, n_heads=8) x = torch.randn(1, 6, 512) # (batch=1, seq=6, dim=512) out = mha(x) print(out.shape) # torch.Size([1, 6, 512]) — same as input print(f"Total params: {sum(p.numel() for p in mha.parameters()):,}") # Total params: 1,050,624 (≈1M + biases)
During training, all tokens are processed in parallel — attention sees the full sequence at once. This is efficient because GPUs excel at large batch matrix multiplies. But during generation, we produce ONE token at a time. Each new token needs to attend to ALL previous tokens.
Without any tricks, generating token number 1000 would require recomputing the key and value vectors for all 999 previous tokens — even though those tokens haven't changed since the last step. That's a colossal waste.
The solution is the KV cache: store every previous token's key and value vectors in memory, and reuse them on each generation step. This turns attention from O(n²) recomputation into O(n) new work per step — but at a steep memory cost.
Let's trace exactly what happens when an LLM generates a 5-token response to a 3-token prompt. We'll call the prompt tokens [A, B, C] and the generated tokens [D, E, F, G, H].
Step 1 (prefill): Process prompt [A, B, C]
Step 2: Generate token D
Step 3: Generate token E
Without caching, by step N we recompute K and V for N tokens. The total K/V computations across all steps: 3 + 4 + 5 + 6 + 7 = 25. With caching: 3 + 1 + 1 + 1 + 1 = 7. That's a 3.6× savings for just 5 generated tokens. For 1000 generated tokens with a 1000-token prompt, the savings are over 500×.
For each token in the sequence, we store one key vector and one value vector per layer, per head. The cache size per token:
The factor of 2 is for K and V (we cache both). Let's compute this for Llama 2 7B:
| Parameter | Llama 2 7B Value |
|---|---|
| nlayers | 32 |
| nheads | 32 |
| dhead | 128 (= 4096 / 32) |
| Precision | FP16 (2 bytes per value) |
Bytes per token:
At different sequence lengths:
Now here's where it gets alarming. The model weights themselves for Llama 2 7B in FP16 are about 14 GB. That's fixed — load once, serve forever. But the KV cache is per user, per request.
For a batch of 32 concurrent users at 4,096 context:
The KV cache alone needs 5× more memory than the model weights. This is why you can't "just add more users" to a GPU — each concurrent user adds ~2 GB of KV cache, and GPU memory is finite.
| Model | Weights (FP16) | KV Cache / User @4K | Max Users on H100 (80 GB) |
|---|---|---|---|
| Llama 2 7B | 14 GB | 2.15 GB | ~30 |
| Llama 2 13B | 26 GB | 3.28 GB | ~16 |
| Llama 2 70B | 140 GB | 10.0 GB | 0 (needs multi-GPU) |
For the 70B model, the weights alone don't fit on a single H100. Even with tensor parallelism across 2 GPUs (160 GB total), after loading the 140 GB model, you have only 20 GB left — enough for just 2 concurrent users at 4K context. This is why KV cache optimization (the topic of the next chapters) is the most active area of LLM systems engineering.
The simulation below shows the KV cache growing in real time as tokens are generated. Click "Generate" to produce one token at a time and watch the cache expand. The memory counter shows exactly how many bytes are consumed.
Click "Generate" to add tokens one at a time. Blue bars are K cache entries, green bars are V cache entries (one per layer). Watch the memory counter climb. Use the model selector to see how different model sizes affect the cache. "Reset" clears the sequence.
python def kv_cache_size_gb( n_layers, n_heads, d_head, seq_len, n_users=1, bytes_per_param=2, # FP16 ): """Calculate KV cache memory in GB.""" bytes_per_token = 2 * n_layers * n_heads * d_head * bytes_per_param total_bytes = bytes_per_token * seq_len * n_users return total_bytes / (1024 ** 3) # Llama 2 family models = { "7B": (32, 32, 128), # (layers, heads, d_head) "13B": (40, 40, 128), "70B": (80, 64, 128), } print(f"{'Model':>6} {'SeqLen':>7} {'Users':>6} {'KV Cache':>10} {'Weights':>10} {'Total':>10}") print("-" * 55) for name, (nl, nh, dh) in models.items(): weight_gb = {"7B": 14, "13B": 26, "70B": 140}[name] for seq in [2048, 4096, 8192]: for users in [1, 32]: kv = kv_cache_size_gb(nl, nh, dh, seq, users) total = kv + weight_gb print(f"{name:>6} {seq:>7} {users:>6} {kv:>9.2f}G {weight_gb:>9}G {total:>9.1f}G") # Key output lines: # 7B 4096 32 68.7G 14G 82.7G ← doesn't fit H100! # 70B 4096 1 10.0G 140G 150.0G ← needs 2+ GPUs # 70B 4096 32 320.0G 140G 460.0G ← needs 6+ H100s
The KV cache bottleneck explains almost every engineering decision in modern LLM deployment:
The rest of this lesson dives into each of these techniques. Every one of them exists because someone looked at the KV cache memory equation and found a term to shrink.
Chapter 3 showed us the KV cache is the memory bottleneck. Every attention head stores its own K and V vectors for every token generated so far. For a model with 32 heads, that's 32 separate copies of keys and 32 separate copies of values, all sitting in GPU memory, all growing with every new token.
Here's the first clever fix: what if all 32 attention heads shared a single set of K and V vectors? Each head still asks its own question (different Q projections), but they all search through the same keys and retrieve from the same values. That's Multi-Query Attention (MQA), proposed by Noam Shazeer in 2019.
The idea sounds radical — surely different heads need different keys? But remember: Q is the "question," and K is the "index." If every head asks a different question, they'll get different attention patterns even against the same set of keys. The queries are still fully expressive; we're just sharing the lookup table.
In standard Multi-Head Attention (MHA), the input X gets projected into Q, K, and V per head. If there are H heads and each head has dimension dh, the K and V projection matrices are each of size [dmodel × (H · dh)]. That's one dh-dimensional key vector and one dh-dimensional value vector per head, per token.
In MQA, the K and V projection matrices shrink to [dmodel × dh] — just one dh-dimensional key vector and one dh-dimensional value vector, shared across all H heads. The Q projection stays the same: each head still gets its own query.
Let's get concrete with Llama 2 7B numbers. The model has 32 attention heads, each with dhead = 128. Weights are stored in float16 (2 bytes per value). The KV cache stores K and V for every token generated so far.
MHA KV cache per token:
MQA KV cache per token:
That's a 32× reduction — exactly equal to the number of heads.
Now scale it up. With a context window of 4,096 tokens:
Now consider serving 64 concurrent users (a modest production workload):
For larger models the savings are even more dramatic. PaLM 540B with 48 heads? MQA shrinks the KV cache by 48×. At scale, this is the difference between needing a cluster and fitting on a single GPU.
Left: Multi-Head Attention (each head has its own K,V). Right: Multi-Query Attention (all heads share one K,V). Toggle to compare memory usage.
The obvious question: if all heads share the same K and V, don't they all attend to the same things? Don't we lose the diversity that makes multi-head attention powerful?
Surprisingly, no — or at least, not much. Each head still has its own Q projection, which means each head asks a different question of the shared keys. Head 1 might project its query to emphasize syntactic features; head 12 might emphasize positional features. They'll compute different attention weights (softmax(QhKT/√d)) because the Qh differs, even though K is the same for all.
Empirically, MQA drops quality by roughly 0.5–1% on standard benchmarks compared to MHA. That's a tiny cost for a 32× memory reduction. The tradeoff is overwhelmingly worthwhile for production inference.
Here's how MQA modifies the standard MHA forward pass. The key change is in the projection dimensions: K and V are projected to dhead (not H × dhead), then broadcast across all heads during the attention computation.
python import torch import torch.nn as nn class MultiQueryAttention(nn.Module): def __init__(self, d_model, n_heads): super().__init__() self.n_heads = n_heads self.d_head = d_model // n_heads # Q: per-head projections (standard) self.W_q = nn.Linear(d_model, d_model) # [d, H*d_h] # K, V: SINGLE head projection (the MQA trick) self.W_k = nn.Linear(d_model, self.d_head) # [d, d_h] self.W_v = nn.Linear(d_model, self.d_head) # [d, d_h] self.W_o = nn.Linear(d_model, d_model) def forward(self, x): B, T, D = x.shape H, d_h = self.n_heads, self.d_head # Project Q, K, V q = self.W_q(x).view(B, T, H, d_h).transpose(1, 2) # [B, H, T, d_h] k = self.W_k(x).unsqueeze(1) # [B, 1, T, d_h] v = self.W_v(x).unsqueeze(1) # [B, 1, T, d_h] # k and v broadcast across H heads automatically scores = (q @ k.transpose(-2, -1)) / (d_h ** 0.5) # [B, H, T, T] attn = scores.softmax(dim=-1) out = (attn @ v).transpose(1, 2).reshape(B, T, D) # [B, T, D] return self.W_o(out)
The critical lines are the W_k and W_v projections: they output dhead
dimensions (one head's worth), then unsqueeze(1) adds the head dimension
so PyTorch's broadcasting does the rest. Each of the H heads computes
QhKT using the same K, without any copying.
MQA is aggressive. All heads share a single set of keys and values. For most large models, the quality drop is negligible. But some teams found it too steep — especially for models in the 7B–13B range where every fraction of a percent on benchmarks matters for competitive positioning.
What if we could dial the tradeoff between MHA and MQA? Instead of giving every head its own KV (MHA) or sharing one KV across all heads (MQA), what if we put heads into groups and shared one KV per group?
That's Grouped-Query Attention (GQA), introduced by Ainslie et al. at Google in 2023. Think of it as a knob: turn it all the way left and you get MQA (one group). Turn it all the way right and you get MHA (every head is its own group). GQA sits anywhere in between.
Say your model has H = 32 query heads and you choose G = 8 KV groups. That means every group of 32/8 = 4 query heads shares a single K and V. The model has 8 independent K projection vectors and 8 independent V projection vectors, instead of 32 (MHA) or 1 (MQA).
During attention, heads 0–3 all use KV group 0. Heads 4–7 use KV group 1. And so on. Within each group, the mechanism is identical to MQA: different queries, shared keys and values.
Llama 2 70B uses GQA with 64 query heads and 8 KV groups. Each head has dhead = 128. Let's compute the KV cache savings.
Heads per group: 64 query heads / 8 groups = 8 query heads per KV group.
GQA-8 KV cache per token:
MHA (for comparison) KV cache per token:
MQA (for comparison) KV cache per token:
GQA-8 gives an 8× reduction over MHA. Not as extreme as MQA's 64×, but with much better quality retention. Here's the full comparison at 4,096 tokens context:
| Method | KV Groups | KV/Token | KV @ 4K ctx | 64 Users |
|---|---|---|---|---|
| MHA | 64 | 32,768 B | 128 MB | 8.0 GB |
| GQA-8 | 8 | 4,096 B | 16 MB | 1.0 GB |
| MQA | 1 | 512 B | 2 MB | 128 MB |
At 64 concurrent users, MHA eats 8 GB of VRAM just for the KV cache — half of a 16 GB card, before model weights even load. GQA-8 needs just 1 GB. MQA needs 128 MB. The sweet spot depends on your model size and serving constraints, but for most 70B-class models, GQA-8 is the winner.
Drag the slider to move between MQA (1 group), GQA, and MHA (all heads). Watch how heads regroup and memory changes.
Here's a remarkable fact: you can convert an already-trained MHA model to GQA without retraining from scratch. The recipe is simple: take the K and V weight matrices from each group of heads and average them.
Say you have 64 heads and want 8 groups. Heads 0–7 form group 0. Take their 8 K weight matrices, average them element-by-element, and use that average as the single K matrix for the group. Do the same for V. Repeat for all 8 groups.
Then fine-tune the model for a few thousand steps to let it adapt to the shared parameters. Llama 2 70B was originally trained with MHA. Meta converted it to GQA-8 using exactly this procedure and recovered 97% of the original quality after a short fine-tuning run. The KV cache shrank 8× for a 3% quality cost that was further recovered during fine-tuning.
The key operation in GQA is the group repeat: after projecting K and V
to G groups, we need to "expand" them so each query head has matching K and V tensors.
PyTorch handles this with repeat_interleave or manual reshaping.
python import torch import torch.nn as nn class GroupedQueryAttention(nn.Module): def __init__(self, d_model, n_heads, n_kv_groups): super().__init__() self.n_heads = n_heads self.n_kv = n_kv_groups self.d_head = d_model // n_heads self.heads_per_group = n_heads // n_kv_groups # Q: full per-head projections self.W_q = nn.Linear(d_model, n_heads * self.d_head) # K, V: only n_kv_groups projections self.W_k = nn.Linear(d_model, n_kv_groups * self.d_head) self.W_v = nn.Linear(d_model, n_kv_groups * self.d_head) self.W_o = nn.Linear(d_model, d_model) def repeat_kv(self, x): # x: [B, G, T, d_h] -> [B, H, T, d_h] # Repeat each KV group heads_per_group times B, G, T, d = x.shape x = x[:, :, None, :, :].expand(B, G, self.heads_per_group, T, d) return x.reshape(B, self.n_heads, T, d) def forward(self, x): B, T, D = x.shape d_h = self.d_head q = self.W_q(x).view(B, T, self.n_heads, d_h).transpose(1, 2) k = self.W_k(x).view(B, T, self.n_kv, d_h).transpose(1, 2) v = self.W_v(x).view(B, T, self.n_kv, d_h).transpose(1, 2) # Expand K, V to match query head count k = self.repeat_kv(k) # [B, H, T, d_h] v = self.repeat_kv(v) # [B, H, T, d_h] scores = (q @ k.transpose(-2, -1)) / (d_h ** 0.5) attn = scores.softmax(dim=-1) out = (attn @ v).transpose(1, 2).reshape(B, T, D) return self.W_o(out)
The repeat_kv method is the core of GQA: it takes the G-grouped K or V
tensor and repeats each group H/G times so the shapes align for the batched matrix
multiply. In practice, most frameworks avoid actual memory copies by using
expand (which creates a view, not a copy) — so the memory savings
are real, not just theoretical.
MQA and GQA attacked the KV cache size per token — shrinking how much memory each position needs. But there's a second dimension to the problem: the number of tokens each position attends to. In standard attention, every token looks at every previous token. That's the n×n attention matrix that makes the whole thing quadratic.
What if each token only looked at the W nearest tokens instead of all n? That's Sliding Window Attention — also called local attention. It caps the attention cost per token at W, making the total cost O(n · W) instead of O(n²). When the sequence is long and the window is moderate, the savings are enormous.
This isn't a new idea — local attention patterns appeared in research as early as 2019. But it became famous when Mistral 7B shipped with sliding window attention as a key architectural choice, showing that you could match or beat much larger models by being smart about what each token actually needs to see.
Think about how language works. When you read the word "running" in "the dog is running through the park," you need the nearby tokens to understand it. You almost never need token #47 from 10,000 positions ago. Language is overwhelmingly local: subjects, verbs, and objects cluster together. Pronouns reference things within a few sentences. Semantic coherence is maintained over paragraphs, not pages.
There are exceptions — a reference to a character introduced chapters ago, or a variable defined hundreds of lines up in code. But these long-range dependencies are rare compared to local ones. Sliding window attention bets that most attention mass lands in a local neighborhood, and that bet pays off.
Visualize the attention matrix as a grid. Rows are query positions (tokens asking questions). Columns are key positions (tokens being attended to). In standard causal attention, the lower-left triangle is filled — each token can attend to itself and all previous tokens.
In sliding window attention with window size W, each token can only attend to the W most recent tokens (including itself). The filled region shrinks from a triangle to a diagonal band of width W. Tokens outside the band get attention weight zero.
Compare mask patterns: full causal, sliding window, and hybrid (Longformer-style with global tokens). Each cell is one attention entry.
Let's count the number of attention entries (cells in the matrix) for different configurations.
Full causal attention, n = 16,384:
The lower triangle has n · (n+1) / 2 entries. For n = 16,384:
Sliding window, n = 16,384, W = 4,096:
Each of the first W tokens attends to all previous tokens (building up to the full window). After that, each token attends to exactly W tokens. A close upper bound:
At n = 32,768 with W = 4,096:
The longer the sequence relative to the window, the bigger the savings. At n = 131,072 with W = 4,096: full attention requires ∼8.6 billion entries, while sliding window needs ∼537 million. That's a 16× reduction. The savings scale as n / (2W) for large n.
The obvious worry: if each token only sees W positions, doesn't the model lose access to information beyond the window? How can it handle a question like "Who was the main character introduced in chapter 1?" when you're 50,000 tokens later?
The answer is layer stacking. Each layer extends the effective reach by W positions. Think of it like a relay race:
Mistral 7B has 32 layers and window size W = 4,096. Its effective receptive field is:
That's 131K tokens — far beyond the model's actual context window. The sliding window doesn't limit information flow; it just makes information travel through the network layer by layer rather than arriving all at once. Empirically, this works just as well for the vast majority of tasks.
Sliding window attention enables another powerful optimization: a rolling buffer KV cache. In standard attention, the KV cache grows with every new token and never shrinks. At position 10,000, you're storing KV entries for all 10,000 previous positions.
With a sliding window, token 10,000 only attends to positions [5,905, 10,000] (for W = 4,096). Positions before 5,905 are outside every current and future token's window. They will never be attended to again. So we can evict them from the cache.
The result: the KV cache is bounded at W entries, regardless of how long the generation runs. Instead of KV cache growing linearly with sequence length, it stays constant. For long-form generation (writing a novel, processing a codebase), this is transformative.
python import torch def sliding_window_mask(seq_len, window_size): # Create causal mask first (lower triangle) causal = torch.ones(seq_len, seq_len, dtype=torch.bool).tril() # Create window mask: token i attends to [i-W+1, i] rows = torch.arange(seq_len).unsqueeze(1) # [n, 1] cols = torch.arange(seq_len).unsqueeze(0) # [1, n] window = (rows - cols) < window_size # within W positions # Combine: must be both causal AND within window mask = causal & window return mask # True = attend, False = mask out # Example: 8 tokens, window size 3 mask = sliding_window_mask(8, 3) # Token 0 sees: [0] # Token 1 sees: [0, 1] # Token 2 sees: [0, 1, 2] # Token 3 sees: [1, 2, 3] <- window kicks in # Token 7 sees: [5, 6, 7]
Every variant we've seen so far still computes the n×n attention matrix — they just make it sparser (sliding window) or share the KV cache (MQA, GQA). The matrix itself, that softmax(QKT/√d) grid of attention weights, is still there. It's still quadratic.
What if we could avoid the n×n matrix entirely? Not approximate it with sparsity, not share components around it, but mathematically side-step it altogether?
That's what Linear Attention does. The key insight is startlingly simple: by changing the order of matrix multiplication, we can compute the same output without ever forming the n×n intermediate matrix. The cost drops from O(n²d) to O(nd²), and when d is much smaller than n (which it always is in practice), that's a massive win.
Let's start with standard attention. For one head, given queries Q (n×d), keys K (n×d), and values V (n×d):
The problem is the QKT term: (n×d) · (d×n) = n×n matrix. This requires O(n²d) operations and O(n²) memory.
Now, softmax is what forces us to compute QKT first — we need every entry of that matrix before we can normalize rows. But what if we replaced softmax with a simpler function that doesn't require the full matrix?
Replace softmax with a kernel function φ applied element-wise to Q and K:
Read that carefully. The parentheses changed. Instead of computing (φ(Q) · φ(K)T) · V, we compute φ(Q) · (φ(K)T · V). Same letters, different grouping. This is the associativity of matrix multiplication.
Why does the grouping matter? Let's trace the shapes:
| Operation | Standard Order | Linear Order |
|---|---|---|
| Step 1 | φ(Q)φ(K)T: (n×d)·(d×n) = n×n | φ(K)TV: (d×n)·(n×d) = d×d |
| Step 2 | (n×n)·V: (n×n)·(n×d) = n×d | φ(Q)·(d×d): (n×d)·(d×d) = n×d |
| Intermediate | n×n (grows with sequence!) | d×d (fixed!) |
| FLOPs | O(n²d) | O(nd²) |
The intermediate matrix went from n×n (grows quadratically with sequence length) to d×d (fixed at dhead², regardless of sequence length). This is the core of linear attention.
Let's work through a tiny example. n = 4 tokens, d = 3 dimensions. We'll skip the kernel φ (assume identity) and just compare the two multiplication orders.
Standard order: (QKT)V
Step 1 — QKT: multiply (4×3) by (3×4) = (4×4) matrix.
Step 2 — (4×4) · V: multiply (4×4) by (4×3) = (4×3) output.
Grand total: 48 + 48 = 96 operations. Intermediate matrix: 4×4 = 16 elements.
Linear order: Q(KTV)
Step 1 — KTV: multiply (3×4) by (4×3) = (3×3) matrix.
Step 2 — Q · (3×3): multiply (4×3) by (3×3) = (4×3) output.
Grand total: 36 + 36 = 72 operations. Intermediate matrix: 3×3 = 9 elements.
At n = 4, the saving is modest: 96 vs 72 operations. But watch what happens as n grows. The standard approach scales as n² while the linear approach scales as n:
| n | d | Standard (n²d) | Linear (nd²) | Speedup |
|---|---|---|---|---|
| 4 | 3 | 96 | 72 | 1.3× |
| 64 | 3 | 24,576 | 1,152 | 21× |
| 1,024 | 64 | 67M | 4.2M | 16× |
| 16,384 | 128 | 34.4B | 268M | 128× |
At n = 16,384 with d = 128 (typical for modern LLMs), linear attention is 128× faster than standard attention. The crossover where linear becomes cheaper is roughly n ≈ d — for any practical sequence length, linear wins.
Left: standard attention materializes the n×n matrix. Right: linear attention computes a d×d summary. Drag the sequence length slider and watch the intermediate matrix grow (or not).
We can't just remove softmax and call it a day. Softmax ensures attention weights are positive and sum to 1 — it's a proper probability distribution. If we want the associativity trick to work, we need a kernel function φ that preserves these properties (or close enough).
The original Linear Transformer (Katharopoulos et al., 2020) used a simple choice:
Where elu is the exponential linear unit: elu(x) = x if x > 0, ex − 1 if x ≤ 0. Adding 1 ensures all values are positive (matching softmax's positivity). The output is then normalized by dividing by the sum of φ(Q)φ(K)T along the key dimension.
Other kernels have been explored: random Fourier features (FAVOR+ in Performer), polynomial kernels, learned kernels. The choice of φ determines how well linear attention approximates the original softmax attention pattern.
If linear attention is so much faster, why isn't everyone using it? The answer is quality. Softmax attention creates sharp, peaked distributions — it can put 90% of the attention weight on a single token when that token is the relevant one. This sharp retrieval is what makes attention so powerful for tasks like question answering, code completion, and factual recall.
The kernel approximation φ produces smoother distributions. Instead of a sharp peak on one token, it spreads weight more broadly. For most generation tasks (continuing text, summarizing, chat), this barely matters. But for "needle in a haystack" retrieval — finding a specific fact in a long context — linear attention is measurably weaker.
This is an active area of research. Models like RWKV, RetNet, and Mamba (which uses a related selective state space mechanism) are closing the gap. Each generation gets closer to softmax quality while keeping linear or near-linear scaling. The bet in the research community is that subquadratic attention will eventually match softmax — but we're not there yet for all tasks.
python import torch import torch.nn.functional as F def linear_attention(Q, K, V): # Q, K, V: [batch, heads, seq_len, d_head] # Apply kernel: phi(x) = elu(x) + 1 (ensures positivity) Q = F.elu(Q) + 1 # [B, H, n, d] K = F.elu(K) + 1 # [B, H, n, d] # THE TRICK: compute K^T V first (d×d, not n×n!) KV = K.transpose(-2, -1) @ V # [B, H, d, d] # Then multiply Q by the d×d summary out = Q @ KV # [B, H, n, d] # Normalize (equivalent to softmax denominator) Z = Q @ K.transpose(-2, -1).sum(dim=-1, keepdim=True) # Simpler: Z = (Q * K.sum(dim=-2, keepdim=True)).sum(dim=-1, keepdim=True) return out / (Z + 1e-6) # avoid division by zero
The entire trick is on one line: KV = K.transpose(-2, -1) @ V. This
computes the d×d summary matrix. Then Q @ KV produces the output
without ever materializing an n×n matrix. The normalization term Z ensures
the output is properly scaled.
Every variant we've seen so far changes the attention algorithm — fewer heads, local windows, kernel tricks. FlashAttention takes a completely different approach: it computes exact standard attention, but reorganizes the computation to be friendly to the GPU's memory hierarchy. Same math, radically faster.
The insight is almost embarrassingly simple once you see it. Modern GPUs have two kinds of memory: a large, slow main memory (HBM — High Bandwidth Memory), and a tiny, blazing-fast on-chip cache (SRAM). Standard attention constantly shuffles enormous matrices between these two levels. FlashAttention restructures the work so that almost everything happens in the fast cache, and the huge intermediate matrix never gets written at all.
Think of HBM as a warehouse and SRAM as your workbench. The warehouse is huge (80 GB on an H100) but it takes time to haul things back and forth. Your workbench is tiny (about 50 MB) but everything on it is instantly accessible. Standard attention is like building furniture by carrying every plank to the warehouse after cutting it, then hauling it all back to assemble. FlashAttention keeps the planks on your workbench and only moves the finished piece once.
Concrete numbers for an NVIDIA H100:
| Level | Size | Bandwidth | Relative speed |
|---|---|---|---|
| HBM (main memory) | 80 GB | 3.35 TB/s | 1× (baseline) |
| SRAM (on-chip cache) | ~50 MB | ~50 TB/s | ~15× faster |
| Registers | ~20 MB total | ∞ (no transfer) | Instant |
The bottleneck in modern GPUs is rarely compute — it's memory bandwidth. Moving data to and from HBM is what makes attention slow. FlashAttention is an IO-aware algorithm: it minimizes the number of HBM reads and writes by doing as much work as possible while data sits in fast SRAM.
Let's trace what standard attention does for a single head. Inputs: Q, K, V ∈ ℝn × d, all sitting in HBM.
Count the HBM traffic: we read/write that n × n matrix three times (write S, read for softmax, read P for multiply). For n = 8192 and d = 128 in FP16 (2 bytes per value):
With 32 heads and 80 layers, that's 32 × 80 × 128 MB = 320 GB of HBM traffic just for the attention scores — in a single forward pass. And HBM bandwidth is the bottleneck. This is where all the time goes.
FlashAttention's core idea: never materialize the full n × n matrix. Instead, split Q into row blocks of size Br and K, V into column blocks of size Bc. Load one Q-block and one K,V-block into SRAM at a time. Compute that tile of the attention matrix, immediately use it, and accumulate the result. The tile lives only in SRAM — it never touches HBM.
SRAM must hold: one Q-block (Br × d), one K-block (Bc × d), one V-block (Bc × d), and the partial output block (Br × d). That's approximately 4 × B × d entries (if Br = Bc = B).
For an A100 with 192 KB of SRAM per SM, d = 128, FP16 (2 bytes):
For a sequence of n = 8192 tokens:
Each tile computes a 187 × 187 sub-matrix = 34,969 entries. That tiny block lives entirely in SRAM. Standard attention would need all 67 million entries in HBM simultaneously. FlashAttention needs only ~35,000 at any moment — a 1,920× reduction in peak memory.
There's a catch. Softmax requires the maximum over all scores in a row for numerical stability. The formula is:
This seems to require seeing ALL n scores before computing any output. If we're processing one tile at a time, we don't know the global max yet. How can we compute softmax?
The answer is the online softmax algorithm (Milakov & Gimelshein, 2018). It processes one block at a time, maintaining a running maximum and a running denominator. When a new tile arrives with a larger maximum, it rescales all previous results:
python # Online softmax — the trick that makes FlashAttention possible # Process tiles of scores one at a time m_prev = -inf # running maximum l_prev = 0.0 # running sum of exp(scores - max) O_prev = zeros(B_r, d) # running output accumulator for j in range(num_kv_blocks): # Load one K,V block from HBM into SRAM K_j = K[j * B_c : (j+1) * B_c] # shape: (B_c, d) V_j = V[j * B_c : (j+1) * B_c] # shape: (B_c, d) # Compute local attention scores for this tile S_ij = Q_block @ K_j.T # shape: (B_r, B_c) # Update running max m_new = max(m_prev, rowmax(S_ij)) # Rescale previous accumulations to new max correction = exp(m_prev - m_new) l_new = correction * l_prev + rowsum(exp(S_ij - m_new)) # Update output: rescale old + add new contribution P_ij = exp(S_ij - m_new) # local attention weights O_new = correction * O_prev + P_ij @ V_j m_prev, l_prev, O_prev = m_new, l_new, O_new # Final output: normalize by total denominator O = O_prev / l_prev # shape: (B_r, d)
The key insight: the correction = exp(m_prev - m_new) term
rescales all prior work when the max changes. If the new tile has a larger
maximum, previous exponentials were too large — multiply them down.
This produces bit-for-bit identical results to standard
softmax. No approximation whatsoever.
Putting it all together:
HBM traffic for FlashAttention: read Q once (n × d), read K,V once per Q-block pass (n × d each, but reused), write O once (n × d). The n × n attention matrix? Never written to HBM. Never read from HBM. It exists only briefly in SRAM, one tile at a time.
Watch the two approaches side by side. Standard attention bounces data back and forth to HBM constantly. FlashAttention loads tiles into SRAM and keeps them there. Click Step to advance through each stage of the computation.
Left: standard attention writes the full n×n matrix to HBM. Right: FlashAttention tiles the computation through SRAM. Red arrows = slow HBM transfers. Green arrows = fast SRAM operations.
Here are benchmarks from the FlashAttention-2 paper (Dao, 2023) on an A100-80GB GPU, with head dimension d = 128, 12 heads:
| Sequence length | Standard (ms) | FlashAttention-2 (ms) | Speedup | Memory savings |
|---|---|---|---|---|
| 512 | 0.8 | 0.5 | 1.6× | 4× |
| 1024 | 2.1 | 0.9 | 2.3× | 16× |
| 2048 | 7.6 | 2.4 | 3.2× | 64× |
| 4096 | 29 | 7.8 | 3.7× | 256× |
| 8192 | 115 | 24 | 4.8× | 1024× |
| 16384 | OOM | 85 | ∞ | ∞ |
The pattern: the longer the sequence, the bigger the win. At n = 16384, standard attention runs out of memory entirely while FlashAttention still fits comfortably. This is why every modern LLM uses FlashAttention — it's free performance with zero quality loss.
| Version | Year | Key improvement |
|---|---|---|
| FlashAttention | 2022 | Tiled computation, online softmax. 2-4× speedup. |
| FlashAttention-2 | 2023 | Better parallelism across GPU warps. 1.7-2× over FA1. |
| FlashAttention-3 | 2024 | Exploits H100 Tensor Cores, FP8 support. 1.5-2× over FA2. |
FlashAttention-3 on H100 can reach 740 TFLOPS for attention — 75% of the theoretical peak. Standard attention typically achieves less than 35% utilization. Same math, same answer, just dramatically better use of the hardware.
You've learned seven different approaches to attention — from full multi-head attention down to linear approximations, from multi-query compression to FlashAttention's memory-hierarchy trick. But which one should you actually use? It depends on your sequence length, model dimension, and quality requirements.
Let's race them.
This arena simulates the throughput, memory usage, and quality tradeoffs of every attention variant at once. Adjust the sliders to match your deployment scenario: short prompts or long context? Small model or frontier-scale? Then hit Race! and watch the variants compete.
Seven attention variants race on throughput, memory, and quality. Adjust parameters and hit Race! to see who wins for your deployment scenario.
At short sequences (512–1024): All variants perform similarly. The n² cost is small, so clever tricks don't save much. Flash-MHA gives the best quality at competitive speed. This is why BERT-era models used full MHA — sequences were short enough that it didn't matter.
At medium sequences (2048–4096): The pack starts to separate. GQA and MQA pull ahead on throughput because they have smaller KV caches (less memory traffic). FlashAttention gives MHA a lifeline. Sliding window is great if your task is local.
At long sequences (8192+): The gap is dramatic. Standard MHA without FlashAttention may OOM entirely. Linear attention dominates raw throughput with its O(n) scaling. Sliding window excels if the task doesn't need global context. GQA + FlashAttention is the modern sweet spot — near-MHA quality with 4–8× less KV memory.
On quality: MHA and FlashAttention are lossless (they compute the same thing). GQA-8 is near-lossless (<0.1% degradation in most benchmarks). MQA has slight degradation (0.5–1%). Linear attention has noticeable degradation (1–3%) on tasks requiring precise long-range recall. Sliding window has zero loss for local patterns but misses distant information.
| Variant | n = 512 | n = 4096 | n = 16384 | Scaling |
|---|---|---|---|---|
| MHA | ~1.0× | 0.3× | OOM | O(n²) |
| MQA | ~1.1× | 0.8× | 0.15× | O(n²) less KV IO |
| GQA-4 | ~1.05× | 0.65× | 0.12× | O(n²) less KV IO |
| GQA-8 | ~1.05× | 0.7× | 0.13× | O(n²) less KV IO |
| Sliding-4K | ~1.0× | ~0.9× | 0.75× | O(n · W) |
| Linear | ~0.95× | ~1.1× | 1.2× | O(n · d²) |
| Flash-MHA | ~1.2× | 0.85× | 0.25× | O(n²) less IO |
(Throughput relative to MHA at n = 512. "OOM" means out of memory on a single A100-80GB.)
KV cache memory per token per layer (dmodel = 4096, dhead = 128, H = 32, FP16):
| Variant | Formula | Bytes per token per layer | At n=8192 |
|---|---|---|---|
| MHA | 2 × H × dhead × 2 | 16,384 | 128 MB |
| GQA-8 | 2 × 8 × dhead × 2 | 4,096 | 32 MB |
| GQA-4 | 2 × 4 × dhead × 2 | 2,048 | 16 MB |
| MQA | 2 × 1 × dhead × 2 | 512 | 4 MB |
| Sliding-W | 2 × H × dhead × 2 (capped at W) | 16,384 (max 4K window) | 64 MB (capped) |
| Linear | 2 × H × dhead² × 2 (fixed) | 2,097,152 (fixed) | 16 MB (fixed) |
You've mastered the complete attention variant toolkit — from standard multi-head attention through all its descendants, down to the hardware tricks that make it fast. This chapter is your reference card. No new concepts. Just the decision guide, the comparison table, and where to go next.
Follow the path that matches your situation:
| Variant | KV Memory | Compute | Quality vs MHA | Used By |
|---|---|---|---|---|
| MHA | O(H · n · d) | O(n² · d) | Baseline | GPT-2/3, BERT, T5 |
| MQA | O(n · d) | O(n² · d) | −0.5–1% | PaLM, Falcon, StarCoder |
| GQA-G | O(G · n · d) | O(n² · d) | −0.1–0.5% | Llama 2/3, Mistral, Gemma |
| Sliding-W | O(H · W · d) | O(n · W · d) | ≈ same (local) | Mistral, Longformer, BigBird |
| Linear | O(H · d²) | O(n · d²) | −1–3% | RWKV, RetNet, Linear Transformer |
| FlashAttn | O(n) scratch | O(n² · d) | Exact | Everything modern |
| Symbol | Meaning | Typical values |
|---|---|---|
| n | Sequence length | 512–128K+ |
| d or dmodel | Model dimension | 512–8192 |
| dhead | Per-head dimension (d / H) | 64–128 |
| H | Number of query heads | 8–128 |
| G | Number of KV groups (GQA) | 1 (MQA) to H (MHA) |
| W | Window size (sliding window) | 512–4096 |
| Br, Bc | FlashAttention tile sizes | 64–256 (hardware-dependent) |
| Q, K, V | Query, Key, Value matrices | ∈ ℝn × dhead |
| S | Attention score matrix (QKT) | ∈ ℝn × n |
| O | Output matrix | ∈ ℝn × dhead |
| HBM | High Bandwidth Memory (GPU main memory) | 40–80 GB |
| SRAM | On-chip static RAM (GPU cache) | ~20–50 MB |
Attention variants don't exist in isolation. Here's where to go next:
| Paper | Year | Contribution |
|---|---|---|
| Vaswani et al., "Attention Is All You Need" | 2017 | Multi-head attention, the Transformer |
| Shazeer, "Fast Transformer Decoding: One Write-Head Is All You Need" | 2019 | Multi-Query Attention (MQA) |
| Ainslie et al., "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints" | 2023 | Grouped-Query Attention |
| Beltagy et al., "Longformer: The Long-Document Transformer" | 2020 | Sliding window + global attention |
| Katharopoulos et al., "Transformers are RNNs" | 2020 | Linear attention via kernel trick |
| Dao et al., "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" | 2022 | Tiled attention, online softmax |
| Dao, "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning" | 2023 | Improved warp-level parallelism |
| Shah et al., "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision" | 2024 | H100 optimization, FP8 support |