Andrej Karpathy (Stanford / OpenAI) — Medium 2016

Yes You Should Understand Backprop

A practical argument for why "just use autograd" is insufficient — and the gradient pathologies that will bite you if you don't understand what happens under the hood.

Prerequisites: Basic backpropagation + Neural network training. That's it.
8
Chapters
8+
Simulations

Chapter 0: Why Bother Learning Backprop?

You have PyTorch. You call loss.backward() and gradients magically appear on every parameter. You never touch a derivative by hand. Training works. Why should you care how backpropagation works internally?

Karpathy's answer: because when training doesn't work — and it will stop working — you will have no idea why. Your loss plateaus. Your accuracy stalls at 60%. Your model generates garbage. The error message is just... silence. The loss curve flatlines and tells you nothing about the cause.

Here are real failures that Karpathy encountered as a researcher, where understanding backprop was the difference between debugging in minutes versus debugging for days:

  1. A model that never learns: Loss starts high and stays high. The sigmoid activations have saturated — gradients are effectively zero. Without knowing sigmoid's derivative, you'd never diagnose this.
  2. A model that learns, then dies: Loss drops for 1000 steps, then spikes to random chance. 40% of ReLU units have died — they output zero for every input, and their gradients are permanently zero.
  3. A model that trains perfectly on small data but diverges on large data: Gradients are exploding through deep layers. Without understanding the multiplicative nature of backprop, you'd blame the data.
The Autograd Trap

Three training curves. All three networks use correct code — the loss function is right, the optimizer is right, the data is right. But only one learns. The other two have gradient pathologies that autograd can't warn you about. Click each to see the diagnosis.

Karpathy's thesis: "Backpropagation is a leaky abstraction." Autograd computes the right gradients. But it doesn't tell you whether those gradients are useful. A gradient of 1e-12 is technically correct but practically zero. A gradient of 1e+8 will blow up your weights. Knowing backprop means you can look at your network architecture and predict these problems before they happen.

The blog post covers four specific failure modes, each rooted in how gradients flow through specific operations during backpropagation. Understanding these will upgrade you from "user of autograd" to "debugger of neural networks."

The leaky abstraction

Karpathy borrows Joel Spolsky's concept of "leaky abstractions." An abstraction is a simplification that hides complexity. A leaky abstraction is one where the hidden complexity bleeds through in unexpected ways. Autograd is a leaky abstraction: it hides the mechanics of gradient computation, but the behavior of those gradients (vanishing, exploding, dying) still affects your training. You can use the abstraction without understanding it — until something breaks.

The analogy is apt: you can drive a car without understanding engines. Until the engine fails. Then the person who understands combustion diagnoses the problem in minutes, while the person who doesn't waits for the mechanic. In ML, the "engine" is gradient flow, and the "mechanic" is often yourself at 2 AM before a deadline.

Consider a concrete example: you're training a ResNet-50 on ImageNet and accuracy plateaus at 40% (it should reach 76%). Without understanding backprop, you try: different learning rates, different optimizers, more data augmentation — a random search. With understanding backprop, you check gradient norms per layer and immediately see that later layers have 1000x larger gradients than early layers. You add gradient clipping, the problem resolves in one try. Hours of guessing replaced by minutes of diagnosis.

The five stages of ML debugging

