Training Foundations

Attention Variants

From the O(n²) bottleneck to the engineering tricks that make modern LLMs fast — MHA, MQA, GQA, and FlashAttention.

Prerequisites: Matrix multiplication + What a neural network layer does. That’s it.
11
Chapters
14+
Simulations
0
Assumed Knowledge

Chapter 0: Why Attention Has a Problem

You're chatting with an LLM. Short messages get instant replies. But paste in a long document and ask a question about it — suddenly the model takes 10x longer. The response trickles out word by word instead of flowing. Why?

The answer is hidden inside the attention mechanism — the mathematical core of every transformer. Attention lets each output token "look at" every input token to decide what's relevant. When there are 128 tokens, that's 128 × 128 = 16,384 comparisons. When there are 8,192 tokens, it's 8,192 × 8,192 = 67 million comparisons. The cost grows with the square of the sequence length.

This is not a minor engineering inconvenience. It is THE fundamental bottleneck in modern language models. Every technique in this lesson — multi-head attention, KV caching, grouped-query attention, Flash Attention — exists because someone needed to tame this quadratic beast.

The Quadratic Wall

Let's make the cost concrete. For each token we generate, the model must compute a dot product between that token's query and every previous token's key. If the sequence has n tokens, that's n dot products per new token. But the model also had to do this for every previous token during the initial processing (the "prefill" phase). Total comparisons: n × n = n².

Doubling the sequence length doesn't double the work — it quadruples it. Let's count by hand.

Hand Calculation: Counting Dot Products

Setup. Each token in the sequence produces a query vector and a key vector. Attention computes the dot product between every query and every key. For a sequence of n tokens, that's n queries × n keys = n² dot products. Let's count explicitly for small sequences.

4 tokens:

8 tokens:

16 tokens:

1,024 tokens:

8,192 tokens (typical LLM context window):

From 4 tokens to 8,192 tokens — a 2,048× increase in sequence length causes a 2,048² = 4,194,304× increase in dot products. That's the quadratic wall.

Worked Example: Real FLOPs

Each dot product between a query and a key involves dk multiplications and dk additions (one multiply-add per dimension). So the total FLOPs for the attention score computation is approximately:

FLOPs ≈ 2 · n² · dk

The factor of 2 accounts for both the Q·KT computation and the (scores)·V multiplication, each contributing roughly n²·dk FLOPs. Let's plug in dk = 512 (a common dimension for 7B-class models):

Sequence Length (n)FLOPs (2·n²·512)Ratio vs n=512
512262,144268 million
1,0241,048,5761.07 billion
2,0484,194,3044.29 billion16×
4,09616,777,21617.2 billion64×
8,19267,108,86468.7 billion256×

Going from 512 to 2,048 tokens (4× longer) costs 16× more FLOPs. Going from 512 to 8,192 (16× longer) costs 256× more. This is why a 128K-context model needs special engineering tricks just to be usable.

See the Quadratic Wall Live

The simulation below shows token generation at different sequence lengths. Use the slider to change the sequence length. Watch how the time per token and total FLOPs climb quadratically. The bar chart on the right makes the n² scaling viscerally obvious.

Token Generation Cost

Drag the sequence length slider and watch the FLOPs counter and per-token bars respond. Each bar represents the cost of generating one token at that sequence position. The bars grow linearly with position because each new token attends to all previous tokens — and the sum of that linear growth is the quadratic total.

Sequence Length 512

From Scratch: FLOPs Calculator

python
def attention_flops(n, d_k=512):
    """Approximate FLOPs for one attention layer."""
    qk_flops = n * n * d_k      # Q @ K^T
    sv_flops = n * n * d_k      # scores @ V
    return qk_flops + sv_flops  # total ≈ 2·n²·d_k

for n in [512, 1024, 2048, 4096, 8192]:
    flops = attention_flops(n)
    print(f"n={n:>5}: {flops/1e9:.2f}B FLOPs  ({flops/attention_flops(512):.0f}x vs n=512)")

# Output:
# n=  512: 0.27B FLOPs  (1x vs n=512)
# n= 1024: 1.07B FLOPs  (4x vs n=512)
# n= 2048: 4.29B FLOPs  (16x vs n=512)
# n= 4096: 17.18B FLOPs  (64x vs n=512)
# n= 8192: 68.72B FLOPs  (256x vs n=512)
The real bottleneck is MEMORY, not compute. You might think the problem is the matrix multiply itself. GPUs eat matrix multiplies for breakfast — an H100 does 990 TFLOPS of FP16. The real issue is the n×n attention matrix. At n=8,192 with FP16, that's 8192² × 2 bytes = 128 MB per layer, per head. A 32-head, 32-layer model needs 128 MB × 32 × 32 = 131 GB just for attention matrices. That doesn't fit in the GPU's fast SRAM (only ~20 MB on an H100), so it spills to slower HBM. This memory bottleneck — not the raw FLOPs — is what Flash Attention solves.

The Roadmap

This lesson builds your understanding of attention from the ground up, then shows you how the field has attacked the quadratic problem from every angle:

Chapter 1
Dot-Product Attention — the exact mechanics of Q, K, V
Chapter 2
Multi-Head Attention — parallel attention with learned subspaces
Chapter 3
The KV Cache — the memory wall in autoregressive generation
Chapters 4-7
MQA, GQA, Flash Attention, and the full engineering toolkit

By the end, you'll understand why every modern LLM uses grouped-query attention, what Flash Attention actually does at the hardware level, and how to calculate whether a model's KV cache will fit on your GPU. Let's start by building attention from scratch.

If you double the sequence length from 2,048 to 4,096 tokens, how do the attention FLOPs change?

Chapter 1: Dot-Product Attention Mechanics

Before we can fix attention, we need to understand exactly what it computes. Not the high-level "tokens look at other tokens" hand-wave — the actual matrix operations, from input to output, with every intermediate number visible.

The mechanism is called scaled dot-product attention, and it was introduced in the "Attention Is All You Need" paper (Vaswani et al., 2017). It has three inputs — queries, keys, and values — and produces a weighted combination of the values, where the weights come from comparing queries to keys.

Think of it like a library lookup. You walk in with a question (your query). Each book on the shelf has a label (its key). You compare your question to every label, figure out which books are most relevant (the attention weights), then read those books (the values) in proportion to their relevance. The output is a weighted blend of the relevant books' contents.

The Three Projections: Q, K, V

The input to attention is a sequence of token embeddings — one vector per token. For a sequence of n tokens with embedding dimension dmodel, the input is a matrix X of shape (n, dmodel).

Attention doesn't work on the raw embeddings directly. Instead, it projects them through three learned weight matrices to create three separate representations:

Q = X · WQ     K = X · WK     V = X · WV
NameShapeWhat It Represents
Q (queries)(n, dk)"What am I looking for?" — each token's search query
K (keys)(n, dk)"What do I contain?" — each token's advertised label
V (values)(n, dv)"What information do I carry?" — each token's payload

The projection matrices WQ, WK, and WV are learned during training. This is crucial — the model doesn't just compare raw token embeddings. It learns what to look for (Q), how to be found (K), and what to send when found (V). These can be entirely different functions of the same input.

The Four Steps of Attention

Once we have Q, K, and V, attention proceeds in four steps:

Step 1: Score
Compute Q · KT — a (n × n) matrix of raw attention scores
Step 2: Scale
Divide every score by √dk to control magnitude
Step 3: Softmax
Apply softmax row-wise — each query's scores become a probability distribution
Step 4: Aggregate
Multiply the attention weights by V to get the output
Attention(Q, K, V) = softmax(Q · KT / √dk) · V

Let's trace this with real numbers.

Hand Calculation: 4 Tokens, dk = 3

Setup. We have 4 tokens. After projection through WQ, WK, WV, we get the following Q, K, V matrices. Each row is one token's vector. dk = dv = 3 for simplicity.

Q matrix (4 × 3) — what each token is looking for:

dim 0dim 1dim 2
Token 01.00.01.0
Token 10.01.00.0
Token 21.01.00.0
Token 30.00.01.0

K matrix (4 × 3) — what each token advertises:

dim 0dim 1dim 2
Token 01.00.00.0
Token 10.01.01.0
Token 21.01.00.0
Token 30.00.01.0

V matrix (4 × 3) — the information payload:

dim 0dim 1dim 2
Token 00.51.00.0
Token 11.00.00.5
Token 20.00.51.0
Token 30.50.50.5

Step 1: Compute Q · KT

Each entry (i, j) is the dot product between query i and key j. Let's compute row 0 (token 0's query against all keys):

