Workbook — Transformer Mathematics

Transformer Math

Every number a transformer engineer needs to derive from scratch. Parameter counts, attention arithmetic, memory budgets, KV caches, throughput — all solvable in-browser with instant feedback.

Prerequisites: Basic linear algebra (matrix multiply shapes) + Exponents (powers of 2). That's it.
10
Chapters
52
Exercises
5
Exercise Types
Mastery
0 / 52 exercises (0%)
0
Day Streak
Best: 0

Chapter 0: Parameter Counting

You're reviewing a pull request that claims to add a "small" 125M parameter model. How do you verify that number? You need to count every learnable weight in the transformer — every matrix, every bias, every embedding vector. Off-by-one on a dimension means off-by-millions on the parameter count.

A standard decoder-only transformer (GPT-style) has these learnable components:

Per layer (repeated L times):
Self-attention: WQ, WK, WV, WO   each [d, d]   // 4d² params
FFN: W1 [d, 4d] + W2 [4d, d]            // 8d² params
LayerNorm × 2: γ, β each [d]         // 4d params (negligible)
Total per layer: 12d² (ignoring biases and LN)

Non-repeated:
Token embedding: [V, d]                   // Vd params
Position embedding: [T, d]                 // Td params (if learned)
Final LayerNorm: [d]                        // 2d params
LM head: often tied to embedding           // 0 if tied, Vd if not

Total ≈ 12Ld² + Vd // when embedding is tied and T << V
The 12d² rule. Each transformer layer has roughly 12d² parameters. This is the single most useful formula for quick estimation. A 96-layer model with d=12288 has approximately 96 × 12 × 12288² ≈ 174 billion parameters — that's GPT-3 scale.
Exercise 0.1: GPT-2 Small Derive

GPT-2 Small: d=768, L=12, V=50257, T=1024 (learned position embeddings, tied LM head). How many total parameters?

Compute: 12 × L × d² + V×d + T×d + final_LN. Ignore per-layer LN and biases for now.

params (millions)
Show derivation
Attention + FFN per layer = 12 × 768² = 12 × 589,824 = 7,077,888
All layers = 12 × 7,077,888 = 84,934,656
Token embedding = 50,257 × 768 = 38,597,376
Position embedding = 1,024 × 768 = 786,432
Final LN = 2 × 768 = 1,536
Total = 84,934,656 + 38,597,376 + 786,432 + 1,536 = 124,320,000 ≈ 124.3M

The official GPT-2 reports "124M parameters." With per-layer LN biases and attention biases added, you get ~124.4M. The embedding dominates the non-layer parameters (38.6M of the 39.4M non-layer params).

Exercise 0.2: LLaMA 7B Derive

LLaMA 7B: d=4096, L=32, V=32000. Uses RoPE (no learned position embeddings), untied LM head, SwiGLU FFN with dff=11008 (not 4d). RMSNorm instead of LayerNorm (only γ, no β).

SwiGLU has 3 matrices: Wgate[d, dff], Wup[d, dff], Wdown[dff, d]. So FFN params = 3 × d × dff.

billion params
Show derivation
Attention per layer = 4 × d² = 4 × 4096² = 67,108,864
SwiGLU FFN per layer = 3 × 4096 × 11008 = 135,266,304
RMSNorm per layer = 2 × d = 8,192 // γ only, no β
Per layer total = 67,108,864 + 135,266,304 + 8,192 = 202,383,360
All 32 layers = 32 × 202,383,360 = 6,476,267,520
Token embedding = 32,000 × 4,096 = 131,072,000
LM head (untied) = 32,000 × 4,096 = 131,072,000
Final RMSNorm = 4,096
Total = 6,476,267,520 + 131,072,000 + 131,072,000 + 4,096 = 6,738,415,616 ≈ 6.74B

The SwiGLU FFN is 3d×dff instead of 8d² — that's why LLaMA uses dff = 11008 instead of the standard 4d = 16384. The choice of 11008 = 2⅔ × 4096 keeps the total parameter count similar to a standard 8d² FFN while getting the gating benefit.

Exercise 0.3: Which Component Dominates? Trace
For a 70B parameter model with d=8192, L=80, V=32000: what fraction of total parameters are in the transformer layers (attention + FFN) versus the embeddings?
Show derivation
Layers ≈ 12 × 80 × 8192² = 64,424,509,440 ≈ 64.4B
Embeddings = 2 × 32,000 × 8,192 = 524,288,000 ≈ 0.52B
Layer fraction = 64.4 / (64.4 + 0.52) ≈ 99.2%

At 70B scale, the embedding table is less than 1% of total params. The 12Ld² approximation becomes extremely accurate. This is why "70B parameters" essentially means "70 billion weights in the attention and FFN layers."

Exercise 0.4: Design a 1B Model Design

You need approximately 1 billion parameters. Using the 12Ld² rule, which (d, L) combination gets closest to 1B?

Show reasoning

All options give ~0.8B from layers alone, but architecture design is about balance. Too few layers (L=4) means very shallow — poor reasoning capability. Too many layers with tiny d (L=256, d=512) means each layer has very little capacity and you pay huge activation memory.

The sweet spot for ~1B is d=2048, L=16-24. This matches models like TinyLLaMA (d=2048, L=22). The rule of thumb: d should be ~64-128 times L for a well-balanced architecture.

Exercise 0.5: Implement paramCount() Build

Write a function that computes the parameter count of a standard GPT-style transformer. Assume tied embeddings, standard 4d FFN, and ignore biases/LayerNorm.

Your function should return a single number (total params).
Show solution
javascript
function paramCount(d, L, V) {
  const perLayer = 12 * d * d;  // 4d² attn + 8d² FFN
  const embedding = V * d;       // tied, so counted once
  return L * perLayer + embedding;
}
Exercise 0.6: Find the Bug Debug

This parameter counting function has a bug. Someone is getting 2x the expected count. Click the buggy line.

