The Complete Beginner's Path

Understand the Transformer
From Scratch

The architecture behind GPT, BERT, LLaMA, and every frontier language model. One paper changed everything — here's how it works.

Prerequisites: Basic linear algebra + Intuition for neural nets. That's it.
11
Chapters
12+
Simulations
0
Assumed ML Knowledge

Chapter 0: What Is a Sequence?

Text is a sequence of words. Audio is a sequence of samples. Video is a sequence of frames. A stock price is a sequence of values over time. Before the Transformer, we processed sequences with recurrent neural networks (RNNs) — one element at a time, left to right. It was slow and it forgot things.

The Transformer processes every element at once. Instead of a conveyor belt, it's a spotlight that shines on the entire sequence simultaneously. This parallelism is why Transformers train so fast on GPUs — and why they scale to billions of parameters.

The core shift: RNNs read a book one word at a time, trying to remember what happened on page 1 when they reach page 300. Transformers can see every page at once and decide which pages matter for understanding each word.
Sequence Types

Click a domain to see how it's tokenized into a sequence.

Key insight: Order matters. "Dog bites man" and "Man bites dog" have the same words but very different meanings. Any useful sequence model must somehow encode position — we'll see how in Chapter 4.
Check: Why do Transformers train faster than RNNs?

Chapter 1: Attention from Scratch

Imagine reading the sentence: "The cat sat on the mat because it was tired." What does "it" refer to? You instantly know it means "the cat." Your brain attends to "cat" when processing "it." That's attention.

In a neural network, each token is a vector (a list of numbers). Attention lets each token compute a weighted average of all other tokens' vectors. The weights come from dot products — measuring how similar two tokens are. Similar tokens get high weights; irrelevant ones get near-zero.

The Q/K/V Machinery

But we don't compare raw token vectors directly. Each token is projected through three learned linear layers to produce three separate vectors: Query (Q), Key (K), and Value (V). Here's the actual tensor math:

python
# Input: x has shape [batch, seq_len, d_model]
# e.g. batch=1, seq_len=10 tokens, d_model=512

Q = x @ W_Q   # [1, 10, 512] @ [512, 512] → [1, 10, 512]
K = x @ W_K   # [1, 10, 512] @ [512, 512] → [1, 10, 512]
V = x @ W_V   # [1, 10, 512] @ [512, 512] → [1, 10, 512]

Three weight matrices, each [d_model, d_model]. That's 512 × 512 × 3 = 786,432 parameters just for one attention layer's projections. These matrices are learned during training — they determine what "looking for" (Q), "containing" (K), and "carrying" (V) mean.

Next, compute the full attention in one shot:

python
scores = Q @ K.transpose(-2, -1)  # [1, 10, 512] @ [1, 512, 10] → [1, 10, 10]
scores = scores / sqrt(d_k)           # scale (Chapter 2 explains why)
weights = softmax(scores, dim=-1)    # [1, 10, 10] — each row sums to 1
output  = weights @ V                  # [1, 10, 10] @ [1, 10, 512] → [1, 10, 512]

The output has the same shape as the input: [batch, seq_len, d_model]. Each token's output is now a weighted mix of all tokens' Value vectors, where the weights are determined by Query-Key similarity. That's the entire mechanism.

The shape story: Input is [B, T, D]. Output is [B, T, D]. In between, we create a [B, T, T] attention matrix — that's the T × T "who attends to whom" map. For a 2048-token sequence, that's a 2048 × 2048 = ~4M entry matrix. This is why attention is O(n²) in sequence length.
Interactive: Token Similarity

Click any token to see its dot-product similarity with every other token. Brighter = higher similarity.

similarity(q, k) = q · k = ∑ qi ki
Why dot products? Two vectors pointing in the same direction have a large dot product. Orthogonal vectors have zero. This gives us a fast, differentiable way to measure "how much should token A care about token B?"
Check: What does a high dot product between two token vectors mean?
🔗 Pattern Recognition
Attention Weights = Optimal Weighting Under Uncertainty
Kalman Filter
K = P·Hᵀ·(H·P·Hᵀ + R)⁻¹
Weights prediction vs. measurement by their uncertainties. High measurement noise → trust prediction more.
Attention
α = softmax(QKᵀ / √dk)
Weights value vectors by their relevance to the query. Low similarity → near-zero weight.

