Training Foundations

Normalization

The layer that tamed deep networks — from Batch Normalization to the RMSNorm inside every modern LLM.

Prerequisites: Basic algebra (mean, variance) + What a neural network layer does. That's it.
10
Chapters
12+
Simulations
0
Assumed Knowledge

Chapter 0: Why Normalize?

You're training a 10-layer network. After 100 batches, you print the activation statistics at each layer. Layer 1 looks fine — values between -2 and 2. Layer 5 is already weird — values clustered near zero. Layer 10? All zeros or all infinities. Your network is dead.

This isn't a hypothetical. Before 2015, training networks deeper than ~10 layers was a nightmare. Researchers spent weeks tuning learning rates, initialization schemes, and gradient clipping heuristics. The problem was fundamental: deep networks are chains of matrix multiplications, and chains of multiplications do one of two things — explode or collapse.

Let's see exactly why.

The Multiplication Chain Problem

A neural network layer does this: take an input vector, multiply it by a weight matrix, add a bias, and apply a nonlinearity. The output becomes the input to the next layer. Each multiplication either stretches or shrinks the numbers. Stack enough stretches, and values hit infinity. Stack enough shrinks, and values hit zero.

Let's trace this by hand with a tiny example. We'll skip the nonlinearity to see the raw multiplication effect.

Hand calculation setup. We have 3 linear layers, each with a 2×2 weight matrix. Input vector: [1.0, -0.5]. No bias, no activation function — just raw matrix multiplies. Let's see what happens to the numbers.

Hand Calculation: Activation Explosion

Layer 1: Weight matrix W1 = [[1.5, 0.8], [0.6, 1.3]].

Input: [1.0, -0.5]

After layer 1: [1.1, -0.05]. Looks reasonable. Values are still small.

Layer 2: Weight matrix W2 = [[1.4, 0.9], [0.7, 1.6]].

Input: [1.1, -0.05]

After layer 2: [1.495, 0.69]. Still okay, but the magnitudes are growing.

Layer 3: Weight matrix W3 = [[1.3, 1.1], [0.5, 1.4]].

Input: [1.495, 0.69]

After layer 3: [2.70, 1.71]. In just 3 layers, the values roughly tripled.

Three layers, and we've already tripled the magnitude. Now imagine 50 layers. Or 100. With weights whose average eigenvalue exceeds 1.0 (which is common with random initialization), each layer multiplies the scale. After n layers, the activation magnitude grows roughly as λn, where λ is the dominant eigenvalue. If λ = 1.2:

LayersScale Factor (1.2n)What Happens
52.49Fine
106.19Getting large
2038.3Saturating activations
509,100Floating point overflow
10082,817,974NaN territory

And if λ = 0.8 (weights slightly too small), the same table runs in reverse: after 50 layers, activations are scaled by 0.850 ≈ 0.000014. Everything is effectively zero. The network outputs the same thing regardless of input.

Internal Covariate Shift

But the explosion/collapse problem is only half the story. There's a subtler issue: even when activations don't explode, their distribution keeps changing.

Think about it from layer 5's perspective. Layer 5 receives its input from layer 4. During training, layer 4's weights are updated every batch. That means the distribution of layer 4's output — which is layer 5's input — changes every single step. Layer 5 is trying to learn a mapping, but the input it's mapping from is a moving target.

This phenomenon is called internal covariate shift: the distribution of each layer's inputs shifts continuously during training because the parameters of the preceding layers change.

Imagine you're learning to hit a baseball, but someone keeps moving the pitcher's mound. You'd adapt, but slowly and painfully. That's what every layer in a deep network experiences.

This is NOT the vanishing gradient problem. Gradients vanish because of chain-rule multiplication in backprop — each layer multiplies the gradient by its local Jacobian, and if those Jacobians have small norms, the gradient shrinks exponentially as it flows backward. Activations drift because each layer sees a different input distribution every time the previous layer updates its weights. These are related but distinct problems. Normalization fixes the activation distribution problem, which also helps gradients as a side effect — but the root causes are different.

What Does a Healthy Activation Look Like?

A healthy activation distribution has two properties:

The ideal: each layer's activations have mean ≈ 0 and standard deviation ≈ 1. This keeps the signal strong but bounded, and keeps nonlinearities in their sensitive region (neither saturated nor dead).

The Fix in One Sentence

What if, after every layer, we forcefully rescaled the activations to have mean 0 and standard deviation 1? That's normalization. That single idea — shoved into the middle of the network as a differentiable operation — is what unlocked deep learning at scale.

Before normalization (2015 and earlier), training a 20-layer network was a research contribution. After normalization, 100+ layer networks became routine. ResNets (152 layers), Transformers (96+ layers), modern LLMs (hundreds of layers) — all use normalization in every single block.

See It With Your Own Eyes

The simulation below shows activation distributions at layer 8 of a deep network during training. Without normalization, watch the histogram drift, collapse, or explode. Toggle normalization on and see it snap to a healthy bell curve.

Activation Drift

Slide "Training Step" to simulate training progress. Layer 8's activation histogram shifts wildly without normalization. Toggle normalization to see the distribution stabilize.

Training Step 0

Notice the pattern. Without normalization, early in training (step 0-20), things look okay. But by step 40-60, the distribution has shifted far from zero — the mean drifts, the variance balloons. By step 80-100, you see one of two failure modes: either the histogram collapses to a spike near zero (vanishing activations) or spreads so wide that most values are in the extreme tails (exploding activations).

With normalization? The histogram stays centered at zero with unit variance at every single step. The shape may change slightly (the network is still learning), but the scale and location are locked. This is what lets the next layer learn a stable mapping.

From Scratch: Watching the Drift

python
import torch
import torch.nn as nn

# 10-layer network with NO normalization
torch.manual_seed(42)
x = torch.randn(64, 256)  # batch of 64, 256 features

for i in range(10):
    W = torch.randn(256, 256) * 0.5  # random weights
    x = x @ W                        # linear transform
    x = torch.relu(x)                # activation
    print(f"Layer {i}: mean={x.mean():.4f}, std={x.std():.4f}")

Run this yourself. You'll see the mean and std drift unpredictably at each layer. Some runs collapse to all zeros by layer 5. Others see the std grow to hundreds. The exact trajectory depends on the random seed, but the pattern — instability — is universal.

Now add one line per layer:

python
import torch

torch.manual_seed(42)
x = torch.randn(64, 256)

for i in range(10):
    W = torch.randn(256, 256) * 0.5
    x = x @ W
    # Normalize: subtract mean, divide by std
    x = (x - x.mean()) / (x.std() + 1e-5)
    x = torch.relu(x)
    print(f"Layer {i}: mean={x.mean():.4f}, std={x.std():.4f}")

Now the mean hovers near 0.4 (not exactly zero because ReLU clips negatives) and the std stays near 0.6 — stable, layer after layer. The network can actually learn.

The 1e-5 matters. We added a tiny epsilon (1e-5) to the denominator: (x - x.mean()) / (x.std() + 1e-5). Without it, if all values in a feature are identical (std = 0), we divide by zero and get NaN. This epsilon appears in every normalization variant — BatchNorm, LayerNorm, RMSNorm — and its default value is always around 1e-5 or 1e-6. Small enough to not distort the normalization, large enough to prevent catastrophe.

The Four Flavors

Over the next chapters, we'll build four normalization methods from scratch. They all do fundamentally the same thing — center and scale activations — but they differ in which activations they compute the statistics over:

Chapter 1
Mean & Variance — the shared math behind ALL normalization
Chapter 2
Batch Norm — statistics across the batch
Chapter 3
Train/Eval split — the #1 BN bug
Chapter 4
Layer Norm — statistics across features
Chapters 5-7
Group Norm, RMSNorm, and the showcase sim

By the end, you'll understand exactly what happens inside nn.LayerNorm and nn.BatchNorm1d, why modern transformers use RMSNorm instead of LayerNorm, and how to pick the right normalization for any architecture.

But first, we need to understand the two operations that every normalizer shares: centering and scaling.

What happens to activations in layer 10 of an unnormalized deep network after many training steps?

Chapter 1: The Centering Problem — Mean and Variance

Before we build any specific normalizer — BatchNorm, LayerNorm, RMSNorm — we need to understand the two operations that all of them share. Every normalization method in deep learning does exactly two things: subtract the mean and divide by the standard deviation.

That's it. The entire field of normalization is variations on these two arithmetic operations. The only question is: which numbers do you compute the mean and standard deviation over? Different answers give different methods. But the math is always the same.