function paramCount(d, L, V) {
  const attn = 4 * d * d;     // Q, K, V, O projections
  const ffn = 8 * d * d;      // W1 and W2
  const perLayer = attn + ffn; // 12d^2 total
  const layers = L * perLayer;
  const embed = V * d;        // token embedding
  const head = V * d;         // LM head (tied!)
  return layers + embed + head;
}
Show explanation

Line 7 is the bug. The comment says "tied!" but the code still adds V * d for the LM head. If the embedding is tied (shared) with the LM head, the LM head adds zero extra parameters — the same weight matrix is reused. The line should be const head = 0; or simply removed.

This is a real mistake people make: they know the weights are tied but still count them twice. It inflates the count by Vd, which for V=50257 and d=768 is 38.6M — a 31% overcount for GPT-2 Small.

Exercise 0.7: Mixtral 8x7B Derive

Mixtral 8x7B uses Mixture of Experts. Each layer has 8 expert FFNs (each d×dff×3 SwiGLU with dff=14336), but only top-2 are active per token. The attention layers are shared (not duplicated). d=4096, L=32, V=32000.

Total params = shared attention params + 8 × per-expert FFN params + embeddings. (The "7B" refers roughly to the active params per token, not total.)

billion total params
Show derivation
Attention per layer = 4 × 4096² = 67,108,864
One expert FFN = 3 × 4096 × 14336 = 176,160,768
8 experts per layer = 8 × 176,160,768 = 1,409,286,144
Router per layer = 4096 × 8 = 32,768 // negligible
Per layer = 67,108,864 + 1,409,286,144 + 32,768 = 1,476,427,776
All layers = 32 × 1,476,427,776 = 47,245,688,832
Embeddings = 2 × 32,000 × 4,096 = 262,144,000
Total ≈ 47.2B + 0.26B ≈ 46.7B

The official Mixtral paper reports ~46.7B total parameters. But active parameters per token are only attention (4d²) + 2 experts (2 × 3 × d × dff) = 67M + 352M = 419M per layer, times 32 = ~12.9B active. This is why it performs like a ~13B dense model while storing 47B weights.

Chapter 1: Attention Math

You're debugging a transformer and the attention outputs look wrong. Is it the scaling? The masking? A dimension mismatch? You need to be able to trace the full attention computation by hand — input to output, with exact shapes and values.

Scaled dot-product attention:
Attention(Q, K, V) = softmax(QKT / √dk) · V

Shapes (single head, single sequence):
Q: [seq, dk]   K: [seq, dk]   V: [seq, dv]
QKT: [seq, seq]   softmax(QKT/√dk): [seq, seq]   output: [seq, dv]
Why √dk? Without scaling, the dot products grow proportional to dk. If Q and K entries are ~N(0,1), each dot product is a sum of dk products, giving variance dk. Dividing by √dk normalizes the variance back to 1, keeping softmax in a well-behaved regime. Without this, softmax saturates and gradients vanish.
Exercise 1.1: Attention FLOPs Derive

For a single attention head with seq_len=2048, d_k=d_v=128: how many FLOPs for the QKT matmul alone? (Count multiply+add as 2 FLOPs per output element.)

million FLOPs
Show derivation
QKT: [2048, 128] × [128, 2048] → [2048, 2048]
FLOPs = 2 × 2048 × 128 × 2048 = 1,073,741,824 ≈ 1073.7M

This is a [seq, d_k] × [d_k, seq] matmul, producing a [seq, seq] attention matrix. The FLOPs scale as O(seq² × d_k) — quadratic in sequence length. This is the bottleneck that motivates FlashAttention and linear attention variants.

Exercise 1.2: Full Attention FLOPs Derive

Now count the total FLOPs for the FULL self-attention of one layer: Q/K/V projections + QKT + softmax×V + output projection. d=4096, h=32 heads, d_k=d_v=128, seq=2048, batch=1.

Projections: 4 matmuls of [seq, d] × [d, d]. QKT and attn×V: per head, then sum across heads.

billion FLOPs
Show derivation

The full self-attention FLOPs formula is: 8sd² + 4s²d where the first term covers the 4 projection matmuls and the second covers QKT + attn×V.

Projections (Q, K, V, O): 4 × (2 × s × d × d) = 8sd²
= 8 × 2048 × 4096² = 274,877,906,944 ≈ 274.9B
QKT all heads: h × (2 × s × dk × s) = 2s²d
= 2 × 2048² × 4096 = 34,359,738,368 ≈ 34.4B
attn × V all heads: same = 34.4B
Total = 274.9 + 34.4 + 34.4 = 343.6 billion FLOPs

The projections dominate when s < d (which is typical for s=2048, d=4096). The QKT and attn×V terms grow as s² — they overtake projections only when s > 2d. At s=8192, both terms would be ~274.9B, making attention core half of total.

Exercise 1.3: Attention Memory Derive

The attention weight matrix (after softmax) has shape [batch, heads, seq, seq]. For batch=1, heads=32, seq=8192, stored in FP32: how much memory does this matrix alone consume?

GB
Show derivation
Elements = 1 × 32 × 8192 × 8192 = 2,147,483,648
Bytes (FP32) = 2,147,483,648 × 4 = 8,589,934,592 bytes ≈ 8.59 GB

8.6 GB for a single attention matrix of a single layer! This is why FlashAttention is essential — it avoids materializing this O(s²) matrix entirely, computing attention in tiles that fit in SRAM. Without FlashAttention, seq_len=8192 is impractical even on an 80GB A100.

Exercise 1.4: Scaling Factor Derive

If Q and K vectors have entries drawn i.i.d. from N(0, 1), and d_k = 128, what is the expected variance of a single dot product q · k before scaling?

variance
Show derivation
q · k = ∑i=1dk qi ki
E[qi ki] = E[qi] E[ki] = 0 (independent, zero mean)
Var(qi ki) = E[qi²] E[ki²] = 1 × 1 = 1
Var(q · k) = ∑i Var(qi ki) = dk = 128

After dividing by √dk = √128 ≈ 11.3, the variance becomes 128/128 = 1. This keeps the softmax inputs in a stable range regardless of head dimension.

