CS224N Lecture 13

Reasoning Part 2

Test-time compute, speculative decoding, and position encoding — making models think harder and run faster.

Prerequisites: L12 Reasoning 1. That's it.
10
Chapters
8+
Simulations
0
Assumed Knowledge

Chapter 0: Reasoning at Scale

In Lecture 12 we saw that chain-of-thought prompting lets LLMs reason step by step. But here's a question nobody asked until recently: if you have a fixed compute budget at inference time, should you spend it on a bigger model that answers once, or a smaller model that thinks longer?

This turns out to be one of the most important questions in modern AI. A 70B model answering a math problem in one shot uses roughly the same FLOPs as a 7B model generating 10 candidate solutions and picking the best one. Which strategy wins? The answer depends on the problem.

For easy problems — "What is 3 + 5?" — the big model is wasteful. The small model gets it right on the first try. For hard problems — AIME competition math, complex code generation — something remarkable happens: the small model with more attempts can match or beat the big model. The extra compute at inference time substitutes for compute at training time.

This is surprising. The conventional wisdom was: want better answers? Train a bigger model. The test-time compute revolution says: sometimes, think harder instead of being bigger. A 7B model that generates 100 solutions and verifies each step can outperform a 70B model that answers once. It's the difference between hiring one expert (expensive upfront, fast answers) vs. hiring a team of juniors with a good quality checker (cheap individuals, quality through selection).

The Crossover Point

The simulation below shows this tradeoff. On the x-axis is problem difficulty. On the y-axis is accuracy. The big model (purple) starts high but plateaus — it can't think harder, it can only think once. The small model with test-time compute (orange) starts lower but scales with more thinking. At some crossover point, the small + compute strategy wins.

Big Model vs. Small Model + Compute

Drag the difficulty slider to see where the crossover happens. The small model uses more test-time compute (multiple attempts + verification) while the big model answers in one shot.

Compute budget 5x

This insight — that you can trade training compute for inference compute — launched an entirely new research direction. Instead of building ever-bigger models, what if we build smarter inference strategies?

The core question of this lecture: can a small model that thinks harder beat a big model that answers quickly? The answer is yes, but only if we build the right scaffolding around inference — process reward models, search strategies, and verification. That scaffolding is what this lesson is about.

This lesson covers two intertwined threads. First: test-time compute scaling — process reward models, best-of-N sampling, and the compute-optimal frontier that tells us when to think harder. Second: inference efficiency — speculative decoding to generate faster, and RoPE/context extension to handle longer inputs. Together, they define the modern inference stack.

A Mental Model for the Lesson

Think of an LLM deployment as having two budgets:

Quality budget (Chapters 1-3, 8): How much compute to spend making the answer better. More search, more verification, more candidates. Costs time and money, but improves accuracy on hard problems.

Speed budget (Chapters 4-6): How to make each unit of compute go further. Speculative decoding generates faster. Context extension handles longer inputs without retraining. These are pure engineering wins.

The ideal inference system optimizes both: maximize quality per unit of compute, then minimize compute per unit of wall-clock time. This lesson teaches you the key technique for each.

A 7B model generates 16 candidate solutions and picks the best. A 70B model generates one solution. They use roughly the same FLOPs. When does the 7B model win?

Chapter 1: Process Reward Models

If we're going to generate multiple reasoning attempts and pick the best one, we need a way to evaluate them. But how? We could check the final answer (is it correct?), but that throws away all the information in the reasoning chain. A solution that gets the right answer by lucky cancellation of errors looks identical to one that reasons correctly.

An Outcome Reward Model (ORM) assigns a single score to the entire solution. It sees the final answer and says "correct" or "wrong." This is like grading an exam by only looking at the answer box — you can't distinguish a student who understood every step from one who guessed right.

A Process Reward Model (PRM) scores every intermediate step. It reads step 1 and says "correct." Step 2: "correct." Step 3: "this step contains an error." This is like a teacher grading the work, not just the answer. It knows exactly where reasoning went wrong.

Why Step-Level Feedback Matters

The Lightman et al. (2023) paper "Let's Verify Step by Step" made this concrete. They trained PRMs on the MATH dataset by having human labelers annotate every step in a solution as correct, incorrect, or neutral. The result: PRM-guided search dramatically outperformed ORM-guided search, especially on hard problems.

The intuition is simple. Consider a 5-step math proof where step 3 is wrong. An ORM sees: "final answer is wrong, score = 0.1." A PRM sees: "steps 1-2 correct (scores 0.95, 0.92), step 3 wrong (score 0.15), steps 4-5 follow from step 3 (scores 0.6, 0.4)." The PRM tells us exactly where to intervene — regenerate from step 3 onward, keeping the correct work.

PRM vs ORM Scoring

A 5-step solution with an error at step 3. The PRM catches the exact failure point. ORM only sees the final answer. Click "Regenerate from Error" to see the PRM advantage: we keep correct work and retry from the first mistake.

Visualizing the Difference

To make the PRM/ORM difference concrete, imagine you're a math teacher grading 100 student exams:

ORM approach: You look at the final answer only. Student A writes "x = 7" — correct. Student B writes "x = 7" — correct. Both get full marks. But Student B made two errors that coincidentally cancelled out. You have no way to know this, and no way to help Student B improve their process.

PRM approach: You read every line of work. Student A: all steps correct. Student B: step 3 has an algebraic error, step 5 has a sign error that cancels it. Student B gets partial credit and specific feedback: "Review your factoring in step 3." This is dramatically more useful for improving future performance.

In the LLM setting, "improving future performance" means knowing where to regenerate from. The PRM tells us the exact step where reasoning went wrong, so we can keep the correct prefix and only regenerate from the error forward.

Training a PRM

How do you train a model to judge reasoning steps? Lightman et al. collected ~800K step-level labels from human annotators on GPT-4 solutions to MATH problems. Each step was labeled:

LabelMeaningExample
PositiveStep is mathematically valid"Expanding (x+2)^2 = x^2 + 4x + 4"
NegativeStep contains an error"Since x^2 = 9, x = 3" (missed x = -3)
NeutralStep is filler or restating the problem"We need to find the value of x"

The PRM is a language model fine-tuned as a classifier: given the problem and all steps so far, predict whether the current step is correct. Practically, you take a pretrained LLM (like GPT-4 or Llama), add a classification head that outputs P(correct) for each step, and fine-tune on the labeled data.

A crucial design decision: the PRM must evaluate each step in context. It doesn't judge "x^2 + 4x + 4" in isolation — it judges whether this step follows logically from the previous steps and the problem statement. This means the PRM needs strong mathematical reasoning ability itself, which is why PRM training typically starts from a capable base model.

At inference time, the PRM's score for step i is the probability it assigns to "correct" given steps 1 through i.

PRM(stepi) = P(correct | problem, step1, ..., stepi)

The overall solution score can be computed in two ways:

Scoremin = mini PRM(stepi)     Scoreprod = ∏i PRM(stepi)

The minimum score reflects the weakest link: a chain of reasoning is only as strong as its weakest step. The product score accumulates uncertainty: many slightly-uncertain steps compound into low overall confidence. In practice, the minimum works better — it directly identifies the bottleneck step.

Consider the difference with a concrete example. Solution A has PRM scores [0.95, 0.93, 0.91, 0.90, 0.88]. Solution B has scores [0.99, 0.99, 0.40, 0.99, 0.99]. The product scores are 0.62 vs 0.38 — both seem questionable. But the minimum scores are 0.88 vs 0.40, immediately revealing that A is uniformly strong while B has a single catastrophic step. The minimum is a better indicator of solution quality because reasoning chains fail at their weakest point.

