Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton (University of Toronto) — arXiv 2016

Layer Normalization

The normalization technique that freed neural networks from batch dependence — normalizing across features within each example, making it the default choice for Transformers, RNNs, and any setting where batch statistics are unreliable.

Prerequisites: Mean and variance + Neural network basics + Backpropagation. That's it.
8
Chapters
8+
Simulations
0
Assumed Knowledge

Chapter 0: The Batch Problem

You're training a language model on sentences of varying length. Your batch has 32 sentences: some are 5 words long, others are 50. You're using Batch Normalization (BatchNorm), the technique that revolutionized image classification in 2015. BatchNorm computes the mean and variance of each feature across the batch — averaging over those 32 sentences — and uses them to normalize activations.

Here's the problem: at timestep 40, only 8 of your 32 sentences are still alive (the short ones have already ended). You're computing batch statistics from just 8 examples. At timestep 48, maybe only 2 sentences remain. Your "batch mean" is now the average of two numbers — a terrible estimate of the true mean. The variance estimate is even worse.

This isn't a minor inconvenience. It's a fundamental incompatibility between BatchNorm and sequential data. The batch statistics become noisy and unreliable as sequence positions get sparse. And it gets worse: during inference, you typically process one sentence at a time (batch size = 1). BatchNorm's running statistics, computed during training on batches of 32+, don't match the single-example regime at all.

BatchNorm's Sequence Problem

Each row is a sentence in the batch. Colored cells are active tokens; gray cells mean the sentence has ended. BatchNorm computes statistics DOWN each column (across the batch). Watch how the number of active examples shrinks at later timesteps, making the statistics unreliable. Drag the slider to change the timestep.

Timestep 2
PropertyBatchNormProblem for Sequences
Normalizes overBatch dimension (across examples)Batch size varies per timestep
Batch size = 1Uses running statisticsStatistics mismatch with training
Variable lengthsPadding masks neededStatistics from padding are meaningless
Online learningImpossible (needs batch)Many real applications need online
The core tension: BatchNorm normalizes across the batch — it asks "how does this feature compare to the same feature in other examples?" But for sequences with variable lengths, different examples contribute to different timesteps. The batch statistics at later timesteps come from a dwindling, biased subset of examples. We need a normalization that works within a single example, independent of what else is in the batch.

Ba, Kiros, and Hinton proposed Layer Normalization (LayerNorm) — a technique that normalizes across the feature dimension within a single training example. Instead of asking "how does this feature compare across the batch?", it asks "how does this feature compare to the other features in the same token?" This makes it completely independent of the batch, and equally valid whether the batch has 1 example or 1000.

LayerNorm became the default normalization in Transformers, enabling the entire GPT/BERT revolution. Every modern large language model — GPT-4, Claude, Llama — uses LayerNorm (or its variant, RMSNorm). Understanding it is essential for understanding how modern neural networks stabilize training.

The internal covariate shift hypothesis

The original motivation for BatchNorm was internal covariate shift: as the parameters of a lower layer change during training, the distribution of inputs to higher layers shifts. Each layer must constantly readapt to new input distributions, slowing convergence. BatchNorm fixes this by normalizing each feature to zero mean and unit variance at every layer.

LayerNorm achieves the same stabilization effect but through a different mechanism. Instead of normalizing across the batch (ensuring each feature has consistent statistics across examples), it normalizes across features (ensuring each example has consistent activation magnitude). The effect on training stability is similar: activations don't explode or vanish, gradient magnitudes stay reasonable, and the loss landscape becomes smoother.

python
# The BatchNorm vs LayerNorm difference in one snippet
import torch

x = torch.randn(4, 8)  # 4 examples, 8 features

# BatchNorm: mean/var across batch (dim=0)
bn_mean = x.mean(dim=0)   # shape [8] — one stat per feature
bn_var  = x.var(dim=0)    # shape [8]

# LayerNorm: mean/var across features (dim=1)
ln_mean = x.mean(dim=1)   # shape [4] — one stat per example
ln_var  = x.var(dim=1)    # shape [4]

# BatchNorm needs the batch; LayerNorm doesn't
Why does BatchNorm fail for recurrent neural networks processing variable-length sequences?

Chapter 1: The LayerNorm Definition

Now that we understand the problem, let's see Ba et al.'s solution. The key insight is deceptively simple: instead of normalizing across the batch, normalize across the features.

Consider a single hidden layer activation vector h with H features: h = [h1, h2, ..., hH]. In BatchNorm, we'd compute the mean of h1 across all examples in the batch. In LayerNorm, we compute the mean across the features within this single example:

μ = (1/H) ∑i=1H hi
σ2 = (1/H) ∑i=1H (hi - μ)2

Then normalize each feature by subtracting the mean and dividing by the standard deviation:

i = (hi - μ) / √(σ2 + ε)

Where ε is a small constant (typically 10-5) to prevent division by zero. Finally, we apply learned gain (g) and bias (b) parameters — one per feature — so the model can undo the normalization if needed:

yi = gi · ĥi + bi

The gain and bias are initialized to 1 and 0, respectively. At initialization, LayerNorm just standardizes activations. During training, the model learns to shift and scale each feature to whatever distribution helps the task. If the normalization hurts a particular feature, the model can learn g and b to effectively bypass it.