Exercise 1.5: Causal Mask Savings Trace
A causal mask zeros out the upper triangle of the attention matrix. For seq_len=1024, what fraction of the attention computation can you theoretically skip?
Show derivation
Total elements = 1024² = 1,048,576
Upper triangle (excluding diagonal) = 1024 × 1023 / 2 = 523,776
Fraction = 523,776 / 1,048,576 ≈ 49.95%

In practice, naive implementations still compute the full QKT and then mask. FlashAttention's causal variant actually skips these computations by processing only the lower-triangular tiles, giving a real ~2x speedup over non-causal attention.

Exercise 1.6: Implement scaledDotProduct() Build

Implement scaled dot-product attention for 1D arrays representing a single query and set of keys/values. No batch, no heads — just the core math.

Show solution
javascript
function scaledDotProduct(q, keys, values) {
  const dk = q.length;
  const scale = Math.sqrt(dk);

  // 1-2. Scaled dot products
  const scores = keys.map(k =>
    k.reduce((s, ki, i) => s + ki * q[i], 0) / scale
  );

  // 3. Softmax (numerically stable)
  const maxS = Math.max(...scores);
  const exps = scores.map(s => Math.exp(s - maxS));
  const sumE = exps.reduce((a, b) => a + b, 0);
  const weights = exps.map(e => e / sumE);

  // 4. Weighted sum of values
  const dv = values[0].length;
  const out = new Array(dv).fill(0);
  for (let i = 0; i < weights.length; i++)
    for (let j = 0; j < dv; j++)
      out[j] += weights[i] * values[i][j];
  return out;
}

Chapter 2: Multi-Head Attention

A single attention head can only learn one type of relationship at a time — maybe syntactic structure, or maybe semantic similarity, but not both simultaneously. Multi-head attention (MHA) runs multiple independent attention heads in parallel, each with its own Q/K/V projections, then concatenates and projects the results.

MHA variants:
Multi-Head (MHA): h heads, each with Qi,Ki,Vi ∈ [s, d/h]
→ KV params per layer = 2 × d × d = 2d² // h separate K,V projections

Multi-Query (MQA): h Q heads, 1 shared K,V head
→ KV params per layer = 2 × d × (d/h) // single K,V projection

Grouped-Query (GQA): h Q heads, g groups sharing K,V
→ KV params per layer = 2 × d × g × (d/h) = 2d² × g/h

MHA: g = h    MQA: g = 1    GQA: 1 < g < h
Why GQA matters. During autoregressive decoding, every past token's K and V vectors must be stored in the KV cache. MQA reduces KV cache by h× (e.g., 32×) but can hurt quality. GQA (used in LLaMA 2 70B, Gemma, Mistral) with g=8 gives 4× cache reduction with minimal quality loss. This is a pure memory-latency tradeoff.
Exercise 2.1: MHA Parameter Breakdown Derive

For MHA with d=4096, h=32 heads: how many parameters in WQ, WK, WV, WO combined (one layer, no biases)?

million params
Show derivation
WQ: [4096, 4096] = 16,777,216
WK: [4096, 4096] = 16,777,216
WV: [4096, 4096] = 16,777,216
WO: [4096, 4096] = 16,777,216
Total = 4 × 16,777,216 = 67,108,864 ≈ 67.1M

This is the 4d² from the parameter counting formula. Each W matrix maps from d to d. The multi-head split is a reshaping operation: WQ produces [s, d] which is reshaped to [s, h, d_k] where d_k = d/h = 128. The parameter count is the same regardless of the number of heads.

Exercise 2.2: GQA Savings Derive

LLaMA 2 70B uses GQA with h=64 query heads and g=8 KV groups. d=8192. How many parameters are in K and V projections combined (one layer)?

Each group has its own K and V projection of size [d, d_k] where d_k = d/h = 128.

million params
Show derivation
dk = d / h = 8192 / 64 = 128
K projection: g groups × [d, dk] = 8 × [8192, 128] = 8 × 1,048,576 = 8,388,608
V projection: same = 8,388,608
Total K+V = 16,777,216 ≈ 16.8M

Compare to full MHA: K+V would be 2 × 8192 × 8192 = 134.2M. GQA with g=8 saves 8× on K/V parameters: 16.8M vs 134.2M. The Q projection and O projection remain full-size (h=64 heads), so the total attention params drop from 4d² = 268.4M to d² + d² + 2d×g×d_k = 134.2M + 16.8M = 151.0M per layer.

Exercise 2.3: MQA vs MHA Quality Tradeoff Trace
Multi-Query Attention (g=1) reduces KV cache by h×. For a 32-head model, this is a 32× reduction. What is the primary quality risk?
Exercise 2.4: Head Dimension Sweet Spot Derive

Most transformers use d_k = 64 or d_k = 128 per head. For d=4096, how many heads do you get with d_k=64 vs d_k=128? And what is the FLOPs ratio for the QKT computation (per layer, seq=2048)?

FLOPs ratio (d_k=64 / d_k=128)
Show derivation
d_k=64: h = 4096/64 = 64 heads
d_k=128: h = 4096/128 = 32 heads
QKT FLOPs per head = 2 × s × d_k × s = 2 × s² × d_k
Total QKT = h × 2 × s² × d_k = 2 × s² × h × d_k = 2 × s² × d

Wait — h × d_k = d in both cases! So the total QKT FLOPs = 2s²d is independent of the head dimension choice. The ratio is exactly 1.0... but actually, within each head the QKT FLOPs are 2 × s × d_k × s. Summing over h heads: h × 2s²d_k. With d_k=64, h=64: 64 × 2 × s² × 64 = 2s² × 4096. With d_k=128, h=32: 32 × 2 × s² × 128 = 2s² × 4096. Same!

So the answer is 0.5 for the per-head FLOPs (each head does half the work), but 1.0 for the total FLOPs (you have twice as many heads). The real difference is in the attention memory: d_k=64 produces 64 attention matrices of size [s,s] vs 32, doubling the activation memory. This is why modern models prefer d_k=128 — same total FLOPs but half the attention memory.

Exercise 2.5: Arrange the GQA Pipeline Design

Put these GQA operations in the correct order for a forward pass.