Karpathy describes an implicit progression that every ML practitioner goes through:

  1. Denial: "My code is correct. The data must be wrong." (It's not the data.)
  2. Anger: "Why won't the loss decrease?!" (Check the gradients.)
  3. Bargaining: "Maybe if I change the learning rate / add more layers / try a different optimizer..." (Address the root cause, not the symptoms.)
  4. Depression: "I've been debugging for two days and the model is no better." (You skipped the diagnostic checklist.)
  5. Acceptance: "I need to look at gradient magnitudes, activation distributions, and the update-to-weight ratio." (Now you're debugging, not guessing.)

The post exists to help practitioners skip stages 1-4 and go directly to stage 5. Every failure mode has a signature in the gradient statistics. Learning to read those signatures is the skill this post teaches.

The specific signatures to look for:

Symptom (loss curve)Gradient signatureDiagnosis
Loss doesn't decrease at allGradients ~0 everywhereSigmoid saturation or dead ReLUs
Loss decreases then plateausEarly layer grads 1000x smallerVanishing gradients
Loss spikes to NaNGradient norm > 1000Exploding gradients
Loss decreases but accuracy poorGradients look healthyWrong loss function or labels
Loss oscillates wildlyUpdate/weight ratio > 0.1Learning rate too high
Why does Karpathy argue that "just use autograd" is insufficient?

Chapter 1: Sigmoid Saturation

The sigmoid function σ(x) = 1/(1 + e−x) maps any real number to the range (0, 1). Its derivative is σ'(x) = σ(x)(1 − σ(x)). The maximum value of this derivative occurs at x = 0, where σ'(0) = 0.25.

Think about what this means for gradient flow. At its very best, the sigmoid multiplies the incoming gradient by 0.25. At most points, it's much worse. When the input is large (say x = 5), σ(5) = 0.993, and σ'(5) = 0.993 × 0.007 = 0.007. The gradient is multiplied by 0.007 — it's essentially killed.

The gradient highway analogy: Think of backpropagation as a highway carrying gradient signal from the loss back to early layers. Every sigmoid is a tollbooth that charges a multiplicative toll. The best toll is 0.25 (at x=0). The worst is ~0.0 (at |x| > 5). With 10 sigmoid layers, even the best case gives 0.2510 = 9.5 × 10−7 — the gradient reaching the first layer is a millionth of what left the loss. At typical operating points, it's far worse.

When does saturation happen?

Saturation happens when the pre-activation (the input to the sigmoid) has a large absolute value. This occurs when:

Sigmoid Saturation: The Gradient Killer

Drag the weight magnitude slider to see how large weights push pre-activations into sigmoid saturation. The bottom panel shows the gradient magnitude at each layer — watch it collapse as weights increase.

Weight scale 1.0
Layers 4

The numerical evidence

Let's trace the gradient through a 5-layer sigmoid network with different weight scales:

Weight scaleTypical |pre-act|σ' per layerGradient at layer 1
0.1 (too small)~0.05~0.250.255 = 0.001 (slow learning)
1.0 (typical)~1.0~0.200.205 = 0.0003 (very slow)
3.0 (too large)~3.0~0.050.055 = 3×10−7 (dead)

With weight scale 3.0, the gradient reaching the first layer is 0.0000003 times the loss gradient. The first layer's weights will barely change. The network is effectively frozen.

The fix: don't use sigmoids (mostly)

Modern networks use ReLU (Rectified Linear Unit) instead of sigmoid for hidden layers. ReLU(x) = max(0, x). Its derivative is 1 for x > 0 and 0 for x ≤ 0. When active, ReLU passes the gradient through unchanged — no multiplicative shrinkage. This is why deep networks are possible: ReLU doesn't create the toll-booth problem.

The math of exponential decay

Let's make the sigmoid gradient problem precise. In a network with L sigmoid layers, the gradient at layer 1 is:

∂L/∂W1 = (∂L/∂aL) ∏l=2L [diag(σ'(zl)) · Wl] · diag(σ'(z1)) · x

Each σ'(zl) factor is at most 0.25. If we have L layers and each contributes a factor of 0.25, the gradient shrinks by a factor of 0.25L. For 10 layers, that's 0.2510 ≈ 10−6. For 20 layers, it's 10−12. This is the vanishing gradient problem, discovered independently by Hochreiter (1991) and Bengio et al. (1994).

The sigmoid isn't used in hidden layers anymore, but it is still used where you genuinely want outputs between 0 and 1: the final layer of binary classifiers, gates in LSTM and GRU cells, and the attention mechanism in some architectures. In these cases, saturation is by design — a gate should commit to being "open" or "closed."

python
import torch

# Sigmoid saturation demo
x = torch.randn(1000)

# At typical pre-activations
sig_grad = torch.sigmoid(x) * (1 - torch.sigmoid(x))
print(f"Sigmoid gradient — mean: {sig_grad.mean():.3f}, max: {sig_grad.max():.3f}")
# mean: 0.194, max: 0.250

# At large pre-activations (saturated)
x_large = torch.randn(1000) * 5
sig_grad_large = torch.sigmoid(x_large) * (1 - torch.sigmoid(x_large))
print(f"Saturated gradient — mean: {sig_grad_large.mean():.4f}")
# mean: 0.0156 — 12x smaller!

# ReLU for comparison
relu_grad = (x > 0).float()
print(f"ReLU gradient — mean: {relu_grad.mean():.3f}")
# mean: 0.500 — half active, half zero, no shrinkage on active

A numerical experiment

Let's verify the gradient decay empirically. We'll create a 5-layer sigmoid network and measure the actual gradient magnitude at each layer:

python
import torch
import torch.nn as nn

# 5-layer sigmoid network
layers = []
for i in range(5):
    layers.append(nn.Linear(100, 100))
    layers.append(nn.Sigmoid())
layers.append(nn.Linear(100, 10))
model = nn.Sequential(*layers)

# Forward + backward
x = torch.randn(32, 100)
y = model(x)
loss = y.sum()
loss.backward()

# Check gradient magnitudes
for i, layer in enumerate(model):
    if hasattr(layer, 'weight'):
        print(f"Layer {i}: grad norm = {layer.weight.grad.norm():.6f}")
# Layer 0:  grad norm = 0.000031  ← vanished!
# Layer 2:  grad norm = 0.000284
# Layer 4:  grad norm = 0.002107
# Layer 6:  grad norm = 0.015892
# Layer 8:  grad norm = 0.118744
# Layer 10: grad norm = 0.891355  ← healthy

The gradient at layer 0 is 28,000x smaller than at layer 10. The first layer's weights will barely change during training. This is sigmoid saturation in action.

Why does the sigmoid activation function kill gradients in deep networks?

Chapter 2: Dead ReLUs

ReLU fixed the vanishing gradient problem. But it introduced a new one: the dead ReLU problem.

ReLU(x) = max(0, x). Its derivative is:

ReLU'(x) = 1 if x > 0, 0 if x ≤ 0

When a ReLU unit's input is negative, its output is zero and its gradient is zero. If a weight update pushes the pre-activation permanently negative for all training examples, that unit is dead. It outputs zero for every input. Its gradient is zero for every input. It can never recover. It's a permanent zombie in your network.

How ReLUs die: During a weight update, the bias shifts negative enough that the weighted sum is negative for every input in the dataset. The unit outputs 0 for everything. Since its gradient is also 0, no update can ever revive it. The unit is permanently dead. With a high learning rate, this can happen to a large fraction of units in a single update step.

How common is this?

Karpathy reports seeing networks where up to 40% of ReLU units were dead. That's 40% of the network's capacity completely wasted — parameters that consume memory and compute but contribute nothing.

To check for dead ReLUs in your own network, pass the entire training set through and record which units produce non-zero output for at least one example. Any unit that outputs exactly zero for every single training example is dead.

Visualizing the death process

Dead ReLUs typically appear during the first few epochs of training, when gradients are large and weight updates are aggressive. The sequence is:

  1. Large gradient update pushes the bias of unit j strongly negative
  2. Pre-activation becomes negative for all training examples: wjTx + bj < 0 ∀ x
  3. Output is zero for all inputs: ReLU(wjTx + bj) = 0
  4. Gradient is zero for all inputs: ∂ReLU/∂x = 0 when x ≤ 0
  5. No recovery possible — the unit can never receive a non-zero gradient again

The tragedy is in step 5: even though the network "knows" this unit is dead (in the sense that its loss would decrease if the unit were alive), the zero gradient means no information reaches the dead unit's weights. It's a local minimum of a very specific kind — one where a degree of freedom has been permanently eliminated.

Dead ReLUs are caused by:

Dead ReLU Units

A network with 20 ReLU hidden units. Green = alive (produces non-zero output for at least some inputs). Red = dead (output is zero for ALL inputs). Increase the learning rate and click "Train Step" to see units die. Once dead, they never recover.

Learning rate 0.10
20/20 alive — Step 0

Detecting dead ReLUs in your network

Here's a simple test: pass your entire training set through the network and record which hidden units produce non-zero output for at least one example. If a unit outputs zero for every single training example, it is dead. Karpathy recommends running this check after each epoch and logging the percentage of dead units. If it exceeds 5-10%, something is wrong — usually the learning rate is too high.

The fix: Leaky ReLU and friends

Leaky ReLU fixes the dead neuron problem by allowing a small gradient when x < 0:

LeakyReLU(x) = x if x > 0, αx if x ≤ 0 (typically α = 0.01)

Now even negative pre-activations get a small gradient (0.01 instead of 0). Units can recover from being pushed negative — the small gradient provides a lifeline for the unit to crawl back to the active region.

The α parameter is usually fixed at 0.01, but Parametric ReLU (PReLU) learns α from data. In practice, LeakyReLU with α=0.01 works well enough that PReLU's extra complexity is rarely worth it. Modern Transformers use GELU (Gaussian Error Linear Unit) which has a smooth transition region instead of a sharp kink, providing even better gradient flow characteristics.

Other variants include:

ActivationFormulaNegative gradientDead units?
ReLUmax(0, x)0Yes — permanent
Leaky ReLUmax(0.01x, x)0.01No — can recover
ELUx if x>0, α(ex−1) elseSmooth, up to αNo
GELUx · Φ(x)Smooth, near 0No
Swish/SiLUx · σ(x)Smooth, smallNo
python
import torch
import torch.nn as nn

# Check for dead ReLUs in a trained network
def count_dead_relus(model, dataloader):
    """Pass data through model and count units that never activate."""
    hooks = []
    activations = {}

    def hook_fn(name):
        def hook(module, inp, out):
            if name not in activations:
                activations[name] = (out > 0).float()
            else:
                activations[name] += (out > 0).float()
        return hook

    for name, module in model.named_modules():
        if isinstance(module, nn.ReLU):
            hooks.append(module.register_forward_hook(hook_fn(name)))

    with torch.no_grad():
        for x, _ in dataloader:
            model(x)

    for h in hooks: h.remove()

    for name, acts in activations.items():
        dead = (acts == 0).all(dim=0).sum()
        total = acts.shape[1]
        print(f"{name}: {dead}/{total} dead ({100*dead/total:.1f}%)")
python
# Initialize bias to small positive value to prevent early death
for m in model.modules():
    if isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, mode='fan_in')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0.01)  # slight positive bias
# This ensures most units start in the active region (output > 0)
What causes a ReLU unit to "die" permanently?

Chapter 3: Vanishing and Exploding Gradients

Sigmoid saturation is one cause of vanishing gradients. But the problem is deeper than any single activation function. It's inherent in the multiplicative structure of backpropagation itself.

In a network with L layers, the gradient of the loss with respect to weights in layer 1 is:

∂L/∂W1 = (∂L/∂aL) · (∂aL/∂aL-1) · ... · (∂a2/∂a1) · (∂a1/∂W1)

Each factor ∂al/∂al-1 is roughly equal to the weight matrix Wl times the diagonal of activation derivatives. The product of L-1 such matrices determines whether gradients grow or shrink.

The critical threshold: If the typical singular value of each Jacobian factor is less than 1, gradients shrink exponentially with depth: O(sL) where s < 1. If greater than 1, gradients grow exponentially: O(sL) where s > 1. Only at exactly 1 do gradients neither vanish nor explode. This is why weight initialization matters so much — it determines where you start on this spectrum.

Vanishing: the signal dies

When gradients vanish, early layers learn much slower than later layers. The network essentially only trains its last few layers. The early layers — which are supposed to learn basic features — stay near their random initialization. You can detect this by comparing gradient norms across layers: if layer 1 has gradient norm 1e-8 and layer 10 has gradient norm 1e-1, you have a 10-million-fold discrepancy. The first layer will take millions of times longer to converge.

The symptoms are subtle: the loss decreases (because the last layers are training), but much slower than expected. The model plateaus at mediocre accuracy. Adding more layers makes things worse, not better. Without understanding vanishing gradients, you might blame the dataset, the learning rate, or the architecture — never suspecting that the first five layers are effectively frozen.

Exploding: the signal bombs

When gradients explode, weight updates become enormous. Weights jump to huge values, activations become NaN, and training crashes. The loss goes to infinity or NaN within a few steps. This is especially severe in recurrent neural networks (RNNs), where the same weight matrix is multiplied by itself at every timestep. After 100 timesteps, even a singular value of 1.01 becomes 1.01100 = 2.7. At 1.1, it's 1.1100 = 13,780.

Exploding gradients are easier to detect than vanishing ones: the loss suddenly spikes to NaN, weights contain infinity, or you get CUDA out-of-memory errors from gradient accumulation. The fix is also simpler: gradient clipping immediately prevents the explosion, while vanishing gradients require architectural changes.

Quick diagnostic: To check if you have vanishing or exploding gradients, add this one-liner to your training loop: for n,p in model.named_parameters(): print(f"{n}: {p.grad.norm():.2e}") If the norms differ by more than 100x across layers, you have a gradient scaling problem. Healthy networks show norms within 3-10x of each other.
Gradient Flow Through Deep Layers

A chain of matrix multiplications representing gradient flow through depth. Drag the "weight scale" slider to see how small changes in weight magnitude cause exponential growth or decay. The gradient at layer 1 is the product of all per-layer factors.

Weight scale 1.00
Depth (layers) 10

RNNs: the worst case

Recurrent neural networks are uniquely vulnerable to gradient scaling because the same weight matrix Wh is multiplied at every timestep. If the RNN processes a sequence of length T, the gradient involves WhT — literally raising the weight matrix to the T-th power.

The largest singular value σmax(Wh) determines what happens:

σmaxAfter T=100 stepsResult
0.90.9100 = 2.66 × 10−5Vanishing — can't learn long-range deps
1.01.0100 = 1.0Perfect — but unstable equilibrium
1.11.1100 = 13,780Exploding — training crashes

This is why vanilla RNNs famously cannot learn dependencies longer than about 10-20 timesteps. LSTMs and GRUs solve this with gating mechanisms that create additive paths through time (similar to residual connections), allowing gradients to flow without multiplicative degradation.

Proper initialization: Xavier and He

The fix is to initialize weights so that the variance of activations stays constant across layers:

MethodWeight varianceBest for
Xavier/Glorot (2010)2 / (fan_in + fan_out)Sigmoid, Tanh
He/Kaiming (2015)2 / fan_inReLU and variants

With proper initialization, the product of Jacobians stays near 1, and gradients neither vanish nor explode — at least at the start of training.

Gradient clipping

For exploding gradients (especially in RNNs), gradient clipping caps the gradient norm. Instead of letting the gradient become arbitrarily large, we rescale it to have a maximum norm:

if ||g|| > threshold: g ← g · threshold / ||g||

This preserves the direction of the gradient while limiting its magnitude. A common threshold is 1.0 or 5.0. Gradient clipping is now standard in training RNNs, Transformers, and any model that might encounter occasional gradient spikes.

Two types of clipping: Gradient norm clipping (above) scales the entire gradient vector to have a maximum L2 norm. Gradient value clipping clips each element independently to [−threshold, +threshold]. Norm clipping is preferred because it preserves the gradient direction. Value clipping can distort the direction by clipping some dimensions but not others.

Skip connections: the modern solution

The most effective solution to vanishing gradients came from He et al. (2015) with ResNet. The idea: instead of computing y = F(x), compute y = F(x) + x. The "+x" is a skip connection (or residual connection) that provides a direct path for gradient flow.

During backpropagation, the gradient through a skip connection is:

∂y/∂x = ∂F(x)/∂x + 1

The "+1" term means the gradient always has a component of at least 1, regardless of how small ∂F(x)/∂x becomes. Even if the learned transformation F completely kills the gradient, the identity shortcut preserves it. This is why ResNets can train networks with 1000+ layers — the skip connections provide gradient highways.

Every modern architecture uses this principle: ResNets, Transformers (both the attention sublayer and FFN sublayer have residual connections), U-Nets, and Mamba blocks.

python
import torch.nn as nn

# A residual block: the "+x" provides a gradient highway
class ResidualBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(dim, dim),
            nn.ReLU(),
            nn.Linear(dim, dim),
        )

    def forward(self, x):
        return x + self.layers(x)  # gradient of identity = 1