In practice, Lightman et al. found that minimum scoring with PRM outperformed product scoring by 2-3 percentage points on MATH. The gap is larger for problems with more steps, where the product score penalizes many slightly-uncertain steps even when none is actually wrong.

A PRM scores every reasoning step, not just the final answer. This lets you pinpoint exactly where reasoning goes wrong and regenerate only from the error. ORM is like grading an exam by the answer box. PRM is like grading the work shown — and it's dramatically more effective for guiding search.

PRM vs ORM: The Numbers

Lightman et al. compared PRM and ORM on MATH, using best-of-N selection with N=100. The PRM-guided selection solved 78.2% of problems. The ORM-guided selection solved 72.4%. For context, a single GPT-4 attempt (N=1) solved only 52.7%. The PRM added 25.5% absolute improvement by being smarter about which solution to pick, while the ORM added only 19.7%.

The gap widens on harder problems. On Level 5 MATH (the hardest), PRM's advantage over ORM exceeds 10 percentage points. Hard problems have more steps, more opportunities for error, and more benefit from step-level feedback.

PRM-Guided Search: Beyond Best-of-N

The real power of PRMs goes beyond simply scoring complete solutions. Because PRMs evaluate partial solutions, they enable tree search: at each step, generate multiple candidate next-steps, score them with the PRM, and only expand the most promising branches.

python
def prm_beam_search(problem, model, prm, beam_width=4):
    # Start with the problem as the root
    beams = [(1.0, problem)]  # (score, partial_solution)

    for step in range(max_steps):
        candidates = []
        for score, partial in beams:
            # Generate K candidate next steps
            next_steps = model.generate_step(partial, n=4)
            for ns in next_steps:
                extended = partial + ns
                # PRM scores the new step in context
                step_score = prm.score_step(extended)
                candidates.append((score * step_score, extended))

        # Keep top beam_width candidates
        candidates.sort(key=lambda x: -x[0])
        beams = candidates[:beam_width]

        # Check if any beam has reached a final answer
        for s, sol in beams:
            if is_complete(sol):
                return sol

    return beams[0][1]  # best partial if no completion

This is dramatically more efficient than best-of-N. Instead of generating N complete solutions and checking at the end, tree search prunes bad paths early. If step 2 has a score of 0.15, we don't waste compute on steps 3-5 of that branch. The PRM acts as a pruning oracle, focusing compute on promising reasoning paths.

The Snell et al. (2024) paper showed that PRM-guided beam search with the same total compute budget as best-of-N achieves significantly higher accuracy, especially on hard problems. The harder the problem, the more tree search outperforms flat sampling — because hard problems have more steps and more opportunities for early pruning.

A 5-step math solution gets the right answer, but step 3 contains a mathematical error that happens to cancel out. How would a PRM and ORM differ in scoring this solution?

Chapter 2: Best-of-N Sampling

The simplest way to spend more compute at inference: generate N candidate solutions and pick the best one. This is best-of-N sampling (also called rejection sampling or reranking). You don't need to modify the model. You don't need to change the training. You just sample N times and select using a reward model.

Think of it like a student taking a test N times and submitting their best attempt. Each attempt uses the same knowledge (same model weights), but randomness in the generation process (temperature sampling) produces different reasoning paths. Some paths happen to navigate around common errors; others stumble into them.

The Algorithm

1. Sample
Generate N complete solutions from the LLM, each with different random sampling (temperature > 0).
2. Score
Run each solution through the reward model (ORM or PRM) to get a quality score.
3. Select
Return the solution with the highest score. Discard the rest.
python
def best_of_n(prompt, model, reward_model, n=8, temp=0.7):
    # Step 1: Generate N candidate solutions
    candidates = [model.generate(prompt, temperature=temp)
                  for _ in range(n)]

    # Step 2: Score each with reward model
    scores = [reward_model.score(prompt, c) for c in candidates]

    # Step 3: Return the best
    best_idx = scores.index(max(scores))
    return candidates[best_idx]  # shape: string

The cost is exactly N times a single generation. For N=8, you use 8x the compute. But the improvement can be enormous: on MATH, going from N=1 to N=100 with a PRM improved GPT-4's accuracy from 52.7% to 78.2% — a 25.5 percentage point gain for just 100x the compute.

How Accuracy Scales with N

The accuracy curve follows a characteristic pattern. Initial doublings of N give large gains: N=1 to N=2 might add 5-8 points. But gains diminish: N=50 to N=100 might add only 1-2 points. This is because easy problems are solved within the first few attempts, and adding more attempts only helps with the remaining hard problems.

The scaling law is roughly logarithmic: accuracy grows as log(N). Doubling N gives a roughly constant accuracy increase. This means 10x more compute gives a fixed additive improvement, not a multiplicative one. Eventually the reward model becomes the bottleneck — even with infinite samples, you can't do better than the reward model's ability to identify correct solutions.

Accuracy(N) ≈ Accuracy(1) + c · log(N)

Where c depends on the reward model quality and the problem difficulty. For easy problems, c is small (the model gets it right on the first try, so more samples barely help). For hard problems, c is larger (each additional sample has a decent chance of finding a new correct path).

Compute Cost Analysis

Let's be precise about the costs. If one generation takes T FLOPs:

ComponentFLOPsNotes
N generationsN × TParallelizable across GPUs
N reward scores (ORM)N × TrmTrm ≈ T (reward model ~ same size)
N reward scores (PRM)N × S × TrmS steps per solution, each scored
SelectionO(N)Trivial: argmax

Total cost with PRM: approximately 2N×T (generation + scoring). The latency, however, can be much lower than 2N×T suggests, because all N generations run in parallel. Wall-clock time is dominated by the single longest generation, not the sum. This makes best-of-N surprisingly practical in batch serving scenarios.

Best-of-N Sampling

Each bar is a candidate solution with a reward score. The best is highlighted. Adjust N to see how more candidates improve the odds of finding a good solution.

N (attempts) 8
Best-of-N is the simplest test-time compute strategy: sample N times, score each, return the best. It requires no model modification — any LLM + any reward model works. Accuracy scales as ~log(N), meaning each doubling of N gives roughly constant improvement. The reward model quality is the ceiling.

Temperature Matters

Best-of-N only works if the N samples are diverse. If temperature is 0 (greedy decoding), every sample is identical — N copies of the same answer. As temperature increases, samples become more diverse but individually lower quality. The sweet spot is typically temperature 0.5-0.8: diverse enough to explore different reasoning paths, coherent enough that good paths appear frequently.

There's a deeper connection to explore-exploit tradeoffs here. Low temperature exploits the model's best guess. High temperature explores the space of possible solutions. Best-of-N with a reward model combines the exploration of high temperature with the selection pressure of the reward model, getting the best of both worlds.

Best-of-N with PRM vs ORM: Side by Side

To make the PRM advantage even more concrete, let's trace through an example with N=4 solutions to a math problem:

SolutionSteps Correct?Final AnswerORM ScorePRM Score (min)
AAll 5 correctCorrect0.920.89
BError at step 3Wrong0.150.12
CError at step 4Correct (lucky)0.880.35
DAll 5 correctCorrect0.900.91

ORM selects: Solution A (score 0.92). Good choice — it's correct. But on the next problem, if the ORM had selected C (score 0.88, second-highest), it would have picked a solution with a fundamental error that happened to produce the right answer. Over many problems, ORM occasionally selects "correct by accident" solutions.

PRM selects: Solution D (min score 0.91). Also correct, and importantly, the PRM correctly identifies C as a poor solution (min score 0.35 from the step 4 error) despite its correct final answer. The PRM never selects "correct by accident" solutions because it catches step-level errors.

