Training Foundations

Gradient Flow

Why gradients vanish and explode — and the five techniques that keep every modern LLM training stable.

Prerequisites: What backpropagation does + Chain rule (multiply derivatives). That's it.
10
Chapters
12+
Simulations
0
Assumed Knowledge

Chapter 0: The Telephone Game

Imagine a chain of 50 people. The first person whispers a number — say 1.0 — to the next. Each person multiplies the number by their own factor before passing it along. If every factor is 0.9, the number after 50 steps is 0.9500.005. The original message is almost gone.

If every factor is 1.1 instead, the number after 50 steps is 1.150117. The message has become a scream.

This is exactly what happens to gradients in a deep neural network. During backpropagation, the gradient passes through each layer, getting multiplied by that layer's local derivative. After 50 layers, the gradient at the first layer is the product of 50 multiplications. If most of those factors are less than 1, the gradient vanishes. If they're greater than 1, it explodes.

This is the central problem of deep learning. Every technique in this lesson — gradient clipping, mixed precision, loss scaling, checkpointing, gradient accumulation — exists because gradients are fragile. They vanish, they explode, they underflow to zero in low-precision formats, and they eat memory. Understanding gradient flow is understanding why training works at all.

The Simulation

Watch the telephone game in action. Each person in the chain multiplies by a factor. Drag the slider to change that factor and see how the signal decays or explodes.

The Gradient Telephone Game

50-layer chain. Each layer multiplies by the factor below. The bar height shows the gradient magnitude at each layer.

Layer factor 0.95

With factor = 0.9, the gradient at layer 1 is less than 1% of the gradient at layer 50. The first layers barely learn. With factor = 1.1, the gradient at layer 1 is 117× the gradient at layer 50. The optimizer step is enormous and training diverges. Only at exactly 1.0 does the gradient survive intact — and real networks never achieve this perfectly.

In a 50-layer network where each layer multiplies the gradient by 0.95, what is the approximate gradient magnitude at layer 1 relative to layer 50?

Chapter 1: The Chain Rule is Multiplication

Chapter 0 showed the symptom. Now let's understand the cause. Backpropagation computes gradients using the chain rule: to find how a change in layer 1's weights affects the loss, you multiply the local derivatives of every layer in between.

For a network with L layers, the gradient at layer k is:

∂L / ∂Wk = (∂L / ∂aL) · (∂aL / ∂aL-1) · … · (∂ak+1 / ∂ak) · (∂ak / ∂Wk)

Each factor (∂ai+1 / ∂ai) is the Jacobian of layer i — how much the output changes per unit change in the input. For a simple linear layer y = Wx + b followed by ReLU, the Jacobian has two parts: the weight matrix W and the derivative of ReLU (which is 1 for positive inputs, 0 for negative).

Hand Calculation: 3-Layer Network

Let's trace the gradient through a tiny network with 3 layers. Each layer is y = σ(Wx) where σ is the sigmoid function.

Layer 3 (output): Gradient from loss = 1.0 (starting point).
Layer 2: Multiply by σ'(z3) · W3. Sigmoid's maximum derivative is 0.25 (at z=0). With a weight magnitude of ~0.5, the factor is 0.25 × 0.5 = 0.125.
Layer 1: Multiply again: 0.125 × 0.125 = 0.0156.

After just 3 layers with sigmoid, the gradient is already 64× smaller. After 10 layers: 0.12510 ≈ 10-9. The gradient is effectively dead. This is why sigmoid networks couldn't be trained beyond ~5 layers — and why the invention of ReLU (whose derivative is 1, not 0.25) was transformative.

ReLU's derivative is exactly 1 for positive inputs. That means the gradient multiplication factor is just the weight magnitude. If weights are initialized near 1.0, the gradient neither vanishes nor explodes. This is why ReLU replaced sigmoid in deep networks: it solved the vanishing gradient problem for the activation function. The remaining problem is the weight matrices themselves.

The Simulation

See exactly how the chain rule multiplies derivatives layer by layer. Pick the activation function and watch the gradient shrink or survive.

Backpropagation: Layer-by-Layer Gradient Flow

Each layer shows its local derivative. The gradient at each layer is the running product of all derivatives to the right. Watch how sigmoid kills the gradient while ReLU preserves it.

Layers 10
Weight scale 0.8
Why did ReLU largely solve the vanishing gradient problem?

Chapter 2: Gradient Clipping

The telephone game showed that gradients can explode. In practice, a single bad batch — an outlier example, a corrupted sample, or an unlucky combination — can produce a gradient 100× or 1000× normal magnitude. The optimizer applies that enormous gradient as a weight update, the loss spikes, and it can take days of training to recover. In some cases, training never recovers.

Gradient clipping is the safety valve. Before the optimizer step, you check the gradient's magnitude and cap it at a maximum threshold. If the gradient is within bounds, nothing changes. If it's too large, you scale it down to the threshold.