Token 0's query matches all keys equally! Now row 1 (token 1's query):

Token 1 strongly matches keys 1 and 2, ignores keys 0 and 3. Continuing for all rows, the full score matrix is:

Q\KKey 0Key 1Key 2Key 3
Query 01.01.01.01.0
Query 10.01.01.00.0
Query 21.01.02.00.0
Query 30.01.00.01.0

Note row 2: token 2's query [1,1,0] dots strongly with key 2 [1,1,0] giving a score of 2.0. Token 2 attends most to itself — a common pattern in self-attention.

Step 2: Scale by √dk

We divide every score by √3 ≈ 1.732:

Q\KKey 0Key 1Key 2Key 3
Query 00.5770.5770.5770.577
Query 10.0000.5770.5770.000
Query 20.5770.5771.1550.000
Query 30.0000.5770.0000.577

Step 3: Softmax (row-wise)

Softmax converts each row to a probability distribution. For row 0, all values are equal (0.577), so softmax gives uniform weights: 0.25 each. For row 2, the 1.155 entry gets boosted:

Row 2: softmax([0.577, 0.577, 1.155, 0.000])

Token 2 puts 41% of its attention on itself (key 2), 23% each on keys 0 and 1, and only 13% on key 3. The full attention weight matrix:

Q\KKey 0Key 1Key 2Key 3
Query 00.2500.2500.2500.250
Query 10.1550.3450.3450.155
Query 20.2300.2300.4100.129
Query 30.1730.3270.1730.327

Each row sums to 1.0 — it's a probability distribution over which tokens to attend to. This is the "attention pattern." Token 0 attends uniformly. Token 1 focuses on tokens 1 and 2. Token 2 focuses on itself.

Step 4: Weighted sum of V

Each token's output is the weighted average of all value vectors, using the attention weights. For token 2:

Token 2's output is [0.410, 0.500, 0.590]. It's a blend of all value vectors, but skewed toward V2 (its own value) because it had the highest attention weight on itself. The output "carries information" from the tokens that scored highest.

Why the √dk Scaling Matters

Without the √dk scaling, the dot products grow with the dimension of the key vectors. Here's why this is a problem.

If the entries of Q and K are roughly standard normal (mean 0, variance 1), then each dot product q·k is the sum of dk products of independent standard normals. Each product has mean 0 and variance 1. By the central limit theorem, the sum has mean 0 and variance dk, so the standard deviation is √dk.

For dk = 64 (a typical head size), that means dot products have standard deviation √64 = 8. For dk = 512, it's √512 ≈ 22.6. Values of magnitude 20+ push softmax into extreme saturation — one entry gets weight ≈ 1.0, everything else gets ≈ 0.0. The output becomes nearly a hard lookup of a single token instead of a soft blend.

Worse: softmax gradients vanish in the saturated regime. The model can't learn to adjust attention patterns if the gradients are dead. Dividing by √dk normalizes the dot products back to unit variance, keeping softmax in its sensitive region where gradients flow.

Scaling = keeping softmax healthy. Without scaling, large dk pushes attention toward hard argmax. With scaling, attention stays soft — meaning multiple tokens contribute to each output, and gradients can adjust the pattern during training. This is why the "scaled" in "scaled dot-product attention" is essential, not optional.

See It: Interactive Attention

Click any token (the "query") to see its attention pattern. Arrows from that token to all others show the attention weights — thicker arrows mean higher weight. The heatmap below shows the full attention matrix.

Attention Visualization

Click a token at the top to select it as the query. Arrows show attention weights — thicker = higher weight. The heatmap shows the full n×n attention matrix. Hover over any cell to see the exact weight.

From Scratch: Attention in Code

First, the pure NumPy implementation — every step visible:

python
import numpy as np

def attention(Q, K, V):
    """Scaled dot-product attention from scratch."""
    d_k = Q.shape[-1]

    # Step 1: raw scores
    scores = Q @ K.T               # (n, n)

    # Step 2: scale
    scores = scores / np.sqrt(d_k) # prevents softmax saturation

    # Step 3: softmax (row-wise)
    exp_scores = np.exp(scores - scores.max(axis=-1, keepdims=True))
    weights = exp_scores / exp_scores.sum(axis=-1, keepdims=True)

    # Step 4: weighted sum of values
    output = weights @ V            # (n, d_v)
    return output, weights

# Test with our hand-calculated example
Q = np.array([[1,0,1],[0,1,0],[1,1,0],[0,0,1]], dtype=np.float32)
K = np.array([[1,0,0],[0,1,1],[1,1,0],[0,0,1]], dtype=np.float32)
V = np.array([[.5,1,0],[1,0,.5],[0,.5,1],[.5,.5,.5]], dtype=np.float32)

out, wts = attention(Q, K, V)
print("Attention weights:\n", wts.round(3))
print("Output:\n", out.round(3))

Now the PyTorch one-liner equivalent:

python
import torch
import torch.nn.functional as F

Q = torch.tensor([[1,0,1],[0,1,0],[1,1,0],[0,0,1]], dtype=torch.float32)
K = torch.tensor([[1,0,0],[0,1,1],[1,1,0],[0,0,1]], dtype=torch.float32)
V = torch.tensor([[.5,1,0],[1,0,.5],[0,.5,1],[.5,.5,.5]], dtype=torch.float32)

# PyTorch's built-in (handles batches, masking, dropout too)
out = F.scaled_dot_product_attention(
    Q.unsqueeze(0), K.unsqueeze(0), V.unsqueeze(0)
).squeeze(0)
print(out.round(decimals=3))  # matches our hand calculation
Attention is NOT just "token similarity." The Q and K matrices are LEARNED projections — the model learns WHAT to look for (Q) and HOW to be found (K). Two tokens with identical embeddings can have completely different keys and queries. A period token might learn a query that asks "what is the subject of this sentence?" and a noun might learn a key that answers "I am a subject." Raw token similarity is just an initialization accident — after training, the Q/K projections encode complex linguistic patterns that have nothing to do with embedding distance.
Why do we divide the attention scores by √dk before applying softmax?

Chapter 2: Multi-Head Attention

One attention pattern isn't enough. Consider the sentence "The cat sat on the mat because it was tired." The word "it" needs to simultaneously attend to "cat" (to resolve the pronoun) and "sat" (to understand what action "it" performed). A single set of Q, K, V projections can only produce one attention pattern — one way of routing information.

Multi-head attention solves this by running multiple attention operations in parallel, each with its own learned Q, K, V projections. Each head can specialize: one head might focus on local syntax (nearby words), another on long-range semantic connections (pronouns to antecedents), another on positional patterns (sentence boundaries). The outputs are concatenated and projected back to the original dimension.

How the Split Works

Here's the key insight that makes multi-head attention efficient: we don't add more parameters for more heads. We split the existing dimension across heads.

If dmodel = 512 and we want H = 8 heads, each head gets:

dk = dv = dmodel / H = 512 / 8 = 64

Each head operates on 64-dimensional queries, keys, and values instead of 512-dimensional ones. Since the total dimension across all heads is still 8 × 64 = 512, the total computation is the same as a single head with dk = 512.

The Full Data Flow

Let's trace the exact tensor shapes through multi-head attention for a concrete example: batch size B = 1, sequence length n = 6, dmodel = 512, H = 8 heads.

Shape trace — every tensor, every step. Follow along and verify each shape. This is what happens inside nn.MultiheadAttention.

Step 1: Project Q, K, V

Step 2: Reshape into heads

Step 3: Attention per head

Step 4: Concatenate heads

Step 5: Output projection

Hand Calculation: Parameter Count

How many parameters does MHA have? Let's count for dmodel = 512, H = 8 heads. We need four weight matrices: WQ, WK, WV, and WO.

Each projection matrix maps dmodel to dmodel (we project all heads at once, then reshape):

Total: 4 × 262,144 = 1,048,576 parameters (about 1 million).

Now the crucial comparison. What if we used a single head with dk = 512?

The parameter count is the same because we just split the projection into smaller pieces per head. The compute cost (FLOPs) is also the same — we're doing the same matrix multiplications, just organized differently. Multi-head attention is a free lunch: more expressive attention patterns at zero extra cost.

What Each Head Learns

In practice, different heads learn strikingly different attention patterns. Research on trained transformers (Voita et al., 2019; Clark et al., 2019) has identified several head types:

Head TypePatternWhat It Does
PositionalAttend to previous/next tokenLocal syntax, bigram patterns
SyntacticAttend to subject-verb pairsLong-range grammar agreement
SemanticAttend to coreferent tokensPronoun resolution, entity tracking
RareAttend to separator/BOS tokens"No-op" head — dump unused attention
InductionAttend to token after previous matchIn-context learning (copy patterns)