This distinction becomes critical at scale. With N=100, the ORM might select one of several "correct by accident" solutions that happen to have high final-answer confidence. The PRM reliably selects solutions with valid reasoning chains. Over thousands of problems, this compounds into a significant accuracy advantage.

The Coverage Question

Best-of-N has an important limitation: coverage. If the model cannot produce the correct answer at all (not with any temperature, any random seed, any prompt), then N=infinity won't help. The coverage of a model on a problem is the probability that at least one of N samples is correct:

Coverage(N) = 1 - (1 - p)N

Where p is the per-sample success probability. If p = 0.1 (the model has a 10% chance of getting it right on any single attempt), then Coverage(100) = 1 - 0.9100 ≈ 99.997%. But if p = 0 (the model fundamentally cannot solve this type of problem), then Coverage(N) = 0 for all N. No amount of sampling helps. This is why model capability (training) still matters — test-time compute amplifies existing capability but cannot create capability from nothing.

Best-of-N with temperature=0 and N=100 gives the same result as N=1. Why?

Chapter 3: Scaling Test-Time Compute

Best-of-N is just one strategy for spending compute at test time. Snell et al. (2024) asked the deeper question: given a fixed compute budget at inference, what is the optimal way to spend it? Their answer depends on problem difficulty, and it fundamentally changes how we think about model deployment.

Consider two extreme strategies for spending N times more compute:

Strategy A: Best-of-N. Run the same model N times, pick the best result. Simple, no model modification needed. Works well when the model can solve the problem but doesn't always — more attempts increase the chance of hitting a correct path.

Strategy B: Revision (sequential refinement). Run the model once, then have it revise its own answer N-1 times. Each revision sees the previous attempt and tries to improve it. This uses the same total compute but concentrates it sequentially. Works well when the model can identify and fix its own errors.

Strategy C: Tree search with PRM. Use the PRM to evaluate partial solutions and guide a beam search through the reasoning tree. At each step, expand the most promising branches and prune the weakest. This uses compute most efficiently but requires a trained PRM.

Each strategy has fundamentally different computational properties:

PropertyBest-of-NRevisionTree Search
ParallelismFully parallelSequentialPartially parallel
MemoryN copies of context1 copy + historyBeam width copies
Requires RM?Yes (for selection)No (self-refine)Yes (for pruning)
Latency= 1 generation= N × generation= depth × beam time
When it shinesMedium difficultyWhen model can self-correctHard, multi-step problems

The Difficulty-Dependent Optimal Strategy

The key finding from Snell et al.: the optimal strategy depends on problem difficulty. This is a deep result — it means there is no universally best inference strategy. The best approach changes based on what you're trying to solve.

Why does this happen? Because different strategies have different scaling properties:

Big model, single shot: Accuracy is determined entirely by model capability. On problems within the model's competence, accuracy is high. On problems beyond it, no single-shot strategy can help. The model either "knows" the answer or it doesn't.

Best-of-N: Accuracy scales as log(N). Good for problems where the model can sometimes produce the right answer but doesn't consistently. More attempts increase coverage. Breaks down when the model never produces the right answer (p = 0).

Tree search + PRM: Accuracy depends on the PRM quality and search depth. Excellent for multi-step problems where errors compound — the PRM catches errors early and redirects. But expensive: O(beam_width × depth × branching_factor) generations.

DifficultyBest StrategyWhy
EasyUse a bigger model (fewer tokens)The small model gets it right on the first try — no search needed. Use the compute to run a better model once.
MediumBest-of-N or revisionThe model can solve it but doesn't always. More attempts or self-correction helps.
HardTree search with PRMThe model needs guidance at every step. Blind sampling wastes compute on doomed paths.
ImpossibleGive up (use compute elsewhere)No amount of test-time compute helps if the model lacks the fundamental capability.
Optimal Test-Time Strategy

Drag the difficulty slider to see which strategy is optimal. On easy problems, a big model wins. On hard problems, search with verification dominates.

Difficulty Medium
There is no single best inference strategy. The optimal use of test-time compute depends on problem difficulty. Easy problems: use a bigger model. Medium problems: sample and select. Hard problems: search with step-level verification. Impossible problems: don't waste the compute. An ideal system classifies problem difficulty and routes to the right strategy.

Compute-Equivalent Comparison

Snell et al. showed that a smaller model (Llama 8B) with compute-optimal test-time strategies can match a 4x larger model (Llama 70B) on many tasks. The smaller model uses its extra compute budget for search and verification rather than raw model capacity. On MATH, test-time compute scaling with a PRM-guided search achieved performance equivalent to a 14x larger model for certain problem distributions.

But this doesn't mean small models always win. The key insight is about compute allocation: for a fixed total compute budget (training + inference), the optimal split between model size and test-time search depends on the task distribution. If most queries are easy, invest in a bigger model. If many queries are hard, invest in search infrastructure.

The Difficulty Router

The practical implication is that an ideal system needs a difficulty router: a lightweight classifier that estimates problem difficulty and routes to the appropriate strategy. This creates a two-stage inference pipeline:

Stage 1: Classify
Lightweight model estimates difficulty from the prompt. O(1) cost. Labels: easy, medium, hard.
Stage 2: Route
Easy → single-shot big model. Medium → best-of-N. Hard → tree search + PRM. Each strategy has different compute/quality profiles.

Training the difficulty router is itself an interesting problem. You can use a proxy: if the model's own confidence (entropy of the output distribution) is high, the problem is probably easy. Low confidence suggests hard. More sophisticated approaches train a small classifier on (prompt, difficulty_label) pairs derived from historical success rates.

Real-World Deployment Numbers

To make this concrete, consider a math tutoring service handling 10,000 queries per hour:

Difficulty% of QueriesStrategyCost/QueryAccuracy
Easy (60%)6,0001-shot, 70B$0.00295%
Medium (30%)3,000Best-of-8, 70B$0.01688%
Hard (10%)1,000Tree search + PRM$0.0575%

Average cost: $0.0114/query. Average accuracy: 91.3%. Compare to uniform best-of-8 everywhere: $0.016/query (40% more expensive) with only marginally better accuracy on easy problems (which were already 95%). The router saves money on easy problems and concentrates compute where it matters.

The router also improves latency. Easy queries get immediate single-shot answers (~200ms). Without the router, every query would wait for 8 generations and PRM scoring (~1600ms). For a chat application where 60% of messages are simple, this means 60% of responses are 8x faster. Users notice the difference.

Connection to Scaling Laws

Kaplan et al. (2020) showed that loss decreases as a power law of training compute: L = C/Nα. Snell et al. showed a similar relationship for test-time compute: accuracy increases as a power law of inference FLOPs, with the exponent depending on problem difficulty and strategy choice.

This means we now have two independent scaling axes. The total capability of a deployed LLM system is a function of both training compute and test-time compute. The frontier models of 2025 push both axes simultaneously: massive pretraining (trillions of tokens) combined with adaptive inference (hundreds to thousands of thinking tokens per query).

The economic implication is that inference compute costs are becoming a significant fraction of total AI compute costs — sometimes exceeding training costs for heavily-queried models. This is driving investment in inference-specific hardware (Groq, Cerebras), specialized architectures (mixture of experts for cheaper inference), and efficiency techniques (speculative decoding, quantization).

For a math competition with extremely difficult problems, which test-time compute strategy is optimal?

Chapter 4: Speculative Decoding

We've been talking about spending more compute for better answers. Now let's flip the script: how do we generate the same answer with less time? This is the problem of inference efficiency, and speculative decoding is one of the most elegant solutions.

