Workbook — Training & Backpropagation

Training Backprop

Every gradient, loss, and optimizer calculation a training engineer needs to do by hand. Chain rule, gradient shapes, cross-entropy, attention backward, Adam dynamics, mixed precision — all solvable in-browser with instant feedback.

Prerequisites: Basic calculus (derivatives, chain rule) + Matrix multiply shapes. That's it.
10
Chapters
54
Exercises
5
Exercise Types
Mastery
0 / 54 exercises (0%)
0
Day Streak
Best: 0

Chapter 0: Chain Rule Fundamentals

You're debugging a training run. The loss isn't going down. Before you can diagnose anything, you need to trace the gradient from the loss all the way back to any weight in the network. This is the chain rule — the single most important tool in all of deep learning.

For a composition of functions f(g(x)), the derivative is:

Chain rule: ∂f/∂x = (∂f/∂g) · (∂g/∂x)

For a 2-layer network: f(x) = w2 · σ(w1 · x + b1) + b2
Let z = w1 · x + b1 // pre-activation
Let a = σ(z) // activation (e.g., ReLU)
Let y = w2 · a + b2 // output

∂y/∂w1 = (∂y/∂a) · (∂a/∂z) · (∂z/∂w1) = w2 · σ'(z) · x
Backprop is just the chain rule, applied systematically. Every "backward pass" in PyTorch computes exactly these partial derivatives, layer by layer, from the loss back to the parameters. If you can trace the chain rule by hand, you understand backprop.
Exercise 0.1: Simple Chain Rule Derive

Given f(x) = (3x + 2)², compute ∂f/∂x at x = 1.

Let u = 3x + 2, then f = u². Apply the chain rule: ∂f/∂x = ∂f/∂u · ∂u/∂x.

Show derivation
u = 3x + 2,   f = u²
∂f/∂u = 2u,   ∂u/∂x = 3
∂f/∂x = 2u · 3 = 6u = 6(3x + 2)
At x = 1: ∂f/∂x = 6(3 + 2) = 6 × 5 = 30
Exercise 0.2: Gradient Through ReLU Derive

Network: y = w2 · ReLU(w1 · x + b1) + b2. Given w1=2, b1=−1, w2=3, b2=0.5, x=1. Loss L = (y − target)² with target=4. Compute ∂L/∂w1.

Forward: z = 2·1 − 1 = 1, a = ReLU(1) = 1, y = 3·1 + 0.5 = 3.5. L = (3.5 − 4)² = 0.25. Then backprop.

Show derivation
∂L/∂y = 2(y − target) = 2(3.5 − 4) = −1
∂y/∂a = w2 = 3
∂a/∂z = ReLU'(z) = 1 // since z = 1 > 0
∂z/∂w1 = x = 1
∂L/∂w1 = (−1)(3)(1)(1) = −3

The gradient is negative, so increasing w1 would decrease the loss — which makes sense since y = 3.5 is below target = 4 and increasing w1 increases the output.

Exercise 0.3: Dead ReLU Trace
Same network as 0.2, but now x = −2. Forward: z = 2(−2) − 1 = −5, a = ReLU(−5) = 0. What is ∂L/∂w1?
Show explanation
∂a/∂z = ReLU'(−5) = 0
∂L/∂w1 = ∂L/∂y · ∂y/∂a · 0 · ∂z/∂w1 = 0

This is the dead ReLU problem. When the pre-activation z is negative, the gradient is exactly zero — the weight gets no learning signal at all. If a neuron's z is always negative for all training data, it's "dead" and can never recover. This is why LeakyReLU and GELU were invented — they let a small gradient flow even for negative inputs.

Exercise 0.4: Three-Layer Chain Derive

Network: y = w3 · ReLU(w2 · ReLU(w1 · x)). Given w1=2, w2=−3, w3=4, x=1. Compute ∂y/∂w1.

Forward first: z1 = 2, a1 = 2, z2 = −6, a2 = ReLU(−6) = 0, y = 0.

Show derivation
∂y/∂a2 = w3 = 4
∂a2/∂z2 = ReLU'(−6) = 0
∂y/∂w1 = 4 × 0 × ... = 0

The dead ReLU at layer 2 kills the gradient for ALL layers before it. This is how vanishing gradients work in practice — a single dead activation can block learning for an entire subnetwork upstream.

Exercise 0.5: Implement chainGrad() Build

Write a function that computes the gradient of a composition fn(...f2(f1(x))). Each function is given as {f: x => ..., df: x => ...} (value and derivative). Return ∂(composition)/∂x.

Forward pass to get intermediate values, then backward pass multiplying derivatives.
Show solution
javascript
function chainGrad(funcs, x) {
  // Forward: store each intermediate value
  const vals = [x];
  for (const fn of funcs) {
    vals.push(fn.f(vals[vals.length - 1]));
  }
  // Backward: multiply all local derivatives
  let grad = 1;
  for (let i = funcs.length - 1; i >= 0; i--) {
    grad *= funcs[i].df(vals[i]);
  }
  return grad;
}
Exercise 0.6: Sigmoid Derivative Derive

The sigmoid function is σ(z) = 1/(1 + e−z). A beautiful fact: σ'(z) = σ(z)(1 − σ(z)). Compute σ'(0).

Show derivation
σ(0) = 1/(1 + e0) = 1/(1 + 1) = 0.5
σ'(0) = 0.5 × (1 − 0.5) = 0.5 × 0.5 = 0.25

The maximum value of σ' is 0.25, occurring at z = 0. For large |z|, σ' → 0. This means that in a deep network with sigmoid activations, the gradient shrinks by at least 4× per layer — this is the vanishing gradient problem that plagued early deep networks.

