The tempo of training — from warmup through cosine decay to the WSD schedule behind every modern LLM.
Set the learning rate to 0.1 and training explodes — loss shoots to infinity. Set it to 0.00001 and training crawls — after a thousand steps you've barely moved. There's a sweet spot, but here's the problem: the sweet spot changes as training progresses.
The learning rate (LR) is the single most important hyperparameter in all of deep learning. It controls how big a step the optimizer takes in the direction the gradient points. Too big, and you overshoot the minimum. Too small, and you waste compute crawling toward it.
But "too big" and "too small" aren't fixed thresholds. Early in training, when the model is far from any minimum, big steps are fine — they help you cover ground quickly and skip past shallow local minima. Late in training, when you're near a good minimum, those same big steps send you bouncing right over it.
Think of the loss function as a hilly terrain. Your model is a ball sitting somewhere on that terrain, and the gradient tells you which direction is downhill. The learning rate controls how far you roll the ball each step.
At a high learning rate, the ball takes enormous leaps. It can jump across valleys, clear ridges, and explore widely. But it can also sail right past the deepest valley and bounce endlessly between slopes. The ball has too much energy to settle.
At a low learning rate, the ball barely moves. It trickles into the nearest dip and sits there, even if that dip is shallow and a much deeper valley exists just over the next hill. The ball doesn't have enough energy to explore.
The simulation below lets you experience both failures firsthand.
Drag the learning rate slider and press "Run" to watch gradient descent navigate a 2D loss landscape. At high LR, the ball overshoots and bounces. At low LR, it barely moves. Try to find the sweet spot — then watch it STILL oscillate near the minimum.
Let's strip away everything and work with the simplest possible loss function: f(x) = (x - 3)². The minimum is at x = 3. We'll start at x = 0 and use gradient descent with different constant learning rates.
The gradient of f(x) = (x - 3)² is:
The update rule is:
where η is the learning rate. Let's trace three scenarios.
Here's the critical insight that motivates this entire lesson:
| LR | After 1 Step | After 10 Steps | Behavior |
|---|---|---|---|
| 0.01 | 0.06 | 0.537 | Crawling |
| 0.5 | 3.0 | 3.0 | Perfect (but only for quadratics) |
| 0.9 | 5.4 | oscillating | Diverging oscillation |
| 1.1 | 6.6 | ∞ | Explosion — loss goes to infinity |
For a pure quadratic f(x) = (x - 3)², the maximum stable learning rate is exactly 1.0. At η = 1.0, each step reflects x to the opposite side of the minimum with the exact same distance. At η > 1.0, the oscillations grow and the optimizer diverges. Real loss functions aren't quadratics, so the stability boundary depends on the local curvature — and that curvature changes as training progresses.
Look at the three cases above. The ideal strategy is obvious: start with a moderate-to-large learning rate (Case 1 territory) to make fast progress when you're far from any minimum. Then gradually shrink it (toward Case 3) as you approach the minimum, so you settle precisely into it instead of bouncing around.
A learning rate schedule is a function that maps training step to learning rate: η(t). Instead of a single number, you get a curve. The schedule encodes the insight that the optimal step size depends on where you are in training.
python # Gradient descent with constant LR — showing all three regimes def f(x): return (x - 3) ** 2 def grad_f(x): return 2 * (x - 3) def gd_constant(lr, steps=50, x0=0.0): x = x0 trajectory = [x] for _ in range(steps): g = grad_f(x) x = x - lr * g trajectory.append(x) if abs(x) > 1e6: # diverged break return trajectory # Too small: crawls slow = gd_constant(lr=0.01) print(f"LR=0.01, after 50 steps: x={slow[-1]:.4f}") # LR=0.01, after 50 steps: x=2.3603 # Just right: converges fast good = gd_constant(lr=0.5) print(f"LR=0.5, after 50 steps: x={good[-1]:.4f}") # LR=0.5, after 50 steps: x=3.0000 # Too large: oscillates osc = gd_constant(lr=0.9) print(f"LR=0.9, after 50 steps: x={osc[-1]:.4f}") # LR=0.9, after 50 steps: x=oscillating around 3.0
You've decided on a target learning rate of 0.001. Naively, you set it to 0.001 from step 0. The model has random weights. Gradients are wild — pointing in random directions with huge magnitudes. Multiplying wild gradients by a non-trivial learning rate: chaos. The first few hundred steps are a car crash.
Learning rate warmup is a simple fix: instead of starting at the target LR, start at zero (or near-zero) and linearly increase to the target over the first N steps. This gives the optimizer time to get its bearings before you ask it to take real steps.
At step 0, every weight in your network is random. The model is essentially computing garbage — its output has nothing to do with the input. Gradients computed on garbage outputs are noisy, enormous, and point in nearly random directions.
Now multiply those wild gradients by a non-trivial learning rate. The result is large, random weight updates. The model lurches in an arbitrary direction. This lurch produces even wilder gradients at step 1. A positive feedback loop forms: bad predictions → wild gradients → large updates → worse predictions → wilder gradients. Training can diverge in the first hundred steps.
Adam maintains two running averages for each parameter: the first moment m (exponential moving average of gradients) and the second moment v (exponential moving average of squared gradients). The actual update divides the first moment by the square root of the second moment: Δw = m / √v.
At step 0, both m and v are initialized to zero. Adam uses bias correction to account for this — dividing by (1 - βt) — but the corrected estimates are still based on just one or two gradient samples. They're unreliable.
Here's the dangerous scenario: suppose the first gradient for some parameter is accidentally small. Then v (squared gradient) is tiny. The update m / √v divides by a tiny number, producing an enormous step. One unlucky gradient can send a weight into orbit.
Linear warmup is the simplest and most common form. Given a target LR of ηtarget and a warmup duration of Tw steps, the learning rate at step t is:
Let's trace it. Target LR = 0.001, warmup = 1000 steps:
| Step | t / Tw | η(t) | Description |
|---|---|---|---|
| 0 | 0.000 | 0.000000 | Effectively frozen |
| 100 | 0.100 | 0.000100 | 10% of target |
| 250 | 0.250 | 0.000250 | 25% of target |
| 500 | 0.500 | 0.000500 | 50% of target |
| 750 | 0.750 | 0.000750 | 75% of target |
| 1000 | 1.000 | 0.001000 | Full target reached |
| 2000 | 2.000 | 0.001000 | Clamped at target (min(1, ...)) |
Notice the LR at step 0 is exactly zero — no update happens at all. By step 100, the LR is just 10% of the target. This means even if the gradient is 10× too large (because Adam's estimates are bad), the effective update is only as big as it would be at the target LR with a normal gradient.
The standard recipe is 1-5% of total training steps. This is not a magic number — it's roughly how long Adam needs to build reliable running averages.
| Total Steps | Warmup (1%) | Warmup (5%) | Typical |
|---|---|---|---|
| 10,000 | 100 | 500 | 100-500 |
| 100,000 | 1,000 | 5,000 | 1,000-2,000 |
| 1,000,000 | 10,000 | 50,000 | 10,000-20,000 |
GPT-3 used 375 warmup steps out of 300,000 total (0.125%). BERT used 10,000 out of 1,000,000 (1%). LLaMA used 2,000 out of roughly 1,500,000 (0.13%). The trend in large-scale training is toward shorter warmup — just enough to stabilize, no more.
The simulation below shows two training runs side by side. The left curve trains without warmup — the LR starts at the target from step 0. The right curve uses linear warmup. Watch the loss in the first few hundred steps: without warmup, the loss often spikes or diverges before recovering (if it recovers at all). With warmup, the loss decreases smoothly from the start.
Two simulated training runs with Adam (target LR = 0.001). Left: no warmup. Right: linear warmup. Adjust the warmup fraction to see how warmup duration affects early training stability. Press "New Run" to generate fresh noise.
python # Linear warmup from scratch def linear_warmup_lr(step, target_lr, warmup_steps): """Returns the learning rate at the given step.""" if step < warmup_steps: return target_lr * (step / warmup_steps) return target_lr # Usage in a training loop target_lr = 1e-3 warmup_steps = 1000 optimizer = torch.optim.Adam(model.parameters(), lr=target_lr) for step in range(total_steps): lr = linear_warmup_lr(step, target_lr, warmup_steps) for pg in optimizer.param_groups: pg['lr'] = lr loss = model(batch) loss.backward() optimizer.step() optimizer.zero_grad()
python # PyTorch built-in: LambdaLR with warmup import torch from torch.optim.lr_scheduler import LambdaLR optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) warmup_steps = 1000 scheduler = LambdaLR( optimizer, lr_lambda=lambda step: min(1.0, step / warmup_steps) ) for step in range(total_steps): loss = model(batch) loss.backward() optimizer.step() scheduler.step() # updates LR after each step optimizer.zero_grad()
Before fancy schedules, researchers used a simple recipe: train at LR = 0.1 for 30 epochs, then drop to 0.01, then to 0.001. Each drop lets the optimizer settle into a minimum that the previous LR was oscillating around. ResNet's original training — the paper that put 152-layer networks on the map — used exactly this recipe.
Step decay (also called staircase decay or piecewise constant schedule) is the simplest schedule beyond a constant LR. You pick a set of milestones (specific epochs or steps) and multiply the LR by a decay factor γ at each milestone.
Given an initial learning rate η0, a decay factor γ (typically 0.1 or 0.5), and a step size S (the number of epochs between drops), the learning rate at epoch e is:
The floor function ⌊e / S⌋ counts how many decay events have happened. At epoch 0, zero events have happened, so γ0 = 1 and the LR is η0. At epoch S, one event has happened, so the LR drops to η0 × γ. At epoch 2S, it drops again to η0 × γ².
The original ResNet paper (He et al., 2015) trained for 120 epochs on ImageNet with η0 = 0.1 and γ = 0.1 at epochs 30, 60, and 90. Let's trace through:
| Epoch Range | Decay Events | η | Description |
|---|---|---|---|
| 0 – 29 | 0 | 0.1 | Full exploration |
| 30 – 59 | 1 | 0.01 | 10× smaller |
| 60 – 89 | 2 | 0.001 | 100× smaller |
| 90 – 119 | 3 | 0.0001 | 1000× smaller (settling) |
Now let's work a gentler example. Initial LR = 0.1, γ = 0.5 (halve the LR each time), drop every 10 epochs, 50 epochs total.
| Epoch | ⌊e/10⌋ | 0.5k | η |
|---|---|---|---|
| 0 | 0 | 1.000 | 0.1000 |
| 10 | 1 | 0.500 | 0.0500 |
| 20 | 2 | 0.250 | 0.0250 |
| 30 | 3 | 0.125 | 0.0125 |
| 40 | 4 | 0.0625 | 0.00625 |
The final LR is 0.1 × 0.54 = 0.00625 — that's 16× smaller than the initial LR. Over 50 epochs, the total "learning budget" (sum of LRs across all epochs, a rough proxy for total learning) is:
Step decay produces a distinctive loss curve shape that you'll recognize instantly once you know what to look for.
Between LR drops, the loss plateaus. The optimizer has found a minimum at the current LR scale and is oscillating around it. Each gradient step bounces the weights back and forth across the minimum — the LR is too large to settle, but small enough to stay in the neighborhood.
When the LR drops, the optimizer suddenly has smaller steps. It can now settle deeper into that minimum, causing a sudden decrease in loss. The loss drops sharply, then plateaus again at the new LR scale.
This plateau → drop → plateau pattern is the visual signature of step decay. If you see staircase-shaped loss curves in someone's training logs, they're using step decay.
The simulation below lets you design your own step decay schedule. The top plot shows the LR staircase. The bottom plot shows a simulated training loss curve — watch for the plateau-then-drop pattern at each LR reduction.
Adjust the decay factor, step interval, and initial LR. Watch how the loss curve responds: plateaus between drops, then sharp decreases at each LR reduction.
python # Step decay from scratch def step_decay_lr(epoch, init_lr, gamma, step_size): """LR at the given epoch under step decay.""" return init_lr * (gamma ** (epoch // step_size)) # Example: ResNet schedule for epoch in [0, 29, 30, 59, 60, 89, 90, 119]: lr = step_decay_lr(epoch, 0.1, 0.1, 30) print(f"Epoch {epoch:3d}: LR = {lr:.4f}") # Epoch 0: LR = 0.1000 # Epoch 29: LR = 0.1000 # Epoch 30: LR = 0.0100 # Epoch 59: LR = 0.0100 # Epoch 60: LR = 0.0010 # Epoch 89: LR = 0.0010 # Epoch 90: LR = 0.0001 # Epoch 119: LR = 0.0001
python # PyTorch built-in: StepLR import torch optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) scheduler = torch.optim.lr_scheduler.StepLR( optimizer, step_size=30, gamma=0.1 ) for epoch in range(120): train_one_epoch(model, optimizer) scheduler.step() # call AFTER optimizer.step() print(f"Epoch {epoch}: LR = {scheduler.get_last_lr()[0]:.4f}") # For custom milestones (not evenly spaced): scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[30, 60, 90], gamma=0.1 )
Step decay requires knowing when to drop the LR — and getting it wrong costs you. What if we didn't drop at all, but smoothly lowered the LR over the entire run? Cosine annealing does exactly this, and the cosine shape turns out to be nearly optimal.
Proposed by Loshchilov and Hutter in 2016, cosine annealing has become the default schedule for most of modern deep learning. GPT-3, LLaMA, PaLM, Chinchilla — they all use cosine annealing (with warmup). Its appeal is simple: one fewer hyperparameter to tune (no milestone epochs), smooth gradient flow, and a shape that naturally spends more time at productive learning rates.
We want a function that starts at ηmax and smoothly decays to ηmin over T total steps. Let's build it from the cosine function.
Step 1: The cosine function cos(θ) goes from 1 (at θ = 0) to -1 (at θ = π). We want something that goes from 1 to 0. So we take:
At t = 0: ½(1 + cos(0)) = ½(1 + 1) = 1. At t = T: ½(1 + cos(π)) = ½(1 - 1) = 0. It goes from 1 to 0 over T steps, following a smooth cosine curve.
Step 2: Scale it to our LR range. We want to go from ηmax down to ηmin:
Verify: At t = 0: ηmin + (ηmax - ηmin) × 1 = ηmax. At t = T: ηmin + (ηmax - ηmin) × 0 = ηmin. Correct.
Let ηmax = 0.001, ηmin = 0, T = 100 steps. We'll compute the LR at key points along the schedule.
| Step t | πt / T | cos(πt/T) | ½(1 + cos) | η(t) |
|---|---|---|---|---|
| 0 | 0 | 1.000 | 1.000 | 0.001000 |
| 10 | π/10 | 0.951 | 0.976 | 0.000976 |
| 25 | π/4 | 0.707 | 0.854 | 0.000854 |
| 33 | π/3 | 0.500 | 0.750 | 0.000750 |
| 50 | π/2 | 0.000 | 0.500 | 0.000500 |
| 67 | 2π/3 | -0.500 | 0.250 | 0.000250 |
| 75 | 3π/4 | -0.707 | 0.146 | 0.000146 |
| 90 | 9π/10 | -0.951 | 0.024 | 0.000024 |
| 100 | π | -1.000 | 0.000 | 0.000000 |
Look at the rate of change. From step 0 to step 25 (first quarter), the LR drops from 0.001 to 0.000854 — only a 15% decrease. From step 75 to step 100 (last quarter), it drops from 0.000146 to 0 — also slow in absolute terms. But from step 25 to step 75 (the middle half), it drops from 0.000854 to 0.000146 — a massive 83% of the total decrease happens in the middle half.
A linear decay drops the LR at a constant rate: η(t) = ηmax(1 - t/T). It spends equal time at every LR level. But not all LR levels are equally productive.
The middle of training is the most productive phase. The model has already learned the broad strokes (it's past the random-initialization chaos) but hasn't yet settled into a minimum (it's still making meaningful improvements). This is when the optimizer is doing real work — refining features, sharpening decision boundaries, tuning internal representations.
Cosine annealing keeps the LR higher during this productive phase and drops it sharply only toward the end. Linear decay, by contrast, has already wasted the LR budget — by the midpoint, it's at half the initial LR, reducing the effective learning rate during the most productive phase.
Another way to see this: imagine a histogram of "time spent at each LR level." Linear decay gives a flat histogram — equal time at every level. Cosine gives a U-shaped histogram that's weighted toward the middle values. Since middle LR values are where the optimizer does its best work, cosine is more efficient.
The simulation below plots cosine annealing, linear decay, and step decay on the same axes. Watch how the three schedules distribute the LR budget differently. The shaded areas represent total learning budget (integral of LR over time).
Three schedules with the same start and end LR. Adjust ηmax to see how each schedule distributes the learning budget. The cosine schedule spends more time at medium LRs.
Loshchilov and Hutter's original paper (SGDR: Stochastic Gradient Descent with Warm Restarts, 2016) proposed an elegant extension: instead of one cosine decay, use repeated cosine cycles. When the LR reaches ηmin, restart from ηmax.
Why would you want to increase the LR after carefully decreasing it? Because the optimizer might have settled into a local minimum that's not the best one. A sudden LR increase — a "warm restart" — gives the optimizer enough energy to escape and explore nearby regions. It might find a better minimum on the next cycle.
The resulting schedule looks like a sawtooth wave with cosine-shaped teeth: each tooth starts at ηmax and decays to ηmin. A common variation doubles the period after each restart (T, 2T, 4T, ...), giving the optimizer progressively longer settling phases.
python # Cosine annealing from scratch — just 3 lines import math def cosine_lr(step, total_steps, lr_max, lr_min=0): """Cosine annealing LR at the given step.""" return lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * step / total_steps)) # Trace it for t in [0, 25, 50, 75, 100]: print(f"Step {t:3d}: LR = {cosine_lr(t, 100, 0.001):.6f}") # Step 0: LR = 0.001000 # Step 25: LR = 0.000854 # Step 50: LR = 0.000500 # Step 75: LR = 0.000146 # Step 100: LR = 0.000000
python # PyTorch: CosineAnnealingLR import torch optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=100, # total steps eta_min=1e-5 # minimum LR (default 0) ) for step in range(100): loss = model(batch) loss.backward() optimizer.step() scheduler.step() optimizer.zero_grad()
python # PyTorch: CosineAnnealingWarmRestarts (SGDR) scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, T_0=50, # steps in first cosine cycle T_mult=2, # double the period after each restart eta_min=1e-5 ) # Cycle 1: steps 0-49 (period 50) # Cycle 2: steps 50-149 (period 100) # Cycle 3: steps 150-349 (period 200)
python # The complete recipe: warmup + cosine decay import math def warmup_cosine_lr(step, total_steps, warmup_steps, lr_max, lr_min=0): """Linear warmup then cosine decay — the modern standard.""" if step < warmup_steps: return lr_max * (step / warmup_steps) progress = (step - warmup_steps) / (total_steps - warmup_steps) return lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * progress)) # Typical usage: 2000 warmup, 300K total, peak LR 3e-4 for s in [0, 1000, 2000, 50000, 150000, 299999]: lr = warmup_cosine_lr(s, 300000, 2000, 3e-4) print(f"Step {s:6d}: LR = {lr:.6f}") # Step 0: LR = 0.000000 (warmup) # Step 1000: LR = 0.000150 (mid-warmup) # Step 2000: LR = 0.000300 (peak) # Step 50000: LR = 0.000290 (slowly decaying) # Step 150000: LR = 0.000150 (midpoint) # Step 299999: LR = 0.000000 (near zero)
What if lowering the learning rate isn't always the right move? What if raising it periodically — kicking the optimizer out of whatever valley it settled into — actually leads to better minima?
That's the radical idea behind Cyclical Learning Rates (CLR), proposed by Leslie Smith in 2017. Instead of a monotonic descent, the LR oscillates between a base_lr (minimum) and a max_lr (maximum) over repeating cycles. Each cycle has two halves: ramp up, then ramp down. The ramp up pushes the optimizer out of the current basin. The ramp down lets it settle into a new one — often a better one.
The key insight is about the loss landscape. Neural networks have many local minima. Some are sharp and narrow (bad — they overfit), some are wide and flat (good — they generalize). A low LR gets stuck in whatever minimum it finds first. A periodically high LR has enough energy to escape narrow minima but not enough to escape wide ones. Over multiple cycles, the optimizer naturally migrates toward wider, flatter minima.
Smith proposed three variants, all built on the same triangle wave:
| Policy | Shape | Behavior Across Cycles |
|---|---|---|
| triangular | Linear ramp up, linear ramp down | Same max_lr every cycle |
| triangular2 | Same triangle shape | Halve max_lr each cycle |
| exp_range | Same triangle shape | Exponential decay of max_lr: max_lr × γstep |
All three share the same base formula. The only difference is how the peak shrinks (or doesn't) over successive cycles.
Given a training step t, a step_size (half the cycle length in steps), base_lr, and max_lr:
The variable x is a normalized position within the current cycle. When x = 0, we're at the peak (LR = max_lr). When x = 1, we're at the trough (LR = base_lr). The absolute value creates the triangle wave — x decreases from 1 to 0 during the ramp-up half, then increases from 0 to 1 during the ramp-down half.
Step 1: Which cycle are we in?
Still in the first cycle.
Step 2: Where within the cycle?
We're 75% of the way through the ramp-down half (x > 0 means we're descending).
Step 3: Compute the LR.
Makes sense: step 3500 is three-quarters through a 4000-step cycle, deep into the ramp-down, so the LR is close to (but above) base_lr.
Verification checkpoints:
Adjust the cycle parameters and switch between the three policies. The LR curve shows the schedule; the loss curve below shows how training responds — note the temporary loss spikes at each LR peak, followed by drops to new lows.
python import numpy as np def cyclical_lr(step, base_lr=0.001, max_lr=0.01, step_size=2000, policy='triangular', gamma=0.99994): """Compute LR at a given step using cyclical schedule.""" cycle = np.floor(step / (2 * step_size)) x = np.abs(step / step_size - 2 * cycle - 1) if policy == 'triangular': scale = 1.0 elif policy == 'triangular2': scale = 1 / (2 ** cycle) # Halve max each cycle elif policy == 'exp_range': scale = gamma ** step # Exponential decay else: raise ValueError(f"Unknown policy: {policy}") lr = base_lr + (max_lr - base_lr) * max(0, 1 - x) * scale return lr # Verify our hand calculation lr_3500 = cyclical_lr(3500) print(f"LR at step 3500: {lr_3500:.5f}") # Output: 0.00325 ✓ # Full schedule steps = np.arange(10000) lrs = [cyclical_lr(s) for s in steps] print(f"Min LR: {min(lrs):.4f}, Max LR: {max(lrs):.4f}") # Output: Min LR: 0.0010, Max LR: 0.0100
With PyTorch:
python import torch model = torch.nn.Linear(10, 1) optimizer = torch.optim.SGD(model.parameters(), lr=0.001) scheduler = torch.optim.lr_scheduler.CyclicLR( optimizer, base_lr=0.001, max_lr=0.01, step_size_up=2000, # Half-cycle (ramp up) step_size_down=2000, # Half-cycle (ramp down) mode='triangular', # or 'triangular2', 'exp_range' gamma=0.99994, # Only used for exp_range cycle_momentum=True, # Inversely cycle momentum too! ) # Usage: call scheduler.step() after each BATCH, not each epoch for step in range(10000): optimizer.zero_grad() loss = model(torch.randn(32, 10)).sum() loss.backward() optimizer.step() scheduler.step() # Step per batch, not per epoch!
scheduler.step() at the end of each epoch. CyclicLR is
different — it needs to be called after every mini-batch. If you call it
per-epoch, you'll only complete a tiny fraction of a cycle over the full
training run, defeating the purpose. This is the most common CLR bug.
What if you could find the optimal learning rate AND train to completion in the same run? The 1cycle policy does both. It's the schedule that achieves super-convergence — reaching the same accuracy in 10 times fewer steps.
Leslie Smith (the same researcher behind CLR) teamed with Nicholay Topin in 2019 to propose something elegant: instead of repeating many small cycles, use one single large cycle that spans the entire training run. The LR starts low, climbs to a high peak in the middle of training, descends back to the starting point, and then plummets even further in a brief annihilation phase at the very end.
1cycle divides training into three phases with very different purposes:
The phase fractions are controlled by pct_start (typically 0.3
to 0.45 — the fraction of training spent ramping up). Phase 2 mirrors
phase 1 in length. Phase 3 gets whatever remains.
Before using 1cycle, you need to know the right max_lr. Smith proposed a brilliant diagnostic: the LR range test. Here's how it works:
The resulting curve has a characteristic shape: the loss stays flat (LR too small to matter), then drops steeply (the sweet spot), then explodes upward (LR too large, training diverges). Your max_lr should be slightly below the explosion point — where the loss is still decreasing but the curve starts to bend upward. The min_lr is typically 1/10th to 1/25th of max_lr.
The magic of 1cycle is that the high LR in the middle of training acts as implicit regularization. At high learning rates, the optimizer can't settle into sharp, narrow minima — the update steps are too large to fit inside them. It can only stay in wide, flat minima. And wide, flat minima generalize better because nearby points in weight space produce similar loss values.
This is the same mechanism that makes large batch training with high LR generalize well — it's an implicit bias toward flat regions of the loss landscape. The 1cycle policy exploits this deliberately: spend half of training at a high enough LR that only wide minima are reachable, then reduce the LR to fine-tune within that wide basin.
Phase boundaries:
At step 2,250 (mid phase 1):
At step 4,500 (peak):
At step 6,750 (mid phase 2):
Notice the symmetry: mid phase 1 and mid phase 2 give the same LR.
At step 9,500 (mid phase 3 — annihilation):
Top: the LR schedule with three phases labeled. Bottom-left: simulated training loss showing super-convergence. Bottom-right: the LR range test — the loss-vs-LR diagnostic that tells you where to set max_lr.
python import numpy as np def one_cycle_lr(step, total_steps, max_lr=0.001, div_factor=10, final_div_factor=100, pct_start=0.45): """Compute LR at a given step using 1cycle schedule.""" min_lr = max_lr / div_factor final_lr = min_lr / final_div_factor phase1_end = int(total_steps * pct_start) phase2_end = int(total_steps * (2 * pct_start)) # If pct_start=0.45, phase2 ends at 90%, leaving 10% for annihilation if step < phase1_end: # Phase 1: ramp up t = step / phase1_end lr = min_lr + (max_lr - min_lr) * t elif step < phase2_end: # Phase 2: ramp down t = (step - phase1_end) / (phase2_end - phase1_end) lr = max_lr - (max_lr - min_lr) * t else: # Phase 3: annihilation t = (step - phase2_end) / (total_steps - phase2_end) lr = min_lr - (min_lr - final_lr) * t return lr # Verify hand calculation print(f"Step 2250: {one_cycle_lr(2250, 10000):.5f}") # 0.00055 print(f"Step 4500: {one_cycle_lr(4500, 10000):.5f}") # 0.00100 print(f"Step 6750: {one_cycle_lr(6750, 10000):.5f}") # 0.00055 print(f"Step 9500: {one_cycle_lr(9500, 10000):.5f}") # 0.00005
The LR range test:
python def lr_range_test(model, train_loader, optimizer, start_lr=1e-7, end_lr=10, num_steps=200): """Run the LR range test. Returns (lrs, losses).""" mult = (end_lr / start_lr) ** (1 / num_steps) lr = start_lr lrs, losses = [], [] best_loss = float('inf') for step, (x, y) in zip(range(num_steps), train_loader): # Set LR for pg in optimizer.param_groups: pg['lr'] = lr loss = model.train_step(x, y) lrs.append(lr) losses.append(loss) # Smoothed loss tracking if loss < best_loss: best_loss = loss if loss > 4 * best_loss: break # Loss exploded — stop here lr *= mult # Geometric increase return lrs, losses # After running: plot lrs (log scale) vs losses # Pick max_lr just before the loss starts rising
With PyTorch:
python import torch model = torch.nn.Linear(10, 1) optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001) scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=0.001, # Peak LR (from range test) total_steps=10000, # Total training steps pct_start=0.3, # 30% warmup, 30% decay, 10% annihilation div_factor=10, # initial_lr = max_lr / 10 final_div_factor=100, # final_lr = initial_lr / 100 anneal_strategy='cos', # Cosine (smoother) or 'linear' ) # Note: PyTorch's OneCycleLR uses cosine annealing within each phase by default # Our hand calc used linear for simplicity — both work well
You've seen warmup. You've seen cosine annealing. You've seen cyclical tricks. Now forget all of them — because the schedule actually running inside GPT-4, Llama 3, Chinchilla, and Gemini is something much simpler. It's three straight lines:
This is Warmup-Stable-Decay (WSD), also called the "trapezoidal" schedule. It's the workhorse of large language model training, and understanding why it works reveals deep truths about how LLMs learn.
Each phase serves a distinct purpose tied to the dynamics of LLM training:
Warmup (Phase 1): The optimizer state (momentum, second moments in Adam) starts from zero. During the first few hundred steps, these running averages are unreliable — they're dominated by the very first gradients, which are computed on random weights. The warmup phase keeps the LR low while the optimizer state "fills up" with reliable statistics. Without warmup, the first few large updates can push the model into a region of weight space from which it never recovers.
Stable (Phase 2): This is where the model learns. At a constant, high learning rate, the optimizer has enough energy to explore the loss landscape broadly. For an LLM, this means developing syntax, semantics, world knowledge, reasoning patterns — all the capabilities that emerge during pre-training. The constant LR is not optimal for any one point in training, but it's good enough across the entire phase. And "good enough with simplicity" beats "theoretically optimal with complexity" when you're training for weeks on thousands of GPUs.
Decay (Phase 3): The final annealing phase reduces the LR to squeeze out the last bits of performance. During stable-phase training, the model oscillates around a good solution — the constant LR is high enough that the loss fluctuates. Decaying the LR lets the optimizer settle into the bottom of the basin. This phase typically improves benchmark scores by 1-3%, which doesn't sound like much but represents billions of dollars in capability at frontier scale.
Let ηpeak be the peak LR, Ttotal the total training steps, w the warmup fraction, and s the stable fraction. Then d = 1 − w − s is the decay fraction.
During warmup (t < twarmup):
During stable (twarmup ≤ t < tstable):
During cosine decay (t ≥ tstable):
Where ηmin is typically ηpeak / 10 (a 10x ratio between peak and minimum LR).
Here's what the big labs actually use:
| Model | Total Steps | Warmup | Stable | Decay | Peak LR | Min LR |
|---|---|---|---|---|---|---|
| Chinchilla | 300K | 2K (0.7%) | 270K (90%) | 28K (9.3%) | 3×10-4 | 3×10-5 |
| LLaMA 2 | ~2M | 2K (0.1%) | 1.8M (90%) | ~200K (10%) | 3×10-4 | 3×10-5 |
| GPT-3 | ~300K | 375 (0.1%) | 270K (90%) | ~30K (10%) | 6×10-5 | 6×10-6 |
| Mistral 7B | ~1M | 2K (0.2%) | ~900K (90%) | ~98K (10%) | 3×10-4 | 3×10-5 |
Notice the pattern: warmup is always tiny (under 1%), stable is always dominant (around 90%), and the decay phase gets roughly 10%. The 10x ratio between peak and min LR is nearly universal. This isn't coincidence — it's the result of extensive scaling law research at Google DeepMind and Meta.
Configure a WSD schedule and watch training unfold. The top panel shows the LR curve with color-coded phases. The middle panel shows the simulated loss. The bottom bar shows the learning budget — what fraction of total loss reduction happens in each phase.
python import math def wsd_lr(step, total_steps, peak_lr=3e-4, min_lr=3e-5, warmup_frac=0.02, stable_frac=0.88, decay='cosine'): """Warmup-Stable-Decay schedule used by LLaMA, Chinchilla, etc.""" warmup_steps = int(total_steps * warmup_frac) stable_end = int(total_steps * (warmup_frac + stable_frac)) if step < warmup_steps: # Phase 1: linear warmup return peak_lr * (step / warmup_steps) elif step < stable_end: # Phase 2: constant return peak_lr else: # Phase 3: decay decay_steps = total_steps - stable_end t = (step - stable_end) / decay_steps # 0 → 1 if decay == 'cosine': return min_lr + (peak_lr - min_lr) * 0.5 * (1 + math.cos(math.pi * t)) elif decay == 'linear': return peak_lr - (peak_lr - min_lr) * t elif decay == 'sqrt': return min_lr + (peak_lr - min_lr) * (1 - math.sqrt(t)) else: raise ValueError(f"Unknown decay: {decay}") # Chinchilla config for s in [0, 1000, 2000, 150000, 272000, 290000, 300000]: lr = wsd_lr(s, 300000) print(f"Step {s:>7d}: LR = {lr:.6f}") # Step 0: LR = 0.000000 (start of warmup) # Step 1000: LR = 0.000050 (mid warmup) # Step 2000: LR = 0.000100 (... still warming) # Step 150000: LR = 0.000300 (stable phase) # Step 272000: LR = 0.000300 (still stable!) # Step 290000: LR = 0.000082 (decaying) # Step 300000: LR = 0.000030 (final min_lr)
How LLaMA sets it up in its training config:
python # LLaMA-style training config (simplified from fairscale/metaseq) config = { "optimizer": { "type": "AdamW", "lr": 3e-4, # peak_lr "betas": (0.9, 0.95), "weight_decay": 0.1, }, "scheduler": { "type": "WSD", "warmup_steps": 2000, "total_steps": 2000000, "min_lr": 3e-5, # 10x ratio "decay": "cosine", }, } # The actual training loop (pseudocode): for step in range(config["scheduler"]["total_steps"]): lr = wsd_lr(step, **config["scheduler"]) for pg in optimizer.param_groups: pg["lr"] = lr # ... forward, backward, step
Cosine isn't the only smooth decay shape. Polynomial decay lets you control the curvature of the schedule with a single parameter. Want to decay fast early and slow late? Set the power above 1. Want to decay slow early and fast late? Set the power below 1.
And before cosine took over the world, the original Transformer paper used inverse square root decay — a specific shape that held as the standard for three years. Let's derive and compare them all.
Polynomial decay interpolates from ηinit to ηend over T steps, with a power parameter p that controls the shape:
At t = 0: (1 − 0)p = 1, so η(0) = ηinit. At t = T: (1 − 1)p = 0, so η(T) = ηend. The power p controls the path between these endpoints:
The physical intuition: high power = front-loaded decay (spend most of your time at a low LR), low power = back-loaded decay (spend most of your time at a high LR). The choice depends on your task: vision models often benefit from p = 2 (quadratic), while NLP tasks tend to prefer p = 1 (linear) or cosine.
Linear (p = 1):
Exactly half. That's what linear means.
Quadratic (p = 2):
Already at 25% of the initial LR halfway through. Quadratic decays faster early.
Square root (p = 0.5):
Still at 70% of the initial LR at the halfway point. Sqrt preserves high LR longer.
Cubic (p = 3):
Only 12.5% left at halfway. Aggressive.
Compare all four at the same moment: sqrt preserves 2.8x more LR than quadratic at the midpoint. The choice of power fundamentally changes where the optimizer spends its "learning budget."
Vaswani et al.'s "Attention Is All You Need" (2017) used a distinctive schedule that combines linear warmup with inverse square root decay:
The min selects between two curves: the linear warmup term
(t × warmup-1.5) dominates when t is small, and the inverse
sqrt term (t-0.5) takes over once t exceeds warmup_steps. The
crossover happens exactly at t = warmup_steps.
Hand calculation: dmodel = 512, warmup_steps = 4000.
At the peak (t = 4000):
At t = 16000 (4x past warmup):
Half the peak LR. Under inverse sqrt decay, you need 4x the steps to halve the LR (since √4 = 2). This makes it a very gentle decay — much slower than cosine or linear.
| Schedule | Shape | When to Use | Era |
|---|---|---|---|
| Step Decay | Staircase | Simple baselines, classic CNNs | 2012-2016 |
| Inverse Sqrt | 1/√t after warmup | Original Transformer | 2017-2019 |
| Cosine Annealing | Half cosine wave | Vision (ViT, DINO), fine-tuning | 2017-present |
| Cyclical LR | Triangular oscillation | Quick exploration, avoiding local minima | 2017-present |
| 1Cycle | One big triangle + annihilation | Fast training, super-convergence | 2019-present |
| WSD | Trapezoid (warmup+constant+decay) | LLM pre-training (GPT, LLaMA) | 2020-present |
| Polynomial | Tunable curve (power p) | When you need precise control | General |
| Exponential | η × γt | Reinforcement learning | General |
All schedules on the same axes. Drag the power slider for polynomial decay and toggle each curve on/off. See how they compare.
python import math def polynomial_lr(step, total_steps, init_lr=0.001, end_lr=0.0, power=2.0): """Polynomial decay: fast early for p>1, slow early for p<1.""" t = min(step / total_steps, 1.0) return (init_lr - end_lr) * (1 - t) ** power + end_lr def inverse_sqrt_lr(step, d_model=512, warmup_steps=4000): """Original Transformer schedule (Vaswani et al. 2017).""" if step == 0: step = 1 return d_model ** (-0.5) * min(step ** (-0.5), step * warmup_steps ** (-1.5)) def exponential_lr(step, init_lr=0.001, gamma=0.9999): """Exponential decay: LR = init * gamma^step.""" return init_lr * gamma ** step # Compare at midpoint (step 50 out of 100) print("Polynomial comparisons at t=50/100:") for p in [0.5, 1.0, 2.0, 3.0]: lr = polynomial_lr(50, 100, power=p) print(f" p={p:.1f}: {lr:.6f}") # p=0.5: 0.000707 (sqrt — gentle early) # p=1.0: 0.000500 (linear) # p=2.0: 0.000250 (quadratic — aggressive early) # p=3.0: 0.000125 (cubic — very aggressive early) # Inverse sqrt peak peak = inverse_sqrt_lr(4000) print(f"Transformer peak LR: {peak:.6f}") # 0.000699
With PyTorch:
python import torch model = torch.nn.Linear(10, 1) optimizer = torch.optim.AdamW(model.parameters(), lr=0.001) # Polynomial decay (built-in since PyTorch 1.13) poly_sched = torch.optim.lr_scheduler.PolynomialLR( optimizer, total_iters=100, power=2.0, # Quadratic decay ) # Exponential decay exp_sched = torch.optim.lr_scheduler.ExponentialLR( optimizer, gamma=0.9999, # Multiply LR by 0.9999 each step ) # Inverse sqrt (use LambdaLR) d_model = 512 warmup = 4000 inv_sqrt_sched = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda step: ( d_model ** (-0.5) * min( (step + 1) ** (-0.5), (step + 1) * warmup ** (-1.5) ) ), )
Everything we've learned, in one simulation. Drop multiple schedules onto the same training run and watch them race. Each one runs its actual formula — the real schedule equations, not an approximation.
This is the payoff. You've learned what each schedule does; now you'll see the differences in real time. Which schedule reaches the lowest loss fastest? Which one oscillates? Which one wastes compute at the wrong learning rates?
Race schedules head-to-head on the same training simulation. Each runs its real formula.
Cosine vs WSD: Cosine starts decaying from step 0, meaning it operates at a lower effective LR throughout mid-training. WSD holds the peak LR for ~90% of training, giving it a higher average LR and more total learning. The WSD loss curve drops faster during the stable phase.
1Cycle's high-LR hump: Watch the 1cycle loss — it actually increases during the ramp-up phase as the LR climbs. This is normal! The optimizer is being pushed out of its current basin. When the LR drops in the second half, the loss falls to a new low. The temporary spike buys a better final result.
Cyclical's oscillation: The CLR loss curve has a sawtooth pattern — periodic spikes at each LR peak followed by drops at each trough. This is the exploration/exploitation cycle in action. Over time, the spikes get smaller as the optimizer settles into progressively wider minima.
Step decay's plateaus: Look for the characteristic flat segments in the step decay loss curve. The loss stops improving between drops (the optimizer oscillates around a minimum at the current LR scale), then drops sharply when the LR is reduced.
You now understand the full schedule toolkit — from raw constant LR to the WSD schedule running inside every frontier LLM. This chapter is your practical reference. No new concepts. Just the decision guide you'll actually use.
| Situation | Schedule | Peak LR | Key Setting |
|---|---|---|---|
| LLM pre-training (GPT, LLaMA, Gemini) | WSD (warmup + stable + cosine decay) | 3e-4 | Warmup ~0.5%, stable ~90%, decay ~10% |
| Transformer fine-tuning (BERT, RoBERTa) | Linear warmup + linear decay | 2e-5 | Warmup 6% of total steps |
| Vision model (ResNet on ImageNet) | Step decay or cosine | 0.1 (SGD) | Drop at 30/60/90 epochs (γ=0.1) |
| Vision Transformer (ViT, DINOv2) | Warmup + cosine decay | 1e-3 | Warmup 10K steps, decay to 1e-5 |
| Fast experiment / competition | 1Cycle | From range test | pct_start=0.3, div_factor=10 |
| Quick hyperparameter search | Cyclical LR (triangular2) | From range test | step_size = 2-10 epochs |
| Reinforcement learning | Linear decay or constant | 3e-4 | Decay to 0 over training |
| Diffusion model | Constant or slow cosine | 1e-4 | Very long training, minimal decay |
| GAN training | Constant | 1e-4 to 2e-4 | No decay (fragile equilibrium) |
| Parameter | Default | Typical Range | What it Controls |
|---|---|---|---|
| Peak LR | 3e-4 (Adam), 0.1 (SGD) | 1e-5 to 1.0 | Maximum step size — the #1 hyperparameter |
| Min LR | Peak / 10 | Peak/10 to Peak/100 | Floor during/after decay. Too low = wasted steps |
| Warmup steps | 1-2% of total | 100 to 10,000 | How long to ramp from 0 to peak. For Adam stability |
| Total steps | Task-dependent | 10K to 2M+ | Training budget. Schedule is defined relative to this |
| Decay shape | Cosine | Cosine, linear, sqrt | How LR drops in the decay phase |
| γ (step decay) | 0.1 | 0.1 to 0.5 | Multiply LR by γ at each milestone |
| Cycle length (CLR) | 2-10 epochs | 500 to 10,000 steps | Period of LR oscillation |
| pct_start (1cycle) | 0.3 | 0.2 to 0.45 | Fraction of training spent ramping up |
When you have limited time for schedule tuning, this is the order of importance:
The Chinchilla paper (Hoffmann et al., 2022) showed that peak LR scales as roughly 1 / √(model_size). Larger models need smaller learning rates. The schedule shape matters less than getting the peak right — which is why WSD's simplicity (one flat plateau at peak) is so appealing for large-scale training.
| Model Size | Typical Peak LR | Reasoning |
|---|---|---|
| 125M params | 6e-4 | Small model, can tolerate larger steps |
| 1.3B params | 2e-4 | Moderate — standard Chinchilla range |
| 7B params | 3e-4 | LLaMA's sweet spot |
| 70B params | 1.5e-4 | LLaMA 2 70B — smaller steps for stability |
| 540B params | 1e-4 | PaLM — very conservative for massive scale |