Step 1: Centering (Subtract the Mean)

Consider five activation values from a single layer:

x = [4, 8, 6, 2, 10]

The mean (μ) is the average:

μ = (4 + 8 + 6 + 2 + 10) / 5 = 30 / 5 = 6.0

Centering means subtracting the mean from every value:

xcentered = [4-6, 8-6, 6-6, 2-6, 10-6] = [-2, 2, 0, -4, 4]

Check: the mean of [-2, 2, 0, -4, 4] is (-2+2+0-4+4)/5 = 0/5 = 0. Centering always produces a zero-mean distribution. Why does this help? Because it removes the bias in the activation distribution. If every activation is centered around 50 instead of 0, the next layer has to learn a large bias just to compensate. Centering puts all layers on a neutral footing.

Step 2: Scaling (Divide by Standard Deviation)

After centering, our values are [-2, 2, 0, -4, 4]. The mean is 0, but the spread varies. These values range from -4 to 4. A different layer might have centered values like [-0.001, 0.002, 0, -0.003, 0.002] — same mean (zero) but vastly different spread.

We want unit spread. The measure of spread is the standard deviation (σ), computed from the variance (σ²):

σ² = (1/N) · ∑i (xi - μ)²

Let's compute it step by step for our centered values [-2, 2, 0, -4, 4]:

xixi - μ (already centered)(xi - μ)²
-2-24
224
000
-4-416
4416
σ² = (4 + 4 + 0 + 16 + 16) / 5 = 40 / 5 = 8.0
σ = √8 ≈ 2.828

Now divide each centered value by σ:

xnorm = [-2/2.828, 2/2.828, 0/2.828, -4/2.828, 4/2.828]
xnorm = [-0.707, 0.707, 0, -1.414, 1.414]

Let's verify. Mean of the normalized values: (-0.707 + 0.707 + 0 - 1.414 + 1.414) / 5 = 0 / 5 = 0. Variance: (0.5 + 0.5 + 0 + 2.0 + 2.0) / 5 = 5.0 / 5 = 1.0. Standard deviation = √1 = 1. Perfect. Mean 0, std 1.

The Complete Formula

Combining both steps into one formula:

x̂ = (x - μ) / σ

Or, with the variance form (which is how it's actually computed, since you avoid an extra square root that would require another square root in the backward pass):

x̂ = (x - μ) / √(σ² + ε)

That ε (epsilon, typically 1e-5) prevents division by zero when the variance happens to be exactly zero.

Learnable Scale (γ) and Shift (β)

Here's a subtle problem. We just forced every activation distribution to have mean 0 and std 1. But what if that's NOT the optimal distribution? What if, for some particular layer, the network would actually perform better if activations had mean 0.5 and std 2.3?

By normalizing to (0, 1), we've constrained the network. We've reduced its representational power. The solution is elegant: after normalizing, apply a learnable linear transformation:

y = γ · x̂ + β

Where γ (scale, initialized to 1) and β (shift, initialized to 0) are learned parameters — they receive gradients and are updated by the optimizer just like weights and biases.

At initialization, γ = 1 and β = 0, so y = x̂ (the normalized value). As training progresses, the network can learn whatever scale and shift is optimal. In the extreme case, it can learn γ = σ and β = μ, which would perfectly undo the normalization.

Why not just skip normalization then? If the network can learn to undo the normalization, what's the point? The key is the gradient path. Without normalization, the gradient must flow through arbitrarily scaled activations. With normalization + learnable γ/β, the gradient flows through a well-conditioned normalized path and then through a simple linear transform. The gradient landscape is smoother. The network can represent the same functions as before, but now it's much easier to optimize. Think of normalization as straightening a winding road — same destination, easier drive.

Hand Calculation: The Full Pipeline

Let's trace the complete normalization pipeline with a concrete example. Input: x = [4, 8, 6, 2, 10]. Learnable parameters: γ = 1.5, β = -0.3 (pretend we're mid-training).

Step 1 — Compute statistics:

Step 2 — Normalize:

Step 3 — Scale and shift:

The output has mean = (-1.361 + 0.761 - 0.300 - 2.421 + 1.821)/5 = -1.500/5 = -0.300 = β. The std = 1.5 × 1.0 = 1.5 = γ. This is always true: after applying γ and β, the output has mean = β and std = γ. The learnable parameters directly control the output distribution.

Explore: Centering and Scaling

The widget below shows data points on a number line. Drag them to change their values, then toggle centering and scaling to see the effect. Watch the mean line and standard deviation bars update in real time.

Normalize Interactively

Drag points on the top number line. Toggle operations to see the normalized result on the bottom line. The mean and ±1σ markers show the distribution's center and spread.

From Scratch: Normalizing in Code

python
import numpy as np

def normalize(x, eps=1e-5):
    """Normalize to mean=0, std=1."""
    mu = x.mean()
    var = ((x - mu) ** 2).mean()
    x_norm = (x - mu) / np.sqrt(var + eps)
    return x_norm

x = np.array([4.0, 8.0, 6.0, 2.0, 10.0])
x_norm = normalize(x)
print(f"Original: mean={x.mean():.2f}, std={x.std():.2f}")
print(f"Normalized: mean={x_norm.mean():.4f}, std={x_norm.std():.4f}")
print(f"Values: {x_norm}")

# With learnable gamma and beta
gamma = 1.5
beta = -0.3
y = gamma * x_norm + beta
print(f"After scale+shift: mean={y.mean():.2f}, std={y.std():.2f}")
# Output: mean=-0.30, std=1.50 — exactly gamma and beta!

Compare to PyTorch's built-in (which we'll use in later chapters):

python
import torch

x = torch.tensor([4.0, 8.0, 6.0, 2.0, 10.0])
x_norm = (x - x.mean()) / x.std(unbiased=False)
print(x_norm)
# tensor([-0.7071,  0.7071,  0.0000, -1.4142,  1.4142])
Biased vs unbiased variance. Notice unbiased=False in the PyTorch code. By default, torch.std() uses Bessel's correction (divides by N-1, not N), which gives an unbiased estimate of population variance from a sample. But normalization layers use biased variance (divide by N) in the forward pass — we're normalizing the actual batch/features we have, not estimating a population parameter. This is a common source of off-by-one numerical mismatches when implementing normalization from scratch.

The Key Insight

Here's the punchline, and it's the single most important idea in this entire lesson:

Every normalization variant — BatchNorm, LayerNorm, RMSNorm, GroupNorm, InstanceNorm — is just this formula with a different definition of "which elements do we compute the mean and std over." That's the ONLY difference. BatchNorm averages across the batch. LayerNorm averages across features. GroupNorm averages across groups of features. RMSNorm drops the mean entirely and only divides by the root-mean-square. Same formula, different axis. Once you understand this chapter, you understand all of normalization.

In the next chapter, we'll see the first and most famous instantiation: computing the mean and variance across the batch dimension.

After normalizing data with mean 100 and std 25, what is the normalized value of the raw value 125?

Chapter 2: Batch Normalization

What if we normalize using statistics from the entire mini-batch? In 2015, Sergey Ioffe and Christian Szegedy proposed exactly this in one of the most cited papers in deep learning history. The idea was deceptively simple: for each feature, compute the mean and variance across all samples in the batch, then normalize. They called it Batch Normalization (BN).

BN transformed deep learning overnight. Networks that previously took weeks to train converged in days. Architectures that were impossible to optimize suddenly worked. Learning rates that would have caused divergence became safe. The paper's title understated the impact: "Batch Normalization: Accelerating Deep Network Training."

The Batch Axis

To understand BN, we need to think about the shape of activations in a network. After a linear layer, the activation tensor has shape [B, D] — B samples in the batch, each with D features. Visualize this as a matrix where rows are samples and columns are features:

Feature 0Feature 1Feature 2
Sample 0153
Sample 1337
Sample 2571
Sample 3355

Batch Normalization computes statistics down each column — independently for each feature, across all samples in the batch. For feature 0, it looks at the values [1, 3, 5, 3] from all four samples, computes their mean and variance, and normalizes them. Feature 1 gets its own separate mean and variance from [5, 3, 7, 5]. Feature 2 from [3, 7, 1, 5].

Each feature is normalized independently. The statistics come from the batch.

The BN Formula

For each feature j in a mini-batch of B samples:

μj = (1/B) · ∑i=1B xij
σj² = (1/B) · ∑i=1B (xij - μj
ij = (xij - μj) / √(σj² + ε)
yij = γj · x̂ij + βj

Note: there are D separate γj and βj parameters — one pair per feature. If you have 512 features, BatchNorm adds 512 learnable scale parameters and 512 learnable shift parameters. That's 1024 extra parameters per BN layer. Compared to the millions of weights in the network, this is negligible.

Hand Calculation: The Full BN Forward Pass

Let's work through BN on the 4×3 matrix above, step by step. ε = 0 for clarity (no numerical stability term).

Our batch (4 samples × 3 features):
[[1, 5, 3], [3, 3, 7], [5, 7, 1], [3, 5, 5]]

Feature 0: values across the batch = [1, 3, 5, 3]

Feature 1: values across the batch = [5, 3, 7, 5]

Feature 2: values across the batch = [3, 7, 1, 5]

Final normalized matrix:

Feature 0Feature 1Feature 2
Sample 0-1.4140-0.447
Sample 10-1.4141.342
Sample 21.4141.414-1.342
Sample 3000.447

Check any column: mean = 0, variance = 1. Each feature is independently normalized. But notice something important — sample 0's normalized value for feature 0 is -1.414, which depends on what samples 1, 2, and 3 had for feature 0. Every sample's output depends on every other sample in the batch.

The Backward Pass: Why It's Not Trivial

In backpropagation, we need the gradient of the loss with respect to each input xij. Because the mean and variance are computed from the inputs, the computation graph has three paths:

Path 1
Direct: xij → x̂ij (through the (x-μ) numerator)
+
Path 2
Through μ: xij → μj → x̂ij (all samples)
+
Path 3
Through σ²: xij → σj² → x̂ij (all samples)

The full gradient is:

∂L/∂xij = (∂L/∂x̂ij) · (1/σj) - (1/B) · ∑k(∂L/∂x̂kj) · (1/σj) - (xij - μj) / (B · σj²) · ∑k(∂L/∂x̂kj) · (xkj - μj) / σj

This looks intimidating, but the key insight is: changing one sample's input affects every other sample's normalized output through the shared mean and variance. The Jacobian is NOT diagonal. This coupling between samples is the fundamental difference between BN and normalizations that operate independently per sample.

BN's dirty secret: your network's output depends on WHICH other samples happen to be in the batch. Feed the same image in two different batches and you get different features. During training, this acts as regularization (the noise is beneficial, like dropout). But for inference, it's a nightmare — we need the output to be deterministic. We'll solve this in Chapter 3.

BN's Hidden Superpower: Loss Landscape Smoothing

For years, everyone believed BN worked by reducing internal covariate shift — the shifting input distributions we discussed in Chapter 0. Then in 2018, Santurkar et al. published "How Does Batch Normalization Help Optimization?" and showed something surprising: BN doesn't actually reduce internal covariate shift much.

What it does do is make the loss landscape smoother. Specifically, BN makes the loss surface more Lipschitz continuous — meaning the gradient doesn't change too abruptly as you move through parameter space. A smoother landscape means:

This is why BN lets you train with learning rates 10-30× larger than without it. Not because the gradients are better-scaled (though they are), but because the landscape itself is smoother.

See the Batch Axis

The simulation below shows a batch × features grid. Click on a feature column to highlight it and watch BN compute the mean and variance from the highlighted column, then normalize those values.

Batch Normalization Grid

Click a feature column to see BN compute its statistics. The left grid shows raw values; the right grid shows normalized values. Each column is normalized independently using statistics from all rows (samples) in that column.

From Scratch: BatchNorm in Code

python
import torch

def batchnorm_forward(x, gamma, beta, eps=1e-5):
    """BatchNorm forward pass.
    x: [B, D] — batch of B samples, D features
    gamma: [D] — learnable scale
    beta: [D] — learnable shift
    """
    # Statistics across batch dimension (axis=0)
    mu = x.mean(dim=0)              # [D] — one mean per feature
    var = x.var(dim=0, unbiased=False)  # [D] — one var per feature

    # Normalize
    x_norm = (x - mu) / torch.sqrt(var + eps)  # [B, D]

    # Scale and shift
    out = gamma * x_norm + beta  # [B, D]

    return out, mu, var

# Test with our 4×3 matrix
x = torch.tensor([[1., 5., 3.],
                   [3., 3., 7.],
                   [5., 7., 1.],
                   [3., 5., 5.]])
gamma = torch.ones(3)   # scale = 1 (no scaling)
beta = torch.zeros(3)   # shift = 0 (no shifting)

out, mu, var = batchnorm_forward(x, gamma, beta)
print(f"Means: {mu}")         # tensor([3., 5., 4.])
print(f"Variances: {var}")    # tensor([2., 2., 5.])
print(f"Normalized:\n{out}") # matches our hand calc

Compare to PyTorch's built-in:

python
import torch.nn as nn

bn = nn.BatchNorm1d(3, affine=True)
# At initialization: gamma=1, beta=0
print(bn.weight)  # Parameter([1., 1., 1.]) — this is gamma
print(bn.bias)    # Parameter([0., 0., 0.]) — this is beta

out_pt = bn(x)    # [4, 3] — same result as our from-scratch version
Parameter count. In BN with D features, the learnable parameters are γ (D values) and β (D values) = 2D parameters. Plus two buffers (running_mean and running_var, D values each) that are NOT learnable — they're updated by exponential moving average, not gradient descent. We'll cover those buffers in Chapter 3.
In Batch Normalization with a mini-batch of 32 samples and 512 features, how many separate means are computed?

Chapter 3: The Training/Inference Split

You've trained a model with BatchNorm and it works great. You deploy it. A single image comes in. Batch size = 1. The mean of a single value is... itself. The variance is... zero. You divide by zero. Your model outputs garbage. What went wrong?

This is the fundamental problem with Batch Normalization at inference time: there is no batch. BN computes statistics across the batch dimension, but during inference, you typically process one input at a time (or small, variable-sized groups). The batch statistics are undefined or meaningless.

The solution: during training, keep a running average of the batch statistics. At inference time, use those saved statistics instead of computing them from the (non-existent) batch.

Exponential Moving Average (EMA)

During training, after each batch, BN updates two buffers using an exponential moving average (EMA):

running_mean = (1 - m) · running_mean + m · batch_mean
running_var = (1 - m) · running_var + m · batch_var

Where m is the momentum parameter (default 0.1 in PyTorch). This is a weighted average where the current batch gets weight m and all previous history gets weight (1-m). Over time, the running statistics converge toward the true population statistics.

Note the confusing naming: PyTorch calls this "momentum" but it's the opposite convention from optimizer momentum. In PyTorch's BN, momentum = 0.1 means 10% new data, 90% old running average. Higher momentum = more weight on the current batch = faster adaptation but noisier estimate.

Hand Calculation: EMA Convergence

Let's trace the running mean for a single feature across 5 training batches. The running mean starts at 0 (PyTorch default). Momentum = 0.1.

Setup: running_mean starts at 0. We see 5 batches with batch means of 3.0, 5.0, 4.0, 3.5, and 4.2 for this feature. The true population mean (if we had infinite data) is approximately 4.0.

Batch 1: batch_mean = 3.0

Way off from the true mean. The EMA has barely started.

Batch 2: batch_mean = 5.0

Batch 3: batch_mean = 4.0

Batch 4: batch_mean = 3.5

Batch 5: batch_mean = 4.2

After 5 batches, the running mean is 1.621 — still far from the true mean of ~4.0. Let's see how it progresses:

After N BatchesApprox. Running MeanGap from 4.0
51.622.38
102.571.43
203.460.54
503.960.04
1003.9990.001

The EMA is deliberately slow. With momentum 0.1, it takes about 50 batches to get within 1% of the true mean. This is by design — you want the running statistics to be stable, not jerked around by a single unusual batch.

model.train() vs model.eval()

PyTorch BatchNorm behaves completely differently in training mode vs evaluation mode:

model.train()model.eval()
Statistics usedBatch statistics (μbatch, σ²batch)Running statistics (running_mean, running_var)
Running statsUpdated after each forward passFrozen (no updates)
Output depends on batchYes — different batch = different outputNo — output is deterministic
Batch size requirementB ≥ 2 (need variance)B ≥ 1 (any size works)

The switch is a single line: model.eval() before inference, model.train() before training. Forgetting this is the #1 BatchNorm bug in production ML.

model.eval() is the #1 BatchNorm bug in production ML. Your model trains perfectly, hits 95% accuracy on the val set, then serves garbage predictions in production. The fix: add model.eval() before inference. One line of code, hours of debugging. Every production ML tutorial should start here. And when you load a saved model for inference, make sure you call model.eval() AFTER model.load_state_dict().

The num_batches_tracked Counter

PyTorch's BatchNorm keeps a counter num_batches_tracked that increments every forward pass during training. This counter is stored in the state_dict alongside the running statistics. It serves two purposes:

The Variance Correction Subtlety

Here's a detail that trips up anyone implementing BN from scratch. During the forward pass, BN computes the biased variance (divides by N). But when updating the running variance, PyTorch applies Bessel's correction:

running_var = (1-m) · running_var + m · (B/(B-1)) · batch_var

The factor B/(B-1) converts the biased variance to an unbiased estimate. Why? The running variance is an estimate of the population variance, and the unbiased estimator is more accurate for that purpose. The forward pass uses biased variance because we're normalizing the actual data we have, not estimating a population parameter.

For a batch of size 32, the correction factor is 32/31 ≈ 1.032 — a 3.2% difference. Small, but it accumulates over thousands of batches and can cause numerical mismatches between your implementation and PyTorch's.

Running statistics are NOT learned parameters. They're buffers — they survive model saving and loading (they're in state_dict) but they don't receive gradients. When you call model.load_state_dict(), the running_mean and running_var are restored. If you forget to load them (or load only the weight/bias parameters), your model uses random running stats and produces garbage at inference time. Always save and load the complete state_dict.