Think of LayerNorm as a "feature equalizer." Imagine a mixing board where each slider is a feature. Some features might fire very strongly (loud channels), others very weakly (quiet channels). LayerNorm first normalizes all channels to the same volume (zero mean, unit variance), then lets learned gain/bias knobs fine-tune each channel's volume. This prevents any single feature from dominating, while still allowing the model to control the distribution.
BatchNorm vs LayerNorm: What Gets Averaged

A batch of 4 examples, each with 6 features. BatchNorm normalizes DOWN each column (across examples, shown in blue). LayerNorm normalizes ACROSS each row (across features, shown in orange). Toggle between them to see which cells participate in computing each statistic.

The crucial difference is the dimension of normalization. Let's make this concrete with tensor shapes:

QuantityBatchNormLayerNorm
Input shape[B, H][B, H]
Mean computed overB (batch dim)H (feature dim)
Mean shape[H] — one mean per feature[B] — one mean per example
Variance shape[H][B]
Gain/bias shape[H] — scale per feature[H] — scale per feature
Depends on batch?Yes — needs multiple examplesNo — works with batch size 1

For a Transformer with dmodel = 768, LayerNorm computes the mean of 768 features for each token independently. A batch of 32 sequences, each 512 tokens long, means 32 × 512 = 16,384 independent normalizations, each averaging over 768 features. No information flows between tokens, between sequences, or between batch elements.

python
import torch
import torch.nn as nn

# LayerNorm from scratch
def layer_norm(x, gamma, beta, eps=1e-5):
    # x: [B, T, H]  (batch, sequence, features)
    # gamma, beta: [H]  (learned gain and bias)

    # Compute mean and variance across last dim (features)
    mu = x.mean(dim=-1, keepdim=True)     # [B, T, 1]
    var = x.var(dim=-1, keepdim=True,
              unbiased=False)                # [B, T, 1]

    # Normalize
    x_hat = (x - mu) / torch.sqrt(var + eps)  # [B, T, H]

    # Scale and shift
    return gamma * x_hat + beta              # [B, T, H]

# Test: batch of 2 sequences, length 3, dim 4
x = torch.randn(2, 3, 4)
gamma = torch.ones(4)
beta = torch.zeros(4)
out = layer_norm(x, gamma, beta)

# Each token now has mean≈0, var≈1 across features
print(out[0, 0].mean().item())  # ≈ 0.0
print(out[0, 0].var().item())   # ≈ 1.0 (biased var)

# Compare with PyTorch's built-in
ln = nn.LayerNorm(4)
out_pt = ln(x)  # same result
Why does normalizing across features make sense? Features within a layer tend to be on different scales — some might range from -10 to 10, others from -0.01 to 0.01. The features with larger magnitudes dominate the gradient updates, slowing convergence for the smaller features. LayerNorm puts all features on the same scale, giving each feature equal footing in gradient-based learning. It's the same reason you standardize input features in linear regression — but applied at every layer.

Computational cost

LayerNorm is remarkably cheap. For a vector of H features, it requires:

OperationFLOPsNote
Compute meanH additionsOne pass over features
Compute variance2H (subtract + square + sum)Second pass
Normalize2H (subtract mean, divide by std)Elementwise
Scale + shift2H (multiply by g, add b)Elementwise
Total~7H FLOPsLinear in feature dim

For a Transformer with dmodel = 768, that's ~5,400 FLOPs per token. Compare to a single attention head: O(T × dk) = O(512 × 64) = 32,768 FLOPs per query token. LayerNorm is a rounding error compared to attention — it adds less than 1% to the total compute.

The parameters are also minimal: just 2 × H = 2 × 768 = 1,536 parameters for the gain and bias vectors. In a 125M-parameter GPT-2, the total LayerNorm parameters across all 12 layers (24 LayerNorm instances) are 24 × 1,536 = 36,864 — just 0.03% of the model. LayerNorm provides enormous training stability for negligible cost.

What is the key difference in how BatchNorm and LayerNorm compute their normalization statistics?

Chapter 2: The Forward Pass Step by Step

Let's work through the LayerNorm forward pass with actual numbers, so there's no ambiguity about what happens at each step.

Suppose we have a single token with a 4-dimensional feature vector:

h = [2.0, -1.0, 0.5, 3.5]

Step 1: Compute the mean across features

μ = (2.0 + (-1.0) + 0.5 + 3.5) / 4 = 5.0 / 4 = 1.25

Step 2: Compute the variance across features

σ2 = ((2.0-1.25)2 + (-1.0-1.25)2 + (0.5-1.25)2 + (3.5-1.25)2) / 4
σ2 = (0.5625 + 5.0625 + 0.5625 + 5.0625) / 4 = 11.25 / 4 = 2.8125

Step 3: Normalize

σ = √(2.8125 + 10-5) ≈ 1.6771
ĥ = (h - μ) / σ = [(2.0-1.25)/1.677, (-1.0-1.25)/1.677, (0.5-1.25)/1.677, (3.5-1.25)/1.677]
ĥ ≈ [0.447, -1.342, -0.447, 1.342]

Check: mean of ĥ is (0.447 - 1.342 - 0.447 + 1.342) / 4 ≈ 0. Variance is (0.200 + 1.801 + 0.200 + 1.801) / 4 ≈ 1.0. The normalization worked — as it always must, by construction.