Chapter 1: Gradient Shapes

You're reviewing a custom layer implementation and need to verify the backward pass is correct. The first thing to check: does every gradient have the same shape as the parameter it corresponds to? If ∂L/∂W has a different shape than W, something is very wrong.

The fundamental shape rule:
If W has shape [din, dout], then ∂L/∂W always has shape [din, dout].

Linear layer forward: Y = X · W    where X is [B, din], W is [din, dout], Y is [B, dout]
Linear layer backward:
∂L/∂W = XT · ∂L/∂Y    shape: [din, B] · [B, dout] = [din, dout]   // same as W ✓
∂L/∂X = ∂L/∂Y · WT    shape: [B, dout] · [dout, din] = [B, din]   // same as X ✓
Shape-matching is your first debugging tool. The gradient of a scalar loss with respect to any tensor always has the same shape as that tensor. If you're ever confused about a backward pass, start by writing down the shapes — they constrain everything.
Exercise 1.1: Linear Layer Gradients Derive

A linear layer has W with shape [768, 3072] and input X with shape [32, 768]. How many elements are in ∂L/∂W?

Show derivation
∂L/∂W has the same shape as W: [768, 3072]
Elements = 768 × 3072 = 2,359,296

The batch dimension (32) does not appear in the gradient shape — it gets summed out during the XT · dY multiply. Every sample in the batch contributes to the same gradient tensor.