Same deep principle: compute a weighted combination where the weights reflect quality of information. The Kalman gain asks "how much should I trust this measurement?" Attention asks "how much should I trust this token's information?" Both produce optimal weighting given their respective uncertainty models.

Can you spot this same "optimal weighting" pattern when we reach Mixture-of-Experts routing in Chapter 9?

Chapter 2: Scaled Dot-Product Attention

Raw attention has three ingredients. Each token produces three vectors by multiplying its embedding with learned weight matrices:

The attention score between tokens i and j is Qi · Kj. We divide by √dk to prevent the dot products from getting too large (which would make softmax saturate and kill gradients). Then softmax converts scores to weights that sum to 1.

Attention(Q, K, V) = softmax( Q Kᵀ / √dk ) · V
Interactive: Compute Attention Weights

Four tokens with 2D Q and K vectors. Watch how weights shift as you drag the query vector of the selected token.

Select token
Qx1.0
Qy0.5
Why scale? If dk = 64, dot products can easily reach values like 30–50. Softmax of 50 is essentially 1.0, making the model attend to only one token and ignore the rest. Dividing by √64 = 8 keeps things in a reasonable range.
🔨 Derivation Prove why the scaling factor is 1/√dk ✓ ATTEMPTED

Assume each entry of Q and K is drawn i.i.d. from N(0, 1). The dot product is q · k = ∑ qiki over dk dimensions.

Your task: Derive the variance of this dot product. Then explain why dividing by √dk (not dk, not 1) is the correct normalization.

Each qi and ki is N(0,1) and independent. For independent zero-mean random variables, Var(XY) = Var(X)·Var(Y) = 1 × 1 = 1. So each term has variance 1.
Variance is additive for independent terms: Var(∑ Xi) = ∑ Var(Xi) = dk × 1 = dk. So q·k ~ N(0, dk).
If Var(q·k) = dk, then Var(q·k / c) = dk / c². Set this equal to 1: dk/c² = 1, so c = √dk. Dividing by dk would make variance = 1/dk (too small). Dividing by 1 leaves variance = dk (too big).

Full derivation:

1. Each entry qi, ki ~ N(0, 1), independent.

2. Var(qi · ki) = E[qi²]·E[ki²] - (E[qi]·E[ki])² = 1·1 - 0 = 1

3. The dot product sums dk such terms: Var(q·k) = dk

4. Standard deviation = √dk. For dk=64, typical dot products are ±8.

5. Dividing by √dk normalizes to unit variance: scores stay in [-3, 3] range where softmax has healthy gradients.

The key insight: This isn't a heuristic. It's the unique scaling that preserves unit variance regardless of head dimension. It falls directly out of the statistics.

The Variance Proof

Let's prove this rigorously. Assume each entry of Q and K is drawn from N(0, 1) — mean zero, variance one. The dot product of two dk-dimensional vectors is:

q · k = ∑i=1dk qi · ki

Each term qi · ki has mean 0 and variance 1 × 1 = 1. Since the terms are independent, the sum has variance = dk. So the dot product is ~N(0, dk).

For dk = 64, a typical dot product might be ±8. Let's see what softmax does with large values:

concrete numbers
# Without scaling (d_k = 64):
raw scores = [8.0, 0.1, -0.3]
softmax    = [0.9997, 0.0002, 0.0001]  # almost one-hot!

# With scaling (divide by √64 = 8):
scaled     = [1.0, 0.0125, -0.0375]
softmax    = [0.58, 0.21, 0.21]         # smooth distribution
The punch line: Without scaling, softmax becomes almost one-hot — the gradient of the dominant entry is nearly zero, so learning stalls. Dividing by √dk normalizes the variance back to 1, keeping softmax in the regime where gradients flow. It's not a heuristic — it's a direct consequence of the statistics of dot products.
Check: What is the purpose of dividing by √dk?
Checkpoint — Before you move on
In your own words: What would happen to a Transformer's training if you removed the 1/√dk scaling entirely? Be specific about the failure mode.
✓ Gate cleared
Model Answer

Without scaling, dot products grow with dk. For dk=64, scores reach ±8. Softmax of [8, 0.1, -0.3] ≈ [0.9997, 0.0002, 0.0001] — essentially one-hot. The gradient of softmax at saturation is nearly zero: ∂softmax/∂z ≈ 0. This means the attention weights can't update. The model locks onto whichever token happened to have the highest initial score and can never learn to redistribute attention. Training stalls completely in the attention layers.