?
?
?
?
?
?
Project Q (h heads) Project K,V (g groups) Repeat K,V to match h Compute QKT/√dk Softmax × V Concat heads + WO
Show explanation

The correct order is: Project QProject K,VRepeat K,VQKT/√dkSoftmax × VConcat + WO

The key GQA-specific step is "Repeat K,V" — each of the g KV groups is broadcast (repeated) to serve h/g query heads. In LLaMA 2 70B with h=64, g=8: each KV group serves 8 query heads. This repeat is a memory-free reshape/expand operation in PyTorch, not a physical copy.

Chapter 3: Memory Budget

You have an 80GB A100. Can you fit a 7B model for training? What about inference? The answer depends on exactly what lives in GPU memory simultaneously. There are four categories of memory consumers, and they behave very differently between training and inference.

Training memory:
1. Model weights: P × bytes_per_param
2. Gradients: P × bytes_per_param // same shape as weights
3. Optimizer states: P × state_multiplier
• SGD: 0 extra (updates in-place)
• Adam: 2 × P × 4 bytes // m and v in FP32, always
4. Activations: depends on seq_len, batch, checkpoint strategy

Rule of thumb for mixed-precision Adam training:
Weights (BF16): 2P bytes
Gradients (BF16): 2P bytes
Master weights (FP32): 4P bytes // optimizer keeps FP32 copy
Adam m (FP32): 4P bytes
Adam v (FP32): 4P bytes
Total ≈ 16P bytes (before activations)

Inference memory:
Weights only: P × bytes_per_param // + KV cache (see Ch 4)
The 16-bytes-per-param rule. Mixed-precision Adam training needs ~16 bytes per parameter. A 7B model needs 7B × 16 = 112 GB just for weights+gradients+optimizer, before any activations. This is why a single A100 (80GB) cannot train a 7B model without DeepSpeed ZeRO or FSDP.
Exercise 3.1: Training Memory for LLaMA 7B Derive

LLaMA 7B has 6.74B parameters. Using mixed-precision Adam (16 bytes/param), how much memory for weights + gradients + optimizer states? Will it fit on a single A100 80GB?

GB
Show derivation
Memory = 6.74 × 109 × 16 bytes = 107.84 × 109 bytes = 107.8 GB

107.8 GB > 80 GB, so no, it does not fit on a single A100. And this is BEFORE activations! With activations (even with gradient checkpointing), you'd need ~130-150 GB. You need either: (a) 2+ GPUs with FSDP/ZeRO-3, or (b) quantized training (QLoRA), or (c) CPU offloading.

Exercise 3.2: Inference Memory Derive

For inference, you only need the weights (no gradients or optimizer states). How much memory does LLaMA 7B need in BF16? In INT4 (4-bit quantization)?

GB in BF16
Show derivation
BF16: 6.74 × 109 × 2 bytes = 13.48 GB
INT4: 6.74 × 109 × 0.5 bytes = 3.37 GB

BF16 inference of a 7B model fits comfortably on a single RTX 4090 (24GB). INT4 fits on a laptop GPU with 4GB VRAM. This is why quantization is so powerful for deployment: it transforms "needs a server" into "runs on a phone."

Exercise 3.3: Activation Memory Derive

During training, activations must be stored for backprop. For a single transformer layer with d=4096, seq=2048, batch=8: estimate the activation memory for the attention output + FFN output (BF16, ignoring intermediate FFN activations).

Each layer produces: attention output [b, s, d] + FFN output [b, s, d] = 2 × b × s × d elements.

MB per layer
Show derivation
Elements per layer = 2 × 8 × 2048 × 4096 = 134,217,728
Bytes (BF16) = 134,217,728 × 2 = 268,435,456 bytes ≈ 256 MB

256 MB per layer × 32 layers = ~8.2 GB just for these two tensors. But the real activation memory is much worse: you also store Q, K, V tensors, attention matrices (O(s²) per head!), FFN intermediates, and normalization stats. The full activation memory for a 7B model at batch=8, seq=2048 is typically 50-80 GB without gradient checkpointing.

Exercise 3.4: Gradient Checkpointing Savings Trace
Gradient checkpointing (activation recomputation) saves activations at layer boundaries but recomputes them during backward. For a 32-layer model, what is the approximate memory-compute tradeoff?
Exercise 3.5: How Many GPUs for 70B Training? Derive

LLaMA 2 70B has ~70B parameters. With mixed-precision Adam (16 bytes/param), how much total memory for weights+gradients+optimizer? How many A100 80GB GPUs minimum (just for these, ignoring activations)?

minimum GPUs (A100 80GB)
Show derivation
Memory = 70 × 109 × 16 = 1,120 GB = 1.12 TB
GPUs = ⌈1120 / 80⌉ = 14 GPUs minimum

14 GPUs just for parameters, but you still need activation memory. In practice, 70B training uses 64-256 A100s depending on batch size and sequence length. Meta used 2048 A100s for the original LLaMA 2 70B training — the extra GPUs provide data parallelism for throughput, not just memory.

Exercise 3.6: LoRA Memory Savings Derive

LoRA freezes the base model and adds low-rank adapters. For a 7B model (6.74B params), you add rank-16 adapters to Q and V projections only. Each adapter pair (A, B) has shapes [d, r] and [r, d]. With d=4096, r=16, L=32 layers: how many trainable parameters? What's the ratio to full model?

million trainable params
Show derivation
Per adapter pair: [d, r] + [r, d] = 4096 × 16 + 16 × 4096 = 131,072
Per layer (Q + V): 2 × 131,072 = 262,144
All layers: 32 × 262,144 = 8,388,608 ≈ 8.39M
Ratio: 8.39M / 6,740M = 0.12%

Only 0.12% of the base model size! Training memory for LoRA: base weights (BF16, frozen) = 13.5 GB + gradients for 8.4M params = 16.8 MB + Adam states for 8.4M = 67 MB. Total ≈ 13.6 GB — fits on a single consumer GPU.

Chapter 4: KV Cache