With only 1 head, the model must compromise — one attention pattern for all these functions. With 8+ heads, each can specialize. This is why multi-head attention is universal in modern transformers.

See It: Multi-Head Patterns

The simulation below shows attention patterns for different numbers of heads. With 1 head, you see a single blurred pattern trying to capture everything. Increase the head count and watch specialization emerge — each head focuses on a different relationship type.

Multi-Head Attention Patterns

Adjust the number of heads to see how attention patterns specialize. Each colored block shows one head's attention pattern for a 6-token sequence. With more heads, each individual head sharpens to focus on specific relationships.

Number of Heads 4

From Scratch: Full MHA Implementation

python
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads  # 512 // 8 = 64

        # Project all heads at once for efficiency
        self.W_q = nn.Linear(d_model, d_model)  # (512 → 512)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x):
        B, N, D = x.shape  # (batch, seq_len, d_model)

        # Step 1: Project
        Q = self.W_q(x)  # (B, N, 512)
        K = self.W_k(x)
        V = self.W_v(x)

        # Step 2: Reshape into heads
        # (B, N, 512) → (B, N, 8, 64) → (B, 8, N, 64)
        Q = Q.view(B, N, self.n_heads, self.d_k).transpose(1, 2)
        K = K.view(B, N, self.n_heads, self.d_k).transpose(1, 2)
        V = V.view(B, N, self.n_heads, self.d_k).transpose(1, 2)

        # Step 3: Scaled dot-product attention (per head)
        scores = Q @ K.transpose(-2, -1) / (self.d_k ** 0.5)
        weights = F.softmax(scores, dim=-1)  # (B, 8, N, N)
        attn_out = weights @ V                   # (B, 8, N, 64)

        # Step 4: Concatenate heads
        # (B, 8, N, 64) → (B, N, 8, 64) → (B, N, 512)
        attn_out = attn_out.transpose(1, 2).contiguous().view(B, N, D)

        # Step 5: Output projection
        return self.W_o(attn_out)  # (B, N, 512)

# Verify shapes
mha = MultiHeadAttention(d_model=512, n_heads=8)
x = torch.randn(1, 6, 512)  # (batch=1, seq=6, dim=512)
out = mha(x)
print(out.shape)  # torch.Size([1, 6, 512]) — same as input
print(f"Total params: {sum(p.numel() for p in mha.parameters()):,}")
# Total params: 1,050,624 (≈1M + biases)
Multi-head attention does NOT use more compute than single-head attention with the same dmodel. The dimensions just get split: 8 heads × 64d = 1 head × 512d. The total FLOPs for the QKT computation are n² × dmodel regardless of how you partition it into heads. More heads = more expressive = zero extra cost. The only overhead is the output projection WO, which both single-head and multi-head versions need.
If dmodel = 1024 and we use 16 heads, what is dk per head?

Chapter 3: The KV Cache Bottleneck

During training, all tokens are processed in parallel — attention sees the full sequence at once. This is efficient because GPUs excel at large batch matrix multiplies. But during generation, we produce ONE token at a time. Each new token needs to attend to ALL previous tokens.

Without any tricks, generating token number 1000 would require recomputing the key and value vectors for all 999 previous tokens — even though those tokens haven't changed since the last step. That's a colossal waste.

The solution is the KV cache: store every previous token's key and value vectors in memory, and reuse them on each generation step. This turns attention from O(n²) recomputation into O(n) new work per step — but at a steep memory cost.

How Autoregressive Generation Works

Let's trace exactly what happens when an LLM generates a 5-token response to a 3-token prompt. We'll call the prompt tokens [A, B, C] and the generated tokens [D, E, F, G, H].

Step-by-step generation. At each step, the model takes in the FULL sequence so far, runs attention, and predicts the next token. Without caching, this means recomputing Q, K, V for everything.

Step 1 (prefill): Process prompt [A, B, C]

Step 2: Generate token D

Step 3: Generate token E

Without caching, by step N we recompute K and V for N tokens. The total K/V computations across all steps: 3 + 4 + 5 + 6 + 7 = 25. With caching: 3 + 1 + 1 + 1 + 1 = 7. That's a 3.6× savings for just 5 generated tokens. For 1000 generated tokens with a 1000-token prompt, the savings are over 500×.

Hand Calculation: KV Cache Memory

The critical question. The KV cache saves enormous compute. But it costs memory — and memory turns out to be the binding constraint in LLM serving. Let's calculate exactly how much.

For each token in the sequence, we store one key vector and one value vector per layer, per head. The cache size per token:

bytes per token = 2 · nlayers · nheads · dhead · bytesper param

The factor of 2 is for K and V (we cache both). Let's compute this for Llama 2 7B:

ParameterLlama 2 7B Value
nlayers32
nheads32
dhead128 (= 4096 / 32)
PrecisionFP16 (2 bytes per value)

Bytes per token:

At different sequence lengths:

Now here's where it gets alarming. The model weights themselves for Llama 2 7B in FP16 are about 14 GB. That's fixed — load once, serve forever. But the KV cache is per user, per request.

For a batch of 32 concurrent users at 4,096 context:

The KV cache alone needs 5× more memory than the model weights. This is why you can't "just add more users" to a GPU — each concurrent user adds ~2 GB of KV cache, and GPU memory is finite.

The Full Picture: Memory Breakdown

ModelWeights (FP16)KV Cache / User @4KMax Users on H100 (80 GB)
Llama 2 7B14 GB2.15 GB~30
Llama 2 13B26 GB3.28 GB~16
Llama 2 70B140 GB10.0 GB0 (needs multi-GPU)

For the 70B model, the weights alone don't fit on a single H100. Even with tensor parallelism across 2 GPUs (160 GB total), after loading the 140 GB model, you have only 20 GB left — enough for just 2 concurrent users at 4K context. This is why KV cache optimization (the topic of the next chapters) is the most active area of LLM systems engineering.

See It: KV Cache Growth

The simulation below shows the KV cache growing in real time as tokens are generated. Click "Generate" to produce one token at a time and watch the cache expand. The memory counter shows exactly how many bytes are consumed.

KV Cache Growth

Click "Generate" to add tokens one at a time. Blue bars are K cache entries, green bars are V cache entries (one per layer). Watch the memory counter climb. Use the model selector to see how different model sizes affect the cache. "Reset" clears the sequence.

Model Size 7B
Concurrent Users 1
Tokens: 0 | Memory: 0 MB

From Scratch: KV Cache Calculator

python
def kv_cache_size_gb(
    n_layers,
    n_heads,
    d_head,
    seq_len,
    n_users=1,
    bytes_per_param=2,  # FP16
):
    """Calculate KV cache memory in GB."""
    bytes_per_token = 2 * n_layers * n_heads * d_head * bytes_per_param
    total_bytes = bytes_per_token * seq_len * n_users
    return total_bytes / (1024 ** 3)

# Llama 2 family
models = {
    "7B":  (32, 32, 128),   # (layers, heads, d_head)
    "13B": (40, 40, 128),
    "70B": (80, 64, 128),
}

print(f"{'Model':>6} {'SeqLen':>7} {'Users':>6} {'KV Cache':>10} {'Weights':>10} {'Total':>10}")
print("-" * 55)

for name, (nl, nh, dh) in models.items():
    weight_gb = {"7B": 14, "13B": 26, "70B": 140}[name]
    for seq in [2048, 4096, 8192]:
        for users in [1, 32]:
            kv = kv_cache_size_gb(nl, nh, dh, seq, users)
            total = kv + weight_gb
            print(f"{name:>6} {seq:>7} {users:>6} {kv:>9.2f}G {weight_gb:>9}G {total:>9.1f}G")

# Key output lines:
#     7B    4096     32      68.7G        14G      82.7G  ← doesn't fit H100!
#    70B    4096      1      10.0G       140G     150.0G  ← needs 2+ GPUs
#    70B    4096     32     320.0G       140G     460.0G  ← needs 6+ H100s
The KV cache is NOT about saving compute (though it does that too). It's about MEMORY. A single H100 has 80 GB. If your model weights take 14 GB and each user's KV cache takes 2.15 GB at 4K context, you can serve at most ~30 concurrent users — regardless of how fast the GPU is. Memory, not FLOPs, is the bottleneck in LLM serving. This is why every technique in the next chapters (MQA, GQA, paged attention) focuses on shrinking the KV cache. A 2× reduction in KV cache size means 2× more concurrent users on the same hardware.