See the Train/Eval Split

The simulation below shows how running statistics converge during training and then freeze during inference. On the left, watch the batch statistics jump around (blue dots) while the running average (orange line) smoothly converges. On the right, see how eval mode uses the frozen running statistics.

Running Statistics Tracker

Each dot is a batch mean. The smooth line is the running mean (EMA). Toggle between train and eval mode. In eval mode, the running statistics freeze and new batches don't update them.

Momentum 0.10

From Scratch: The Full BN Module

python
import torch
import torch.nn as nn

class MyBatchNorm1d(nn.Module):
    def __init__(self, D, eps=1e-5, momentum=0.1):
        super().__init__()
        self.eps = eps
        self.momentum = momentum
        # Learnable parameters
        self.gamma = nn.Parameter(torch.ones(D))
        self.beta = nn.Parameter(torch.zeros(D))
        # Buffers (not learned, but saved in state_dict)
        self.register_buffer('running_mean', torch.zeros(D))
        self.register_buffer('running_var', torch.ones(D))
        self.register_buffer('num_batches_tracked', torch.tensor(0))

    def forward(self, x):
        if self.training:
            # Use batch statistics
            mu = x.mean(dim=0)
            var = x.var(dim=0, unbiased=False)
            # Update running stats (with Bessel's correction)
            B = x.size(0)
            with torch.no_grad():
                self.running_mean = (1-self.momentum) * self.running_mean \
                                    + self.momentum * mu
                self.running_var = (1-self.momentum) * self.running_var \
                                   + self.momentum * var * B / (B-1)
                self.num_batches_tracked += 1
        else:
            # Use running statistics (frozen)
            mu = self.running_mean
            var = self.running_var

        x_norm = (x - mu) / torch.sqrt(var + self.eps)
        return self.gamma * x_norm + self.beta

Test it:

python
bn = MyBatchNorm1d(3)
x = torch.tensor([[1., 5., 3.],
                   [3., 3., 7.],
                   [5., 7., 1.],
                   [3., 5., 5.]])

# Training mode: uses batch stats, updates running stats
bn.train()
out_train = bn(x)
print(f"running_mean: {bn.running_mean}")
# tensor([0.3000, 0.5000, 0.4000]) — 10% of batch mean

# Eval mode: uses running stats (frozen)
bn.eval()
single = torch.tensor([[2., 4., 3.]])  # batch size 1!
out_eval = bn(single)  # Works! Uses running_mean/running_var
print(f"Eval output: {out_eval}")
The state_dict contains everything. After training, print bn.state_dict() and you'll see: weight (γ), bias (β), running_mean, running_var, and num_batches_tracked. All five are saved and loaded together. The running statistics and the counter are buffers, not parameters — they don't appear in bn.parameters() and the optimizer never touches them. But they're critical for inference.
If running_mean = 2.0 and the new batch mean is 7.0 with momentum = 0.1, what is the updated running_mean?

Chapter 4: Layer Normalization

What if we don't want our normalization to depend on other samples at all? What if each token, each sample, each data point normalizes itself?

That's the idea behind Layer Normalization (LN), proposed by Ba, Kiros, and Hinton in 2016. Instead of computing statistics across the batch (down the columns of our B×D matrix), LN computes statistics across the features (along the rows). Each sample is normalized independently using its own mean and variance.

This one change — batch axis → feature axis — eliminates every problem we discussed in Chapter 3. No running statistics. No train/eval split. No batch size dependency. And it's the reason every modern transformer uses Layer Normalization.

The Feature Axis

Let's use the same 4×3 matrix from Chapter 2. But now, instead of looking down each column (BN), we look across each row (LN):

Feature 0Feature 1Feature 2← LN direction
Sample 0153μ=3.0, σ²=2.667
Sample 1337μ=4.333, σ²=3.556
Sample 2571μ=4.333, σ²=6.222
Sample 3355μ=4.333, σ²=0.889

Each row is a separate world. Sample 0's normalization uses only the values [1, 5, 3] — it has no knowledge of what samples 1, 2, or 3 look like. This independence is the key property of Layer Normalization.

Hand Calculation: Layer Norm Step by Step

Same batch, new axis. Let's compute LN for each sample. ε = 0 for clarity.

Sample 0: values across features = [1, 5, 3]

Sample 1: values across features = [3, 3, 7]

Sample 2: values across features = [5, 7, 1]

Sample 3: values across features = [3, 5, 5]

Final LN-normalized matrix:

Feature 0Feature 1Feature 2
Sample 0-1.2251.2250
Sample 1-0.707-0.7071.414
Sample 20.2671.069-1.336
Sample 3-1.4140.7070.707

Check any row: mean ≈ 0, variance ≈ 1. Compare this to the BN result from Chapter 2 — the numbers are completely different! BN and LN produce different outputs because they normalize along different axes.

BN vs LN: Side by Side

Let's make the axis difference crystal clear:

Batch NormalizationLayer Normalization
Normalizes acrossBatch dimension (columns)Feature dimension (rows)
Stats computed fromAll B samples for each featureAll D features for each sample
Number of meansD (one per feature)B (one per sample)
Sample independenceNo — samples affect each otherYes — each sample normalized alone
Running statisticsYes (EMA during training)No
Train/eval differenceYes — must call model.eval()No — identical behavior
Batch size requirementB ≥ 2 during trainingAny batch size, always
γ, β shape[D][D] (same!)
Common mistake: thinking LN and BN produce the same result when "just transposed." They don't. BN normalizes each feature to have zero mean across the batch — different samples get compared to each other. LN normalizes each sample to have zero mean across features — different features get compared to each other. These are fundamentally different operations with different inductive biases. BN says "this feature should have a consistent distribution across all inputs." LN says "this sample's features should be balanced relative to each other."

LN in Transformers

In a transformer, the hidden representation has shape [batch, seq_len, d_model]. A single "sample" for LayerNorm is a single token's representation — a vector of d_model numbers (e.g., 768 in BERT, 4096 in GPT-3).

LayerNorm normalizes over the last dimension (d_model). This means:

For a batch of 8 sequences, each 512 tokens long, with d_model = 768:

WhatValue
Input shape[8, 512, 768]
Number of separate means computed8 × 512 = 4,096
Each mean computed from768 values (one token's features)
γ shape[768]
β shape[768]
Learnable parameters768 + 768 = 1,536

The γ and β are shared across all tokens and all samples. There's one scale and one shift per feature dimension, applied identically everywhere. This is why the parameter count is so small: just 2 × d_model per LN layer.

LayerNorm has NO training/inference split. No running_mean, no running_var, no num_batches_tracked. model.train() and model.eval() do nothing to LayerNorm layers. This makes LN dramatically simpler to deploy, debug, and reason about. It's one of the key reasons transformers are easier to train than older architectures that relied on BatchNorm.

Why BN Won for CNNs but LN Won for Transformers

For convolutional neural networks processing images, BatchNorm makes intuitive sense. In a CNN, each channel (feature map) has a consistent meaning: channel 37 might detect vertical edges. It makes sense to normalize channel 37 to have consistent statistics across all images — the concept of "how much vertical edge is present" should have a stable distribution.

For transformers processing sequences, the situation is reversed. The "features" are the hidden dimensions of each token's representation. Unlike CNN channels, these don't have stable per-dimension semantics. What matters is the relative pattern across dimensions for each token. And crucially, sequences have variable length — a batch might contain a 10-token sentence and a 500-token document. Computing batch statistics across these wildly different sequences would be meaningless.

There are also practical advantages:

See BN vs LN

The simulation below shows the same batch×features grid, but now you can toggle between BN mode (column highlighting) and LN mode (row highlighting). Watch which cells contribute to each normalization operation.

BN vs LN: Which Axis?

Toggle between Batch Norm and Layer Norm to see which cells (highlighted) contribute to each normalization. Click any cell to highlight its normalization group.

From Scratch: Layer Norm in Code

python
import torch

def layernorm_forward(x, gamma, beta, eps=1e-5):
    """LayerNorm forward pass.
    x: [B, D] — batch of B samples, D features
    gamma: [D] — learnable scale
    beta: [D] — learnable shift
    """
    # Statistics across feature dimension (axis=-1)
    mu = x.mean(dim=-1, keepdim=True)     # [B, 1]
    var = x.var(dim=-1, keepdim=True,
                unbiased=False)              # [B, 1]

    # Normalize
    x_norm = (x - mu) / torch.sqrt(var + eps)  # [B, D]

    # Scale and shift
    out = gamma * x_norm + beta  # [B, D]
    return out

# Test with the same 4×3 matrix
x = torch.tensor([[1., 5., 3.],
                   [3., 3., 7.],
                   [5., 7., 1.],
                   [3., 5., 5.]])
gamma = torch.ones(3)
beta = torch.zeros(3)

out = layernorm_forward(x, gamma, beta)
print(out)
# Row 0: [-1.2247, 1.2247, 0.0000] — matches hand calc!

Compare to PyTorch:

python
import torch.nn as nn

ln = nn.LayerNorm(3)  # normalized_shape = [3] (the feature dim)
out_pt = ln(x)
print(out_pt)  # Same result!

# Key difference from BN:
print(list(ln.state_dict().keys()))
# ['weight', 'bias'] — that's it! No running_mean, no running_var

# Works with batch size 1:
single = torch.tensor([[2., 4., 3.]])
print(ln(single))  # No error, no special mode needed

Notice the nn.LayerNorm(3) constructor takes the normalized shape — the shape of the dimensions to normalize over. For a transformer with d_model=768, you'd write nn.LayerNorm(768). For a tensor with shape [B, T, D], this normalizes over the last dimension D. You can also pass a tuple like (T, D) to normalize over multiple trailing dimensions.

The Two Flavors of LN in Transformers

There are two placement patterns for LayerNorm in transformers, and the distinction matters:

Post-LN (original Transformer, 2017): normalization comes after the residual connection.

Input x
Token representation
Attention(x)
Self-attention sublayer
↓ + x (residual)
LayerNorm
Normalize AFTER adding residual

Pre-LN (GPT-2 and most modern models): normalization comes before the sublayer, inside the residual branch.

Input x
Token representation
LayerNorm(x)
Normalize BEFORE attention
Attention(LN(x))
Attend to normalized input
↓ + x (residual)
Output
Un-normalized residual + transformed branch

Pre-LN is more stable for training deep models (100+ layers) because the residual path carries un-normalized signals directly, preventing gradient vanishing through the normalization layers. Post-LN can achieve slightly better final performance but is harder to train and often requires learning rate warmup.

GPT-2 switched to Pre-LN, and every major LLM since has followed. GPT-3, PaLM, LLaMA, Mistral — they all use Pre-LN (some with RMSNorm instead of LayerNorm, which we'll cover in Chapter 6). The reason: Pre-LN eliminates the need for careful learning rate warmup and makes training more stable at scale. The cost is a slight reduction in final performance, which can be recovered with longer training.
In Layer Normalization applied to a transformer with hidden dim 768, how many separate means are computed per token?

Chapter 5: RMSNorm — The Efficient Simplification

LayerNorm does two things: center (subtract the mean) then scale (divide by the standard deviation). In the last chapter we saw how this makes each sample self-contained — no batch dependency, perfect for transformers. But what if one of those two steps doesn't actually matter?

That's the question Biao Zhang and Rico Sennrich asked in 2019. Their answer: skip the centering step entirely. Just divide by the root-mean-square. The result is RMSNorm — the normalization behind virtually every modern large language model.

The intuition is surprisingly simple. After many layers of training, residual connections and careful initialization conspire to keep activations roughly zero-centered already. The mean subtraction step in LayerNorm is doing almost nothing — it's subtracting a number very close to zero. So we remove it and save the compute.

The RMS Formula

Root Mean Square (RMS) is exactly what it sounds like: take each value, square it, average those squares, then take the square root. For a vector x with D elements:

RMS(x) = √( (1/D) · ∑i=1..D xi² )

Then RMSNorm normalizes and scales:

i = (xi / RMS(x)) · γi

That's it. No mean subtraction. No bias term β. Just divide by the RMS and multiply by a learnable scale γ. Two fewer operations than LayerNorm, and one fewer parameter vector (β is gone).

Hand Calculation: LayerNorm vs RMSNorm

Let's run both methods on the same input and compare. Take:

x = [2, −1, 3, −2, 1]

LayerNorm (from Chapter 4):

Step 1 — Mean: μ = (2 + (−1) + 3 + (−2) + 1) / 5 = 3/5 = 0.6

Step 2 — Center: x − μ = [1.4, −1.6, 2.4, −2.6, 0.4]

Step 3 — Variance: σ² = (1.96 + 2.56 + 5.76 + 6.76 + 0.16) / 5 = 17.2 / 5 = 3.44

Step 4 — Standard deviation: σ = √3.44 = 1.855

Step 5 — Normalize: xLN = [1.4/1.855, −1.6/1.855, 2.4/1.855, −2.6/1.855, 0.4/1.855]

xLN = [0.755, −0.863, 1.294, −1.402, 0.216]

RMSNorm:

Step 1 — Square each element: x² = [4, 1, 9, 4, 1]

Step 2 — Mean of squares: (4 + 1 + 9 + 4 + 1) / 5 = 19/5 = 3.8

Step 3 — RMS: √3.8 = 1.949

Step 4 — Normalize: xRMS = [2/1.949, −1/1.949, 3/1.949, −2/1.949, 1/1.949]

xRMS = [1.026, −0.513, 1.540, −1.026, 0.513]

The outputs are different. LayerNorm's output sums to ~0 (it's centered). RMSNorm's output sums to ~1.54 (not centered). But both produce well-scaled values — nothing is exploding or vanishing. In practice, after many training steps, the γ parameters adapt to compensate for these differences, and both methods yield equivalent model quality.

Don't confuse RMSNorm with "LayerNorm without the mean." The denominators are different! LayerNorm divides by std = √var. RMSNorm divides by RMS = √(mean of squares). When the mean is nonzero, these give different results. Also, RMSNorm has only γ (scale), not β (shift). They are related but not identical operations.

The Mathematical Relationship

There's a clean algebraic connection between RMS and standard deviation. Recall:

Var(x) = E[x²] − (E[x])²

Rearranging: E[x²] = Var(x) + (E[x])². But E[x²] is exactly RMS². So:

RMS² = σ² + μ²

When the mean μ ≈ 0 (which it usually is in deep transformers, thanks to residual connections and symmetric initialization), then RMS² ≈ σ², which means RMS ≈ σ. In that regime, RMSNorm and LayerNorm produce nearly identical outputs. The centering step in LayerNorm was subtracting ~0 and the denominator was effectively the same. RMSNorm just makes this official.

The Speed Advantage

LayerNorm needs two passes over the data: one to compute the mean (summing D values), and a second to compute the variance (summing D squared deviations). That's 2D additions, D subtractions, D squarings, plus the final division.

RMSNorm needs one pass: square each element and sum. That's D multiplications and D additions. No mean pass. No subtraction pass. For dmodel = 8192 (LLaMA 70B), this saves roughly 15% of normalization compute per layer. Across 80 layers, it adds up.

The savings are even more significant on GPU hardware. Two-pass algorithms require either reading the tensor from memory twice (doubling memory bandwidth) or storing intermediate results. RMSNorm's single pass is more cache-friendly and parallelizable.

Who Uses RMSNorm?

ModelOrganizationYearNormalization
T5Google2019RMSNorm (they called it "simplified LN")
PaLMGoogle2022RMSNorm
LLaMA / LLaMA 2 / LLaMA 3Meta2023–24RMSNorm
Mistral / MixtralMistral AI2023–24RMSNorm
Gemma / Gemma 2Google2024RMSNorm
Qwen / Qwen 2Alibaba2024RMSNorm
GPT-2, BERT, GPT-3OpenAI / Google2018–20LayerNorm (older era)

The pattern is clear: every LLM released after ~2022 uses RMSNorm. The older LayerNorm holdouts are pre-2021 models. RMSNorm is the modern default for language models.

LayerNorm vs RMSNorm Pipeline

Compare the two pipelines side by side. Drag the slider to change the input values and watch the outputs update in real time. Notice how both methods produce well-scaled results, but RMSNorm skips the entire centering step.

Mean shift 0.6
Spread 1.8

Code: RMSNorm from Scratch

The beauty of RMSNorm is how short it is. Three lines of math:

python
import torch
import torch.nn as nn

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(dim))  # learnable scale

    def forward(self, x):
        # x shape: (batch, seq_len, dim)
        rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        return self.gamma * (x / rms)