The bottleneck of LLM inference is that it's sequential. Each token depends on all previous tokens, so you generate one token at a time. A 70B model generating 500 tokens makes 500 sequential forward passes, each taking ~30ms. That's 15 seconds — and the GPU is underutilized because each forward pass processes a single token through a massive network.

Why is the GPU underutilized? Because during autoregressive decoding, the model processes one token through the entire network. The weight matrices are 70B parameters, but the input is a single vector. The ratio of compute to memory bandwidth is terrible — we're loading billions of parameters from memory just to multiply by one vector. This is called being memory-bandwidth bound rather than compute-bound.

During prefill (processing the prompt), the model processes all tokens in parallel — fully utilizing the GPU's compute. But during generation, it's one token at a time, and the GPU sits mostly idle, waiting for memory. Any technique that processes multiple tokens in a single pass moves us closer to the efficient parallel regime.

Speculative decoding (Leviathan et al., 2022) exploits an asymmetry: verifying a sequence is faster than generating it. If someone hands you a draft essay, you can read it and check it much faster than writing it from scratch. The same is true for LLMs: verifying K tokens in parallel (one forward pass) is much faster than generating K tokens sequentially (K forward passes).

The Two-Model Setup

Speculative decoding uses two models:

Draft model (small, fast). A lightweight model (e.g., 1B parameters) that generates K "draft" tokens quickly. These tokens are approximate — they won't exactly match what the target model would produce, but they're often close enough.

Target model (large, slow). The full-size model (e.g., 70B parameters) that we actually want to sample from. It verifies the draft tokens in a single forward pass by computing the probability of each draft token given all previous context.

The Algorithm

1. Draft
Small model generates K tokens autoregressively: t1, t2, ..., tK. Fast: K × 2ms = 10ms for K=5.
2. Verify
Large model processes ALL K tokens in parallel in ONE forward pass. Gets Ptarget(ti | t<i) for each position. Takes ~30ms regardless of K.
3. Accept/Reject
For each token, compare draft probability q(t) to target probability p(t). Accept with probability min(1, p(t)/q(t)). First rejection stops — resample that token from adjusted distribution. Accept all: bonus token from target.
↻ repeat from step 1

The acceptance criterion ensures the output distribution is identical to standard sampling from the target model. This is not an approximation — the math guarantees exact distributional equivalence. You get the same quality, just faster.

To understand the acceptance criterion intuitively: if the draft model assigns probability q(t) = 0.3 to token "the" and the target assigns p(t) = 0.6, then min(1, p/q) = min(1, 2.0) = 1. We always accept — the target agrees even more strongly. But if q(t) = 0.6 and p(t) = 0.3, then min(1, p/q) = min(1, 0.5) = 0.5. We accept with 50% probability. The draft model was overconfident about this token, and we reject it half the time to match the target's softer preference.

The beauty is that rejected tokens aren't wasted. When we reject at position i, we know exactly what the target distribution looks like (we computed it during verification), so we can immediately sample the correct token from the residual distribution without another forward pass. The entire draft-verify-accept/reject cycle produces tokens from the exact same distribution as if we had only used the target model.

Why It's Faster

Without speculative decoding: each token costs one forward pass of the target model (~30ms). For 500 tokens: 500 × 30ms = 15 seconds.

With speculative decoding (K=5, 70% acceptance rate): on average, each cycle drafts 5 tokens (~10ms), verifies them (~30ms), and accepts ~3.5. Each cycle takes ~40ms and produces ~3.5 tokens. For 500 tokens: 500/3.5 × 40ms ≈ 5.7 seconds. That's a 2.6x speedup with identical output quality.

Let's break down the math more carefully. In each cycle:

StepOperationTimeGPU Utilization
Draft 5 tokens5 sequential 1B forward passes5 × 2ms = 10msLow (small model)
Verify 5 tokens1 parallel 70B forward pass30msHigh (batch of 5)
Accept/rejectCompare probabilities<0.1msN/A
Total cycle~40ms
Expected tokens~3.5 accepted + 1 resampled

The key insight: the verification step processes 5 tokens in the same time as processing 1 token. This is because the forward pass is memory-bandwidth bound — loading 70B parameters from GPU memory takes the same time whether you multiply by 1 vector or 5. The marginal cost of verifying extra draft tokens is nearly zero.

Speedup ≈ E[accepted + 1] / (K · tdraft + ttarget) × ttarget

The speedup depends on the acceptance rate, which depends on how well the draft model approximates the target. A well-chosen draft model (trained on similar data, same tokenizer) gets 60-80% acceptance rates on typical text.

Speculative Decoding

Watch the draft model (small, fast) propose tokens, then the target model (large, slow) verify them in parallel. Green = accepted, red = rejected. Compare the speed counter to standard autoregressive decoding.

Standard: 0 tokens  ·  Speculative: 0 tokens  ·  Speedup: 1.0x
Speculative decoding is "write fast, verify fast." A small draft model proposes K tokens cheaply. The large target model verifies all K in a single parallel pass. Accepted tokens are mathematically identical to what the target would have generated alone. Typical speedup: 2-3x with no quality loss.

Choosing the Draft Model

The draft model should be:

PropertyWhy
Same tokenizer as targetTokens must correspond 1:1 for verification to work
Much smaller (5-20x fewer params)Must be fast enough that drafting K tokens is cheaper than one target pass
Trained on similar data distributionHigher acceptance rate = more speedup. Mismatched distribution = most drafts rejected
Not too smallVery small models have low acceptance rates, wasting the verification pass

In practice, companies train dedicated draft models alongside their flagship models. Llama 3.1 70B uses a distilled 8B version as its draft model. GPT-4 likely uses a smaller GPT-3.5-class model internally.

The Acceptance Criterion in Detail

The math behind the accept/reject step is elegant. Let q(x) be the draft model's probability for token x, and p(x) be the target model's probability. For each draft token t:

Accept t with probability min(1, p(t) / q(t))

If the draft model assigns lower probability than the target (p(t) > q(t)), we always accept — the draft was conservative, and the target agrees even more. If the draft assigns higher probability (q(t) > p(t)), we accept proportionally — the draft was overconfident, so we reject some fraction of the time.

When we reject at position i, we resample from a residual distribution:

p'(x) = max(0, p(x) - q(x)) / Z   where Z = ∑x max(0, p(x) - q(x))

This residual distribution fills in exactly what the draft model missed. The combined process — accept with probability min(1, p/q), resample from residual on rejection — produces samples from p(x) exactly. This is a variant of rejection sampling, and the proof relies on the fact that the acceptance probability times the draft distribution plus the residual correction equals the target distribution.

Variants and Extensions

Several extensions to speculative decoding have emerged:

Medusa (Cai et al., 2024): Instead of using a separate draft model, adds multiple prediction heads to the target model itself. Each head predicts 1, 2, 3, ... tokens ahead. Cheaper than maintaining two models, but slightly lower acceptance rates.

Lookahead decoding: Uses the target model's own n-gram patterns as the "draft." No separate model needed at all, but limited speedup (1.5-2x vs 2-3x). Works by maintaining a pool of candidate continuations from previous generations and verifying them in parallel.

Staged speculative decoding: Uses a cascade of draft models (tiny → small → medium) before the final verification. Each stage filters out easy-to-reject tokens, reducing load on larger models. For example: a 0.5B model drafts 10 tokens, a 3B model verifies and keeps 7, then the 70B model verifies those 7 and accepts 5. The cascade structure maximizes the acceptance rate at each level.

Self-speculative decoding: Uses early layers of the target model itself as the draft model (by skipping the last N layers). No separate model needed, and the draft model perfectly shares the same tokenizer and training distribution. The tradeoff is that early exits from a 70B model are faster than a full forward pass but slower than a dedicated 1B model.