During autoregressive generation, each new token needs to attend to all previous tokens. Without caching, you'd recompute K and V projections for the entire context at every step. The KV cache stores the K and V vectors from all previous positions, so each step only computes K,V for the new token and appends to the cache.

KV cache size (MHA):
Size = 2 × L × s × h × dk × bytes_per_element
= 2 × L × s × d × bytes // since h × d_k = d

KV cache size (GQA with g groups):
Size = 2 × L × s × g × dk × bytes
= 2 × L × s × d × (g/h) × bytes // g/h reduction factor
KV cache is the bottleneck. For long-context inference, the KV cache often exceeds the model weights in memory. A 7B model in FP16 is 13.5 GB, but its KV cache at 128K context can be 32+ GB. This is why KV cache optimization (GQA, quantization, PagedAttention, sliding window) is the central challenge of LLM serving.
Exercise 4.1: LLaMA 7B KV Cache Derive

LLaMA 7B: d=4096, L=32, h=32 (MHA, g=h). At context length 4096, in BF16: how much memory for the full KV cache?

GB
Show derivation
KV cache = 2 × 32 × 4096 × 4096 × 2 bytes
= 2 × 32 × 4096 × 4096 × 2
= 2,147,483,648 bytes ≈ 2.0 GB

2 GB for KV cache on top of 13.5 GB for weights = 15.5 GB total. This fits on a 24GB RTX 4090 with room to spare. But watch what happens at longer contexts...

Exercise 4.2: Long Context KV Cache Derive

Same LLaMA 7B model, but now at context length 128K (131072 tokens). KV cache in BF16?

GB
Show derivation
KV cache = 2 × 32 × 131072 × 4096 × 2 bytes
= 68,719,476,736 bytes ≈ 64.0 GB

64 GB for the KV cache alone — almost 5× the model weights! On a single A100 80GB: weights (13.5 GB) + KV cache (64 GB) = 77.5 GB. It barely fits, and only for a single user. This is why batch=1 long-context serving is so expensive, and why techniques like KV cache quantization (INT8 or even INT4) and sliding window attention are essential.

Exercise 4.3: GQA KV Cache Savings Derive

LLaMA 2 70B uses GQA with g=8, h=64, d=8192, L=80. KV cache at context=4096, BF16. Compare to full MHA (g=h=64).

GB (GQA, g=8)
Show derivation
dk = d / h = 8192 / 64 = 128
Per layer elements: 2 (K+V) × g × dk × s = 2 × 8 × 128 × 4096 = 8,388,608
All layers: 80 × 8,388,608 = 671,088,640 elements
Bytes (BF16) = 671,088,640 × 2 = 1,342,177,280 ≈ 1.25 GB
Full MHA (g=h=64): 2 × 80 × 4096 × 64 × 128 × 2 = 10.0 GB
Savings factor = 10.0 / 1.25 = 8× = h/g = 64/8

GQA with 8 groups saves 8× on KV cache memory. This matters enormously for serving: at 128K context, full MHA would need 320 GB of KV cache vs 40 GB with GQA. The parameter savings (fewer K/V projection weights) is a bonus, but the KV cache reduction is the primary motivation.

Exercise 4.4: Batch KV Cache Budget Derive

You have an A100 80GB serving LLaMA 7B in BF16 (13.5 GB weights). Each user gets up to 4096 context. KV cache per user = 2.0 GB (from Exercise 4.1). How many concurrent users can you serve?

concurrent users
Show derivation
Available for KV = 80 - 13.5 = 66.5 GB
Users = ⌊66.5 / 2.0⌋ = 33 concurrent users

33 concurrent users at 4K context. If you INT8-quantize the KV cache (1.0 GB/user), you get 66 users. If you use INT4 (0.5 GB/user), you get 133 users. This is exactly why PagedAttention (vLLM) and KV cache quantization are the most impactful serving optimizations — they directly multiply your throughput per GPU.

Exercise 4.5: KV Cache per Token Derive

For LLaMA 7B (d=4096, L=32, MHA), how many bytes does each new token add to the KV cache in BF16?

KB per token
Show derivation
Per token: 2 (K+V) × 32 (layers) × 4096 (d) × 2 (BF16) = 524,288 bytes = 512 KB

Every single generated token permanently consumes 512 KB of GPU memory. At 100 tokens/second generation speed, that's 50 MB/s of memory allocation per user. For a chatbot generating 500-token responses, each response consumes 256 MB of KV cache.

Chapter 5: Softmax & Layer Norm

These are the two non-linear operations that every transformer token passes through repeatedly. Getting them wrong — numerically, conceptually, or implementationally — produces subtle bugs that corrupt outputs without crashing.

Softmax:
softmax(xi) = exp(xi) / ∑j exp(xj)
Stable softmax: subtract max first
softmax(xi) = exp(xi - max(x)) / ∑j exp(xj - max(x))

LayerNorm:
LN(x) = γ ⊙ (x - μ) / √(σ² + ε) + β
μ = mean(x), σ² = var(x), γ and β are learned [d]

RMSNorm (used in LLaMA, Mistral):
RMSNorm(x) = γ ⊙ x / √(mean(x²) + ε)
// No centering (no μ), no β — just scale normalization
Exercise 5.1: Softmax Overflow Trace
Given logits [1000, 1001, 999], what happens if you compute exp(xi) directly in FP32 (max representable ≈ 3.4×1038)?
Exercise 5.2: Implement stableSoftmax() Build
Show solution
javascript
function stableSoftmax(logits) {
  const maxL = Math.max(...logits);
  const exps = logits.map(l => Math.exp(l - maxL));
  const sum = exps.reduce((a, b) => a + b, 0);
  return exps.map(e => e / sum);
}
Exercise 5.3: RMSNorm vs LayerNorm Trace
RMSNorm omits the mean subtraction and β bias compared to LayerNorm. Why does this work? And why prefer it?
Exercise 5.4: Implement rmsNorm() Build
Show solution
javascript
function rmsNorm(x, gamma, eps) {
  const meanSq = x.reduce((s, v) => s + v * v, 0) / x.length;
  const rms = Math.sqrt(meanSq + eps);
  return x.map((v, i) => gamma[i] * v / rms);
}
Exercise 5.5: Temperature Scaling Derive

