The Complete Beginner's Path

On-Policy Distillation
for LLMs

Why small models can't reason by copying — and the training loop that fixes it.

Prerequisites: Autoregressive LMs + Probability basics. That's it.
10
Chapters
8+
Simulations
0
Assumed Knowledge

Chapter 0: Why Practice?

You clone GPT-4's homework answers perfectly. Your student model scores 95% on predicting the next word. Then someone asks it to prove that √2 is irrational, and it produces: "Assume √2 = p/q. Then 2 = p²/q². Therefore p² = 2q². So p is even. Wait, actually let me try a different approach... The proof is left as an exercise."

It sounds like math. It is not math. The model learned to imitate the surface pattern of reasoning — hedging, self-correction phrases, formal notation — without learning the actual logical chain that connects assumptions to conclusions.

This is the central failure of naive knowledge distillation: training a small "student" model by having it predict the same tokens as a large "teacher" model. The student learns what the right answer looks like, but not how to produce it step-by-step when things go slightly off-script.

The problem is subtle but devastating. During training, the student sees perfect teacher-generated text at every position. It learns: "given this perfect prefix, predict the next token." But during inference, the student generates its own prefix. The moment it makes a single mistake — one wrong token — it's now in territory it has never seen during training. And from that point forward, every subsequent token is generated from an unfamiliar context.

Think of it this way: off-policy training is like learning to drive by watching dashcam footage of perfect driving. You know what a good lane change looks like. But the first time you drift slightly left, you've never practiced recovery — you've only ever seen the view from perfectly centered. Panic ensues.

On-policy training is like actually getting behind the wheel with an instructor. You make mistakes. You drift. And the instructor says "ok, from HERE, do THIS." You learn what recovery looks like from error states, not just from perfect states. This is the key insight of on-policy distillation.

The core failure: A student trained only on teacher outputs has 95% per-token accuracy. Sounds great. But for a 20-step reasoning chain: 0.9520 = 36% probability of getting the entire sequence right. For 50 steps: 0.9550 = 7.7%. The model literally cannot complete a medium-length proof because it has never practiced recovering from its own errors.
Off-Policy vs On-Policy: Error Propagation

Toggle between modes to see how errors propagate. In copy mode, the student trains on perfect tokens but fails at inference. In practice mode, it trains on its own outputs and learns to recover.

Mode: Copy (Off-Policy Training)

The Numbers Don't Lie

Let's make this concrete with actual arithmetic. Suppose your student model achieves ε = 0.05 per-token error rate — meaning 95% accuracy at each individual position. That sounds excellent. But reasoning requires getting a sequence of tokens all correct.

For a sequence of length T, the probability of producing the entire sequence correctly (assuming independent errors, which is optimistic) is (1 - ε)T. Let's compute this for various sequence lengths:

Sequence Length TSuccess Prob (1-ε)TInterpretation
5 tokens0.955 = 0.774 (77%)Short factual answer — mostly works
10 tokens0.9510 = 0.599 (60%)Simple reasoning — coin flip
20 tokens0.9520 = 0.358 (36%)Medium proof — usually fails
50 tokens0.9550 = 0.077 (7.7%)Long derivation — almost always fails
100 tokens0.95100 = 0.006 (0.6%)Complex reasoning — effectively impossible

This is the compounding error problem. Even a tiny per-step error rate becomes catastrophic over long sequences. And the situation is actually worse than this table suggests, because errors aren't independent — once the model makes one error, it enters an unfamiliar state, making subsequent errors more likely.

Key insight: Knowledge distillation isn't about per-token accuracy. It's about trajectory-level success. A model that's 95% accurate per token but has never practiced error recovery is dramatically worse than a model that's 90% accurate per token but knows what to do when things go wrong.

Hand Computation: Success Probability

Let's trace through the math carefully for T=5, ε=0.05:

P(success) = (1 - ε)T = (1 - 0.05)5 = 0.955

Step by step: 0.951 = 0.95. Then 0.952 = 0.95 × 0.95 = 0.9025. Then 0.953 = 0.9025 × 0.95 = 0.8574. Then 0.954 = 0.8574 × 0.95 = 0.8145. Finally 0.955 = 0.8145 × 0.95 = 0.7738.

So even with 95% per-step accuracy over just 5 steps, we've already lost 23% of our sequences. Now imagine T=50: we'd compute 0.95 multiplied by itself 50 times. The result is 0.0769 — only 7.7% of sequences survive intact.

From Arithmetic to Code

First, let's compute this manually, step by step:

python
# Manual computation — no libraries
epsilon = 0.05   # per-token error rate
T = 20            # sequence length

# Multiply (1 - epsilon) by itself T times
prob = 1.0
for step in range(T):
    prob = prob * (1 - epsilon)
    print(f"Step {step+1:2d}: P(all correct) = {prob:.4f}")

# Step  1: P(all correct) = 0.9500
# Step  2: P(all correct) = 0.9025
# Step  5: P(all correct) = 0.7738
# Step 10: P(all correct) = 0.5987
# Step 20: P(all correct) = 0.3585

Now the compact version:

python
# One-liner: success probability for any epsilon, T
success = (1 - 0.05) ** 20  # = 0.3585

# Sweep across sequence lengths
for T in [5, 10, 20, 50, 100]:
    p = (1 - 0.05) ** T
    print(f"T={T:3d}:  P(success) = {p:.4f}  ({p*100:.1f}%)")

And with numpy for vectorized computation:

python
import numpy as np

epsilons = np.array([0.01, 0.05, 0.10])
lengths = np.arange(1, 101)

# Shape: (3, 100) — each row is one epsilon, columns are T=1..100
success_grid = (1 - epsilons[:, None]) ** lengths[None, :]
# success_grid[1, 49] = P(success | epsilon=0.05, T=50) = 0.0769
The piano analogy: Off-policy distillation is learning piano by watching YouTube. You can visualize the fingering perfectly. But the first time your finger slips to the wrong key, you freeze — you've never felt a wrong note before, never practiced moving FROM a wrong position TO the next correct one. On-policy training is actual practice: you hit wrong notes, and your teacher shows you how to recover from each specific mistake.
A model trained only on teacher outputs encounters its own error at step 4 of a 20-step proof. Why can't it recover?

Chapter 1: The Autoregressive Trap

Chapter 0 showed that something goes wrong. Now let's be precise about what goes wrong, mathematically. The training loop looks innocent enough — but it hides a devastating assumption that only reveals itself at inference time.

Here's the standard token-level knowledge distillation objective. We have a teacher model pT and a student model pθ. At each position t in a sequence, we minimize the KL divergence between the teacher's distribution and the student's distribution:

Loff-policy = ∑t=1T DKL( pT(· | x, y<t) ‖ pθ(· | x, y<t) )

Read that carefully. The prefix y<t — all the tokens before position t — is the teacher's sequence. The student is being asked: "given this perfect prefix written by the teacher, what should the next token be?" And it learns to answer that question very well.

But at inference time, the student doesn't get the teacher's prefix. It generates its own:

Lon-policy = ∑t=1T DKL( pT(· | x, ŷ<t) ‖ pθ(· | x, ŷ<t) )

where ŷ<t ~ pθ — the prefix is sampled from the student's own distribution. This is the train/test mismatch. During training, the student always conditions on clean, teacher-generated context. During inference, it conditions on its own noisy, imperfect context. The distribution of prefixes at test time is fundamentally different from training time.

This mismatch has a name in the sequential decision-making literature: distribution shift (also called covariate shift in the supervised learning context). The input distribution (prefixes) changes between training and testing. And the shift compounds — each error makes future inputs even more out-of-distribution.