Why This Changes Everything

The KV cache bottleneck explains almost every engineering decision in modern LLM deployment:

The rest of this lesson dives into each of these techniques. Every one of them exists because someone looked at the KV cache memory equation and found a term to shrink.

For Llama 2 7B (32 layers, 32 heads, dhead=128) serving 64 concurrent users at 4,096 context in FP16, approximately how much KV cache memory is needed?

Chapter 4: Multi-Query Attention

Chapter 3 showed us the KV cache is the memory bottleneck. Every attention head stores its own K and V vectors for every token generated so far. For a model with 32 heads, that's 32 separate copies of keys and 32 separate copies of values, all sitting in GPU memory, all growing with every new token.

Here's the first clever fix: what if all 32 attention heads shared a single set of K and V vectors? Each head still asks its own question (different Q projections), but they all search through the same keys and retrieve from the same values. That's Multi-Query Attention (MQA), proposed by Noam Shazeer in 2019.

The idea sounds radical — surely different heads need different keys? But remember: Q is the "question," and K is the "index." If every head asks a different question, they'll get different attention patterns even against the same set of keys. The queries are still fully expressive; we're just sharing the lookup table.

The Architecture Change

In standard Multi-Head Attention (MHA), the input X gets projected into Q, K, and V per head. If there are H heads and each head has dimension dh, the K and V projection matrices are each of size [dmodel × (H · dh)]. That's one dh-dimensional key vector and one dh-dimensional value vector per head, per token.

In MQA, the K and V projection matrices shrink to [dmodel × dh] — just one dh-dimensional key vector and one dh-dimensional value vector, shared across all H heads. The Q projection stays the same: each head still gets its own query.

Input X
[batch, seq, dmodel]
Q Projection (per head)
H separate WQ matrices → H query vectors per token
K, V Projection (SHARED)
One WK, one WV → 1 key + 1 value per token
Attention per head
Each head: softmax(Qh KT / √d) · V — same K, V, different Q

Hand Calculation: Memory Savings

Let's get concrete with Llama 2 7B numbers. The model has 32 attention heads, each with dhead = 128. Weights are stored in float16 (2 bytes per value). The KV cache stores K and V for every token generated so far.

MHA KV cache per token:

MQA KV cache per token:

That's a 32× reduction — exactly equal to the number of heads.

Now scale it up. With a context window of 4,096 tokens:

Now consider serving 64 concurrent users (a modest production workload):

For larger models the savings are even more dramatic. PaLM 540B with 48 heads? MQA shrinks the KV cache by 48×. At scale, this is the difference between needing a cluster and fitting on a single GPU.

MHA vs MQA: Head Architecture & Memory

Left: Multi-Head Attention (each head has its own K,V). Right: Multi-Query Attention (all heads share one K,V). Toggle to compare memory usage.

Heads 32
Sequence Length 4096

The Quality Concern

The obvious question: if all heads share the same K and V, don't they all attend to the same things? Don't we lose the diversity that makes multi-head attention powerful?

Surprisingly, no — or at least, not much. Each head still has its own Q projection, which means each head asks a different question of the shared keys. Head 1 might project its query to emphasize syntactic features; head 12 might emphasize positional features. They'll compute different attention weights (softmax(QhKT/√d)) because the Qh differs, even though K is the same for all.

Empirically, MQA drops quality by roughly 0.5–1% on standard benchmarks compared to MHA. That's a tiny cost for a 32× memory reduction. The tradeoff is overwhelmingly worthwhile for production inference.

Code: MQA Implementation

Here's how MQA modifies the standard MHA forward pass. The key change is in the projection dimensions: K and V are projected to dhead (not H × dhead), then broadcast across all heads during the attention computation.

python
import torch
import torch.nn as nn

class MultiQueryAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.d_head = d_model // n_heads

        # Q: per-head projections (standard)
        self.W_q = nn.Linear(d_model, d_model)  # [d, H*d_h]

        # K, V: SINGLE head projection (the MQA trick)
        self.W_k = nn.Linear(d_model, self.d_head)  # [d, d_h]
        self.W_v = nn.Linear(d_model, self.d_head)  # [d, d_h]

        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x):
        B, T, D = x.shape
        H, d_h = self.n_heads, self.d_head

        # Project Q, K, V
        q = self.W_q(x).view(B, T, H, d_h).transpose(1, 2)  # [B, H, T, d_h]
        k = self.W_k(x).unsqueeze(1)                             # [B, 1, T, d_h]
        v = self.W_v(x).unsqueeze(1)                             # [B, 1, T, d_h]

        # k and v broadcast across H heads automatically
        scores = (q @ k.transpose(-2, -1)) / (d_h ** 0.5)  # [B, H, T, T]
        attn = scores.softmax(dim=-1)
        out = (attn @ v).transpose(1, 2).reshape(B, T, D)     # [B, T, D]
        return self.W_o(out)

The critical lines are the W_k and W_v projections: they output dhead dimensions (one head's worth), then unsqueeze(1) adds the head dimension so PyTorch's broadcasting does the rest. Each of the H heads computes QhKT using the same K, without any copying.

Real-world adoption. MQA powers some of the largest deployed models: Google's PaLM (540B), BigScience's StarCoder (15B), and TII's Falcon (40B). When you need maximum KV cache savings and can tolerate a fraction of a percent quality drop, MQA is the go-to choice.
Misconception: MQA reduces compute. It does not. The attention matrix is still n×n per head — every head still computes softmax(QKT/√d). MQA reduces memory by shrinking the KV cache. The Q projections are still per-head. Only K and V are shared. Compute is virtually unchanged; memory plummets.
If a model has 32 attention heads and switches from MHA to MQA, by what factor does the KV cache shrink?

Chapter 5: Grouped-Query Attention

MQA is aggressive. All heads share a single set of keys and values. For most large models, the quality drop is negligible. But some teams found it too steep — especially for models in the 7B–13B range where every fraction of a percent on benchmarks matters for competitive positioning.

What if we could dial the tradeoff between MHA and MQA? Instead of giving every head its own KV (MHA) or sharing one KV across all heads (MQA), what if we put heads into groups and shared one KV per group?

That's Grouped-Query Attention (GQA), introduced by Ainslie et al. at Google in 2023. Think of it as a knob: turn it all the way left and you get MQA (one group). Turn it all the way right and you get MHA (every head is its own group). GQA sits anywhere in between.

How Groups Work

Say your model has H = 32 query heads and you choose G = 8 KV groups. That means every group of 32/8 = 4 query heads shares a single K and V. The model has 8 independent K projection vectors and 8 independent V projection vectors, instead of 32 (MHA) or 1 (MQA).

During attention, heads 0–3 all use KV group 0. Heads 4–7 use KV group 1. And so on. Within each group, the mechanism is identical to MQA: different queries, shared keys and values.

The GQA spectrum. MHA, GQA, and MQA are not three different algorithms — they're points on a single continuum. GQA with G = H is MHA. GQA with G = 1 is MQA. Every value of G in between gives a different memory/quality tradeoff. The "grouped-query" name is the general case; MHA and MQA are the extremes.

Hand Calculation: Llama 2 70B

Llama 2 70B uses GQA with 64 query heads and 8 KV groups. Each head has dhead = 128. Let's compute the KV cache savings.

Heads per group: 64 query heads / 8 groups = 8 query heads per KV group.

GQA-8 KV cache per token:

MHA (for comparison) KV cache per token:

MQA (for comparison) KV cache per token:

GQA-8 gives an 8× reduction over MHA. Not as extreme as MQA's 64×, but with much better quality retention. Here's the full comparison at 4,096 tokens context:

MethodKV GroupsKV/TokenKV @ 4K ctx64 Users
MHA6432,768 B128 MB8.0 GB
GQA-884,096 B16 MB1.0 GB
MQA1512 B2 MB128 MB

At 64 concurrent users, MHA eats 8 GB of VRAM just for the KV cache — half of a 16 GB card, before model weights even load. GQA-8 needs just 1 GB. MQA needs 128 MB. The sweet spot depends on your model size and serving constraints, but for most 70B-class models, GQA-8 is the winner.

The GQA Dial: Memory vs Quality

Drag the slider to move between MQA (1 group), GQA, and MHA (all heads). Watch how heads regroup and memory changes.

KV Groups (G) 8
Total Heads (H) 32

The Conversion Trick

Here's a remarkable fact: you can convert an already-trained MHA model to GQA without retraining from scratch. The recipe is simple: take the K and V weight matrices from each group of heads and average them.

Say you have 64 heads and want 8 groups. Heads 0–7 form group 0. Take their 8 K weight matrices, average them element-by-element, and use that average as the single K matrix for the group. Do the same for V. Repeat for all 8 groups.

Then fine-tune the model for a few thousand steps to let it adapt to the shared parameters. Llama 2 70B was originally trained with MHA. Meta converted it to GQA-8 using exactly this procedure and recovered 97% of the original quality after a short fine-tuning run. The KV cache shrank 8× for a 3% quality cost that was further recovered during fine-tuning.

Code: GQA Implementation

The key operation in GQA is the group repeat: after projecting K and V to G groups, we need to "expand" them so each query head has matching K and V tensors. PyTorch handles this with repeat_interleave or manual reshaping.

python
import torch
import torch.nn as nn

class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model, n_heads, n_kv_groups):
        super().__init__()
        self.n_heads = n_heads
        self.n_kv = n_kv_groups
        self.d_head = d_model // n_heads
        self.heads_per_group = n_heads // n_kv_groups

        # Q: full per-head projections
        self.W_q = nn.Linear(d_model, n_heads * self.d_head)
        # K, V: only n_kv_groups projections
        self.W_k = nn.Linear(d_model, n_kv_groups * self.d_head)
        self.W_v = nn.Linear(d_model, n_kv_groups * self.d_head)
        self.W_o = nn.Linear(d_model, d_model)

    def repeat_kv(self, x):
        # x: [B, G, T, d_h] -> [B, H, T, d_h]
        # Repeat each KV group heads_per_group times
        B, G, T, d = x.shape
        x = x[:, :, None, :, :].expand(B, G, self.heads_per_group, T, d)
        return x.reshape(B, self.n_heads, T, d)

    def forward(self, x):
        B, T, D = x.shape
        d_h = self.d_head

        q = self.W_q(x).view(B, T, self.n_heads, d_h).transpose(1, 2)
        k = self.W_k(x).view(B, T, self.n_kv, d_h).transpose(1, 2)
        v = self.W_v(x).view(B, T, self.n_kv, d_h).transpose(1, 2)

        # Expand K, V to match query head count
        k = self.repeat_kv(k)  # [B, H, T, d_h]
        v = self.repeat_kv(v)  # [B, H, T, d_h]

        scores = (q @ k.transpose(-2, -1)) / (d_h ** 0.5)
        attn = scores.softmax(dim=-1)
        out = (attn @ v).transpose(1, 2).reshape(B, T, D)
        return self.W_o(out)