Two Types of Clipping

Norm clipping (the standard approach): compute the global norm of all gradients — that's √(∑ gi2) across every parameter in the entire model. If the global norm exceeds a threshold (typically 1.0), scale every gradient by (threshold / norm). This preserves the direction of the gradient while capping its magnitude.

g ← g · min(1, max_norm / ‖g‖)

Value clipping (less common): clamp each gradient element independently to [-threshold, +threshold]. This changes the direction of the gradient vector, which is usually undesirable. Norm clipping is preferred for this reason.

Hand Calculation: Norm Clipping

A small model has 3 parameter tensors with gradients [3.0, 4.0] and [0.0]. The global norm is √(32 + 42 + 02) = √25 = 5.0.

With max_norm = 1.0: since 5.0 > 1.0, we scale by 1.0 / 5.0 = 0.2. The clipped gradients become [0.6, 0.8] and [0.0]. Same direction, magnitude capped at 1.0.

With max_norm = 10.0: since 5.0 < 10.0, min(1, 10/5) = min(1, 2) = 1. No clipping occurs. The gradients pass through unchanged.

Clipping preserves direction. When the gradient norm exceeds the threshold, norm clipping scales ALL gradients by the same factor. The optimizer still moves in the same direction — just with a shorter step. This is why norm clipping is preferred over value clipping, which distorts the direction.

The Simulation

Visualize norm clipping as a circle. Any gradient vector landing outside the circle gets pulled back to the boundary.

Gradient Clipping: Norm vs Value

Red dot = unclipped gradient. Green dot = after norm clipping. Yellow dot = after value clipping. The circle is the max_norm boundary. Click/drag to move the gradient.

max_norm 1.0

Code: Gradient Clipping in PyTorch

python
import torch

model = MyModel()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

for batch in dataloader:
    loss = model(batch)
    loss.backward()

    # Norm clipping — the standard approach for LLMs
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

    optimizer.step()
    optimizer.zero_grad()
python
# Manual implementation — see what clip_grad_norm_ does
def clip_grad_norm(parameters, max_norm):
    total_norm = 0.0
    for p in parameters:
        if p.grad is not None:
            total_norm += p.grad.data.norm(2).item() ** 2
    total_norm = total_norm ** 0.5

    clip_coef = max_norm / (total_norm + 1e-6)
    if clip_coef < 1.0:
        for p in parameters:
            if p.grad is not None:
                p.grad.data.mul_(clip_coef)
    return total_norm
Why is norm clipping preferred over value clipping?

Chapter 3: Gradient Accumulation

Large language models need large batch sizes to train stably. GPT-3 used an effective batch size of 3.2 million tokens. LLaMA-2 used 4 million tokens. A single H100 GPU with 80 GB of memory can barely fit a batch of 4 sequences of length 2048 — that's about 8,000 tokens. How do you bridge the gap from 8,000 to 4,000,000?

Gradient accumulation. Instead of updating weights after every forward-backward pass, you run N micro-batches, accumulate (sum) the gradients, and then do ONE optimizer step using the total. Mathematically, this is identical to running one giant batch of size N × micro_batch_size — because gradients are linear in the loss, and the loss is typically averaged over the batch.

How It Works

Micro-batch 1
Forward + backward. Gradients stored in .grad buffers. Don't step.
Micro-batch 2
Forward + backward. New gradients add to existing .grad buffers.
… N times
Keep accumulating. Memory usage = 1 micro-batch, not N.
Optimizer step
Divide accumulated grads by N. Clip. Update weights. Zero grads.
↻ repeat

Hand Calculation: Why It's Mathematically Identical

True big batch (B=8): Loss = (1/8) ∑i=1..8 Li. Gradient = (1/8) ∑ ∇Li.

Accumulated (2 micro-batches of 4):

Equivalently, if you divide the loss by the accumulation steps before backward, you don't need the final division: loss = loss / accum_steps before loss.backward().

Accumulation trades time for memory. With 8 accumulation steps, you use 8× the training time per optimizer step (8 forward-backward passes instead of 1), but you use the same memory as a single micro-batch. This is how a single GPU can simulate batch sizes of thousands.

The Simulation

Watch gradients accumulate across micro-batches. Each micro-batch contributes a noisy gradient estimate; the accumulation averages out the noise.

Gradient Accumulation Visualizer

Each arrow is one micro-batch's gradient (noisy). The thick arrow is the accumulated mean. More micro-batches = less noise, bigger effective batch.

Accumulation steps 4
Noise level 0.8

Code: Gradient Accumulation in PyTorch

python
accum_steps = 8  # Effective batch = micro_batch x 8

for step, batch in enumerate(dataloader):
    loss = model(batch)
    loss = loss / accum_steps  # Normalize BEFORE backward
    loss.backward()            # Gradients ADD to .grad buffers

    if (step + 1) % accum_steps == 0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        optimizer.zero_grad()  # Reset .grad to zero
