Why gradients vanish and explode — and the five techniques that keep every modern LLM training stable.
Imagine a chain of 50 people. The first person whispers a number — say 1.0 — to the next. Each person multiplies the number by their own factor before passing it along. If every factor is 0.9, the number after 50 steps is 0.950 ≈ 0.005. The original message is almost gone.
If every factor is 1.1 instead, the number after 50 steps is 1.150 ≈ 117. The message has become a scream.
This is exactly what happens to gradients in a deep neural network. During backpropagation, the gradient passes through each layer, getting multiplied by that layer's local derivative. After 50 layers, the gradient at the first layer is the product of 50 multiplications. If most of those factors are less than 1, the gradient vanishes. If they're greater than 1, it explodes.
Watch the telephone game in action. Each person in the chain multiplies by a factor. Drag the slider to change that factor and see how the signal decays or explodes.
50-layer chain. Each layer multiplies by the factor below. The bar height shows the gradient magnitude at each layer.
With factor = 0.9, the gradient at layer 1 is less than 1% of the gradient at layer 50. The first layers barely learn. With factor = 1.1, the gradient at layer 1 is 117× the gradient at layer 50. The optimizer step is enormous and training diverges. Only at exactly 1.0 does the gradient survive intact — and real networks never achieve this perfectly.
Chapter 0 showed the symptom. Now let's understand the cause. Backpropagation computes gradients using the chain rule: to find how a change in layer 1's weights affects the loss, you multiply the local derivatives of every layer in between.
For a network with L layers, the gradient at layer k is:
Each factor (∂ai+1 / ∂ai) is the Jacobian of layer i — how much the output changes per unit change in the input. For a simple linear layer y = Wx + b followed by ReLU, the Jacobian has two parts: the weight matrix W and the derivative of ReLU (which is 1 for positive inputs, 0 for negative).
Let's trace the gradient through a tiny network with 3 layers. Each layer is y = σ(Wx) where σ is the sigmoid function.
Layer 3 (output): Gradient from loss = 1.0 (starting point).
Layer 2: Multiply by σ'(z3) · W3.
Sigmoid's maximum derivative is 0.25 (at z=0). With a weight magnitude of ~0.5, the
factor is 0.25 × 0.5 = 0.125.
Layer 1: Multiply again: 0.125 × 0.125 = 0.0156.
After just 3 layers with sigmoid, the gradient is already 64× smaller. After 10 layers: 0.12510 ≈ 10-9. The gradient is effectively dead. This is why sigmoid networks couldn't be trained beyond ~5 layers — and why the invention of ReLU (whose derivative is 1, not 0.25) was transformative.
See exactly how the chain rule multiplies derivatives layer by layer. Pick the activation function and watch the gradient shrink or survive.
Each layer shows its local derivative. The gradient at each layer is the running product of all derivatives to the right. Watch how sigmoid kills the gradient while ReLU preserves it.
The telephone game showed that gradients can explode. In practice, a single bad batch — an outlier example, a corrupted sample, or an unlucky combination — can produce a gradient 100× or 1000× normal magnitude. The optimizer applies that enormous gradient as a weight update, the loss spikes, and it can take days of training to recover. In some cases, training never recovers.
Gradient clipping is the safety valve. Before the optimizer step, you check the gradient's magnitude and cap it at a maximum threshold. If the gradient is within bounds, nothing changes. If it's too large, you scale it down to the threshold.
Norm clipping (the standard approach): compute the global norm of all gradients — that's √(∑ gi2) across every parameter in the entire model. If the global norm exceeds a threshold (typically 1.0), scale every gradient by (threshold / norm). This preserves the direction of the gradient while capping its magnitude.
Value clipping (less common): clamp each gradient element independently to [-threshold, +threshold]. This changes the direction of the gradient vector, which is usually undesirable. Norm clipping is preferred for this reason.
A small model has 3 parameter tensors with gradients [3.0, 4.0] and [0.0]. The global norm is √(32 + 42 + 02) = √25 = 5.0.
With max_norm = 1.0: since 5.0 > 1.0, we scale by 1.0 / 5.0 = 0.2. The clipped gradients become [0.6, 0.8] and [0.0]. Same direction, magnitude capped at 1.0.
With max_norm = 10.0: since 5.0 < 10.0, min(1, 10/5) = min(1, 2) = 1. No clipping occurs. The gradients pass through unchanged.
Visualize norm clipping as a circle. Any gradient vector landing outside the circle gets pulled back to the boundary.
Red dot = unclipped gradient. Green dot = after norm clipping. Yellow dot = after value clipping. The circle is the max_norm boundary. Click/drag to move the gradient.
python import torch model = MyModel() optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4) for batch in dataloader: loss = model(batch) loss.backward() # Norm clipping — the standard approach for LLMs torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() optimizer.zero_grad()
python # Manual implementation — see what clip_grad_norm_ does def clip_grad_norm(parameters, max_norm): total_norm = 0.0 for p in parameters: if p.grad is not None: total_norm += p.grad.data.norm(2).item() ** 2 total_norm = total_norm ** 0.5 clip_coef = max_norm / (total_norm + 1e-6) if clip_coef < 1.0: for p in parameters: if p.grad is not None: p.grad.data.mul_(clip_coef) return total_norm
Large language models need large batch sizes to train stably. GPT-3 used an effective batch size of 3.2 million tokens. LLaMA-2 used 4 million tokens. A single H100 GPU with 80 GB of memory can barely fit a batch of 4 sequences of length 2048 — that's about 8,000 tokens. How do you bridge the gap from 8,000 to 4,000,000?
Gradient accumulation. Instead of updating weights after every forward-backward pass, you run N micro-batches, accumulate (sum) the gradients, and then do ONE optimizer step using the total. Mathematically, this is identical to running one giant batch of size N × micro_batch_size — because gradients are linear in the loss, and the loss is typically averaged over the batch.
True big batch (B=8): Loss = (1/8) ∑i=1..8 Li. Gradient = (1/8) ∑ ∇Li.
Accumulated (2 micro-batches of 4):
Equivalently, if you divide the loss by the accumulation steps before backward,
you don't need the final division: loss = loss / accum_steps before
loss.backward().
Watch gradients accumulate across micro-batches. Each micro-batch contributes a noisy gradient estimate; the accumulation averages out the noise.
Each arrow is one micro-batch's gradient (noisy). The thick arrow is the accumulated mean. More micro-batches = less noise, bigger effective batch.
python accum_steps = 8 # Effective batch = micro_batch x 8 for step, batch in enumerate(dataloader): loss = model(batch) loss = loss / accum_steps # Normalize BEFORE backward loss.backward() # Gradients ADD to .grad buffers if (step + 1) % accum_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() optimizer.zero_grad() # Reset .grad to zero
loss.backward()
N times without dividing by N, your effective learning rate is N× too high. The
model will diverge. Always either (a) divide the loss by accum_steps before backward,
or (b) divide the gradients by accum_steps before the optimizer step.
A 7B-parameter model in FP32 takes 28 GB just for the weights. In FP16, it takes 14 GB. Same model, half the memory, roughly 2× the throughput on modern GPUs. The catch: FP16 can only represent numbers between ±65,504, and anything below ~0.00006 rounds to zero. Gradients are often smaller than that.
The solution is mixed precision training: use half-precision (FP16 or BF16) for the fast parts — forward pass, most of backward pass — but keep a master copy of the weights in full FP32 for the optimizer update. You get the speed of half precision and the accuracy of full precision.
But to understand why this works, we need to understand the three floating point formats and what each one sacrifices.
Every floating point number is stored as three fields: a sign bit (positive or negative), exponent bits (the scale — powers of 2), and mantissa bits (the precision — significant digits). More exponent bits means wider range. More mantissa bits means finer precision.
| Format | Total bits | Sign | Exponent | Mantissa | Range | Precision | Memory |
|---|---|---|---|---|---|---|---|
| FP32 | 32 | 1 | 8 | 23 | ±3.4×1038 | ~7 digits | 4 bytes |
| FP16 | 16 | 1 | 5 | 10 | ±65,504 | ~3.3 digits | 2 bytes |
| BF16 | 16 | 1 | 8 | 7 | ±3.4×1038 | ~2.4 digits | 2 bytes |
Notice the key difference between FP16 and BF16: they're both 16 bits, but they split those bits differently. FP16 gives 5 exponent bits (limited range) and 10 mantissa bits (decent precision). BF16 gives 8 exponent bits (same range as FP32!) and only 7 mantissa bits (less precision). BF16 = FP32's range in FP16's size.
Let's trace a real gradient through all three formats. Suppose a gradient has value 0.00003 in FP32. This is typical for deep layers in a large model.
In FP32: The smallest positive normal number is ~1.2 × 10-38. Our value 0.00003 = 3 × 10-5 is enormous by comparison. Stored exactly (to 7 digits of precision): 0.00003000000. No problem.
In FP16: With only 5 exponent bits, the smallest positive normal number is 2-14 ≈ 0.00006103. Our gradient is 0.00003 — that's below the smallest normal number. It enters the subnormal range, losing precision rapidly. Values below ~5.96 × 10-8 round to exactly zero. Our 0.00003 survives as a subnormal (~0.00003), but with only ~1 significant digit. A gradient of 0.000003 would be dead.
In BF16: With 8 exponent bits (same as FP32), the smallest positive normal number is ~1.2 × 10-38. Our gradient 0.00003 is comfortably representable. It rounds to the nearest BF16 value: approximately 0.0000305 (only ~2.4 digits of precision, but the value is preserved). The weight update happens.
Here's the complete recipe used by every modern training framework:
The master weights are the key insight. The optimizer (Adam, AdamW, etc.) needs full precision because weight updates are tiny: a learning rate of 10-4 times a gradient of 10-3 gives a delta of 10-7. In BF16, that delta rounds away. In FP32, it accumulates over thousands of steps.
Memory breakdown for a 7B model with Adam:
| Component | Pure FP32 | Mixed Precision (BF16) |
|---|---|---|
| Model weights | 28 GB (FP32) | 14 GB (BF16) + 28 GB (FP32 master) |
| Adam states (m, v) | 56 GB (FP32) | 56 GB (FP32) |
| Activations | ~40 GB (FP32) | ~20 GB (BF16) |
| Gradients | 28 GB (FP32) | 14 GB (BF16) |
| Total | ~152 GB | ~132 GB |
Wait — that's only 13% savings on paper? The real win is throughput. BF16 matrix multiplications run 2-4× faster on Tensor Cores (A100, H100). And the activation memory drop from 40 to 20 GB means you can double your batch size.
Explore the representable values of each format. Pick a number and see what happens.
Pick a value with the slider. Green = representable, yellow = rounded, red = underflow to zero. The density of tick marks shows where each format has precision.
python # Modern approach: BF16 with torch.autocast import torch model = MyModel().cuda() optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4) for batch in dataloader: inputs, targets = batch # Forward pass in BF16 — 2x faster matmuls with torch.autocast(device_type="cuda", dtype=torch.bfloat16): logits = model(inputs) loss = loss_fn(logits, targets) # loss auto-cast to FP32 # Backward in BF16 (autocast context persists for grads) loss.backward() # Optimizer step: PyTorch auto-converts grads to FP32 # for the master weight update optimizer.step() optimizer.zero_grad()
python # Manual approach — see what autocast does under the hood model_bf16 = MyModel().cuda().to(torch.bfloat16) master_weights = {n: p.float().clone() for n, p in model_bf16.named_parameters()} optimizer = torch.optim.AdamW(master_weights.values(), lr=3e-4) for batch in dataloader: # Forward in BF16 logits = model_bf16(inputs.bfloat16()) loss = logits.float().cross_entropy(targets) # loss in FP32 # Backward: gradients flow in BF16 loss.backward() # Copy BF16 grads -> FP32 master params for n, p in model_bf16.named_parameters(): master_weights[n].grad = p.grad.float() # Optimizer updates FP32 master weights optimizer.step() optimizer.zero_grad() # Copy FP32 master -> BF16 model with torch.no_grad(): for n, p in model_bf16.named_parameters(): p.copy_(master_weights[n].bfloat16())
Mixed precision training in FP16 still has a problem: many gradients live in the range 10-5 to 10-8, all of which underflow to zero in FP16. You might lose 50-80% of your gradient values — the network trains, but slowly and poorly. Loss scaling is the fix.
The idea is beautifully simple. Loss scaling multiplies the loss by a large constant S (say 1024) before backprop. Because backprop is linear in the loss, this scales every single gradient by the same factor S. All those tiny 10-7 gradients become 10-4 gradients — safely inside FP16's representable range. After backprop, you divide every gradient by S to get the true values back.
Same math. No underflow. That's the whole trick.
True gradient: g = 0.00002. In FP16, the smallest normal number is ~0.00006. Our gradient enters the subnormal range and loses most of its precision. A value like 0.000003 would underflow to exactly zero.
With loss scaling (S = 1024):
The gradient survived. Without scaling, it would have been truncated or zeroed. Over thousands of training steps, those lost gradients compound into a slower, worse model.
What scale factor should you use? Too small, and gradients still underflow. Too large, and gradients overflow (exceed FP16's max of 65,504), producing NaN or Inf. The answer: dynamic loss scaling.
Start with a large scale factor, like S = 216 = 65,536. Every training step, check if any gradient overflowed (is NaN or Inf). If yes: skip that optimizer step and halve S. If no overflow for N consecutive steps (typically N = 2000): double S.
This automatically finds the sweet spot — the largest scale factor that
doesn't cause overflow. PyTorch's GradScaler does exactly this.
See how loss scaling shifts the gradient histogram from the underflow danger zone into the representable zone.
Gradients from a real-ish distribution. Red region = underflows to zero in FP16. Green = safely representable. Drag the scale factor to shift the histogram.
python # FP16 with GradScaler — the classic approach (pre-BF16 hardware) import torch from torch.amp import GradScaler model = MyModel().cuda() optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4) scaler = GradScaler() # Starts with scale = 65536 for batch in dataloader: optimizer.zero_grad() # Forward in FP16 with torch.autocast(device_type="cuda", dtype=torch.float16): logits = model(inputs) loss = loss_fn(logits, targets) # Scale loss -> backward -> scaled gradients in FP16 scaler.scale(loss).backward() # Unscale grads -> clip -> step (skips if overflow detected) scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) scaler.step(optimizer) # No-op if grads overflowed scaler.update() # Adjust scale factor
python # Manual loss scaling — what GradScaler does internally scale = 65536.0 growth_interval = 2000 steps_since_overflow = 0 for batch in dataloader: with torch.autocast(device_type="cuda", dtype=torch.float16): logits = model(inputs) loss = loss_fn(logits, targets) # Scale and backprop scaled_loss = loss * scale scaled_loss.backward() # Check for overflow has_overflow = any( torch.isinf(p.grad).any() or torch.isnan(p.grad).any() for p in model.parameters() if p.grad is not None ) if has_overflow: scale /= 2 # Halve on overflow steps_since_overflow = 0 optimizer.zero_grad() continue # Skip this step entirely # Unscale gradients for p in model.parameters(): if p.grad is not None: p.grad /= scale optimizer.step() optimizer.zero_grad() steps_since_overflow += 1 if steps_since_overflow >= growth_interval: scale *= 2 # Double if stable steps_since_overflow = 0
torch.autocast with dtype=torch.bfloat16 doesn't use a
GradScaler. If your hardware supports BF16 (A100, H100, TPUs, Apple
M-series), skip loss scaling entirely. It's an FP16 bandaid, not a universal
technique.
During backprop, you need the activations from the forward pass. Each layer's backward step uses the output of the previous layer to compute gradients. Normally, you store all of these in memory during the forward pass and consume them during the backward pass.
For a 96-layer transformer with a batch of sequences, the stored activations can easily exceed 60 GB — more than the weights, optimizer states, and gradients combined. For very long sequences, activations dominate everything.
Gradient checkpointing (also called activation checkpointing or rematerialization) solves this with a simple tradeoff: store activations at only every Kth layer, and recompute the rest during backprop. You trade a small amount of extra compute for a massive reduction in memory.
In standard backprop through N layers, the forward pass stores N activations, and the backward pass consumes them in reverse. Memory = O(N).
With checkpointing every K layers, you divide the network into N/K segments of K layers each. During the forward pass, you only save the activation at the boundary of each segment — that's N/K saved activations. During backprop, when you need a layer's activation inside a segment, you re-run the forward pass from the nearest saved checkpoint.
The saved checkpoints cost N/K memory. The recomputation within each segment needs at most K activations at a time (just the current segment). Total memory: N/K + K. The extra compute cost: one additional forward pass per segment = N/K forward passes of K layers each = N extra layer-forward-passes = roughly one full extra forward pass through the entire network.
Standard backprop (K=1, save everything):
Checkpoint every K=8 layers:
Checkpoint every K=16 layers:
Extreme: K=N (save nothing, checkpoint only input):
python import torch from torch.utils.checkpoint import checkpoint class TransformerWithCheckpointing(torch.nn.Module): def __init__(self, n_layers=64, d_model=4096): super().__init__() self.layers = torch.nn.ModuleList([ TransformerBlock(d_model) for _ in range(n_layers) ]) self.ckpt_every = int(n_layers ** 0.5) # sqrt(N) def forward(self, x): for i, layer in enumerate(self.layers): if i % self.ckpt_every == 0 and self.training: # Checkpoint this layer — don't save activation, # recompute during backward pass x = checkpoint(layer, x, use_reentrant=False) else: x = layer(x) return x
python # Even simpler: checkpoint_sequential for a block of layers from torch.utils.checkpoint import checkpoint_sequential class SimpleCheckpointModel(torch.nn.Module): def __init__(self): super().__init__() self.layers = torch.nn.Sequential(*[ TransformerBlock(4096) for _ in range(64) ]) def forward(self, x): # Split into 8 segments, checkpoint each return checkpoint_sequential(self.layers, segments=8, input=x)
Watch checkpointing in action. Configure the model, set the checkpoint interval, and step through forward and backward passes to see exactly which activations are saved, which are recomputed, and how memory compares.
Top: Layer diagram. Green = saved activation, gray = discarded (will recompute).
During backward: yellow = recomputing, blue = currently computing gradient.
Bottom: Memory and compute comparison vs standard backprop.
You've learned each technique in isolation. Now let's see how they all fit together in a single training step of a real LLM. Every one of these techniques is used simultaneously — they're not alternatives, they're a stack.
Here is the complete pipeline for one optimizer step in a modern LLM training run (like LLaMA, GPT-4, or Gemma). Every box is a technique from this lesson.
Remove any one of these techniques and something breaks at scale:
| Without this | What happens at 7B+ parameters |
|---|---|
| Mixed precision | 28 GB weights + 56 GB optimizer + ~40 GB activations = 124 GB FP32. Doesn't fit on one H100 (80 GB). Also 2-4× slower. |
| Master weights | Small updates (lr × grad ≈ 10-7) round to zero in BF16. Model stops learning after initial fast progress. |
| Gradient checkpointing | Activation memory for 96-layer model with long sequences can hit 60+ GB. OOM on anything but multi-GPU setups. |
| Gradient clipping | A single bad batch with anomalous gradients can spike loss and take days to recover. Common in early training. |
| Gradient accumulation | Effective batch size too small. LLMs need batch sizes of 1M-4M tokens. Can't fit that in one forward pass. |
| LR schedule | Constant LR either too high (diverges) or too low (slow). Warmup prevents early instabilities. Decay is essential. |
python # Complete modern training loop — all techniques combined import torch from torch.utils.checkpoint import checkpoint # -- Setup -- model = LLaMA_7B().cuda().bfloat16() # BF16 model master = {n: p.float().clone() for n, p in model.named_parameters()} optimizer = torch.optim.AdamW(master.values(), lr=3e-4, betas=(0.9, 0.95)) scheduler = CosineWarmupScheduler(optimizer, warmup=2000, total=100000) accum_steps = 8 # 8 micro-batches per optimizer step max_norm = 1.0 # gradient clip threshold # -- Training Loop -- for step, batch in enumerate(dataloader): micro_batches = split(batch, accum_steps) # -- Accumulate gradients over micro-batches -- for i, micro in enumerate(micro_batches): # 1. Forward in BF16 (with checkpointing inside model) with torch.autocast("cuda", dtype=torch.bfloat16): logits = model(micro.input_ids) # checkpointed internally loss = cross_entropy(logits, micro.labels) loss = loss / accum_steps # normalize for accumulation # 3. Backward — grads accumulate in .grad buffers loss.backward() # 4. Copy accumulated BF16 grads -> FP32 master params for n, p in model.named_parameters(): master[n].grad = p.grad.float() # 5. Clip gradient norm torch.nn.utils.clip_grad_norm_(master.values(), max_norm=max_norm) # 6. Optimizer step (FP32) optimizer.step() optimizer.zero_grad() # 7. LR schedule step scheduler.step() # 8. Copy FP32 master -> BF16 model with torch.no_grad(): for n, p in model.named_parameters(): p.copy_(master[n].bfloat16()) # Zero model grads for next accumulation model.zero_grad()
Click any stage to see its details. Toggle options to see how the pipeline changes.
Each box is a pipeline stage. Color indicates precision: blue = BF16, orange = FP32. Click a stage to highlight it. Toggle options below to modify the pipeline.
| Component | Memory | Precision | Notes |
|---|---|---|---|
| Model weights (BF16) | 14 GB | BF16 | For forward/backward |
| Master weights (FP32) | 28 GB | FP32 | For optimizer |
| Adam m, v states | 56 GB | FP32 | First + second moments |
| Gradients (BF16) | 14 GB | BF16 | Accumulated across micro-batches |
| Activations (ckpt) | ~5 GB | BF16 | With checkpointing (vs ~20 GB without) |
| Total | ~117 GB | — | Need model parallelism (split across GPUs) |
Even with every optimization, a 7B model's training state exceeds 80 GB. That's why real LLM training uses model parallelism (split the model across multiple GPUs) and ZeRO (shard optimizer states). But that's a distributed systems lesson, not a gradient flow lesson.
Let's race them all.
You've learned six techniques for taming gradients: no intervention, gradient clipping alone, mixed precision (BF16), gradient accumulation, gradient checkpointing, and the full modern pipeline. Each handles different failure modes — but reading about them is one thing. Watching them train side by side is another.
This simulation trains six configurations on the same loss landscape. Drag the sliders to create the conditions where each one fails. You'll discover that no single technique is enough — you need the full stack.
Six training configurations race on the same task. Each demonstrates characteristic failure modes. Find where each one breaks.
Experiment 1: Depth = 48 layers. The baseline (no clipping, no mixed precision) explodes almost immediately — the gradient product through 48 layers overflows. With clipping alone, it survives but trains slowly. The full stack trains smoothly because BF16 prevents the intermediate overflow and clipping catches spikes.
Experiment 2: Learning rate = 1.0. Drag the LR slider all the way right. Everything except the full stack diverges. The full stack's gradient clipping caps the step size, and the optimizer's adaptive moments (Adam) further stabilize. Clipping alone helps but isn't sufficient without the momentum buffer.
Experiment 3: Batch size = 1. With a single sample per batch, gradient variance is enormous. The accumulation config (which simulates a larger effective batch) produces a much smoother loss curve. This is the main benefit of gradient accumulation: variance reduction, not just memory savings.
Experiment 4: Compare +BF16 vs +Clip. At moderate depth and LR, both work well. But crank the depth to 48: BF16 prevents intermediate overflow during the forward pass (smaller numbers), while clipping only fixes the backward pass. You need both.
| Config | Fails when... | Symptom | What's missing |
|---|---|---|---|
| Baseline | Depth > 8 or LR > 0.01 | Loss explodes (NaN) | Everything |
| +Clip only | Very deep networks | Trains but slowly (gradients always clipped) | Mixed precision + proper init |
| +BF16 only | Bad batch + no clipping | Single spike derails training | Gradient clipping |
| +Accum only | Deep + high LR | Smoother but still explodes | Clipping + mixed precision |
| +Ckpt only | Same as baseline | Saves memory but doesn't fix gradient dynamics | Clipping + mixed precision |
| Full Stack | Rarely (needs extreme settings) | Stable training across all configs | Nothing — this is the production recipe |
You now understand the complete gradient flow toolkit — from the chain rule to the full modern training pipeline. This chapter is your practical reference. No new concepts. Just the techniques, the decision guide, and the connections to where you go next.
| Technique | What it does | When you need it | Cost |
|---|---|---|---|
| Gradient Clipping | Caps gradient norm to max_norm | Always. Prevents explosion from bad batches. | ~0% compute. One norm computation. |
| Gradient Accumulation | Sums gradients over N micro-batches | When batch size > GPU memory | N× forward-backward time. No extra memory. |
| Mixed Precision (BF16) | Forward/backward in BF16, optimizer in FP32 | Always. 2× speed, 0.5× activation memory. | Extra memory for FP32 master weights. |
| Loss Scaling | Multiplies loss by S to prevent FP16 underflow | FP16 only. Unnecessary with BF16. | ~0%. One multiply and divide. |
| Gradient Checkpointing | Recomputes activations instead of storing them | When activation memory exceeds GPU budget | ~33% extra compute (at K=√N). |
| Master Weights | FP32 copy of weights for optimizer | Always with mixed precision | 2× weight memory (BF16 + FP32). |
Follow the path that matches your situation:
| Symbol | Meaning | Typical values |
|---|---|---|
| ∇L | Gradient of loss w.r.t. parameters | 10-6 to 10-1 |
| max_norm | Gradient clipping threshold | 1.0 (LLMs), 0.5-5.0 (vision) |
| N | Number of accumulation micro-batches | 4-64 |
| K | Checkpoint interval (layers) | √(total layers) |
| S | Loss scale factor (FP16 only) | 210 to 216 |
| FP32 | 32-bit floating point (8e 23m) | Optimizer, master weights |
| BF16 | 16-bit brain float (8e 7m) | Forward/backward, activations |
| FP16 | 16-bit float (5e 10m) | Legacy. Use BF16 if available. |
python import torch from torch.amp import GradScaler from torch.utils.checkpoint import checkpoint # Gradient clipping torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # Mixed precision (BF16 — preferred) with torch.autocast("cuda", dtype=torch.bfloat16): ... # Mixed precision (FP16 + loss scaling) scaler = GradScaler() scaler.scale(loss).backward() scaler.unscale_(optimizer) scaler.step(optimizer) scaler.update() # Gradient accumulation loss = loss / accum_steps loss.backward() # grads accumulate in .grad if step % accum_steps == 0: optimizer.step(); optimizer.zero_grad() # Gradient checkpointing x = checkpoint(layer, x, use_reentrant=False)
Gradient flow doesn't exist in isolation. Here's where to go next:
| Paper | Year | Contribution |
|---|---|---|
| Hochreiter, "Vanishing Gradient Problem" | 1991 | Identified the fundamental problem of gradient decay in deep networks |
| Pascanu et al., "On the Difficulty of Training RNNs" | 2013 | Gradient clipping for RNNs. Popularized norm clipping. |
| Micikevicius et al., "Mixed Precision Training" | 2018 | FP16 + loss scaling + master weights recipe |
| Chen et al., "Training Deep Nets with Sublinear Memory Cost" | 2016 | Gradient checkpointing (rematerialization) |
| Kalamkar et al., "A Study of BFLOAT16 for DL Training" | 2019 | BF16 as the practical replacement for FP16 + loss scaling |