The repeat_kv method is the core of GQA: it takes the G-grouped K or V tensor and repeats each group H/G times so the shapes align for the batched matrix multiply. In practice, most frameworks avoid actual memory copies by using expand (which creates a view, not a copy) — so the memory savings are real, not just theoretical.

Misconception: GQA is "worse MHA." For large models (70B+), GQA-8 matches MHA quality while using 8× less KV cache memory. The quality loss is measurable on academic benchmarks but tiny in practice. Llama 2 70B, Llama 3, Gemma, Mistral, and Qwen — all use GQA. It is the de facto standard for modern large language models.
When to use which. For models under 7B: MQA often suffices — quality loss is small and memory savings are huge. For 7B–70B: GQA with 4–8 groups hits the sweet spot. For research/quality-first: MHA retains full head diversity but costs the most memory. Choose based on your serving budget.
A model has 32 query heads and uses GQA with 4 KV groups. How many query heads share each KV pair?

Chapter 6: Sliding Window Attention

MQA and GQA attacked the KV cache size per token — shrinking how much memory each position needs. But there's a second dimension to the problem: the number of tokens each position attends to. In standard attention, every token looks at every previous token. That's the n×n attention matrix that makes the whole thing quadratic.

What if each token only looked at the W nearest tokens instead of all n? That's Sliding Window Attention — also called local attention. It caps the attention cost per token at W, making the total cost O(n · W) instead of O(n²). When the sequence is long and the window is moderate, the savings are enormous.

This isn't a new idea — local attention patterns appeared in research as early as 2019. But it became famous when Mistral 7B shipped with sliding window attention as a key architectural choice, showing that you could match or beat much larger models by being smart about what each token actually needs to see.

Why Local Context Is Usually Enough

Think about how language works. When you read the word "running" in "the dog is running through the park," you need the nearby tokens to understand it. You almost never need token #47 from 10,000 positions ago. Language is overwhelmingly local: subjects, verbs, and objects cluster together. Pronouns reference things within a few sentences. Semantic coherence is maintained over paragraphs, not pages.

There are exceptions — a reference to a character introduced chapters ago, or a variable defined hundreds of lines up in code. But these long-range dependencies are rare compared to local ones. Sliding window attention bets that most attention mass lands in a local neighborhood, and that bet pays off.

The Mask Pattern

Visualize the attention matrix as a grid. Rows are query positions (tokens asking questions). Columns are key positions (tokens being attended to). In standard causal attention, the lower-left triangle is filled — each token can attend to itself and all previous tokens.

In sliding window attention with window size W, each token can only attend to the W most recent tokens (including itself). The filled region shrinks from a triangle to a diagonal band of width W. Tokens outside the band get attention weight zero.

Attention Mask Patterns

Compare mask patterns: full causal, sliding window, and hybrid (Longformer-style with global tokens). Each cell is one attention entry.

Window Size (W) 4
Sequence Length 16

Hand Calculation: Computation Savings

Let's count the number of attention entries (cells in the matrix) for different configurations.

Full causal attention, n = 16,384:

The lower triangle has n · (n+1) / 2 entries. For n = 16,384:

Sliding window, n = 16,384, W = 4,096:

Each of the first W tokens attends to all previous tokens (building up to the full window). After that, each token attends to exactly W tokens. A close upper bound:

At n = 32,768 with W = 4,096:

The longer the sequence relative to the window, the bigger the savings. At n = 131,072 with W = 4,096: full attention requires ∼8.6 billion entries, while sliding window needs ∼537 million. That's a 16× reduction. The savings scale as n / (2W) for large n.

The Receptive Field Trick

The obvious worry: if each token only sees W positions, doesn't the model lose access to information beyond the window? How can it handle a question like "Who was the main character introduced in chapter 1?" when you're 50,000 tokens later?

The answer is layer stacking. Each layer extends the effective reach by W positions. Think of it like a relay race:

Mistral 7B has 32 layers and window size W = 4,096. Its effective receptive field is:

Effective reach = L × W = 32 × 4,096 = 131,072 tokens

That's 131K tokens — far beyond the model's actual context window. The sliding window doesn't limit information flow; it just makes information travel through the network layer by layer rather than arriving all at once. Empirically, this works just as well for the vast majority of tasks.

The Rolling Buffer KV Cache

Sliding window attention enables another powerful optimization: a rolling buffer KV cache. In standard attention, the KV cache grows with every new token and never shrinks. At position 10,000, you're storing KV entries for all 10,000 previous positions.

With a sliding window, token 10,000 only attends to positions [5,905, 10,000] (for W = 4,096). Positions before 5,905 are outside every current and future token's window. They will never be attended to again. So we can evict them from the cache.

The result: the KV cache is bounded at W entries, regardless of how long the generation runs. Instead of KV cache growing linearly with sequence length, it stays constant. For long-form generation (writing a novel, processing a codebase), this is transformative.

Code: Sliding Window Mask

python
import torch

def sliding_window_mask(seq_len, window_size):
    # Create causal mask first (lower triangle)
    causal = torch.ones(seq_len, seq_len, dtype=torch.bool).tril()

    # Create window mask: token i attends to [i-W+1, i]
    rows = torch.arange(seq_len).unsqueeze(1)    # [n, 1]
    cols = torch.arange(seq_len).unsqueeze(0)    # [1, n]
    window = (rows - cols) < window_size        # within W positions

    # Combine: must be both causal AND within window
    mask = causal & window
    return mask  # True = attend, False = mask out

# Example: 8 tokens, window size 3
mask = sliding_window_mask(8, 3)
# Token 0 sees: [0]
# Token 1 sees: [0, 1]
# Token 2 sees: [0, 1, 2]
# Token 3 sees: [1, 2, 3]     <- window kicks in
# Token 7 sees: [5, 6, 7]
Misconception: sliding window limits what the model can "see." Not true. Information propagates through layers. Each layer extends the effective reach by W positions. A 32-layer model with W = 4,096 has an effective receptive field of 131,072 tokens — far beyond any typical context window. The window constrains direct attention, not information flow.
Who uses sliding window? Mistral 7B (W=4096, 32 layers), Mixtral 8x7B (same window), Gemma 2 (alternates sliding window and full attention layers). Many long-context models use hybrid schemes: sliding window for most layers, full attention for a few "global" layers.
If a model has 24 layers with a sliding window of size 2,048, what is the effective receptive field?