In production, speculative decoding is most beneficial for interactive applications where latency matters (chatbots, code completion). For batch processing where throughput matters more than latency, simply running more queries in parallel on the target model is often more efficient than speculative decoding.

Speculative decoding produces outputs with slightly lower quality than standard decoding, trading quality for speed. True or false?

Chapter 5: RoPE

We switch gears now from inference strategies to a foundational component of modern Transformers: how does the model know the order of tokens? Without positional information, "The dog bit the man" and "The man bit the dog" produce identical attention patterns. Position encoding is what makes word order matter.

Earlier Transformers used absolute positional encodings: each position gets a fixed vector added to the token embedding. Position 0 gets vector p0, position 1 gets p1, and so on. These can be learned (BERT, GPT-2) or sinusoidal (original Transformer). The problem: they don't generalize to positions the model hasn't seen during training. If you train on sequences of length 2048, position 2049 has no learned embedding.

Why does position matter at all? Attention computes QTK for every pair of tokens. Without position information, attention is permutation-invariant — it treats "Alice loves Bob" identically to "Bob loves Alice." The queries and keys are computed from token embeddings alone, which carry semantic but not positional information. Some form of position encoding is essential for language understanding.

Rotary Position Embedding (RoPE), introduced by Su et al. (2021), solves this with a beautiful geometric idea: encode position by rotating the embedding vector.

The Core Idea: Rotation

Instead of adding a position vector, RoPE multiplies (rotates) the query and key vectors at each position. The angle of rotation depends on the position. Here's the key property: when you compute the dot product between a query at position m and a key at position n, the result depends only on the relative distance |m - n|, not on the absolute positions.

Think of it like clock hands. If one hand points at 3 o'clock and another at 5 o'clock, the angle between them is 2 hours — regardless of what time it actually is. RoPE makes attention work the same way: the "angle" between query and key depends on how far apart they are, not where they are in absolute terms.

This is a profound shift from absolute to relative position encoding. With absolute encoding, the model must memorize "position 42 means X." With RoPE, the model learns "distance 5 means nearby" and "distance 500 means far away" — patterns that transfer across any starting position. This is why RoPE generalizes better to unseen positions, and why it's the foundation for context extension.

The Math (2D case)

For a 2D vector [x1, x2], rotation by angle θ is:

R(θ) · [x1, x2]T = [x1cosθ - x2sinθ,   x1sinθ + x2cosθ]T

RoPE rotates the query at position m by angle mθ and the key at position n by angle nθ. Their dot product becomes:

qmT kn = (R(mθ)q)T(R(nθ)k) = qT R((n-m)θ) k

The rotation R((n-m)θ) depends only on the relative position (n-m). This is the magic: relative position information emerges naturally from rotation, without explicit relative position biases.

For the full embedding dimension d (typically 128 per head), RoPE applies independent rotations to consecutive pairs of dimensions (1-2, 3-4, 5-6, ...), each with a different base frequency. There are d/2 = 64 independent rotations happening simultaneously, each encoding position at a different scale. The frequencies follow a geometric series:

θi = base-2i/d,   i = 0, 1, ..., d/2-1   (base = 10000)

Low-frequency dimensions capture long-range position differences (like the hour hand), while high-frequency dimensions capture local position differences (like the second hand). This multi-scale encoding is analogous to the original sinusoidal positional encoding but integrated into attention rather than added to embeddings.

To make this concrete: dimension pair 0 with θ0 = 1.0 rotates by 1 radian per position. After 6 positions, it's completed nearly a full rotation. So it can distinguish "2 positions apart" from "4 positions apart" with high sensitivity. But dimension pair 63 (for d=128) with θ63 ≈ 0.00001 barely rotates at all in 6 positions. It needs thousands of positions before the rotation becomes significant. This pair distinguishes "position 1000" from "position 5000" but can't tell "position 1" from "position 6."

The full RoPE encoding for a single token is a d/2-dimensional rotation vector, where each component captures position information at a different scale. The attention mechanism's dot product naturally combines all these scales, giving the model both fine-grained local and coarse long-range position information simultaneously.

RoPE: Rotation in 2D

Drag the position slider to rotate the query vector. The key stays fixed. The dot product (similarity) depends on the angle between them — which depends only on relative position.

Query position 0
RoPE encodes position by rotating query/key vectors. The dot product between rotated vectors depends only on relative position, not absolute. This is why RoPE enables length generalization: the rotation operation itself works at any position. The challenge is that the model's attention patterns may not generalize to rotation angles it hasn't seen during training — which is why context length extension is needed.

Implementation in PyTorch

python
import torch

def apply_rope(x, positions, base=10000):
    # x: (batch, seq_len, n_heads, head_dim)
    d = x.shape[-1]
    # Frequencies: geometric series from 1 to 1/base
    freqs = 1.0 / (base ** (torch.arange(0, d, 2).float() / d))
    # Angles: position * frequency, shape: (seq_len, d/2)
    angles = positions.unsqueeze(-1) * freqs.unsqueeze(0)
    cos_a = angles.cos()   # (seq_len, d/2)
    sin_a = angles.sin()   # (seq_len, d/2)

    # Split into pairs and rotate
    x1 = x[..., ::2]     # even dims
    x2 = x[..., 1::2]    # odd dims
    # 2D rotation: [x1*cos - x2*sin, x1*sin + x2*cos]
    out1 = x1 * cos_a - x2 * sin_a
    out2 = x1 * sin_a + x2 * cos_a
    # Interleave back: shape unchanged
    return torch.stack([out1, out2], dim=-1).flatten(-2)

Notice: no learned parameters. RoPE is entirely deterministic — the rotation angles are computed from position indices and fixed frequencies. This means it works at any sequence length without any modification to model weights.

Why RoPE Won

RoPE is used by nearly every modern LLM: Llama, Mistral, Gemma, Qwen, DeepSeek, Phi. Why did it beat the alternatives?

MethodTypeRelative?Length GeneralizationUsed By
Sinusoidal (Vaswani 2017)Additive, fixedNoTheoretical onlyOriginal Transformer
Learned absolute (GPT-2)Additive, learnedNoNoneGPT-2, BERT
Relative bias (T5)Bias on attentionYesModerateT5, FLAN
ALiBi (Press 2022)Linear biasYesGoodBLOOM, MPT
RoPERotation on Q/KYesGood + extensibleLlama, Mistral, etc.

RoPE's advantages: (1) naturally relative without explicit bias terms, (2) no extra parameters, (3) compatible with context extension techniques (PI, NTK, YaRN), (4) works in both causal and bidirectional attention. The rotation formulation is also GPU-friendly — it's just element-wise multiplication and addition, no attention mask modifications needed.

Multi-Scale Intuition

The geometric frequency series θi = 10000-2i/d creates a multi-scale representation. Dimension pair 0 has the highest frequency (completes a full rotation every ~6 positions). Dimension pair d/2-1 has the lowest frequency (completes a rotation every ~10000 positions). This means:

High-frequency pairs distinguish adjacent tokens: position 5 vs 6. They're like the second hand on a clock — changing rapidly, sensitive to small differences. These matter for local syntax: word order within a phrase.

Low-frequency pairs distinguish distant tokens: position 100 vs 200. They're like the hour hand — changing slowly, encoding coarse position. These matter for document structure: which paragraph we're in.

This multi-scale design is why context extension methods treat different frequency bands differently (YaRN). Extending context length primarily affects low-frequency dimensions that encode long-range position — local structure stays intact.

Why does RoPE naturally encode relative position rather than absolute position?

Chapter 6: Context Length Extension