The specific failure: gradient vanishing in the attention weights, not in the value path. The FFN still trains, but attention becomes a random fixed lookup — catastrophic for language modeling.

Chapter 3: Multi-Head Attention

One set of Q, K, V can only learn one type of pattern. But language has many simultaneous relationships: syntax (subject-verb), coreference (pronoun-noun), semantic similarity, positional patterns. Multi-head attention runs several attention operations in parallel, each with its own learned Q, K, V projections.

If the model dimension is d = 512 and we use h = 8 heads, each head works with d/h = 64 dimensions. After computing attention independently, we concatenate all head outputs and project back to the full dimension.

MultiHead(Q, K, V) = Concat(head1, ..., headh) WO
Input
x ∈ Rn×512
Split into 8 heads
Each head: Q, K, V ∈ Rn×64
Parallel Attention
8 independent softmax(QKᵀ/√64)V
Concat + Project
Concat all heads → WO → Rn×512
Interactive: What Each Head Sees

Select a head to see its attention pattern. Each head learns to focus on different relationships.

How the Split Actually Works

A common misconception: multi-head attention does NOT run 8 separate attention operations with 8 separate weight matrices. In practice, you do one big projection then reshape:

python
# d_model=512, n_heads=8, d_k=64
Q = x @ W_Q    # [B, T, 512] @ [512, 512] → [B, T, 512]

# Reshape into heads:
Q = Q.view(B, T, 8, 64)  # split last dim into 8 heads of 64
Q = Q.transpose(1, 2)      # [B, 8, T, 64] — heads become a batch dim

# Same for K and V. Now attention is a single batched matmul:
scores = Q @ K.transpose(-2, -1)  # [B, 8, T, 64] @ [B, 8, 64, T] → [B, 8, T, T]
# 8 independent T×T attention matrices, computed in one GPU kernel

After attention, we concatenate heads and project back:

python
out = (softmax(scores / 8) @ V)  # [B, 8, T, 64]
out = out.transpose(1, 2).contiguous().view(B, T, 512)  # concat heads
out = out @ W_O                     # [B, T, 512] @ [512, 512] → [B, T, 512]
Key insight: Multi-head attention costs roughly the same as single-head attention with full dimensionality. The split is free — it's just a reshape, not a copy. And you get richer, more diverse attention patterns. The only extra cost is the output projection WO.
💻 Build It Implement Multi-Head Attention from Scratch ✓ ATTEMPTED
You've seen the code above. Now close it mentally and implement multi-head attention yourself. Given the signature below, write the body in the editor. No peeking.
signature import torch import torch.nn.functional as F def multi_head_attention(x, W_Q, W_K, W_V, W_O, n_heads): """ Args: x: [B, T, d_model] — input sequence W_Q, W_K, W_V: [d_model, d_model] — projection matrices W_O: [d_model, d_model] — output projection n_heads: int — number of attention heads Returns: [B, T, d_model] — attended output """
Test case
x = torch.randn(2, 10, 512)
W_Q = W_K = W_V = W_O = torch.randn(512, 512)
out = multi_head_attention(x, W_Q, W_K, W_V, W_O, 8)
assert out.shape == (2, 10, 512)
1. Q = x @ W_Q, then reshape to [B, T, n_heads, d_k] and transpose to [B, n_heads, T, d_k]
2. Same for K, V
3. scores = Q @ K.transpose(-2,-1) / sqrt(d_k)
4. weights = softmax(scores)
5. out = weights @ V → [B, n_heads, T, d_k]
6. Transpose back, reshape to [B, T, d_model], project through W_O
python
def multi_head_attention(x, W_Q, W_K, W_V, W_O, n_heads):
    B, T, d_model = x.shape
    d_k = d_model // n_heads

    # Project
    Q = x @ W_Q  # [B, T, d_model]
    K = x @ W_K
    V = x @ W_V

    # Reshape into heads
    Q = Q.view(B, T, n_heads, d_k).transpose(1, 2)  # [B, h, T, d_k]
    K = K.view(B, T, n_heads, d_k).transpose(1, 2)
    V = V.view(B, T, n_heads, d_k).transpose(1, 2)

    # Scaled dot-product attention
    scores = Q @ K.transpose(-2, -1) / (d_k ** 0.5)
    weights = F.softmax(scores, dim=-1)
    out = weights @ V  # [B, h, T, d_k]

    # Concat heads and project
    out = out.transpose(1, 2).contiguous().view(B, T, d_model)
    return out @ W_O