# Stack 50 residual blocks — gradients still flow!
model = nn.Sequential(*[ResidualBlock(128) for _ in range(50)])
x = torch.randn(1, 128)
out = model(x)
out.sum().backward()

# Check: gradient norms are similar across all 50 blocks
for i, block in enumerate(model):
    gn = block.layers[0].weight.grad.norm().item()
    if i % 10 == 0:
        print(f"Block {i}: grad_norm={gn:.4f}")
# Block 0:  grad_norm=0.0342  — still healthy!
# Block 10: grad_norm=0.0358
# Block 20: grad_norm=0.0361
# Block 30: grad_norm=0.0355
# Block 40: grad_norm=0.0349

Without skip connections, a 50-layer network would have gradient ratios of millions between first and last layers. With skip connections, the ratio stays close to 1. This is why ResNet could train 152 layers in 2015, when previous networks struggled with more than 20.

The gradient highway principle: Every successful deep architecture since 2015 includes some form of gradient highway — a path that lets gradients skip over potentially problematic layers. ResNet uses additive skip connections. Transformers use the same pattern in every sublayer. LSTM uses the cell state as a gradient highway through time. Highway Networks (Srivastava et al., 2015) use gated skip connections. DenseNet concatenates all previous layers' outputs. The specific mechanism varies, but the principle is universal: provide at least one path where gradients can flow without multiplicative degradation.
python
import torch.nn as nn