Temperature T divides logits before softmax: softmax(x/T). For logits [2.0, 1.0, 0.5] with T=0.5, what is the probability of the top token?

probability
Show derivation
Scaled logits = [2.0/0.5, 1.0/0.5, 0.5/0.5] = [4.0, 2.0, 1.0]
exp([4.0, 2.0, 1.0]) = [54.598, 7.389, 2.718]
sum = 64.705
P(top) = 54.598 / 64.705 = 0.844

T < 1 makes the distribution sharper (more confident). At T=1, P(top) = exp(2)/(exp(2)+exp(1)+exp(0.5)) = 0.576. At T=0.5, it jumps to 0.844. At T→0, it approaches 1.0 (greedy). At T→∞, it approaches 1/3 (uniform).

Chapter 6: Positional Encoding

Self-attention is permutation-invariant — it has no concept of order. "The cat sat on the mat" and "mat the on sat cat the" produce identical attention patterns. Positional encoding injects position information so the model can distinguish token order.

Sinusoidal (original Transformer):
PE(pos, 2i) = sin(pos / 100002i/d)
PE(pos, 2i+1) = cos(pos / 100002i/d)

Learned (GPT-2): lookup table [T, d], one vector per position

RoPE (LLaMA, Mistral): rotates Q,K vectors by position-dependent angle
Rθ(pos) rotates pairs of dimensions by θi = pos / 100002i/d
Key property: qm · kn depends only on (m - n), giving relative position

ALiBi (BLOOM): no encoding — adds bias -m|i-j| to attention scores
where m is a per-head slope
Exercise 6.1: Sinusoidal Wavelengths Derive

For d=512, what is the wavelength (period) of the sinusoidal encoding at dimension i=0 (the fastest oscillating) and i=255 (the slowest)?

Wavelength = 2π × 100002i/d.

wavelength at i=0 (≈ 2π)
Show derivation
i=0: wavelength = 2π × 100000/512 = 2π × 1 = 6.28
i=255: wavelength = 2π × 10000510/512 ≈ 2π × 9961 ≈ 62,577

The fastest dimension (i=0) oscillates with period 2π ≈ 6.3 positions — it can distinguish adjacent tokens. The slowest (i=255) has a period of ~62,577 positions — it encodes coarse position over very long sequences. This geometric spacing from 2π to ~62,577 is why sinusoidal encoding generalizes well: it covers all frequency scales.

Exercise 6.2: RoPE Parameters Derive

RoPE has 0 learnable parameters. Learned positional embedding for T=4096, d=4096 has how many parameters?

million params
Show derivation
Learned PE = T × d = 4096 × 4096 = 16,777,216 ≈ 16.8M

16.8M saved parameters is one reason modern models use RoPE. But the bigger reason is that RoPE generalizes to sequence lengths longer than training context (with NTK-aware or YaRN scaling), while learned embeddings cannot — position 8193 has no learned vector if you trained on 8192.

Exercise 6.3: RoPE Relative Property Trace
RoPE's key property is that the dot product qmT kn only depends on (m-n). What does this mean for the model's behavior?

Chapter 7: Throughput & Latency

You're serving a 70B model and the PM asks: "How many tokens per second can we generate?" The answer depends on whether you're compute-bound (prefill) or memory-bandwidth-bound (decode), your batch size, and your hardware. You need to derive these numbers, not guess.

Forward pass FLOPs (approximate, per token):
FLOPs/token ≈ 2P // 2 FLOPs per parameter (multiply + add)

Prefill (prompt processing, batch of tokens):
Compute-bound. Throughput ≈ Peak FLOP/s / (2P)

Decode (autoregressive generation, one token at a time):
Memory-bandwidth-bound (batch=1). Must read all weights from HBM.
Throughput ≈ Memory bandwidth / (P × bytes_per_param)

TTFT = Time To First Token = prefill time
ITL = Inter-Token Latency = time per decode step
TPS = Tokens Per Second = 1 / ITL (single user)
The 2P rule. A forward pass through a model with P parameters costs approximately 2P FLOPs per token. This is because each parameter participates in one multiply-add operation per token. For a 7B model: 2 × 7B = 14 GFLOP per token. On an H100 (990 TFLOP/s BF16), that's 14G / 990T = 14 microseconds per token — absurdly fast, until you realize you're memory-bandwidth-bound during decode.
Exercise 7.1: Decode Throughput (Batch=1) Derive

LLaMA 7B in BF16 on an A100 (2 TB/s HBM bandwidth). During decode at batch=1, you read all weights once per token. What is the maximum tokens/second?

tokens/sec
Show derivation
Weight bytes = 6.74B × 2 = 13.48 GB
Time per token = 13.48 GB / (2 TB/s) = 6.74 ms
Tokens/sec = 1000 / 6.74 = 148 tokens/sec

148 tokens/second at batch=1. But the A100 has 312 TFLOP/s BF16 — computing 2P = 13.5 GFLOP per token takes only 43 microseconds. The GPU is 99.4% idle during decode, just waiting for memory. This is why batching is essential: at batch=8, you amortize the weight read across 8 tokens, getting ~8× throughput.

Exercise 7.2: Prefill Throughput Derive

Same LLaMA 7B on A100 (312 TFLOP/s BF16). Prefilling a 2048-token prompt. Assuming 50% MFU (model FLOPs utilization), how long does prefill take?

ms
Show derivation
FLOPs = 2 × 6.74B × 2048 = 27.6 TFLOP
Effective throughput = 312 × 0.5 = 156 TFLOP/s
Time = 27.6 / 156 = 0.1769 s = 176.9 ms

~177 ms TTFT for a 2048-token prompt. Prefill is compute-bound because you're processing many tokens at once in large matrix multiplies that fully utilize the GPU's tensor cores. The 50% MFU accounts for kernel launch overhead, memory access patterns, and operator fusion gaps.

Exercise 7.3: Batch Size Crossover Derive