Bonus challenge: Add a causal mask (lower-triangular) so token i cannot attend to tokens j > i. What single line do you add, and where?
Check: Why use multiple attention heads instead of one?

Chapter 4: Positional Encoding

Attention is permutation-invariant: if you shuffle the input tokens, the attention weights change but the mechanism itself doesn't inherently know the order. "Cat sat mat" and "Mat cat sat" would produce the same attention pattern. We need to inject position information explicitly.

Three approaches:

Sinusoidal (Original)
PE(pos, 2i) = sin(pos / 100002i/d)
PE(pos, 2i+1) = cos(pos / 100002i/d)
Learned (BERT, GPT-2)
Trainable embedding matrix Epos ∈ Rmax_len×d
RoPE (LLaMA, Modern)
Rotate Q and K vectors by position-dependent angle. Relative position encoded in the dot product itself.
Interactive: Sinusoidal Position Encodings

Each row is a position (0–31). Each column is a dimension. Color = encoding value. Notice the wave patterns at different frequencies.

Dimensions32
Why sinusoids? The original paper showed that for any fixed offset k, PE(pos+k) can be represented as a linear function of PE(pos). This means the model can learn to attend to relative positions using simple linear operations.
Check: Without positional encoding, attention is:
🔨 Derivation Prove that sinusoidal encodings allow learning relative position ✓ ATTEMPTED

The original Transformer paper claims: "for any fixed offset k, PE(pos+k) can be represented as a linear function of PE(pos)." This means the model can learn to attend to "3 tokens back" using a simple linear transformation.

Your task: Given PE(pos, 2i) = sin(pos/100002i/d) and PE(pos, 2i+1) = cos(pos/100002i/d), prove that PE(pos+k) = Mk · PE(pos) for some matrix Mk that depends only on k, not on pos.

The angle addition formulas: sin(a+b) = sin(a)cos(b) + cos(a)sin(b) and cos(a+b) = cos(a)cos(b) − sin(a)sin(b). Here a = pos/100002i/d and b = k/100002i/d.
For frequency ω = 1/100002i/d:
[sin(ω(pos+k)), cos(ω(pos+k))] = [sin(ω·pos), cos(ω·pos)] · [[cos(ωk), −sin(ωk)], [sin(ωk), cos(ωk)]]
This is a 2D rotation matrix! The rotation angle depends only on k and the frequency, not on pos.
The full PE vector is d/2 independent (sin, cos) pairs, each at a different frequency. The full transformation PE(pos+k) = Mk · PE(pos) is a block-diagonal matrix of d/2 rotation matrices. Each 2×2 block rotates by a different angle. Mk depends only on k.

Proof: Let ωi = 1/100002i/d. For each dimension pair (2i, 2i+1):

PE(pos+k, 2i) = sin(ωi(pos+k)) = sin(ωi·pos)cos(ωi·k) + cos(ωi·pos)sin(ωi·k)

PE(pos+k, 2i+1) = cos(ωi(pos+k)) = cos(ωi·pos)cos(ωi·k) − sin(ωi·pos)sin(ωi·k)

In matrix form: [PE(pos+k, 2i), PE(pos+k, 2i+1)]ᵀ = R(ωi·k) · [PE(pos, 2i), PE(pos, 2i+1)]ᵀ

where R(θ) is the standard 2D rotation matrix. The full Mk is block-diagonal with d/2 such rotation blocks.

The key insight: Sinusoidal position encodings encode relative position as a rotation. The model's linear layers can learn these rotation matrices, letting it attend to "k positions back" regardless of absolute position. This is why sinusoidal encodings generalize to unseen sequence lengths — and why RoPE (which applies rotation directly to Q and K) is the modern evolution of this idea.

Chapter 5: The Encoder Block

An encoder block takes a sequence and returns a refined sequence of the same shape. It has two sub-layers, each wrapped in a residual connection (add the input back) and layer normalization. The residual connections are critical — they let gradients flow straight through, enabling very deep stacks.

