How to train small models that actually reason — by letting them practice on their own mistakes, not just copy perfect answers.
You've trained a small model to imitate GPT-4's outputs. It sounds fluent on short prompts. Then you ask it to solve a 10-step math problem, and it falls apart after step 3. The answers are grammatically perfect but logically incoherent. Why?
The problem is deceptively simple: during training, your student model always conditioned on perfect prefixes — the teacher's token sequences. It never saw its own mistakes. At inference time, the student generates from its own imperfect prefixes, and the moment it makes a small error, it enters territory it has literally never encountered.
Off-policy distillation means the student trains on data generated by someone else — typically the teacher model. The training signal is: "given this perfect prefix that the teacher wrote, predict the next token distribution." The student learns to be a good next-token predictor when everything before it is correct.
This works beautifully for short sequences. A single sentence, a quick factual answer — the student rarely drifts far enough from teacher territory to matter.
But for multi-step reasoning? Catastrophe.
Imagine learning to drive by watching 10,000 hours of perfect dashcam footage. You see smooth lane changes, perfect braking distances, flawless parallel parking. You memorize it all.
Then someone puts you behind the wheel. You drift slightly left. Nothing in your training covered "recovering from a leftward drift" — you only ever saw the car perfectly centered. So you do something random. Now you're even further from any state you've seen. Panic compounds.
On-policy distillation means the student generates (some or all) prefixes itself, then receives the teacher's guidance on those self-generated contexts. The training distribution matches the inference distribution because both come from the student.
When the student drifts at step 3, it still gets teacher signal at step 4. It learns: "when I'm in this slightly-wrong state, here's how to recover." Over thousands of such episodes, the student builds robustness to its own failure modes.
Toggle between modes to see how the student behaves at inference time. In off-policy mode, the student has never seen its own errors — it enters unknown territory after a mistake. In on-policy mode, it has practiced recovery.
The intuition from Chapter 0 has a precise mathematical formulation. Let's derive exactly WHY errors compound — and quantify how much worse off-policy is than on-policy.
Standard knowledge distillation trains the student to match the teacher's output distribution, conditioned on teacher-generated prefixes:
The critical detail: y<t comes from the teacher. The student never conditions on its own generated tokens during training. At inference, however, it must condition on its own previous outputs.
This creates a train-test distribution mismatch. During training, the student sees the distribution of prefixes pT(y<t). During inference, it sees pθ(y<t). These two distributions diverge more with each token generated.
Let ε be the per-step error probability — the chance that the student generates a token that deviates from the teacher's distribution at any given step. After T steps:
This is the DAgger theorem (Ross et al., 2011), originally proven for imitation learning in robotics. It applies directly to autoregressive language modeling.
Suppose your student has 95% per-step accuracy (ε = 0.05). For a 10-step reasoning chain:
Independent errors (best case): P(all correct) = 0.9510 = 60%. Already losing 40% of trajectories.
With compounding (off-policy): Each error makes subsequent errors more likely because the student enters unseen territory. The actual success rate drops well below 60% — closer to the O(εT2) = O(0.05 × 100) = O(5) bound, meaning errors dominate completely.
With on-policy training: The student has practiced recovery. Errors at step 3 don't cascade because the student knows what to do from its own imperfect states. Success follows the O(εT) = O(0.05 × 10) = O(0.5) bound — much more controlled.
On-policy distillation replaces teacher prefixes with student-generated (or mixed) prefixes:
Here πmix is a mixture policy — typically λ · pθ + (1-λ) · pdata — that blends student rollouts with ground-truth sequences. The divergence Df can be forward KL, reverse KL, JSD, or any f-divergence (Chapter 2 explores which to choose).
Adjust per-step error ε and see how total error scales with sequence length T. The quadratic (off-policy) curve explodes while the linear (on-policy) curve remains manageable.
Now we know WHY on-policy helps. The next question: what should the student optimize when training on its own outputs? The choice of divergence measure fundamentally shapes what the student learns.
An f-divergence is a general way to measure how different two probability distributions are. Every f-divergence has the form:
where f is a convex function with f(1) = 0. Different choices of f give different divergences — and each one tells the student to prioritize different aspects of the teacher distribution.
Forward KL (f(u) = u log u) measures DKL(Pteacher || Qstudent). The expectation is over the student's distribution. This means: wherever the teacher has probability mass, the student MUST also have probability mass (or the divergence explodes).
Result: the student covers all modes of the teacher. If the teacher assigns probability to two different valid answers, the student spreads probability over both. But it also assigns probability to the space between the modes — hallucinating intermediate outputs that neither mode produces.
Reverse KL (f(u) = -log u) measures DKL(Qstudent || Pteacher). The expectation is over the teacher's distribution. This means: wherever the student has probability mass, the teacher must also (or the divergence explodes).
Result: the student seeks one mode and concentrates on it. It won't hallucinate between modes because it never places mass where the teacher doesn't. But it drops entire modes — if the teacher has two valid answers, the student picks one and ignores the other.
The Jensen-Shannon Divergence (JSD) is a symmetric average of forward and reverse KL, computed through a mixture M = (P+Q)/2. It's bounded in [0, log 2], which prevents gradient explosion.
The α-divergence family provides a continuous interpolation parameter α between forward KL (α → 1) and reverse KL (α → 0). Setting α = 0.5 gives a symmetric divergence similar to JSD.
| Task Type | Best Divergence | Why |
|---|---|---|
| Math / Code | Reverse KL | One correct answer — mode-seeking concentrates on it |
| Creative writing | Forward KL | Many valid outputs — mode-covering preserves diversity |
| General chat | JSD / α-div | Balance between precision and coverage |
| Safety-critical | Reverse KL | Must not hallucinate outside teacher's support |
A bimodal teacher (two valid answers). Slide from Forward KL to Reverse KL and watch how the student distribution changes. Forward KL spreads to cover both modes (hallucinating between). Reverse KL collapses onto one mode.
Three papers crystallized the core approaches to on-policy distillation. Each picks a different point in the design space — and each reveals a different tradeoff.
GKD is the simplest framework. Its key insight: make on-policy distillation modular. Separate the sampling policy (who generates the prefixes?) from the divergence (what does the student optimize?).
The sampling policy is a mixture:
When λ = 0, it's standard off-policy (train on ground-truth prefixes). When λ = 1, it's fully on-policy (train only on student's own generations). Values in between blend both — the student sees some perfect prefixes and some of its own imperfect ones.
GKD is divergence-agnostic: you can plug in forward KL, reverse KL, JSD, or any f-divergence. The sampling policy and the divergence are orthogonal choices.
MiniLLM goes all-in on reverse KL. The problem: reverse KL DKL(pθ || pT) requires computing the gradient through the student's own sampling process (because the expectation is over pθ). You can't just backpropagate through discrete token sampling.
MiniLLM's solution: treat it as a reinforcement learning problem. The "reward" for generating token yt is:
This is the log-probability ratio: how much more the teacher likes this token than the student does. High reward = the teacher strongly approves of something the student isn't yet confident about. The REINFORCE estimator uses this reward to compute policy gradients.
The catch: REINFORCE has high variance. The gradient estimates are noisy, requiring large batch sizes, careful baseline subtraction, and many training steps to converge. Training is unstable — the loss oscillates wildly before settling.
DistiLLM solves MiniLLM's instability problem with a clever trick: instead of minimizing DKL(pθ || pT) directly, minimize DKL(pθ || p̃) where p̃ is a skewed mixture:
Why does this help? The gradient explosion in MiniLLM happens when the student assigns near-zero probability to a token the teacher strongly prefers. The ratio pT/pθ → ∞, and the gradient explodes.
By mixing in the student's own distribution, DistiLLM ensures p̃ never diverges too far from pθ. Even when pT(y) is high and pθ(y) is near zero, the mixture p̃(y) = α · pT(y) + (1-α) · 0 ≈ α · pT(y), and the ratio p̃/pθ grows more slowly than pT/pθ. Gradient explosions are tamed.
python # GKD: simple mixture sampling + any divergence def gkd_loss(student, teacher, x, lam=0.5, div='jsd'): # Sample prefix from mixture policy if random() < lam: prefix = student.generate(x) # on-policy else: prefix = ground_truth(x) # off-policy p_t = teacher.forward(x, prefix) # teacher logits p_s = student.forward(x, prefix) # student logits return compute_divergence(p_t, p_s, div) # MiniLLM: reverse KL via REINFORCE def minillm_loss(student, teacher, x): y = student.generate(x) # full on-policy rollout r = teacher.logprob(y, x) - student.logprob(y, x) baseline = moving_avg(r) # variance reduction return -((r - baseline) * student.logprob(y, x)).mean() # DistiLLM: skewed mixture target def distillm_loss(student, teacher, x, alpha=0.9): y = student.generate(x) p_t = teacher.forward(x, y) p_s = student.forward(x, y) p_mix = alpha * p_t + (1 - alpha) * p_s # skewed target return kl_div(p_s, p_mix) # bounded gradient!
See how each method's student distribution looks after training on the same bimodal teacher. Adjust λ (GKD's on-policy fraction) to see how more on-policy data improves the student.
Fixed divergences are one-size-fits-all. But what if different tokens need different treatment? A confident token where teacher and student agree needs less gradient pressure than a confused token at a reasoning branch point.
Consider a math proof: tokens like "therefore" are near-deterministic — both teacher and student assign 95%+ probability. But choosing between "integration by parts" vs. "substitution" is the crux. A single divergence treats both identically. That's wasteful at best, harmful at worst.
ToDi (Token-level Divergence routing) scores each token by its log-ratio rt = log pθ(yt) − log pT(yt). If rt > 0 (student overestimates), use FKL to pull it down. If rt < 0 (student underestimates), use RKL to push it up. The routing is per-token and parameter-free.
AKL (Adaptive KL) blends FKL and RKL based on the head/tail gap of the teacher's distribution. Peaked teacher → RKL. Flat teacher → FKL. The weight is a smooth function of teacher entropy.
EOPD interpolates continuously: αt = σ(H(pT) − τ). High teacher entropy → more FKL. Low entropy → more RKL.
G-OPD proves a startling result: on-policy distillation IS reinforcement learning.
Left side: minimize KL from student to teacher. Right side: maximize reward (teacher log-prob) plus entropy (exploration). Identical. Every RL trick — baselines, advantage estimation, PPO clipping — applies directly to distillation.
Each token routed to FKL or RKL based on teacher entropy. Drag the threshold to see routing change.
| Method | Divergence | Key Innovation | Best For |
|---|---|---|---|
| GKD | Fixed (any) | λ-interpolation | General |
| ToDi | Adaptive FKL↔RKL | Per-token routing via log-ratio | Mixed tasks |
| G-OPD | KL-constrained RL | Reward extrapolation beyond teacher | Reasoning |
| AOPD | PG/FKL switch | Advantage-sign routing | Math proofs |
We've discussed WHAT to optimize. Now: where does the teaching signal come from? The answer determines everything — compute budget, quality ceiling, deployment constraints.
Think of it as a spectrum. At one end, full access to the teacher's brain (every logit). At the other end, just a thumbs-up or thumbs-down on complete outputs.
You run the teacher and get its full output distribution at every token position. Richest signal — you know exactly how much probability the teacher assigns to every possible next token. Cost: full teacher forward pass for every sample (~140 GB for a 70B model).
Cross-family challenge: Different tokenizers mean logits aren't aligned. DSKD and ULD solve this via shared vocabulary projection or distribution-level matching.
Only generated text or scalar scores. Reality when distilling from GPT-4 or Claude. GAD trains a discriminator. OVD uses the teacher to SELECT among student candidates.
The model teaches itself by exploiting asymmetries in its own capabilities:
Three paradigms. Click to highlight data flow. Toggle cost overlay.
python # White-box: exact KL at every token (dense signal) loss = kl_div(teacher_logits, student_logits) # [B, T, V] # Black-box: only outcome reward (sparse signal) reward = teacher_api.score(student_output) # scalar loss = -reward * log_prob(student_output) # REINFORCE # Self-distillation: model teaches itself good = model.generate(prompt, with_hints=True) loss = kl_div(model(prompt).logits, model(prompt+hints).logits)
The most surprising finding: you don't always need a teacher. A model can improve itself by exploiting asymmetries in its own capabilities.
OPSD: Condition on ground-truth during training. The model with answers generates better reasoning chains. Distill: p(reasoning | x, answer) → p(reasoning | x). The model learns to reason well WITHOUT seeing the answer.
CRISP: Concise prompt as PI — saves 57% tokens at inference while maintaining quality.
GATES: Full document as PI → distill to model with only a summary.
SPIN: At iteration k, generate yk. DPO with "chosen" = reference, "rejected" = yk. Update. Repeat. Converges when pθ = pref (distributions indistinguishable, DPO signal = zero).
π-Play: Multi-agent variant — several models co-evolve, playing against each other rather than a fixed reference.
SD-ZERO: Generate N solutions, verifier gives binary reward (correct/incorrect), optimize with GRPO. No teacher at all — just trial and error with automated grading. This is essentially what DeepSeek-R1-Zero does.
Current model (warm) vs. reference (teal). Click "Play Round" to iterate. Watch KL drop to zero (saturation).
| Approach | Signal Source | Assumption | Risk |
|---|---|---|---|
| OPSD | Ground-truth PI | GT available at train | PI leakage |
| SPIN | Previous iteration | Stable reference | Saturation after 3-5 rounds |
| SD-ZERO | Binary verifier | Correctness checkable | Reward hacking |
| GATES | Document gating | Long context helps | Not always true |
On-policy training is inherently unstable. The data distribution shifts with every gradient step. The student explores bad regions and gets noisy gradients. How do Qwen3 and DeepSeek-V4 make this actually work?
TIP (Token Importance Profiling) scores each token on a 2D quadrant: teacher entropy (is the teacher confident?) × student-teacher divergence (does the student disagree?). The golden tokens: teacher is confident but student is confused. These get highest weight.
SCOPE uses surprise weighting: tokens where student assigns very low probability but teacher assigns high get upweighted by their log-ratio.
Start with easy samples (low teacher perplexity), gradually increase difficulty. Beta-kernel sampling from the difficulty frontier — the difficulty slider auto-adjusts as the student improves.
On-policy is 3-8× more expensive (generate + score + train). Solutions:
| Method | Cost | Trick | Speedup |
|---|---|---|---|
| Naive | 3-5× | None | 1× |
| FOPD | 1.5× | Prefix truncation | 2-3× |
| Lightning-OPD | 1.2× | Offline caching | 4× |
| NPD | 0.6× | Async + ΔI-IFD | 8.1× |
If the student generates a terrible prefix (off-topic, degenerate repetition), teacher feedback on it is meaningless noise. Detection: perplexity threshold, length filter, reward floor.
Tokens scored by importance (divergence × teacher confidence). Drag threshold — gray tokens get skipped.
Let's put it all together. You're going to run an on-policy distillation pipeline and see every design choice interact in real time.
The teacher has a bimodal distribution — two valid answer modes. The student starts broad and must learn to match. Pick a method, set sequence length, and watch.
Top: teacher (fixed) vs student (evolving). Bottom: error accumulation over T. Pick a method and train.
At short sequences (T<10), all methods work equally well. As T grows, compounding error separates them: off-policy O(εT²) explodes, on-policy O(εT) stays manageable. This is why every frontier lab uses on-policy for reasoning models.
On-Policy Distillation sits at the intersection of three converging fields: knowledge distillation, reinforcement learning, and imitation learning.
| Perspective | Sampling | Objective | "Teacher" |
|---|---|---|---|
| Knowledge Distillation | Student rollouts | Df(PT || Pθ) | Larger model |
| RLHF / DPO | Student rollouts | Reward maximization | Human preferences |
| Imitation Learning | Learner trajectories | Expert correction | Expert policy |
| Self-Play | Self-generated | Improve over reference | Previous self |
The convergence is not just conceptual — it's infrastructural. OPD and RLHF need the same pipeline: rollout generation, scoring, filtering, gradient updates. They share 90% of code.
| Concept | Equation | When to Use |
|---|---|---|
| Off-policy KD | ∑ DKL(pT || pθ) on teacher prefixes | Short sequences, same family |
| On-policy (GKD) | Ey~πmix[∑ Df(pT, pθ)] | General purpose, moderate compute |
| RL equivalence | min DKL(Pθ||PT) = max E[log PT] + H | When you want reward shaping |
| Exposure bias | Off = O(εT²), On = O(εT) | Always — this is WHY OPD |