A useful mental model: LayerNorm is a two-step operation. Step 1 (centering): subtract the mean to move the distribution to be centered at zero. Step 2 (scaling): divide by the standard deviation to make the spread exactly 1. Together, these ensure every token's feature vector has the same "shape" — zero mean, unit variance — regardless of what the raw activations looked like. The gain and bias then let the model choose the optimal mean and variance for each feature.

Step 4: Scale and shift

yi = gi · ĥi + bi

With default initialization (g = [1,1,1,1], b = [0,0,0,0]), the output equals ĥ. After training, the model might learn g = [0.5, 2.0, 1.0, 0.8] and b = [0.1, -0.3, 0.0, 0.5] to scale features that matter more.

LayerNorm Forward Pass Visualizer

Watch the 4 features get normalized step by step. The bars show raw values (left), centered (middle), and normalized (right). Drag sliders to change the input values and see how the normalization adapts.

Feature 1 2.0
Feature 4 3.5
What does the ε = 10-5 do? If all features happen to be identical (say h = [3, 3, 3, 3]), then σ2 = 0 and we'd divide by zero. The ε prevents this. It also stabilizes gradient flow when the variance is very small — without it, the gradient of 1/σ would explode near σ = 0, causing numerical issues. In practice, features are rarely identical, but ε provides essential safety.

The full computation in one line

For a Transformer layer with input x of shape [B, T, D]:

python
# Full forward pass — what PyTorch's nn.LayerNorm does
def layernorm_forward(x, gamma, beta, eps=1e-5):
    # x shape: [B, T, D] = [batch, seq_len, d_model]
    # gamma, beta shape: [D]

    # Step 1: mean across features (last dim)
    mu = x.mean(dim=-1, keepdim=True)        # [B, T, 1]

    # Step 2: variance across features
    var = ((x - mu) ** 2).mean(dim=-1,
           keepdim=True)                       # [B, T, 1]

    # Step 3: normalize
    x_hat = (x - mu) / torch.sqrt(var + eps) # [B, T, D]

    # Step 4: scale and shift
    y = gamma * x_hat + beta                  # [B, T, D]

    # Cache for backward pass
    cache = (x_hat, gamma, mu, var, eps)
    return y, cache

Notice that the normalization is per-token. Token [0, 3] and token [1, 7] each get their own mean and variance — computed independently from 768 features. No information leaks between tokens, between sequences, or between timesteps. This is what makes LayerNorm work for autoregressive generation, where tokens are generated one at a time.

A common source of bugs: biased vs unbiased variance

PyTorch's torch.var() computes the unbiased variance (divides by N-1) by default, while LayerNorm uses the biased variance (divides by N). This can cause subtle numerical mismatches if you implement LayerNorm manually:

python
# Common bug: using default (unbiased) variance
wrong_var = x.var(dim=-1, keepdim=True)              # divides by N-1
right_var = x.var(dim=-1, keepdim=True, unbiased=False)  # divides by N

# Or equivalently:
right_var = ((x - x.mean(-1, keepdim=True)) ** 2).mean(-1, keepdim=True)

# For d_model=768, the difference is small (768 vs 767)
# but for small dims (d=4), it's 25% off!

For dmodel = 768, dividing by 768 vs 767 is a 0.13% difference — negligible. But for a 4-dimensional feature vector (as in our worked example), dividing by 4 vs 3 is a 25% difference — enough to cause visible discrepancies. Always use unbiased=False when implementing LayerNorm manually.

Mixed precision considerations

In practice, modern LLM training uses mixed precision (FP16 or BF16 for activations, FP32 for master weights). LayerNorm is sensitive to precision because it involves computing a mean and variance, then dividing — small errors in the variance can be amplified by the division. Most implementations compute LayerNorm statistics in FP32 even when activations are in FP16:

python
# Mixed-precision LayerNorm (what PyTorch does internally)
def layernorm_mixed(x_fp16, gamma, beta, eps=1e-5):
    # Upcast to FP32 for statistics
    x_fp32 = x_fp16.float()
    mu = x_fp32.mean(dim=-1, keepdim=True)
    var = ((x_fp32 - mu) ** 2).mean(dim=-1, keepdim=True)
    x_norm = (x_fp32 - mu) / torch.sqrt(var + eps)
    # Downcast result back to FP16
    return (gamma * x_norm + beta).half()
If the input vector is h = [4.0, 0.0, 2.0, 2.0], what is the mean μ used by LayerNorm?

Chapter 3: The Backward Pass — Gradients Through Normalization

Training requires backpropagation through LayerNorm. This is trickier than it looks, because the output depends on the mean and variance, which themselves depend on all inputs. Changing one feature changes the statistics, which changes the normalization of every other feature. Let's derive the gradients carefully.

We need three gradients: ∂L/∂hi (to propagate to the previous layer), ∂L/∂gi (to update the gain), and ∂L/∂bi (to update the bias).

Easy ones first: gain and bias gradients

Since yi = gi · ĥi + bi, the gradients of the loss L with respect to g and b are straightforward:

∂L/∂bi = ∂L/∂yi
∂L/∂gi = ∂L/∂yi · ĥi

The bias gradient is just the upstream gradient. The gain gradient is the upstream gradient times the normalized activation. These are elementwise — no coupling between features.

The hard one: input gradient

This is where it gets interesting. We need ∂L/∂hi. The normalization ĥi = (hi - μ) / σ depends on hi in three ways:

Direct
hi appears in the numerator (hi - μ)
+
Through μ
μ = (1/H) ∑ hj, so ∂μ/∂hi = 1/H
+
Through σ2
σ2 = (1/H) ∑ (hj - μ)2, so ∂σ2/∂hi = (2/H)(hi - μ)