Input
x ∈ Rn×d
LayerNorm
Normalize each token vector
Multi-Head Self-Attention
Each token attends to all tokens
↓ + residual
LayerNorm
Normalize again
Feed-Forward Network
Two linear layers with ReLU/GELU: d → 4d → d
↓ + residual
Output
Same shape: Rn×d
Interactive: Data Flow Through an Encoder Block

Watch a token vector flow through each sub-layer. The residual stream carries information forward.

Click "Step Forward" to begin

Data Flow with Actual Shapes

Let's trace a concrete example through one encoder block. Assume d_model = 512, 8 heads, sequence length = 10:

python
# Input: x = [B, 10, 512]

# Step 1: LayerNorm (normalize each token vector independently)
x_norm = layer_norm(x)       # [B, 10, 512] — same shape

# Step 2: Multi-Head Self-Attention
attn_out = mha(x_norm)       # [B, 10, 512] — same shape

# Step 3: Residual add
x = x + attn_out             # [B, 10, 512] — ADD, not replace

# Step 4: LayerNorm again
x_norm = layer_norm(x)       # [B, 10, 512]

# Step 5: Feed-Forward Network (the 4× expansion)
h = x_norm @ W1 + b1         # [B, 10, 512] @ [512, 2048] → [B, 10, 2048]
h = gelu(h)                  # activation function
ffn_out = h @ W2 + b2        # [B, 10, 2048] @ [2048, 512] → [B, 10, 512]

# Step 6: Residual add
x = x + ffn_out              # [B, 10, 512] — still same shape!

The FFN expands from 512 to 2048 (4×), applies a nonlinearity, then projects back to 512. Why 4×? It's a design choice, not derived from theory. The original paper used 4× and it stuck. Some modern models use 8/3× with gated variants (SwiGLU). The expansion gives each token a wider space to "think" before compressing back.

Parameter count per encoder block (d=512, 8 heads):
Attention: 4 × 512 × 512 = 1,048,576 (WQ, WK, WV, WO)
FFN: 512 × 2048 + 2048 × 512 = 2,097,152
LayerNorm: 2 × 2 × 512 = 2,048
Total: ~3.15M per block. Stack 6 = 18.9M. Stack 96 = 302M just for attention + FFN.
The FFN is a memory bank. The attention layer lets tokens talk to each other. The FFN processes each token independently — it's where the model stores factual knowledge. In large models, the FFN is 2/3 of all parameters.
Check: What is the purpose of the residual connection?
💥 Break-It Lab What Dies When You Remove Components? ✓ ATTEMPTED
A 12-layer Transformer is training normally. Toggle off components below to see what specific pathology emerges. Experts don't just know "use residual connections" — they know exactly what breaks without them and at what depth.
Remove residual connections ACTIVE
Failure mode: Gradient vanishing at depth. Without residuals, gradients must pass through every layer's weights sequentially. By layer 6-8, gradients shrink to ~10⁻¹². Early layers stop learning entirely. The model effectively becomes a 3-4 layer network regardless of nominal depth. This is why pre-Transformer deep networks needed careful initialization (Xavier/Kaiming) — residuals make depth robust.
Remove layer normalization ACTIVE
Failure mode: Activation explosion. Without normalization, activations grow exponentially across layers. By layer 4-5, values reach 10³–10&sup4;. Softmax saturates (same problem as no scaling). Loss spikes, then NaN. You can partially compensate with very small learning rates (10⁻&sup5;), but training is 10× slower and unstable. LayerNorm is the "automatic gain control" that lets you stack layers without hand-tuning initialization.
Remove FFN (attention only) ACTIVE
Failure mode: Expressivity collapse. Attention is a linear operation on V (weighted average). Without the FFN's nonlinearity, stacking N attention layers is equivalent to ~1 attention layer with a richer V projection. The model can route information but can't transform it. Perplexity plateaus ~30% higher than with FFN. The FFN is where factual knowledge lives — it's the "memory bank" of the Transformer.
Remove 1/√dk scaling ACTIVE
Failure mode: Attention collapse to one-hot. Without scaling (dk=64), dot products are ~N(0, 64). Softmax of scores like [8, -2, 1] → [0.999, 0.000, 0.001]. Every token attends to exactly one other token. No mixing, no soft weighting. The model becomes a hard lookup table. Training still converges but to a much worse solution — about 40% higher loss.
Toggle components off to see what breaks. An expert knows the failure modes.