Compare to the LLaMA implementation (simplified from the Meta codebase):

python
# From Meta's LLaMA — same logic, production-optimized
class LlamaRMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)  # float32 for stability
        return output * self.weight

Notice the rsqrt call — that's 1/√x, a single GPU instruction. And the .float() cast: LLaMA runs in bfloat16, but normalization is done in float32 to avoid precision issues. This is a common production pattern — mixed-precision normalization.

RMSNorm is one of those beautiful simplifications in ML. The original formulation (LayerNorm) had a step that seemed essential — centering. But empirically, for large transformers, centering does almost nothing. Removing it gave a ~15% speedup with no quality loss. The lesson: always question whether every step is earning its compute.
RMSNorm is faster than LayerNorm primarily because it skips computing the ____.

Chapter 6: Group & Instance Normalization

So far we've seen two extremes. BatchNorm normalizes across the batch — one set of statistics per feature, shared across all samples. LayerNorm normalizes across all features — one set of statistics per sample, independent of the batch. What about something in between?

This question matters most for vision models. Images have spatial structure. A convolutional feature map has shape [B, C, H, W] — batch, channels, height, width. Each "channel" detects a different pattern (edges, textures, shapes). The question is: which dimensions do you normalize over?

The Vision Problem

BatchNorm in a CNN computes statistics over [B, H, W] for each channel independently. This works brilliantly when B is large (32, 64, 128) because you get a stable estimate of each channel's mean and variance. But in many vision tasks — object detection, segmentation, 3D reconstruction, medical imaging — you often can't fit more than 1-4 images per GPU. With B=2, your "statistics" are computed from just 2 samples. That's noise, not signal.

LayerNorm computes over [C, H, W] for each sample. No batch dependency, great. But it lumps all channels together. A channel detecting horizontal edges and a channel detecting blue pixels have very different activation distributions. Normalizing them jointly can wash out meaningful differences.

We need a middle ground.

GroupNorm: The Best of Both Worlds

GroupNorm (Wu & He, 2018) splits the C channels into G groups, then normalizes within each group. If you have 256 channels and G=32 groups, each group contains 256/32 = 8 channels. Within each group, you compute the mean and variance across those 8 channels and all spatial positions (H × W). Each sample is processed independently — no batch dependency.

The computation for one sample, one group:

Stats computed over: (C/G) × H × W values
μg = (1/n) · ∑i ∈ group g xi     σg = √( (1/n) · ∑i ∈ group g (xi − μg)² + ε )
i = γi · (xi − μg) / σg + βi     (for each i ∈ group g)

The learnable γ and β are still per-channel (not per-group), so the model can rescale each channel independently after normalization.

InstanceNorm: The Extreme Case

What happens when every channel is its own group? That's G = C. Each channel is normalized independently over just its spatial dimensions (H × W). This is called Instance Normalization (Ulyanov et al., 2016).

InstanceNorm was invented for style transfer. The core insight: visual "style" (texture, color palette, brushstroke patterns) is encoded in per-channel statistics — the mean and variance of each feature map. If you normalize away each channel's own statistics, you effectively strip the style from the image. The network can then re-inject a new style through the learned γ and β.

Hand Calculation: GroupNorm Step by Step

Let's work a concrete example. One image, 4 channels, 2×2 spatial:

data
Channel 0:  [[1, 2],   Channel 1:  [[5, 6],
             [3, 4]]              [7, 8]]

Channel 2:  [[9, 10],  Channel 3:  [[13, 14],
             [11, 12]]             [15, 16]]

GroupNorm with G=2 (2 groups, 2 channels each):

Group 0 (channels 0 and 1) — all values: [1, 2, 3, 4, 5, 6, 7, 8]

μ0 = (1+2+3+4+5+6+7+8) / 8 = 36/8 = 4.5

σ0² = ((1−4.5)² + (2−4.5)² + ... + (8−4.5)²) / 8 = (12.25 + 6.25 + 2.25 + 0.25 + 0.25 + 2.25 + 6.25 + 12.25) / 8 = 42 / 8 = 5.25

σ0 = √5.25 = 2.291

Channel 0 normalized: [(1−4.5)/2.291, (2−4.5)/2.291, (3−4.5)/2.291, (4−4.5)/2.291] = [−1.528, −1.091, −0.655, −0.218]

Channel 1 normalized: [(5−4.5)/2.291, (6−4.5)/2.291, (7−4.5)/2.291, (8−4.5)/2.291] = [0.218, 0.655, 1.091, 1.528]

Group 1 (channels 2 and 3) — all values: [9, 10, 11, 12, 13, 14, 15, 16]

μ1 = 100/8 = 12.5, σ1² = 5.25 (same spread!), σ1 = 2.291

The normalized values for Group 1 have the exact same pattern as Group 0 — just shifted in the original space. After normalization, both groups span the same range [−1.528, +1.528]. That's normalization doing its job.

InstanceNorm (G=4) — each channel alone:

Channel 0: values [1, 2, 3, 4]. μ = 2.5, σ² = 1.25, σ = 1.118. Normalized: [−1.342, −0.447, 0.447, 1.342]

Channel 1: values [5, 6, 7, 8]. μ = 6.5, σ² = 1.25, σ = 1.118. Normalized: [−1.342, −0.447, 0.447, 1.342] — same pattern! (Same spread.)

Notice that InstanceNorm produces the same normalized pattern for every channel because they all have the same internal spread. GroupNorm, by pooling channels, captures the cross-channel differences (channel 0 is low, channel 1 is high) and normalizes them together. Which is better depends on the task.

The Normalization Spectrum

All five methods we've seen form a spectrum based on which dimensions they share statistics over:

MethodStats computed overBatch dependent?ParamsBest for
BatchNormB, H, W (per channel)Yesγ, β per channelCNNs (large batch)
GroupNormC/G, H, W (per group, per sample)Noγ, β per channelCNNs (small batch)
InstanceNormH, W (per channel, per sample)Noγ, β per channelStyle transfer
LayerNormC, H, W (per sample)Noγ, β per featureTransformers, RNNs
RMSNormD (per sample, no mean)Noγ per featureLLMs

Reading left to right: BN shares the most across samples (needs the batch), InstanceNorm shares the least (each channel is independent), and GroupNorm/LayerNorm sit in between. The trend in modern AI has been away from batch-dependent methods and toward per-sample methods.

The 2018 GroupNorm paper showed devastating results. On ImageNet with batch size 2, BatchNorm accuracy dropped over 10 points compared to batch size 32. GroupNorm was completely unaffected — identical accuracy at batch size 2 and batch size 32. For any vision task where you can't guarantee large batches (detection, segmentation, 3D, medical), GroupNorm is the safe default.
Normalization Axes Visualizer

A feature map with shape [B, C, H, W]. The highlighted cells show which values are pooled to compute one set of statistics (μ, σ). Switch between methods to see how they differ. For GroupNorm, adjust G.

Groups (G) 2
Instance Normalization seems extreme — normalizing each channel independently throws away all cross-channel information. But for style transfer, that's exactly right. Style is encoded in per-channel statistics (the Gram matrix). Removing per-channel stats removes style, letting the network re-inject a new one through γ and β. The "extreme" approach is perfectly matched to the task.