Common bug: forgetting to divide the loss. If you call loss.backward() N times without dividing by N, your effective learning rate is N× too high. The model will diverge. Always either (a) divide the loss by accum_steps before backward, or (b) divide the gradients by accum_steps before the optimizer step.
With gradient accumulation over 8 micro-batches of size 4, what is the effective batch size?

Chapter 4: Mixed Precision Training

A 7B-parameter model in FP32 takes 28 GB just for the weights. In FP16, it takes 14 GB. Same model, half the memory, roughly 2× the throughput on modern GPUs. The catch: FP16 can only represent numbers between ±65,504, and anything below ~0.00006 rounds to zero. Gradients are often smaller than that.

The solution is mixed precision training: use half-precision (FP16 or BF16) for the fast parts — forward pass, most of backward pass — but keep a master copy of the weights in full FP32 for the optimizer update. You get the speed of half precision and the accuracy of full precision.

But to understand why this works, we need to understand the three floating point formats and what each one sacrifices.

The Three Formats

Every floating point number is stored as three fields: a sign bit (positive or negative), exponent bits (the scale — powers of 2), and mantissa bits (the precision — significant digits). More exponent bits means wider range. More mantissa bits means finer precision.

FormatTotal bitsSignExponentMantissaRangePrecisionMemory
FP32321823±3.4×1038~7 digits4 bytes
FP16161510±65,504~3.3 digits2 bytes
BF1616187±3.4×1038~2.4 digits2 bytes

Notice the key difference between FP16 and BF16: they're both 16 bits, but they split those bits differently. FP16 gives 5 exponent bits (limited range) and 10 mantissa bits (decent precision). BF16 gives 8 exponent bits (same range as FP32!) and only 7 mantissa bits (less precision). BF16 = FP32's range in FP16's size.

Why range beats precision for training. Gradients span a huge dynamic range — some are 10-7, others are 102. You need range to represent all of them without underflow. But you don't need 7 decimal digits of precision in each gradient — 2-3 digits is plenty when you're making thousands of noisy updates anyway. BF16 makes the right tradeoff.

Hand Calculation: Underflow in Action

Let's trace a real gradient through all three formats. Suppose a gradient has value 0.00003 in FP32. This is typical for deep layers in a large model.

In FP32: The smallest positive normal number is ~1.2 × 10-38. Our value 0.00003 = 3 × 10-5 is enormous by comparison. Stored exactly (to 7 digits of precision): 0.00003000000. No problem.

In FP16: With only 5 exponent bits, the smallest positive normal number is 2-140.00006103. Our gradient is 0.00003 — that's below the smallest normal number. It enters the subnormal range, losing precision rapidly. Values below ~5.96 × 10-8 round to exactly zero. Our 0.00003 survives as a subnormal (~0.00003), but with only ~1 significant digit. A gradient of 0.000003 would be dead.

In BF16: With 8 exponent bits (same as FP32), the smallest positive normal number is ~1.2 × 10-38. Our gradient 0.00003 is comfortably representable. It rounds to the nearest BF16 value: approximately 0.0000305 (only ~2.4 digits of precision, but the value is preserved). The weight update happens.