Chapter 6: The Decoder Block

The decoder block has the same structure as the encoder, plus two crucial additions: causal masking and (in encoder-decoder models) cross-attention.

Causal masking: During generation, token i must not see tokens i+1, i+2, ... (the future). We achieve this by setting those attention scores to −∞ before softmax, which forces their weights to zero. This creates a lower-triangular attention matrix.

Cross-attention: In translation models, the decoder attends to the encoder's output. The decoder provides Q; the encoder provides K and V. This is how the decoder "reads" the source language.

Interactive: Causal Mask

The attention matrix before and after masking. White cells are visible; dark cells are masked (−∞). Each token can only see itself and earlier tokens.

Showing: Full attention (no mask)
Decoder Input
Previously generated tokens
Masked Self-Attention
Causal: can only look backward
↓ + residual
Cross-Attention
Q from decoder, K/V from encoder
↓ + residual
Feed-Forward Network
Same as encoder FFN
↓ + residual
Output
Logits over vocabulary
Decoder-only models (GPT, LLaMA) skip cross-attention entirely. They use only masked self-attention + FFN. The entire "encoder" is implicit: the prompt IS the encoding. This simplification is why decoder-only models dominate today.
Check: Why does the decoder use causal masking?
⚔ Adversarial: You're building a fill-in-the-blank model (like BERT) but accidentally use causal masking. What breaks?
Input: "The [MASK] sat on the mat." You want the model to predict "cat" using context from BOTH sides. But causal masking is active.

Chapter 7: Training

Training a Transformer for language modeling is deceptively simple: given a sequence of tokens, predict the next token at every position. The loss function is cross-entropy between the predicted probability distribution and the actual next token.

L = − ∑t log P(xt+1 | x1, ..., xt)

Teacher forcing: During training, we don't use the model's own predictions as input for the next step. Instead, we always feed the true previous tokens. This is faster and more stable than autoregressive training, but it means the model never sees its own mistakes during training.

Interactive: Cross-Entropy Loss

The model predicts a probability for the correct next token. Drag the slider to see how loss changes. Higher confidence in the right answer = lower loss.

P(correct)0.50
The beauty of causal masking: With one forward pass, the model produces predictions for every position simultaneously. Position 1 predicts token 2, position 2 predicts token 3, and so on. One sequence gives you n−1 training examples for free.
ConceptWhat It Does
Next-token predictionThe training objective: predict xt+1 from x1..t
Cross-entropy lossMeasures how far predicted distribution is from truth
Teacher forcingUse true tokens (not predictions) as input during training
AdamW optimizerAdaptive learning rate + weight decay
Warmup + cosine decayGradually increase then decrease learning rate
Check: What is teacher forcing?
⚔ Adversarial: Teacher forcing uses true tokens as input, but at inference the model uses its own predictions. What specific failure mode does this mismatch create?
Scenario: Your model generates "The cat sat on the" perfectly, then produces "mxt" (garbage). Now it must continue from "mxt" — a state it never saw during training.

Chapter 8: KV Cache & Inference

During generation, the model produces one token at a time. Without optimization, generating token n requires recomputing attention over all n−1 previous tokens — that's O(n²) total work for a full sequence. The KV cache stores the Key and Value vectors from previous tokens so we never recompute them.

Step 1
Process "The" → compute K1, V1 → store in cache
Step 2
Process "cat" → compute K2, V2 → append to cache → attend to [K1,K2]
Step 3
Process "sat" → compute K3, V3 → append → attend to [K1,K2,K3]
↓ ...
Step n
Only compute Qn for the new token, reuse all cached K, V
Interactive: Watch the KV Cache Grow

Click "Generate Token" to add one token. The cache (blue bars) grows while each step only computes one new Q (orange).

Cache: 0 tokens | Total KV memory: 0 KB

Exact Memory Calculation

The KV cache stores K and V tensors for every layer and every token generated so far. Here's the formula:

KV memory = num_layers × 2 × seq_len × d_model × dtype_bytes

The "2" is for K and V. Let's compute this for real models:

ModelLayersd_modelSeq LenKV Cache Size
GPT-2 Small127681,02436 MB
GPT-3 (175B)9612,2882,0489.4 GB
LLaMA-2 70B808,1924,09610.5 GB
LLaMA-3 70B808,192128,000~328 GB