Applying the chain rule and simplifying (which takes about a page of algebra), the gradient of the loss with respect to the input is:

∂L/∂hi = (gi / σ) · [∂L/∂yi - (1/H) ∑j ∂L/∂yj · gj - (ĥi/H) ∑j ∂L/∂yj · gj · ĥj]

Let's understand each term:

TermWhat it does
∂L/∂yiThe "direct" gradient — how the loss changes when this feature changes, ignoring statistics
-(1/H) ∑ ∂L/∂yj gjThe "mean correction" — subtracts the mean gradient, from ∂μ/∂hi
-(ĥi/H) ∑ ∂L/∂yj gjjThe "variance correction" — accounts for the gradient through σ2
The gradient has a "centering" effect. The mean and variance correction terms ensure that the gradient is itself zero-mean and decorrelated with the normalized activations. This is not a coincidence — it's exactly what makes LayerNorm stabilize training. The corrections prevent any single feature from accumulating disproportionate gradient, the same way the forward pass prevents any feature from having disproportionate magnitude.

Concrete gradient computation

Let's verify with numbers. Using our earlier example: h = [2.0, -1.0, 0.5, 3.5], μ = 1.25, σ = 1.677, ĥ = [0.447, -1.342, -0.447, 1.342]. Suppose the upstream gradient is dL/dy = [1.5, 0.5, -0.8, 0.3] and g = [1, 1, 1, 1] (default gain).

dx_hat = dy · g = [1.5, 0.5, -0.8, 0.3]
mean(dx_hat) = (1.5 + 0.5 - 0.8 + 0.3) / 4 = 0.375
mean(dx_hat · ĥ) = (1.5×0.447 + 0.5×(-1.342) + (-0.8)×(-0.447) + 0.3×1.342) / 4
= (0.671 - 0.671 + 0.358 + 0.403) / 4 = 0.190
dhi = (1/σ) · [dx_hati - 0.375 - ĥi × 0.190]

For feature 0: dh0 = (1/1.677) × [1.5 - 0.375 - 0.447×0.190] = (1/1.677) × [1.040] = 0.620

For feature 1: dh1 = (1/1.677) × [0.5 - 0.375 - (-1.342)×0.190] = (1/1.677) × [0.380] = 0.227

Notice: the input gradient dh has smaller magnitude than the upstream gradient dy. LayerNorm attenuates gradients by the factor 1/σ — when activations have high variance (large σ), gradients are dampened. This is the automatic gradient scaling that prevents explosion.

Gradient Flow Through LayerNorm

Upstream gradient (top, teal) flows through LayerNorm to produce the input gradient (bottom, warm). Watch how the mean-correction and variance-correction terms modify the gradient. Drag to change the upstream gradient and see the effect.

Upstream grad[0] 1.5
python
# LayerNorm backward pass
def layernorm_backward(dy, cache):
    # dy: [B, T, D] — upstream gradient
    x_hat, gamma, mu, var, eps = cache
    H = x_hat.shape[-1]  # feature dim
    std = torch.sqrt(var + eps)  # [B, T, 1]

    # Gain and bias gradients (summed over batch & sequence)
    dgamma = (dy * x_hat).sum(dim=(0, 1))  # [D]
    dbeta  = dy.sum(dim=(0, 1))              # [D]

    # Input gradient
    dx_hat = dy * gamma           # [B, T, D]
    dvar   = (dx_hat * x_hat * -0.5 / std).sum(dim=-1, keepdim=True)
    dmu    = (-dx_hat / std).sum(dim=-1, keepdim=True)

    # Combine three paths
    dx = dx_hat / std + dvar * 2 * x_hat * std / H + dmu / H

    return dx, dgamma, dbeta
Why is the backward pass through LayerNorm more complex than a simple elementwise operation?

Chapter 4: The Normalization Zoo

LayerNorm is one member of a family of normalization techniques. Each normalizes over different dimensions, and each works best for different data types. Understanding the full zoo helps you pick the right one for your architecture.

The key to understanding all normalization variants is this: given a 4D tensor of activations (for images) or 3D tensor (for sequences), which dimensions do you average over to compute the mean and variance? Everything else — the formula, the gain/bias, the ε — is identical.

For a feature map of shape [B, C, H, W] (batch, channels, height, width):

MethodNormalize overMean shapeBest for
BatchNormB, H, W[C]CNNs, large batches
LayerNormC, H, W[B]Transformers, RNNs
InstanceNormH, W[B, C]Style transfer, image generation
GroupNormC/G, H, W[B, G]Small batches, detection
Normalization Zoo: What Gets Averaged

A [B=4, C=6, H, W] tensor visualized as a 4×6 grid. Each normalization method averages over different cells (shown in color). Cells that share color get averaged together to compute one mean/variance pair. Click each method to see the difference.

BatchNorm (Ioffe & Szegedy, 2015)

Batch Normalization normalizes each channel independently, computing statistics across the batch and spatial dimensions. For channel c, it computes the mean across all B × H × W pixels in that channel. This works brilliantly for CNNs because channels represent consistent features (edge detectors, texture patterns), and averaging over many spatial locations and batch examples gives stable statistics.

The problem: it needs a large enough batch to get good statistics. With batch size 1, the "batch mean" is just the single example's mean — useless for normalization. BatchNorm also requires maintaining running statistics for inference, adding complexity.