# Xavier initialization (for sigmoid/tanh)
layer = nn.Linear(256, 256)
nn.init.xavier_uniform_(layer.weight)

# He initialization (for ReLU)
layer = nn.Linear(256, 256)
nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu')

# Gradient clipping (for RNNs)
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
python
import torch
import torch.nn as nn

# Demonstrating vanishing vs exploding gradients
def check_gradient_flow(depth, weight_scale):
    """Build a network of given depth and check gradient magnitudes."""
    layers = []
    for _ in range(depth):
        l = nn.Linear(50, 50, bias=False)
        nn.init.normal_(l.weight, std=weight_scale)
        layers.append(l)
        layers.append(nn.ReLU())
    model = nn.Sequential(*layers)

    x = torch.randn(1, 50)
    out = model(x)
    out.sum().backward()

    norms = []
    for m in model:
        if hasattr(m, 'weight') and m.weight.grad is not None:
            norms.append(m.weight.grad.norm().item())
    return norms

# Weight scale 0.1: vanishing
print("Vanishing:", check_gradient_flow(10, 0.1))
# [2e-8, 5e-7, 1e-5, 3e-4, 8e-3, ...]

# Weight scale 0.2 (He init for 50 units): healthy
print("Healthy:", check_gradient_flow(10, 0.2))
# [0.03, 0.04, 0.05, 0.06, 0.05, ...]