RoPE gives us position encoding that theoretically works at any length. But in practice, a model trained on 4K context length fails catastrophically at 8K. Why? Because even though the rotation operation is defined at position 8000, the model has never seen attention patterns at that distance. The attention weights learned during training don't generalize to unseen rotation angles.

This is like learning to read a clock up to 12 hours, then suddenly being asked to read a 24-hour clock. The math of rotation is the same, but you've never practiced interpreting angles in that range. The attention heads have learned specific patterns for rotation angles they've seen (0 to 4096θ), and angles outside this range produce unfamiliar patterns that the model interprets as noise.

More precisely, the problem is in the low-frequency dimensions. High-frequency dimensions (which encode local position) cycle through their full range within the training length. But low-frequency dimensions (which encode long-range position) have only completed a fraction of a full rotation during training. At 2x the training length, they encounter angles they've literally never seen.

The community developed several techniques to extend trained models to longer contexts without full retraining. They all modify RoPE's frequencies to compress or interpolate positions into the range the model already understands.

Method 1: Position Interpolation (PI)

The simplest approach: if the model was trained on positions 0-4096, and we want to handle positions 0-8192, just scale all positions by 0.5. Position 8192 becomes 4096 — within the training range. Chen et al. (2023) called this Position Interpolation.

position'PI = position × (Ltrain / Ltarget)

The problem: it compresses nearby positions closer together, reducing the model's ability to distinguish adjacent tokens. If positions 0 and 1 become 0 and 0.5, the rotation angle difference halves, and fine-grained local attention suffers. The model trained to distinguish "I ate cake" from "I ate the cake" now struggles because the positional signal between "ate" and "cake" has been cut in half.

In experiments, PI works well for moderate extensions (2-4x) but degrades noticeably at 8x+. The local resolution loss becomes too severe — the model confuses nearby tokens and produces garbled text at short range even while maintaining coherent long-range structure.

Method 2: NTK-Aware Scaling

NTK-aware scaling (from the "Code Llama" approach) takes a smarter route. Instead of scaling all positions uniformly, it increases the base of the frequency computation. This effectively stretches the lower frequencies (long-range) while leaving higher frequencies (local) mostly unchanged.

θi' = (base × α)-2i/d   where α = (Ltarget / Ltrain)d/(d-2)

The intuition: local position information (high-frequency dimensions) is more important for language understanding than long-range position information (low-frequency dimensions). NTK-aware scaling preserves local resolution while extending the range.

The name "NTK-aware" comes from the Neural Tangent Kernel connection: the scaling formula is derived by analyzing how the model's feature space changes as positions extend, and finding the base adjustment that minimizes distortion in the NTK regime. In practice, you don't need to understand the NTK theory — just modify the base constant and the method works.

NTK-aware scaling handles 4-16x extensions well. Code Llama used this to extend Llama 2's 4K context to 16K and then to 100K. The quality at 100K is noticeably worse than at 16K (especially for retrieval tasks), but the model remains functional and doesn't catastrophically fail.

Method 3: YaRN (Yet another RoPE extensioN)

YaRN (Peng et al., 2023) combines interpolation with attention temperature scaling. It divides RoPE dimensions into three groups:

Frequency GroupTreatmentWhy
High frequency (local)No modificationLocal patterns don't need extension
Medium frequencyNTK interpolationSmooth transition between local and global
Low frequency (global)Linear interpolationLong-range positions compressed to training range

Additionally, YaRN scales the attention logits by a temperature factor sqrt(1/s) where s is the extension ratio. This compensates for the increased entropy in the attention distribution when positions are compressed.

Why Attention Temperature Matters

When you compress positions, the rotation angle differences between adjacent tokens decrease. This means QTK values for nearby tokens become more similar — the attention distribution becomes flatter (higher entropy). The model "forgets" which nearby tokens are most relevant, attending to everything equally.

Temperature scaling counteracts this. By dividing attention logits by sqrt(s), we sharpen the distribution back to its original entropy level. Without this correction, the model's local attention patterns degrade even though the global structure is preserved.

Attention(Q, K, V) = softmax(Q KT / (√d · √s)) · V

Where s is the scaling ratio (target_len / train_len). For 4K → 32K, s = 8, so we divide by an additional √8 ≈ 2.83. This is a subtle but important correction that separates YaRN's quality from simpler methods.

Context Length Extension Methods

The heatmap shows attention score vs. position distance. Drag the context length slider to extend. Toggle between methods to see how they handle long-range positions.

Context length 4K
All context extension methods modify RoPE frequencies to map new positions into the training range. PI scales everything uniformly (simple but loses local resolution). NTK adjusts the base frequency (preserves local detail). YaRN treats different frequency bands differently (best quality). All require a small amount of fine-tuning (~1000 steps) to work well.

Practical Impact

These techniques transformed what's deployable. Llama 2 trained on 4K context was extended to 100K+ with YaRN. Code Llama extended to 100K with NTK-aware scaling for long-document code understanding. GPT-4 Turbo's 128K context window likely uses similar techniques.

The fine-tuning requirement is minimal: ~1000 steps on long-context data (about 1 GPU-hour) suffices to adapt the attention patterns. Compare this to pretraining from scratch on longer sequences, which would cost millions of GPU-hours. Context extension is perhaps the highest-ROI technique in modern LLM engineering.

Implementation: Extending Llama's Context

Let's see what context extension looks like in practice. Here's Position Interpolation applied to a RoPE model:

python
import torch

def apply_rope_with_pi(x, positions, base=10000,
                        train_len=4096, target_len=32768):
    # Position Interpolation: scale positions down
    scale = train_len / target_len  # 0.125 for 4K → 32K
    positions = positions * scale   # pos 32768 → 4096

    d = x.shape[-1]
    freqs = 1.0 / (base ** (torch.arange(0, d, 2).float() / d))
    angles = positions.unsqueeze(-1) * freqs.unsqueeze(0)
    cos_a, sin_a = angles.cos(), angles.sin()

    x1, x2 = x[..., ::2], x[..., 1::2]
    return torch.stack([
        x1 * cos_a - x2 * sin_a,
        x1 * sin_a + x2 * cos_a
    ], dim=-1).flatten(-2)

def apply_rope_ntk(x, positions, base=10000,
                     train_len=4096, target_len=32768):
    # NTK-Aware: scale the base frequency instead
    d = x.shape[-1]
    ratio = target_len / train_len
    new_base = base * ratio ** (d / (d - 2))
    # Same formula, just different base
    freqs = 1.0 / (new_base ** (torch.arange(0, d, 2).float() / d))
    angles = positions.unsqueeze(-1) * freqs.unsqueeze(0)
    cos_a, sin_a = angles.cos(), angles.sin()
    x1, x2 = x[..., ::2], x[..., 1::2]
    return torch.stack([
        x1 * cos_a - x2 * sin_a,
        x1 * sin_a + x2 * cos_a
    ], dim=-1).flatten(-2)

Notice how minimal the code change is. Position Interpolation adds one line (positions * scale). NTK-aware scaling changes one constant (the base). The entire inference stack stays the same. This simplicity is why these techniques are so widely adopted.

The Needle-in-a-Haystack Test

How do we evaluate whether context extension actually works? The needle-in-a-haystack test embeds a random fact ("The best thing to do in San Francisco is eat a sandwich and sit in Dolores Park on a sunny day") at various positions in a long document, then asks the model to recall it. A heatmap of position vs. context length reveals where the model's attention breaks down.

Without extension: the model retrieves perfectly up to the training length (4K), then recall drops to near zero. With PI: retrieval works at all positions but becomes fuzzy — the model sometimes confuses adjacent facts. With YaRN: near-perfect retrieval up to 128K, with graceful degradation beyond.