Chapter 7: Linear Attention

Every variant we've seen so far still computes the n×n attention matrix — they just make it sparser (sliding window) or share the KV cache (MQA, GQA). The matrix itself, that softmax(QKT/√d) grid of attention weights, is still there. It's still quadratic.

What if we could avoid the n×n matrix entirely? Not approximate it with sparsity, not share components around it, but mathematically side-step it altogether?

That's what Linear Attention does. The key insight is startlingly simple: by changing the order of matrix multiplication, we can compute the same output without ever forming the n×n intermediate matrix. The cost drops from O(n²d) to O(nd²), and when d is much smaller than n (which it always is in practice), that's a massive win.

The Associativity Trick

Let's start with standard attention. For one head, given queries Q (n×d), keys K (n×d), and values V (n×d):

Standard: Output = softmax(Q KT / √d) · V

The problem is the QKT term: (n×d) · (d×n) = n×n matrix. This requires O(n²d) operations and O(n²) memory.

Now, softmax is what forces us to compute QKT first — we need every entry of that matrix before we can normalize rows. But what if we replaced softmax with a simpler function that doesn't require the full matrix?

Replace softmax with a kernel function φ applied element-wise to Q and K:

Linear: Output = φ(Q) · (φ(K)T · V)

Read that carefully. The parentheses changed. Instead of computing (φ(Q) · φ(K)T) · V, we compute φ(Q) · (φ(K)T · V). Same letters, different grouping. This is the associativity of matrix multiplication.

Why does the grouping matter? Let's trace the shapes:

OperationStandard OrderLinear Order
Step 1φ(Q)φ(K)T: (n×d)·(d×n) = n×nφ(K)TV: (d×n)·(n×d) = d×d
Step 2(n×n)·V: (n×n)·(n×d) = n×dφ(Q)·(d×d): (n×d)·(d×d) = n×d
Intermediaten×n (grows with sequence!)d×d (fixed!)
FLOPsO(n²d)O(nd²)

The intermediate matrix went from n×n (grows quadratically with sequence length) to d×d (fixed at dhead², regardless of sequence length). This is the core of linear attention.

Hand Calculation: Standard vs Linear Order

Let's work through a tiny example. n = 4 tokens, d = 3 dimensions. We'll skip the kernel φ (assume identity) and just compare the two multiplication orders.

Standard order: (QKT)V

Step 1 — QKT: multiply (4×3) by (3×4) = (4×4) matrix.

Step 2 — (4×4) · V: multiply (4×4) by (4×3) = (4×3) output.

Grand total: 48 + 48 = 96 operations. Intermediate matrix: 4×4 = 16 elements.

Linear order: Q(KTV)

Step 1 — KTV: multiply (3×4) by (4×3) = (3×3) matrix.

Step 2 — Q · (3×3): multiply (4×3) by (3×3) = (4×3) output.

Grand total: 36 + 36 = 72 operations. Intermediate matrix: 3×3 = 9 elements.

At n = 4, the saving is modest: 96 vs 72 operations. But watch what happens as n grows. The standard approach scales as n² while the linear approach scales as n:

ndStandard (n²d)Linear (nd²)Speedup
4396721.3×
64324,5761,15221×
1,0246467M4.2M16×
16,38412834.4B268M128×

At n = 16,384 with d = 128 (typical for modern LLMs), linear attention is 128× faster than standard attention. The crossover where linear becomes cheaper is roughly n ≈ d — for any practical sequence length, linear wins.

Standard vs Linear Attention: Computation Comparison

Left: standard attention materializes the n×n matrix. Right: linear attention computes a d×d summary. Drag the sequence length slider and watch the intermediate matrix grow (or not).

Sequence Length (n) 32
Head Dim (d) 16

The Kernel Function φ

We can't just remove softmax and call it a day. Softmax ensures attention weights are positive and sum to 1 — it's a proper probability distribution. If we want the associativity trick to work, we need a kernel function φ that preserves these properties (or close enough).

The original Linear Transformer (Katharopoulos et al., 2020) used a simple choice:

φ(x) = elu(x) + 1

