The layer that tamed deep networks — from Batch Normalization to the RMSNorm inside every modern LLM.
You're training a 10-layer network. After 100 batches, you print the activation statistics at each layer. Layer 1 looks fine — values between -2 and 2. Layer 5 is already weird — values clustered near zero. Layer 10? All zeros or all infinities. Your network is dead.
This isn't a hypothetical. Before 2015, training networks deeper than ~10 layers was a nightmare. Researchers spent weeks tuning learning rates, initialization schemes, and gradient clipping heuristics. The problem was fundamental: deep networks are chains of matrix multiplications, and chains of multiplications do one of two things — explode or collapse.
Let's see exactly why.
A neural network layer does this: take an input vector, multiply it by a weight matrix, add a bias, and apply a nonlinearity. The output becomes the input to the next layer. Each multiplication either stretches or shrinks the numbers. Stack enough stretches, and values hit infinity. Stack enough shrinks, and values hit zero.
Let's trace this by hand with a tiny example. We'll skip the nonlinearity to see the raw multiplication effect.
Layer 1: Weight matrix W1 = [[1.5, 0.8], [0.6, 1.3]].
Input: [1.0, -0.5]
After layer 1: [1.1, -0.05]. Looks reasonable. Values are still small.
Layer 2: Weight matrix W2 = [[1.4, 0.9], [0.7, 1.6]].
Input: [1.1, -0.05]
After layer 2: [1.495, 0.69]. Still okay, but the magnitudes are growing.
Layer 3: Weight matrix W3 = [[1.3, 1.1], [0.5, 1.4]].
Input: [1.495, 0.69]
After layer 3: [2.70, 1.71]. In just 3 layers, the values roughly tripled.
Three layers, and we've already tripled the magnitude. Now imagine 50 layers. Or 100. With weights whose average eigenvalue exceeds 1.0 (which is common with random initialization), each layer multiplies the scale. After n layers, the activation magnitude grows roughly as λn, where λ is the dominant eigenvalue. If λ = 1.2:
| Layers | Scale Factor (1.2n) | What Happens |
|---|---|---|
| 5 | 2.49 | Fine |
| 10 | 6.19 | Getting large |
| 20 | 38.3 | Saturating activations |
| 50 | 9,100 | Floating point overflow |
| 100 | 82,817,974 | NaN territory |
And if λ = 0.8 (weights slightly too small), the same table runs in reverse: after 50 layers, activations are scaled by 0.850 ≈ 0.000014. Everything is effectively zero. The network outputs the same thing regardless of input.
But the explosion/collapse problem is only half the story. There's a subtler issue: even when activations don't explode, their distribution keeps changing.
Think about it from layer 5's perspective. Layer 5 receives its input from layer 4. During training, layer 4's weights are updated every batch. That means the distribution of layer 4's output — which is layer 5's input — changes every single step. Layer 5 is trying to learn a mapping, but the input it's mapping from is a moving target.
This phenomenon is called internal covariate shift: the distribution of each layer's inputs shifts continuously during training because the parameters of the preceding layers change.
Imagine you're learning to hit a baseball, but someone keeps moving the pitcher's mound. You'd adapt, but slowly and painfully. That's what every layer in a deep network experiences.
A healthy activation distribution has two properties:
The ideal: each layer's activations have mean ≈ 0 and standard deviation ≈ 1. This keeps the signal strong but bounded, and keeps nonlinearities in their sensitive region (neither saturated nor dead).
What if, after every layer, we forcefully rescaled the activations to have mean 0 and standard deviation 1? That's normalization. That single idea — shoved into the middle of the network as a differentiable operation — is what unlocked deep learning at scale.
Before normalization (2015 and earlier), training a 20-layer network was a research contribution. After normalization, 100+ layer networks became routine. ResNets (152 layers), Transformers (96+ layers), modern LLMs (hundreds of layers) — all use normalization in every single block.
The simulation below shows activation distributions at layer 8 of a deep network during training. Without normalization, watch the histogram drift, collapse, or explode. Toggle normalization on and see it snap to a healthy bell curve.
Slide "Training Step" to simulate training progress. Layer 8's activation histogram shifts wildly without normalization. Toggle normalization to see the distribution stabilize.
Notice the pattern. Without normalization, early in training (step 0-20), things look okay. But by step 40-60, the distribution has shifted far from zero — the mean drifts, the variance balloons. By step 80-100, you see one of two failure modes: either the histogram collapses to a spike near zero (vanishing activations) or spreads so wide that most values are in the extreme tails (exploding activations).
With normalization? The histogram stays centered at zero with unit variance at every single step. The shape may change slightly (the network is still learning), but the scale and location are locked. This is what lets the next layer learn a stable mapping.
python import torch import torch.nn as nn # 10-layer network with NO normalization torch.manual_seed(42) x = torch.randn(64, 256) # batch of 64, 256 features for i in range(10): W = torch.randn(256, 256) * 0.5 # random weights x = x @ W # linear transform x = torch.relu(x) # activation print(f"Layer {i}: mean={x.mean():.4f}, std={x.std():.4f}")
Run this yourself. You'll see the mean and std drift unpredictably at each layer. Some runs collapse to all zeros by layer 5. Others see the std grow to hundreds. The exact trajectory depends on the random seed, but the pattern — instability — is universal.
Now add one line per layer:
python import torch torch.manual_seed(42) x = torch.randn(64, 256) for i in range(10): W = torch.randn(256, 256) * 0.5 x = x @ W # Normalize: subtract mean, divide by std x = (x - x.mean()) / (x.std() + 1e-5) x = torch.relu(x) print(f"Layer {i}: mean={x.mean():.4f}, std={x.std():.4f}")
Now the mean hovers near 0.4 (not exactly zero because ReLU clips negatives) and the std stays near 0.6 — stable, layer after layer. The network can actually learn.
(x - x.mean()) / (x.std() + 1e-5). Without it, if all values
in a feature are identical (std = 0), we divide by zero and get NaN. This
epsilon appears in every normalization variant — BatchNorm, LayerNorm,
RMSNorm — and its default value is always around 1e-5 or 1e-6. Small
enough to not distort the normalization, large enough to prevent catastrophe.
Over the next chapters, we'll build four normalization methods from scratch. They all do fundamentally the same thing — center and scale activations — but they differ in which activations they compute the statistics over:
By the end, you'll understand exactly what happens inside
nn.LayerNorm and nn.BatchNorm1d, why modern
transformers use RMSNorm instead of LayerNorm, and how to pick the
right normalization for any architecture.
But first, we need to understand the two operations that every normalizer shares: centering and scaling.
Before we build any specific normalizer — BatchNorm, LayerNorm, RMSNorm — we need to understand the two operations that all of them share. Every normalization method in deep learning does exactly two things: subtract the mean and divide by the standard deviation.
That's it. The entire field of normalization is variations on these two arithmetic operations. The only question is: which numbers do you compute the mean and standard deviation over? Different answers give different methods. But the math is always the same.
Consider five activation values from a single layer:
The mean (μ) is the average:
Centering means subtracting the mean from every value:
Check: the mean of [-2, 2, 0, -4, 4] is (-2+2+0-4+4)/5 = 0/5 = 0. Centering always produces a zero-mean distribution. Why does this help? Because it removes the bias in the activation distribution. If every activation is centered around 50 instead of 0, the next layer has to learn a large bias just to compensate. Centering puts all layers on a neutral footing.
After centering, our values are [-2, 2, 0, -4, 4]. The mean is 0, but the spread varies. These values range from -4 to 4. A different layer might have centered values like [-0.001, 0.002, 0, -0.003, 0.002] — same mean (zero) but vastly different spread.
We want unit spread. The measure of spread is the standard deviation (σ), computed from the variance (σ²):
Let's compute it step by step for our centered values [-2, 2, 0, -4, 4]:
| xi | xi - μ (already centered) | (xi - μ)² |
|---|---|---|
| -2 | -2 | 4 |
| 2 | 2 | 4 |
| 0 | 0 | 0 |
| -4 | -4 | 16 |
| 4 | 4 | 16 |
Now divide each centered value by σ:
Let's verify. Mean of the normalized values: (-0.707 + 0.707 + 0 - 1.414 + 1.414) / 5 = 0 / 5 = 0. Variance: (0.5 + 0.5 + 0 + 2.0 + 2.0) / 5 = 5.0 / 5 = 1.0. Standard deviation = √1 = 1. Perfect. Mean 0, std 1.
Combining both steps into one formula:
Or, with the variance form (which is how it's actually computed, since you avoid an extra square root that would require another square root in the backward pass):
That ε (epsilon, typically 1e-5) prevents division by zero when the variance happens to be exactly zero.
Here's a subtle problem. We just forced every activation distribution to have mean 0 and std 1. But what if that's NOT the optimal distribution? What if, for some particular layer, the network would actually perform better if activations had mean 0.5 and std 2.3?
By normalizing to (0, 1), we've constrained the network. We've reduced its representational power. The solution is elegant: after normalizing, apply a learnable linear transformation:
Where γ (scale, initialized to 1) and β (shift, initialized to 0) are learned parameters — they receive gradients and are updated by the optimizer just like weights and biases.
At initialization, γ = 1 and β = 0, so y = x̂ (the normalized value). As training progresses, the network can learn whatever scale and shift is optimal. In the extreme case, it can learn γ = σ and β = μ, which would perfectly undo the normalization.
Let's trace the complete normalization pipeline with a concrete example. Input: x = [4, 8, 6, 2, 10]. Learnable parameters: γ = 1.5, β = -0.3 (pretend we're mid-training).
Step 1 — Compute statistics:
Step 2 — Normalize:
Step 3 — Scale and shift:
The output has mean = (-1.361 + 0.761 - 0.300 - 2.421 + 1.821)/5 = -1.500/5 = -0.300 = β. The std = 1.5 × 1.0 = 1.5 = γ. This is always true: after applying γ and β, the output has mean = β and std = γ. The learnable parameters directly control the output distribution.
The widget below shows data points on a number line. Drag them to change their values, then toggle centering and scaling to see the effect. Watch the mean line and standard deviation bars update in real time.
Drag points on the top number line. Toggle operations to see the normalized result on the bottom line. The mean and ±1σ markers show the distribution's center and spread.
python import numpy as np def normalize(x, eps=1e-5): """Normalize to mean=0, std=1.""" mu = x.mean() var = ((x - mu) ** 2).mean() x_norm = (x - mu) / np.sqrt(var + eps) return x_norm x = np.array([4.0, 8.0, 6.0, 2.0, 10.0]) x_norm = normalize(x) print(f"Original: mean={x.mean():.2f}, std={x.std():.2f}") print(f"Normalized: mean={x_norm.mean():.4f}, std={x_norm.std():.4f}") print(f"Values: {x_norm}") # With learnable gamma and beta gamma = 1.5 beta = -0.3 y = gamma * x_norm + beta print(f"After scale+shift: mean={y.mean():.2f}, std={y.std():.2f}") # Output: mean=-0.30, std=1.50 — exactly gamma and beta!
Compare to PyTorch's built-in (which we'll use in later chapters):
python import torch x = torch.tensor([4.0, 8.0, 6.0, 2.0, 10.0]) x_norm = (x - x.mean()) / x.std(unbiased=False) print(x_norm) # tensor([-0.7071, 0.7071, 0.0000, -1.4142, 1.4142])
unbiased=False in
the PyTorch code. By default, torch.std() uses Bessel's
correction (divides by N-1, not N), which gives an unbiased estimate of
population variance from a sample. But normalization layers use biased
variance (divide by N) in the forward pass — we're normalizing the actual
batch/features we have, not estimating a population parameter. This is a
common source of off-by-one numerical mismatches when implementing
normalization from scratch.
Here's the punchline, and it's the single most important idea in this entire lesson:
In the next chapter, we'll see the first and most famous instantiation: computing the mean and variance across the batch dimension.
What if we normalize using statistics from the entire mini-batch? In 2015, Sergey Ioffe and Christian Szegedy proposed exactly this in one of the most cited papers in deep learning history. The idea was deceptively simple: for each feature, compute the mean and variance across all samples in the batch, then normalize. They called it Batch Normalization (BN).
BN transformed deep learning overnight. Networks that previously took weeks to train converged in days. Architectures that were impossible to optimize suddenly worked. Learning rates that would have caused divergence became safe. The paper's title understated the impact: "Batch Normalization: Accelerating Deep Network Training."
To understand BN, we need to think about the shape of activations in a network. After a linear layer, the activation tensor has shape [B, D] — B samples in the batch, each with D features. Visualize this as a matrix where rows are samples and columns are features:
| Feature 0 | Feature 1 | Feature 2 | |
|---|---|---|---|
| Sample 0 | 1 | 5 | 3 |
| Sample 1 | 3 | 3 | 7 |
| Sample 2 | 5 | 7 | 1 |
| Sample 3 | 3 | 5 | 5 |
Batch Normalization computes statistics down each column — independently for each feature, across all samples in the batch. For feature 0, it looks at the values [1, 3, 5, 3] from all four samples, computes their mean and variance, and normalizes them. Feature 1 gets its own separate mean and variance from [5, 3, 7, 5]. Feature 2 from [3, 7, 1, 5].
Each feature is normalized independently. The statistics come from the batch.
For each feature j in a mini-batch of B samples:
Note: there are D separate γj and βj parameters — one pair per feature. If you have 512 features, BatchNorm adds 512 learnable scale parameters and 512 learnable shift parameters. That's 1024 extra parameters per BN layer. Compared to the millions of weights in the network, this is negligible.
Let's work through BN on the 4×3 matrix above, step by step. ε = 0 for clarity (no numerical stability term).
Feature 0: values across the batch = [1, 3, 5, 3]
Feature 1: values across the batch = [5, 3, 7, 5]
Feature 2: values across the batch = [3, 7, 1, 5]
Final normalized matrix:
| Feature 0 | Feature 1 | Feature 2 | |
|---|---|---|---|
| Sample 0 | -1.414 | 0 | -0.447 |
| Sample 1 | 0 | -1.414 | 1.342 |
| Sample 2 | 1.414 | 1.414 | -1.342 |
| Sample 3 | 0 | 0 | 0.447 |
Check any column: mean = 0, variance = 1. Each feature is independently normalized. But notice something important — sample 0's normalized value for feature 0 is -1.414, which depends on what samples 1, 2, and 3 had for feature 0. Every sample's output depends on every other sample in the batch.
In backpropagation, we need the gradient of the loss with respect to each input xij. Because the mean and variance are computed from the inputs, the computation graph has three paths:
The full gradient is:
This looks intimidating, but the key insight is: changing one sample's input affects every other sample's normalized output through the shared mean and variance. The Jacobian is NOT diagonal. This coupling between samples is the fundamental difference between BN and normalizations that operate independently per sample.
For years, everyone believed BN worked by reducing internal covariate shift — the shifting input distributions we discussed in Chapter 0. Then in 2018, Santurkar et al. published "How Does Batch Normalization Help Optimization?" and showed something surprising: BN doesn't actually reduce internal covariate shift much.
What it does do is make the loss landscape smoother. Specifically, BN makes the loss surface more Lipschitz continuous — meaning the gradient doesn't change too abruptly as you move through parameter space. A smoother landscape means:
This is why BN lets you train with learning rates 10-30× larger than without it. Not because the gradients are better-scaled (though they are), but because the landscape itself is smoother.
The simulation below shows a batch × features grid. Click on a feature column to highlight it and watch BN compute the mean and variance from the highlighted column, then normalize those values.
Click a feature column to see BN compute its statistics. The left grid shows raw values; the right grid shows normalized values. Each column is normalized independently using statistics from all rows (samples) in that column.
python import torch def batchnorm_forward(x, gamma, beta, eps=1e-5): """BatchNorm forward pass. x: [B, D] — batch of B samples, D features gamma: [D] — learnable scale beta: [D] — learnable shift """ # Statistics across batch dimension (axis=0) mu = x.mean(dim=0) # [D] — one mean per feature var = x.var(dim=0, unbiased=False) # [D] — one var per feature # Normalize x_norm = (x - mu) / torch.sqrt(var + eps) # [B, D] # Scale and shift out = gamma * x_norm + beta # [B, D] return out, mu, var # Test with our 4×3 matrix x = torch.tensor([[1., 5., 3.], [3., 3., 7.], [5., 7., 1.], [3., 5., 5.]]) gamma = torch.ones(3) # scale = 1 (no scaling) beta = torch.zeros(3) # shift = 0 (no shifting) out, mu, var = batchnorm_forward(x, gamma, beta) print(f"Means: {mu}") # tensor([3., 5., 4.]) print(f"Variances: {var}") # tensor([2., 2., 5.]) print(f"Normalized:\n{out}") # matches our hand calc
Compare to PyTorch's built-in:
python import torch.nn as nn bn = nn.BatchNorm1d(3, affine=True) # At initialization: gamma=1, beta=0 print(bn.weight) # Parameter([1., 1., 1.]) — this is gamma print(bn.bias) # Parameter([0., 0., 0.]) — this is beta out_pt = bn(x) # [4, 3] — same result as our from-scratch version
You've trained a model with BatchNorm and it works great. You deploy it. A single image comes in. Batch size = 1. The mean of a single value is... itself. The variance is... zero. You divide by zero. Your model outputs garbage. What went wrong?
This is the fundamental problem with Batch Normalization at inference time: there is no batch. BN computes statistics across the batch dimension, but during inference, you typically process one input at a time (or small, variable-sized groups). The batch statistics are undefined or meaningless.
The solution: during training, keep a running average of the batch statistics. At inference time, use those saved statistics instead of computing them from the (non-existent) batch.
During training, after each batch, BN updates two buffers using an exponential moving average (EMA):
Where m is the momentum parameter (default 0.1 in PyTorch). This is a weighted average where the current batch gets weight m and all previous history gets weight (1-m). Over time, the running statistics converge toward the true population statistics.
Note the confusing naming: PyTorch calls this "momentum" but it's the opposite convention from optimizer momentum. In PyTorch's BN, momentum = 0.1 means 10% new data, 90% old running average. Higher momentum = more weight on the current batch = faster adaptation but noisier estimate.
Let's trace the running mean for a single feature across 5 training batches. The running mean starts at 0 (PyTorch default). Momentum = 0.1.
Batch 1: batch_mean = 3.0
Way off from the true mean. The EMA has barely started.
Batch 2: batch_mean = 5.0
Batch 3: batch_mean = 4.0
Batch 4: batch_mean = 3.5
Batch 5: batch_mean = 4.2
After 5 batches, the running mean is 1.621 — still far from the true mean of ~4.0. Let's see how it progresses:
| After N Batches | Approx. Running Mean | Gap from 4.0 |
|---|---|---|
| 5 | 1.62 | 2.38 |
| 10 | 2.57 | 1.43 |
| 20 | 3.46 | 0.54 |
| 50 | 3.96 | 0.04 |
| 100 | 3.999 | 0.001 |
The EMA is deliberately slow. With momentum 0.1, it takes about 50 batches to get within 1% of the true mean. This is by design — you want the running statistics to be stable, not jerked around by a single unusual batch.
PyTorch BatchNorm behaves completely differently in training mode vs evaluation mode:
| model.train() | model.eval() | |
|---|---|---|
| Statistics used | Batch statistics (μbatch, σ²batch) | Running statistics (running_mean, running_var) |
| Running stats | Updated after each forward pass | Frozen (no updates) |
| Output depends on batch | Yes — different batch = different output | No — output is deterministic |
| Batch size requirement | B ≥ 2 (need variance) | B ≥ 1 (any size works) |
The switch is a single line: model.eval() before inference,
model.train() before training. Forgetting this is the
#1 BatchNorm bug in production ML.
model.eval() before
inference. One line of code, hours of debugging. Every production ML
tutorial should start here. And when you load a saved model for inference,
make sure you call model.eval() AFTER model.load_state_dict().
PyTorch's BatchNorm keeps a counter num_batches_tracked that
increments every forward pass during training. This counter is stored in
the state_dict alongside the running statistics. It serves
two purposes:
momentum=None, BN uses 1/num_batches_tracked as the momentum
instead of a fixed value. This gives equal weight to all batches seen so
far (simple average instead of exponential average). Rarely used in practice.Here's a detail that trips up anyone implementing BN from scratch. During the forward pass, BN computes the biased variance (divides by N). But when updating the running variance, PyTorch applies Bessel's correction:
The factor B/(B-1) converts the biased variance to an unbiased estimate. Why? The running variance is an estimate of the population variance, and the unbiased estimator is more accurate for that purpose. The forward pass uses biased variance because we're normalizing the actual data we have, not estimating a population parameter.
For a batch of size 32, the correction factor is 32/31 ≈ 1.032 — a 3.2% difference. Small, but it accumulates over thousands of batches and can cause numerical mismatches between your implementation and PyTorch's.
state_dict) but they don't receive gradients. When you call
model.load_state_dict(), the running_mean and running_var
are restored. If you forget to load them (or load only the weight/bias
parameters), your model uses random running stats and produces garbage
at inference time. Always save and load the complete state_dict.
The simulation below shows how running statistics converge during training and then freeze during inference. On the left, watch the batch statistics jump around (blue dots) while the running average (orange line) smoothly converges. On the right, see how eval mode uses the frozen running statistics.
Each dot is a batch mean. The smooth line is the running mean (EMA). Toggle between train and eval mode. In eval mode, the running statistics freeze and new batches don't update them.
python import torch import torch.nn as nn class MyBatchNorm1d(nn.Module): def __init__(self, D, eps=1e-5, momentum=0.1): super().__init__() self.eps = eps self.momentum = momentum # Learnable parameters self.gamma = nn.Parameter(torch.ones(D)) self.beta = nn.Parameter(torch.zeros(D)) # Buffers (not learned, but saved in state_dict) self.register_buffer('running_mean', torch.zeros(D)) self.register_buffer('running_var', torch.ones(D)) self.register_buffer('num_batches_tracked', torch.tensor(0)) def forward(self, x): if self.training: # Use batch statistics mu = x.mean(dim=0) var = x.var(dim=0, unbiased=False) # Update running stats (with Bessel's correction) B = x.size(0) with torch.no_grad(): self.running_mean = (1-self.momentum) * self.running_mean \ + self.momentum * mu self.running_var = (1-self.momentum) * self.running_var \ + self.momentum * var * B / (B-1) self.num_batches_tracked += 1 else: # Use running statistics (frozen) mu = self.running_mean var = self.running_var x_norm = (x - mu) / torch.sqrt(var + self.eps) return self.gamma * x_norm + self.beta
Test it:
python bn = MyBatchNorm1d(3) x = torch.tensor([[1., 5., 3.], [3., 3., 7.], [5., 7., 1.], [3., 5., 5.]]) # Training mode: uses batch stats, updates running stats bn.train() out_train = bn(x) print(f"running_mean: {bn.running_mean}") # tensor([0.3000, 0.5000, 0.4000]) — 10% of batch mean # Eval mode: uses running stats (frozen) bn.eval() single = torch.tensor([[2., 4., 3.]]) # batch size 1! out_eval = bn(single) # Works! Uses running_mean/running_var print(f"Eval output: {out_eval}")
bn.state_dict() and you'll see: weight (γ), bias (β),
running_mean, running_var, and num_batches_tracked. All five are saved
and loaded together. The running statistics and the counter are buffers,
not parameters — they don't appear in bn.parameters() and the
optimizer never touches them. But they're critical for inference.
What if we don't want our normalization to depend on other samples at all? What if each token, each sample, each data point normalizes itself?
That's the idea behind Layer Normalization (LN), proposed by Ba, Kiros, and Hinton in 2016. Instead of computing statistics across the batch (down the columns of our B×D matrix), LN computes statistics across the features (along the rows). Each sample is normalized independently using its own mean and variance.
This one change — batch axis → feature axis — eliminates every problem we discussed in Chapter 3. No running statistics. No train/eval split. No batch size dependency. And it's the reason every modern transformer uses Layer Normalization.
Let's use the same 4×3 matrix from Chapter 2. But now, instead of looking down each column (BN), we look across each row (LN):
| Feature 0 | Feature 1 | Feature 2 | ← LN direction | |
|---|---|---|---|---|
| Sample 0 | 1 | 5 | 3 | μ=3.0, σ²=2.667 |
| Sample 1 | 3 | 3 | 7 | μ=4.333, σ²=3.556 |
| Sample 2 | 5 | 7 | 1 | μ=4.333, σ²=6.222 |
| Sample 3 | 3 | 5 | 5 | μ=4.333, σ²=0.889 |
Each row is a separate world. Sample 0's normalization uses only the values [1, 5, 3] — it has no knowledge of what samples 1, 2, or 3 look like. This independence is the key property of Layer Normalization.
Same batch, new axis. Let's compute LN for each sample. ε = 0 for clarity.
Sample 0: values across features = [1, 5, 3]
Sample 1: values across features = [3, 3, 7]
Sample 2: values across features = [5, 7, 1]
Sample 3: values across features = [3, 5, 5]
Final LN-normalized matrix:
| Feature 0 | Feature 1 | Feature 2 | |
|---|---|---|---|
| Sample 0 | -1.225 | 1.225 | 0 |
| Sample 1 | -0.707 | -0.707 | 1.414 |
| Sample 2 | 0.267 | 1.069 | -1.336 |
| Sample 3 | -1.414 | 0.707 | 0.707 |
Check any row: mean ≈ 0, variance ≈ 1. Compare this to the BN result from Chapter 2 — the numbers are completely different! BN and LN produce different outputs because they normalize along different axes.
Let's make the axis difference crystal clear:
| Batch Normalization | Layer Normalization | |
|---|---|---|
| Normalizes across | Batch dimension (columns) | Feature dimension (rows) |
| Stats computed from | All B samples for each feature | All D features for each sample |
| Number of means | D (one per feature) | B (one per sample) |
| Sample independence | No — samples affect each other | Yes — each sample normalized alone |
| Running statistics | Yes (EMA during training) | No |
| Train/eval difference | Yes — must call model.eval() | No — identical behavior |
| Batch size requirement | B ≥ 2 during training | Any batch size, always |
| γ, β shape | [D] | [D] (same!) |
In a transformer, the hidden representation has shape [batch, seq_len, d_model]. A single "sample" for LayerNorm is a single token's representation — a vector of d_model numbers (e.g., 768 in BERT, 4096 in GPT-3).
LayerNorm normalizes over the last dimension (d_model). This means:
For a batch of 8 sequences, each 512 tokens long, with d_model = 768:
| What | Value |
|---|---|
| Input shape | [8, 512, 768] |
| Number of separate means computed | 8 × 512 = 4,096 |
| Each mean computed from | 768 values (one token's features) |
| γ shape | [768] |
| β shape | [768] |
| Learnable parameters | 768 + 768 = 1,536 |
The γ and β are shared across all tokens and all samples. There's one scale and one shift per feature dimension, applied identically everywhere. This is why the parameter count is so small: just 2 × d_model per LN layer.
model.train() and
model.eval() do nothing to LayerNorm layers. This makes LN
dramatically simpler to deploy, debug, and reason about. It's one of the
key reasons transformers are easier to train than older architectures that
relied on BatchNorm.
For convolutional neural networks processing images, BatchNorm makes intuitive sense. In a CNN, each channel (feature map) has a consistent meaning: channel 37 might detect vertical edges. It makes sense to normalize channel 37 to have consistent statistics across all images — the concept of "how much vertical edge is present" should have a stable distribution.
For transformers processing sequences, the situation is reversed. The "features" are the hidden dimensions of each token's representation. Unlike CNN channels, these don't have stable per-dimension semantics. What matters is the relative pattern across dimensions for each token. And crucially, sequences have variable length — a batch might contain a 10-token sentence and a 500-token document. Computing batch statistics across these wildly different sequences would be meaningless.
There are also practical advantages:
The simulation below shows the same batch×features grid, but now you can toggle between BN mode (column highlighting) and LN mode (row highlighting). Watch which cells contribute to each normalization operation.
Toggle between Batch Norm and Layer Norm to see which cells (highlighted) contribute to each normalization. Click any cell to highlight its normalization group.
python import torch def layernorm_forward(x, gamma, beta, eps=1e-5): """LayerNorm forward pass. x: [B, D] — batch of B samples, D features gamma: [D] — learnable scale beta: [D] — learnable shift """ # Statistics across feature dimension (axis=-1) mu = x.mean(dim=-1, keepdim=True) # [B, 1] var = x.var(dim=-1, keepdim=True, unbiased=False) # [B, 1] # Normalize x_norm = (x - mu) / torch.sqrt(var + eps) # [B, D] # Scale and shift out = gamma * x_norm + beta # [B, D] return out # Test with the same 4×3 matrix x = torch.tensor([[1., 5., 3.], [3., 3., 7.], [5., 7., 1.], [3., 5., 5.]]) gamma = torch.ones(3) beta = torch.zeros(3) out = layernorm_forward(x, gamma, beta) print(out) # Row 0: [-1.2247, 1.2247, 0.0000] — matches hand calc!
Compare to PyTorch:
python import torch.nn as nn ln = nn.LayerNorm(3) # normalized_shape = [3] (the feature dim) out_pt = ln(x) print(out_pt) # Same result! # Key difference from BN: print(list(ln.state_dict().keys())) # ['weight', 'bias'] — that's it! No running_mean, no running_var # Works with batch size 1: single = torch.tensor([[2., 4., 3.]]) print(ln(single)) # No error, no special mode needed
Notice the nn.LayerNorm(3) constructor takes the
normalized shape — the shape of the dimensions to
normalize over. For a transformer with d_model=768, you'd write
nn.LayerNorm(768). For a tensor with shape [B, T, D], this
normalizes over the last dimension D. You can also pass a tuple like
(T, D) to normalize over multiple trailing dimensions.
There are two placement patterns for LayerNorm in transformers, and the distinction matters:
Post-LN (original Transformer, 2017): normalization comes after the residual connection.
Pre-LN (GPT-2 and most modern models): normalization comes before the sublayer, inside the residual branch.
Pre-LN is more stable for training deep models (100+ layers) because the residual path carries un-normalized signals directly, preventing gradient vanishing through the normalization layers. Post-LN can achieve slightly better final performance but is harder to train and often requires learning rate warmup.
LayerNorm does two things: center (subtract the mean) then scale (divide by the standard deviation). In the last chapter we saw how this makes each sample self-contained — no batch dependency, perfect for transformers. But what if one of those two steps doesn't actually matter?
That's the question Biao Zhang and Rico Sennrich asked in 2019. Their answer: skip the centering step entirely. Just divide by the root-mean-square. The result is RMSNorm — the normalization behind virtually every modern large language model.
The intuition is surprisingly simple. After many layers of training, residual connections and careful initialization conspire to keep activations roughly zero-centered already. The mean subtraction step in LayerNorm is doing almost nothing — it's subtracting a number very close to zero. So we remove it and save the compute.
Root Mean Square (RMS) is exactly what it sounds like: take each value, square it, average those squares, then take the square root. For a vector x with D elements:
Then RMSNorm normalizes and scales:
That's it. No mean subtraction. No bias term β. Just divide by the RMS and multiply by a learnable scale γ. Two fewer operations than LayerNorm, and one fewer parameter vector (β is gone).
Let's run both methods on the same input and compare. Take:
LayerNorm (from Chapter 4):
Step 1 — Mean: μ = (2 + (−1) + 3 + (−2) + 1) / 5 = 3/5 = 0.6
Step 2 — Center: x − μ = [1.4, −1.6, 2.4, −2.6, 0.4]
Step 3 — Variance: σ² = (1.96 + 2.56 + 5.76 + 6.76 + 0.16) / 5 = 17.2 / 5 = 3.44
Step 4 — Standard deviation: σ = √3.44 = 1.855
Step 5 — Normalize: xLN = [1.4/1.855, −1.6/1.855, 2.4/1.855, −2.6/1.855, 0.4/1.855]
RMSNorm:
Step 1 — Square each element: x² = [4, 1, 9, 4, 1]
Step 2 — Mean of squares: (4 + 1 + 9 + 4 + 1) / 5 = 19/5 = 3.8
Step 3 — RMS: √3.8 = 1.949
Step 4 — Normalize: xRMS = [2/1.949, −1/1.949, 3/1.949, −2/1.949, 1/1.949]
The outputs are different. LayerNorm's output sums to ~0 (it's centered). RMSNorm's output sums to ~1.54 (not centered). But both produce well-scaled values — nothing is exploding or vanishing. In practice, after many training steps, the γ parameters adapt to compensate for these differences, and both methods yield equivalent model quality.
There's a clean algebraic connection between RMS and standard deviation. Recall:
Rearranging: E[x²] = Var(x) + (E[x])². But E[x²] is exactly RMS². So:
When the mean μ ≈ 0 (which it usually is in deep transformers, thanks to residual connections and symmetric initialization), then RMS² ≈ σ², which means RMS ≈ σ. In that regime, RMSNorm and LayerNorm produce nearly identical outputs. The centering step in LayerNorm was subtracting ~0 and the denominator was effectively the same. RMSNorm just makes this official.
LayerNorm needs two passes over the data: one to compute the mean (summing D values), and a second to compute the variance (summing D squared deviations). That's 2D additions, D subtractions, D squarings, plus the final division.
RMSNorm needs one pass: square each element and sum. That's D multiplications and D additions. No mean pass. No subtraction pass. For dmodel = 8192 (LLaMA 70B), this saves roughly 15% of normalization compute per layer. Across 80 layers, it adds up.
The savings are even more significant on GPU hardware. Two-pass algorithms require either reading the tensor from memory twice (doubling memory bandwidth) or storing intermediate results. RMSNorm's single pass is more cache-friendly and parallelizable.
| Model | Organization | Year | Normalization |
|---|---|---|---|
| T5 | 2019 | RMSNorm (they called it "simplified LN") | |
| PaLM | 2022 | RMSNorm | |
| LLaMA / LLaMA 2 / LLaMA 3 | Meta | 2023–24 | RMSNorm |
| Mistral / Mixtral | Mistral AI | 2023–24 | RMSNorm |
| Gemma / Gemma 2 | 2024 | RMSNorm | |
| Qwen / Qwen 2 | Alibaba | 2024 | RMSNorm |
| GPT-2, BERT, GPT-3 | OpenAI / Google | 2018–20 | LayerNorm (older era) |
The pattern is clear: every LLM released after ~2022 uses RMSNorm. The older LayerNorm holdouts are pre-2021 models. RMSNorm is the modern default for language models.
Compare the two pipelines side by side. Drag the slider to change the input values and watch the outputs update in real time. Notice how both methods produce well-scaled results, but RMSNorm skips the entire centering step.
The beauty of RMSNorm is how short it is. Three lines of math:
python import torch import torch.nn as nn class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-6): super().__init__() self.eps = eps self.gamma = nn.Parameter(torch.ones(dim)) # learnable scale def forward(self, x): # x shape: (batch, seq_len, dim) rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) return self.gamma * (x / rms)
Compare to the LLaMA implementation (simplified from the Meta codebase):
python # From Meta's LLaMA — same logic, production-optimized class LlamaRMSNorm(nn.Module): def __init__(self, dim, eps=1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): output = self._norm(x.float()).type_as(x) # float32 for stability return output * self.weight
Notice the rsqrt call — that's 1/√x, a single GPU instruction.
And the .float() cast: LLaMA runs in bfloat16, but normalization is done
in float32 to avoid precision issues. This is a common production pattern —
mixed-precision normalization.
So far we've seen two extremes. BatchNorm normalizes across the batch — one set of statistics per feature, shared across all samples. LayerNorm normalizes across all features — one set of statistics per sample, independent of the batch. What about something in between?
This question matters most for vision models. Images have spatial structure. A convolutional feature map has shape [B, C, H, W] — batch, channels, height, width. Each "channel" detects a different pattern (edges, textures, shapes). The question is: which dimensions do you normalize over?
BatchNorm in a CNN computes statistics over [B, H, W] for each channel independently. This works brilliantly when B is large (32, 64, 128) because you get a stable estimate of each channel's mean and variance. But in many vision tasks — object detection, segmentation, 3D reconstruction, medical imaging — you often can't fit more than 1-4 images per GPU. With B=2, your "statistics" are computed from just 2 samples. That's noise, not signal.
LayerNorm computes over [C, H, W] for each sample. No batch dependency, great. But it lumps all channels together. A channel detecting horizontal edges and a channel detecting blue pixels have very different activation distributions. Normalizing them jointly can wash out meaningful differences.
We need a middle ground.
GroupNorm (Wu & He, 2018) splits the C channels into G groups, then normalizes within each group. If you have 256 channels and G=32 groups, each group contains 256/32 = 8 channels. Within each group, you compute the mean and variance across those 8 channels and all spatial positions (H × W). Each sample is processed independently — no batch dependency.
The computation for one sample, one group:
The learnable γ and β are still per-channel (not per-group), so the model can rescale each channel independently after normalization.
What happens when every channel is its own group? That's G = C. Each channel is normalized independently over just its spatial dimensions (H × W). This is called Instance Normalization (Ulyanov et al., 2016).
InstanceNorm was invented for style transfer. The core insight: visual "style" (texture, color palette, brushstroke patterns) is encoded in per-channel statistics — the mean and variance of each feature map. If you normalize away each channel's own statistics, you effectively strip the style from the image. The network can then re-inject a new style through the learned γ and β.
Let's work a concrete example. One image, 4 channels, 2×2 spatial:
data Channel 0: [[1, 2], Channel 1: [[5, 6], [3, 4]] [7, 8]] Channel 2: [[9, 10], Channel 3: [[13, 14], [11, 12]] [15, 16]]
GroupNorm with G=2 (2 groups, 2 channels each):
Group 0 (channels 0 and 1) — all values: [1, 2, 3, 4, 5, 6, 7, 8]
μ0 = (1+2+3+4+5+6+7+8) / 8 = 36/8 = 4.5
σ0² = ((1−4.5)² + (2−4.5)² + ... + (8−4.5)²) / 8 = (12.25 + 6.25 + 2.25 + 0.25 + 0.25 + 2.25 + 6.25 + 12.25) / 8 = 42 / 8 = 5.25
σ0 = √5.25 = 2.291
Channel 0 normalized: [(1−4.5)/2.291, (2−4.5)/2.291, (3−4.5)/2.291, (4−4.5)/2.291] = [−1.528, −1.091, −0.655, −0.218]
Channel 1 normalized: [(5−4.5)/2.291, (6−4.5)/2.291, (7−4.5)/2.291, (8−4.5)/2.291] = [0.218, 0.655, 1.091, 1.528]
Group 1 (channels 2 and 3) — all values: [9, 10, 11, 12, 13, 14, 15, 16]
μ1 = 100/8 = 12.5, σ1² = 5.25 (same spread!), σ1 = 2.291
The normalized values for Group 1 have the exact same pattern as Group 0 — just shifted in the original space. After normalization, both groups span the same range [−1.528, +1.528]. That's normalization doing its job.
InstanceNorm (G=4) — each channel alone:
Channel 0: values [1, 2, 3, 4]. μ = 2.5, σ² = 1.25, σ = 1.118. Normalized: [−1.342, −0.447, 0.447, 1.342]
Channel 1: values [5, 6, 7, 8]. μ = 6.5, σ² = 1.25, σ = 1.118. Normalized: [−1.342, −0.447, 0.447, 1.342] — same pattern! (Same spread.)
Notice that InstanceNorm produces the same normalized pattern for every channel because they all have the same internal spread. GroupNorm, by pooling channels, captures the cross-channel differences (channel 0 is low, channel 1 is high) and normalizes them together. Which is better depends on the task.
All five methods we've seen form a spectrum based on which dimensions they share statistics over:
| Method | Stats computed over | Batch dependent? | Params | Best for |
|---|---|---|---|---|
| BatchNorm | B, H, W (per channel) | Yes | γ, β per channel | CNNs (large batch) |
| GroupNorm | C/G, H, W (per group, per sample) | No | γ, β per channel | CNNs (small batch) |
| InstanceNorm | H, W (per channel, per sample) | No | γ, β per channel | Style transfer |
| LayerNorm | C, H, W (per sample) | No | γ, β per feature | Transformers, RNNs |
| RMSNorm | D (per sample, no mean) | No | γ per feature | LLMs |
Reading left to right: BN shares the most across samples (needs the batch), InstanceNorm shares the least (each channel is independent), and GroupNorm/LayerNorm sit in between. The trend in modern AI has been away from batch-dependent methods and toward per-sample methods.
A feature map with shape [B, C, H, W]. The highlighted cells show which values are pooled to compute one set of statistics (μ, σ). Switch between methods to see how they differ. For GroupNorm, adjust G.
python import torch import torch.nn as nn class GroupNorm(nn.Module): def __init__(self, num_groups, num_channels, eps=1e-5): super().__init__() self.G = num_groups self.C = num_channels self.eps = eps self.gamma = nn.Parameter(torch.ones(num_channels)) self.beta = nn.Parameter(torch.zeros(num_channels)) def forward(self, x): B, C, H, W = x.shape # Reshape: group channels together x = x.view(B, self.G, C // self.G, H, W) # Stats over (C/G, H, W) — dims 2, 3, 4 mu = x.mean(dim=[2, 3, 4], keepdim=True) var = x.var(dim=[2, 3, 4], keepdim=True, unbiased=False) x = (x - mu) / torch.sqrt(var + self.eps) # Reshape back, apply per-channel scale and shift x = x.view(B, C, H, W) return self.gamma.view(1, C, 1, 1) * x + self.beta.view(1, C, 1, 1)
The PyTorch built-in is identical in spirit:
python # Built-in — same thing, CUDA-optimized norm = nn.GroupNorm(num_groups=32, num_channels=256) # 8 channels per group out = norm(x) # x shape: (B, 256, H, W)
The common default is G=32. Why 32? The GroupNorm paper tested G={1, 2, 4, 8, 16, 32, 64} and found 32 performed best across ResNets. It's not a magic number — values between 16 and 64 all work well. The key constraint is that C must be divisible by G.
We know what to normalize. We know how (LayerNorm, RMSNorm, etc.). Now the architectural question that tripped up the entire field for two years: where in the network do you place the normalization layer?
This sounds like a trivial detail. It's not. Moving normalization by one line of code — from after the residual addition to before the sublayer — was the single change that let GPT-2 scale to 1.5 billion parameters without collapsing during training.
A standard transformer block has two sublayers: multi-head attention and a feedforward network (FFN). Each sublayer has a residual connection. The question is where LayerNorm (or RMSNorm) goes relative to these components.
Post-LN (the original, "Attention Is All You Need" 2017):
The residual is added first, then normalized. The norm wraps the entire output.
Pre-LN (GPT-2, most modern models):
Normalize first, then run through attention/FFN, then add the residual. The residual connection directly adds the raw input to the sublayer output.
Visually, the difference is tiny. Computationally, it changes everything about gradient flow.
The key is the gradient highway. In Pre-LN, the residual connection creates a direct path from the output of any layer all the way back to the input. When gradients flow backward during training, they can travel along this highway without passing through any normalization or nonlinearity.
Let's trace the gradient path through L layers of Pre-LN:
Taking the derivative of the final output with respect to an early layer's input:
That I (identity) term is the gradient highway. No matter what happens in the sublayers — even if their gradients vanish completely — the identity term guarantees a gradient of magnitude 1 flows directly from layer L to layer l. The gradient cannot vanish.
Now contrast with Post-LN. Each layer's output passes through LayerNorm, which has its own Jacobian. The gradient must flow through L consecutive LayerNorm Jacobians:
Each JLN can shrink or amplify the gradient depending on the activation scale at that layer. Over L=48 layers, these Jacobian products can compound into vanishingly small or explosively large values. That's why Post-LN needs careful learning rate warmup — you start with tiny steps to avoid the explosion, then slowly increase.
Imagine a simplified 4-layer network. The gradient signal starts at 1.0 at the output and flows backward.
Pre-LN: at each layer, the gradient splits. One copy travels along the residual highway (unchanged). The other goes through the sublayer. Even if each sublayer attenuates its gradient by 0.5×:
gradient trace Layer 4 (output): gradient = 1.0 Layer 3: highway = 1.0, sublayer path = 0.5, total ≈ 1.0 Layer 2: highway = 1.0, sublayer paths = 0.5 + 0.25, total ≈ 1.0 Layer 1: highway = 1.0, sublayer paths accumulate, total ≈ 1.0 Layer 0: highway = 1.0 ← always at least 1.0
Post-LN: no highway. Every path goes through LayerNorm Jacobians. If each LayerNorm attenuates the gradient by 0.9×:
gradient trace Layer 4 (output): gradient = 1.0 Layer 3: 1.0 × 0.9 = 0.90 Layer 2: 0.90 × 0.9 = 0.81 Layer 1: 0.81 × 0.9 = 0.73 Layer 0: 0.73 × 0.9 = 0.66 ← 34% signal lost in just 4 layers
At 48 layers: 0.948 = 0.006. The gradient reaching layer 0 is 0.6% of its original value. That's vanishing gradients. With Pre-LN, the highway guarantees it stays at 1.0 regardless of depth.
nn.TransformerEncoderLayer defaults to... Post-LN! Many
practitioners unknowingly train Post-LN networks and blame "transformers being
finicky" when Pre-LN would train stably out of the box. Always check which
variant you're using.
Post-LN's instability manifests at the very start of training. Before the model has learned anything, activations can be large and poorly scaled. The LayerNorm Jacobians are at their most unpredictable. A full-sized learning rate applied to these chaotic early gradients can cause the loss to explode immediately.
The fix: learning rate warmup. Start with a tiny learning rate (e.g., 1/1000th of the target) and linearly increase it over the first few thousand steps. By the time the learning rate reaches its full value, the network has settled enough that the LayerNorm Jacobians are well-behaved.
Pre-LN doesn't need this. Because the gradient highway guarantees stable gradient flow from step 1, you can use the full learning rate immediately. This is one reason Pre-LN became the standard — it removes an entire hyperparameter (warmup schedule) from the equation.
Two transformer stacks side by side. Arrows show gradient magnitude flowing backward. Increase the layer count and watch Post-LN gradients vanish while Pre-LN stays stable via the residual highway.
python class PostLNBlock(nn.Module): """Original Transformer (2017) — norm AFTER residual add""" def __init__(self, d_model, n_heads, d_ff): super().__init__() self.attn = nn.MultiheadAttention(d_model, n_heads) self.ffn = nn.Sequential( nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model)) self.ln1 = nn.LayerNorm(d_model) self.ln2 = nn.LayerNorm(d_model) def forward(self, x): x = self.ln1(x + self.attn(x, x, x)[0]) # add THEN norm x = self.ln2(x + self.ffn(x)) # add THEN norm return x class PreLNBlock(nn.Module): """GPT-2 / Modern Transformer — norm BEFORE sublayer""" def __init__(self, d_model, n_heads, d_ff): super().__init__() self.attn = nn.MultiheadAttention(d_model, n_heads) self.ffn = nn.Sequential( nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model)) self.ln1 = nn.LayerNorm(d_model) self.ln2 = nn.LayerNorm(d_model) def forward(self, x): x = x + self.attn(self.ln1(x), self.ln1(x), self.ln1(x))[0] # norm THEN add x = x + self.ffn(self.ln2(x)) # norm THEN add return x
One line of difference. Yet Pre-LN trains stably at 48+ layers with no warmup, while Post-LN can diverge at 12 layers without careful scheduling.
Pre-LN solved training stability but introduced a new problem. Each sublayer's output gets added directly to the residual stream without normalization. After many layers, the residual magnitudes can grow unboundedly. At 100+ layers, this matters.
DeepNorm (Microsoft, 2022) addresses this. It uses Post-LN but scales down the residual before adding: output = LayerNorm(α · x + Sublayer(x)), where α = (2N)1/4 for N layers. This keeps the residual magnitude stable even at 1000 layers. GPT-4 is rumored to use a variant.
Sandwich Norm applies normalization both before AND after each sublayer. Belt and suspenders. More expensive but extremely stable.
| Variant | Formula | Pros | Cons | Used by |
|---|---|---|---|---|
| Post-LN | LN(x + Sub(x)) | Better final quality (some studies) | Needs warmup, unstable deep | Original Transformer, BERT |
| Pre-LN | x + Sub(LN(x)) | Stable, no warmup needed | Residual growth at depth | GPT-2/3, LLaMA, Mistral |
| DeepNorm | LN(αx + Sub(x)) | Stable at 1000+ layers | Requires careful α tuning | Very deep models |
| Sandwich | x + LN(Sub(LN(x))) | Very stable | 2× norm compute | CogView, some ViTs |
Let's race them all.
You've learned six normalization strategies: None, BatchNorm, LayerNorm, RMSNorm, GroupNorm, and Pre-LN placement. Each has strengths and weaknesses — but reading about failure modes is one thing. Watching them happen in real time is another.
This simulation trains six identical networks on the same task, differing only in normalization. Drag the sliders to create the conditions where each method fails. You'll discover that there is no universally best normalization — only the right one for your constraints.
Six methods train simultaneously on the same task. Each runs characteristic training dynamics. Find each method's failure mode.
Experiment 1: Batch size = 1. Drag the batch size slider all the way left. BatchNorm computes statistics from a single sample — the "mean" is just the value itself, the "variance" is zero. BN's loss flatlines or oscillates wildly. Every other method is unaffected because they don't depend on batch statistics. This is why BN is never used in LLMs (batch size 1 during inference) or in many vision tasks (small batch for high-resolution images).
Experiment 2: Depth = 32 layers. The "None" network's loss immediately explodes — activations compound through 32 layers with no normalization, causing overflow. BN and GroupNorm handle depth well when batch size is adequate. Pre-LN (with its gradient highway) is the most stable at extreme depth.
Experiment 3: Learning rate = 1.0. Drag the LR slider all the way right. With this aggressive learning rate, the unnormalized network diverges instantly. BN diverges next (less stable gradients). LayerNorm and GroupNorm struggle. Pre-LN survives the longest because its gradient highway prevents the amplification that causes divergence.
Experiment 4: Compare LN vs RMSNorm. Under all conditions, the teal (LN) and purple (RMSNorm) curves track each other almost exactly. That's the whole point — RMSNorm matches LayerNorm's quality while being ~15% faster. The speed difference isn't visible in the simulation, but the quality equivalence is.
| Method | Fails when... | Symptom | Fix |
|---|---|---|---|
| None | Depth > 8 or LR > 0.01 | Loss explodes (NaN) | Add any normalization |
| BatchNorm | Batch size < 4 | Noisy/collapsed statistics | Use GroupNorm or LayerNorm |
| LayerNorm | Extreme depth + Post-LN | Vanishing gradients | Switch to Pre-LN placement |
| RMSNorm | Same as LayerNorm | Same as LayerNorm | Same as LayerNorm |
| GroupNorm | G too large for few channels | Underfitting | Reduce G or use LayerNorm |
| Pre-LN | Very deep (>100 layers) | Residual magnitude growth | Add residual scaling (DeepNorm) |
You now understand the complete normalization toolkit — from BatchNorm to RMSNorm, from Post-LN to Pre-LN. This chapter is your practical reference. No new concepts. Just the formulas, the decision guide, and the connections to where you go next.
| Method | Formula | Stats axis | Parameters | Compute |
|---|---|---|---|---|
| BatchNorm | γ · (x − μB) / σB + β | Batch (B, H, W) | γ, β per feature | 2 passes + running stats |
| LayerNorm | γ · (x − μL) / σL + β | Features (D) or (C,H,W) | γ, β per feature | 2 passes |
| RMSNorm | γ · x / RMS(x) | Features (D) | γ per feature | 1 pass |
| GroupNorm | γ · (x − μG) / σG + β | Group (C/G, H, W) | γ, β per channel | 2 passes per group |
| InstanceNorm | γ · (x − μI) / σI + β | Spatial (H, W) | γ, β per channel | 2 passes per channel |
Follow the path that matches your situation:
| Symbol | Meaning | Typical values |
|---|---|---|
| B | Batch size | 1–4096 |
| D | Feature dimension (dmodel in transformers) | 256–8192 |
| C | Number of channels (vision) | 3–2048 |
| H, W | Spatial height and width | 7–224 |
| G | Number of groups (GroupNorm) | 16–64 (default: 32) |
| γ | Learnable scale parameter | Initialized to 1 |
| β | Learnable shift parameter | Initialized to 0 |
| ε | Small constant for numerical stability | 1e-5 or 1e-6 |
| μ | Mean of activations | Varies |
| σ | Standard deviation of activations | Varies |
python import torch.nn as nn # BatchNorm — for CNNs with large batches bn = nn.BatchNorm2d(num_features=256) # LayerNorm — for transformers (encoder-style) ln = nn.LayerNorm(normalized_shape=768) # RMSNorm — for LLMs (PyTorch 2.4+) rms = nn.RMSNorm(normalized_shape=4096) # GroupNorm — for CNNs with small batches gn = nn.GroupNorm(num_groups=32, num_channels=256) # InstanceNorm — for style transfer isn = nn.InstanceNorm2d(num_features=256)
Normalization doesn't exist in isolation. Here's where to go next:
| Paper | Year | Contribution |
|---|---|---|
| Ioffe & Szegedy, "Batch Normalization" | 2015 | Introduced BN, enabled 10× faster training |
| Ba, Kiros & Hinton, "Layer Normalization" | 2016 | Batch-independent normalization for RNNs/Transformers |
| Ulyanov et al., "Instance Normalization" | 2016 | Per-channel normalization for style transfer |
| Wu & He, "Group Normalization" | 2018 | Bridge between BN and LN for small-batch vision |
| Zhang & Sennrich, "Root Mean Square Layer Normalization" | 2019 | Remove centering for 15% speedup, no quality loss |
| Xiong et al., "On Layer Normalization in the Transformer Architecture" | 2020 | Formal analysis of Pre-LN vs Post-LN gradient flow |