The key distinction: Off-policy = train on teacher's prefixes. On-policy = train on student's own prefixes. The ONLY difference is who generates y<t. But this single difference determines whether errors compound quadratically or linearly (as we'll prove in Chapter 2).

Visualizing the Drift

Let's make this concrete with a picture. At position t=1, both training and inference are identical — there's no prefix yet, just the prompt x. At position t=2, training uses the teacher's first token (perfect), while inference uses the student's first token (possibly wrong). By position t=5, training still operates from a perfect 4-token prefix, while inference has been building on potentially corrupted tokens for 4 steps.

Training vs Inference: The Prefix Problem

Left: during training, the student always sees perfect teacher tokens (green). Right: during inference, it sees its own tokens — some correct (green), some wrong (orange). Watch how drift accumulates token by token.

Click Animate to watch drift accumulate

Worked Example: KL Divergence at One Position

Let's compute the actual loss at a single token position. Suppose we have a tiny vocabulary of 3 tokens: ["therefore", "so", "but"]. The teacher assigns probabilities:

pT = [0.7, 0.2, 0.1] — the teacher is fairly confident "therefore" comes next.

The student's current distribution is:

pθ = [0.3, 0.4, 0.3] — the student is confused, spreading probability too broadly.

The forward KL divergence DKL(pT || pθ) measures how surprised the student would be if reality followed the teacher:

DKL(pT || pθ) = ∑i pT(i) · log( pT(i) / pθ(i) )

Let's compute each term:

Term 1 ("therefore"): 0.7 × log(0.7 / 0.3) = 0.7 × log(2.333) = 0.7 × 0.8473 = 0.5931

Term 2 ("so"): 0.2 × log(0.2 / 0.4) = 0.2 × log(0.5) = 0.2 × (-0.6931) = -0.1386

Term 3 ("but"): 0.1 × log(0.1 / 0.3) = 0.1 × log(0.333) = 0.1 × (-1.0986) = -0.1099

Total: DKL = 0.5931 + (-0.1386) + (-0.1099) = 0.3447 nats

Wait — negative terms? Yes! KL divergence sums to a non-negative total (always ≥ 0 by Gibbs' inequality), but individual terms can be negative. The negative terms come from tokens where the student assigns MORE probability than the teacher. In our case, the student over-weights "so" and "but" while under-weighting "therefore." The net KL is positive: 0.3447 nats, meaning the student still has significant room to improve.

Misconception alert: "KL divergence measures distance between distributions." Not quite — it's ASYMMETRIC. DKL(P||Q) ≠ DKL(Q||P). Forward KL (teacher || student) penalizes the student for putting LOW probability where the teacher puts HIGH probability. Reverse KL (student || teacher) penalizes the student for putting HIGH probability where the teacher puts LOW probability. This asymmetry is the entire basis of Chapter 3.

Implementing KL Divergence

First, from absolute scratch — no libraries:

python
import math

# Teacher and student distributions
p_teacher = [0.7, 0.2, 0.1]
p_student = [0.3, 0.4, 0.3]

# Manual KL: sum p_T * log(p_T / p_S)
kl = 0.0
for i in range(len(p_teacher)):
    if p_teacher[i] > 0:
        term = p_teacher[i] * math.log(p_teacher[i] / p_student[i])
        print(f"  Term {i}: {p_teacher[i]:.1f} * log({p_teacher[i]:.1f}/{p_student[i]:.1f}) = {term:.4f}")
        kl += term

print(f"KL divergence = {kl:.4f} nats")
# Output:
#   Term 0: 0.7 * log(0.7/0.3) = 0.5931
#   Term 1: 0.2 * log(0.2/0.4) = -0.1386
#   Term 2: 0.1 * log(0.1/0.3) = -0.1099
# KL divergence = 0.3447 nats

Now with PyTorch — how you'd actually compute it in a training loop:

python
import torch
import torch.nn.functional as F

# Logits from teacher and student (before softmax)
teacher_logits = torch.tensor([1.5, 0.3, -0.5])  # arbitrary logits
student_logits = torch.tensor([0.1, 0.5, 0.1])

# Convert to log-probs (student) and probs (teacher)
teacher_probs = F.softmax(teacher_logits, dim=-1)
student_log_probs = F.log_softmax(student_logits, dim=-1)

# PyTorch KL: expects log(Q) and P, computes sum P * (log P - log Q)
kl = F.kl_div(student_log_probs, teacher_probs, reduction='sum')
print(f"KL = {kl.item():.4f}")

The Compounding Mechanism

Now let's see why the mismatch doesn't just cause a fixed offset, but actually compounds over time. Consider a sequence where the student makes an error at position t=3. At position t=4:

During training, the input is [prompt, teacher_tok_1, teacher_tok_2, teacher_tok_3]. The student has seen millions of examples with this exact type of clean context.

During inference, the input is [prompt, student_tok_1, student_tok_2, WRONG_tok_3]. The student has never seen context with this error. Its prediction at position 4 will be unreliable. And unreliable position 4 feeds into position 5, which feeds into position 6...

This is exposure bias (Bengio et al., 2015) — the model is never "exposed" to its own errors during training, so it can't handle them at inference. The errors compound because each error creates a novel context that triggers further errors.

Summary so far: The off-policy KD loss trains on teacher prefixes. At inference, the student generates its own prefixes. One error creates out-of-distribution context, triggering more errors. This is not a bug in the student's capacity — it's a fundamental mismatch between training and testing conditions.
What's the key difference between the off-policy and on-policy KD objectives?

Chapter 2: The DAgger Theorem

Chapter 1 showed there's a mismatch between training and inference. But how bad is it? Could we just add more training data and fix it? Could we make the student bigger? Or is there a fundamental scaling law that makes off-policy training catastrophically worse as sequences get longer?

The answer comes from a remarkable theorem by Ross, Gordon, and Bagnell (2011), originally proven for imitation learning in robotics. It applies directly to autoregressive language models and gives us a precise bound on how errors scale with sequence length.

The theorem says: if your per-step error rate is ε, then your total accumulated error over a sequence of length T scales as:

Off-policy:   Total Error ≤ O(ε T2)     [QUADRATIC]
On-policy:   Total Error ≤ O(ε T)     [LINEAR]

Quadratic vs linear. That's not a constant factor improvement — it's a fundamentally different scaling law. At T=100, quadratic gives 10,000ε while linear gives 100ε. The off-policy approach is 100x worse at sequence length 100.

Deriving the Quadratic Bound (from scratch)

Let's derive WHY the error is quadratic for off-policy training. The argument is intuitive once you see it:

Setup: The student has per-step error rate ε. This means at any given step, there's a probability ε that it produces a token different from what the teacher would produce.

Step 1: At time step t, what's the probability that the student is in an "unseen state" — a prefix it never encountered during training? If any of the previous t-1 tokens were wrong, the current prefix is novel. The probability of being in a novel state at step t is at most (t-1)·ε (union bound — each of the t-1 previous positions had ε probability of error).

Step 2: Once in an unseen state, the student's error rate is no longer ε — it could be as bad as 1 (total failure). So at step t, the expected error contribution is:

E[error at step t] ≤ ε + (t-1)·ε · 1 = ε · t

The first ε is the intrinsic error (even on familiar states). The second term is: probability of being in an unseen state (t-1)·ε, times worst-case error of 1.

Step 3: Sum over all T steps to get total error:

Total Error ≤ ∑t=1T ε · t = ε · ∑t=1T t = ε · T(T+1)/2 ≈ εT2/2

That's the quadratic bound: O(εT2). The sum of 1+2+3+...+T = T(T+1)/2 is the classic formula, and it gives us the T2 scaling.

Why On-Policy Gives Linear

Now suppose we train on-policy: the student sees its own prefixes during training. In that case, it HAS seen error states before. It knows how to behave when the prefix is imperfect. So even in a "novel" state (one that arose from a student error), the expected error per step is still just ε — because the student has practiced recovery.

Total Error (on-policy) ≤ ∑t=1T ε = ε · T

No compounding. Just ε per step, T steps, total = εT. Linear.

Numerical Comparison

Let's plug in numbers. Set ε = 0.05 (95% per-step accuracy):

Sequence Length TOff-Policy (εT²/2)On-Policy (εT)Ratio (off/on)
T = 100.05×100/2 = 2.50.05×10 = 0.5
T = 200.05×400/2 = 100.05×20 = 1.010×
T = 500.05×2500/2 = 62.50.05×50 = 2.525×
T = 1000.05×10000/2 = 2500.05×100 = 5.050×

At T=100, the off-policy error (250) exceeds the total sequence length — meaning the model has essentially failed completely. The on-policy error (5.0) is manageable — about 5 mistakes in 100 tokens, and the model can keep going.

Misconception: "Just train longer and ε will go to zero." It won't. The student has limited capacity (fewer parameters than the teacher). There's a floor on ε determined by the capacity gap. Even ε = 0.01 gives devastating off-policy error at T=100: 0.01 × 10000 / 2 = 50 total errors. Meanwhile on-policy gives only 0.01 × 100 = 1. The problem isn't insufficient training — it's the wrong training distribution.
DAgger Error Bounds: Quadratic vs Linear

Drag the ε slider to see how per-step error rate affects the gap between off-policy (quadratic, orange) and on-policy (linear, teal). The dashed line shows T (maximum reasonable error = total failure).

ε (per-step error) 0.05
At T=50: Off-policy error = 62.5, On-policy error = 2.5 (25x gap)

The DAgger Algorithm (Sketch)

The DAgger paper didn't just prove the bound — it proposed a fix. The Dataset Aggregation (DAgger) algorithm works like this:

Step 1: Roll out
Run the student policy to generate trajectories (on-policy data)
Step 2: Label
Ask the teacher: "at each state the STUDENT visited, what would you do?"
Step 3: Aggregate
Add these new (student-state, teacher-action) pairs to the dataset
Step 4: Retrain
Train student on the growing dataset (old + new data)
↻ repeat

The key insight: in step 2, the teacher labels the student's states. This means the training data now includes examples from the states the student actually visits — including error states. Over iterations, the student builds experience with its own mistakes.

For LLM distillation, DAgger translates to: let the student generate text, then ask the teacher to provide probability distributions conditioned on the student's actual generated prefixes. This is computationally expensive (requires running the teacher on student-generated sequences) but provably reduces the error bound from quadratic to linear.

Code: Computing the Error Bounds

python
# Manual computation of DAgger bounds
epsilon = 0.05

# Off-policy: sum of epsilon * t for t=1..T
def off_policy_error(eps, T):
    # = eps * T*(T+1)/2
    total = 0.0
    for t in range(1, T+1):
        total += eps * t
    return total

# On-policy: sum of epsilon for t=1..T
def on_policy_error(eps, T):
    return eps * T

# Check: off_policy_error(0.05, 50)
# = 0.05 * (1+2+3+...+50) = 0.05 * 1275 = 63.75
print(off_policy_error(0.05, 50))  # 63.75
print(on_policy_error(0.05, 50))   # 2.50

Compact one-liners:

python
# Closed-form expressions
off_err = lambda e, T: e * T * (T + 1) / 2
on_err  = lambda e, T: e * T

# At what T does off-policy become 10x worse?
# off/on = T*(T+1)/2 / T = (T+1)/2 = 10  =>  T = 19
print(f"Off-policy is 10x worse at T = 19")
print(f"  Off: {off_err(0.05, 19):.2f}, On: {on_err(0.05, 19):.2f}")
# Off: 9.50, On: 0.95 — ratio = 10.0

And visualized with numpy:

python
import numpy as np

eps = 0.05
T = np.arange(1, 101)

off_policy = eps * T * (T + 1) / 2   # quadratic
on_policy  = eps * T                   # linear

# "Success probability" heuristic: clip error to [0, T]
off_success = np.maximum(0, 1 - off_policy / T)
on_success  = np.maximum(0, 1 - on_policy / T)
# At T=50: off_success = max(0, 1 - 63.75/50) = 0 (total failure)
# At T=50: on_success  = max(0, 1 - 2.5/50) = 0.95 (mostly works!)
The crossover point: At what sequence length does off-policy become K times worse than on-policy? The ratio is εT²/2 divided by εT = T/2. So for K=10 (10x worse), we need T=20. For K=50 (50x worse), T=100. The gap grows linearly with sequence length — meaning longer reasoning chains suffer disproportionately from off-policy training.
At what sequence length T does off-policy distillation become exactly 10× worse than on-policy? (Hint: ratio = (T+1)/2 ≈ T/2)

Chapter 3: f-Divergences — The Design Space

OK, so on-policy training helps (Chapter 2). But what should the student actually optimize? "Match the teacher's distribution" can mean many different things mathematically. The choice of how to measure "mismatch" dramatically changes what the student learns.

All the distance measures we'll consider belong to a single family: the f-divergences. Every f-divergence has the form:

Df(P || Q) = ∑i Q(i) · f( P(i) / Q(i) )

where f is a convex function with f(1) = 0. Different choices of f give different divergences, each with distinct behavior:

Divergencef(u)FormulaBehavior
Forward KLu log u∑ P log(P/Q)Mode-covering (broad)
Reverse KL-log u∑ Q log(Q/P)Mode-seeking (sharp)
JSD-(1+u)/2 · log((1+u)/2) + u/2 · log u½KL(P||M) + ½KL(Q||M)Symmetric compromise
Total Variation|u-1|/2½ ∑ |P-Q|Largest single-token gap

The two most important for LLM distillation are Forward KL (FKL) and Reverse KL (RKL). Their difference is best understood through a specific example.

The Bimodal Teacher: Why Mode-Covering vs Mode-Seeking Matters

Imagine a teacher distribution with two peaks — two equally valid ways to answer a math question. Maybe "therefore x = 3" and "so we get x = 3" are both correct continuations. The teacher assigns probability to both modes.

Now the student is smaller — maybe it can only represent a unimodal distribution (one peak). What does it do?

Forward KL (mode-covering): FKL = ∑ pT log(pT/pθ) penalizes the student for putting LOW probability where the teacher puts HIGH probability. If the teacher has two peaks, the student MUST cover both — or face infinite penalty (log(something/0) = ∞). So the student spreads out to cover both modes... and ends up assigning probability to the gap between them. This means it might generate tokens that are BETWEEN the two valid answers — hallucination.

Reverse KL (mode-seeking): RKL = ∑ pθ log(pθ/pT) penalizes the student for putting HIGH probability where the teacher puts LOW probability. The student is rewarded for being precise — it picks ONE mode and locks onto it sharply. It ignores the other mode entirely. This means it might miss valid alternatives, but it never hallucinates.

JSD (compromise): M = (P+Q)/2 and JSD = ½KL(P||M) + ½KL(Q||M). It's symmetric, bounded between 0 and log(2), and offers a middle ground. The student covers most of the teacher's mass without putting too much weight in low-density regions.

The practical implication for LLMs: Math/code tasks have ONE correct answer — use Reverse KL (mode-seeking). Creative/translation tasks have MANY valid outputs — Forward KL (mode-covering) preserves diversity. This is why GKD (Chapter 4) makes the divergence choice orthogonal to the sampling strategy — different tasks need different divergences.

Worked Example: FKL vs RKL on a 3-Token Vocabulary

Let's compute both divergences for concrete numbers. Teacher: PT = [0.5, 0.4, 0.1]. Student: Pθ = [0.6, 0.3, 0.1].

Forward KL: DKL(PT || Pθ) = ∑ PT(i) · log(PT(i) / Pθ(i))

Term 1: 0.5 × log(0.5/0.6) = 0.5 × log(0.8333) = 0.5 × (-0.1823) = -0.0912

Term 2: 0.4 × log(0.4/0.3) = 0.4 × log(1.3333) = 0.4 × 0.2877 = 0.1151

Term 3: 0.1 × log(0.1/0.1) = 0.1 × log(1.0) = 0.1 × 0 = 0

FKL = -0.0912 + 0.1151 + 0 = 0.0239 nats

Reverse KL: DKL(Pθ || PT) = ∑ Pθ(i) · log(Pθ(i) / PT(i))

Term 1: 0.6 × log(0.6/0.5) = 0.6 × log(1.2) = 0.6 × 0.1823 = 0.1094

Term 2: 0.3 × log(0.3/0.4) = 0.3 × log(0.75) = 0.3 × (-0.2877) = -0.0863

Term 3: 0.1 × log(0.1/0.1) = 0.1 × log(1.0) = 0.1 × 0 = 0

RKL = 0.1094 + (-0.0863) + 0 = 0.0231 nats

In this case FKL (0.0239) and RKL (0.0231) are similar — because the distributions are close and roughly unimodal. The dramatic difference appears when the teacher is bimodal and the student is unimodal.

The Dramatic Case: Bimodal Teacher

Now let PT = [0.45, 0.05, 0.45, 0.05] — two strong modes at positions 0 and 2, with small mass between them. Student can only be unimodal. Two strategies:

Strategy A (cover both): Pθ = [0.3, 0.2, 0.3, 0.2] — broad, covering everything.

Strategy B (pick one): Pθ = [0.7, 0.2, 0.05, 0.05] — sharp, focused on mode 0.

FKL prefers Strategy A (0.191 nats) over Strategy B (0.552 nats) — because B puts only 0.05 where the teacher puts 0.45, creating a large log(0.45/0.05) = 2.197 penalty. But Strategy A hallucinates — it puts 20% probability on tokens 1 and 3 where the teacher barely has 5%.

RKL prefers Strategy B (0.193 nats) over Strategy A (0.239 nats) — because B doesn't put high probability where the teacher puts low. The "dropped mode" (token 2 getting only 0.05 from student) doesn't cost much in RKL because the student's mass there is small.

Misconception: "JSD is 'the best of both worlds.'" It's not — it's a compromise that has neither FKL's guaranteed coverage of all teacher modes nor RKL's guaranteed precision within chosen modes. JSD is good for moderate-diversity tasks (translation, summarization) where you want some mode coverage without hallucination. But for math (use RKL) or creative writing (use FKL), the extremes are better.
Mode-Covering vs Mode-Seeking

The white curve is a bimodal teacher distribution. Drag the slider to morph the student distribution from Forward KL optimum (broad, covers both modes) to Reverse KL optimum (sharp, picks one mode). The red region shows hallucination mass.

FKL ←→ RKL FKL
Teacher: bimodal (two valid answers). Student optimized with Forward KL — covers both but hallucinates between.

Code: All Three Divergences from Scratch

python
import math

def forward_kl(p_teacher, p_student):
    """D_KL(P_T || P_S) = sum P_T * log(P_T / P_S)"""
    kl = 0.0
    for pt, ps in zip(p_teacher, p_student):
        if pt > 0:
            assert ps > 0, "FKL infinite if student has zero where teacher is positive"
            kl += pt * math.log(pt / ps)
    return kl

def reverse_kl(p_teacher, p_student):
    """D_KL(P_S || P_T) = sum P_S * log(P_S / P_T)"""
    kl = 0.0
    for pt, ps in zip(p_teacher, p_student):
        if ps > 0:
            assert pt > 0, "RKL infinite if teacher has zero where student is positive"
            kl += ps * math.log(ps / pt)
    return kl

def jsd(p_teacher, p_student):
    """JSD = 0.5 * KL(P_T||M) + 0.5 * KL(P_S||M) where M = (P_T+P_S)/2"""
    m = [(pt + ps) / 2 for pt, ps in zip(p_teacher, p_student)]
    return 0.5 * forward_kl(p_teacher, m) + 0.5 * forward_kl(p_student, m)

# Test with our worked example
P_T = [0.5, 0.4, 0.1]
P_S = [0.6, 0.3, 0.1]

print(f"Forward KL: {forward_kl(P_T, P_S):.4f}")  # 0.0239
print(f"Reverse KL: {reverse_kl(P_T, P_S):.4f}")  # 0.0231
print(f"JSD:        {jsd(P_T, P_S):.4f}")          # 0.0059

And the PyTorch equivalent:

python
import torch
import torch.nn.functional as F

P_T = torch.tensor([0.5, 0.4, 0.1])
P_S = torch.tensor([0.6, 0.3, 0.1])

# Forward KL: F.kl_div expects log(Q), P — computes sum P*(log P - log Q)
fkl = F.kl_div(P_S.log(), P_T, reduction='sum')

# Reverse KL: swap arguments
rkl = F.kl_div(P_T.log(), P_S, reduction='sum')

# JSD
M = 0.5 * (P_T + P_S)
jsd_val = 0.5 * F.kl_div(M.log(), P_T, reduction='sum') + \
          0.5 * F.kl_div(M.log(), P_S, reduction='sum')

print(f"FKL={fkl:.4f}, RKL={rkl:.4f}, JSD={jsd_val:.4f}")

When to Use Which Divergence

Here's the practical decision rule for LLM distillation:

Task TypeTeacher ShapeBest DivergenceWhy
Math proofsUnimodal (one correct path)Reverse KLPrecision: don't hallucinate wrong steps
Code generationFew modes (correct implementations)Reverse KLCorrectness over diversity
TranslationModerate modes (valid phrasings)JSDSome variety without garbage
Creative writingHighly multimodal (many valid outputs)Forward KLPreserve diversity of expression
SummarizationModerate modesJSD or FKLMultiple valid summaries exist
Key takeaway: The divergence choice is a precision-diversity tradeoff. RKL = precise but misses modes. FKL = diverse but hallucinates. JSD = compromise. There's no universal best — it depends on whether your task rewards precision (math, code) or diversity (creative, conversational).
A math problem has exactly one correct derivation path. Which divergence should you use for distilling a math-reasoning student, and why?

Chapter 4: GKD — The First Practical Framework

We know on-policy helps (Chapter 2). We have a menu of divergences (Chapter 3). But how do you actually combine these into a practical training algorithm? The first paper to do this systematically was Generalized Knowledge Distillation (GKD) by Agarwal et al. (2024). Its key insight: make the sampling policy and the divergence measure orthogonal design choices.

Previous methods hard-coded both: standard KD uses off-policy sampling + forward KL. SeqKD uses off-policy sampling + sequence-level loss. GKD says: pick ANY combination from a 2D grid:

LGKD = Ey ~ πmix [ ∑t Df( pT(·|x, y<t) || pθ(·|x, y<t) ) ]

where πmix = λ · pθ + (1-λ) · pdata is a mixture sampling policy. The parameter λ ∈ [0, 1] controls how "on-policy" you are:

λSampling PolicyWhat the student sees
0Pure off-policy (pdata)Only ground-truth / teacher sequences
0.550/50 mixtureHalf teacher sequences, half student rollouts
1.0Pure on-policy (pθ)Only the student's own generated sequences

And Df can be any f-divergence from Chapter 3 — FKL, RKL, JSD, etc. The two choices are independent.

GKD's Killer Finding

The paper's most surprising result: the sampling policy (λ) matters MORE than the divergence choice. Even with simple forward KL, switching from λ=0 (off-policy) to λ=0.5 (mixed) dramatically improves downstream task performance. The divergence choice matters less — it's a second-order effect.

This makes intuitive sense given Chapter 2: the DAgger theorem says the sampling distribution (on-policy vs off-policy) determines whether errors scale quadratically or linearly. The divergence just affects the constant factor.

The GKD recipe in practice: (1) Warm up the student with off-policy training (λ=0) for a few thousand steps. (2) Switch to mixed training (λ=0.5) for the bulk of training. (3) Optionally increase to λ=1.0 for final fine-tuning. This curriculum ensures the student is "good enough" before we ask it to generate its own training data.

Walked Through: One GKD Training Step

Let's trace through exactly what happens in a single training step with λ=0.5 and JSD loss. This is the complete data flow, no hand-waving.

Step 1: Decide sampling source. We flip a weighted coin: with probability λ=0.5, we use on-policy (student generates); with probability 1-λ=0.5, we use the ground-truth sequence from the dataset. Suppose the coin says "on-policy."

Step 2: Generate student rollout. Given prompt x = "Prove that 2+2=4", the student autoregressively generates: ŷ = [2, +, 2, =, 4, by, definition]. Seven tokens. Each sampled from pθ(·|x, ŷ<t).

Step 3: Get teacher scores. We feed the student's sequence back through the teacher model. At EACH position t, we get the teacher's full probability distribution pT(·|x, ŷ<t). Note: the teacher conditions on the STUDENT's prefix — this is the on-policy part.

Step 4: Compute per-token JSD loss. At position t=3 (predicting "2"), for example:

Teacher distribution: pT = [0.01, 0.02, 0.85, 0.01, 0.05, 0.03, 0.03] (over a tiny vocab for illustration)

Student distribution: pθ = [0.05, 0.10, 0.50, 0.10, 0.10, 0.05, 0.10]

Mixture: M = (pT + pθ)/2 = [0.03, 0.06, 0.675, 0.055, 0.075, 0.04, 0.065]

JSD = 0.5 · KL(pT||M) + 0.5 · KL(pθ||M)

For the dominant token (index 2): 0.5 × 0.85 × log(0.85/0.675) + 0.5 × 0.50 × log(0.50/0.675) = 0.5 × 0.85 × 0.2299 + 0.5 × 0.50 × (-0.3001) = 0.0977 + (-0.0750) = 0.0227

Sum across all vocabulary entries and all T positions to get the total loss for this sample.

Step 5: Backpropagate and update. Compute gradients of the total loss with respect to student parameters θ. Apply optimizer step (AdamW, lr=1e-5 typical).

Misconception: "Fully on-policy (λ=1) is always best." It's not — at the START of training, the student is so weak that its generations are garbage. Training on garbage doesn't help. You need λ=0 or λ=0.5 initially to give the student coherent examples to learn from. λ=1 works only after the student can produce semi-reasonable text. Think of it as: you can't "practice piano" productively if you can't even find middle C yet.

Implementation: The Full GKD Loop

First, the complete training loop in ~30 lines, annotated:

python
import torch
import torch.nn.functional as F

def gkd_step(student, teacher, prompt, lam=0.5, divergence='jsd'):
    """One GKD training step. Returns scalar loss."""

    # Step 1: Decide sampling source
    if torch.rand(1).item() < lam:
        # On-policy: student generates the sequence
        with torch.no_grad():
            generated_ids = student.generate(prompt, max_new_tokens=64)
    else:
        # Off-policy: use ground-truth from dataset
        generated_ids = ground_truth_ids

    # Step 2: Get teacher distribution at each position
    with torch.no_grad():
        teacher_logits = teacher(generated_ids).logits  # [1, T, V]
        teacher_probs = F.softmax(teacher_logits, dim=-1)

    # Step 3: Get student distribution (WITH gradients)
    student_logits = student(generated_ids).logits  # [1, T, V]
    student_probs = F.softmax(student_logits, dim=-1)
    student_log_probs = F.log_softmax(student_logits, dim=-1)

    # Step 4: Compute chosen divergence
    if divergence == 'fkl':
        # Forward KL: sum P_T * log(P_T / P_S)
        loss = (teacher_probs * (teacher_probs.log() - student_log_probs)).sum(-1).mean()
    elif divergence == 'rkl':
        # Reverse KL: sum P_S * log(P_S / P_T)
        teacher_log_probs = teacher_probs.log()
        loss = (student_probs * (student_log_probs - teacher_log_probs)).sum(-1).mean()
    elif divergence == 'jsd':
        # JSD: 0.5*KL(P_T||M) + 0.5*KL(P_S||M)
        M = 0.5 * (teacher_probs + student_probs)
        M_log = M.log()
        loss = 0.5 * (teacher_probs * (teacher_probs.log() - M_log)).sum(-1).mean() + \
               0.5 * (student_probs * (student_log_probs - M_log)).sum(-1).mean()

    return loss

And a minimal training loop using this function:

python
# GKD training schedule with lambda warmup
optimizer = torch.optim.AdamW(student.parameters(), lr=1e-5)

for step in range(10000):
    # Curriculum: ramp lambda from 0 to 0.5 over first 1000 steps
    lam = min(0.5, step / 2000)  # 0 -> 0.5 linearly

    prompt = next(dataloader)
    loss = gkd_step(student, teacher, prompt, lam=lam, divergence='jsd')

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step % 100 == 0:
        print(f"Step {step}, λ={lam:.2f}, loss={loss.item():.4f}")

Why the Mixture Matters: Worked Numbers

Let's see quantitatively why λ=0.5 outperforms λ=0. Consider a student with ε=0.08 (8% per-token error on teacher-distributed inputs). After 1000 steps of training:

λ=0 (off-policy): The student only sees perfect prefixes. Its per-token error on student-distributed inputs stays at ~0.08 or worse (it never practices recovery). On a T=30 reasoning chain: error bound = 0.08 × 900/2 = 36. Complete failure.

λ=0.5 (mixed): Half the time, the student sees its own imperfect prefixes. It learns to handle error states. Its effective ε on student-distributed inputs drops to ~0.04 (half the original, because it's now trained on the right distribution). On a T=30 chain: error bound = 0.04 × 30 = 1.2. About one mistake per chain — recoverable.

The numbers aren't precise (real LLMs are more complex), but the mechanism is: training on your own distribution makes you better at handling your own distribution. This is the DAgger theorem in action.

GKD Training Simulator

Watch a student distribution train toward a bimodal teacher. Adjust λ to control sampling policy. Higher λ = more on-policy. Click "Train 10 Steps" to animate learning. The teal curve is the student; the white curve is the teacher.

λ (on-policy fraction) 0.50
Step 0 | KL divergence: 0.482 | λ = 0.50

GKD's Ablation Results (Summarized)

The paper tested all combinations of sampling policy and divergence on summarization (XSum) and translation (WMT) tasks. Key findings:

ConfigurationXSum ROUGE-LImprovement over off-policy FKL
λ=0, FKL (standard KD)32.1baseline
λ=0, RKL32.4+0.3 (divergence alone helps a little)
λ=0.5, FKL34.2+2.1 (sampling helps A LOT)
λ=0.5, JSD34.5+2.4 (best combination)
λ=1.0, JSD33.9+1.8 (fully on-policy slightly worse without warmup)

The pattern is clear: changing λ from 0 to 0.5 gives +2.1 ROUGE-L. Changing divergence from FKL to JSD (same λ) gives only +0.3. The sampling distribution is the dominant factor.

GKD's legacy: Before GKD, people argued about FKL vs RKL for months. GKD showed this was the wrong debate. The right question isn't "which loss function" — it's "whose prefixes does the student train on?" Getting the sampling distribution right matters 10x more than the loss function. This insight directly led to all subsequent on-policy distillation work (Chapters 5-8).

The Complete GKD Algorithm

Phase 1: Warmup
λ=0, FKL. Train on teacher sequences for 500-1000 steps. Student learns basic competence.
Phase 2: Mixed Training
λ=0.5, JSD. Half on-policy, half teacher sequences. Student learns error recovery.
Phase 3: On-Policy Refinement
λ=1.0, JSD (optional). Fully on-policy. Student refines on its own distribution.
Result
Student that performs well on ITS OWN generated sequences, not just teacher-conditioned inputs.
GKD in one sentence: Generate some sequences from the student, score them with the teacher, compute a divergence loss at each token position, and backpropagate. The λ parameter controls what fraction of training sequences come from the student vs. from the dataset. That's it. The rest is engineering (warmup schedules, temperature scaling, gradient accumulation).
Why does GKD use a mixture policy πmix = λ·pθ + (1-λ)·pdata instead of pure on-policy (λ=1) from the start?

Chapter 5: Reverse KL & The RL Equivalence

MiniLLM (Gu et al., 2024) went all-in on reverse KL. Why? And why does this turn distillation into reinforcement learning?

Recall from Chapter 3: reverse KL is DKL(Pθ || PT) = Ey~Pθ[log Pθ(y) - log PT(y)]. The expectation is over the STUDENT's own distribution. To compute this loss, we must sample from the student — generate full sequences — then evaluate how well those sequences score under the teacher.

Here's the problem: sampling from a categorical distribution is not differentiable. When the student picks token yt from its softmax output, that discrete choice has no gradient. We can't backpropagate through "I chose token #4527 from a 32,000-word vocabulary." The argmax (or sample) operation is a wall.

The solution is ancient RL wisdom: the REINFORCE trick (Williams, 1992). Instead of backpropagating through the sample, we treat the sampled sequence as given and adjust the probability of producing it based on how good it turned out to be. Good sequence? Increase its probability. Bad sequence? Decrease it. The "goodness" is the reward.

The key equivalence: Define per-token reward as rt = log PT(yt | y<t) - log Pθ(yt | y<t). Then minimizing reverse KL is EXACTLY equivalent to maximizing expected cumulative reward. Distillation IS reinforcement learning with teacher log-probability as the reward function.

Let's derive this carefully. Start with reverse KL over full sequences:

minθ DKL(Pθ || PT) = minθ Ey~Pθ[ log Pθ(y) - log PT(y) ]

Negate (turn min into max) and rearrange:

= maxθ Ey~Pθ[ log PT(y) - log Pθ(y) ]

Split into two terms:

= maxθ { Ey~Pθ[ log PT(y) ] + H(Pθ) }

The first term is the expected teacher log-probability — this is the reward. The second term is the student's own entropy — a bonus for exploration, preventing the student from collapsing to a single deterministic output. This is EXACTLY the entropy-regularized RL objective (MaxEnt RL, Ziebart et al., 2008).

G-OPD (Generalized On-Policy Distillation) takes this further with reward extrapolation. If log PT(y) is the base reward, we can ADD an external signal to push the student BEYOND the teacher:

R(y) = log PT(y) + β · Rext(y)

Where Rext might be a math verifier (is the proof correct?), a code executor (does it pass tests?), or a human preference model. This is how you get a 7B student that's BETTER than the 70B teacher on specific tasks — impossible with forward KL alone, because FKL can only make the student MATCH the teacher, never exceed it.

Worked Example: One REINFORCE Gradient Step

The student generates the token "therefore" at position t with probability pθ("therefore") = 0.3. The teacher assigns pT("therefore") = 0.7. Let's compute the gradient:

Step 1 — Per-token reward:

rt = log PT("therefore") - log Pθ("therefore") = log(0.7) - log(0.3) = -0.357 - (-1.204) = +0.847

Positive reward! The teacher likes this token much more than the student's current belief suggests. The student should increase its probability here.

Step 2 — REINFORCE gradient:

θ J = rt × ∇θ log Pθ("therefore")

The score function ∇ log Pθ for a categorical distribution at the chosen token has magnitude 1/Pθ = 1/0.3 = 3.33 (in the logit direction). So the effective gradient magnitude is:

|∇ J| = 0.847 × 3.33 = 2.82

This is a strong positive gradient — it pushes the student hard toward "therefore." Now consider what happens at equilibrium: if pθ = pT = 0.7, then rt = log(0.7) - log(0.7) = 0. Zero reward, zero gradient. The student has matched the teacher on this token — no more signal. Beautiful.

What if the student generates a token the teacher hates? Say pθ("however") = 0.4 but pT("however") = 0.02:

rt = log(0.02) - log(0.4) = -3.91 - (-0.916) = -2.99

Massive negative reward. Gradient magnitude: |-2.99 × (1/0.4)| = 7.48. The student gets punished HARD for generating tokens the teacher considers nearly impossible. This is the mode-seeking behavior of reverse KL in action — it ruthlessly suppresses tokens outside the teacher's support.

python
import torch
import torch.nn.functional as F

def minillm_loss(student, teacher, prompts, max_len=128, entropy_coef=0.01):
    """MiniLLM: Reverse KL via REINFORCE.
    1. Generate full sequences from student (on-policy sampling)
    2. Score each token with teacher log-prob (reward)
    3. Apply policy gradient (REINFORCE with baseline)
    """
    # Step 1: Generate from student — NO gradient through sampling
    with torch.no_grad():
        sequences = student.generate(prompts, max_length=max_len, do_sample=True)
        # Also get old log-probs as baseline
        old_logits = student(sequences)
        old_log_probs = F.log_softmax(old_logits, dim=-1)
        old_lp = old_log_probs.gather(-1, sequences.unsqueeze(-1)).squeeze(-1)

    # Step 2: Get teacher log-probs (reward signal)
    with torch.no_grad():
        teacher_logits = teacher(sequences)
        teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
        teacher_lp = teacher_log_probs.gather(-1, sequences.unsqueeze(-1)).squeeze(-1)

    # Per-token reward: how much teacher likes this vs student's prior
    rewards = teacher_lp - old_lp  # r_t = log P_T(y_t) - log P_θ(y_t)

    # Step 3: Recompute student log-probs WITH gradient
    new_logits = student(sequences)
    new_log_probs = F.log_softmax(new_logits, dim=-1)
    log_pi = new_log_probs.gather(-1, sequences.unsqueeze(-1)).squeeze(-1)

    # REINFORCE: maximize E[reward * log_prob]
    policy_loss = -(rewards.detach() * log_pi).mean()

    # Entropy bonus: prevent collapse
    probs = new_log_probs.exp()
    entropy = -(probs * new_log_probs).sum(dim=-1).mean()

    return policy_loss - entropy_coef * entropy
Distillation as RL: Reward Landscape & Policy

Top: teacher log-prob (reward landscape). Bottom: student policy shifting toward high-reward regions. Toggle "Reward Extrapolation" to see the student surpass the teacher.

Entropy β 0.10
Misconception: People think "distillation = supervised learning." G-OPD proved it's NOT. Minimizing reverse KL IS entropy-regularized RL. This means every trick from RLHF — PPO clipping, advantage estimation, GAE, reward shaping, KL penalties — ALL apply directly to distillation. If you already know RL, you already know advanced distillation. The only difference is the reward source: in RLHF it's a human preference model, in distillation it's the teacher's log-probability.

The practical implications are profound. Reward extrapolation lets the student exceed the teacher. This is impossible with forward KL (which can only make you match the teacher). With reverse KL + external verifier, a 7B model can beat a 70B model on math — it just needs to explore more and the verifier tells it when it's right.

In the RL view of on-policy distillation, what plays the role of the reward function?

Chapter 6: Self-Distillation — No Teacher Needed

What if you don't HAVE a bigger teacher? What if your model IS the biggest model? This is the situation frontier labs face: GPT-5 can't distill from GPT-6 because GPT-6 doesn't exist yet. The model must improve from itself.

Self-distillation sounds paradoxical — you can't create knowledge from nothing. But you CAN reorganize existing knowledge. Three distinct mechanisms make this possible, each exploiting a different kind of "hidden signal" within the model's own capabilities.

Mechanism 1: Privileged Information (OPSD)

Give the model extra information at training time that's unavailable at test time. Concretely: provide the ground-truth answer as a "hint" during training. The model with hints generates better reasoning traces. Then distill the hint-conditioned model into a hint-free model:

P(reasoning | question, hint) → P(reasoning | question)

Example: Math problem "What is ∫ sin(x) dx?" The privileged version sees "[Answer: -cos(x) + C]" as context and produces a clean step-by-step derivation. The unprivileged version learns to produce that same derivation WITHOUT seeing the answer — it learns the reasoning PATTERN, not the specific answer.

This works because knowing the answer simplifies reasoning enormously. When you know the destination, you don't explore dead ends. The privileged model produces cleaner, shorter, more logical traces — and the unprivileged model learns to mimic that efficiency.

Mechanism 2: Self-Play (SPIN)

SPIN (Self-Play fINe-tuning) uses the model's previous version as a reference. At each round, the current model's outputs are "rejected" and the reference's outputs are "chosen." DPO (Direct Preference Optimization) then pushes the model to prefer reference-quality outputs over its own current outputs.

Round k
Freeze reference = θk-1. Train current = θk
Generate
Current θk produces outputs = "rejected"
Compare
Reference outputs (or ground truth) = "chosen"
DPO Update
θk+1: increase P(chosen), decrease P(rejected)
↻ set θk = θk+1, repeat

Mechanism 3: Verifier Feedback (SD-ZERO)

No teacher, no reference model. The model generates many candidate solutions to a problem, a binary verifier grades them (correct/incorrect), and GRPO (Group Relative Policy Optimization) uses the good ones as positive signal. Pure exploration + automated grading, no human in the loop.

This is how DeepSeek-R1 was trained: generate 64 candidate solutions to each math problem, check which are correct, reward the correct ones. The model learns which of its OWN strategies work — purely from trial and error.

Worked Example: SPIN Iteration

Current model generates "2+2=5" (rejected). Reference has "2+2=4" (chosen). With β=0.1:

log Pθ("2+2=4") = -2.3,   log Pθ("2+2=5") = -1.8
log Pref("2+2=4") = -1.5,   log Pref("2+2=5") = -2.5

The DPO logit is:

h = β × [(log Pθ(chosen) - log Pref(chosen)) - (log Pθ(rej) - log Pref(rej))]
= 0.1 × [(-2.3 - (-1.5)) - (-1.8 - (-2.5))]
= 0.1 × [(-0.8) - (0.7)] = 0.1 × (-1.5) = -0.15

Loss = -log σ(-0.15) = -log(0.463) = 0.770. This gradient pushes the model to increase P("2+2=4") relative to P("2+2=5"). After the update, the model prefers the correct answer slightly more. Repeat for many examples across many rounds.

After 3-5 rounds, the model's outputs become indistinguishable from the reference's. At that point, "chosen" and "rejected" look the same to DPO — the reward signal is zero. Training has saturated.

python
import torch
import torch.nn.functional as F

def spin_round(model, ref_model, prompts, ground_truth, beta=0.1):
    """One round of SPIN (Self-Play fINe-tuning).
    model: current policy to improve
    ref_model: frozen previous version
    ground_truth: high-quality outputs (chosen)
    """
    # Generate "rejected" from current model
    with torch.no_grad():
        rejected = model.generate(prompts, temperature=0.7)

    # Compute log-probs under both models for both sequences
    log_p_chosen = model.log_prob(ground_truth)       # WITH gradient
    log_p_rejected = model.log_prob(rejected)        # WITH gradient

    with torch.no_grad():
        log_ref_chosen = ref_model.log_prob(ground_truth)
        log_ref_rejected = ref_model.log_prob(rejected)

    # DPO loss
    logits = beta * (
        (log_p_chosen - log_ref_chosen) -
        (log_p_rejected - log_ref_rejected)
    )
    loss = -F.logsigmoid(logits).mean()
    return loss

def check_saturation(model, ref_model, test_prompts, threshold=0.01):
    """Has SPIN converged? Check if model ≈ reference."""
    with torch.no_grad():
        outputs = model.generate(test_prompts, temperature=0.5)
        kl = (model.log_prob(outputs) - ref_model.log_prob(outputs)).mean()
    print(f"KL(model || ref) = {kl.item():.4f}")
    return abs(kl.item()) < threshold  # True → saturated, stop training
Self-Play Convergence

Watch the current model (orange) converge toward the reference (teal). After ~6 rounds, the signal saturates.

Round 0 | KL: 0.500
Misconception: People think self-distillation is "free lunch" — unlimited improvement from nothing. It's NOT. It saturates. You can only improve to the level of the reference / ground truth. To go BEYOND that level, you need external signal: a verifier (SD-ZERO), human feedback (RLHF), or a stronger teacher. Self-play is a polishing step, not a capability-creation step. Labs use it as ONE stage in a multi-stage pipeline: SFT → self-distill → RL → self-distill again.
Why does SPIN stop improving after 3-5 iterations?

Chapter 7: Training Stability — Making It Work in Practice

Everything we've discussed sounds great in theory. In practice, on-policy training is a knife's edge. The data distribution changes with every gradient step. The student generates new text, which changes the loss landscape, which changes the gradients, which changes the student, which changes the generated text. It's a feedback loop that can spiral into instability or mode collapse in minutes.

Three pillars make on-policy distillation stable in production systems. Without them, you burn GPU-hours producing garbage.

Pillar 1: Token Importance Weighting (TIP)

Not all tokens are equally informative. Consider three cases:

Teacher ConfidenceStudent AgreementActionWhy
High (entropy < 0.2)Low (divergence > 0.5)TRAINTeacher knows the answer, student doesn't — maximum learning signal
HighHigh (divergence < 0.1)SkipBoth already agree — zero information gain, wasted compute
Low (entropy > 1.0)AnySkipTeacher itself is uncertain — its "guidance" here is noise, not signal

The importance score for each token position combines teacher confidence with student confusion:

wt = (1 - H(PT(yt)) / log V) × DKL(PT(yt) || Pθ(yt))

Where H is the teacher's entropy at position t, V is vocabulary size (so H/log V normalizes to [0,1]), and DKL measures how different student is from teacher. High weight = teacher confident AND student confused. This is the "golden quadrant" — maximum information gain per gradient step.

Pillar 2: Curriculum Scheduling (PACED)

Start with easy sequences and ramp difficulty. "Easy" = short sequences, simple prompts, high teacher confidence. "Hard" = long sequences, multi-step reasoning, ambiguous prompts. A beta-kernel sampler selects training examples from the frontier — sequences just barely beyond the student's current ability. Too easy = no learning. Too hard = noise.

Pillar 3: Flawed Prefix Detection

If the student generates garbage (repetitive loops like "the the the the", incoherent word salad, or degenerate patterns), the teacher's feedback on that garbage is NOISE. The teacher was trained on coherent text — asking it "what comes after 'the the the the the'?" produces meaningless output. Training on that meaningless output actively harms the student.

Solution: compute the perplexity of the student's generated prefix. If PPL > threshold (typically 100-500), discard that prefix and resample. Only train on prefixes that are at least minimally coherent.

Worked Example: Scoring 5 Tokens

python
import numpy as np

V = 32000      # vocabulary size
log_V = np.log(V)  # ≈ 10.37, maximum possible entropy

# Each token: (word, teacher_entropy, KL_divergence)
tokens = [
    ("the",         0.08, 0.03),   # function word, both agree
    ("integration", 0.15, 0.82),   # content word, student confused
    ("um",          2.30, 0.45),   # filler, teacher uncertain too
    ("equals",      0.05, 0.91),   # math token, student very wrong
    ("maybe",       1.80, 0.12),   # hedge word, teacher unsure
]

threshold = 0.30
print(f"{'Token':<12} {'H_T':<6} {'KL':<6} {'Score':<8} {'Decision'}")
print("-" * 50)
for word, h_t, kl in tokens:
    confidence = 1.0 - h_t / log_V    # how sure the teacher is
    score = confidence * kl           # importance weight
    action = "TRAIN" if score >= threshold else "skip"
    print(f"{word:<12} {h_t:<6.2f} {kl:<6.2f} {score:<8.3f} {action}")

# Token        H_T    KL     Score    Decision
# --------------------------------------------------
# the          0.08   0.03   0.029    skip
# integration  0.15   0.82   0.798    TRAIN  ← golden!
# um           2.30   0.45   0.350    TRAIN  (borderline)
# equals       0.05   0.91   0.906    TRAIN  ← golden!
# maybe        1.80   0.12   0.099    skip

Tokens "integration" and "equals" are in the golden quadrant: teacher is confident, student is wrong. These are where 80%+ of the learning happens. "the" and "maybe" are both worthless to train on — "the" because everyone agrees, "maybe" because nobody knows.

The compute savings are enormous:

ApproachRelative CostWhy
Naive on-policy (all tokens)3-5× baselineGenerate + teacher forward + student backward on every token
+ Token filtering (skip ~60%)1.5×Only backpropagate on high-importance tokens
+ Prefix caching & async generation0.8×Overlap generation with gradient computation in pipeline
+ Flawed prefix rejection0.6×Don't waste teacher queries on degenerate student outputs
The industrial pipeline schedule: Off-policy warmup (first 20% of training, λ=0) → Mixed phase (next 40%, λ=0.3→0.7 ramping) → Full on-policy (final 40%, λ=1.0). This avoids on-policy instability when the student is still terrible early on, and captures the full benefit of distribution matching later.
Token Importance Scorer

Each token colored by importance score. Adjust threshold to see which tokens get trained on vs. skipped. Below: the 2D scatter — golden quadrant (low teacher entropy, high divergence) highlighted.

Threshold 0.30
Misconception: People think "just train on everything on-policy" is the right default. It's NOT. This WASTES compute on uninformative tokens (where student already agrees with teacher) and HARMS training on noisy tokens (where teacher is uncertain too — its gradient at those positions is random noise). Selective training gives the same final quality in 40% of the compute. The token importance filter is one of the highest-ROI engineering decisions in distillation pipelines.
Why should you SKIP tokens where the teacher has high entropy (is uncertain)?

Chapter 8: OPD Lab — Full Pipeline Simulator

Time to run the full pipeline yourself. This simulator lets you compare every distillation method we've covered on a concrete task: matching a bimodal teacher distribution over varying sequence lengths. The teacher has two equally valid modes (like a translation task with two correct phrasings). The student starts as a broad, uncertain distribution and must learn to match both peaks.

Watch the key behaviors emerge:

On-Policy Distillation Laboratory

Top: teacher (teal) vs student (orange) distributions. Bottom: KL divergence over training steps. Select method, set T, and train.

Seq Length T 20
Step: 0 KL: Success P:
Experiments to try:
  1. Set T=5, Off-Policy → works fine (converges in ~40 steps). Now set T=40 → watch it fail.
  2. Switch to GKD λ=1.0 at T=40 → steady convergence. Linear error holds.
  3. Try MiniLLM (RKL) → one peak grows, the other vanishes. Mode-seeking in action.
  4. Self-Play → blazing fast for 15 steps, then FLAT. Saturation is real.
  5. Compare all methods at T=10 (similar) vs T=50 (enormous gap). This IS the DAgger bound.
MethodError ScalingMode BehaviorFailure Pattern
Off-PolicyO(εT²)Mode-coveringCatastrophic at long T
GKD λ=0.5O(εT)Balanced (both peaks)Slow initial convergence
GKD λ=1.0O(εT)Balanced (both peaks)Noisy early, best final quality
MiniLLM (RKL)O(εT)Mode-seeking (one peak)Drops valid modes entirely
Self-Play (SPIN)O(εT) → flatMatching (reference-bounded)Saturates at ~30 steps

The showcase demonstrates why production systems use hybrid pipelines: off-policy warmup for early stability, GKD for the main training phase, and optional RKL + verifier for final polishing on deterministic tasks (math, code) where mode-seeking is desirable.

Chapter 9: Mastery & Connections

Everything you've learned, compressed for interview prep, implementation, and further study. This chapter tests whether you can APPLY the concepts, not just recognize them.

Derivation Challenge

Derive the DAgger bound from scratch. Click each "???" to reveal:

Setup: Student has per-step error ε under training distribution. At step t:

1. P(at least one mistake in steps 1..t) ≤ ???

2. Once off-distribution, worst-case per-step error = ???

3. Expected error at step t = ???

4. Total over T steps: ∑t=1T tε = ε · ???

5. With on-policy (DAgger): error at step t = ???

Design Challenge

You're distilling a 70B reasoning model into 7B for deployment. Budget: 1000 GPU-hours. Design the pipeline (click "?" to reveal suggested answers):

StageBudget %MethodλRationale
Warmup ? ? ? ?
Main ? ? ? ?
RL Polish ? ? ? ?
Final ? ? ? ?

Click each "?" cell to reveal the suggested answer.

Break-It Lab

Three scenarios where OPD breaks. For each: identify the failure mode and the fix (click to reveal).

Scenario 1: Student generates "the the the the the..." for 50 tokens. You query the teacher to score it.

Failure + Fix: Click to reveal

Scenario 2: Self-play (SPIN) metrics are flat after round 3. No improvement.

Failure + Fix: Click to reveal

Scenario 3: Student trained with FKL on math problems outputs "7" when valid answers are "4" and "12".

Failure + Fix: Click to reveal

Cheat Sheet

ConceptKey EquationWhen to Use
Off-policy lossEy~PT[KL(PT || Pθ)]Warmup; short sequences only
On-policy lossEy~Pθ[KL(PT || Pθ)]Main training; long sequences
GKD mixλ·Lon + (1-λ)·LoffAlways — interpolate via λ
DAgger off-policyO(εT²)Predicting failure at length T
DAgger on-policyO(εT)Linear error guarantee
RL equivalencemax EPθ[log PT] + H(Pθ)When using reverse KL (MiniLLM)
Token importance(1 - HT/log V) × KLFiltering uninformative tokens
Reward extrap.R = log PT + β·RextExceeding teacher quality

Connections

Builds on:

Leads to:

  • Agent distillation (tool use, multi-turn)
  • Speculative decoding (student as draft model)
  • Scaling laws for distillation compute
  • Test-time compute & self-refinement

"What I cannot create, I do not understand." — Richard Feynman. In distillation: what the student cannot generate on its own, it has not truly learned.

What unifies knowledge distillation, RLHF, and imitation learning into a single framework?