Where elu is the exponential linear unit: elu(x) = x if x > 0, ex − 1 if x ≤ 0. Adding 1 ensures all values are positive (matching softmax's positivity). The output is then normalized by dividing by the sum of φ(Q)φ(K)T along the key dimension.

Other kernels have been explored: random Fourier features (FAVOR+ in Performer), polynomial kernels, learned kernels. The choice of φ determines how well linear attention approximates the original softmax attention pattern.

The Catch

If linear attention is so much faster, why isn't everyone using it? The answer is quality. Softmax attention creates sharp, peaked distributions — it can put 90% of the attention weight on a single token when that token is the relevant one. This sharp retrieval is what makes attention so powerful for tasks like question answering, code completion, and factual recall.

The kernel approximation φ produces smoother distributions. Instead of a sharp peak on one token, it spreads weight more broadly. For most generation tasks (continuing text, summarizing, chat), this barely matters. But for "needle in a haystack" retrieval — finding a specific fact in a long context — linear attention is measurably weaker.

This is an active area of research. Models like RWKV, RetNet, and Mamba (which uses a related selective state space mechanism) are closing the gap. Each generation gets closer to softmax quality while keeping linear or near-linear scaling. The bet in the research community is that subquadratic attention will eventually match softmax — but we're not there yet for all tasks.

Code: Linear Attention in 15 Lines

python
import torch
import torch.nn.functional as F

def linear_attention(Q, K, V):
    # Q, K, V: [batch, heads, seq_len, d_head]

    # Apply kernel: phi(x) = elu(x) + 1 (ensures positivity)
    Q = F.elu(Q) + 1   # [B, H, n, d]
    K = F.elu(K) + 1   # [B, H, n, d]

    # THE TRICK: compute K^T V first (d×d, not n×n!)
    KV = K.transpose(-2, -1) @ V    # [B, H, d, d]

    # Then multiply Q by the d×d summary
    out = Q @ KV                        # [B, H, n, d]

    # Normalize (equivalent to softmax denominator)
    Z = Q @ K.transpose(-2, -1).sum(dim=-1, keepdim=True)
    # Simpler: Z = (Q * K.sum(dim=-2, keepdim=True)).sum(dim=-1, keepdim=True)

    return out / (Z + 1e-6)  # avoid division by zero

The entire trick is on one line: KV = K.transpose(-2, -1) @ V. This computes the d×d summary matrix. Then Q @ KV produces the output without ever materializing an n×n matrix. The normalization term Z ensures the output is properly scaled.

The state-space connection. Linear attention has a beautiful dual interpretation: it's equivalent to a linear recurrence (state-space model). The d×d matrix KV is the "state" that accumulates information from all tokens seen so far. Each new token updates this state and reads from it. This is why Mamba, RWKV, and RetNet all feel related — they're all variants of maintaining a fixed-size summary of the sequence instead of storing every token's KV pair. The convergence of attention and state-space models is one of the most exciting trends in modern architecture research.
Misconception: linear attention is "attention without softmax." Simply removing softmax without changing the multiplication order still gives O(n²). The key insight is reordering the matrix multiplication from (QKT)V to Q(KTV). It's the associativity trick that makes it linear, not the removal of softmax. You need a kernel φ to replace softmax, but the kernel alone doesn't help without the reordering.
In linear attention, what is the size of the intermediate matrix that replaces the n×n attention matrix?

Chapter 8: FlashAttention

Every variant we've seen so far changes the attention algorithm — fewer heads, local windows, kernel tricks. FlashAttention takes a completely different approach: it computes exact standard attention, but reorganizes the computation to be friendly to the GPU's memory hierarchy. Same math, radically faster.

The insight is almost embarrassingly simple once you see it. Modern GPUs have two kinds of memory: a large, slow main memory (HBM — High Bandwidth Memory), and a tiny, blazing-fast on-chip cache (SRAM). Standard attention constantly shuffles enormous matrices between these two levels. FlashAttention restructures the work so that almost everything happens in the fast cache, and the huge intermediate matrix never gets written at all.

The GPU Memory Hierarchy

Think of HBM as a warehouse and SRAM as your workbench. The warehouse is huge (80 GB on an H100) but it takes time to haul things back and forth. Your workbench is tiny (about 50 MB) but everything on it is instantly accessible. Standard attention is like building furniture by carrying every plank to the warehouse after cutting it, then hauling it all back to assemble. FlashAttention keeps the planks on your workbench and only moves the finished piece once.

Concrete numbers for an NVIDIA H100:

LevelSizeBandwidthRelative speed
HBM (main memory)80 GB3.35 TB/s1× (baseline)
SRAM (on-chip cache)~50 MB~50 TB/s~15× faster
Registers~20 MB total∞ (no transfer)Instant

The bottleneck in modern GPUs is rarely compute — it's memory bandwidth. Moving data to and from HBM is what makes attention slow. FlashAttention is an IO-aware algorithm: it minimizes the number of HBM reads and writes by doing as much work as possible while data sits in fast SRAM.

Standard Attention: The HBM Nightmare

Let's trace what standard attention does for a single head. Inputs: Q, K, V ∈ ℝn × d, all sitting in HBM.

Step 1: Read Q, K from HBM
Transfer 2 × n × d values to compute unit
Step 2: Compute S = QKT
Produces n × n matrix (huge!)
Step 3: Write S to HBM
n2 values → HBM (the bottleneck)
Step 4: Read S from HBM, compute P = softmax(S)
Read n2 values, write n2 values back
Step 5: Read P, V from HBM
Read n2 + n × d values
Step 6: Compute O = PV, write to HBM
Final output: n × d values

Count the HBM traffic: we read/write that n × n matrix three times (write S, read for softmax, read P for multiply). For n = 8192 and d = 128 in FP16 (2 bytes per value):

S matrix = n2 = 81922 = 67,108,864 entries
Memory = 67M × 2 bytes = 128 MB per head, per layer

With 32 heads and 80 layers, that's 32 × 80 × 128 MB = 320 GB of HBM traffic just for the attention scores — in a single forward pass. And HBM bandwidth is the bottleneck. This is where all the time goes.

FlashAttention: The Tiling Strategy

FlashAttention's core idea: never materialize the full n × n matrix. Instead, split Q into row blocks of size Br and K, V into column blocks of size Bc. Load one Q-block and one K,V-block into SRAM at a time. Compute that tile of the attention matrix, immediately use it, and accumulate the result. The tile lives only in SRAM — it never touches HBM.

The tiling analogy. Imagine computing a huge multiplication table (100 × 100). Standard attention writes the entire table on a piece of paper (HBM), then reads it back row by row. FlashAttention computes one small block at a time (say 10 × 10), uses those results immediately, and never writes the table at all. At the end, you have the same final answer — you just never needed the intermediate table.

Hand Calculation: How Many Tiles?

SRAM must hold: one Q-block (Br × d), one K-block (Bc × d), one V-block (Bc × d), and the partial output block (Br × d). That's approximately 4 × B × d entries (if Br = Bc = B).

For an A100 with 192 KB of SRAM per SM, d = 128, FP16 (2 bytes):

SRAM capacity ≈ 192,000 bytes
Entries that fit = 192,000 / 2 = 96,000
Need 4 × B × d entries ≤ 96,000
B ≤ 96,000 / (4 × 128) = 187 tokens per tile

For a sequence of n = 8192 tokens:

Number of Q-blocks = ⌈8192 / 187⌉ = 44 blocks
Number of K,V-blocks = 44 blocks
Total tiles = 44 × 44 = 1,936 tiles

Each tile computes a 187 × 187 sub-matrix = 34,969 entries. That tiny block lives entirely in SRAM. Standard attention would need all 67 million entries in HBM simultaneously. FlashAttention needs only ~35,000 at any moment — a 1,920× reduction in peak memory.

The Online Softmax Trick

There's a catch. Softmax requires the maximum over all scores in a row for numerical stability. The formula is:

softmax(xi) = exp(xi − max(x)) / ∑j exp(xj − max(x))

This seems to require seeing ALL n scores before computing any output. If we're processing one tile at a time, we don't know the global max yet. How can we compute softmax?

The answer is the online softmax algorithm (Milakov & Gimelshein, 2018). It processes one block at a time, maintaining a running maximum and a running denominator. When a new tile arrives with a larger maximum, it rescales all previous results:

python
# Online softmax — the trick that makes FlashAttention possible
# Process tiles of scores one at a time

m_prev = -inf          # running maximum
l_prev = 0.0            # running sum of exp(scores - max)
O_prev = zeros(B_r, d)  # running output accumulator

for j in range(num_kv_blocks):
    # Load one K,V block from HBM into SRAM
    K_j = K[j * B_c : (j+1) * B_c]   # shape: (B_c, d)
    V_j = V[j * B_c : (j+1) * B_c]   # shape: (B_c, d)

    # Compute local attention scores for this tile
    S_ij = Q_block @ K_j.T             # shape: (B_r, B_c)

    # Update running max
    m_new = max(m_prev, rowmax(S_ij))

    # Rescale previous accumulations to new max
    correction = exp(m_prev - m_new)
    l_new = correction * l_prev + rowsum(exp(S_ij - m_new))

    # Update output: rescale old + add new contribution
    P_ij = exp(S_ij - m_new)          # local attention weights
    O_new = correction * O_prev + P_ij @ V_j

    m_prev, l_prev, O_prev = m_new, l_new, O_new

# Final output: normalize by total denominator
O = O_prev / l_prev                    # shape: (B_r, d)

The key insight: the correction = exp(m_prev - m_new) term rescales all prior work when the max changes. If the new tile has a larger maximum, previous exponentials were too large — multiply them down. This produces bit-for-bit identical results to standard softmax. No approximation whatsoever.

FlashAttention is NOT an approximation. This is the most common misconception. It computes exact softmax attention — identical output to the standard algorithm, down to floating-point precision. The speedup comes purely from being smarter about memory access patterns. It's an IO-aware algorithm, not a mathematical shortcut. Every variant we saw earlier (MQA, GQA, sliding window, linear) changes what attention computes. FlashAttention changes how.

The Full Algorithm

Putting it all together:

1. Divide Q into Br-row blocks
Each block processes independently (outer loop)
2. For each Q-block, iterate over K,V blocks
Load one K,V block into SRAM (inner loop)
3. Compute tile Sij = QiKjT
Small tile lives entirely in SRAM
4. Online softmax: update max, rescale, accumulate
Never store the full attention matrix
5. Write final output block to HBM
One write per Q-block (n/Br total writes)

HBM traffic for FlashAttention: read Q once (n × d), read K,V once per Q-block pass (n × d each, but reused), write O once (n × d). The n × n attention matrix? Never written to HBM. Never read from HBM. It exists only briefly in SRAM, one tile at a time.

The Simulation

Watch the two approaches side by side. Standard attention bounces data back and forth to HBM constantly. FlashAttention loads tiles into SRAM and keeps them there. Click Step to advance through each stage of the computation.

GPU Memory Hierarchy — Standard vs FlashAttention

Left: standard attention writes the full n×n matrix to HBM. Right: FlashAttention tiles the computation through SRAM. Red arrows = slow HBM transfers. Green arrows = fast SRAM operations.

Seq length (n) 4096

Performance Numbers

Here are benchmarks from the FlashAttention-2 paper (Dao, 2023) on an A100-80GB GPU, with head dimension d = 128, 12 heads:

Sequence lengthStandard (ms)FlashAttention-2 (ms)SpeedupMemory savings
5120.80.51.6×
10242.10.92.3×16×
20487.62.43.2×64×
4096297.83.7×256×
8192115244.8×1024×
16384OOM85

The pattern: the longer the sequence, the bigger the win. At n = 16384, standard attention runs out of memory entirely while FlashAttention still fits comfortably. This is why every modern LLM uses FlashAttention — it's free performance with zero quality loss.

FlashAttention composes with all other variants. You can use FlashAttention with MHA, MQA, GQA, or sliding window attention. It doesn't change what you compute — it speeds up any dot-product attention computation. Llama 3 uses GQA + FlashAttention-2. Mistral uses sliding window + FlashAttention. They're orthogonal improvements that stack.

The Evolution

VersionYearKey improvement
FlashAttention2022Tiled computation, online softmax. 2-4× speedup.
FlashAttention-22023Better parallelism across GPU warps. 1.7-2× over FA1.
FlashAttention-32024Exploits H100 Tensor Cores, FP8 support. 1.5-2× over FA2.

FlashAttention-3 on H100 can reach 740 TFLOPS for attention — 75% of the theoretical peak. Standard attention typically achieves less than 35% utilization. Same math, same answer, just dramatically better use of the hardware.

What does FlashAttention avoid materializing in GPU HBM?

Chapter 9: The Attention Arena

You've learned seven different approaches to attention — from full multi-head attention down to linear approximations, from multi-query compression to FlashAttention's memory-hierarchy trick. But which one should you actually use? It depends on your sequence length, model dimension, and quality requirements.

Let's race them.

This arena simulates the throughput, memory usage, and quality tradeoffs of every attention variant at once. Adjust the sliders to match your deployment scenario: short prompts or long context? Small model or frontier-scale? Then hit Race! and watch the variants compete.

How to use the Arena. Set your deployment scenario with the sliders, then click Race! Bars animate from left to right — longer bars mean higher throughput (faster). Below the race, you'll see memory usage and quality ratings. Try these experiments: (1) Set seq length to 16384 — watch MHA collapse while linear attention dominates. (2) Set seq length to 512 — all variants are similar, Flash-MHA wins on quality. (3) Compare GQA groups: 4 groups vs 8 groups — see the memory-quality tradeoff. (4) Crank model dim to 4096 — watch memory diverge dramatically.
Attention Racing Arena

Seven attention variants race on throughput, memory, and quality. Adjust parameters and hit Race! to see who wins for your deployment scenario.

Sequence Length 4096
Model Dim (dmodel) 1024
Num Heads (H) 16
GQA Groups 4

What the Arena Reveals

At short sequences (512–1024): All variants perform similarly. The n² cost is small, so clever tricks don't save much. Flash-MHA gives the best quality at competitive speed. This is why BERT-era models used full MHA — sequences were short enough that it didn't matter.

At medium sequences (2048–4096): The pack starts to separate. GQA and MQA pull ahead on throughput because they have smaller KV caches (less memory traffic). FlashAttention gives MHA a lifeline. Sliding window is great if your task is local.

At long sequences (8192+): The gap is dramatic. Standard MHA without FlashAttention may OOM entirely. Linear attention dominates raw throughput with its O(n) scaling. Sliding window excels if the task doesn't need global context. GQA + FlashAttention is the modern sweet spot — near-MHA quality with 4–8× less KV memory.

On quality: MHA and FlashAttention are lossless (they compute the same thing). GQA-8 is near-lossless (<0.1% degradation in most benchmarks). MQA has slight degradation (0.5–1%). Linear attention has noticeable degradation (1–3%) on tasks requiring precise long-range recall. Sliding window has zero loss for local patterns but misses distant information.

The modern default: GQA + FlashAttention. Every frontier LLM shipped since 2023 uses this combination. Llama 2/3, Mistral, Gemma, Command-R, Qwen — all GQA with FlashAttention. It's not the fastest possible (linear attention wins on raw throughput) and not the highest quality possible (full MHA is marginally better). But it's the best tradeoff — near-lossless quality with excellent throughput and manageable memory. Engineering is about tradeoffs.

Throughput Scaling Behavior

Variantn = 512n = 4096n = 16384Scaling
MHA~1.0×0.3×OOMO(n²)
MQA~1.1×0.8×0.15×O(n²) less KV IO
GQA-4~1.05×0.65×0.12×O(n²) less KV IO
GQA-8~1.05×0.7×0.13×O(n²) less KV IO
Sliding-4K~1.0×~0.9×0.75×O(n · W)
Linear~0.95×~1.1×1.2×O(n · d²)
Flash-MHA~1.2×0.85×0.25×O(n²) less IO

(Throughput relative to MHA at n = 512. "OOM" means out of memory on a single A100-80GB.)

Memory Comparison

KV cache memory per token per layer (dmodel = 4096, dhead = 128, H = 32, FP16):

VariantFormulaBytes per token per layerAt n=8192
MHA2 × H × dhead × 216,384128 MB
GQA-82 × 8 × dhead × 24,09632 MB
GQA-42 × 4 × dhead × 22,04816 MB
MQA2 × 1 × dhead × 25124 MB
Sliding-W2 × H × dhead × 2 (capped at W)16,384 (max 4K window)64 MB (capped)
Linear2 × H × dhead² × 2 (fixed)2,097,152 (fixed)16 MB (fixed)

Chapter 10: Cheat Sheet & Connections

You've mastered the complete attention variant toolkit — from standard multi-head attention through all its descendants, down to the hardware tricks that make it fast. This chapter is your reference card. No new concepts. Just the decision guide, the comparison table, and where to go next.

The Decision Flowchart

Follow the path that matches your situation:

What's your primary constraint?
Every deployment has a bottleneck. Identify yours first.
Need exact attention? No quality compromise?
FlashAttention with MHA or GQA. Same math, faster execution.
Memory limited, some quality flexibility?
How much can you trade? Slight loss → GQA-8. More savings → MQA.
Very long context (>32K tokens)?
Sliding window + FlashAttention for local tasks. GQA + FlashAttention for global.
Real-time / edge deployment?
Linear attention (O(n) scaling) or MQA (minimal KV cache).
Serving many concurrent users?
GQA + FlashAttention (Llama 3 recipe). Best throughput-per-dollar at scale.

Comprehensive Comparison

VariantKV MemoryComputeQuality vs MHAUsed By
MHAO(H · n · d)O(n² · d)BaselineGPT-2/3, BERT, T5
MQAO(n · d)O(n² · d)−0.5–1%PaLM, Falcon, StarCoder
GQA-GO(G · n · d)O(n² · d)−0.1–0.5%Llama 2/3, Mistral, Gemma
Sliding-WO(H · W · d)O(n · W · d)≈ same (local)Mistral, Longformer, BigBird
LinearO(H · d²)O(n · d²)−1–3%RWKV, RetNet, Linear Transformer
FlashAttnO(n) scratchO(n² · d)ExactEverything modern

Symbol Reference

SymbolMeaningTypical values
nSequence length512–128K+
d or dmodelModel dimension512–8192
dheadPer-head dimension (d / H)64–128
HNumber of query heads8–128
GNumber of KV groups (GQA)1 (MQA) to H (MHA)
WWindow size (sliding window)512–4096
Br, BcFlashAttention tile sizes64–256 (hardware-dependent)
Q, K, VQuery, Key, Value matrices∈ ℝn × dhead
SAttention score matrix (QKT)∈ ℝn × n
OOutput matrix∈ ℝn × dhead
HBMHigh Bandwidth Memory (GPU main memory)40–80 GB
SRAMOn-chip static RAM (GPU cache)~20–50 MB

Summary of Everything

Ch 0: Why O(n²)?
Every token attends to every other. Cost explodes with sequence length.
Ch 1: Dot-Product Attention
Q · KT → softmax → V. The foundation everything builds on.
Ch 2: Multi-Head Attention
H parallel heads, each attending to different subspaces.
Ch 3: The KV Cache Bottleneck
Autoregressive inference stores all past K,V. Memory explodes.
Ch 4: Multi-Query Attention
One shared K,V for all heads. H× less KV memory.
Ch 5: Grouped-Query Attention
G groups of shared K,V. The Goldilocks tradeoff.
Ch 6: Sliding Window
Each token attends only to W neighbors. O(n·W) instead of O(n²).
Ch 7: Linear Attention
Replace softmax with kernel trick. O(n) but approximate.
Ch 8: FlashAttention
Same math, tiled through SRAM. Exact attention, 2-5× faster.
Ch 9: Arena
Race all variants. Know the tradeoffs for your scenario.

Connections

Attention variants don't exist in isolation. Here's where to go next:

Key Papers

PaperYearContribution
Vaswani et al., "Attention Is All You Need"2017Multi-head attention, the Transformer
Shazeer, "Fast Transformer Decoding: One Write-Head Is All You Need"2019Multi-Query Attention (MQA)
Ainslie et al., "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints"2023Grouped-Query Attention
Beltagy et al., "Longformer: The Long-Document Transformer"2020Sliding window + global attention
Katharopoulos et al., "Transformers are RNNs"2020Linear attention via kernel trick
Dao et al., "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness"2022Tiled attention, online softmax
Dao, "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning"2023Improved warp-level parallelism
Shah et al., "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision"2024H100 optimization, FP8 support
"The purpose of computing is insight, not numbers." — Richard Hamming. Every variant in this lesson exists because someone understood the bottleneck deeply enough to remove it. MQA understood that heads share more than they diverge. GQA understood that the optimal sharing ratio isn't 1 or H. Sliding window understood that not every token needs global context. FlashAttention understood that the hardware, not the algorithm, was the bottleneck. The lesson isn't about memorizing seven attention methods — it's about learning to see where the real bottleneck is and engineering around it.
You're deploying a 70B model to serve 100 concurrent users at 8K context length on 8×H100 GPUs. Which attention configuration gives the best throughput-per-dollar?