Code: GroupNorm from Scratch

python
import torch
import torch.nn as nn

class GroupNorm(nn.Module):
    def __init__(self, num_groups, num_channels, eps=1e-5):
        super().__init__()
        self.G = num_groups
        self.C = num_channels
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(num_channels))
        self.beta = nn.Parameter(torch.zeros(num_channels))

    def forward(self, x):
        B, C, H, W = x.shape
        # Reshape: group channels together
        x = x.view(B, self.G, C // self.G, H, W)
        # Stats over (C/G, H, W) — dims 2, 3, 4
        mu = x.mean(dim=[2, 3, 4], keepdim=True)
        var = x.var(dim=[2, 3, 4], keepdim=True, unbiased=False)
        x = (x - mu) / torch.sqrt(var + self.eps)
        # Reshape back, apply per-channel scale and shift
        x = x.view(B, C, H, W)
        return self.gamma.view(1, C, 1, 1) * x + self.beta.view(1, C, 1, 1)

The PyTorch built-in is identical in spirit:

python
# Built-in — same thing, CUDA-optimized
norm = nn.GroupNorm(num_groups=32, num_channels=256)  # 8 channels per group
out = norm(x)  # x shape: (B, 256, H, W)

The common default is G=32. Why 32? The GroupNorm paper tested G={1, 2, 4, 8, 16, 32, 64} and found 32 performed best across ResNets. It's not a magic number — values between 16 and 64 all work well. The key constraint is that C must be divisible by G.

GroupNorm with G=1 (one group containing all channels) is equivalent to which other normalization method?

Chapter 7: Pre-LN vs Post-LN — Where You Place It Matters

We know what to normalize. We know how (LayerNorm, RMSNorm, etc.). Now the architectural question that tripped up the entire field for two years: where in the network do you place the normalization layer?

This sounds like a trivial detail. It's not. Moving normalization by one line of code — from after the residual addition to before the sublayer — was the single change that let GPT-2 scale to 1.5 billion parameters without collapsing during training.

The Two Placements

A standard transformer block has two sublayers: multi-head attention and a feedforward network (FFN). Each sublayer has a residual connection. The question is where LayerNorm (or RMSNorm) goes relative to these components.

Post-LN (the original, "Attention Is All You Need" 2017):

output = LayerNorm( x + Sublayer(x) )

The residual is added first, then normalized. The norm wraps the entire output.

Pre-LN (GPT-2, most modern models):

output = x + Sublayer( LayerNorm(x) )

Normalize first, then run through attention/FFN, then add the residual. The residual connection directly adds the raw input to the sublayer output.

Visually, the difference is tiny. Computationally, it changes everything about gradient flow.

Why Pre-LN is More Stable

The key is the gradient highway. In Pre-LN, the residual connection creates a direct path from the output of any layer all the way back to the input. When gradients flow backward during training, they can travel along this highway without passing through any normalization or nonlinearity.

Let's trace the gradient path through L layers of Pre-LN:

xL = x0 + ∑l=0L−1 Sublayerl( LayerNorm(xl) )

Taking the derivative of the final output with respect to an early layer's input:

∂xL / ∂xl = I + ∑k=lL−1 ∂Sublayerk / ∂xl

That I (identity) term is the gradient highway. No matter what happens in the sublayers — even if their gradients vanish completely — the identity term guarantees a gradient of magnitude 1 flows directly from layer L to layer l. The gradient cannot vanish.

Now contrast with Post-LN. Each layer's output passes through LayerNorm, which has its own Jacobian. The gradient must flow through L consecutive LayerNorm Jacobians:

∂xL / ∂xl = ∏k=lL−1 JLN,k · (I + Jsublayer,k)

Each JLN can shrink or amplify the gradient depending on the activation scale at that layer. Over L=48 layers, these Jacobian products can compound into vanishingly small or explosively large values. That's why Post-LN needs careful learning rate warmup — you start with tiny steps to avoid the explosion, then slowly increase.

A Concrete Gradient Example

Imagine a simplified 4-layer network. The gradient signal starts at 1.0 at the output and flows backward.

Pre-LN: at each layer, the gradient splits. One copy travels along the residual highway (unchanged). The other goes through the sublayer. Even if each sublayer attenuates its gradient by 0.5×:

gradient trace
Layer 4 (output): gradient = 1.0
Layer 3: highway = 1.0, sublayer path = 0.5, total ≈ 1.0
Layer 2: highway = 1.0, sublayer paths = 0.5 + 0.25, total ≈ 1.0
Layer 1: highway = 1.0, sublayer paths accumulate, total ≈ 1.0
Layer 0: highway = 1.0  ← always at least 1.0

Post-LN: no highway. Every path goes through LayerNorm Jacobians. If each LayerNorm attenuates the gradient by 0.9×:

gradient trace
Layer 4 (output): gradient = 1.0
Layer 3: 1.0 × 0.9 = 0.90
Layer 2: 0.90 × 0.9 = 0.81
Layer 1: 0.81 × 0.9 = 0.73
Layer 0: 0.73 × 0.9 = 0.66  ← 34% signal lost in just 4 layers

At 48 layers: 0.948 = 0.006. The gradient reaching layer 0 is 0.6% of its original value. That's vanishing gradients. With Pre-LN, the highway guarantees it stays at 1.0 regardless of depth.

The original "Attention Is All You Need" paper (2017) used Post-LN and needed careful warmup. But the widely-used PyTorch nn.TransformerEncoderLayer defaults to... Post-LN! Many practitioners unknowingly train Post-LN networks and blame "transformers being finicky" when Pre-LN would train stably out of the box. Always check which variant you're using.

The Warmup Requirement

Post-LN's instability manifests at the very start of training. Before the model has learned anything, activations can be large and poorly scaled. The LayerNorm Jacobians are at their most unpredictable. A full-sized learning rate applied to these chaotic early gradients can cause the loss to explode immediately.

The fix: learning rate warmup. Start with a tiny learning rate (e.g., 1/1000th of the target) and linearly increase it over the first few thousand steps. By the time the learning rate reaches its full value, the network has settled enough that the LayerNorm Jacobians are well-behaved.

Pre-LN doesn't need this. Because the gradient highway guarantees stable gradient flow from step 1, you can use the full learning rate immediately. This is one reason Pre-LN became the standard — it removes an entire hyperparameter (warmup schedule) from the equation.

Gradient Flow: Pre-LN vs Post-LN

Two transformer stacks side by side. Arrows show gradient magnitude flowing backward. Increase the layer count and watch Post-LN gradients vanish while Pre-LN stays stable via the residual highway.

Layers 6
LN attenuation 0.90

Code: Pre-LN vs Post-LN Blocks

python
class PostLNBlock(nn.Module):
    """Original Transformer (2017) — norm AFTER residual add"""
    def __init__(self, d_model, n_heads, d_ff):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, n_heads)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model))
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)

    def forward(self, x):
        x = self.ln1(x + self.attn(x, x, x)[0])  # add THEN norm
        x = self.ln2(x + self.ffn(x))              # add THEN norm
        return x

class PreLNBlock(nn.Module):
    """GPT-2 / Modern Transformer — norm BEFORE sublayer"""
    def __init__(self, d_model, n_heads, d_ff):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, n_heads)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model))
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)

    def forward(self, x):
        x = x + self.attn(self.ln1(x), self.ln1(x), self.ln1(x))[0]  # norm THEN add
        x = x + self.ffn(self.ln2(x))  # norm THEN add
        return x

One line of difference. Yet Pre-LN trains stably at 48+ layers with no warmup, while Post-LN can diverge at 12 layers without careful scheduling.

Recent Developments

Pre-LN solved training stability but introduced a new problem. Each sublayer's output gets added directly to the residual stream without normalization. After many layers, the residual magnitudes can grow unboundedly. At 100+ layers, this matters.

DeepNorm (Microsoft, 2022) addresses this. It uses Post-LN but scales down the residual before adding: output = LayerNorm(α · x + Sublayer(x)), where α = (2N)1/4 for N layers. This keeps the residual magnitude stable even at 1000 layers. GPT-4 is rumored to use a variant.

Sandwich Norm applies normalization both before AND after each sublayer. Belt and suspenders. More expensive but extremely stable.