# Weight scale 0.5: exploding
print("Exploding:", check_gradient_flow(10, 0.5))
# [1e+4, 3e+3, 800, 200, 50, ...]
Why do gradients vanish or explode in deep networks?

Chapter 4: Batch Normalization

In 2015, Ioffe and Szegedy proposed Batch Normalization (BatchNorm) — a technique that addresses the gradient flow problem from a completely different angle. Instead of fixing the initialization or the activation function, BatchNorm fixes the distribution of inputs to each layer during training.

The idea: before each activation function, normalize the pre-activations to have zero mean and unit variance across the current mini-batch:

μB = (1/m) ∑i xi
σB2 = (1/m) ∑i (xi − μB)2
i = (xi − μB) / √(σB2 + ε)
yi = γ x̂i + β

The last step introduces learnable parameters γ (scale) and β (shift). This lets the network learn to undo the normalization if that's optimal — BatchNorm doesn't constrain what the network can represent, only how it gets there.

Why this helps gradient flow: By keeping pre-activations centered around zero, BatchNorm prevents sigmoid/tanh from saturating. With unit variance, most pre-activations fall in the range [−2, +2], exactly where sigmoid's derivative is largest. It's like ensuring every car on the gradient highway stays in the lane with the lowest toll.

The gradient perspective

During backpropagation, the BatchNorm gradient has an important property: it normalizes the gradient magnitude as well. Even if the layer's weights are poorly scaled, the normalization step ensures the gradient's scale is approximately preserved. This provides a self-correcting mechanism that keeps gradients in a healthy range.

BatchNorm vs No BatchNorm: Activation Distributions

Pre-activation distributions at each layer of a 6-layer network. Without BatchNorm (top), distributions drift and spread, pushing activations into saturation. With BatchNorm (bottom), distributions stay centered and compact. Toggle to compare.

Internal covariate shift

Ioffe and Szegedy's original motivation was internal covariate shift — the phenomenon where the distribution of inputs to each layer changes during training, because the parameters of previous layers are changing. This makes training harder because each layer has to continually adapt to a shifting input distribution.

BatchNorm fixes this by standardizing the distribution at each layer to zero mean and unit variance, regardless of what the previous layers are doing. While the "internal covariate shift" explanation has been debated (recent work by Santurkar et al. suggests BatchNorm primarily smooths the loss landscape), the practical benefits are undeniable.