InstanceNorm (Ulyanov et al., 2016)

Instance Normalization normalizes each channel of each example independently — averaging over H × W only. Think of it as BatchNorm with batch size 1. It was developed for style transfer, where the "style" is encoded in the feature statistics (mean and variance). By normalizing away these statistics, InstanceNorm strips the style from the content, allowing a different style to be applied.

GroupNorm (Wu & He, 2018)

Group Normalization divides channels into G groups and normalizes within each group. It's a middle ground between LayerNorm (one group = all channels) and InstanceNorm (each channel is its own group). GroupNorm was designed for object detection and segmentation, where small batch sizes (due to large image crops) make BatchNorm unreliable.

Why LayerNorm won for Transformers: Transformers process sequences, not images. There's no spatial dimension — each token is a 1D feature vector. BatchNorm would normalize across the batch (unreliable for variable-length sequences). InstanceNorm would normalize each feature independently (no channel sharing). GroupNorm could work but adds a hyperparameter (number of groups). LayerNorm normalizes across all features within each token — the natural choice for 1D feature vectors with no spatial structure.

RMSNorm: The modern simplification

Root Mean Square Layer Normalization (Zhang & Sennrich, 2019) simplifies LayerNorm by dropping the mean centering:

RMSNorm(h) = h / RMS(h) · g     where RMS(h) = √((1/H) ∑i hi2)

RMSNorm skips the mean subtraction, normalizing only by the root mean square. This is ~10-15% faster than full LayerNorm because it avoids computing the mean and the variance separately. Empirically, the re-centering (mean subtraction) contributes little to performance — the re-scaling (division by standard deviation) does most of the work.

RMSNorm is used in Llama, Llama 2, Llama 3, Gemma, and many other modern LLMs. It's becoming the de facto replacement for LayerNorm.

python
def rms_norm(x, gamma, eps=1e-6):
    # x: [B, T, D],  gamma: [D]
    rms = torch.sqrt((x ** 2).mean(dim=-1, keepdim=True) + eps)
    return x / rms * gamma  # no mean subtraction, no bias

# Compare operations:
# LayerNorm: mean, variance, subtract, divide, scale, shift → 6 ops
# RMSNorm:  square, mean, sqrt, divide, scale → 5 ops (no mean, no shift)
What makes LayerNorm the natural choice for Transformers over BatchNorm, InstanceNorm, or GroupNorm?

Chapter 5: RNN Experiments — Where LayerNorm Was Born

