Why small models can't reason by copying — and the training loop that fixes it.
You clone GPT-4's homework answers perfectly. Your student model scores 95% on predicting the next word. Then someone asks it to prove that √2 is irrational, and it produces: "Assume √2 = p/q. Then 2 = p²/q². Therefore p² = 2q². So p is even. Wait, actually let me try a different approach... The proof is left as an exercise."
It sounds like math. It is not math. The model learned to imitate the surface pattern of reasoning — hedging, self-correction phrases, formal notation — without learning the actual logical chain that connects assumptions to conclusions.
This is the central failure of naive knowledge distillation: training a small "student" model by having it predict the same tokens as a large "teacher" model. The student learns what the right answer looks like, but not how to produce it step-by-step when things go slightly off-script.
The problem is subtle but devastating. During training, the student sees perfect teacher-generated text at every position. It learns: "given this perfect prefix, predict the next token." But during inference, the student generates its own prefix. The moment it makes a single mistake — one wrong token — it's now in territory it has never seen during training. And from that point forward, every subsequent token is generated from an unfamiliar context.
Think of it this way: off-policy training is like learning to drive by watching dashcam footage of perfect driving. You know what a good lane change looks like. But the first time you drift slightly left, you've never practiced recovery — you've only ever seen the view from perfectly centered. Panic ensues.
On-policy training is like actually getting behind the wheel with an instructor. You make mistakes. You drift. And the instructor says "ok, from HERE, do THIS." You learn what recovery looks like from error states, not just from perfect states. This is the key insight of on-policy distillation.
Toggle between modes to see how errors propagate. In copy mode, the student trains on perfect tokens but fails at inference. In practice mode, it trains on its own outputs and learns to recover.
Let's make this concrete with actual arithmetic. Suppose your student model achieves ε = 0.05 per-token error rate — meaning 95% accuracy at each individual position. That sounds excellent. But reasoning requires getting a sequence of tokens all correct.
For a sequence of length T, the probability of producing the entire sequence correctly (assuming independent errors, which is optimistic) is (1 - ε)T. Let's compute this for various sequence lengths:
| Sequence Length T | Success Prob (1-ε)T | Interpretation |
|---|---|---|
| 5 tokens | 0.955 = 0.774 (77%) | Short factual answer — mostly works |
| 10 tokens | 0.9510 = 0.599 (60%) | Simple reasoning — coin flip |
| 20 tokens | 0.9520 = 0.358 (36%) | Medium proof — usually fails |
| 50 tokens | 0.9550 = 0.077 (7.7%) | Long derivation — almost always fails |
| 100 tokens | 0.95100 = 0.006 (0.6%) | Complex reasoning — effectively impossible |
This is the compounding error problem. Even a tiny per-step error rate becomes catastrophic over long sequences. And the situation is actually worse than this table suggests, because errors aren't independent — once the model makes one error, it enters an unfamiliar state, making subsequent errors more likely.
Let's trace through the math carefully for T=5, ε=0.05:
Step by step: 0.951 = 0.95. Then 0.952 = 0.95 × 0.95 = 0.9025. Then 0.953 = 0.9025 × 0.95 = 0.8574. Then 0.954 = 0.8574 × 0.95 = 0.8145. Finally 0.955 = 0.8145 × 0.95 = 0.7738.
So even with 95% per-step accuracy over just 5 steps, we've already lost 23% of our sequences. Now imagine T=50: we'd compute 0.95 multiplied by itself 50 times. The result is 0.0769 — only 7.7% of sequences survive intact.
First, let's compute this manually, step by step:
python # Manual computation — no libraries epsilon = 0.05 # per-token error rate T = 20 # sequence length # Multiply (1 - epsilon) by itself T times prob = 1.0 for step in range(T): prob = prob * (1 - epsilon) print(f"Step {step+1:2d}: P(all correct) = {prob:.4f}") # Step 1: P(all correct) = 0.9500 # Step 2: P(all correct) = 0.9025 # Step 5: P(all correct) = 0.7738 # Step 10: P(all correct) = 0.5987 # Step 20: P(all correct) = 0.3585
Now the compact version:
python # One-liner: success probability for any epsilon, T success = (1 - 0.05) ** 20 # = 0.3585 # Sweep across sequence lengths for T in [5, 10, 20, 50, 100]: p = (1 - 0.05) ** T print(f"T={T:3d}: P(success) = {p:.4f} ({p*100:.1f}%)")
And with numpy for vectorized computation:
python import numpy as np epsilons = np.array([0.01, 0.05, 0.10]) lengths = np.arange(1, 101) # Shape: (3, 100) — each row is one epsilon, columns are T=1..100 success_grid = (1 - epsilons[:, None]) ** lengths[None, :] # success_grid[1, 49] = P(success | epsilon=0.05, T=50) = 0.0769
Chapter 0 showed that something goes wrong. Now let's be precise about what goes wrong, mathematically. The training loop looks innocent enough — but it hides a devastating assumption that only reveals itself at inference time.
Here's the standard token-level knowledge distillation objective. We have a teacher model pT and a student model pθ. At each position t in a sequence, we minimize the KL divergence between the teacher's distribution and the student's distribution:
Read that carefully. The prefix y<t — all the tokens before position t — is the teacher's sequence. The student is being asked: "given this perfect prefix written by the teacher, what should the next token be?" And it learns to answer that question very well.
But at inference time, the student doesn't get the teacher's prefix. It generates its own:
where ŷ<t ~ pθ — the prefix is sampled from the student's own distribution. This is the train/test mismatch. During training, the student always conditions on clean, teacher-generated context. During inference, it conditions on its own noisy, imperfect context. The distribution of prefixes at test time is fundamentally different from training time.
This mismatch has a name in the sequential decision-making literature: distribution shift (also called covariate shift in the supervised learning context). The input distribution (prefixes) changes between training and testing. And the shift compounds — each error makes future inputs even more out-of-distribution.
Let's make this concrete with a picture. At position t=1, both training and inference are identical — there's no prefix yet, just the prompt x. At position t=2, training uses the teacher's first token (perfect), while inference uses the student's first token (possibly wrong). By position t=5, training still operates from a perfect 4-token prefix, while inference has been building on potentially corrupted tokens for 4 steps.
Left: during training, the student always sees perfect teacher tokens (green). Right: during inference, it sees its own tokens — some correct (green), some wrong (orange). Watch how drift accumulates token by token.
Let's compute the actual loss at a single token position. Suppose we have a tiny vocabulary of 3 tokens: ["therefore", "so", "but"]. The teacher assigns probabilities:
pT = [0.7, 0.2, 0.1] — the teacher is fairly confident "therefore" comes next.
The student's current distribution is:
pθ = [0.3, 0.4, 0.3] — the student is confused, spreading probability too broadly.
The forward KL divergence DKL(pT || pθ) measures how surprised the student would be if reality followed the teacher:
Let's compute each term:
Term 1 ("therefore"): 0.7 × log(0.7 / 0.3) = 0.7 × log(2.333) = 0.7 × 0.8473 = 0.5931
Term 2 ("so"): 0.2 × log(0.2 / 0.4) = 0.2 × log(0.5) = 0.2 × (-0.6931) = -0.1386
Term 3 ("but"): 0.1 × log(0.1 / 0.3) = 0.1 × log(0.333) = 0.1 × (-1.0986) = -0.1099
Total: DKL = 0.5931 + (-0.1386) + (-0.1099) = 0.3447 nats
Wait — negative terms? Yes! KL divergence sums to a non-negative total (always ≥ 0 by Gibbs' inequality), but individual terms can be negative. The negative terms come from tokens where the student assigns MORE probability than the teacher. In our case, the student over-weights "so" and "but" while under-weighting "therefore." The net KL is positive: 0.3447 nats, meaning the student still has significant room to improve.
First, from absolute scratch — no libraries:
python import math # Teacher and student distributions p_teacher = [0.7, 0.2, 0.1] p_student = [0.3, 0.4, 0.3] # Manual KL: sum p_T * log(p_T / p_S) kl = 0.0 for i in range(len(p_teacher)): if p_teacher[i] > 0: term = p_teacher[i] * math.log(p_teacher[i] / p_student[i]) print(f" Term {i}: {p_teacher[i]:.1f} * log({p_teacher[i]:.1f}/{p_student[i]:.1f}) = {term:.4f}") kl += term print(f"KL divergence = {kl:.4f} nats") # Output: # Term 0: 0.7 * log(0.7/0.3) = 0.5931 # Term 1: 0.2 * log(0.2/0.4) = -0.1386 # Term 2: 0.1 * log(0.1/0.3) = -0.1099 # KL divergence = 0.3447 nats
Now with PyTorch — how you'd actually compute it in a training loop:
python import torch import torch.nn.functional as F # Logits from teacher and student (before softmax) teacher_logits = torch.tensor([1.5, 0.3, -0.5]) # arbitrary logits student_logits = torch.tensor([0.1, 0.5, 0.1]) # Convert to log-probs (student) and probs (teacher) teacher_probs = F.softmax(teacher_logits, dim=-1) student_log_probs = F.log_softmax(student_logits, dim=-1) # PyTorch KL: expects log(Q) and P, computes sum P * (log P - log Q) kl = F.kl_div(student_log_probs, teacher_probs, reduction='sum') print(f"KL = {kl.item():.4f}")
Now let's see why the mismatch doesn't just cause a fixed offset, but actually compounds over time. Consider a sequence where the student makes an error at position t=3. At position t=4:
During training, the input is [prompt, teacher_tok_1, teacher_tok_2, teacher_tok_3]. The student has seen millions of examples with this exact type of clean context.
During inference, the input is [prompt, student_tok_1, student_tok_2, WRONG_tok_3]. The student has never seen context with this error. Its prediction at position 4 will be unreliable. And unreliable position 4 feeds into position 5, which feeds into position 6...
This is exposure bias (Bengio et al., 2015) — the model is never "exposed" to its own errors during training, so it can't handle them at inference. The errors compound because each error creates a novel context that triggers further errors.
Chapter 1 showed there's a mismatch between training and inference. But how bad is it? Could we just add more training data and fix it? Could we make the student bigger? Or is there a fundamental scaling law that makes off-policy training catastrophically worse as sequences get longer?
The answer comes from a remarkable theorem by Ross, Gordon, and Bagnell (2011), originally proven for imitation learning in robotics. It applies directly to autoregressive language models and gives us a precise bound on how errors scale with sequence length.
The theorem says: if your per-step error rate is ε, then your total accumulated error over a sequence of length T scales as:
Quadratic vs linear. That's not a constant factor improvement — it's a fundamentally different scaling law. At T=100, quadratic gives 10,000ε while linear gives 100ε. The off-policy approach is 100x worse at sequence length 100.
Let's derive WHY the error is quadratic for off-policy training. The argument is intuitive once you see it:
Setup: The student has per-step error rate ε. This means at any given step, there's a probability ε that it produces a token different from what the teacher would produce.
Step 1: At time step t, what's the probability that the student is in an "unseen state" — a prefix it never encountered during training? If any of the previous t-1 tokens were wrong, the current prefix is novel. The probability of being in a novel state at step t is at most (t-1)·ε (union bound — each of the t-1 previous positions had ε probability of error).
Step 2: Once in an unseen state, the student's error rate is no longer ε — it could be as bad as 1 (total failure). So at step t, the expected error contribution is:
The first ε is the intrinsic error (even on familiar states). The second term is: probability of being in an unseen state (t-1)·ε, times worst-case error of 1.
Step 3: Sum over all T steps to get total error:
That's the quadratic bound: O(εT2). The sum of 1+2+3+...+T = T(T+1)/2 is the classic formula, and it gives us the T2 scaling.
Now suppose we train on-policy: the student sees its own prefixes during training. In that case, it HAS seen error states before. It knows how to behave when the prefix is imperfect. So even in a "novel" state (one that arose from a student error), the expected error per step is still just ε — because the student has practiced recovery.
No compounding. Just ε per step, T steps, total = εT. Linear.
Let's plug in numbers. Set ε = 0.05 (95% per-step accuracy):
| Sequence Length T | Off-Policy (εT²/2) | On-Policy (εT) | Ratio (off/on) |
|---|---|---|---|
| T = 10 | 0.05×100/2 = 2.5 | 0.05×10 = 0.5 | 5× |
| T = 20 | 0.05×400/2 = 10 | 0.05×20 = 1.0 | 10× |
| T = 50 | 0.05×2500/2 = 62.5 | 0.05×50 = 2.5 | 25× |
| T = 100 | 0.05×10000/2 = 250 | 0.05×100 = 5.0 | 50× |
At T=100, the off-policy error (250) exceeds the total sequence length — meaning the model has essentially failed completely. The on-policy error (5.0) is manageable — about 5 mistakes in 100 tokens, and the model can keep going.
Drag the ε slider to see how per-step error rate affects the gap between off-policy (quadratic, orange) and on-policy (linear, teal). The dashed line shows T (maximum reasonable error = total failure).
The DAgger paper didn't just prove the bound — it proposed a fix. The Dataset Aggregation (DAgger) algorithm works like this:
The key insight: in step 2, the teacher labels the student's states. This means the training data now includes examples from the states the student actually visits — including error states. Over iterations, the student builds experience with its own mistakes.
For LLM distillation, DAgger translates to: let the student generate text, then ask the teacher to provide probability distributions conditioned on the student's actual generated prefixes. This is computationally expensive (requires running the teacher on student-generated sequences) but provably reduces the error bound from quadratic to linear.
python # Manual computation of DAgger bounds epsilon = 0.05 # Off-policy: sum of epsilon * t for t=1..T def off_policy_error(eps, T): # = eps * T*(T+1)/2 total = 0.0 for t in range(1, T+1): total += eps * t return total # On-policy: sum of epsilon for t=1..T def on_policy_error(eps, T): return eps * T # Check: off_policy_error(0.05, 50) # = 0.05 * (1+2+3+...+50) = 0.05 * 1275 = 63.75 print(off_policy_error(0.05, 50)) # 63.75 print(on_policy_error(0.05, 50)) # 2.50
Compact one-liners:
python # Closed-form expressions off_err = lambda e, T: e * T * (T + 1) / 2 on_err = lambda e, T: e * T # At what T does off-policy become 10x worse? # off/on = T*(T+1)/2 / T = (T+1)/2 = 10 => T = 19 print(f"Off-policy is 10x worse at T = 19") print(f" Off: {off_err(0.05, 19):.2f}, On: {on_err(0.05, 19):.2f}") # Off: 9.50, On: 0.95 — ratio = 10.0
And visualized with numpy:
python import numpy as np eps = 0.05 T = np.arange(1, 101) off_policy = eps * T * (T + 1) / 2 # quadratic on_policy = eps * T # linear # "Success probability" heuristic: clip error to [0, T] off_success = np.maximum(0, 1 - off_policy / T) on_success = np.maximum(0, 1 - on_policy / T) # At T=50: off_success = max(0, 1 - 63.75/50) = 0 (total failure) # At T=50: on_success = max(0, 1 - 2.5/50) = 0.95 (mostly works!)
OK, so on-policy training helps (Chapter 2). But what should the student actually optimize? "Match the teacher's distribution" can mean many different things mathematically. The choice of how to measure "mismatch" dramatically changes what the student learns.
All the distance measures we'll consider belong to a single family: the f-divergences. Every f-divergence has the form:
where f is a convex function with f(1) = 0. Different choices of f give different divergences, each with distinct behavior:
| Divergence | f(u) | Formula | Behavior |
|---|---|---|---|
| Forward KL | u log u | ∑ P log(P/Q) | Mode-covering (broad) |
| Reverse KL | -log u | ∑ Q log(Q/P) | Mode-seeking (sharp) |
| JSD | -(1+u)/2 · log((1+u)/2) + u/2 · log u | ½KL(P||M) + ½KL(Q||M) | Symmetric compromise |
| Total Variation | |u-1|/2 | ½ ∑ |P-Q| | Largest single-token gap |
The two most important for LLM distillation are Forward KL (FKL) and Reverse KL (RKL). Their difference is best understood through a specific example.
Imagine a teacher distribution with two peaks — two equally valid ways to answer a math question. Maybe "therefore x = 3" and "so we get x = 3" are both correct continuations. The teacher assigns probability to both modes.
Now the student is smaller — maybe it can only represent a unimodal distribution (one peak). What does it do?
Forward KL (mode-covering): FKL = ∑ pT log(pT/pθ) penalizes the student for putting LOW probability where the teacher puts HIGH probability. If the teacher has two peaks, the student MUST cover both — or face infinite penalty (log(something/0) = ∞). So the student spreads out to cover both modes... and ends up assigning probability to the gap between them. This means it might generate tokens that are BETWEEN the two valid answers — hallucination.
Reverse KL (mode-seeking): RKL = ∑ pθ log(pθ/pT) penalizes the student for putting HIGH probability where the teacher puts LOW probability. The student is rewarded for being precise — it picks ONE mode and locks onto it sharply. It ignores the other mode entirely. This means it might miss valid alternatives, but it never hallucinates.
JSD (compromise): M = (P+Q)/2 and JSD = ½KL(P||M) + ½KL(Q||M). It's symmetric, bounded between 0 and log(2), and offers a middle ground. The student covers most of the teacher's mass without putting too much weight in low-density regions.
Let's compute both divergences for concrete numbers. Teacher: PT = [0.5, 0.4, 0.1]. Student: Pθ = [0.6, 0.3, 0.1].
Forward KL: DKL(PT || Pθ) = ∑ PT(i) · log(PT(i) / Pθ(i))
Term 1: 0.5 × log(0.5/0.6) = 0.5 × log(0.8333) = 0.5 × (-0.1823) = -0.0912
Term 2: 0.4 × log(0.4/0.3) = 0.4 × log(1.3333) = 0.4 × 0.2877 = 0.1151
Term 3: 0.1 × log(0.1/0.1) = 0.1 × log(1.0) = 0.1 × 0 = 0
FKL = -0.0912 + 0.1151 + 0 = 0.0239 nats
Reverse KL: DKL(Pθ || PT) = ∑ Pθ(i) · log(Pθ(i) / PT(i))
Term 1: 0.6 × log(0.6/0.5) = 0.6 × log(1.2) = 0.6 × 0.1823 = 0.1094
Term 2: 0.3 × log(0.3/0.4) = 0.3 × log(0.75) = 0.3 × (-0.2877) = -0.0863
Term 3: 0.1 × log(0.1/0.1) = 0.1 × log(1.0) = 0.1 × 0 = 0
RKL = 0.1094 + (-0.0863) + 0 = 0.0231 nats
In this case FKL (0.0239) and RKL (0.0231) are similar — because the distributions are close and roughly unimodal. The dramatic difference appears when the teacher is bimodal and the student is unimodal.
Now let PT = [0.45, 0.05, 0.45, 0.05] — two strong modes at positions 0 and 2, with small mass between them. Student can only be unimodal. Two strategies:
Strategy A (cover both): Pθ = [0.3, 0.2, 0.3, 0.2] — broad, covering everything.
Strategy B (pick one): Pθ = [0.7, 0.2, 0.05, 0.05] — sharp, focused on mode 0.
FKL prefers Strategy A (0.191 nats) over Strategy B (0.552 nats) — because B puts only 0.05 where the teacher puts 0.45, creating a large log(0.45/0.05) = 2.197 penalty. But Strategy A hallucinates — it puts 20% probability on tokens 1 and 3 where the teacher barely has 5%.
RKL prefers Strategy B (0.193 nats) over Strategy A (0.239 nats) — because B doesn't put high probability where the teacher puts low. The "dropped mode" (token 2 getting only 0.05 from student) doesn't cost much in RKL because the student's mass there is small.
The white curve is a bimodal teacher distribution. Drag the slider to morph the student distribution from Forward KL optimum (broad, covers both modes) to Reverse KL optimum (sharp, picks one mode). The red region shows hallucination mass.
python import math def forward_kl(p_teacher, p_student): """D_KL(P_T || P_S) = sum P_T * log(P_T / P_S)""" kl = 0.0 for pt, ps in zip(p_teacher, p_student): if pt > 0: assert ps > 0, "FKL infinite if student has zero where teacher is positive" kl += pt * math.log(pt / ps) return kl def reverse_kl(p_teacher, p_student): """D_KL(P_S || P_T) = sum P_S * log(P_S / P_T)""" kl = 0.0 for pt, ps in zip(p_teacher, p_student): if ps > 0: assert pt > 0, "RKL infinite if teacher has zero where student is positive" kl += ps * math.log(ps / pt) return kl def jsd(p_teacher, p_student): """JSD = 0.5 * KL(P_T||M) + 0.5 * KL(P_S||M) where M = (P_T+P_S)/2""" m = [(pt + ps) / 2 for pt, ps in zip(p_teacher, p_student)] return 0.5 * forward_kl(p_teacher, m) + 0.5 * forward_kl(p_student, m) # Test with our worked example P_T = [0.5, 0.4, 0.1] P_S = [0.6, 0.3, 0.1] print(f"Forward KL: {forward_kl(P_T, P_S):.4f}") # 0.0239 print(f"Reverse KL: {reverse_kl(P_T, P_S):.4f}") # 0.0231 print(f"JSD: {jsd(P_T, P_S):.4f}") # 0.0059
And the PyTorch equivalent:
python import torch import torch.nn.functional as F P_T = torch.tensor([0.5, 0.4, 0.1]) P_S = torch.tensor([0.6, 0.3, 0.1]) # Forward KL: F.kl_div expects log(Q), P — computes sum P*(log P - log Q) fkl = F.kl_div(P_S.log(), P_T, reduction='sum') # Reverse KL: swap arguments rkl = F.kl_div(P_T.log(), P_S, reduction='sum') # JSD M = 0.5 * (P_T + P_S) jsd_val = 0.5 * F.kl_div(M.log(), P_T, reduction='sum') + \ 0.5 * F.kl_div(M.log(), P_S, reduction='sum') print(f"FKL={fkl:.4f}, RKL={rkl:.4f}, JSD={jsd_val:.4f}")
Here's the practical decision rule for LLM distillation:
| Task Type | Teacher Shape | Best Divergence | Why |
|---|---|---|---|
| Math proofs | Unimodal (one correct path) | Reverse KL | Precision: don't hallucinate wrong steps |
| Code generation | Few modes (correct implementations) | Reverse KL | Correctness over diversity |
| Translation | Moderate modes (valid phrasings) | JSD | Some variety without garbage |
| Creative writing | Highly multimodal (many valid outputs) | Forward KL | Preserve diversity of expression |
| Summarization | Moderate modes | JSD or FKL | Multiple valid summaries exist |
We know on-policy helps (Chapter 2). We have a menu of divergences (Chapter 3). But how do you actually combine these into a practical training algorithm? The first paper to do this systematically was Generalized Knowledge Distillation (GKD) by Agarwal et al. (2024). Its key insight: make the sampling policy and the divergence measure orthogonal design choices.
Previous methods hard-coded both: standard KD uses off-policy sampling + forward KL. SeqKD uses off-policy sampling + sequence-level loss. GKD says: pick ANY combination from a 2D grid:
where πmix = λ · pθ + (1-λ) · pdata is a mixture sampling policy. The parameter λ ∈ [0, 1] controls how "on-policy" you are:
| λ | Sampling Policy | What the student sees |
|---|---|---|
| 0 | Pure off-policy (pdata) | Only ground-truth / teacher sequences |
| 0.5 | 50/50 mixture | Half teacher sequences, half student rollouts |
| 1.0 | Pure on-policy (pθ) | Only the student's own generated sequences |
And Df can be any f-divergence from Chapter 3 — FKL, RKL, JSD, etc. The two choices are independent.
The paper's most surprising result: the sampling policy (λ) matters MORE than the divergence choice. Even with simple forward KL, switching from λ=0 (off-policy) to λ=0.5 (mixed) dramatically improves downstream task performance. The divergence choice matters less — it's a second-order effect.
This makes intuitive sense given Chapter 2: the DAgger theorem says the sampling distribution (on-policy vs off-policy) determines whether errors scale quadratically or linearly. The divergence just affects the constant factor.
Let's trace through exactly what happens in a single training step with λ=0.5 and JSD loss. This is the complete data flow, no hand-waving.
Step 1: Decide sampling source. We flip a weighted coin: with probability λ=0.5, we use on-policy (student generates); with probability 1-λ=0.5, we use the ground-truth sequence from the dataset. Suppose the coin says "on-policy."
Step 2: Generate student rollout. Given prompt x = "Prove that 2+2=4", the student autoregressively generates: ŷ = [2, +, 2, =, 4, by, definition]. Seven tokens. Each sampled from pθ(·|x, ŷ<t).
Step 3: Get teacher scores. We feed the student's sequence back through the teacher model. At EACH position t, we get the teacher's full probability distribution pT(·|x, ŷ<t). Note: the teacher conditions on the STUDENT's prefix — this is the on-policy part.
Step 4: Compute per-token JSD loss. At position t=3 (predicting "2"), for example:
Teacher distribution: pT = [0.01, 0.02, 0.85, 0.01, 0.05, 0.03, 0.03] (over a tiny vocab for illustration)
Student distribution: pθ = [0.05, 0.10, 0.50, 0.10, 0.10, 0.05, 0.10]
Mixture: M = (pT + pθ)/2 = [0.03, 0.06, 0.675, 0.055, 0.075, 0.04, 0.065]
JSD = 0.5 · KL(pT||M) + 0.5 · KL(pθ||M)
For the dominant token (index 2): 0.5 × 0.85 × log(0.85/0.675) + 0.5 × 0.50 × log(0.50/0.675) = 0.5 × 0.85 × 0.2299 + 0.5 × 0.50 × (-0.3001) = 0.0977 + (-0.0750) = 0.0227
Sum across all vocabulary entries and all T positions to get the total loss for this sample.
Step 5: Backpropagate and update. Compute gradients of the total loss with respect to student parameters θ. Apply optimizer step (AdamW, lr=1e-5 typical).
First, the complete training loop in ~30 lines, annotated:
python import torch import torch.nn.functional as F def gkd_step(student, teacher, prompt, lam=0.5, divergence='jsd'): """One GKD training step. Returns scalar loss.""" # Step 1: Decide sampling source if torch.rand(1).item() < lam: # On-policy: student generates the sequence with torch.no_grad(): generated_ids = student.generate(prompt, max_new_tokens=64) else: # Off-policy: use ground-truth from dataset generated_ids = ground_truth_ids # Step 2: Get teacher distribution at each position with torch.no_grad(): teacher_logits = teacher(generated_ids).logits # [1, T, V] teacher_probs = F.softmax(teacher_logits, dim=-1) # Step 3: Get student distribution (WITH gradients) student_logits = student(generated_ids).logits # [1, T, V] student_probs = F.softmax(student_logits, dim=-1) student_log_probs = F.log_softmax(student_logits, dim=-1) # Step 4: Compute chosen divergence if divergence == 'fkl': # Forward KL: sum P_T * log(P_T / P_S) loss = (teacher_probs * (teacher_probs.log() - student_log_probs)).sum(-1).mean() elif divergence == 'rkl': # Reverse KL: sum P_S * log(P_S / P_T) teacher_log_probs = teacher_probs.log() loss = (student_probs * (student_log_probs - teacher_log_probs)).sum(-1).mean() elif divergence == 'jsd': # JSD: 0.5*KL(P_T||M) + 0.5*KL(P_S||M) M = 0.5 * (teacher_probs + student_probs) M_log = M.log() loss = 0.5 * (teacher_probs * (teacher_probs.log() - M_log)).sum(-1).mean() + \ 0.5 * (student_probs * (student_log_probs - M_log)).sum(-1).mean() return loss
And a minimal training loop using this function:
python # GKD training schedule with lambda warmup optimizer = torch.optim.AdamW(student.parameters(), lr=1e-5) for step in range(10000): # Curriculum: ramp lambda from 0 to 0.5 over first 1000 steps lam = min(0.5, step / 2000) # 0 -> 0.5 linearly prompt = next(dataloader) loss = gkd_step(student, teacher, prompt, lam=lam, divergence='jsd') optimizer.zero_grad() loss.backward() optimizer.step() if step % 100 == 0: print(f"Step {step}, λ={lam:.2f}, loss={loss.item():.4f}")
Let's see quantitatively why λ=0.5 outperforms λ=0. Consider a student with ε=0.08 (8% per-token error on teacher-distributed inputs). After 1000 steps of training:
λ=0 (off-policy): The student only sees perfect prefixes. Its per-token error on student-distributed inputs stays at ~0.08 or worse (it never practices recovery). On a T=30 reasoning chain: error bound = 0.08 × 900/2 = 36. Complete failure.
λ=0.5 (mixed): Half the time, the student sees its own imperfect prefixes. It learns to handle error states. Its effective ε on student-distributed inputs drops to ~0.04 (half the original, because it's now trained on the right distribution). On a T=30 chain: error bound = 0.04 × 30 = 1.2. About one mistake per chain — recoverable.
The numbers aren't precise (real LLMs are more complex), but the mechanism is: training on your own distribution makes you better at handling your own distribution. This is the DAgger theorem in action.
Watch a student distribution train toward a bimodal teacher. Adjust λ to control sampling policy. Higher λ = more on-policy. Click "Train 10 Steps" to animate learning. The teal curve is the student; the white curve is the teacher.
The paper tested all combinations of sampling policy and divergence on summarization (XSum) and translation (WMT) tasks. Key findings:
| Configuration | XSum ROUGE-L | Improvement over off-policy FKL |
|---|---|---|
| λ=0, FKL (standard KD) | 32.1 | baseline |
| λ=0, RKL | 32.4 | +0.3 (divergence alone helps a little) |
| λ=0.5, FKL | 34.2 | +2.1 (sampling helps A LOT) |
| λ=0.5, JSD | 34.5 | +2.4 (best combination) |
| λ=1.0, JSD | 33.9 | +1.8 (fully on-policy slightly worse without warmup) |
The pattern is clear: changing λ from 0 to 0.5 gives +2.1 ROUGE-L. Changing divergence from FKL to JSD (same λ) gives only +0.3. The sampling distribution is the dominant factor.
MiniLLM (Gu et al., 2024) went all-in on reverse KL. Why? And why does this turn distillation into reinforcement learning?
Recall from Chapter 3: reverse KL is DKL(Pθ || PT) = Ey~Pθ[log Pθ(y) - log PT(y)]. The expectation is over the STUDENT's own distribution. To compute this loss, we must sample from the student — generate full sequences — then evaluate how well those sequences score under the teacher.
Here's the problem: sampling from a categorical distribution is not differentiable. When the student picks token yt from its softmax output, that discrete choice has no gradient. We can't backpropagate through "I chose token #4527 from a 32,000-word vocabulary." The argmax (or sample) operation is a wall.
The solution is ancient RL wisdom: the REINFORCE trick (Williams, 1992). Instead of backpropagating through the sample, we treat the sampled sequence as given and adjust the probability of producing it based on how good it turned out to be. Good sequence? Increase its probability. Bad sequence? Decrease it. The "goodness" is the reward.
Let's derive this carefully. Start with reverse KL over full sequences:
Negate (turn min into max) and rearrange:
Split into two terms:
The first term is the expected teacher log-probability — this is the reward. The second term is the student's own entropy — a bonus for exploration, preventing the student from collapsing to a single deterministic output. This is EXACTLY the entropy-regularized RL objective (MaxEnt RL, Ziebart et al., 2008).
G-OPD (Generalized On-Policy Distillation) takes this further with reward extrapolation. If log PT(y) is the base reward, we can ADD an external signal to push the student BEYOND the teacher:
Where Rext might be a math verifier (is the proof correct?), a code executor (does it pass tests?), or a human preference model. This is how you get a 7B student that's BETTER than the 70B teacher on specific tasks — impossible with forward KL alone, because FKL can only make the student MATCH the teacher, never exceed it.
The student generates the token "therefore" at position t with probability pθ("therefore") = 0.3. The teacher assigns pT("therefore") = 0.7. Let's compute the gradient:
Step 1 — Per-token reward:
Positive reward! The teacher likes this token much more than the student's current belief suggests. The student should increase its probability here.
Step 2 — REINFORCE gradient:
The score function ∇ log Pθ for a categorical distribution at the chosen token has magnitude 1/Pθ = 1/0.3 = 3.33 (in the logit direction). So the effective gradient magnitude is:
This is a strong positive gradient — it pushes the student hard toward "therefore." Now consider what happens at equilibrium: if pθ = pT = 0.7, then rt = log(0.7) - log(0.7) = 0. Zero reward, zero gradient. The student has matched the teacher on this token — no more signal. Beautiful.
What if the student generates a token the teacher hates? Say pθ("however") = 0.4 but pT("however") = 0.02:
Massive negative reward. Gradient magnitude: |-2.99 × (1/0.4)| = 7.48. The student gets punished HARD for generating tokens the teacher considers nearly impossible. This is the mode-seeking behavior of reverse KL in action — it ruthlessly suppresses tokens outside the teacher's support.
python import torch import torch.nn.functional as F def minillm_loss(student, teacher, prompts, max_len=128, entropy_coef=0.01): """MiniLLM: Reverse KL via REINFORCE. 1. Generate full sequences from student (on-policy sampling) 2. Score each token with teacher log-prob (reward) 3. Apply policy gradient (REINFORCE with baseline) """ # Step 1: Generate from student — NO gradient through sampling with torch.no_grad(): sequences = student.generate(prompts, max_length=max_len, do_sample=True) # Also get old log-probs as baseline old_logits = student(sequences) old_log_probs = F.log_softmax(old_logits, dim=-1) old_lp = old_log_probs.gather(-1, sequences.unsqueeze(-1)).squeeze(-1) # Step 2: Get teacher log-probs (reward signal) with torch.no_grad(): teacher_logits = teacher(sequences) teacher_log_probs = F.log_softmax(teacher_logits, dim=-1) teacher_lp = teacher_log_probs.gather(-1, sequences.unsqueeze(-1)).squeeze(-1) # Per-token reward: how much teacher likes this vs student's prior rewards = teacher_lp - old_lp # r_t = log P_T(y_t) - log P_θ(y_t) # Step 3: Recompute student log-probs WITH gradient new_logits = student(sequences) new_log_probs = F.log_softmax(new_logits, dim=-1) log_pi = new_log_probs.gather(-1, sequences.unsqueeze(-1)).squeeze(-1) # REINFORCE: maximize E[reward * log_prob] policy_loss = -(rewards.detach() * log_pi).mean() # Entropy bonus: prevent collapse probs = new_log_probs.exp() entropy = -(probs * new_log_probs).sum(dim=-1).mean() return policy_loss - entropy_coef * entropy
Top: teacher log-prob (reward landscape). Bottom: student policy shifting toward high-reward regions. Toggle "Reward Extrapolation" to see the student surpass the teacher.
The practical implications are profound. Reward extrapolation lets the student exceed the teacher. This is impossible with forward KL (which can only make you match the teacher). With reverse KL + external verifier, a 7B model can beat a 70B model on math — it just needs to explore more and the verifier tells it when it's right.
What if you don't HAVE a bigger teacher? What if your model IS the biggest model? This is the situation frontier labs face: GPT-5 can't distill from GPT-6 because GPT-6 doesn't exist yet. The model must improve from itself.
Self-distillation sounds paradoxical — you can't create knowledge from nothing. But you CAN reorganize existing knowledge. Three distinct mechanisms make this possible, each exploiting a different kind of "hidden signal" within the model's own capabilities.
Give the model extra information at training time that's unavailable at test time. Concretely: provide the ground-truth answer as a "hint" during training. The model with hints generates better reasoning traces. Then distill the hint-conditioned model into a hint-free model:
Example: Math problem "What is ∫ sin(x) dx?" The privileged version sees "[Answer: -cos(x) + C]" as context and produces a clean step-by-step derivation. The unprivileged version learns to produce that same derivation WITHOUT seeing the answer — it learns the reasoning PATTERN, not the specific answer.
This works because knowing the answer simplifies reasoning enormously. When you know the destination, you don't explore dead ends. The privileged model produces cleaner, shorter, more logical traces — and the unprivileged model learns to mimic that efficiency.
SPIN (Self-Play fINe-tuning) uses the model's previous version as a reference. At each round, the current model's outputs are "rejected" and the reference's outputs are "chosen." DPO (Direct Preference Optimization) then pushes the model to prefer reference-quality outputs over its own current outputs.
No teacher, no reference model. The model generates many candidate solutions to a problem, a binary verifier grades them (correct/incorrect), and GRPO (Group Relative Policy Optimization) uses the good ones as positive signal. Pure exploration + automated grading, no human in the loop.
This is how DeepSeek-R1 was trained: generate 64 candidate solutions to each math problem, check which are correct, reward the correct ones. The model learns which of its OWN strategies work — purely from trial and error.
Current model generates "2+2=5" (rejected). Reference has "2+2=4" (chosen). With β=0.1:
The DPO logit is:
Loss = -log σ(-0.15) = -log(0.463) = 0.770. This gradient pushes the model to increase P("2+2=4") relative to P("2+2=5"). After the update, the model prefers the correct answer slightly more. Repeat for many examples across many rounds.
After 3-5 rounds, the model's outputs become indistinguishable from the reference's. At that point, "chosen" and "rejected" look the same to DPO — the reward signal is zero. Training has saturated.
python import torch import torch.nn.functional as F def spin_round(model, ref_model, prompts, ground_truth, beta=0.1): """One round of SPIN (Self-Play fINe-tuning). model: current policy to improve ref_model: frozen previous version ground_truth: high-quality outputs (chosen) """ # Generate "rejected" from current model with torch.no_grad(): rejected = model.generate(prompts, temperature=0.7) # Compute log-probs under both models for both sequences log_p_chosen = model.log_prob(ground_truth) # WITH gradient log_p_rejected = model.log_prob(rejected) # WITH gradient with torch.no_grad(): log_ref_chosen = ref_model.log_prob(ground_truth) log_ref_rejected = ref_model.log_prob(rejected) # DPO loss logits = beta * ( (log_p_chosen - log_ref_chosen) - (log_p_rejected - log_ref_rejected) ) loss = -F.logsigmoid(logits).mean() return loss def check_saturation(model, ref_model, test_prompts, threshold=0.01): """Has SPIN converged? Check if model ≈ reference.""" with torch.no_grad(): outputs = model.generate(test_prompts, temperature=0.5) kl = (model.log_prob(outputs) - ref_model.log_prob(outputs)).mean() print(f"KL(model || ref) = {kl.item():.4f}") return abs(kl.item()) < threshold # True → saturated, stop training
Watch the current model (orange) converge toward the reference (teal). After ~6 rounds, the signal saturates.
Everything we've discussed sounds great in theory. In practice, on-policy training is a knife's edge. The data distribution changes with every gradient step. The student generates new text, which changes the loss landscape, which changes the gradients, which changes the student, which changes the generated text. It's a feedback loop that can spiral into instability or mode collapse in minutes.
Three pillars make on-policy distillation stable in production systems. Without them, you burn GPU-hours producing garbage.
Not all tokens are equally informative. Consider three cases:
| Teacher Confidence | Student Agreement | Action | Why |
|---|---|---|---|
| High (entropy < 0.2) | Low (divergence > 0.5) | TRAIN | Teacher knows the answer, student doesn't — maximum learning signal |
| High | High (divergence < 0.1) | Skip | Both already agree — zero information gain, wasted compute |
| Low (entropy > 1.0) | Any | Skip | Teacher itself is uncertain — its "guidance" here is noise, not signal |
The importance score for each token position combines teacher confidence with student confusion:
Where H is the teacher's entropy at position t, V is vocabulary size (so H/log V normalizes to [0,1]), and DKL measures how different student is from teacher. High weight = teacher confident AND student confused. This is the "golden quadrant" — maximum information gain per gradient step.
Start with easy sequences and ramp difficulty. "Easy" = short sequences, simple prompts, high teacher confidence. "Hard" = long sequences, multi-step reasoning, ambiguous prompts. A beta-kernel sampler selects training examples from the frontier — sequences just barely beyond the student's current ability. Too easy = no learning. Too hard = noise.
If the student generates garbage (repetitive loops like "the the the the", incoherent word salad, or degenerate patterns), the teacher's feedback on that garbage is NOISE. The teacher was trained on coherent text — asking it "what comes after 'the the the the the'?" produces meaningless output. Training on that meaningless output actively harms the student.
Solution: compute the perplexity of the student's generated prefix. If PPL > threshold (typically 100-500), discard that prefix and resample. Only train on prefixes that are at least minimally coherent.
python import numpy as np V = 32000 # vocabulary size log_V = np.log(V) # ≈ 10.37, maximum possible entropy # Each token: (word, teacher_entropy, KL_divergence) tokens = [ ("the", 0.08, 0.03), # function word, both agree ("integration", 0.15, 0.82), # content word, student confused ("um", 2.30, 0.45), # filler, teacher uncertain too ("equals", 0.05, 0.91), # math token, student very wrong ("maybe", 1.80, 0.12), # hedge word, teacher unsure ] threshold = 0.30 print(f"{'Token':<12} {'H_T':<6} {'KL':<6} {'Score':<8} {'Decision'}") print("-" * 50) for word, h_t, kl in tokens: confidence = 1.0 - h_t / log_V # how sure the teacher is score = confidence * kl # importance weight action = "TRAIN" if score >= threshold else "skip" print(f"{word:<12} {h_t:<6.2f} {kl:<6.2f} {score:<8.3f} {action}") # Token H_T KL Score Decision # -------------------------------------------------- # the 0.08 0.03 0.029 skip # integration 0.15 0.82 0.798 TRAIN ← golden! # um 2.30 0.45 0.350 TRAIN (borderline) # equals 0.05 0.91 0.906 TRAIN ← golden! # maybe 1.80 0.12 0.099 skip
Tokens "integration" and "equals" are in the golden quadrant: teacher is confident, student is wrong. These are where 80%+ of the learning happens. "the" and "maybe" are both worthless to train on — "the" because everyone agrees, "maybe" because nobody knows.
The compute savings are enormous:
| Approach | Relative Cost | Why |
|---|---|---|
| Naive on-policy (all tokens) | 3-5× baseline | Generate + teacher forward + student backward on every token |
| + Token filtering (skip ~60%) | 1.5× | Only backpropagate on high-importance tokens |
| + Prefix caching & async generation | 0.8× | Overlap generation with gradient computation in pipeline |
| + Flawed prefix rejection | 0.6× | Don't waste teacher queries on degenerate student outputs |
Each token colored by importance score. Adjust threshold to see which tokens get trained on vs. skipped. Below: the 2D scatter — golden quadrant (low teacher entropy, high divergence) highlighted.
Time to run the full pipeline yourself. This simulator lets you compare every distillation method we've covered on a concrete task: matching a bimodal teacher distribution over varying sequence lengths. The teacher has two equally valid modes (like a translation task with two correct phrasings). The student starts as a broad, uncertain distribution and must learn to match both peaks.
Watch the key behaviors emerge:
Top: teacher (teal) vs student (orange) distributions. Bottom: KL divergence over training steps. Select method, set T, and train.
| Method | Error Scaling | Mode Behavior | Failure Pattern |
|---|---|---|---|
| Off-Policy | O(εT²) | Mode-covering | Catastrophic at long T |
| GKD λ=0.5 | O(εT) | Balanced (both peaks) | Slow initial convergence |
| GKD λ=1.0 | O(εT) | Balanced (both peaks) | Noisy early, best final quality |
| MiniLLM (RKL) | O(εT) | Mode-seeking (one peak) | Drops valid modes entirely |
| Self-Play (SPIN) | O(εT) → flat | Matching (reference-bounded) | Saturates at ~30 steps |
The showcase demonstrates why production systems use hybrid pipelines: off-policy warmup for early stability, GKD for the main training phase, and optional RKL + verifier for final polishing on deterministic tasks (math, code) where mode-seeking is desirable.
Everything you've learned, compressed for interview prep, implementation, and further study. This chapter tests whether you can APPLY the concepts, not just recognize them.
Derive the DAgger bound from scratch. Click each "???" to reveal:
Setup: Student has per-step error ε under training distribution. At step t:
1. P(at least one mistake in steps 1..t) ≤ ???
2. Once off-distribution, worst-case per-step error = ???
3. Expected error at step t = ???
4. Total over T steps: ∑t=1T tε = ε · ???
5. With on-policy (DAgger): error at step t = ???
You're distilling a 70B reasoning model into 7B for deployment. Budget: 1000 GPU-hours. Design the pipeline (click "?" to reveal suggested answers):
| Stage | Budget % | Method | λ | Rationale |
|---|---|---|---|---|
| Warmup | ? | ? | ? | ? |
| Main | ? | ? | ? | ? |
| RL Polish | ? | ? | ? | ? |
| Final | ? | ? | ? | ? |
Click each "?" cell to reveal the suggested answer.
Three scenarios where OPD breaks. For each: identify the failure mode and the fix (click to reveal).
Scenario 1: Student generates "the the the the the..." for 50 tokens. You query the teacher to score it.
Failure + Fix: Click to reveal
Scenario 2: Self-play (SPIN) metrics are flat after round 3. No improvement.
Failure + Fix: Click to reveal
Scenario 3: Student trained with FKL on math problems outputs "7" when valid answers are "4" and "12".
Failure + Fix: Click to reveal
| Concept | Key Equation | When to Use |
|---|---|---|
| Off-policy loss | Ey~PT[KL(PT || Pθ)] | Warmup; short sequences only |
| On-policy loss | Ey~Pθ[KL(PT || Pθ)] | Main training; long sequences |
| GKD mix | λ·Lon + (1-λ)·Loff | Always — interpolate via λ |
| DAgger off-policy | O(εT²) | Predicting failure at length T |
| DAgger on-policy | O(εT) | Linear error guarantee |
| RL equivalence | max EPθ[log PT] + H(Pθ) | When using reverse KL (MiniLLM) |
| Token importance | (1 - HT/log V) × KL | Filtering uninformative tokens |
| Reward extrap. | R = log PT + β·Rext | Exceeding teacher quality |
Builds on:
Leads to:
"What I cannot create, I do not understand." — Richard Feynman. In distillation: what the student cannot generate on its own, it has not truly learned.