For GPT-3, the calculation: 96 layers × 2 (K+V) × 2,048 tokens × 12,288 dims × 2 bytes (FP16) = 9.66 × 109 bytes ≈ 9.4 GB. That's just for one user's cache — serving 100 concurrent users needs 940 GB of KV cache memory alone.

Why this matters for deployment: The model weights for a 70B model are ~140 GB (FP16). But the KV cache for a 128K context is 328 GB — more than double the weights. This is why techniques like GQA (Grouped-Query Attention, sharing K/V across query heads) and quantized KV caches (INT8 or INT4) are essential for production serving.
Memory cost: For a 70B parameter model with 128K context, the KV cache alone can require ~40 GB of memory per user. This is often the bottleneck for serving large models, not the model weights themselves.
🏗 Design Challenge You're the Architect: Serve a 70B Model to 100 Users ✓ ATTEMPTED
Your company needs to serve a 70B-parameter Transformer to 100 concurrent users with 32K context windows. You have a budget for 8× H100 GPUs (80GB each = 640GB total). Design the memory layout.
Model weights
70B params × 2 bytes = 140 GB (FP16)
KV per user
80 layers × 2 × 32K × 8192 × 2B = ?
Total GPU memory
640 GB across 8 GPUs
Concurrent users
100
1. Calculate: How much KV cache memory per user? Can you fit 100 users?
2. If it doesn't fit, which techniques would you apply? (GQA, KV quantization, paged attention, shorter context?)
3. What's the trade-off of each technique?

The math: KV per user = 80 × 2 × 32,768 × 8,192 × 2 = 82 GB per user. For 100 users: 8,200 GB. Plus 140 GB weights. Total: 8,340 GB. You have 640 GB. You're 13× over budget.

Real solutions, in order of impact:

1. GQA (Grouped-Query Attention): LLaMA-3 70B uses 8 KV heads instead of 64 query heads. KV cache shrinks by 8× → 10.25 GB/user.

2. KV cache quantization (INT8): Another 2× reduction → 5.1 GB/user.

3. Paged Attention (vLLM): Don't pre-allocate full 32K. Most requests use 2-4K tokens. Only allocate pages as needed → average 3-5 GB/user effective.

4. Tensor parallelism: Shard weights across 8 GPUs (17.5 GB each). Remaining ~62 GB per GPU for KV cache. 8 GPUs × 62 GB = ~496 GB for KV. At 5 GB/user effective: ~99 concurrent users. It fits.

The insight: no single technique solves it. You need GQA + quantization + paged allocation + parallelism together. This is why LLaMA chose GQA — it's a serving-time decision made at training time.

Check: What does the KV cache store?

Chapter 9: MoE & Scaling

As models get bigger, we face a dilemma: more parameters = better quality, but also more compute per token. Mixture of Experts (MoE) breaks this tradeoff. Instead of one giant FFN, we have many smaller "expert" FFNs. A router selects the top-k experts for each token. Only those experts run.

y = ∑i∈TopK gi · Experti(x)
Interactive: Expert Routing

Each token is routed to 2 of 8 experts. Different tokens activate different experts. Click "Route" to see a new random routing.

Scaling Laws

Kaplan et al. (2020) discovered that model performance follows predictable power laws:

Loss ∝ N−0.076 where N = number of parameters. Double the parameters, loss drops ~5%. This is why labs race to build bigger models — the returns are predictable.

Efficient Attention

TechniqueIdeaSpeedup
Flash AttentionTile computation to stay in SRAM, never materialize full attention matrix2–4x
Ring AttentionDistribute sequence across GPUs, pass KV blocks in a ringLinear in #GPUs
Grouped-Query AttentionShare KV heads across multiple Q heads1.5–2x less KV memory
Sliding WindowEach token only attends to nearby tokensO(n·w) instead of O(n²)
Check: What is the key advantage of Mixture of Experts?
🏗 Design Challenge You're the Architect: Design an MoE Routing Strategy ✓ ATTEMPTED
You're training a 1T-parameter MoE model with 128 experts, top-2 routing. During training, you discover that 90% of tokens route to just 3 experts while 125 experts are nearly dormant. The model performs poorly. Fix the routing.
Total experts
128 per MoE layer
Active experts/token
Top-2 (must stay at 2)
Training budget
2T tokens, can't restart
Problem
Expert collapse: 3 experts get 90% of traffic
1. What causes expert collapse? (Think about the feedback loop in softmax routing)
2. Name 3 techniques to force balanced routing without hurting quality
3. What's the trade-off of each? (capacity factor, auxiliary loss, random routing)