Exercise 1.2: Trace Through a MLP Trace
Pipeline: Linear(768, 3072) → ReLU → Linear(3072, 768) → Softmax → Cross-Entropy. The upstream gradient ∂L/∂logits has shape [B, 768]. What is the shape of ∂L/∂W1 (the first linear layer's weight)?
Show explanation

The gradient of L with respect to W1 always has the same shape as W1: [768, 3072]. It doesn't matter how many layers come after — the chain rule multiplies everything together and the batch dimension gets summed out.

Exercise 1.3: Total Gradient Elements Derive

A transformer layer has: WQ, WK, WV, WO each [d, d] and FFN W1 [d, 4d], W2 [4d, d] where d = 768. How many total gradient elements must be computed for one backward pass through this layer?

elements (millions)
Show derivation
Attention: 4 × 768 × 768 = 4 × 589,824 = 2,359,296
FFN W1: 768 × 3072 = 2,359,296
FFN W2: 3072 × 768 = 2,359,296
Total = 2,359,296 + 2,359,296 + 2,359,296 = 7,077,888 ≈ 7.08M

Every gradient element is the same size as the parameter — so the backward pass computes exactly as many gradient values as there are parameters. For a 7B model, that's 7 billion gradient elements per training step.

Exercise 1.4: Bias Gradient Shape Trace
A bias vector b has shape [dout]. In the forward pass Y = XW + b, the bias is broadcast across the batch. How is ∂L/∂b computed from ∂L/∂Y (shape [B, dout])?
Show explanation

Since b is broadcast to every row of Y, the gradient flows back from every row. The chain rule requires summing over the batch: ∂L/∂b = ∑i ∂L/∂Yi. This is a sum, not a mean — the loss is already averaged over the batch if needed.

Exercise 1.5: Backward Memory Derive

During backprop through Linear(768, 3072) with batch=32 and sequence length 512, we need the stored activation X of shape [B×seq, din] to compute ∂L/∂W = XT · dY. How much memory (in MB) does storing X in FP32 require?

MB
Show derivation
X shape = [32 × 512, 768] = [16384, 768]
Elements = 16384 × 768 = 12,582,912
Bytes (FP32) = 12,582,912 × 4 = 50,331,648 ≈ 48 MB (50.3 MB exact)

This is why activation memory dominates training — every linear layer must save its input for the backward pass. A transformer with 32 layers saves these activations 32 times. Activation checkpointing trades compute for memory by recomputing these instead of storing them.

Chapter 2: Cross-Entropy Loss

The loss function is the single number that tells the optimizer "how wrong you are." For language models, the loss is almost always cross-entropy: it measures how surprised the model is by the correct next token.

Cross-entropy loss: L = −∑i yi · log(pi)

For one-hot labels (only one correct class c):
L = −log(pc) // just the negative log of the correct class probability

Softmax: pi = ezi / ∑j ezj   // converts logits z to probabilities p

Gradient of CE + Softmax: ∂L/∂zi = pi − yi // beautifully simple!
The gradient is just (predicted − actual). If the model predicts pc = 0.9 for the correct class, the gradient for that logit is 0.9 − 1 = −0.1 (small push to increase it). If pc = 0.01, the gradient is 0.01 − 1 = −0.99 (massive push). The loss automatically provides stronger signal when the model is more wrong.
Exercise 2.1: Compute Cross-Entropy Derive

Logits: z = [2.0, 1.0, 0.1]. True class = 0 (first position). Compute the cross-entropy loss. (First apply softmax, then −log(p0).)

Softmax: pi = ezi / (e2.0 + e1.0 + e0.1). Use e2 ≈ 7.389, e1 ≈ 2.718, e0.1 ≈ 1.105.

Show derivation
e2.0 = 7.389,   e1.0 = 2.718,   e0.1 = 1.105
Sum = 7.389 + 2.718 + 1.105 = 11.212
p0 = 7.389 / 11.212 = 0.659
L = −log(0.659) = 0.417
Exercise 2.2: Loss When Confident vs. Wrong Trace
The model assigns pcorrect = 0.9 (loss = −log(0.9) ≈ 0.105). Now pcorrect drops to 0.1 (loss = −log(0.1) ≈ 2.303). By what factor did the loss increase?
Show derivation
Ratio = 2.303 / 0.105 ≈ 21.9×

Cross-entropy is logarithmic — being very wrong is dramatically more expensive than being slightly wrong. This is by design: the −log function has an asymptote at 0, so assigning near-zero probability to the correct class is catastrophically penalized. This is also why label smoothing helps — it prevents the model from trying to push pc all the way to 1.0.

Exercise 2.3: Implement crossEntropy() Build

Write a function that computes cross-entropy loss given logits (raw scores) and the correct class index. Apply softmax internally. Use Math.log and Math.exp.

Softmax first (subtract max for numerical stability), then -log(p_target).
Show solution
javascript
function crossEntropy(logits, target) {
  const maxZ = Math.max(...logits);
  const exps = logits.map(z => Math.exp(z - maxZ));
  const sumExp = exps.reduce((a, b) => a + b);
  const logProb = (logits[target] - maxZ) - Math.log(sumExp);
  return -logProb;
}
Exercise 2.4: Find the Bug Debug

This cross-entropy implementation produces NaN for large logits. Click the buggy line.

function crossEntropy(logits, target) {
  const exps = logits.map(z => Math.exp(z));
  const sumExp = exps.reduce((a, b) => a + b);
  const probs = exps.map(e => e / sumExp);
  return -Math.log(probs[target]);
}
Show explanation

Line 2 is the bug. Computing Math.exp(z) directly on raw logits overflows to Infinity when z > ~709 (FP64 limit). The fix is to subtract the max logit first: logits.map(z => Math.exp(z - maxZ)). This is the log-sum-exp trick — it produces the same softmax probabilities but avoids overflow. Every production softmax implementation uses this trick.

Exercise 2.5: Perplexity Derive

Perplexity = eL where L is the average cross-entropy loss per token. If a model achieves an average loss of 2.5 nats per token on a test set, what is the perplexity? (Use e2.5 ≈ 12.18.)

Show derivation
Perplexity = eL = e2.512.18

Perplexity has an intuitive interpretation: the model is "as confused as if it were choosing uniformly between ~12 options at each token." A perfect model has perplexity 1 (loss = 0). GPT-3 achieves ~20 perplexity on common benchmarks. A model that always guesses uniformly over V=50k tokens has perplexity 50,000.

Exercise 2.6: Softmax + CE Gradient Derive

Logits z = [2.0, 1.0, 0.1], true class = 0. From Exercise 2.1, softmax gives p = [0.659, 0.242, 0.099]. Compute all three components of the gradient ∂L/∂z.

Use the formula: ∂L/∂zi = pi − yi where y = [1, 0, 0].

∂L/∂z0 (the correct class)
Show derivation
∂L/∂z0 = p0 − y0 = 0.659 − 1 = −0.341
∂L/∂z1 = p1 − y1 = 0.242 − 0 = +0.242
∂L/∂z2 = p2 − y2 = 0.099 − 0 = +0.099

The gradient pushes the correct class logit up (negative gradient = increase) and all other logits down (positive gradient = decrease). The magnitudes sum to zero: −0.341 + 0.242 + 0.099 = 0. This is a consequence of probabilities summing to 1.

Chapter 3: Backprop Through Attention

Attention is the most expensive operation in a transformer — both forward and backward. The backward pass through attention is where most of the FLOPs go during training. If you understand the gradient flow through attention, you understand why FlashAttention matters.

Attention forward:
S = Q · KT / √dk    // [seq, seq] scores
P = softmax(S)          // [seq, seq] attention weights
O = P · V                // [seq, d_k] output

Attention backward (given dO = ∂L/∂O):
dV = PT · dO              // [seq, d_k]
dP = dO · VT              // [seq, seq]
dS = dP ⊙ P − P ⊙ (dP ⊙ P).sum(dim=−1, keepdim=True) // softmax backward
dQ = dS · K / √dk       // [seq, d_k]
dK = dST · Q / √dk     // [seq, d_k]
Why FlashAttention matters for backward. The naive backward pass must store the full [seq, seq] attention matrix P. For seq=8192 and 32 heads, that's 32 × 8192² × 4 bytes = 8 GB just for one layer's attention weights. FlashAttention recomputes P from Q, K in blocks during the backward pass, trading a small amount of extra compute for massive memory savings.
Exercise 3.1: Attention Score Matrix Size Derive

For a single attention head with seq=2048, dk=64, how many elements are in the attention score matrix S = QKT/√dk?

elements (millions)
Show derivation
S has shape [seq, seq] = [2048, 2048]
Elements = 2048 × 2048 = 4,194,304 ≈ 4.19M

This is O(seq²) — the quadratic bottleneck of attention. With 32 heads, the total is 32 × 4.19M = 134M elements. At FP32 that's 536 MB just for one layer's attention scores.

Exercise 3.2: Attention Backward FLOPs Derive

For one head with dk=64, seq=512: the backward pass through attention involves these matmuls: PT·dO [seq,seq]×[seq,dk], dO·VT [seq,dk]×[dk,seq], dS·K [seq,seq]×[seq,dk], dST·Q [seq,seq]×[seq,dk]. Each matmul of [m,n]×[n,p] costs 2mnp FLOPs. What are the total FLOPs for these 4 matmuls?

million FLOPs
Show derivation
PT·dO: 2 × 512 × 512 × 64 = 33,554,432
dO·VT: 2 × 512 × 64 × 512 = 33,554,432
dS·K: 2 × 512 × 512 × 64 = 33,554,432
dST·Q: 2 × 512 × 512 × 64 = 33,554,432
Total = 4 × 33,554,432 = 134,217,728 ≈ 134.2M FLOPs

The forward pass has only 2 matmuls (QKT and PV), so the backward pass does roughly 2× the FLOPs of the forward. This "backward is ~2× forward" rule holds approximately across the entire transformer.

Exercise 3.3: Softmax Backward Trace
The softmax backward pass has an unusual property compared to ReLU or linear layers. Which statement is true?
Show explanation

The Jacobian of softmax is ∂pi/∂zj = piij − pj), which can be written as diag(p) − ppT. This is a dense matrix — changing any one logit zj affects all probabilities. This is fundamentally different from ReLU (diagonal Jacobian) and is one reason softmax is relatively expensive in the backward pass.