LayerNorm was originally designed for RNNs, not Transformers (which wouldn't exist for another year). Ba et al. tested it on three tasks that were considered challenging for recurrent networks in 2016: image-sentence ranking, question-answering, and skip-thought vectors.

The key application was to the recurrent transition. In a standard RNN, the hidden state update is:

ht = f(Wh ht-1 + Wx xt + b)

Ba et al. applied LayerNorm to the pre-activation (before the nonlinearity f):

ht = f(LayerNorm(Wh ht-1 + Wx xt))

This normalizes the summed inputs at every timestep, preventing the hidden state from exploding or vanishing over long sequences. Unlike BatchNorm, it works regardless of how many sequences in the batch are still active at timestep t.

Hidden State Dynamics: With vs Without LayerNorm

A simple RNN processing a 30-step sequence. Without normalization (red), the hidden state magnitude can grow unbounded or collapse. With LayerNorm (teal), the activations stay in a stable range. Click "Run" to simulate.

Click Run to simulate

Results from the paper

The paper reported results on three tasks:

TaskMetricWithout LNWith LNImprovement
Image-Sentence RankingRecall@136.740.0+3.3 points
Skip-Thought VectorsPearson r (STS)0.6900.716Faster convergence
Question AnsweringAccuracy (%)31.936.2+4.3 points

The improvements weren't dramatic, but they were consistent — and LayerNorm converged faster in every case. The real value was robustness: LayerNorm made training less sensitive to learning rate and initialization. Models that would diverge without normalization trained stably with it.

The invariance properties

Ba et al. proved two important invariance properties of LayerNorm:

Weight matrix re-scaling invariance: If you multiply all weights W by a scalar α, LayerNorm produces the same output. The scalar multiplies all pre-activations by α, shifting the mean by α and the variance by α2. The normalization divides by σ = α · σoriginal, canceling the α. This means the network is robust to the overall scale of the weights — only the directions matter.
Data re-scaling/shifting invariance: If you add a constant c to all features or multiply by a scalar α, LayerNorm gives the same normalized output. Adding c shifts the mean by c (canceled by mean subtraction). Multiplying by α scales the variance by α2 (canceled by division by σ). This means LayerNorm is invariant to the scale and shift of the input features.

These invariance properties explain why LayerNorm helps optimization: the loss landscape becomes smoother because the network's output doesn't change under rescaling of weights or inputs. The optimizer can focus on learning useful representations rather than fighting against magnitude issues.

Where to apply LayerNorm in an RNN

There are multiple choices for where to apply LayerNorm within a recurrent cell, and the paper experiments with them:

PlacementWhat's normalizedEffect
Pre-activationWhh + WxxMost effective — stabilizes the gate inputs
Recurrent onlyWhh onlyPartial benefit — misses input contribution
Cell stateLSTM cell state ctPrevents cell state from growing unboundedly
OutputHidden state htStabilizes outputs but doesn't help gating

The paper found that applying LayerNorm to both the pre-activation (gate inputs) and the cell state gave the best results. This makes sense: the pre-activation normalization keeps the gate activations in the linear regime of the sigmoid/tanh, and the cell state normalization prevents the memory from growing without bound over long sequences.

Connection to weight normalization

The paper draws a connection to Weight Normalization (Salimans & Kingma, 2016), which reparameterizes weight matrices as w = g · v/||v|| (separating direction from magnitude). Both techniques aim to decouple the direction of updates from their magnitude. LayerNorm achieves this implicitly through activation normalization; WeightNorm achieves it explicitly through weight reparameterization.

A subtle but important difference: Weight Normalization normalizes the weight vectors (decoupling their direction from magnitude), while LayerNorm normalizes the activations (the output of the weight-input product). Both improve optimization, but LayerNorm is more general because it also corrects for the scale of the inputs, not just the weights. In practice, LayerNorm became far more widely adopted — it's the default in every Transformer, while WeightNorm is used mainly in generative models.
python
# LayerNorm in an LSTM (the original use case)
class LN_LSTM(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.W_x = nn.Linear(input_size, 4 * hidden_size)
        self.W_h = nn.Linear(hidden_size, 4 * hidden_size, bias=False)
        # Separate LayerNorm for input and recurrent paths
        self.ln_x = nn.LayerNorm(4 * hidden_size)
        self.ln_h = nn.LayerNorm(4 * hidden_size)
        self.ln_c = nn.LayerNorm(hidden_size)

    def forward(self, x, hx=None):
        # x: [B, T, input_size]
        B, T, _ = x.shape
        if hx is None:
            h = torch.zeros(B, self.hidden_size)
            c = torch.zeros(B, self.hidden_size)
        else:
            h, c = hx

        outputs = []
        for t in range(T):
            # Apply LayerNorm to gates
            gates = self.ln_x(self.W_x(x[:, t])) + \
                    self.ln_h(self.W_h(h))
            i, f, g, o = gates.chunk(4, dim=-1)
            i = torch.sigmoid(i)
            f = torch.sigmoid(f)
            g = torch.tanh(g)
            o = torch.sigmoid(o)
            c = f * c + i * g
            h = o * torch.tanh(self.ln_c(c))
            outputs.append(h)
        return torch.stack(outputs, dim=1)  # [B, T, H]
What key invariance property does LayerNorm provide for optimization?

Chapter 6: Pre-LN vs Post-LN — The Transformer Placement Debate

When the Transformer arrived in 2017, it used LayerNorm in a specific position: after the residual connection. This became known as Post-LN. But in 2020, researchers discovered that moving LayerNorm before the sub-layer (Pre-LN) made training dramatically easier. This seemingly minor change has profound effects on gradient flow and training stability.

Post-LN: The original placement

In the original Transformer (Vaswani et al., 2017), the computation in each sub-layer is:

xout = LayerNorm(x + SubLayer(x))

The input x goes through the sub-layer (attention or FFN), gets added back (residual connection), and then the sum is normalized. This means the normalization sees the combined signal: the original input plus the sub-layer's contribution.

Pre-LN: The modern default

In the Pre-LN variant (used by GPT-2, GPT-3, and most modern LLMs), the computation is:

xout = x + SubLayer(LayerNorm(x))

The input x is normalized before being fed to the sub-layer, and the residual connection adds the un-normalized input to the sub-layer's output. This seemingly small change has a critical consequence for gradient flow.

Pre-LN vs Post-LN: Gradient Flow

Compare the two placements. In Post-LN (left), gradients must flow through LayerNorm at every layer. In Pre-LN (right), the residual connection provides a clean gradient highway that bypasses normalization. Toggle between them and trace the gradient path (warm arrows).

Why Pre-LN is easier to train

The gradient analysis reveals the key difference. In Post-LN, the gradient from the output to the input of layer l must pass through all the LayerNorm operations in layers l+1 through L. Each LayerNorm includes a division by σ, which can amplify or attenuate gradients unpredictably. With 96 layers (as in GPT-3), the gradient passes through 96 LayerNorm operations — each introducing potential instability.

In Pre-LN, the residual connection provides a gradient highway. The gradient from any layer to any earlier layer has a direct path through the residual connections that bypasses all normalization layers. The gradient through this path is exactly 1 (identity). This is the same principle that made ResNets work — the residual connection ensures that gradients can flow unchanged over arbitrary depth.

The practical impact is dramatic: Post-LN requires careful learning rate warmup (4000+ steps) to prevent divergence. Pre-LN trains stably with no warmup at all. Post-LN needs gradient clipping; Pre-LN is robust without it. This is why GPT-2 switched to Pre-LN — it enables scaling to very deep models (48+ layers) without training instability.
PropertyPost-LNPre-LN
PlacementLN(x + SubLayer(x))x + SubLayer(LN(x))
Warmup needed?Yes, 4000+ stepsNo (or minimal)
Gradient flowMust pass through all LNsClean residual highway
Final layer outputAlready normalizedUnnormalized (needs final LN)
Used byOriginal Transformer, BERTGPT-2/3/4, Llama, Claude
Quality at convergenceSlightly better (when it trains)Slightly worse (but much easier)

An interesting finding: Post-LN, when it successfully trains, can achieve slightly better final performance. This has led to hybrid approaches like Sandwich-LN (adding a second LayerNorm after the sub-layer in Pre-LN architecture) and careful initialization schemes that make Post-LN trainable at scale. But for most practitioners, Pre-LN remains the safe default.

The gradient math

Let's see why Pre-LN has better gradient flow. Consider L layers. In Post-LN, the gradient from loss to layer l's input passes through:

∂L/∂xl = ∂L/∂xL · ∏k=lL-1 ∂LN(xk + f(xk)) / ∂xk

Each LayerNorm derivative involves a 1/σ factor that can amplify or attenuate. After L layers, the product of L such factors can be very large or very small — gradient explosion or vanishing.

In Pre-LN, the residual connection gives:

xl+1 = xl + f(LN(xl))
∂xl+1/∂xl = I + ∂f(LN(xl))/∂xl

The identity matrix I provides a "gradient highway" — even if the f(LN(...)) gradient is small, the gradient is at least I. Over L layers, the worst case is still I (identity), not zero. This is the mathematical reason Pre-LN enables training very deep Transformers (96+ layers) without warmup.

Empirical comparison

Xiong et al. (2020) in "On Layer Normalization in the Transformer Architecture" provided a thorough empirical comparison:

MetricPost-LN (6L)Pre-LN (6L)Post-LN (12L)Pre-LN (12L)
Converges without warmup?No (diverges)YesNo (diverges)Yes
Best BLEU (WMT)28.628.0Fails29.2
Gradient norm (init)~103~100~106~100
Training stabilityFragileRobustVery fragileRobust

The gradient norm column is telling: at initialization, Post-LN gradients are 1000x larger than Pre-LN gradients for 6 layers, and a million times larger for 12 layers. This exponential growth is exactly the product of LayerNorm Jacobians. Pre-LN keeps gradients near O(1) regardless of depth.

The takeaway is clear: if you're building a Transformer and don't have a specific reason to use Post-LN, use Pre-LN. It trains stably at any depth, requires no warmup, and is the architecture used by the vast majority of production LLMs. Post-LN is only worth the complexity if you've verified it gives meaningfully better results for your specific task and you have the engineering resources to handle the fragile training dynamics.

DeepNorm: making Post-LN work at depth

Microsoft's DeepNorm (Wang et al., 2022) showed that Post-LN can work at extreme depth (up to 1000 layers) with a simple modification: scale the residual connection by a constant α:

xout = LayerNorm(α · x + SubLayer(x))

With α = (2N)1/4 where N is the number of layers, this keeps the gradient norms bounded. DeepNorm combines the stability of Pre-LN with the slightly better convergence quality of Post-LN. It was used in Microsoft's 530B-parameter model.

python
# DeepNorm: scaled Post-LN for very deep models
class DeepNorm_Block(nn.Module):
    def __init__(self, d_model, n_layers):
        super().__init__()
        self.alpha = (2 * n_layers) ** 0.25
        self.norm = nn.LayerNorm(d_model)
        # Also scale initialization of sublayer weights
        # by beta = (8 * n_layers) ** -0.25

    def forward(self, x, sublayer_out):
        # Scale the residual by alpha
        return self.norm(self.alpha * x + sublayer_out)

The final LayerNorm

Pre-LN has one important caveat: because the output of the last Transformer block is not normalized (the residual adds unnormalized input), you need a final LayerNorm after the last block. Without it, the output magnitudes can vary significantly, destabilizing the final linear projection to logits. Every Pre-LN model (GPT-2, GPT-3, Llama) includes this final LayerNorm.

python
# Pre-LN Transformer (GPT-2 style)
class PreLN_Transformer(nn.Module):
    def __init__(self, d_model, n_layers, vocab_size):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.blocks = nn.ModuleList([
            PreLN_Block(d_model) for _ in range(n_layers)
        ])
        self.final_ln = nn.LayerNorm(d_model)  # CRITICAL
        self.head = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        h = self.embed(x)
        for block in self.blocks:
            h = block(h)
        h = self.final_ln(h)  # normalize before logits
        return self.head(h)
python
# Post-LN (original Transformer)
class PostLN_Block(nn.Module):
    def forward(self, x):
        # x → sublayer → add → normalize
        x = self.norm1(x + self.attn(x))      # LN after residual
        x = self.norm2(x + self.ffn(x))       # LN after residual
        return x

# Pre-LN (GPT-2, modern default)
class PreLN_Block(nn.Module):
    def forward(self, x):
        # x → normalize → sublayer → add
        x = x + self.attn(self.norm1(x))      # LN before sublayer
        x = x + self.ffn(self.norm2(x))       # LN before sublayer
        return x
    # Note: Pre-LN needs a final LayerNorm after the last block
Why does Pre-LN make Transformer training more stable than Post-LN?

Chapter 7: Normalization Explorer

Now let's bring everything together in a comprehensive interactive simulation. You'll see all four normalization methods applied to the same data, with full control over the input distribution. This is where the differences become visceral — not just theoretical, but visible.

Full Normalization Explorer

A [4 batch × 8 features] activation matrix. Each cell is a neuron activation. The heatmap shows raw values (top) and normalized values (bottom) under each normalization method. Drag sliders to change the input distribution and watch how each method responds differently. Click a method to highlight which cells it groups together.

Scale spread 2.0
Feature bias 0.0
LayerNorm active
What to look for: (1) Increase "Scale spread" to make some features much larger than others — watch how LayerNorm equalizes them within each row. (2) Increase "Feature bias" to shift all features positive — watch how LayerNorm centers them but RMSNorm doesn't. (3) Click between methods to see which cells share statistics. (4) Notice how BatchNorm groups cells vertically (across batch) while LayerNorm groups them horizontally (across features).

When to use which normalization

ScenarioBest ChoiceWhy
Transformer / LLMLayerNorm or RMSNormBatch-independent, works with autoregressive generation
CNN with large batchBatchNormStable statistics, well-tested, slight accuracy edge
CNN with small batchGroupNormBatch-independent, works well for detection/segmentation
Style transferInstanceNormStrips style statistics from content features
Modern LLM (Llama-class)RMSNorm10-15% faster than LayerNorm, same quality
RNN / LSTMLayerNormWorks at any timestep, any batch size

The bigger picture: why normalization matters

Normalization isn't just a training trick — it fundamentally changes the optimization landscape. Without normalization, the loss surface has sharp ravines where some dimensions have much larger gradients than others. The optimizer oscillates across the ravine walls instead of moving along the ravine floor toward the minimum.

Normalization smooths these ravines. By ensuring that all features have similar magnitudes, it makes the loss surface more isotropic — the gradients point more directly toward the minimum. This allows larger learning rates, faster convergence, and more stable training.

Ba et al.'s contribution was recognizing that this smoothing doesn't need batch statistics. You can get the same benefit by normalizing within each example — and in doing so, free neural networks from the tyranny of the batch. This insight, simple as it seems, was essential for the Transformer revolution.

Practical implementation notes

If you're implementing LayerNorm in a production system, here are the details that matter:

python
# Production-grade LayerNorm with fused kernel hints
class LayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-5):
        super().__init__()
        self.eps = eps
        # Learnable parameters
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))

    def forward(self, x):
        # In PyTorch, F.layer_norm uses a fused CUDA kernel
        # that computes mean, var, normalize, scale, shift
        # in a single memory pass — 2-3x faster than naive
        return F.layer_norm(
            x, self.weight.shape, self.weight, self.bias, self.eps
        )