BF16 is NOT just "better FP16." It sacrifices precision (only ~2.4 decimal digits vs FP16's ~3.3) for range. This means BF16 math is less precise for medium-sized numbers — if you multiply 1.5 × 3.7, BF16 gives you more rounding error than FP16. The tradeoff: you never get underflow or overflow, but your intermediate computations have more rounding error. For gradients, range matters more than precision, so BF16 wins.

The Mixed Precision Recipe

Here's the complete recipe used by every modern training framework:

Forward Pass
Run in BF16/FP16. Fast, half the memory for activations.
Loss Computation
Cast to FP32 for the loss. Small numbers matter here.
Backward Pass
Run in BF16/FP16. Gradients computed in half precision.
Optimizer Step
Convert gradients to FP32. Apply to FP32 master weights. Copy back to BF16.
↻ repeat

The master weights are the key insight. The optimizer (Adam, AdamW, etc.) needs full precision because weight updates are tiny: a learning rate of 10-4 times a gradient of 10-3 gives a delta of 10-7. In BF16, that delta rounds away. In FP32, it accumulates over thousands of steps.

Memory breakdown for a 7B model with Adam:

ComponentPure FP32Mixed Precision (BF16)
Model weights28 GB (FP32)14 GB (BF16) + 28 GB (FP32 master)
Adam states (m, v)56 GB (FP32)56 GB (FP32)
Activations~40 GB (FP32)~20 GB (BF16)
Gradients28 GB (FP32)14 GB (BF16)
Total~152 GB~132 GB

Wait — that's only 13% savings on paper? The real win is throughput. BF16 matrix multiplications run 2-4× faster on Tensor Cores (A100, H100). And the activation memory drop from 40 to 20 GB means you can double your batch size.

The Simulation

Explore the representable values of each format. Pick a number and see what happens.

Floating Point Number Lines

Pick a value with the slider. Green = representable, yellow = rounded, red = underflow to zero. The density of tick marks shows where each format has precision.

Value 10-3

Code: Mixed Precision Training Loop

python
# Modern approach: BF16 with torch.autocast
import torch

model = MyModel().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

for batch in dataloader:
    inputs, targets = batch

    # Forward pass in BF16 — 2x faster matmuls
    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        logits = model(inputs)
        loss = loss_fn(logits, targets)  # loss auto-cast to FP32

    # Backward in BF16 (autocast context persists for grads)
    loss.backward()

    # Optimizer step: PyTorch auto-converts grads to FP32
    # for the master weight update
    optimizer.step()
    optimizer.zero_grad()
python
# Manual approach — see what autocast does under the hood
model_bf16 = MyModel().cuda().to(torch.bfloat16)
master_weights = {n: p.float().clone() for n, p in model_bf16.named_parameters()}
optimizer = torch.optim.AdamW(master_weights.values(), lr=3e-4)

for batch in dataloader:
    # Forward in BF16
    logits = model_bf16(inputs.bfloat16())
    loss = logits.float().cross_entropy(targets)  # loss in FP32

    # Backward: gradients flow in BF16
    loss.backward()

    # Copy BF16 grads -> FP32 master params
    for n, p in model_bf16.named_parameters():
        master_weights[n].grad = p.grad.float()

    # Optimizer updates FP32 master weights
    optimizer.step()
    optimizer.zero_grad()

    # Copy FP32 master -> BF16 model
    with torch.no_grad():
        for n, p in model_bf16.named_parameters():
            p.copy_(master_weights[n].bfloat16())
Why is BF16 preferred over FP16 for training?

Chapter 5: Loss Scaling

Mixed precision training in FP16 still has a problem: many gradients live in the range 10-5 to 10-8, all of which underflow to zero in FP16. You might lose 50-80% of your gradient values — the network trains, but slowly and poorly. Loss scaling is the fix.

The idea is beautifully simple. Loss scaling multiplies the loss by a large constant S (say 1024) before backprop. Because backprop is linear in the loss, this scales every single gradient by the same factor S. All those tiny 10-7 gradients become 10-4 gradients — safely inside FP16's representable range. After backprop, you divide every gradient by S to get the true values back.

Same math. No underflow. That's the whole trick.

How It Works

Scale
Multiply loss by S: Lscaled = S · L
Backprop
Gradients are S · ∇L (all shifted into FP16 range)
Unscale
Divide: g = gscaled / S (back to true values, now in FP32)
Optimizer
Apply true gradients to FP32 master weights

Hand Calculation: Rescuing a Gradient

True gradient: g = 0.00002. In FP16, the smallest normal number is ~0.00006. Our gradient enters the subnormal range and loses most of its precision. A value like 0.000003 would underflow to exactly zero.

With loss scaling (S = 1024):

The gradient survived. Without scaling, it would have been truncated or zeroed. Over thousands of training steps, those lost gradients compound into a slower, worse model.

Dynamic Loss Scaling

What scale factor should you use? Too small, and gradients still underflow. Too large, and gradients overflow (exceed FP16's max of 65,504), producing NaN or Inf. The answer: dynamic loss scaling.

Start with a large scale factor, like S = 216 = 65,536. Every training step, check if any gradient overflowed (is NaN or Inf). If yes: skip that optimizer step and halve S. If no overflow for N consecutive steps (typically N = 2000): double S.

This automatically finds the sweet spot — the largest scale factor that doesn't cause overflow. PyTorch's GradScaler does exactly this.

The dynamic scaling dance. In early training, gradients are large and S quickly drops from 65,536 to maybe 1,024. As training stabilizes and gradients shrink, S creeps back up. The scaler hunts for the maximum usable scale, always pushing gradients as high into FP16's range as possible without overflow.

The Simulation

See how loss scaling shifts the gradient histogram from the underflow danger zone into the representable zone.

Gradient Histogram: With & Without Scaling

Gradients from a real-ish distribution. Red region = underflows to zero in FP16. Green = safely representable. Drag the scale factor to shift the histogram.

Scale Factor (2k) 20 = 1

Code: Loss Scaling in PyTorch

python
# FP16 with GradScaler — the classic approach (pre-BF16 hardware)
import torch
from torch.amp import GradScaler

model = MyModel().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
scaler = GradScaler()  # Starts with scale = 65536

for batch in dataloader:
    optimizer.zero_grad()

    # Forward in FP16
    with torch.autocast(device_type="cuda", dtype=torch.float16):
        logits = model(inputs)
        loss = loss_fn(logits, targets)

    # Scale loss -> backward -> scaled gradients in FP16
    scaler.scale(loss).backward()

    # Unscale grads -> clip -> step (skips if overflow detected)
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    scaler.step(optimizer)   # No-op if grads overflowed
    scaler.update()          # Adjust scale factor
python
# Manual loss scaling — what GradScaler does internally
scale = 65536.0
growth_interval = 2000
steps_since_overflow = 0

for batch in dataloader:
    with torch.autocast(device_type="cuda", dtype=torch.float16):
        logits = model(inputs)
        loss = loss_fn(logits, targets)

    # Scale and backprop
    scaled_loss = loss * scale
    scaled_loss.backward()

    # Check for overflow
    has_overflow = any(
        torch.isinf(p.grad).any() or torch.isnan(p.grad).any()
        for p in model.parameters() if p.grad is not None
    )

    if has_overflow:
        scale /= 2       # Halve on overflow
        steps_since_overflow = 0
        optimizer.zero_grad()
        continue        # Skip this step entirely

    # Unscale gradients
    for p in model.parameters():
        if p.grad is not None:
            p.grad /= scale

    optimizer.step()
    optimizer.zero_grad()

    steps_since_overflow += 1
    if steps_since_overflow >= growth_interval:
        scale *= 2       # Double if stable
        steps_since_overflow = 0
Loss scaling is ONLY needed for FP16. BF16 has enough exponent range that gradients don't underflow, so loss scaling is unnecessary. PyTorch's torch.autocast with dtype=torch.bfloat16 doesn't use a GradScaler. If your hardware supports BF16 (A100, H100, TPUs, Apple M-series), skip loss scaling entirely. It's an FP16 bandaid, not a universal technique.
What does loss scaling prevent during FP16 training?

Chapter 6: Gradient Checkpointing

During backprop, you need the activations from the forward pass. Each layer's backward step uses the output of the previous layer to compute gradients. Normally, you store all of these in memory during the forward pass and consume them during the backward pass.

For a 96-layer transformer with a batch of sequences, the stored activations can easily exceed 60 GB — more than the weights, optimizer states, and gradients combined. For very long sequences, activations dominate everything.

Gradient checkpointing (also called activation checkpointing or rematerialization) solves this with a simple tradeoff: store activations at only every Kth layer, and recompute the rest during backprop. You trade a small amount of extra compute for a massive reduction in memory.

The Core Idea

In standard backprop through N layers, the forward pass stores N activations, and the backward pass consumes them in reverse. Memory = O(N).

With checkpointing every K layers, you divide the network into N/K segments of K layers each. During the forward pass, you only save the activation at the boundary of each segment — that's N/K saved activations. During backprop, when you need a layer's activation inside a segment, you re-run the forward pass from the nearest saved checkpoint.

The saved checkpoints cost N/K memory. The recomputation within each segment needs at most K activations at a time (just the current segment). Total memory: N/K + K. The extra compute cost: one additional forward pass per segment = N/K forward passes of K layers each = N extra layer-forward-passes = roughly one full extra forward pass through the entire network.

The optimal K = √N. Memory is N/K + K. To minimize: take the derivative with respect to K and set it to zero. You get K = √N. At this setting, memory = 2√N instead of N, and extra compute = √N forward passes. For a 64-layer model: standard backprop stores 64 activations. Checkpoint every 8 (= √64): store 8 checkpoints + recompute at most 8 at a time = 16. That's a 4× memory reduction for ~12% more compute.

Hand Calculation: 64-Layer Model

Standard backprop (K=1, save everything):

Checkpoint every K=8 layers:

Checkpoint every K=16 layers:

Extreme: K=N (save nothing, checkpoint only input):

Why √N is optimal. At K = √N, you have √N segments with √N layers each. You store √N checkpoints and buffer at most √N activations during recomputation. Both terms are equal at the minimum. Any other split has one term larger than the other.

Code: Checkpointing in PyTorch

python
import torch
from torch.utils.checkpoint import checkpoint

class TransformerWithCheckpointing(torch.nn.Module):
    def __init__(self, n_layers=64, d_model=4096):
        super().__init__()
        self.layers = torch.nn.ModuleList([
            TransformerBlock(d_model) for _ in range(n_layers)
        ])
        self.ckpt_every = int(n_layers ** 0.5)  # sqrt(N)

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            if i % self.ckpt_every == 0 and self.training:
                # Checkpoint this layer — don't save activation,
                # recompute during backward pass
                x = checkpoint(layer, x, use_reentrant=False)
            else:
                x = layer(x)
        return x
python
# Even simpler: checkpoint_sequential for a block of layers
from torch.utils.checkpoint import checkpoint_sequential

class SimpleCheckpointModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = torch.nn.Sequential(*[
            TransformerBlock(4096) for _ in range(64)
        ])

    def forward(self, x):
        # Split into 8 segments, checkpoint each
        return checkpoint_sequential(self.layers, segments=8, input=x)

The Showcase

Watch checkpointing in action. Configure the model, set the checkpoint interval, and step through forward and backward passes to see exactly which activations are saved, which are recomputed, and how memory compares.

Gradient Checkpointing Explorer

Top: Layer diagram. Green = saved activation, gray = discarded (will recompute). During backward: yellow = recomputing, blue = currently computing gradient.
Bottom: Memory and compute comparison vs standard backprop.

Layers 16
Ckpt every K 4

Chapter 7: The Modern Training Pipeline

You've learned each technique in isolation. Now let's see how they all fit together in a single training step of a real LLM. Every one of these techniques is used simultaneously — they're not alternatives, they're a stack.

Here is the complete pipeline for one optimizer step in a modern LLM training run (like LLaMA, GPT-4, or Gemma). Every box is a technique from this lesson.

The Complete Pipeline

1. Forward (BF16)
Run forward pass in BF16. Gradient checkpointing saves activations every K layers, discards the rest.
2. Loss (FP32)
Compute loss in FP32 for numerical stability. If gradient accumulation: add to running total.
3. Scale & Backward
Scale loss (FP16 only). Backward pass in BF16 — recomputes discarded activations from checkpoints.
↻ repeat 1-3 for N micro-batches (gradient accumulation)
4. Unscale
Divide accumulated gradients by scale factor (FP16 only) and by N micro-batches.
5. Clip
Clip gradient global norm to max_norm (typically 1.0). Prevents explosion from bad batches.
6. Optimizer (FP32)
Adam/AdamW updates FP32 master weights using FP32 gradients + momentum + velocity.
7. LR Schedule
Step the learning rate scheduler (warmup → cosine decay → minimum LR).
8. Cast Back
Copy FP32 master weights → BF16 model weights for next forward pass.

Why Every Step Matters

Remove any one of these techniques and something breaks at scale:

Without thisWhat happens at 7B+ parameters
Mixed precision28 GB weights + 56 GB optimizer + ~40 GB activations = 124 GB FP32. Doesn't fit on one H100 (80 GB). Also 2-4× slower.
Master weightsSmall updates (lr × grad ≈ 10-7) round to zero in BF16. Model stops learning after initial fast progress.
Gradient checkpointingActivation memory for 96-layer model with long sequences can hit 60+ GB. OOM on anything but multi-GPU setups.
Gradient clippingA single bad batch with anomalous gradients can spike loss and take days to recover. Common in early training.
Gradient accumulationEffective batch size too small. LLMs need batch sizes of 1M-4M tokens. Can't fit that in one forward pass.
LR scheduleConstant LR either too high (diverges) or too low (slow). Warmup prevents early instabilities. Decay is essential.
These techniques aren't optional luxuries. They're necessary for training models above ~1B parameters. Without mixed precision, a 7B model needs 112 GB just for optimizer states (FP32 weights + Adam first moment + second moment = 4× model size). Without gradient accumulation, you can't reach the batch sizes needed for stable training. Without checkpointing, the activation memory exceeds even an H100's 80 GB on long sequences.

The Complete Code

python
# Complete modern training loop — all techniques combined
import torch
from torch.utils.checkpoint import checkpoint

# -- Setup --
model = LLaMA_7B().cuda().bfloat16()  # BF16 model
master = {n: p.float().clone() for n, p in model.named_parameters()}
optimizer = torch.optim.AdamW(master.values(), lr=3e-4, betas=(0.9, 0.95))
scheduler = CosineWarmupScheduler(optimizer, warmup=2000, total=100000)

accum_steps = 8      # 8 micro-batches per optimizer step
max_norm = 1.0       # gradient clip threshold

# -- Training Loop --
for step, batch in enumerate(dataloader):
    micro_batches = split(batch, accum_steps)

    # -- Accumulate gradients over micro-batches --
    for i, micro in enumerate(micro_batches):
        # 1. Forward in BF16 (with checkpointing inside model)
        with torch.autocast("cuda", dtype=torch.bfloat16):
            logits = model(micro.input_ids)    # checkpointed internally
            loss = cross_entropy(logits, micro.labels)
            loss = loss / accum_steps          # normalize for accumulation

        # 3. Backward — grads accumulate in .grad buffers
        loss.backward()

    # 4. Copy accumulated BF16 grads -> FP32 master params
    for n, p in model.named_parameters():
        master[n].grad = p.grad.float()

    # 5. Clip gradient norm
    torch.nn.utils.clip_grad_norm_(master.values(), max_norm=max_norm)

    # 6. Optimizer step (FP32)
    optimizer.step()
    optimizer.zero_grad()

    # 7. LR schedule step
    scheduler.step()

    # 8. Copy FP32 master -> BF16 model
    with torch.no_grad():
        for n, p in model.named_parameters():
            p.copy_(master[n].bfloat16())

    # Zero model grads for next accumulation
    model.zero_grad()

The Pipeline Visualizer

Click any stage to see its details. Toggle options to see how the pipeline changes.

Modern LLM Training Pipeline

Each box is a pipeline stage. Color indicates precision: blue = BF16, orange = FP32. Click a stage to highlight it. Toggle options below to modify the pipeline.

Memory Budget for a 7B Model on H100 (80 GB)

ComponentMemoryPrecisionNotes
Model weights (BF16)14 GBBF16For forward/backward
Master weights (FP32)28 GBFP32For optimizer
Adam m, v states56 GBFP32First + second moments
Gradients (BF16)14 GBBF16Accumulated across micro-batches
Activations (ckpt)~5 GBBF16With checkpointing (vs ~20 GB without)
Total~117 GBNeed model parallelism (split across GPUs)

Even with every optimization, a 7B model's training state exceeds 80 GB. That's why real LLM training uses model parallelism (split the model across multiple GPUs) and ZeRO (shard optimizer states). But that's a distributed systems lesson, not a gradient flow lesson.

What you've learned in this lesson. Gradients flow backward through a chain of Jacobians (Ch 0-1). They can vanish or explode (Ch 2). We tame them with clipping (Ch 3), store them efficiently with mixed precision (Ch 4), rescue them from underflow with loss scaling (Ch 5), save memory for them with checkpointing (Ch 6), and combine everything into the modern training pipeline (Ch 7). Every LLM you use was trained with this exact stack.
In a modern LLM training pipeline, which precision is used for the optimizer's master copy of weights?

Chapter 8: The Gradient Flow Arena

Let's race them all.

You've learned six techniques for taming gradients: no intervention, gradient clipping alone, mixed precision (BF16), gradient accumulation, gradient checkpointing, and the full modern pipeline. Each handles different failure modes — but reading about them is one thing. Watching them train side by side is another.

This simulation trains six configurations on the same loss landscape. Drag the sliders to create the conditions where each one fails. You'll discover that no single technique is enough — you need the full stack.

How to use the Arena. Hit Play to start training. Adjust network depth, learning rate, and batch size with the sliders. Watch the loss curves diverge. Try these experiments: (1) Set depth=48 with no clipping and watch it explode. (2) Set LR=1.0 and see which configs survive. (3) Set batch=1 and compare accumulation vs single-batch. (4) Toggle FP16 mode and watch underflow kill gradients.
Gradient Flow Racing Arena

Six training configurations race on the same task. Each demonstrates characteristic failure modes. Find where each one breaks.

Depth 16
Log LR 0.010
Batch size 16
Speed 3

What to Discover

Experiment 1: Depth = 48 layers. The baseline (no clipping, no mixed precision) explodes almost immediately — the gradient product through 48 layers overflows. With clipping alone, it survives but trains slowly. The full stack trains smoothly because BF16 prevents the intermediate overflow and clipping catches spikes.

Experiment 2: Learning rate = 1.0. Drag the LR slider all the way right. Everything except the full stack diverges. The full stack's gradient clipping caps the step size, and the optimizer's adaptive moments (Adam) further stabilize. Clipping alone helps but isn't sufficient without the momentum buffer.

Experiment 3: Batch size = 1. With a single sample per batch, gradient variance is enormous. The accumulation config (which simulates a larger effective batch) produces a much smoother loss curve. This is the main benefit of gradient accumulation: variance reduction, not just memory savings.

Experiment 4: Compare +BF16 vs +Clip. At moderate depth and LR, both work well. But crank the depth to 48: BF16 prevents intermediate overflow during the forward pass (smaller numbers), while clipping only fixes the backward pass. You need both.

The Arena reveals a crucial insight: each technique fixes ONE failure mode. Clipping fixes gradient explosion. BF16 fixes memory and throughput. Accumulation fixes batch size. Checkpointing fixes activation memory. Loss scaling fixes FP16 underflow. No single technique is enough for large-scale training — you need the full stack working together.

Failure Mode Summary

ConfigFails when...SymptomWhat's missing
BaselineDepth > 8 or LR > 0.01Loss explodes (NaN)Everything
+Clip onlyVery deep networksTrains but slowly (gradients always clipped)Mixed precision + proper init
+BF16 onlyBad batch + no clippingSingle spike derails trainingGradient clipping
+Accum onlyDeep + high LRSmoother but still explodesClipping + mixed precision
+Ckpt onlySame as baselineSaves memory but doesn't fix gradient dynamicsClipping + mixed precision
Full StackRarely (needs extreme settings)Stable training across all configsNothing — this is the production recipe

Chapter 9: Cheat Sheet & Connections

You now understand the complete gradient flow toolkit — from the chain rule to the full modern training pipeline. This chapter is your practical reference. No new concepts. Just the techniques, the decision guide, and the connections to where you go next.

Every Technique at a Glance

TechniqueWhat it doesWhen you need itCost
Gradient ClippingCaps gradient norm to max_normAlways. Prevents explosion from bad batches.~0% compute. One norm computation.
Gradient AccumulationSums gradients over N micro-batchesWhen batch size > GPU memoryN× forward-backward time. No extra memory.
Mixed Precision (BF16)Forward/backward in BF16, optimizer in FP32Always. 2× speed, 0.5× activation memory.Extra memory for FP32 master weights.
Loss ScalingMultiplies loss by S to prevent FP16 underflowFP16 only. Unnecessary with BF16.~0%. One multiply and divide.
Gradient CheckpointingRecomputes activations instead of storing themWhen activation memory exceeds GPU budget~33% extra compute (at K=√N).
Master WeightsFP32 copy of weights for optimizerAlways with mixed precision2× weight memory (BF16 + FP32).

The Decision Flowchart

Follow the path that matches your situation:

Starting a training run?
Apply these in order. Each is nearly always beneficial.
1. Mixed Precision
Use BF16 if hardware supports it (A100, H100, TPU, M-series). Otherwise FP16 + loss scaling.
2. Gradient Clipping
Set max_norm = 1.0. This is the universal default. Adjust only if you see clipping on >50% of steps.
3. Gradient Accumulation
Set accum_steps = target_batch / micro_batch. Divide loss by accum_steps before backward.
4. Checkpointing
Enable if activations cause OOM. Start with K = √N. Accept ~33% compute overhead.
5. Monitor gradient norms
Log grad_norm every step. If it spikes >10× average, investigate the data batch.

Symbol Glossary

SymbolMeaningTypical values
∇LGradient of loss w.r.t. parameters10-6 to 10-1
max_normGradient clipping threshold1.0 (LLMs), 0.5-5.0 (vision)
NNumber of accumulation micro-batches4-64
KCheckpoint interval (layers)√(total layers)
SLoss scale factor (FP16 only)210 to 216
FP3232-bit floating point (8e 23m)Optimizer, master weights
BF1616-bit brain float (8e 7m)Forward/backward, activations
FP1616-bit float (5e 10m)Legacy. Use BF16 if available.

PyTorch One-Liner Reference

python
import torch
from torch.amp import GradScaler
from torch.utils.checkpoint import checkpoint

# Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

# Mixed precision (BF16 — preferred)
with torch.autocast("cuda", dtype=torch.bfloat16): ...

# Mixed precision (FP16 + loss scaling)
scaler = GradScaler()
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
scaler.step(optimizer)
scaler.update()

# Gradient accumulation
loss = loss / accum_steps
loss.backward()  # grads accumulate in .grad
if step % accum_steps == 0: optimizer.step(); optimizer.zero_grad()

# Gradient checkpointing
x = checkpoint(layer, x, use_reentrant=False)

Summary of Everything

Ch 0: The Telephone Game
Gradients are products of per-layer factors. They vanish or explode exponentially.
Ch 1: Chain Rule
Backprop multiplies Jacobians. Sigmoid kills gradients. ReLU preserves them.
Ch 2: Clipping
Cap gradient norm to prevent explosion. Preserves direction.
Ch 3: Accumulation
Sum gradients over micro-batches. Simulates large batch. Trades time for memory.
Ch 4: Mixed Precision
BF16 forward/backward + FP32 optimizer. 2× speed, range beats precision.
Ch 5: Loss Scaling
FP16-only bandaid. Multiply loss to rescue underflowing gradients.
Ch 6: Checkpointing
Save every √N layers. Recompute the rest. 4× memory savings.
Ch 7: Pipeline
All techniques together. The complete modern LLM training recipe.
Ch 8: Arena
Race them all. Each technique fixes ONE failure mode.

Connections

Gradient flow doesn't exist in isolation. Here's where to go next:

Key Papers

PaperYearContribution
Hochreiter, "Vanishing Gradient Problem"1991Identified the fundamental problem of gradient decay in deep networks
Pascanu et al., "On the Difficulty of Training RNNs"2013Gradient clipping for RNNs. Popularized norm clipping.
Micikevicius et al., "Mixed Precision Training"2018FP16 + loss scaling + master weights recipe
Chen et al., "Training Deep Nets with Sublinear Memory Cost"2016Gradient checkpointing (rematerialization)
Kalamkar et al., "A Study of BFLOAT16 for DL Training"2019BF16 as the practical replacement for FP16 + loss scaling
"The art of training deep networks is the art of keeping gradients alive." Every technique in this lesson serves one purpose: ensuring that the error signal from the loss function reaches every parameter in the network, at the right magnitude, in the right precision, without running out of memory. Master this, and you understand why modern training pipelines look the way they do.
You're setting up a 13B parameter LLM training run on 8 H100 GPUs. Which combination of techniques do you need?