BatchNorm during inference

During training, BatchNorm uses the mean and variance of the current mini-batch. But at test time, you process one example at a time — there's no batch to compute statistics from. The solution: during training, keep a running average of batch means and variances (using exponential moving average). At test time, use these stored statistics instead of batch statistics.

python
# BatchNorm handles this automatically:
model.train()   # uses batch statistics
model.eval()    # uses running average statistics

# Common bug: forgetting model.eval() before inference!
# This causes different behavior between train and test

Practical effects of BatchNorm

PropertyWithout BatchNormWith BatchNorm
Max learning rate~0.01~0.1 (10x higher)
Initialization sensitivityHigh — bad init = no learningLow — robust to init
Training speedBaseline~2-3x faster to converge
RegularizationNoneMild (mini-batch noise)
python
import torch.nn as nn

# Standard pattern: Linear → BatchNorm → ReLU
model = nn.Sequential(
    nn.Linear(784, 256),
    nn.BatchNorm1d(256),  # normalizes across batch dim
    nn.ReLU(),
    nn.Linear(256, 128),
    nn.BatchNorm1d(128),
    nn.ReLU(),
    nn.Linear(128, 10),
)

# Verify: gradient norms should be similar across layers
x = torch.randn(32, 784)
out = model(x)
out.sum().backward()
for name, p in model.named_parameters():
    if 'weight' in name:
        print(f"{name}: grad_norm={p.grad.norm():.4f}")
# With BatchNorm: all grad norms within 2-3x of each other ✓

# For CNNs: BatchNorm2d after Conv2d
conv_block = nn.Sequential(
    nn.Conv2d(64, 128, 3, padding=1),
    nn.BatchNorm2d(128),
    nn.ReLU(),
)

LayerNorm: BatchNorm's sibling

Layer Normalization (Ba et al., 2016) normalizes across the feature dimension instead of the batch dimension. This avoids BatchNorm's dependency on batch size and works correctly with batch size 1 (common in recurrent and autoregressive models).

LayerNorm: normalize across features (dimension d)
BatchNorm: normalize across batch (dimension B)

Transformers use LayerNorm exclusively. The choice between BatchNorm and LayerNorm depends on the architecture:

ArchitectureNormalizationWhy
CNNs (ResNet)BatchNormLarge batches available, spatial invariance helps
Transformers (GPT, BERT)LayerNormAutoregressive decoding uses batch size 1
RNNs (LSTM)LayerNormSequence length varies, can't batch across time
GANsInstanceNorm or SpectralNormBatchNorm causes mode collapse in discriminator
How does Batch Normalization help gradient flow?

Chapter 5: Gradient Checking

So you've implemented a custom layer, loss function, or training loop. How do you know your gradients are correct? Autograd handles standard operations, but the moment you write custom backward passes, implement gradient modifications, or use mixed-precision training, bugs can creep in silently. The loss still decreases — just not as fast as it should — and you have no idea you're leaving performance on the table.

Gradient checking (also called "grad check" or "numerical differentiation") is the gold standard for verifying gradient implementations. The idea: approximate the gradient numerically using the definition of a derivative, then compare it to your analytical gradient.

Finite difference approximation

The centered difference formula approximates the derivative of f at x:

f'(x) ≈ (f(x + ε) − f(x − ε)) / (2ε)

This is O(ε2) accurate — much better than the one-sided formula (f(x+ε) − f(x))/ε which is only O(ε) accurate. Use ε = 1e-5 (not too small, or floating-point precision kills you; not too large, or the approximation is poor).

The practical recipe: For each parameter w, (1) compute loss at w + ε, (2) compute loss at w − ε, (3) numerical gradient = (loss+ − loss−) / (2ε), (4) compare to your analytical gradient. If the relative error is below 1e-5, you're good. Above 1e-3, you have a bug.

When to run gradient checking

Gradient checking is slow — O(2P) forward passes for P parameters. Run it:

Relative error metric

Don't compare absolute differences — they're scale-dependent. Use relative error:

relative_error = |ganalytic − gnumeric| / max(|ganalytic|, |gnumeric|)
Relative errorInterpretation
< 1e-7Excellent — implementation is correct
1e-7 to 1e-5Good — might be fine, check edge cases
1e-5 to 1e-3Suspicious — likely a bug, investigate
> 1e-3Bug — your gradient is wrong
Gradient Check: Analytical vs. Numerical

A simple function f(w) = w3 + 2w. The analytical derivative is 3w2 + 2. The numerical derivative is computed via centered difference. Drag ε to see how the numerical approximation quality changes. Watch the relative error.

w 1.5
log₁₀(ε) -5.0
python
import torch