Exercise 3.4: FlashAttention Memory Savings Derive

Standard attention stores P (shape [h, seq, seq]) for the backward pass. FlashAttention only stores O (shape [h, seq, dk]) and the per-row logsumexp (shape [h, seq]). For h=32, seq=4096, dk=128, compute the memory ratio (standard / flash) in FP16.

× ratio
Show derivation
Standard: h × seq × seq × 2 bytes = 32 × 4096 × 4096 × 2 = 1,073,741,824 bytes = 1024 MB
Flash O: h × seq × dk × 2 = 32 × 4096 × 128 × 2 = 33,554,432 bytes ≈ 32 MB
Flash LSE: h × seq × 4 = 32 × 4096 × 4 = 524,288 bytes ≈ 0.5 MB // FP32 for stability
Ratio = 1024 / (32 + 0.5) ≈ 31.5×
Exercise 3.5: Arrange the Attention Backward Design

Put these attention backward steps in the correct order, starting from the upstream gradient dO.

?
?
?
?
?
dV = PT·dO dP = dO·VT dS = softmax_bwd(dP, P) dQ = dS·K / √dk dK = dST·Q / √dk
Show explanation

The order is: dV (only needs P and dO, both available), dP (needs dO and V), dS (needs dP and P from softmax backward), dQ (needs dS and K), dK (needs dS and Q). Note that dV and dP can be computed in parallel since they have no dependency on each other, but dS must come before dQ and dK.

Chapter 4: Learning Rate & Schedules

The learning rate is the most important hyperparameter in all of deep learning. Too high and training diverges. Too low and you waste compute. The schedule — how the LR changes over time — is just as critical as the peak value.

Linear warmup (steps 0 to W):
lr(t) = lrmax × t / W

Cosine annealing (steps W to T):
lr(t) = lrmin + 0.5 × (lrmax − lrmin) × (1 + cos(π × (t − W) / (T − W)))

Typical values for LLM training:
lrmax = 3 × 10−4,   lrmin = 3 × 10−5 (10% of max),   warmup = 2000 steps
Warmup prevents early instability. At step 0, the Adam optimizer's momentum terms (m and v) are initialized to zero. Without warmup, the first update uses a very noisy gradient estimate with no momentum smoothing. The warmup period lets Adam build up reliable statistics before taking full-sized steps.
Exercise 4.1: Linear Warmup Derive

Warmup from 0 to lrmax = 3×10−4 over 2000 steps. What is the learning rate at step 500?

(e.g. 0.000075)
Show derivation
lr(500) = 3 × 10−4 × 500 / 2000 = 3 × 10−4 × 0.25 = 7.5 × 10−5
Exercise 4.2: Cosine Annealing Derive

Cosine schedule: lrmax=3×10−4, lrmin=3×10−5, warmup=2000, total=10000. What is the LR at step 6000? (cos(π/2) = 0, cos(π) = −1)

(e.g. 0.000165)
Show derivation
Progress = (6000 − 2000) / (10000 − 2000) = 4000/8000 = 0.5
cos(π × 0.5) = cos(π/2) = 0
lr = 3×10−5 + 0.5 × (3×10−4 − 3×10−5) × (1 + 0)
lr = 3×10−5 + 0.5 × 2.7×10−4 × 1 = 3×10−5 + 1.35×10−4 = 1.65 × 10−4

At the midpoint of the cosine decay, the LR is exactly halfway between lrmax and lrmin. The cosine shape decays slowly at first, then fast in the middle, then slowly again at the end — matching the intuition that late training needs fine-grained updates.

Exercise 4.3: Implement cosineSchedule() Build

Write a function that returns the learning rate at a given step, with linear warmup followed by cosine decay to lrmin.

During warmup: linear ramp. After warmup: cosine annealing. Use Math.cos and Math.PI.
Show solution
javascript
function cosineSchedule(step, warmup, total, lrMax, lrMin) {
  if (step < warmup) {
    return lrMax * step / warmup;
  }
  const progress = (step - warmup) / (total - warmup);
  return lrMin + 0.5 * (lrMax - lrMin) * (1 + Math.cos(Math.PI * progress));
}
Exercise 4.4: LR Divergence Trace
You start training with lr = 1×10−3 (no warmup). After 100 steps, the loss suddenly shoots to NaN. You restart with lr = 1×10−4 and training converges smoothly. What most likely happened?
Show explanation

When the learning rate is too high, each gradient step overshoots the minimum. The loss increases, which produces larger gradients, which produce even larger updates — a positive feedback loop that quickly reaches Infinity, then NaN. This is training divergence. The fix is always to reduce the LR or add warmup. The critical LR depends on the model architecture, batch size, and loss surface curvature.

Exercise 4.5: Warmup Steps for LLaMA 3 Derive

LLaMA 3 70B trains for 15T tokens with batch size 4M tokens (4,194,304). Warmup is 2000 steps. How many tokens does the model see during warmup?

billion tokens
Show derivation
Tokens during warmup = 2000 steps × 4,194,304 tokens/step = 8,388,608,000 ≈ 8.39B tokens