Root cause: Softmax routing creates a rich-get-richer feedback loop. The expert that handles slightly more tokens gets more gradient signal, becomes slightly better, then attracts even more tokens. Within 1000 steps, this snowballs into complete collapse.

Industry solutions:

1. Load balancing loss (Switch Transformer): Add an auxiliary loss α·∑ fi · Pi where fi = fraction of tokens routed to expert i, Pi = average router probability for expert i. This penalizes concentration. α = 0.01 typically. Trade-off: can reduce model quality by ~0.5% if α too high.

2. Expert capacity factor (GShard): Cap each expert to at most C × (N/E) tokens per batch (C≈1.25). Overflow tokens are dropped or routed to a second choice. Trade-off: dropped tokens lose information; too low C wastes expert capacity.

3. Random routing with learned bias (BASE layers): Add noise to router logits during training. Top-2 from (logits + noise). Trade-off: slows convergence early but prevents lock-in.

4. Expert choice routing (Zhou et al.): Flip the problem — each expert CHOOSES its top-k tokens instead of each token choosing experts. Guarantees perfect balance by construction. Trade-off: variable number of experts per token; some tokens may get 0 experts.

Mixtral-8x7B uses a simple top-2 softmax router with load balancing loss. DeepSeek-MoE adds shared experts (always active) plus routed experts. The key insight: you need an explicit mechanism to fight the rich-get-richer dynamic. Softmax alone will always collapse.

⚔ Adversarial: Your MoE model has 8 experts with top-1 routing. At inference, you observe that the router assigns identical probabilities (0.125 each) to all experts for a given token. What happens?
The router softmax output is [0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125]. Top-1 selects one expert. But which one?

Chapter 10: What Makes It Work

The Transformer's power isn't in any single component — it's in how they compose. Here are the key phenomena researchers have discovered:

The Residual Stream

Think of the residual connections as a communication bus. Each layer reads from the stream, processes information, and writes its contribution back. Information from early layers is never destroyed — it flows all the way to the final layer. This is fundamentally different from a pipeline where each stage replaces the previous output.

Induction Heads

One of the most remarkable discoveries: pairs of attention heads that implement in-context pattern completion. If the model sees "Harry Potter is a wizard... Harry Potter is a", the induction head copies "wizard" from the earlier occurrence. This is a key mechanism behind in-context learning — the ability to learn new tasks from examples in the prompt without any weight updates.

Interactive: Induction Head Pattern

The sequence repeats. Watch how the attention pattern forms a diagonal stripe shifted by the repeat length — that's the induction head copying from the first occurrence.

Repeat offset5

Emergent Abilities

PhenomenonWhat Happens
In-context learningLearns new tasks from examples in the prompt
Chain-of-thoughtStep-by-step reasoning improves accuracy
Few-shot generalizationSolves unseen tasks with just a few examples
Tool useLearns to call APIs, write code, use calculators
The deep mystery: We designed the Transformer for translation. We didn't design it for reasoning, coding, or creative writing. These abilities emerged from scale. Understanding why is one of the most important open questions in AI.
🔗 The Same Architecture Everywhere
One Architecture, Five Domains

The Transformer you just learned is the same architecture behind:

Vision (ViT)
Patch an image into 16×16 tokens. Apply standard Transformer. The attention matrix shows which patches relate. → VLM lesson
Diffusion (DiT)
Replace the U-Net with a Transformer. Noisy patches as tokens. Cross-attend to text conditioning. → Diffusion lesson
RL (Decision Transformer)
Tokenize (reward, state, action) triples. Predict next action autoregressively. Offline RL as sequence modeling. → RL lesson
Robotics (VLA)
Vision tokens + language tokens + proprioception → action tokens. The same Q/K/V mechanism deciding what to attend to. → VLA lesson

The pattern: tokenize your domain, then attend. Any sequential or set-structured problem can be cast as a Transformer problem. The architecture doesn't know or care what the tokens represent.

"Attention is all you need."
— Vaswani et al., 2017

You now understand the architecture that powers every frontier AI model. From dot products to multi-head attention, from encoder blocks to KV caches — this is the foundation of modern AI.