# For RMSNorm (Llama-style):
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
        # No bias parameter — one less thing to learn

    def forward(self, x):
        rms = torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x / rms * self.weight
DetailLayerNormRMSNorm
Parameters2D (gain + bias)D (gain only)
Operationsmean, var, normalize, scale, shiftsquare, mean, sqrt, scale
Memory passes2 (compute stats, then normalize)2 (same)
FP16 safe?Yes with eps=1e-5Yes with eps=1e-6
Fused kernel?Yes (PyTorch, Triton)Yes (custom needed)
The legacy of this paper: LayerNorm appears in every Transformer block in every modern LLM. GPT-4, Claude, Gemini, Llama — all use LayerNorm (or its descendant RMSNorm) at every layer. That's two LayerNorm operations per block, times 96+ blocks, times billions of tokens. LayerNorm is arguably the most-executed normalization operation in the history of computing. And it all started from a simple question: "What if we normalized across features instead of across the batch?"

The timeline of normalization innovations

YearTechniqueKey IdeaImpact
2015BatchNormNormalize per feature across batchEnabled deep CNNs (ResNet, Inception)
2016LayerNormNormalize per example across featuresEnabled Transformer, GPT, BERT
2016InstanceNormNormalize per channel per exampleEnabled neural style transfer
2016Weight NormReparameterize weights (direction + magnitude)Improved generative models
2018GroupNormNormalize within channel groupsEnabled small-batch detection
2019RMSNormSkip mean centering, normalize by RMS only10-15% faster, used in Llama
2023QK-NormNormalize Q and K before attentionStabilizes very large models