def gradient_check(model, loss_fn, x, y, eps=1e-5):
    """Check analytical vs numerical gradients for all params."""
    # Get analytical gradients
    model.zero_grad()
    loss = loss_fn(model(x), y)
    loss.backward()

    for name, param in model.named_parameters():
        analytic = param.grad.clone()
        numeric = torch.zeros_like(param)

        # Check a random subset (full check is slow)
        flat = param.data.view(-1)
        indices = torch.randperm(len(flat))[:10]

        for idx in indices:
            orig = flat[idx].item()
            flat[idx] = orig + eps
            loss_plus = loss_fn(model(x), y).item()
            flat[idx] = orig - eps
            loss_minus = loss_fn(model(x), y).item()
            flat[idx] = orig

            num_grad = (loss_plus - loss_minus) / (2 * eps)
            ana_grad = analytic.view(-1)[idx].item()

            rel_err = abs(num_grad - ana_grad) / max(abs(num_grad), abs(ana_grad), 1e-8)
            if rel_err > 1e-5:
                print(f"⚠ {name}[{idx}]: rel_err={rel_err:.2e}")

        print(f"✓ {name}: gradient check passed")
When performing gradient checking, why should you use the centered difference formula (f(x+ε) − f(x−ε)) / 2ε instead of the one-sided formula (f(x+ε) − f(x)) / ε?

Chapter 6: Debugging Showcase

Let's put it all together. You're training a network and something is wrong. Here's the systematic debugging workflow that Karpathy recommends, enriched with everything we've learned.

Gradient Pathology Debugger

Train a 6-layer network with different configurations and watch for gradient pathologies. The left panel shows per-layer gradient magnitudes (log scale). The right panel shows the loss curve. Toggle pathological settings to see the failures, then apply fixes.

Activation
Init scale 0.50
Learning rate 0.010
Configure and train

The update-to-weight ratio

Karpathy's single most useful diagnostic: the update-to-weight ratio. For each parameter tensor, compute:

ratio = η · ||∇w|| / ||w||

This ratio tells you how much the weights are changing relative to their current magnitude, per update step. The sweet spot is around 1e-3: weights change by about 0.1% per step.

RatioMeaningAction
> 1e-1Weights changing by 10%+ per stepReduce learning rate or clip gradients
~1e-3Healthy — weights change by 0.1% per stepGood — keep going
< 1e-5Weights barely changingIncrease learning rate or check for vanishing gradients
~0Weights frozen — dead gradientsCheck activation saturation, dead ReLUs

This one metric catches most training pathologies. If you add only one diagnostic to your training loop, make it this one.

python
# Implementing the update-to-weight ratio monitor
class GradientMonitor:
    def __init__(self, model, optimizer):
        self.model = model
        self.optimizer = optimizer

    def log_ratios(self):
        lr = self.optimizer.defaults['lr']
        for name, p in self.model.named_parameters():
            if p.grad is None: continue
            update_norm = (lr * p.grad).norm().item()
            weight_norm = p.data.norm().item()
            ratio = update_norm / (weight_norm + 1e-8)
            # Log to your favorite tool
            print(f"{name}: ratio={ratio:.2e}"
                  f" {'✓' if 1e-4 < ratio < 1e-2 else '⚠'}")

The debugging checklist

1. Check loss at init
For N classes with softmax, initial loss should be −ln(1/N). If much higher, weights or biases are wrong.
2. Overfit tiny batch
Train on 5-10 examples. Loss should reach ~0. If not, the model or training loop has a bug.
3. Monitor gradient magnitudes
Log ||grad|| per layer per epoch. If early layers have gradients 1000x smaller than late layers, you have vanishing gradients.
4. Check activation distributions
Log mean and std of activations per layer. If saturated (all near 0 or 1 for sigmoid), learning is stalled.
5. Count dead units
For ReLU networks, check what fraction of units output 0 for all inputs. If >10%, reduce learning rate or use Leaky ReLU.
6. Gradient check custom code
Any custom loss, layer, or training modification: run gradient check with ε=1e-5, check relative error < 1e-5.
python
import torch
import torch.nn as nn

def debug_training(model, dataloader, optimizer, loss_fn, epochs=5):
    """Training loop with gradient diagnostics."""
    for epoch in range(epochs):
        for x, y in dataloader:
            optimizer.zero_grad()
            out = model(x)
            loss = loss_fn(out, y)
            loss.backward()

            # === DIAGNOSTIC 1: gradient magnitudes ===
            for name, p in model.named_parameters():
                if p.grad is not None:
                    grad_norm = p.grad.norm().item()
                    if grad_norm < 1e-7:
                        print(f"⚠ VANISHING: {name} grad_norm={grad_norm:.2e}")
                    if grad_norm > 100:
                        print(f"⚠ EXPLODING: {name} grad_norm={grad_norm:.2e}")

            # === DIAGNOSTIC 2: update-to-weight ratio ===
            for name, p in model.named_parameters():
                if p.grad is not None:
                    update_mag = (optimizer.defaults['lr'] * p.grad).norm()
                    weight_mag = p.data.norm()
                    ratio = update_mag / (weight_mag + 1e-8)
                    # Healthy ratio: ~1e-3. Too high = unstable. Too low = slow.
                    if ratio > 0.1:
                        print(f"⚠ {name}: update/weight = {ratio:.2e} (too large)")

            optimizer.step()

When autograd actually fails

Karpathy gives specific examples of when autograd computes technically correct but practically useless gradients:

These are all cases where the code runs without errors, the loss has a finite value, autograd returns gradients — but the gradients are useless. Only someone who understands the backward pass can predict these failures.