At what batch size does decode become compute-bound instead of memory-bound on the A100? The crossover happens when compute time = memory time. Use: 312 TFLOP/s compute, 2 TB/s bandwidth, 6.74B params in BF16.

Crossover batch = (Peak FLOP/s) / (Bandwidth × FLOPs_per_byte). FLOPs per byte = 2P / (P × 2) = 1 per weight byte... actually: at batch B, FLOPs = 2PB, bytes = 2P. So AI = B. Crossover when B = ridge point.

batch size
Show derivation
At batch B: FLOPs per token step = 2P × B, Bytes read = P × 2 (weights, once)
Arithmetic Intensity = 2PB / (2P) = B
Ridge point = 312 TFLOP/s / 2 TB/s = 156
Crossover at batch = 156

Below batch=156, you're memory-bound (wasting compute). Above 156, you're compute-bound (fully utilizing tensor cores). In practice, KV cache memory limits your batch size long before 156 — at 4K context, each user needs ~2 GB KV cache, so 156 users would need 312 GB of KV cache alone!

Exercise 7.4: INT4 Speedup Derive

If you INT4-quantize the weights (0.5 bytes/param), what is the new decode throughput at batch=1 on A100 (2 TB/s)?

tokens/sec
Show derivation
Weight bytes = 6.74B × 0.5 = 3.37 GB
Time per token = 3.37 GB / (2 TB/s) = 1.685 ms
Tokens/sec = 1000 / 1.685 = 594 tokens/sec

4× speedup over BF16! INT4 quantization gives you 4× the decode throughput because you read 4× fewer bytes from memory, and decode is memory-bandwidth-bound. This is the main reason people quantize for serving, not just to save memory — it directly increases tokens/second.

Exercise 7.5: 70B Serving: H100 Count Derive

You need to serve LLaMA 70B at 30 tokens/sec per user, supporting 100 concurrent users (100 × 30 = 3000 tokens/sec total). BF16 weights, H100 (3.35 TB/s bandwidth). How many H100s minimum for decode throughput alone?

H100s
Show derivation
Weight bytes = 70B × 2 = 140 GB
Per-GPU decode speed (batch=1) = 3.35 TB/s / 140 GB = 23.9 tokens/sec per copy

But with batching! At batch=B, throughput ≈ B × 23.9 (while memory-bound).

One GPU replica: batch=100 needs 100 × KV cache memory too. But for throughput: 100 tokens per weight read = 100 × 23.9 = 2390 tok/s per replica.

Hmm, the key insight: 70B doesn't fit on one H100 (140 GB > 80 GB). Need tensor parallelism. With TP=2: 70 GB per GPU, fitting on 2 H100s. Now: bandwidth per model = 2 × 3.35 = 6.7 TB/s. Batch-1 throughput = 6.7T / 140G = 47.9 tok/s. At batch 100 (memory-bound): 4,790 tok/s on 2 GPUs — enough for 3000 tok/s target.

But we need memory for 100 KV caches: 100 × ~5 GB (4K context, GQA) = 500 GB KV cache. 2 GPUs have 160 GB total - 140 GB weights = 20 GB free. Can only fit 4 users! Need more GPUs for KV cache memory.

Memory needed: 140 GB (weights) + 500 GB (KV) = 640 GB
GPUs for memory: ⌈640 / 80⌉ = 8 GPUs

But 8 GPUs with 8 × 3.35 = 26.8 TB/s bandwidth. Batch-100 throughput = 26.8T × 100 / 140G ≈ 19,143 tok/s — way more than 3000 needed. However, tensor parallelism overhead (allreduce) and KV cache operations reduce effective throughput.

Practical answer: ~8 H100s minimum for memory (weights + KV cache for 100 users), with the throughput target easily met. But with overhead, inefficiency, and headroom, real deployments use ~16-32 GPUs. The answer of 25 accounts for production overhead, redundancy, and prefill compute sharing.

Chapter 8: Architecture Design

You've learned to count parameters, compute FLOPs, estimate memory, and measure throughput. Now put it all together: given a set of constraints (compute budget, memory, latency target), design the right transformer architecture.

Exercise 8.1: d/L Balance Trace
Two models with the same parameter count (~1.3B): Model A has d=2048, L=24. Model B has d=4096, L=6. Which performs better on language modeling, and why?
Exercise 8.2: Chinchilla Optimal Derive

Chinchilla's finding: optimal training uses ~20 tokens per parameter (D ≈ 20P). For a 7B model, how many training tokens and what total training FLOPs (using the 6ND rule: FLOPs = 6 × N × D)?

trillion training FLOPs (1012)
Show derivation
D = 20 × 7B = 140B tokens
FLOPs = 6 × 7B × 140B = 5,880 × 1012 = 5.88 × 1015

5,880 TFLOP (5.88 PFLOP). On a cluster of 1024 A100s at 50% MFU: effective throughput = 1024 × 312 × 0.5 = 159,744 TFLOP/s. Training time = 5,880,000 / 159,744 = 36,803 seconds ≈ 10.2 hours. In practice, communication overhead doubles this to ~20-30 hours. Note: LLaMA was trained on 1-1.4T tokens (far beyond Chinchilla optimal) to produce a better-quality smaller model — over-training is the norm for inference-optimized models.

Exercise 8.3: Edge Deployment Design Design

You need a model that runs on a phone (6 GB RAM, 2 TOPS INT4 compute). Requirements: <3 GB model size (INT4), >10 tokens/sec decode. Design the architecture: how many parameters maximum? What d and L would you choose?

Show reasoning

Max INT4 model size = 3 GB → max params = 3 GB / 0.5 bytes = 6B. But we need room for KV cache and OS overhead. A 3B model at INT4 = 1.5 GB is the sweet spot: leaves 4.5 GB headroom.

For 3B params: 12Ld² = 3B → with L=32, d = √(3B / 384) = √7.8M ≈ 2795. Round to d=2560 (multiple of 128 for hardware alignment) with L=32 gives 12×32×2560² = 2.52B layer params + embeddings ≈ 2.7B total.

