The weights you start with determine whether your network trains or dies — from variance preservation to the recipes inside every modern LLM.
Initialize a 20-layer network with weights drawn from N(0, 1) — standard normal. Run a single forward pass. Print the activation statistics at each layer. Layer 1: mean 0, std 1. Layer 5: std 50. Layer 10: std 10,000. Layer 20: overflow to infinity. Training hasn't even started and your network is already dead.
This isn't an edge case. This is what happens by default when you don't think about initialization. The very first time data flows through your network — before a single gradient is computed, before a single weight is updated — the initialization alone determines whether the network can learn or not.
Let's see exactly why, by hand.
We'll trace a forward pass through a tiny 3-layer network with 2 neurons per layer. All weights are drawn from N(0, 1) — the naive default. No bias, ReLU activation.
Layer 1: W1 = [[0.8, -1.2], [0.5, 1.1]]
Compute h1 = W1 · x:
Apply ReLU: max(0, [0.2, 1.05]) = [0.2, 1.05]. Both positive, so unchanged.
Layer 2: W2 = [[1.5, -0.3], [0.9, 1.8]]
Compute h2 = W2 · [0.2, 1.05]:
Apply ReLU: max(0, [-0.015, 2.07]) = [0, 2.07]. One neuron is already dead.
Layer 3: W3 = [[1.3, 0.7], [-0.4, 1.6]]
Compute h3 = W3 · [0, 2.07]:
Apply ReLU: [1.449, 3.312].
Input scale was ~1. After just 3 layers, output values have tripled. In a real network with 512 neurons per layer, the effect is exponentially worse. Why? Because each layer multiplies the signal by a matrix whose entries have variance 1 — far too large.
Here is the core mathematical fact. A single linear layer computes y = W · x, where W is a d×d weight matrix and x is a d-dimensional input. If the weights have variance σw2 and the input has variance σx2, then the output variance is:
With d = 512 neurons and σw2 = 1 (standard normal), a single layer multiplies the variance by 512. After two layers: 5122 = 262,144. After 10 layers: 51210 ≈ 1027. The numbers have long since overflowed to infinity.
| Layers | Variance Factor (512n) | What Happens |
|---|---|---|
| 1 | 512 | 10× too large |
| 2 | 262,144 | Activations saturate |
| 5 | 3.5 × 1013 | float32 overflow |
| 10 | 1.3 × 1027 | Infinity / NaN |
| 20 | 1.6 × 1054 | Beyond any float |
And if we make the weights too small — say σw = 0.01, so σw2 = 0.0001 — each layer multiplies variance by 512 × 0.0001 = 0.0512. After 10 layers: 0.051210 ≈ 10-13. All activations collapse to zero. The network outputs the same thing regardless of input.
The simulation below shows a 10-layer network. Input enters at the top. As it flows through each layer, the activation magnitude is shown as a colored bar. Green means healthy (magnitude near 1), yellow means concerning, red means exploding or vanishing. Drag the slider to control the initial weight standard deviation.
Drag the weight std slider. At std≈1.0, activations explode. At std≈0.01, they vanish. Find the sweet spot where all 10 layers stay green.
Notice the pattern. At std = 1.0 (the default in many frameworks), the bars turn red almost immediately — variance explodes through layers. At std = 0.01, the bars shrink to nothing — variance vanishes. There's a narrow sweet spot around std ≈ 0.04-0.07 (for 512-dim layers) where all bars stay green.
That sweet spot is not something you find by trial and error. It has an exact mathematical formula. The next chapter derives it.
python import torch torch.manual_seed(42) x = torch.randn(64, 512) # batch of 64, 512 features print("=== N(0, 1) weights: EXPLOSION ===") for i in range(20): W = torch.randn(512, 512) # std = 1.0 x = x @ W x = torch.relu(x) print(f"Layer {i:2d}: mean={x.mean():12.2f} std={x.std():12.2f}") # Reset and try too-small weights torch.manual_seed(42) x = torch.randn(64, 512) print("\n=== N(0, 0.01) weights: VANISHING ===") for i in range(20): W = torch.randn(512, 512) * 0.01 x = x @ W x = torch.relu(x) print(f"Layer {i:2d}: mean={x.mean():12.6f} std={x.std():12.6f}") # Reset and try the sweet spot torch.manual_seed(42) x = torch.randn(64, 512) print("\n=== N(0, sqrt(2/512)) weights: STABLE ===") for i in range(20): W = torch.randn(512, 512) * (2/512)**0.5 x = x @ W x = torch.relu(x) print(f"Layer {i:2d}: mean={x.mean():.4f} std={x.std():.4f}")
Run this yourself. The first block overflows by layer 10. The second collapses to zero by layer 5. The third — using √(2/512) as the weight standard deviation — stays stable across all 20 layers. That √(2/n) is He initialization, which we'll derive in Chapter 3. But first, we need to understand the underlying principle: variance preservation.
The explosion/collapse problem from Chapter 0 has a clean mathematical solution. If we can make each layer preserve the variance of its input — variance in equals variance out — then even a 100-layer network will have healthy activations at every layer. The question is: what weight variance achieves this?
The answer turns out to be beautifully simple. And once you see it, the entire zoo of initialization methods — Xavier, He, Lecun — becomes obvious.
Consider a single neuron with nin input connections. It computes:
We need three assumptions (all standard and approximately true in practice):
Under these assumptions, the variance of a product of two independent zero-mean random variables is:
And the variance of a sum of independent random variables is:
Since all weights share the same variance Var(w) and all inputs share the same variance Var(x), this simplifies to:
Example 1: fan_in = 512, naive initialization N(0, 1)
Weight variance Var(w) = 1. Input variance Var(x) = 1 (standardized input).
Output variance is 512× the input. After 10 layers: 51210 ≈ 1.3 × 1027. Catastrophic explosion.
Example 2: fan_in = 512, correct initialization Var(w) = 1/512
Output variance equals input variance. After 10 layers: 110 = 1. After 100 layers: 1100 = 1. Perfectly preserved.
Example 3: fan_in = 64
Notice the weights are larger for smaller layers. This makes intuitive sense: with fewer inputs, each weight contributes more to the output, so each individual weight can be larger without blowing up the sum.
The simulation below shows two panels. Left: variance at each layer for naive N(0,1) initialization (exponential growth). Right: variance for Var(w) = 1/nin (flat line at 1). Use the fan_in slider to see how layer width affects the explosion rate, and the weight variance slider to find the sweet spot yourself.
Left panel uses your chosen weight variance. Right panel uses the theoretically correct 1/fan_in. Adjust fan_in to see how wider layers make the explosion worse.
Play with this. When fan_in = 512 and Var(w) = 1/512 ≈ 0.00195, the left panel shows a flat line at variance = 1 — perfect preservation. Increase Var(w) even slightly, and the left panel curves upward exponentially. Decrease it, and it curves downward toward zero.
The right panel always shows the correct initialization. Notice how the theoretically correct Var(w) changes with fan_in: wider layers need smaller weights.
python import numpy as np np.random.seed(42) fan_in = 512 n_layers = 10 batch_size = 1000 # --- Naive: Var(w) = 1 --- print("Naive N(0,1):") x = np.random.randn(batch_size, fan_in) for i in range(n_layers): W = np.random.randn(fan_in, fan_in) # Var = 1 x = x @ W print(f" Layer {i}: Var = {np.var(x):.2e}") # --- Correct: Var(w) = 1/fan_in --- print(f"\nCorrect Var(w) = 1/{fan_in} = {1/fan_in:.6f}:") x = np.random.randn(batch_size, fan_in) for i in range(n_layers): W = np.random.randn(fan_in, fan_in) * np.sqrt(1/fan_in) x = x @ W print(f" Layer {i}: Var = {np.var(x):.4f}")
Run this. The naive version shows variance doubling at every layer: 500, 260000, 108, ... until overflow. The correct version shows variance hovering near 1.0 at every layer, with small random fluctuations.
We've arrived at the fundamental principle behind all modern initialization:
This single equation — Var(w) = 1/nin — is the seed from which Xavier, He, and every other initialization method grows. Xavier extends it to account for the backward pass. He extends it to account for ReLU. But the core idea is always: make each layer preserve variance.
Chapter 1 derived Var(w) = 1/fan_in for the forward pass. But gradients flow backward. For gradients to preserve variance during backpropagation, we need a different condition: Var(w) = 1/fan_out. These are different numbers when the layer isn't square.
Xavier Glorot and Yoshua Bengio noticed this problem in 2010. Their solution is elegant: use the average of the two requirements.
During backpropagation, the gradient at layer l flows from layer l+1. The gradient with respect to the input is:
This is another matrix multiplication, but now the matrix is WT, which has shape fan_out × fan_in. The "effective fan_in" from the gradient's perspective is fan_out. By the same variance argument from Chapter 1:
For gradient variance preservation, we need fan_out · Var(w) = 1, so Var(w) = 1/fan_out.
We have two conflicting requirements:
| Direction | Requirement | Preserves |
|---|---|---|
| Forward | Var(w) = 1/fan_in | Activation variance |
| Backward | Var(w) = 1/fan_out | Gradient variance |
We can satisfy both simultaneously only when fan_in = fan_out (a square weight matrix). For non-square layers, Xavier uses the harmonic mean — the average of the two:
This doesn't perfectly satisfy either requirement, but it comes close to both. The forward pass variance is slightly off, and the backward pass variance is slightly off, but neither explodes or vanishes.
Xavier initialization can be implemented with either a normal or uniform distribution:
Xavier Normal:
Xavier Uniform:
Why √6? Because the variance of a uniform distribution U(-a, a) is a2/3. Setting a2/3 = 2/(fan_in + fan_out) gives a2 = 6/(fan_in + fan_out), so a = √(6/(fan_in + fan_out)).
Layer: fan_in = 512, fan_out = 256.
So W ~ U(-0.0884, 0.0884). Every weight is between -0.09 and +0.09.
Compare with naive N(0, 1): that's weights with std = 1.0 — over 20× too large. It's like setting the volume on a guitar amp to 11 before checking if the mic is plugged in.
Layer: fan_in = fan_out = 512 (square).
When the layer is square, Xavier reduces to exactly the 1/fan_in rule from Chapter 1. The forward and backward requirements agree.
The simulation below shows a 10-layer tanh network. We show activation histograms at layers 1, 5, and 10. With Xavier, all layers have healthy bell-curve distributions. With naive N(0,1), the later layers saturate — all outputs crowd near ±1, where tanh is flat and gradients vanish.
Toggle between Xavier and naive initialization. Watch how the activation histograms change at deep layers. Switch between Normal and Uniform Xavier variants.
The naive initialization produces a devastating pattern: at layer 1, the histogram looks fine. By layer 5, values are pushed toward the extremes. By layer 10, every activation is saturated at -1 or +1. The tanh function is completely flat at these extremes, so the gradient is essentially zero. Learning stops.
With Xavier, the histogram at layer 10 looks nearly identical to layer 1 — a smooth bell curve well within the sensitive region of tanh. Gradients flow freely. Learning works.
python import torch import torch.nn as nn import numpy as np # --- Xavier from scratch --- def xavier_normal(fan_in, fan_out): std = np.sqrt(2.0 / (fan_in + fan_out)) return torch.randn(fan_out, fan_in) * std def xavier_uniform(fan_in, fan_out): a = np.sqrt(6.0 / (fan_in + fan_out)) return torch.empty(fan_out, fan_in).uniform_(-a, a) # --- Compare activations through 10 tanh layers --- torch.manual_seed(42) fan_in = 512 # Naive x = torch.randn(64, fan_in) print("Naive N(0,1) + tanh:") for i in range(10): W = torch.randn(fan_in, fan_in) x = torch.tanh(x @ W) print(f" Layer {i}: mean={x.mean():.4f} std={x.std():.4f}") # Xavier torch.manual_seed(42) x = torch.randn(64, fan_in) print("\nXavier Normal + tanh:") for i in range(10): W = xavier_normal(fan_in, fan_in) x = torch.tanh(x @ W) print(f" Layer {i}: mean={x.mean():.4f} std={x.std():.4f}") # PyTorch built-in (identical) W = torch.empty(512, 512) nn.init.xavier_uniform_(W) # in-place initialization nn.init.xavier_normal_(W) # normal variant
With naive initialization, you'll see std drop to ~0.05 by layer 5 — the network is saturated and nearly dead. With Xavier, std stays around 0.5-0.65 across all 10 layers. The slight decrease below 1.0 is expected: tanh squishes values, so it always reduces variance slightly. Xavier accounts for this by starting each layer's output at the right scale.
Xavier assumes the activation function is linear or symmetric. ReLU is neither — it kills everything below zero. If half your neurons output zero, the effective fan_in is halved, and Xavier's variance estimate is too small by a factor of 2. He initialization adds that missing factor.
Published by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun in 2015 — the same team behind ResNets — this initialization made it possible to train networks with over 100 layers using ReLU activations. Before He init, deep ReLU networks suffered from the slow variance collapse that Xavier couldn't prevent.
Recall from Chapter 1: Var(y) = nin · Var(w) · Var(x). This assumed linear activations. Now add ReLU:
ReLU(z) = max(0, z). If the pre-activation z follows a symmetric distribution around zero (which it does with zero-mean weights and inputs), then exactly half the values are negative and get zeroed out. The other half pass through unchanged.
What does this do to variance? If we zero out half the values and keep the other half, the expected squared value is halved:
Since E[z] = 0 (zero-mean), E[z2] = Var(z). So:
ReLU halves the variance at every layer. With Xavier's Var(w) = 1/fan_in, the variance after ReLU is:
Each layer halves the variance. After 10 layers: 0.510 = 0.001. After 20 layers: 0.520 = 10-6. The signal slowly bleeds out. Not as dramatic as N(0,1) explosion, but equally fatal for deep networks.
The fix is almost comically simple. If ReLU halves the variance, double the weight variance to compensate:
Check it:
The factor of 2 from He exactly cancels the factor of 0.5 from ReLU. Variance is preserved.
Layer with fan_in = 512, ReLU activation.
| Method | Var(w) | Std(w) | After ReLU: Var(out) |
|---|---|---|---|
| Xavier | 1/512 = 0.00195 | 0.0442 | 0.5 × 512 × 0.00195 × 1 = 0.5 |
| He | 2/512 = 0.00391 | 0.0625 | 0.5 × 512 × 0.00391 × 1 = 1.0 |
With Xavier, each layer loses half its variance. With He, variance is perfectly preserved despite ReLU zeroing out half the neurons.
After 10 ReLU layers:
| Method | Cumulative Variance Factor | Signal Strength |
|---|---|---|
| Xavier | 0.510 = 0.001 | 0.1% of original — nearly dead |
| He | 1.010 = 1.0 | 100% of original — perfectly healthy |
He initialization assumes standard ReLU, where exactly half the values are zeroed. For variants:
| Activation | Fraction Passed | Correction Factor | Var(w) |
|---|---|---|---|
| ReLU | 50% | 2 | 2/fan_in |
| Leaky ReLU (α=0.01) | ~50% at full + 50% at α | 2/(1+α2) | ≈ 2/fan_in |
| Leaky ReLU (α=0.2) | mixed | 2/(1+0.04) ≈ 1.92 | 1.92/fan_in |
| GELU / SiLU | ~57% | ~1.7 | ~1.7/fan_in |
| Linear / tanh | 100% | 1 (use Xavier) | 1/fan_in |
In practice, He init (factor = 2) works well enough for GELU, SiLU, and Leaky ReLU too. The correction for these activations is between 1.6 and 2.0, so using exactly 2 is close enough.
The simulation below shows a 10-layer ReLU network. Side-by-side activation histograms at layers 1, 5, and 10. With Xavier, variance dies across layers (histograms narrow progressively). With He, variance is preserved (histograms keep the same spread).
Toggle between Xavier and He initialization. Watch the activation histograms at layers 1, 5, and 10 — Xavier slowly dies, He stays healthy.
The difference is striking at layer 10. Xavier's histogram has collapsed to a narrow spike near zero — the signal has been halved 10 times, leaving only 0.1% of the original variance. He's histogram at layer 10 looks nearly identical to layer 1. The ReLU is still zeroing half the values (visible as a tall bar at 0), but the surviving values maintain their spread.
python import torch import torch.nn as nn import numpy as np # --- He init from scratch --- def he_normal(fan_in, fan_out): std = np.sqrt(2.0 / fan_in) return torch.randn(fan_out, fan_in) * std def he_uniform(fan_in, fan_out): a = np.sqrt(6.0 / fan_in) # Var = a²/3 = 2/fan_in return torch.empty(fan_out, fan_in).uniform_(-a, a) # --- Compare Xavier vs He through ReLU layers --- torch.manual_seed(42) fan_in = 512 # Xavier + ReLU (slow death) x = torch.randn(64, fan_in) print("Xavier + ReLU:") for i in range(10): W = torch.randn(fan_in, fan_in) * np.sqrt(1/fan_in) x = torch.relu(x @ W) print(f" Layer {i}: mean={x.mean():.4f} std={x.std():.4f}") # He + ReLU (stable) torch.manual_seed(42) x = torch.randn(64, fan_in) print("\nHe + ReLU:") for i in range(10): W = he_normal(fan_in, fan_in) x = torch.relu(x @ W) print(f" Layer {i}: mean={x.mean():.4f} std={x.std():.4f}") # PyTorch built-in W = torch.empty(512, 512) nn.init.kaiming_normal_(W, mode='fan_in', nonlinearity='relu') nn.init.kaiming_uniform_(W, mode='fan_in', nonlinearity='relu') # mode='fan_in' (default) preserves forward pass variance # mode='fan_out' preserves backward pass variance # nonlinearity='relu' uses gain=√2; 'leaky_relu' adjusts for slope
With Xavier + ReLU, std drops to ~0.15 by layer 10 — the network is dying. With He + ReLU, std stays near 0.58 (the expected ReLU output std for unit input) across all 10 layers. The network lives.
kaiming_normal_
has a mode parameter: 'fan_in' preserves forward
variance, 'fan_out' preserves backward variance. For most
networks, 'fan_in' is the right choice — it keeps activations
healthy, which matters more than perfect gradient flow in practice. Use
'fan_out' only if you're specifically worried about backward
pass stability (rare with modern optimizers like Adam).
He initialization is the default for modern deep learning.
PyTorch uses Kaiming uniform (He uniform) by default for nn.Linear
and nn.Conv2d layers. When you write nn.Linear(512, 256),
the weights are automatically initialized with He uniform. You've been using
He init all along — you just didn't know it.
Xavier and He use statistical arguments — on average, variance is preserved. Roll the dice a million times and the mean comes out right. But averages are cold comfort when a single unlucky sample explodes layer 47 and kills your training run.
What if we could guarantee norm preservation exactly? Not on average. Not in expectation. For every single input vector, the output has exactly the same magnitude as the input. No stretching, no shrinking, ever.
Such a guarantee exists, and it comes from linear algebra, not statistics. The tool is the orthogonal matrix.
A square matrix Q is orthogonal if its transpose is its inverse:
This single property has a powerful geometric consequence: multiplying any vector by Q preserves its length. Here's why. Take any vector x and compute the squared norm of Qx:
So ||Qx|| = ||x||. The matrix Q rotates (and possibly reflects) the vector, but never stretches or shrinks it. Think of it like spinning a ball on a table — the ball's size doesn't change, only its orientation.
You can't just write down a random orthogonal matrix — random matrices are almost never orthogonal. But you can extract one from any random matrix using QR decomposition.
Every matrix A can be factored as A = QR, where Q is orthogonal and R is upper-triangular. We want Q — we throw away R. The process is:
For rectangular weight matrices (fan_in ≠ fan_out), we generate a larger square matrix, decompose it, and take the appropriate slice of Q. The rows (or columns) of Q are still orthonormal, so the norm-preserving property holds for the dimensions we keep.
Let's walk through every step with a concrete 2×2 example.
Step 1: Random matrix.
Step 2: QR decomposition. We use Gram-Schmidt. First column of A is a1 = [0.3, 1.2]. Normalize it:
Second column of A is a2 = [-0.8, 0.5]. Remove the component along q1:
Step 3: Our orthogonal matrix Q:
Verification: QTQ = I?
| Check | Result | |
|---|---|---|
| Row 1 · Row 1 | 0.243² + 0.970² | 0.059 + 0.941 = 1.000 |
| Row 2 · Row 2 | (-0.970)² + 0.243² | 0.941 + 0.059 = 1.000 |
| Row 1 · Row 2 | 0.243(-0.970) + 0.970(0.243) | -0.236 + 0.236 = 0.000 |
Rows are unit-length and perpendicular. That's orthogonality.
Now test norm preservation. Input x = [1, 0], so ||x|| = 1.
The vector rotated from pointing along the x-axis to pointing mostly along the y-axis, but its length is exactly 1. Try any other input — the norm will always be preserved.
Pure orthogonal initialization preserves norms with gain = 1. But sometimes
you want controlled scaling. PyTorch's nn.init.orthogonal_
accepts a gain parameter that multiplies Q:
With gain = √2 (appropriate for ReLU), the output norm is √2 times the input norm. This compensates for ReLU zeroing half the activations — exactly like He initialization, but with the rotational structure of an orthogonal matrix.
Orthogonal initialization shines in three specific scenarios:
| Scenario | Why Orthogonal Works |
|---|---|
| RNNs / LSTMs | The hidden state is multiplied by the recurrence matrix at every timestep. 100 timesteps = 100 multiplications by the same matrix. Orthogonal ensures no explosion or vanishing through time. |
| Very deep networks without normalization | If you can't use BatchNorm or LayerNorm (rare today), orthogonal init is the strongest guarantee of signal preservation. |
| GAN training | GANs are notoriously unstable. Orthogonal init on the discriminator helps prevent mode collapse by ensuring gradients flow reliably. |
The widget below shows what happens to a 2D vector when multiplied by a weight matrix. With random init, the vector stretches or shrinks. With orthogonal init, it rotates but keeps the same length. Adjust the gain slider to see controlled scaling. Below: activation variance through 20 layers.
Watch how orthogonal init preserves vector norms exactly, while random init causes drift.
python import numpy as np def orthogonal_init(shape, gain=1.0): """Orthogonal initialization via QR decomposition.""" # Step 1: Random matrix from N(0, 1) rows, cols = shape # Make square matrix (larger dimension) n = max(rows, cols) A = np.random.randn(n, n) # Step 2: QR decomposition Q, R = np.linalg.qr(A) # Fix sign ambiguity (ensures uniform distribution) d = np.diag(R) Q *= np.sign(d) # flip columns where R's diagonal is negative # Step 3: Slice to desired shape Q = Q[:rows, :cols] # Step 4: Apply gain return gain * Q # Test: verify norm preservation W = orthogonal_init((4, 4), gain=1.0) x = np.random.randn(4) print(f"||x|| = {np.linalg.norm(x):.4f}") print(f"||Wx|| = {np.linalg.norm(W @ x):.4f}") # They will be identical (up to floating-point precision)
python # PyTorch built-in import torch.nn as nn layer = nn.Linear(256, 256) nn.init.orthogonal_(layer.weight, gain=1.0) # For ReLU networks, use gain=sqrt(2): nn.init.orthogonal_(layer.weight, gain=nn.init.calculate_gain('relu'))
Transformers have a structural feature that changes the initialization game entirely: residual connections. Every sublayer — attention and feedforward — adds its output to a running residual stream. And addition accumulates.
Picture a river. Each transformer layer is a tributary that pours into it. With 96 layers (GPT-3), that's 192 tributaries (attention + FFN each). Even if each tributary carries a modest flow, 192 of them together create a flood.
In a standard transformer block, the computation is:
After one block, x contains the original signal plus two additive terms. After N blocks, x has accumulated 2N additive terms:
If the original x0 has variance 1, and each sublayer output fi(x) has variance σ², then the total variance of xfinal is approximately:
We want this to stay near 1. Solving for σ²:
Each sublayer's output projection must be initialized with a standard deviation scaled down by 1/√(2N). More layers → smaller initial contributions. This is the core insight behind every modern LLM initialization recipe.
GPT-2's initialization code (and GPT-3, which follows the same pattern) uses a base standard deviation of 0.02 for most weights. But the output projections in attention and FFN get a special scaling:
This 0.02 base isn't arbitrary. For GPT-2's d_model = 768:
GPT-2 uses 0.02, which is conservatively smaller than 0.036. This extra conservatism gives a margin of safety — it's better to start too small (slow initial learning) than too large (unstable training).
Let's compute the residual scaling for different GPT-2 variants:
| Model | Layers | d_model | Base σ | Residual σ |
|---|---|---|---|---|
| GPT-2 Small | 12 | 768 | 0.02 | 0.02 / √24 = 0.00408 |
| GPT-2 Medium | 24 | 1024 | 0.02 | 0.02 / √48 = 0.00289 |
| GPT-2 Large | 36 | 1280 | 0.02 | 0.02 / √72 = 0.00236 |
| GPT-2 XL | 48 | 1600 | 0.02 | 0.02 / √96 = 0.00204 |
The pattern is clear: deeper models use dramatically smaller initialization for their residual projections. A 48-layer model's residual weights start 5× smaller than a 12-layer model's.
Let's verify the math for GPT-2 XL (48 layers). Does residual scaling actually keep the variance stable?
Without scaling:
Wait — that's because 0.02 is already small. The real problem is what happens inside each sublayer. The output projection multiplies by a d_model × d_model matrix. With standard init (σ = 0.02 per element), the output variance is approximately d_model × 0.02² × input_variance. For d_model = 1600:
With scaling:
The GPT-2 recipe works but has a limitation: when you change model size, you need to re-tune learning rate and other hyperparameters. μP (Maximal Update Parametrization) from Yang & Hu (2021) fixes this by adjusting initialization, learning rate, and activation scaling together in a principled way.
The key insight: μP sets init and LR so that the size of parameter updates relative to parameter magnitudes stays constant across model widths. You tune hyperparameters on a small model and they transfer directly to the large model. Cerebras and some research labs use μP for efficient hyperparameter search.
Adjust the number of layers and toggle residual scaling on/off. Watch how variance grows through the network.
Without scaling, variance grows linearly with depth. With scaling, it stays controlled.
python import torch import torch.nn as nn class GPT2Block(nn.Module): def __init__(self, d_model, n_heads, n_layers): super().__init__() self.ln1 = nn.LayerNorm(d_model) self.attn = nn.MultiheadAttention(d_model, n_heads) self.ln2 = nn.LayerNorm(d_model) self.ffn = nn.Sequential( nn.Linear(d_model, 4 * d_model), nn.GELU(), nn.Linear(4 * d_model, d_model), # output projection ) self.n_layers = n_layers def _init_weights(self, module): if isinstance(module, nn.Linear): # Standard init: N(0, 0.02) torch.nn.init.normal_(module.weight, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) # Special: scale down residual projections # These are the LAST linear layers in attention & FFN if hasattr(module, '_is_residual_proj'): std = 0.02 / ((2 * self.n_layers) ** 0.5) torch.nn.init.normal_(module.weight, std=std) # For GPT-2 XL (48 layers, d_model=1600): # Standard layers: std = 0.02 # Residual projections: std = 0.02 / sqrt(96) = 0.00204
Time to put everything together. This is the showcase simulation — a full deep network where you pick the initialization method, activation function, number of layers, and layer width, then watch what happens to activations and gradients as signals propagate through the network.
You've learned the theory: Xavier for tanh, He for ReLU, orthogonal for exact preservation, residual scaling for transformers. Now break things. Pick the wrong init for the wrong activation. Make the network absurdly deep. Watch activations collapse to zero or explode to infinity. Then fix it with the right initialization and see everything turn green.
Here's what to look for with each combination. Try these in order:
| Init | Activation | What Happens |
|---|---|---|
| Zeros | Any | Everything is dead immediately. All activations are zero, all gradients are zero. The network is a brick wall. |
| N(0,1) | None (Linear) | Activations explode exponentially — variance doubles (approximately) at each layer. By layer 10, values are in the thousands. |
| N(0,1) | Sigmoid | Activations saturate. Everything collapses to ~0.5. Gradients vanish because sigmoid's gradient is near zero in the saturated region. |
| Xavier | Tanh | Healthy! Green all the way through. This is what Xavier was designed for. |
| Xavier | ReLU | Gradients slowly die. Xavier doesn't account for ReLU's zeroing — variance drops by ~half each layer. By layer 16+, things are dim. |
| He | ReLU | Healthy! Green all the way through. He's factor-of-2 correction fixes Xavier's shortcoming. |
| Orthogonal | None (Linear) | Perfect preservation. Every layer has variance exactly 1.0. The flattest line you'll ever see. |
Pick an init method, activation, depth, and width. Then run the forward and backward pass.
No quiz for the showcase chapter — the simulation is the test. If you can predict what will happen before clicking each combination, you've mastered initialization.
We've covered the general strategies: Xavier for tanh/sigmoid, He for ReLU, orthogonal for exact preservation, residual scaling for transformers. Now the specific tricks. What do you initialize to zero? What do you initialize to a small constant? And what does the actual init code in Llama look like?
Here's a counterintuitive idea: initialize the last layer of each residual block to all zeros.
In a residual block, the computation is:
If f's last layer has weights W = 0, then f(x) = 0, so:
The network starts as the identity function. Every residual block passes its input straight through, unchanged. The network then gradually learns to deviate from identity — each block making small, incremental modifications rather than random large ones.
This technique is used in:
Biases are simpler than weights. The standard rule: initialize all biases to zero. Here's why this works:
A bias adds a constant to every neuron's pre-activation: z = Wx + b. If the weights are properly initialized (Xavier, He, etc.), the pre-activations already have zero mean. A zero bias preserves this property. A nonzero bias would shift all neurons in the same direction, which the network would need to unlearn.
There's one historical exception. With ReLU activation, some practitioners initialized biases to a small positive value like 0.01:
The argument: if a neuron's pre-activation is slightly negative at init, ReLU kills it (output = 0). A small positive bias shifts the distribution so more neurons fire. In practice, this made little difference when combined with proper He initialization, and modern networks use zero biases universally.
One more case: many modern LLMs have no biases at all. Llama, PaLM, and other recent architectures remove biases entirely from linear layers. The argument: biases add parameters but contribute little when combined with layer normalization, which re-centers activations anyway.
Let's look at what modern LLMs actually do. Here's the initialization recipe for three major architectures:
| Component | GPT-2 | Llama | BERT |
|---|---|---|---|
| Embeddings | N(0, 0.02) | N(0, 1/√d) | N(0, 0.02) |
| Linear layers | N(0, 0.02) | He normal | N(0, 0.02) |
| Normalization γ | ones | ones (RMSNorm) | ones (LayerNorm) |
| Normalization β | zeros | N/A (RMSNorm has no β) | zeros |
| Output projections | N(0, 0.02/√2N) | scaled by 1/√2N | N(0, 0.02) |
| Biases | zeros | none (removed) | zeros |
The pattern across all three: (1) most weights use a small normal distribution, (2) normalization layers start as identity (γ=1, β=0), (3) output projections are scaled down for deep models.
Llama 2 7B has d_model = 4096 and 32 layers. Let's compute every init value:
Embeddings:
Linear layers (Q, K, V, gate, up projections):
He normal: σ = √(2/fan_in). For a 4096 → 4096 layer:
Output projections (attn_out, ffn_down):
RMSNorm weights:
That's the entire init. No biases (Llama removes them). No β (RMSNorm doesn't have one). Clean and minimal.
Let's verify this with a quick hand calculation. Take a 2-layer network with 2 neurons each, all weights initialized to 0.5, bias 0:
Forward pass with input [1, 1]:
Both neurons produce identical output: 1.0.
Backward pass:
After the update, the weights are still identical. Repeat forever — the two neurons are permanently locked in sync. You paid for two neurons but only got one.
Select a component and an architecture to see its initialization value. Toggle between GPT-2, Llama, and BERT recipes.
Select a model architecture and see the initialization recipe for each component.
python import torch import torch.nn as nn import math def init_llama(model, d_model, n_layers): """Initialize a Llama-style model.""" for name, param in model.named_parameters(): if param.dim() < 2: continue # skip 1D params (norms) if 'embed' in name: # Embeddings: N(0, 1/sqrt(d_model)) nn.init.normal_(param, std=1.0 / math.sqrt(d_model)) elif 'out_proj' in name or 'down_proj' in name: # Residual projections: He normal / sqrt(2*n_layers) fan_in = param.shape[1] std = math.sqrt(2.0 / fan_in) / math.sqrt(2 * n_layers) nn.init.normal_(param, std=std) else: # All other linear layers: He normal nn.init.kaiming_normal_(param, mode='fan_in') # RMSNorm weights: all ones (PyTorch default) for name, param in model.named_parameters(): if 'norm' in name and param.dim() == 1: nn.init.ones_(param) # Llama 2 7B: d_model=4096, n_layers=32 # Embedding std = 1/sqrt(4096) = 0.0156 # Linear std = sqrt(2/4096) = 0.0221 # Residual std = 0.0221 / sqrt(64) = 0.00276 # RMSNorm weights = [1, 1, 1, ..., 1] # Biases = none (Llama has no biases)
Let's race them all.
You've learned five initialization strategies: Zero, Random, Xavier, He/Kaiming, and Orthogonal. Each has strengths and weaknesses — but reading about failure modes is one thing. Watching them happen in real time is another.
This simulation trains five identical networks on the same task, differing only in initialization. Drag the sliders to create the conditions where each method fails. You'll discover that He + ReLU isn't always king — context matters.
Five methods train simultaneously. Each has characteristic dynamics. Find each method's failure mode.
Experiment 1: Depth = 32, ReLU. Xavier's loss flatlines or oscillates wildly — its variance halves every layer through ReLU (the ½ factor). He stays stable because it compensates with 2/fan_in. This is the textbook case.
Experiment 2: Tanh activation. Switch to tanh. Now Xavier is the winner — it was designed for symmetric activations. He overshoots because its larger variance pushes activations into tanh's saturation regions.
Experiment 3: Depth = 64. At extreme depth, even He struggles. Orthogonal initialization keeps gradients flowing because it preserves norms exactly (no expansion, no shrinkage). This is why Orthogonal is popular in RNNs and very deep networks.
Experiment 4: Shallow (depth = 4), wide (fan_in = 1024). All methods converge to similar performance. Initialization matters less for shallow networks — the problem is self-correcting over a few layers. This is why you rarely hear about initialization for 3-layer MLPs.
| Method | Fails when... | Symptom | Fix |
|---|---|---|---|
| Zero | Always | All neurons identical forever | Use any non-zero init |
| Random N(0,0.01) | Depth > 8 | Activations vanish exponentially | Scale with fan_in |
| Xavier | ReLU + depth > 16 | Variance halves per layer | Switch to He/Kaiming |
| He/Kaiming | Tanh/sigmoid activations | Saturation, slow convergence | Switch to Xavier |
| Orthogonal | Never truly "fails" | Same as Xavier/He in shallow nets | Use Xavier/He for simplicity |
You now understand the complete initialization toolkit — from the symmetry-breaking problem to the recipes used in every modern LLM. This chapter is your practical reference. No new concepts. Just the formulas, the decision guide, and the connections to where you go next.
| Method | Distribution | Variance | Best for |
|---|---|---|---|
| Xavier Normal | N(0, σ²) | σ² = 2 / (fan_in + fan_out) | Tanh, sigmoid, linear |
| Xavier Uniform | U(-a, a) | a = √(6 / (fan_in + fan_out)) | Same as Xavier Normal |
| He Normal | N(0, σ²) | σ² = 2 / fan_in | ReLU, Leaky ReLU |
| He Uniform | U(-a, a) | a = √(6 / fan_in) | Same as He Normal |
| Orthogonal | QR decomposition × gain | Preserves norms exactly | Very deep nets, RNNs |
| Truncated Normal | N(0, σ²), |x| < 2σ | Varies | Transformers (avoids outliers) |
Follow the path that matches your situation:
| Symbol | Meaning | Typical values |
|---|---|---|
| fan_in | Number of input connections to a neuron | 64–12288 |
| fan_out | Number of output connections from a neuron | 64–12288 |
| σ² | Variance of the weight distribution | ~0.001–0.1 |
| gain | Scaling factor for the activation function | 1.0 (linear), √2 (ReLU) |
| dmodel | Hidden dimension in transformers | 256–8192 |
| N | Number of transformer layers | 6–128 |
python import torch.nn as nn # Xavier for tanh/sigmoid nn.init.xavier_normal_(layer.weight) nn.init.xavier_uniform_(layer.weight) # He/Kaiming for ReLU nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu') nn.init.kaiming_uniform_(layer.weight, mode='fan_in', nonlinearity='relu') # Orthogonal for deep nets / RNNs nn.init.orthogonal_(layer.weight, gain=1.0) # Truncated normal for transformer embeddings nn.init.trunc_normal_(layer.weight, std=0.02) # LLM residual scaling (GPT-2 style) nn.init.normal_(layer.weight, std=0.02 / math.sqrt(2 * n_layers))
| If you want to learn... | Go to |
|---|---|
| How normalization keeps activations stable during training | Normalization |
| How optimizers adapt learning rates per-parameter | Optimizers |
| How loss functions shape the gradient landscape | Loss Functions |
| How attention mechanisms use scaled init | Transformer |
| How residual connections interact with init | Universal Architecture |