A related class of problems involves gradient hacking — when a model's architecture inadvertently creates zero-gradient paths. For example, using torch.where(condition, a, b) creates a gradient discontinuity at the boundary. Or using .detach() in the wrong place silently cuts the gradient graph. These bugs produce models that "kind of work" but never reach their potential.

What's the first thing you should check if your neural network's loss doesn't decrease from its initial value?

Chapter 7: Connections

Karpathy's blog post was written in 2016 but remains relevant today. The specific failure modes he describes — saturated activations, dead units, vanishing/exploding gradients — are fundamental to how backpropagation works. They don't go away with better frameworks; they're inherent in the math.

Karpathy went on to lead the AI team at Tesla Autopilot and later joined OpenAI, where he trained some of the most successful neural networks in history. His consistent message: understanding the fundamentals — especially gradient flow — is what separates engineers who can debug models from those who can only run scripts. The principles in this blog post apply at every scale, from MNIST to GPT-4.

In modern large-scale training (GPT-4 class models), the same pathologies appear but at different scales. Gradient explosions during pre-training can waste thousands of GPU-hours. Dead attention heads (the Transformer equivalent of dead ReLUs) reduce model capacity. Loss spikes from numerical instability require careful mixed-precision engineering. All of these are gradient flow problems — the same problems Karpathy described in 2016, just with more zeros in the parameter count.

What changed since 2016

Problem2016 solutionModern solution
Vanishing gradientsReLU + proper init+ Residual connections (skip), LayerNorm, GELU
Exploding gradientsGradient clipping+ Learning rate warmup, gradient scaling (mixed precision)
Dead ReLUsLeaky ReLU, careful LRGELU (now standard in Transformers), SiLU
Sigmoid saturationDon't use sigmoidSigmoid still used only for gates (LSTM, Mamba)
Gradient correctnessManual grad checktorch.autograd.gradcheck() built-in

What hasn't changed

Modern gradient monitoring

In 2024-2025, the best practice is to log gradient statistics during training using tools like Weights & Biases or TensorBoard. The key metrics to track:

MetricWhat to watch forTool
||∇W|| per layerShould be similar across layers (within 10x)wandb.log({"grad_norm/layer_0": ...})
Update/weight ratioShould be ~1e-3 for all layersCustom hook on optimizer.step()
Activation distributionShould not drift to saturated valuesForward hook logging mean/std
Dead unit countShould be <5% for ReLU networksForward hook counting zero outputs
Loss NaN/InfImmediate stop — gradient explosiontorch.isnan(loss).any()

Modern frameworks also provide built-in tools: torch.nn.utils.clip_grad_norm_ for clipping, torch.autograd.detect_anomaly() for finding NaN sources, and torch.autograd.gradcheck() for numerical verification.

python
# Essential PyTorch gradient debugging tools

# 1. Detect NaN/Inf sources
with torch.autograd.detect_anomaly():
    output = model(input)
    loss = criterion(output, target)
    loss.backward()  # will print stack trace at NaN source

# 2. Numerical gradient check
torch.autograd.gradcheck(my_custom_fn, inputs, eps=1e-6)

# 3. Gradient clipping with logging
total_norm = nn.utils.clip_grad_norm_(model.parameters(), 1.0)
if total_norm > 1.0:
    print(f"Clipped! Original norm: {total_norm:.2f}")
Karpathy's lasting principle: "Yes, you should understand backprop." This isn't about computing gradients by hand forever — it's about building the mental model that lets you predict what will happen when you make an architecture choice. Will GELU have better gradient flow than ReLU? (Yes, slightly.) Will a 100-layer network without skip connections train? (No.) Will batch size 1 break BatchNorm? (Yes.) These predictions require understanding the mechanics, not just the API.

Related Veanors

The practitioner's gradient knowledge stack

Karpathy's implicit curriculum for neural network practitioners, from most to least critical:

  1. Know what gradients should look like. Update/weight ratio ~1e-3. Gradient norms similar across layers. No NaN or Inf.
  2. Know common failure modes. Sigmoid saturation, dead ReLUs, vanishing/exploding gradients. Recognize them from the loss curve alone.
  3. Know the fixes. BatchNorm, proper initialization (He/Xavier), gradient clipping, skip connections, Leaky ReLU.
  4. Know how to verify. Gradient checking, activation histograms, the overfit-small-batch test.
  5. Know when autograd isn't enough. Custom backward passes, non-differentiable operations, numerical stability.

You don't need to compute gradients by hand for production code. But you need to think in gradients — to look at an architecture diagram and predict where gradient flow will succeed or fail. When someone proposes adding 10 sigmoid layers, you should immediately think "vanishing gradients." When someone suggests a learning rate of 1.0 with ReLU, you should think "dead units." When someone removes all skip connections from a Transformer, you should think "gradient won't reach the embedding layer."

This intuition — gradient intuition — is what separates a competent ML engineer from someone who just runs training scripts.

And it starts with understanding backpropagation at the level Karpathy describes in this post. Not because you'll compute gradients by hand, but because you'll know what to look for when things go wrong.

"The most dangerous thing about autograd is that it makes it trivially easy to do the wrong thing." — Andrej Karpathy

Which modern technique directly addresses the vanishing gradient problem by providing shortcut paths for gradient flow?