The needle-in-a-haystack test has become the standard benchmark for context extension quality. A clean heatmap (all green) means the model can reliably attend to any position in its context window. Dark patches indicate "dead zones" where the model loses attention — typically at the boundaries of the original training range, in the middle of very long documents (the "lost in the middle" problem), or at positions where frequency bands create destructive interference.

In practice, most applications need reliable retrieval at all positions, not just on average. A legal AI that ignores a clause because it falls in a dead zone is worse than useless. This is why YaRN's comprehensive approach (treating each frequency band appropriately) wins over simpler methods for production deployments, despite the slightly more complex implementation.

A model trained with RoPE on 4K context fails at 8K even though RoPE's rotation formula works at any position. Why?

Chapter 7: Inference Explorer

Time to put it all together. This simulation runs three inference strategies side by side on the same set of problems, so you can see the speed/quality/cost tradeoffs in real time.

Pipeline 1: Standard autoregressive decoding. The target model generates one token at a time. Reliable quality, predictable speed, baseline cost.

Pipeline 2: Speculative decoding. A draft model proposes tokens, the target model verifies. Same quality as standard, but 2-3x faster. Slightly higher peak memory (two models loaded).

Pipeline 3: Best-of-N with PRM. The model generates N=4 complete solutions, a PRM scores each, and we return the best. 4x slower than standard, but dramatically higher quality on hard problems. The PRM adds marginal compute.

Watch how these strategies perform differently depending on problem difficulty. Easy problems: all three get it right, speculative is just faster. Hard problems: best-of-N pulls ahead in quality while standard and speculative both struggle.

Inference Strategy Race

Three strategies race to solve problems. Watch speed (tokens generated), quality (correct solutions), and cost (total FLOPs) diverge as difficulty increases.

Difficulty Easy
Standard: waiting Speculative: waiting Best-of-4+PRM: waiting

Notice the fundamental tradeoff triangle:

StrategySpeedQuality (hard)Cost
Standard1xBaseline1x
Speculative2-3x fasterSame as standard~1.1x (draft model is cheap)
Best-of-N + PRMNx slowerMuch higher~Nx + PRM overhead

In production, these strategies are often combined. A serving system might use speculative decoding for all requests (free speedup), then additionally use best-of-N for requests flagged as high-difficulty (math, code, complex reasoning). The routing decision itself can be made by a lightweight classifier trained on problem difficulty.

Reading the Race

Here's what to look for in the simulation:

On easy problems (slider left): All three strategies solve most problems correctly (green boxes). The standard pipeline finishes first in wall-clock time for a single problem, but speculative decoding is faster overall. Best-of-N wastes compute generating 4 solutions when 1 was sufficient.

On hard problems (slider right): Standard and speculative both produce many red boxes (wrong answers). Best-of-N pulls ahead in green boxes because its 4 attempts, filtered by the PRM, find correct solutions more often. The extra compute pays for itself in quality.

The insight: There's no free lunch. Speculative decoding gives speed for free (same quality, less time). Best-of-N gives quality for cost (better answers, more compute). The optimal system uses both: speculative decoding for speed at all difficulty levels, plus best-of-N for accuracy when the problem demands it.

Production Inference Architecture

Modern LLM serving systems combine all three strategies in a layered architecture:

Layer 1: Difficulty Router
Classify incoming request as easy/medium/hard. Use entropy of first few tokens as a heuristic, or a trained classifier. ~1ms overhead.
Layer 2: Strategy Selector
Easy → single pass. Medium → best-of-4. Hard → tree search + PRM. Route to the appropriate inference pipeline.
Layer 3: Speculative Execution
ALL pipelines use speculative decoding for the individual generations. This is a free speedup layer that stacks with any strategy.
Layer 4: Verification
For medium/hard: PRM scores the candidates. For all: optional consistency check (do multiple attempts agree?). Return the best verified response.

This layered architecture is how systems like ChatGPT likely work behind the scenes. The user sees a single response, but the system may have generated multiple candidates, verified them, and selected the best — all while using speculative decoding to minimize latency.

Chapter 8: The Compute Frontier

We now have two knobs for improving model capability: train-time compute (bigger model, more data, longer training) and test-time compute (search, verification, multiple attempts). How should we split a fixed total budget between them?

Traditional scaling laws (Kaplan et al., 2020; Hoffmann et al., 2022) focused entirely on training. The Chinchilla result said: for a given training budget, there's an optimal model size and data amount. But these laws assumed inference is free — one forward pass per query.

When we add test-time compute as a variable, the landscape changes. A smaller model + search at inference can achieve the same capability as a bigger model + single-shot inference. The total compute might be the same, but the allocation differs.

Amortized vs. Per-Query Cost

An important economic distinction: training compute is a one-time, amortized cost. You train the model once and serve it to millions of users. Test-time compute is a per-query cost. Every single query pays the inference tax.

This means the optimal split depends on query volume. For a service handling 1 billion queries per day (like ChatGPT), a 10x increase in per-query compute costs 10x more in ongoing GPU spend. Training a 10x bigger model is expensive upfront but adds zero marginal cost per query. At sufficient scale, training compute is almost always cheaper than test-time compute for the same capability level.

But there's a nuance: test-time compute is adaptive. You can use 1x compute for easy queries and 100x for hard ones. Training compute is fixed — the 70B model uses 70B parameters for "what is 2+2?" just as it does for AIME problems. The ability to allocate compute on-demand is test-time scaling's greatest advantage, and it becomes increasingly important as query distributions become more skewed.

The Frontier Curve

Imagine a 2D plot: x-axis is training compute, y-axis is test-time compute. Each point represents a (model, inference strategy) pair. Draw an iso-capability curve — all (training, test-time) combinations that achieve the same accuracy on a benchmark. This curve is the compute frontier.

The frontier is convex: you can trade training compute for test-time compute, but with diminishing returns in both directions. Too much training, too little inference search: the model is big but inflexible — it can only answer as well as its single forward pass allows. Too little training, too much inference search: the model lacks fundamental capabilities that no amount of search can compensate for — you can't search your way to understanding calculus if the model has never seen calculus during training.

The frontier shape tells us something important: there's an optimal balance point. For any given total compute budget, there exists a unique (training, inference) split that maximizes capability. Moving to either extreme wastes compute. The practical challenge is that this optimal split changes depending on the task distribution, the deployment scale, and the latency requirements.

Train-Time vs. Test-Time Compute Frontier

The curve shows all (training, test-time) pairs that achieve the same accuracy. Drag the budget slider to increase total compute and watch the frontier shift outward. Click points on the curve to see what model + strategy they represent.

Total budget Medium

Where Are We Today?

The industry has shifted from purely scaling training to also scaling inference. The evidence:

SignalWhat it means
OpenAI o1/o3 modelsModels explicitly trained to use more test-time compute via long chain-of-thought
Google DeepMind AlphaProofMCTS-style search at inference solves IMO problems
Anthropic Claude "extended thinking"Allocates more tokens for harder problems
DeepSeek-R1Open-weight reasoning model trained with reinforcement learning to think step by step

The pattern is clear: the frontier models of 2024-2025 derive much of their capability from test-time compute, not just model size. The research question has shifted from "how big should the model be?" to "how should we split compute between training and inference?"

The compute frontier has two axes: training and inference. Modern AI advances come from scaling BOTH. Chinchilla optimized the training axis alone. The next frontier — and the one actively being explored — optimizes the joint allocation. The models that seem "smarter" than their size suggests are likely spending more compute at inference time.

The Future: Adaptive Compute

The ideal system doesn't use a fixed inference strategy. It adapts: easy questions get one-shot answers (fast, cheap). Hard questions get tree search with PRM verification (slow, expensive, accurate). The system itself decides how much to think.