Decode throughput: memory-bound at batch=1. Bandwidth on phone ~50 GB/s (LPDDR5). 1.5 GB / 50 GB/s = 30 ms/token = 33 tok/s. Meets the 10 tok/s target with margin.

Exercise 8.4: FFN Ratio Derive

Standard transformers use dff = 4d. SwiGLU uses dff = 8d/3 (rounded). Show that SwiGLU with dff = 8d/3 has approximately the same parameter count as standard FFN with dff = 4d.

Standard FFN: 2 × d × 4d = 8d². SwiGLU FFN: 3 × d × (8d/3) = ?
Exercise 8.5: Complete Architecture Table Derive

Fill in the missing values. Use 12Ld² + Vd for total params (tied embedding).

ModeldLhdkParams (B)
GPT-2 XL1600482564?
GPT-3122889696128?
LLaMA 13B51204040128?

Compute total params for GPT-3 (V=50257):

billion params (GPT-3)
Show derivation
GPT-2 XL: 12 × 48 × 1600² + 50257 × 1600 = 1,474,560,000 + 80,411,200 = 1.56B
GPT-3: 12 × 96 × 12288² + 50257 × 12288 = 173,946,175,488 + 617,558,016 = 174.6B
LLaMA 13B: 12 × 40 × 5120² + 32000 × 5120 = 12,582,912,000 + 163,840,000 = 12.75B

Chapter 9: Capstone Challenge

The capstone combines everything from Chapters 0-8 into a single multi-step design problem. This is the kind of analysis you'd do at a frontier lab when planning a new model or estimating serving costs for a customer.

The scenario. You are tasked with training and deploying a new 13B parameter model for a chat product. Your cluster has 64 H100 80GB GPUs. Your serving target: 50 concurrent users at 4096 context length, 30 tok/s per user. Work through each question below to determine feasibility.
Exercise 9.1: Architecture Derive

Choose d, L, h for a ~13B model. Use SwiGLU (dff = 8d/3), GQA with g=8, RoPE. V=32000. Verify the param count.

Hint: LLaMA 13B uses d=5120, L=40. Verify: attention = 4d² per layer, SwiGLU = 3 × d × (8d/3) = 8d² per layer. Total per layer = 12d². But with GQA, attention K/V projections shrink.

billion params (verify ~13B)
Show derivation
d=5120, L=40, h=40, g=8, dk=128, V=32000
Quick estimate: 12Ld² + 2Vd = 12 × 40 × 5120² + 2 × 32000 × 5120
= 12,582,912,000 + 327,680,000 = 12.91B ≈ 13B

The 12d² approximation works because GQA's KV savings (~10%) and SwiGLU's different ratio roughly cancel out in practice. For exact counts with GQA: attention = (2d² + 2dg·dk) per layer, but the 12d² shortcut gets you within 5%.

Exercise 9.2: Training Compute Derive

Train on 260B tokens (20 × 13B, Chinchilla optimal). Using 6ND: how many total FLOPs? On 64 H100s at 50% MFU (990 TFLOP/s BF16 peak), how many hours?

hours
Show derivation
FLOPs = 6 × 13 × 109 × 260 × 109 = 20,280 × 1018 = 2.03 × 1022
Cluster throughput = 64 × 990 TFLOP/s × 0.5 = 31,680 TFLOP/s
Time = 2.03 × 1022 / (3.168 × 1016) = 640,909 s ≈ 178 hours ≈ 7.4 days

7.4 days on 64 H100s. In practice, communication overhead (allreduce, pipeline bubbles) and checkpointing add 20-40%, making it ~10 days. This is well within a typical 2-week training window. The 6ND rule makes this estimation trivial — it's the first thing to compute when scoping a training run.

Exercise 9.3: Serving Memory Derive

Serving this 13B model in BF16 with 50 concurrent users at 4096 context, GQA g=8, h=40, d=5120, L=40. How much total GPU memory needed (weights + KV cache)?

KV cache per user = 2 × L × s × g × d_k × 2 bytes.

GB total
Show derivation
Weights = 13B × 2 = 26 GB
dk = 5120 / 40 = 128
KV per user = 2 × L × s × g × dk × 2 bytes
= 2 × 40 × 4096 × 8 × 128 × 2 = 671,088,640 bytes ≈ 0.625 GB
50 users × 0.625 GB = 31.25 GB KV cache
Total = 26 + 31.25 = 57.25 GB

57.25 GB fits on a single H100 (80 GB) with 22.75 GB to spare for activations and OS overhead. GQA is doing the heavy lifting here: with full MHA (g=h=40), the KV cache would be 5× larger at 156 GB — requiring 3 GPUs just for the cache.

Exercise 9.4: Decode Throughput Check Derive

Can a single H100 (3.35 TB/s, 80 GB) serve this 13B model with 50 users at 30 tok/s each (1500 tok/s total)? Check both memory and throughput.

Exercise 9.5: Cost Estimation Derive

H100 cloud cost: ~$3/hour. If 1 H100 serves 50 users generating 1M tokens/day each (50M tokens/day total), what is the cost per 1M tokens?

$/1M tokens
Show derivation
Daily GPU cost = $3/hr × 24 hrs = $72/day
Tokens per day = 50M
Cost per 1M tokens = $72 / 50 = $1.44/1M tokens

$1.44/1M output tokens for a 13B model on H100. For comparison, API providers charge $0.15-$3.00/1M tokens for similar-sized models. The gap is margin, overhead, prefill cost, idle time, and redundancy. If you INT4-quantize: 4× more users per GPU → $0.36/1M tokens.

The proof of work. If you completed every exercise in this workbook from scratch — counted parameters, traced attention, computed memory budgets, estimated throughput, and designed architectures — you can walk into any ML systems interview and hold your own. These are the exact calculations that frontier lab engineers do daily. "What I cannot create, I do not understand."

Related Lessons

TopicLesson
Transformer fundamentalsTransformer — From Absolute Zero
GPT architectureGPT — From Absolute Zero
Scaling & inferenceScaling Book Workbook
Distributed trainingDistributed Training — From Absolute Zero
Inference optimizationML Inference Engineer — Day In The Life