That's ~0.056% of the total 15T tokens spent just ramping up the learning rate. It seems like a waste, but without warmup the first few hundred steps would produce garbage updates that can permanently damage the model. The investment pays for itself.

Chapter 5: Adam Optimizer

Adam is the default optimizer for virtually all modern deep learning. It combines two ideas: momentum (running average of gradients to smooth noisy updates) and adaptive learning rates (scaling each parameter's update by the inverse of its recent gradient magnitude).

Adam update rule at step t:
mt = β1 · mt−1 + (1 − β1) · gt     // first moment (momentum)
vt = β2 · vt−1 + (1 − β2) · gt²   // second moment (squared grad)
t = mt / (1 − β1t)                // bias correction
t = vt / (1 − β2t)                // bias correction
θt = θt−1 − lr · m̂t / (√v̂t + ε)

Defaults: β1 = 0.9, β2 = 0.999, ε = 10−8
Why bias correction matters. At step 1, m1 = 0.1·g1 (since m0=0). This underestimates the true mean gradient by 10×. The correction divides by (1 − 0.91) = 0.1, recovering the true scale. Without it, the first ~10 steps would have severely underscaled updates.
Exercise 5.1: First Moment at Step 3 Derive

β1 = 0.9, m0 = 0. Gradients: g1 = 4.0, g2 = −2.0, g3 = 6.0. Compute the bias-corrected first moment m̂3.

Show derivation
m1 = 0.9 × 0 + 0.1 × 4.0 = 0.4
m2 = 0.9 × 0.4 + 0.1 × (−2.0) = 0.36 − 0.2 = 0.16
m3 = 0.9 × 0.16 + 0.1 × 6.0 = 0.144 + 0.6 = 0.744
3 = 0.744 / (1 − 0.9³) = 0.744 / (1 − 0.729) = 0.744 / 0.271 = 2.745

Note: The uncorrected m3 = 0.744 would severely underestimate the true running mean. The correction factor 1/(1 − 0.729) ≈ 3.69 compensates for the zero initialization. After ~30 steps, β1t becomes negligible and the correction is effectively 1.

Exercise 5.2: Second Moment Derive

Same gradients: g1=4.0, g2=−2.0, g3=6.0. β2=0.999, v0=0. Compute v̂3 (bias-corrected second moment at step 3).

Show derivation
v1 = 0.999 × 0 + 0.001 × 16.0 = 0.016
v2 = 0.999 × 0.016 + 0.001 × 4.0 = 0.015984 + 0.004 = 0.019984
v3 = 0.999 × 0.019984 + 0.001 × 36.0 = 0.019964 + 0.036 = 0.055964
3 = 0.055964 / (1 − 0.999³) = 0.055964 / (1 − 0.997003) = 0.055964 / 0.002997 = 18.67

The bias correction for v is even more dramatic: the correction factor is ~333× at step 3. This is because β2 = 0.999 means v accumulates very slowly, so it takes many more steps to reach its steady state. This large correction is why the first few Adam steps behave differently — and why warmup is important.

Exercise 5.3: Adam Memory Cost Derive

Adam stores m and v (both same shape as parameters) in FP32, even when parameters are in FP16. For a 7B parameter model, how many GB do the Adam optimizer states consume?

GB
Show derivation
Adam states = m (FP32) + v (FP32) = 2 × 7B × 4 bytes = 56 GB

This is why training is so memory-hungry. The model weights in FP16 are only 14 GB, but the optimizer states add 56 GB — 4× more than the model itself! This is the main reason ZeRO-style optimizer sharding exists: split these 56 GB across GPUs instead of duplicating on each one.

Exercise 5.4: Find the Bug in Adam Debug

This Adam implementation gives wrong updates for the first ~100 steps but converges to correct behavior later. Click the buggy line.

function adamStep(theta, grad, m, v, t, lr, b1, b2, eps) {
  m = b1 * m + (1 - b1) * grad;
  v = b2 * v + (1 - b2) * grad * grad;
  const mHat = m;
  const vHat = v;
  theta = theta - lr * mHat / (Math.sqrt(vHat) + eps);
  return { theta, m, v };
}
Show explanation

Line 4 (and line 5) is the bug. The bias correction is missing: mHat = m should be mHat = m / (1 - b1**t), and similarly for vHat. Without bias correction, the first ~100 steps have severely underscaled updates because m and v are initialized to 0 and slowly warm up. After many steps, βt → 0 and the correction becomes negligible, which is why it "converges to correct behavior later."

Exercise 5.5: Implement adamStep() Build

Write a single Adam update step. Return the updated {theta, m, v}.

Remember bias correction: mHat = m / (1 - b1^t), vHat = v / (1 - b2^t).
Show solution
javascript
function adamStep(theta, grad, m, v, t, lr, b1, b2, eps) {
  m = b1 * m + (1 - b1) * grad;
  v = b2 * v + (1 - b2) * grad * grad;
  const mHat = m / (1 - Math.pow(b1, t));
  const vHat = v / (1 - Math.pow(b2, t));
  theta = theta - lr * mHat / (Math.sqrt(vHat) + eps);
  return { theta, m, v };
}

Chapter 6: Batch Normalization

Batch normalization was one of the biggest training breakthroughs: it stabilizes training by normalizing activations across the batch. But it has quirks — it behaves differently at train time vs. inference, and it fails at small batch sizes. Understanding the math explains all these behaviors.

BatchNorm forward (training):
μB = (1/B) ∑i xi                // batch mean
σB² = (1/B) ∑i (xi − μB)²    // batch variance
i = (xi − μB) / √(σB² + ε) // normalize
yi = γ · x̂i + β                // scale and shift (learnable)
γ and β let the network undo normalization. If the optimal activation distribution is not zero-mean unit-variance, the network can learn γ and β to recover any mean and variance. BatchNorm constrains the activations initially (helping training) but doesn't permanently limit expressivity.
Exercise 6.1: Compute BatchNorm by Hand Derive

Batch of 4 values: x = [2, 1, 3, 4]. γ = 2, β = 1, ε = 0. Compute the BN output for x1 = 2.

First compute μB and σB², then normalize x1, then scale and shift.

y1
Show derivation
μB = (2 + 1 + 3 + 4) / 4 = 10/4 = 2.5
σB² = ((2−2.5)² + (1−2.5)² + (3−2.5)² + (4−2.5)²) / 4
σB² = (0.25 + 2.25 + 0.25 + 2.25) / 4 = 5.0/4 = 1.25
1 = (2 − 2.5) / √1.25 = −0.5 / 1.118 = −0.4472
y1 = 2 × (−0.4472) + 1 = −0.8944 + 1 = 0.106
Exercise 6.2: Why BN Fails at Batch=1 Trace
If the batch has only 1 sample, what happens during BatchNorm training?
Show explanation

With batch=1, μB = x, so (x − μB) = 0 for every element. The variance is also 0. After normalization, x̂ = 0/√ε ≈ 0 everywhere — the layer outputs γ·0 + β = β regardless of input. This is why LayerNorm was invented for sequence models: it normalizes across features (the d dimension) instead of across the batch, so it works with any batch size, even 1.

Exercise 6.3: Running Stats Derive

During training, BN tracks running statistics: μrun = momentum × μrun + (1 − momentum) × μB with momentum=0.1. If μrun=0 initially and the first 3 batch means are 2.5, 3.0, 2.0, what is μrun after 3 batches?

Show derivation
μrun,1 = 0.1 × 0 + 0.9 × 2.5 = 2.25
μrun,2 = 0.1 × 2.25 + 0.9 × 3.0 = 0.225 + 2.7 = 2.925
μrun,3 = 0.1 × 2.925 + 0.9 × 2.0 = 0.2925 + 1.8 = 2.093

Note: PyTorch uses the convention where momentum=0.1 means "keep 10% of old, take 90% of new" — the opposite of many textbooks. This is why the running mean converges quickly to recent batch means. At eval time, this accumulated μrun replaces the batch mean to ensure deterministic outputs.

Exercise 6.4: LayerNorm vs BatchNorm Trace
Transformers use LayerNorm instead of BatchNorm. Which is the key practical reason?
Show explanation

BatchNorm computes statistics across the batch for each feature. This creates three problems for transformers: (1) different sequence lengths in a batch have different valid positions, (2) autoregressive generation uses batch=1, and (3) batch statistics couple samples in unwanted ways. LayerNorm sidesteps all of this by normalizing across the d-dimensional feature vector for each token independently.

Exercise 6.5: RMSNorm Simplification Derive

RMSNorm (used in LLaMA) skips the mean subtraction: x̂i = xi / RMS(x) where RMS(x) = √((1/d)∑xi²). For x = [3, 4] and γ = [1, 1], compute the RMSNorm output for the first element.

Show derivation
RMS = √((3² + 4²) / 2) = √(25/2) = √12.5 = 3.536
1 = 3 / 3.536 = 0.8485
y1 = 1 × 0.8485 = 0.849

RMSNorm saves the mean computation (one fewer reduction kernel) and has fewer parameters (no β — only γ). The 2023 LLaMA paper showed it trains just as well as LayerNorm while being ~8% faster in wall-clock time.

Chapter 7: Grad Accumulation & Mixed Precision

Real-world training rarely fits the ideal "one big batch" into GPU memory. Instead, we use two tricks: gradient accumulation (split the batch across multiple forward/backward passes, sum gradients) and mixed precision (use FP16 for speed but FP32 where precision matters).

Effective batch size:
Beff = micro_batch × accum_steps × num_GPUs

Mixed precision training:
1. Keep master weights in FP32
2. Cast to FP16 for forward + backward // 2× memory savings on activations
3. Compute gradients in FP16
4. Scale loss by S (e.g. 216) before backward // prevent gradient underflow
5. Unscale gradients (divide by S) after backward
6. Update master weights in FP32 with Adam
Why loss scaling? FP16 can only represent values down to ~6×10−8. Many gradients in deep networks are smaller than this and would become zero in FP16 ("underflow"). Multiplying the loss by 216 = 65536 shifts all gradients up by that factor, keeping them in FP16's representable range. We divide by 65536 before the optimizer step to compensate.
Exercise 7.1: Effective Batch Size Derive

micro_batch = 4, accumulation_steps = 8, num_GPUs = 4. What is the effective batch size?

Show derivation
Beff = 4 × 8 × 4 = 128

Each GPU processes a micro-batch of 4 samples, accumulates gradients over 8 steps (32 effective per GPU), and 4 GPUs contribute in parallel. The optimizer step happens once per 128 samples — identical to training with batch size 128 on a single GPU (if it fit in memory).

Exercise 7.2: FP16 Activation Savings Derive

A transformer layer stores activations for the backward pass. For batch=32, seq=2048, d=4096, the activations for one layer include: input (X), QKV projections, attention scores (softmax output), FFN intermediate, and residual connections — approximately 4 × [B×seq, d] tensors in the main path. How much memory is saved by storing these in FP16 instead of FP32?

GB saved
Show derivation
One tensor = 32 × 2048 × 4096 = 268,435,456 elements
4 tensors = 4 × 268,435,456 = 1,073,741,824 elements
FP32 cost = 1,073,741,824 × 4 bytes = 4.0 GB
FP16 cost = 1,073,741,824 × 2 bytes = 2.0 GB
Savings = 4.0 − 2.0 = 2.0 GB per layer

For a 32-layer model, that's 64 GB of activation memory saved. This is the main practical benefit of mixed precision — not the 2× faster matmuls on tensor cores (though that helps too), but the halved activation memory that lets you train larger models or use larger batch sizes.

Exercise 7.3: Loss Scaling Trace
A gradient value is 1×10−9 in FP32. FP16's smallest positive normal is ~6×10−8. With loss scaling factor S = 216 = 65536, is this gradient preserved?
Show explanation

Multiplying the loss by S scales all gradients by S (by the chain rule). The original 1×10−9 becomes 6.55×10−5, which FP16 represents with no problem. After the backward pass, we divide all gradients by S before the optimizer step, recovering the true gradient value in FP32.

Exercise 7.4: Training Memory Budget Derive

For a 7B parameter model with mixed precision (FP16 weights + FP32 master weights + Adam states), compute the optimizer memory alone. Parameters in FP16: 2 bytes each. Master copy in FP32: 4 bytes. Adam m and v in FP32: 4 bytes each.

GB
Show derivation
FP16 params = 7B × 2 = 14 GB
FP32 master = 7B × 4 = 28 GB
Adam m (FP32) = 7B × 4 = 28 GB
Adam v (FP32) = 7B × 4 = 28 GB
Total = 14 + 28 + 28 + 28 = 98 GB

That's 14 bytes per parameter, or 2× + 4× + 4× + 4× = 14× the FP16 model size. This is the famous "14 bytes per parameter" rule for mixed-precision Adam training. A single 80GB A100 can't even hold the optimizer states for a 7B model — you need at least 2 GPUs with ZeRO Stage 2.

Exercise 7.5: Tokens Per Step Derive

You're training with micro_batch=2, seq=4096, accum_steps=16, num_GPUs=8. How many tokens does the model process per optimizer step?

tokens (thousands)
Show derivation
Effective batch = 2 × 16 × 8 = 256 sequences
Tokens per step = 256 × 4096 = 1,048,576 ≈ 1049K tokens

~1M tokens per step is a common target for LLM training. LLaMA 3 used ~4M tokens per step. The tokens-per-step determines how many optimizer steps you need: to train on 1T tokens at 1M tokens/step, you need 1,000,000 steps.

Chapter 8: Training Diagnostics

The loss curve and gradient statistics are your window into what's happening inside the model during training. Learning to read these signals is like a doctor reading vital signs — it tells you whether the patient is healthy, sick, or about to crash.

Gradient L2 norm:
||g||2 = √(∑i gi²)    // across ALL parameters

Gradient clipping (max_norm method):
if ||g||2 > max_norm:
   g = g × max_norm / ||g||2    // scale down to max_norm, preserve direction
Loss curve shapes tell a story. Smooth decrease = healthy training. Oscillation = LR too high. Plateau = LR too low or model bottleneck. Sudden spike then recovery = bad batch or gradient explosion (clipping saved it). Spike to NaN = unrecoverable divergence. If loss increases monotonically = learning rate way too high, or a sign flip bug in the loss function.
Exercise 8.1: Gradient Norm Derive

A tiny model has 5 parameters with gradients: g = [0.3, −0.4, 0.5, −0.1, 0.2]. Compute the L2 gradient norm.

Show derivation
||g||2 = √(0.09 + 0.16 + 0.25 + 0.01 + 0.04) = √0.55 = 0.7416
Exercise 8.2: Gradient Clipping Derive

The total gradient norm is 5.0 but max_norm = 1.0. What scaling factor is applied to all gradients? After clipping, what is the new gradient for a parameter whose original gradient was 2.0?

clipped gradient value
Show derivation
scale = max_norm / ||g||2 = 1.0 / 5.0 = 0.2
clipped_grad = 2.0 × 0.2 = 0.4

Gradient clipping preserves the direction of the gradient vector but caps its magnitude. Every parameter's gradient is scaled by the same factor, so the relative magnitudes are preserved. This is much better than per-parameter clipping, which distorts the direction.

Exercise 8.3: Implement clipGradNorm() Build

Write a function that clips a gradient array by its L2 norm. Return the clipped gradient array.

Compute L2 norm, then if norm > maxNorm, scale all grads by maxNorm/norm.
Show solution
javascript
function clipGradNorm(grads, maxNorm) {
  const norm = Math.sqrt(grads.reduce((s, g) => s + g * g, 0));
  if (norm <= maxNorm) return grads.slice();
  const scale = maxNorm / norm;
  return grads.map(g => g * scale);
}
Exercise 8.4: Diagnose the Loss Curve Trace
Your loss curve drops rapidly for the first 1000 steps, then completely flattens. The gradient norm is tiny (~1e-7). What is the most likely diagnosis?
Show explanation

A flat loss with tiny gradient norm means the optimizer has no useful signal to follow. This is different from convergence (where loss would be near the theoretical minimum). Common causes: (1) the LR schedule decayed too fast and hit ~0 too early, (2) the LR is so small that updates are negligible, (3) the model landed in a saddle point. The fix is usually to restart with a higher minimum LR or longer warmup.

Exercise 8.5: Gradient Explosion Detection Derive

You log the gradient norm every 100 steps: [1.2, 1.5, 1.8, 2.4, 4.1, 12.3, 89.7, NaN]. With max_norm=1.0 gradient clipping, at which step range did clipping first activate?

step
Show explanation

Clipping activates whenever the norm exceeds max_norm = 1.0. The very first logged norm is 1.2 > 1.0, so clipping was already active at step 100 (and likely from the start). The escalating norms (1.2 → 89.7 → NaN) suggest the model is diverging despite clipping — the gradients are growing faster than clipping can contain. This needs a lower learning rate, not just more aggressive clipping.

Chapter 9: Capstone: Train a Tiny Model

Time to put it all together. You're planning a training run for a 125M parameter model on 10 billion tokens. You need to compute everything: FLOPs, training time, memory, learning rate, cost. These are the exact calculations that ML engineers at frontier labs do before every run.

Key formulas for training estimation:
Total FLOPs ≈ 6 × N × D // N = params, D = tokens (Kaplan et al.)
GPU throughput ≈ 150 TFLOPS/s per A100 (effective, mixed precision) // ~50% of peak 312 TFLOPS
Training time = Total FLOPs / (num_GPUs × throughput)
Memory per GPU ≈ 14 × N bytes (mixed precision Adam) + activation memory
The 6ND rule. A single training step does ~6 FLOPs per parameter per token: 2 for the forward pass, 4 for the backward pass (2 for computing gradients of activations, 2 for computing gradients of weights). This 6ND approximation is accurate to within ~10% for standard transformers.
Exercise 9.1: Total FLOPs Derive

N = 125M parameters, D = 10B tokens. Total training FLOPs?

× 1018 FLOPs (exaFLOPs)
Show derivation
FLOPs = 6 × 125 × 106 × 10 × 109 = 6 × 1.25 × 1018 = 7.5 × 1018
Exercise 9.2: Training Time Derive

8 A100 GPUs, each at 150 TFLOPS/s effective throughput. How many hours to complete 7.5 × 1018 FLOPs?

hours
Show derivation
Total throughput = 8 × 150 × 1012 = 1.2 × 1015 FLOPS/s
Time = 7.5 × 1018 / 1.2 × 1015 = 6250 seconds
Hours = 6250 / 3600 = 1.74 hours

A 125M model on 10B tokens is quite small — under 2 hours on 8 A100s. For comparison, LLaMA 3 70B on 15T tokens takes ~30 million GPU-hours.

Exercise 9.3: Memory Budget Derive

125M params, mixed-precision Adam (14 bytes/param for model + optimizer). Does this fit on a single 80GB A100 (leaving 40GB for activations)?

GB (model + optimizer only)
Show derivation
Memory = 14 × 125 × 106 = 1.75 × 109 bytes = 1.75 GB

Easily fits on a single GPU! With 78 GB remaining for activations, you can use very large batch sizes or long sequences. In fact, the bottleneck for a 125M model is usually compute (GPU utilization), not memory. You'd use 8 GPUs for speed, not memory.

Exercise 9.4: Optimal Learning Rate Trace
According to the Chinchilla scaling laws, the optimal LR scales roughly as N−0.5 (decreasing with model size). If a 1B model uses lr = 3×10−4, what's a reasonable estimate for our 125M model?
Show derivation
lr ∝ N−0.5
lr125M / lr1B = (125M / 1B)−0.5 = (0.125)−0.5 = 1/√0.125 = √8 ≈ 2.83
lr125M ≈ 2.83 × 3×10−48.5 × 10−4

Smaller models can tolerate higher learning rates because their loss landscape is smoother with fewer parameters. GPT-2 Small (124M) was originally trained with lr = 2.5×10−4, but modern practice with cosine schedules suggests ~6-10×10−4 works well for this scale.

Exercise 9.5: Training Steps Derive

10B tokens, batch size = 512 sequences of length 1024. How many optimizer steps?

steps
Show derivation
Tokens per step = 512 × 1024 = 524,288
Steps = 10,000,000,000 / 524,288 = 19,073

~19K steps is a short training run. With 2000 warmup steps, that means warmup is ~10% of total training — typical for this scale. The cosine schedule decays over the remaining 17K steps.

Exercise 9.6: Arrange the Training Pipeline Design

Put these steps in the correct order for one complete training iteration with gradient accumulation and mixed precision.

?
?
?
?
?
?
FP16 forward pass Scale loss (× S) FP16 backward pass Unscale grads (÷ S) Clip gradients FP32 Adam step
Show explanation

The correct order: FP16 forwardScale lossFP16 backwardUnscale gradsClip gradientsFP32 Adam step. Scaling must happen before backward (so gradients stay in FP16 range). Unscaling must happen before clipping (clip on true gradient magnitude). The Adam step uses FP32 master weights to maintain precision.

Exercise 9.7: Training Cost Derive

8 A100 GPUs for 1.74 hours. Cloud cost: $2.50/GPU-hour (A100 80GB spot). How much does this training run cost?

USD
Show derivation
GPU-hours = 8 × 1.74 = 13.92
Cost = 13.92 × $2.50 = $34.80

A 125M model on 10B tokens costs ~$35 on spot instances. Scaling this up: a 7B model on 1T tokens costs ~$200K, and LLaMA 3 405B on 15T tokens cost an estimated $100M+. Compute cost scales as O(N × D), so both model size and dataset size are cost multipliers.

The proof of work. If you completed every exercise in this workbook from scratch — traced chain rules through networks, computed gradient shapes, derived loss functions, worked through Adam updates, and planned a full training run — you can reason about training at a systems level. These calculations are the foundation for every training decision: architecture, optimizer, schedule, precision, budget. "What I cannot create, I do not understand."

Related Lessons

TopicLesson
Transformer fundamentalsTransformer — From Absolute Zero
GPT architectureGPT — From Absolute Zero
Transformer mathTransformer Math Workbook
Distributed trainingDistributed Training — From Absolute Zero