The progression tells a clear story: the field moved from batch-dependent normalization (BatchNorm) to batch-independent normalization (LayerNorm, GroupNorm) to simplified normalization (RMSNorm). Each step removed unnecessary computation while maintaining the core benefit: keeping activations in a reasonable range to stabilize gradient-based training.

QK-Norm: the latest evolution

The newest normalization technique, used in some very large models, is QK-Norm: applying LayerNorm to the Q and K projections before computing attention scores. Without QK-Norm, as models get very large (100B+ parameters), the dot products in attention can become extremely large, causing numerical instability in fp16/bf16 training. QK-Norm keeps the Q and K vectors at unit magnitude, preventing this.

Attention(Q, K, V) = softmax(LN(Q) · LN(K)T / √dk) · V

This is a direct descendant of Ba et al.'s original idea: when activations grow too large, normalize them. The same principle, applied at a different point in the computation, solving the same fundamental problem — numerical stability through normalization.

From a simple question in 2016 — "what if we normalized across features?" — to the universal presence of LayerNorm in every modern AI system: in LLMs, vision Transformers, diffusion models, robotics policies, and speech models. Ba, Kiros, and Hinton's paper is one of those quiet foundations that the entire field stands on.

The paper has been cited over 15,000 times. Its core equation — subtract the mean, divide by the standard deviation, scale and shift — is executed trillions of times per day across the world's data centers. For a 7-page arXiv preprint from 2016, that's an extraordinary legacy. And its descendants — RMSNorm, QK-Norm, DeepNorm — continue to evolve as models grow deeper and wider, each solving the same fundamental problem of activation instability that Ba, Kiros, and Hinton first addressed.

You're building a 70B-parameter LLM that will be deployed for single-request inference (batch size 1). Which normalization should you choose and why?