This is already happening. OpenAI's o1 model uses variable-length chain-of-thought: it "thinks" for more tokens on harder problems. The model itself has learned when to think harder — a meta-capability trained via reinforcement learning. Anthropic's Claude similarly uses "extended thinking" for complex queries. The amount of test-time compute is no longer fixed but adaptive.

This creates a new optimization problem: given a stream of queries with varying difficulty, how do you allocate a fixed inference budget to maximize total quality? This is a scheduling problem, and it's an active area of research in ML systems.

The o1 / R1 Paradigm

OpenAI's o1 and DeepSeek's R1 represent the most visible test-time compute scaling in production. These models are trained via reinforcement learning to produce long chains of thought — often hundreds of tokens of "thinking" before the final answer. The RL reward encourages correct final answers, and the model learns to spend more thinking tokens on harder problems.

The key architectural insight: the model itself learns when to think harder. Unlike best-of-N (external search), o1-style reasoning is internal search — the model explores multiple approaches within a single generation, backtracking and trying alternatives as part of its chain of thought.

python
# Simplified o1-style training loop
for problem in training_set:
    # Model generates reasoning + answer
    cot_output = model.generate(problem, max_tokens=4096)
    answer = extract_answer(cot_output)

    # RL reward: +1 if correct, -1 if wrong
    reward = 1.0 if answer == problem.ground_truth else -1.0

    # PPO update: model learns to produce CoTs that lead to correct answers
    # Emergent behavior: harder problems → longer CoTs
    loss = ppo_loss(cot_output, reward)
    optimizer.step(loss)

The result: on AIME 2024 (hard competition math), o1 solved 83% of problems by "thinking" for an average of 2,000 tokens. GPT-4 without extended thinking solved 13%. The difference is almost entirely attributable to test-time compute — the base model capabilities are similar, but o1 was trained to use them more effectively through longer reasoning chains.

DeepSeek-R1 demonstrated that this approach works with open-weight models too. Trained with RL on mathematical reasoning, R1 learned to produce long, structured chains of thought with self-correction ("Wait, let me reconsider..."), exploration of alternatives ("Another approach would be..."), and explicit verification ("Let me check: if x = 3, then..."). These behaviors emerged from RL training without explicit programming — the model discovered that thinking more carefully leads to more reward.

The compute scaling law for these models follows a power law: performance improves as a power of thinking tokens. Double the thinking budget (from 1000 to 2000 tokens), get a consistent accuracy improvement. This is the test-time analog of neural scaling laws, and it suggests that we're far from the ceiling of what's achievable by thinking harder.

The implication for practitioners: when evaluating an LLM's reasoning ability, always specify the inference budget. A model that scores 50% on MATH with 100 thinking tokens might score 80% with 2000 thinking tokens. The model didn't get "smarter" — it got more compute. Benchmark results without inference budget disclosure are increasingly misleading.

Why have modern AI labs shifted from just scaling model size to also scaling test-time compute?

Chapter 9: Connections

This lecture sits at the intersection of two major threads in modern AI: making models reason better (test-time compute) and making inference practical (speculative decoding, context extension). These ideas connect deeply to recent research.

Key Papers

PaperContributionConnection
Let's Verify Step by Step (Lightman 2023)Process reward models with human step-level labelsFoundation for PRM-guided search (Ch 1-2)
Speculative Decoding (Leviathan 2022)Draft-verify paradigm for lossless speedupCore of Ch 4 — the accept/reject criterion
Scaling Test-Time Compute (Snell 2024)Compute-optimal inference strategiesDifficulty-dependent strategy selection (Ch 3)
RoPE (Su 2021)Rotary position encoding via rotationPosition encoding and context extension (Ch 5-6)

Lecture Connections

LectureRelationship
L12: Reasoning Part 1Chain-of-thought, prompting strategies — the foundation this lecture builds on. L12 shows how to make models think step by step. L13 shows how to do it at scale with verification.
L14: ACL Guest LectureBroader perspective on where the field is going, including reasoning benchmarks and evaluation.

The Big Picture

Lecture 13 Concept Map

The two threads of this lecture — reasoning at scale (left) and efficient inference (right) — converge at the compute frontier.

The story of this lecture in one sentence: modern inference is not just "run the model once" — it's a system of drafting, verifying, searching, and selecting that turns raw model capability into reliable, efficient performance.

What separates a research model from a production model is increasingly the inference stack, not the model weights. Speculative decoding makes it fast. PRM-guided search makes it accurate. Context extension makes it handle real-world documents. Together, these techniques define the modern LLM serving infrastructure.

Historical Arc

It's worth stepping back to see how far inference has come:

EraInference StrategyBottleneck
2018 (BERT)Single forward pass, no generationModel size
2020 (GPT-3)Greedy or nucleus samplingContext length (2K)
2022Chain-of-thought promptingPrompt engineering
2023Best-of-N + reward modelsReward model quality
2024PRM-guided tree search + speculative decodingDifficulty estimation, PRM training data
2025Adaptive compute (o1, R1) + long context (1M tokens)Inference cost management

Each step expanded what's possible at inference time. We went from "run the model and take whatever it says" to "run it many times, verify each step, search through possibilities, and do it all fast with speculative decoding." The inference stack is now as complex and important as the training stack.

Looking Ahead

Several open questions define the next frontier of inference research:

1. How to train better PRMs without human labels. The 800K human-labeled steps in Lightman et al. are expensive. Can we use model-generated labels (self-play, MCTS) to train equally good PRMs? Early results from Math-Shepherd and other work suggest yes, but the quality gap isn't fully closed.

2. How to learn when to stop thinking. o1-style models sometimes "overthink" easy problems, wasting tokens on unnecessary reasoning. The optimal stopping criterion — when has the model thought enough? — is an open problem.

3. How to combine test-time compute with tool use. What if the model can call a calculator mid-reasoning, or verify a step by running code? Combining search (this lecture) with tools (L10) creates a richer action space for test-time compute. AlphaCode 2, for instance, generates many candidate programs, runs them against test cases (tool use), and selects the best by execution results — combining sampling, verification, and tool use into a single inference pipeline.

4. How to extend context efficiently. Current context extension works up to ~128K tokens, but real-world documents (codebases, legal contracts, medical records) can be millions of tokens. Ring attention, infinite attention, and retrieval-augmented approaches are pushing toward million-token contexts, but efficient attention mechanisms remain an active research area.

5. How to make PRMs cheaper to train. Can models self-supervise step-level correctness by checking consistency across multiple solution paths? If three independent attempts all agree on step 2 but disagree on step 3, that's evidence step 3 is the error point. This "verification by consensus" could generate PRM training data without human labelers.

What We Covered

ConceptCore IdeaPractical Impact
Process Reward ModelsScore every reasoning step, not just final answerPRM + best-of-100 adds 25.5% to GPT-4 math accuracy
Best-of-N SamplingSample N times, return the bestSimplest test-time scaling; accuracy ~ log(N)
Compute-Optimal SearchOptimal strategy depends on problem difficultySmall model + search can match 4x larger model
Speculative DecodingDraft fast, verify in parallel, identical output2-3x speedup with zero quality loss
RoPEEncode position by rotating Q/K vectorsStandard in all modern LLMs; enables context extension
Context ExtensionScale RoPE frequencies to handle longer contexts4K → 100K+ with ~1 GPU-hour of fine-tuning
"The best way to predict the future of AI is to look at where compute is being allocated." Training compute got us from GPT-2 to GPT-4. Test-time compute is getting us from GPT-4 to o1. The next frontier will optimize both jointly — adaptive systems that decide how much to think, how to verify, and when to give up.