Pre-LN solved training stability but created residual growth. Each sublayer output gets ADDED to the residual without normalization. After 100 layers, the residual grows large. This is why very deep models use "residual scaling" — multiply the sublayer output by 1/√(2N) before adding. DeepNorm combines this scaling with Post-LN for the best of both worlds.
VariantFormulaProsConsUsed by
Post-LNLN(x + Sub(x))Better final quality (some studies)Needs warmup, unstable deepOriginal Transformer, BERT
Pre-LNx + Sub(LN(x))Stable, no warmup neededResidual growth at depthGPT-2/3, LLaMA, Mistral
DeepNormLN(αx + Sub(x))Stable at 1000+ layersRequires careful α tuningVery deep models
Sandwichx + LN(Sub(LN(x)))Very stable2× norm computeCogView, some ViTs
Why does Pre-LN enable training without learning rate warmup?

Chapter 8: The Normalization Arena

Let's race them all.

You've learned six normalization strategies: None, BatchNorm, LayerNorm, RMSNorm, GroupNorm, and Pre-LN placement. Each has strengths and weaknesses — but reading about failure modes is one thing. Watching them happen in real time is another.

This simulation trains six identical networks on the same task, differing only in normalization. Drag the sliders to create the conditions where each method fails. You'll discover that there is no universally best normalization — only the right one for your constraints.

How to use the Arena. Hit Play to start training. Adjust batch size, network depth, and learning rate with the sliders. Watch the loss curves and the activation histogram to understand what's happening inside each network. Try these experiments: (1) Set batch_size=1 and watch BN collapse. (2) Set depth=32 and watch "None" explode. (3) Set lr=1.0 and watch Post-LN diverge. (4) Compare LN vs RMSNorm — nearly identical curves.
Normalization Racing Arena

Six methods train simultaneously on the same task. Each runs characteristic training dynamics. Find each method's failure mode.

Batch size 16
Depth 8
Log LR 0.010
Speed 3

What to Discover

Experiment 1: Batch size = 1. Drag the batch size slider all the way left. BatchNorm computes statistics from a single sample — the "mean" is just the value itself, the "variance" is zero. BN's loss flatlines or oscillates wildly. Every other method is unaffected because they don't depend on batch statistics. This is why BN is never used in LLMs (batch size 1 during inference) or in many vision tasks (small batch for high-resolution images).

Experiment 2: Depth = 32 layers. The "None" network's loss immediately explodes — activations compound through 32 layers with no normalization, causing overflow. BN and GroupNorm handle depth well when batch size is adequate. Pre-LN (with its gradient highway) is the most stable at extreme depth.

Experiment 3: Learning rate = 1.0. Drag the LR slider all the way right. With this aggressive learning rate, the unnormalized network diverges instantly. BN diverges next (less stable gradients). LayerNorm and GroupNorm struggle. Pre-LN survives the longest because its gradient highway prevents the amplification that causes divergence.

Experiment 4: Compare LN vs RMSNorm. Under all conditions, the teal (LN) and purple (RMSNorm) curves track each other almost exactly. That's the whole point — RMSNorm matches LayerNorm's quality while being ~15% faster. The speed difference isn't visible in the simulation, but the quality equivalence is.

The Arena reveals a crucial insight: there is NO universally best normalization. BatchNorm is fastest when batch size is large but catastrophic when it's small. LayerNorm is rock-solid but ~15% slower than RMSNorm. GroupNorm is the safe vision default but adds a hyperparameter (G). Pre-LN solves depth scaling but needs residual management. Know the failure modes, choose accordingly.

Failure Mode Summary

MethodFails when...SymptomFix
NoneDepth > 8 or LR > 0.01Loss explodes (NaN)Add any normalization
BatchNormBatch size < 4Noisy/collapsed statisticsUse GroupNorm or LayerNorm
LayerNormExtreme depth + Post-LNVanishing gradientsSwitch to Pre-LN placement
RMSNormSame as LayerNormSame as LayerNormSame as LayerNorm
GroupNormG too large for few channelsUnderfittingReduce G or use LayerNorm
Pre-LNVery deep (>100 layers)Residual magnitude growthAdd residual scaling (DeepNorm)

Chapter 9: Cheat Sheet & Connections

You now understand the complete normalization toolkit — from BatchNorm to RMSNorm, from Post-LN to Pre-LN. This chapter is your practical reference. No new concepts. Just the formulas, the decision guide, and the connections to where you go next.

Every Formula at a Glance

MethodFormulaStats axisParametersCompute
BatchNormγ · (x − μB) / σB + βBatch (B, H, W)γ, β per feature2 passes + running stats
LayerNormγ · (x − μL) / σL + βFeatures (D) or (C,H,W)γ, β per feature2 passes
RMSNormγ · x / RMS(x)Features (D)γ per feature1 pass
GroupNormγ · (x − μG) / σG + βGroup (C/G, H, W)γ, β per channel2 passes per group
InstanceNormγ · (x − μI) / σI + βSpatial (H, W)γ, β per channel2 passes per channel

The Decision Flowchart

Follow the path that matches your situation:

What's your architecture?
The first branch point
Transformer / LLM
Use RMSNorm with Pre-LN placement. This is the modern default (LLaMA, Mistral, Gemma).
Transformer / Encoder (BERT-style)
Use LayerNorm. Pre-LN or Post-LN both work for encoders.
CNN — Large batch guaranteed?
Yes → BatchNorm (fastest). No → GroupNorm (G=32).
Style Transfer / Image Generation
Use InstanceNorm (strips per-channel style stats).
Very deep (>48 layers)?
Use Pre-LN + residual scaling, or DeepNorm. Test gradient norms.

Symbol Glossary

SymbolMeaningTypical values
BBatch size1–4096
DFeature dimension (dmodel in transformers)256–8192
CNumber of channels (vision)3–2048
H, WSpatial height and width7–224
GNumber of groups (GroupNorm)16–64 (default: 32)
γLearnable scale parameterInitialized to 1
βLearnable shift parameterInitialized to 0
εSmall constant for numerical stability1e-5 or 1e-6
μMean of activationsVaries
σStandard deviation of activationsVaries

PyTorch One-Liner Reference

python
import torch.nn as nn

# BatchNorm — for CNNs with large batches
bn  = nn.BatchNorm2d(num_features=256)

# LayerNorm — for transformers (encoder-style)
ln  = nn.LayerNorm(normalized_shape=768)

# RMSNorm — for LLMs (PyTorch 2.4+)
rms = nn.RMSNorm(normalized_shape=4096)

# GroupNorm — for CNNs with small batches
gn  = nn.GroupNorm(num_groups=32, num_channels=256)

# InstanceNorm — for style transfer
isn = nn.InstanceNorm2d(num_features=256)

Summary of Everything

Ch 0: Why Normalize?
Activations drift and explode without normalization.
Ch 1: Centering
Subtract mean, divide by std, scale & shift. The universal formula.
Ch 2: BatchNorm
Normalize across batch dim. Fast, but needs large batches.
Ch 3: Train/Eval Split
Running stats via EMA. The model.eval() bug.
Ch 4: LayerNorm
Normalize per-sample. No batch dependency. Transformer default.
Ch 5: RMSNorm
Skip centering. 1 pass, 15% faster. Modern LLM default.
Ch 6: Group & Instance
Spectrum of channel grouping for vision tasks.
Ch 7: Pre-LN vs Post-LN
Placement determines gradient flow and training stability.
Ch 8: Arena
Race them all. Know the failure modes.

Connections

Normalization doesn't exist in isolation. Here's where to go next:

Key Papers

PaperYearContribution
Ioffe & Szegedy, "Batch Normalization"2015Introduced BN, enabled 10× faster training
Ba, Kiros & Hinton, "Layer Normalization"2016Batch-independent normalization for RNNs/Transformers
Ulyanov et al., "Instance Normalization"2016Per-channel normalization for style transfer
Wu & He, "Group Normalization"2018Bridge between BN and LN for small-batch vision
Zhang & Sennrich, "Root Mean Square Layer Normalization"2019Remove centering for 15% speedup, no quality loss
Xiong et al., "On Layer Normalization in the Transformer Architecture"2020Formal analysis of Pre-LN vs Post-LN gradient flow
"Good engineering is not about complexity — it's about knowing which simplifications are safe." RMSNorm removed centering. Pre-LN moved one line of code. GroupNorm added one hyperparameter. Each was a small change with outsized impact. The lesson isn't about normalization — it's about understanding a system well enough to know where to simplify.
For a new LLM project using a 48-layer transformer, which normalization setup would you recommend?