The shortcut that made deep learning deep — from ResNet's identity highway to the Pre-LN blocks inside every modern transformer.
A 20-layer network gets 7% training error. You add 36 more layers — now it's 56 layers, strictly more powerful. Training error: 10%. Worse. Not worse on test data (that would be overfitting). Worse on training data. The deeper network can't even match what the shallower one learned.
This paradox motivated the most important architectural invention in deep learning. It was observed by Kaiming He and colleagues in 2015, and it violated everything researchers expected about network depth. More layers means more parameters, which means strictly more expressive power. A 56-layer network includes every function a 20-layer network can represent. So how can it perform worse?
The answer is not overfitting. Overfitting means low training error but high test error — the model memorized the data but can't generalize. Here, training error itself is higher. The deeper network is failing to learn, period. Not failing to generalize — failing to optimize.
Here's the argument for why a 56-layer network should always be at least as good as a 20-layer one. Take the trained 20-layer network. Copy its weights into the first 20 layers of the 56-layer network. Now set each of the remaining 36 layers to compute the identity function — just pass the input through unchanged. The 56-layer network now computes the exact same function as the 20-layer one.
Its training error would be identical: 7%. From there, the optimizer can only improve — it has 36 extra layers of capacity to work with. So the 56-layer network should get at most 7% training error, probably lower.
But it doesn't. It gets 10%. The optimizer cannot find the identity-copying solution. Let's see exactly why.
Goal: A single layer should compute y = x (identity). The layer computes y = ReLU(Wx + b). We need ReLU(Wx + b) = x for all inputs x.
Step 1 — What W and b must be: For positive inputs, ReLU is the identity, so we need Wx + b = x. That means W = I (identity matrix) and b = 0.
Step 2 — Check with input x = [0.5, -0.3]:
Problem! The output is [0.5, 0.0], not [0.5, -0.3]. ReLU kills the negative component. Even with the perfect weight matrix W = I, a ReLU layer cannot represent the identity for negative inputs. The nonlinearity destroys information.
Step 3 — The optimizer's dilemma. In a real network, the optimizer starts with random small weights (not W = I). It would need to discover that W should be close to the identity matrix through gradient descent. But there's nothing in the loss landscape that says "this layer should do nothing." Every layer is trying to learn a useful transformation. The gradient signal says "change your weights to reduce the loss" — not "set your weights to the identity."
Let's see this play out in a training simulation. Watch what happens as you increase depth in a plain network.
Training curves for plain networks. As depth increases, training loss gets worse, not better. Toggle "Add Residuals" to see the fix.
Notice the pattern. At 8 layers, the plain network trains fine — the loss curve drops steadily. At 20 layers, it's a bit slower but reaches a reasonable minimum. But push past 40 layers and training loss plateaus much higher. The network has more capacity but achieves worse results.
Now toggle "Add Residuals." Suddenly, 56 layers trains better than 20. Deeper is better again. That's the magic of residual connections — but we'll get to the how in Chapter 1.
He et al. (2015) trained plain networks on CIFAR-10 and ImageNet. The results were stark:
| Architecture | Depth | Training Error | Test Error |
|---|---|---|---|
| Plain Network | 20 | 4.67% | 8.82% |
| Plain Network | 56 | 6.97% | 10.56% |
| ResNet | 20 | 4.28% | 8.75% |
| ResNet | 56 | 2.91% | 6.97% |
The 56-layer plain network has higher training error than the 20-layer one — 6.97% vs 4.67%. This is degradation. But the 56-layer ResNet beats the 20-layer ResNet convincingly. Residual connections made depth useful again.
python # Demonstrate degradation: deeper plain net → worse training loss import torch import torch.nn as nn class PlainNet(nn.Module): def __init__(self, depth, hidden=64): super().__init__() layers = [nn.Linear(784, hidden), nn.ReLU()] for _ in range(depth - 2): layers += [nn.Linear(hidden, hidden), nn.BatchNorm1d(hidden), nn.ReLU()] layers.append(nn.Linear(hidden, 10)) self.net = nn.Sequential(*layers) def forward(self, x): return self.net(x.view(x.size(0), -1)) # Train both and compare: shallow = PlainNet(depth=20) # converges to ~4-5% train error deep = PlainNet(depth=56) # plateaus at ~7-8% train error! # The deeper network is WORSE at fitting the training data. # This is not overfitting — it's an optimization failure.
What if, instead of asking each layer to learn the complete transformation, we asked it to learn just the change? The input passes through unchanged via a shortcut, and the layer only needs to learn what to add. If the optimal addition is nothing, the layer just outputs zeros — trivially easy.
That's the core idea of residual learning. Instead of training a block to compute the desired output H(x) directly, we train it to compute the residual F(x) = H(x) - x. The actual output is then:
The "+x" is the skip connection (or shortcut connection). It carries the input around the block, unchanged, and adds it to whatever the block produces. This tiny change — adding one wire that bypasses the block — is the most important architectural innovation in deep learning history.
Remember the degradation problem: extra layers in a plain network can't learn the identity function H(x) = x. They would need W = I and b = 0, which is hard to reach by gradient descent.
With a residual connection, the story changes completely. If the optimal function is identity (H(x) = x), the block needs to learn:
The block needs to output all zeros. That's trivially easy — weights initialized near zero already do this! The block starts as an approximate identity and only learns a non-zero residual when doing so helps. Extra layers default to "transparent" instead of fighting to learn identity.
Let's trace data through both architectures with concrete numbers.
Plain layer: y = ReLU(Wx + b)
The original signal [0.5, -0.3] is gone. The output [0.11, 0.12] bears little resemblance to the input. The layer completely replaced the signal with its own small, noisy transformation.
Residual layer: y = ReLU(Wx + b) + x
The skip connection preserved the original signal! The layer's small contribution was added, not replaced. The output [0.61, -0.18] is clearly related to the input [0.5, -0.3] — it's the input plus a small adjustment.
The skip connection helps gradients too. Consider the gradient of the output with respect to the input:
Plain layer:
Residual layer:
That +I is the identity matrix. It guarantees the gradient is at least the identity — even if ∂F/∂x is zero (the block's gradient vanishes), the gradient through the skip connection is exactly 1. The gradient always has a highway home.
Left: plain network (signal gets transformed at each layer). Right: residual network (signal splits — one path skips, one path transforms, then they merge). Adjust layer strength to control how much each block contributes.
At layer strength 0, the block contributes nothing — F(x) = 0. The residual output is just x. The block is perfectly transparent. At strength 1, both paths contribute equally. Notice how the residual network always preserves the input signal — it's always visible in the output.
python import torch import torch.nn as nn # Plain block — input goes through, original signal lost class PlainBlock(nn.Module): def __init__(self, channels): super().__init__() self.conv1 = nn.Conv2d(channels, channels, 3, padding=1) self.bn1 = nn.BatchNorm2d(channels) self.conv2 = nn.Conv2d(channels, channels, 3, padding=1) self.bn2 = nn.BatchNorm2d(channels) def forward(self, x): out = torch.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) return torch.relu(out) # no skip! # Residual block — input preserved via skip connection class ResidualBlock(nn.Module): def __init__(self, channels): super().__init__() self.conv1 = nn.Conv2d(channels, channels, 3, padding=1) self.bn1 = nn.BatchNorm2d(channels) self.conv2 = nn.Conv2d(channels, channels, 3, padding=1) self.bn2 = nn.BatchNorm2d(channels) def forward(self, x): out = torch.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) return torch.relu(out + x) # +x is the ENTIRE skip connection # That's it. One line — "+ x" — is the difference. # Test: with zero-initialized conv2, output ≈ ReLU(x) ≈ identity for x > 0
In a plain network, the gradient is a product: layer 1's gradient = gL × gL-1 × ... × g1. One small factor kills the whole product. In a residual network, the gradient is a sum of many paths. Even if most paths have tiny gradients, the direct skip-connection path contributes a gradient of exactly 1. The gradient can't vanish.
This isn't hand-waving. Let's derive exactly what happens with concrete math.
Consider a chain of 3 residual blocks. Each block computes yk = Fk(yk-1) + yk-1. Start with input x:
Now let's compute the gradient ∂y3/∂x using the chain rule. We need to think about all possible paths from x to y3.
At each block, there are two paths: through the block (via Fk) or through the skip connection (identity, gradient = 1). With 3 blocks, there are 23 = 8 possible paths:
| Path | Description | Gradient Contribution |
|---|---|---|
| skip-skip-skip | All skips (direct path) | 1 |
| skip-skip-F3 | Only block 3 | ∂F3/∂x |
| skip-F2-skip | Only block 2 | ∂F2/∂x |
| F1-skip-skip | Only block 1 | ∂F1/∂x |
| skip-F2-F3 | Blocks 2 & 3 | ∂F3·∂F2/∂x |
| F1-skip-F3 | Blocks 1 & 3 | ∂F3·∂F1/∂x |
| F1-F2-skip | Blocks 1 & 2 | ∂F2·∂F1/∂x |
| F1-F2-F3 | All blocks | ∂F3·∂F2·∂F1/∂x |
The total gradient is the sum of all 8 paths. The first row — the all-skips path — always contributes exactly 1. This is the gradient highway that never vanishes.
Plain network gradient (single path, product of factors):
The gradient has decayed to 2.7% of the original signal. After 10 layers with this factor, it would be 0.310 = 0.0000059 — essentially zero. The network can't learn.
Residual network gradient (sum of all 8 paths):
The residual gradient is 81 times larger (2.197 vs 0.027). And most of that magnitude — 1.0 out of 2.197 — comes from the direct skip path. Even if every block had zero gradient (∂Fk/∂x = 0), the total gradient would still be 1.0.
The comparison gets more dramatic with depth. Let's compute for various depths with the same block gradient of 0.3:
| Depth (L) | Paths (2L) | Plain Gradient | Residual Gradient | Ratio |
|---|---|---|---|---|
| 3 | 8 | 0.027 | 2.197 | 81× |
| 5 | 32 | 0.00243 | 4.052 | 1,668× |
| 10 | 1,024 | 5.9 × 10-6 | 13.79 | 2.3M× |
| 20 | 1,048,576 | 3.5 × 10-11 | 190.0 | 5.5T× |
At 20 layers, the plain network's gradient is effectively zero (3.5 × 10-11). The residual network's gradient is 190 — healthy and usable. This is why ResNets can train at 100+ layers where plain networks fail at 20.
Every possible gradient path from output to input. Plain network: one path (product). Residual: many paths (sum). The direct skip path (green) always has magnitude 1.
Drag the block gradient slider down toward 0.05 — severe gradient vanishing. In the plain network, the single product shrinks to nothing. In the residual network, the green path (direct skip) stays at 1.0, keeping the total gradient healthy. Now increase the number of blocks — the plain gradient collapses exponentially while the residual gradient grows (more paths contribute).
python # Compute gradient magnitudes for plain vs residual networks import numpy as np from itertools import product def plain_gradient(block_grads): """Product of all block gradients — single path.""" return np.prod(block_grads) def residual_gradient(block_grads): """Sum over all 2^L paths — each path is skip (1) or block (g_k).""" L = len(block_grads) total = 0.0 for choices in product([0, 1], repeat=L): # 0 = skip (multiply by 1), 1 = through block (multiply by g_k) path_grad = 1.0 for k, use_block in enumerate(choices): if use_block: path_grad *= block_grads[k] total += path_grad return total # Example: 5 blocks, each with gradient 0.3 g = [0.3] * 5 print(f"Plain: {plain_gradient(g):.6f}") # 0.002430 print(f"Residual: {residual_gradient(g):.3f}") # 4.052 print(f"Ratio: {residual_gradient(g)/plain_gradient(g):.0f}x") # 1668x # Shortcut: residual gradient = (1 + g)^L by the binomial theorem # With g=0.3, L=5: 1.3^5 = 3.713 + correction terms ≈ 4.052 ✓
Why exactly does optimization succeed with residuals but fail without them? The answer is visual: plot the loss landscape. Without residuals, it looks like a mountain range — jagged peaks, narrow valleys, saddle points everywhere. With residuals, it looks like a smooth bowl. The optimizer rolling downhill actually reaches the bottom.
In 2018, Li et al. published a landmark paper that made this concrete. They developed a technique to visualize loss landscapes and showed that residual connections don't just help gradients — they fundamentally reshape the optimization surface itself.
Here's how they did it. Take a trained network with weights θ*. Pick two random directions in weight space, δ1 and δ2. Then plot the loss as you move along these directions:
This gives you a 2D slice through the high-dimensional loss landscape. Think of it as a topographic map of the terrain the optimizer must navigate. The results were stunning:
The smoothing effect has an elegant explanation. A residual block computes y = F(x) + x. When we initialize the network with small random weights, F(x) is small — close to zero. So the block starts as an approximate identity: y ≈ x.
This means a deep residual network starts as an approximate identity function. Each added block makes a small perturbation to an already-working solution. The loss function at depth L is close to the loss function at depth L - 1, because the extra block barely changes the output.
Without residuals, each added layer can completely scramble the representation. The loss surface is unconstrained — adding a layer can put you in a completely different region of loss space. The landscape becomes chaotic because there's no continuity between what depth L computes and what depth L + 1 computes.
Plain network: y = W2 · σ(W1 · x).
Original weights: W1 = [[0.8, 0.3], [0.1, 0.9]], W2 = [[1.2, -0.5], [0.4, 1.1]]. Using ReLU.
Now perturb W1 by ε = 0.1 (add 0.1 to every element): W1' = [[0.9, 0.4], [0.2, 1.0]].
Change in output: [0.105, 0.225]. The perturbation was amplified by W2 — a 0.1 change in W1 caused a 0.225 change in output. With more layers, this amplification compounds.
Residual network: y = σ(W1 · x) + x.
Change in output: [0.15, 0.15]. No amplification — the skip connection absorbs the perturbation. The change in the block's output is added directly, not multiplied through downstream layers. The landscape stays smooth because small weight changes cause small output changes.
Left: plain network — rugged, chaotic. Right: residual network — smooth, bowl-like. A ball (optimizer) rolls downhill on each surface. As depth increases, the plain landscape gets worse while the residual stays smooth.
Start at depth 4 — both landscapes are relatively smooth. Now drag the depth slider to 32, then 64. The plain landscape becomes increasingly jagged and chaotic. Sharp ridges appear. Local minima multiply. The optimizer ball gets stuck in shallow valleys far from the true minimum.
The residual landscape barely changes. At depth 64, it's still a smooth bowl. The optimizer rolls cleanly to the bottom every time. This is the visual proof of why ResNets train where plain networks fail.
There's a bonus. The smooth residual landscape doesn't just make training easier — it produces better solutions. The minima in smooth landscapes tend to be wide and flat. Wide minima generalize better because small perturbations to the weights (from new data at test time) don't change the loss much.
In contrast, the sharp minima found in chaotic landscapes are narrow — the loss increases steeply if you move even slightly away. A model sitting in a sharp minimum is brittle: it memorizes training data but fails on new examples.
python # Simplified loss landscape visualization (concept code) import numpy as np def plain_landscape(alpha, beta, depth): """Simulate increasingly rugged landscape with depth.""" # Base bowl + noise that grows with depth bowl = alpha**2 + beta**2 # Add harmonics — more depth = more high-frequency noise ruggedness = 0.0 for k in range(1, depth // 4 + 1): freq = k * 2.0 amp = 0.3 / k # diminishing but persistent ruggedness += amp * np.sin(freq * alpha) * np.cos(freq * beta) return bowl + ruggedness * (depth / 10.0) def resid_landscape(alpha, beta, depth): """Residual landscape stays smooth regardless of depth.""" # Nearly the same bowl — residuals suppress the ruggedness bowl = alpha**2 + beta**2 # Tiny perturbation that barely grows with depth noise = 0.02 * np.sin(2*alpha) * np.cos(2*beta) return bowl + noise # At depth 64: # plain_landscape(0.5, 0.3, 64) → highly irregular # resid_landscape(0.5, 0.3, 64) → smooth bowl ≈ 0.34
Every transformer has residual connections — that's settled. But WHERE you put the layer normalization relative to the residual add changes training stability dramatically. This seemingly minor architectural choice has caused more training failures than any hyperparameter.
The original Transformer paper (Vaswani et al., 2017) used what we now call Post-LN: normalize AFTER the residual addition. In 2020, researchers discovered that moving normalization BEFORE the sublayer — Pre-LN — made deep transformers dramatically easier to train. Most modern LLMs (GPT-2, GPT-3, Llama, Mistral) use Pre-LN.
Let's trace the exact data flow through each variant to see why.
Both variants have the same components: a sublayer (attention or FFN), a layer normalization, and a residual add. The only difference is the order.
Post-LN (original)
Pre-LN (modern)
During backpropagation, the gradient flows backward through the computational graph. Let's trace the gradient through the skip connection in each variant.
Post-LN gradient path through the skip:
The gradient through the skip connection must pass through the LayerNorm Jacobian. LayerNorm is a function of the mean and variance of its input — when the activation statistics are unstable, this Jacobian can amplify or dampen gradients unpredictably. The ∂LN/∂(·) term is the problem.
Pre-LN gradient path through the skip:
The gradient through the skip connection is I — the identity. No normalization in the direct path. The gradient flows through the skip as clean 1s, exactly like the ResNet skip connection from Chapter 2. The LayerNorm only appears on the branch path, where its effects on the gradient are contained.
Let's trace a concrete gradient through a 4-layer Post-LN vs Pre-LN stack. For simplicity, assume each LayerNorm Jacobian scales the gradient by a factor α (which varies per layer in practice), and each sublayer Jacobian contributes a small term δ.
Post-LN, 4 layers (gradient at layer 1 from loss at layer 4):
Gradient at layer 1: 0.489. After just 4 layers, the gradient is halved. With 24 layers (a small GPT), the product of those α factors can shrink the gradient to near zero — or if any α > 1, it can explode.
Pre-LN, 4 layers (gradient at layer 1 from loss at layer 4):
The gradient floor is 1.0 regardless of depth. No multiplicative decay. This is the same mechanism that made ResNets trainable — the clean additive skip creates a gradient highway that LayerNorm in Post-LN blocks.
The simulation below shows both block types side by side. Each column shows the data flow through the block: input enters at the top, output exits at the bottom. Gradient magnitude is shown flowing backward (bottom to top) with color coding: green means healthy gradient, red means problematic.
Drag the learning rate slider. At low LR, both variants work fine. As you increase LR, Post-LN starts showing gradient instability (red flashes at early layers) while Pre-LN stays stable (green throughout).
Two transformer blocks process the same input. Watch gradient health (green=stable, red=unstable) as you increase learning rate.
python import torch import torch.nn as nn class PostLNBlock(nn.Module): """Original Transformer: normalize AFTER residual add.""" def __init__(self, d_model): super().__init__() self.attn = nn.MultiheadAttention(d_model, 8) self.norm = nn.LayerNorm(d_model) def forward(self, x): # 1. Sublayer attn_out, _ = self.attn(x, x, x) # 2. Residual add THEN normalize return self.norm(x + attn_out) # LN wraps the skip! class PreLNBlock(nn.Module): """Modern LLM: normalize BEFORE sublayer.""" def __init__(self, d_model): super().__init__() self.attn = nn.MultiheadAttention(d_model, 8) self.norm = nn.LayerNorm(d_model) def forward(self, x): # 1. Normalize THEN sublayer attn_out, _ = self.attn(self.norm(x), self.norm(x), self.norm(x)) # 2. Clean residual add — skip path is untouched return x + attn_out # gradient flows through x directly
ResNet proved that shortcuts work. But is addition the only way to connect layers? What if instead of adding the input to the output, you concatenate them? Or what if you let the network learn how much to skip?
Three architectures explored three different points on the connectivity spectrum. Each teaches us something different about how information should flow through deep networks.
You already know this one. The output is the input plus a learned residual:
The input and output must have the same dimension (or you add a projection). The gradient through the skip is 1. Simple, cheap, and it scales to 1000+ layers. This is what the entire industry uses.
DenseNet (Huang et al., 2017) asked: why add when you can keep everything? In a Dense Block, each layer receives the concatenated outputs of ALL previous layers as its input:
The square brackets mean concatenation along the channel dimension. Layer k doesn't just see the previous layer's output — it sees every feature ever computed. This is maximum feature reuse.
The catch: concatenation grows the channel count linearly. If each layer produces g feature maps (called the growth rate), then layer k has (input_channels + k × g) channels as input. After 50 layers with g=32, that's 1600+ channels. Memory-hungry and computationally expensive for deep networks.
Highway Networks (Srivastava et al., 2015) came before ResNets and asked a subtler question: what if the network could learn which layers should transform and which should pass through? They introduced a transform gate:
T(x) is a learned gate with sigmoid output (values between 0 and 1). H(x) is the transformation. The ⊙ means element-wise multiplication.
Highways are more flexible than ResNets but also more expensive — you need to learn T(x) in addition to H(x), doubling the parameters per layer.
Let's trace a concrete input through all three patterns. Input: x = [1.0, 2.0]. The transformation produces F(x) = [0.3, -0.1].
ResNet (addition):
DenseNet (concatenation):
Highway (learned gate):
The simulation below shows all three connectivity patterns. Each architecture is drawn as a stack of layers with connections between them. Toggle between them to see how information flows. DenseNet has the richest connectivity (arrows from every layer to every later layer), Highway has adaptive connections (arrow width proportional to gate value), and ResNet has the simplest pattern (one skip per block).
Click each architecture to see its connectivity and feature dimension growth. Watch how channels accumulate in DenseNet.
python import torch import torch.nn as nn class ResBlock(nn.Module): """Additive skip: y = F(x) + x""" def __init__(self, dim): super().__init__() self.net = nn.Sequential( nn.Linear(dim, dim), nn.ReLU(), nn.Linear(dim, dim)) def forward(self, x): return x + self.net(x) # dim in == dim out class DenseBlock(nn.Module): """Concatenative skip: x_k = H_k([x_0, ..., x_{k-1}])""" def __init__(self, in_dim, growth_rate, n_layers): super().__init__() self.layers = nn.ModuleList() for i in range(n_layers): self.layers.append(nn.Sequential( nn.Linear(in_dim + i * growth_rate, growth_rate), nn.ReLU())) def forward(self, x): features = [x] for layer in self.layers: out = layer(torch.cat(features, dim=-1)) features.append(out) # keep ALL outputs return torch.cat(features, dim=-1) # dim grows! class HighwayBlock(nn.Module): """Learned gate: y = T(x)*H(x) + (1-T(x))*x""" def __init__(self, dim): super().__init__() self.H = nn.Sequential(nn.Linear(dim, dim), nn.ReLU()) self.T = nn.Sequential(nn.Linear(dim, dim), nn.Sigmoid()) def forward(self, x): t = self.T(x) # gate: how much to transform h = self.H(x) # transformation return t * h + (1 - t) * x # blend
| Property | ResNet | DenseNet | Highway |
|---|---|---|---|
| Skip type | Addition | Concatenation | Learned gate |
| Dim constraint | in == out | None (grows) | in == out |
| Memory per layer | O(d²) | O(d × k×g) | O(2d²) |
| Gradient path | Clean identity | Clean (each path independent) | Scaled by gate |
| Max practical depth | 1000+ | ~100 (memory) | ~50 (gate learning) |
| Used in | Everything modern | Medical imaging, small models | Historical; LSTM cells are highways |
Time to put everything together. You've learned why skip connections matter (Chapter 0-1), how gradients flow through them (Chapter 2-3), where to place normalization (Chapter 4), and what alternative connectivity patterns exist (Chapter 5). Now build your own network and watch it train.
The simulation below trains a configurable neural network on a 2D classification task. You control depth, skip connection type, normalization placement, and learning rate. The visualization shows three things simultaneously: the network architecture with gradient health, the decision boundary evolving in real time, and the training loss curve.
Build a network, train it, and watch architecture choices play out in real time.
Left panel: Network diagram. Each rectangle is a layer. Color indicates gradient health at that layer: green means the gradient magnitude is within a healthy range, yellow means it's weakening, red means it's vanished or exploded. Skip connections are shown as curved arrows around blocks.
Center panel: Classification task. Two spirals of points (orange and teal) that the network must separate. The colored regions show the network's decision boundary — where it predicts each class. A good network produces a spiral boundary that perfectly separates the two classes. A failing network produces a straight line or chaotic noise.
Right panel: Loss curve. Training loss over time. A healthy network shows a smooth downward curve. An unstable network shows spikes or a plateau. The gradient norm is shown as a secondary trace — watch for it collapsing to zero (vanishing) or shooting upward (exploding).
You understand WHY residuals work. Now: exactly HOW to use them. This chapter gives you the copy-pasteable building blocks for every major architecture, plus the initialization tricks that make them train from step zero.
If your network has more than ~6 layers, add residual connections. The cost is zero extra parameters (just an addition), and the benefit is guaranteed gradient flow. There is no reason not to do this.
Place normalization BEFORE the sublayer, not after. Modern recipe:
Two residual additions per transformer block — one for attention, one for FFN. Each follows the same Pre-LN pattern.
The classic ResNet building block for convolutional networks. Two convolutions with batch normalization and a skip connection:
When spatial dimensions change (e.g., stride=2 for downsampling), the skip path needs a 1×1 convolution with matching stride to align dimensions.
GPT-2 popularized Pre-LN. Each block has two residual additions:
Llama (Meta, 2023) refined the GPT-2 block with two key swaps: RMSNorm instead of LayerNorm (15% faster, same quality), and SwiGLU FFN instead of GELU FFN (better expressiveness). The residual structure is identical.
Here's a trick that sounds almost too simple: initialize the last linear layer in each residual branch to all zeros.
Why? At initialization, the residual branch output is F(x) = 0 for any input. So the block computes y = 0 + x = x. The entire network starts as the identity function. Every layer just passes its input through.
Gradients flow perfectly at initialization (the gradient through the skip is 1, and the gradient through the branch is 0). As training progresses, the zero weights gradually become nonzero — the network learns useful residuals at its own pace.
GPT-2 uses this trick. So do many ResNet variants. It's one line of code:
python # Zero-init the output projection of each residual block for block in model.blocks: nn.init.zeros_(block.ffn.out_proj.weight) nn.init.zeros_(block.ffn.out_proj.bias) nn.init.zeros_(block.attn.out_proj.weight) nn.init.zeros_(block.attn.out_proj.bias)
A transformer with N blocks has 2N residual additions (one for attention, one for FFN per block). Each addition grows the residual stream's variance. After 2N additions, the variance has grown by a factor of ~2N.
The fix: scale each sublayer's output projection by 1/√(2N).
For a 32-layer transformer: 2N = 64, scale = 1/√64 = 0.125. Each sublayer contributes 1/8th of its output to the residual stream. The variance stays bounded regardless of depth.
GPT-3 uses this scaling. Llama and Mistral use it. It's another one-liner:
python # Scale output projections by 1/sqrt(2*n_layers) scale = (2 * n_layers) ** -0.5 for block in model.blocks: block.attn.out_proj.weight.data *= scale block.ffn.out_proj.weight.data *= scale
The simulation below lets you toggle between four canonical block designs: the ResNet block, GPT-2 Pre-LN block, Post-LN Transformer block, and the Llama block. Each shows the exact data flow with tensor shapes annotated, highlighting where the residual add happens and where normalization lives.
Toggle between canonical block designs. Tensor shapes and data flow annotated at each step.
python import torch, torch.nn as nn, math class ResNetBlock(nn.Module): """Classic ResNet BasicBlock (vision).""" def __init__(self, channels): super().__init__() self.conv1 = nn.Conv2d(channels, channels, 3, padding=1) self.bn1 = nn.BatchNorm2d(channels) self.conv2 = nn.Conv2d(channels, channels, 3, padding=1) self.bn2 = nn.BatchNorm2d(channels) nn.init.zeros_(self.bn2.weight) # zero init! def forward(self, x): out = torch.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) return torch.relu(out + x) # residual add, then ReLU class GPT2Block(nn.Module): """GPT-2 Pre-LN Transformer block.""" def __init__(self, d, n_heads, n_layers): super().__init__() self.ln1 = nn.LayerNorm(d) self.attn = nn.MultiheadAttention(d, n_heads, batch_first=True) self.ln2 = nn.LayerNorm(d) self.ffn = nn.Sequential( nn.Linear(d, 4*d), nn.GELU(), nn.Linear(4*d, d)) # Zero init + 1/sqrt(2N) scaling s = (2 * n_layers) ** -0.5 nn.init.zeros_(self.ffn[-1].weight) self.scale = s def forward(self, x): h = self.ln1(x) a, _ = self.attn(h, h, h) x = x + self.scale * a # residual 1 x = x + self.scale * self.ffn(self.ln2(x)) # residual 2 return x class LlamaBlock(nn.Module): """Llama-style Pre-RMSNorm + SwiGLU.""" def __init__(self, d, n_heads, n_layers): super().__init__() self.norm1 = nn.RMSNorm(d) self.attn = nn.MultiheadAttention(d, n_heads, batch_first=True) self.norm2 = nn.RMSNorm(d) # SwiGLU: gate * SiLU(Wx), then project down self.w_gate = nn.Linear(d, 4*d, bias=False) self.w_up = nn.Linear(d, 4*d, bias=False) self.w_down = nn.Linear(4*d, d, bias=False) nn.init.zeros_(self.w_down.weight) # zero init def forward(self, x): h = self.norm1(x) a, _ = self.attn(h, h, h) x = x + a h = self.norm2(x) x = x + self.w_down( torch.nn.functional.silu(self.w_gate(h)) * self.w_up(h)) return x
Five architectures enter. One training task. You control the depth. Watch them race — loss curves, gradient health, and memory usage all visible in real time. This is the payoff: seeing every concept from Chapters 0-7 compete head-to-head.
The architectures are: Plain (no skips), ResNet (additive skip), Pre-LN Transformer (Pre-LN + residual), Post-LN Transformer (Post-LN + residual), and DenseNet (concatenative skip).
Five architectures train on the same task. Adjust depth to see which survive. Left: loss curves. Right: gradient health bars and memory usage.
At shallow depths, the differences are subtle — all architectures can handle 8 layers. But as you push past 32, the separation becomes dramatic. This is exactly what He et al. observed in 2015, and what every LLM training team sees today.
| Architecture | Skip Type | Gradient Path | Memory | Max Depth | Used In |
|---|---|---|---|---|---|
| Plain | None | Product (vanishes) | O(d) | ~20 | Shallow MLPs |
| ResNet | Addition | Sum of 2L paths | O(d) | 1000+ | Vision, audio |
| Pre-LN Transformer | Addition + Pre-Norm | Clean identity + branches | O(d) | 100+ blocks | GPT-2/3, Llama, Mistral |
| Post-LN Transformer | Addition + Post-Norm | Through LN Jacobian | O(d) | ~24 blocks | Original Transformer, BERT |
| DenseNet | Concatenation | Independent per feature | O(d×L) | ~100 | Medical imaging |
| Highway | Learned gate | Scaled by gate | O(2d) | ~50 | LSTM cells |
| U-Net | Cross-level concat | Across resolutions | O(d×levels) | ~5 levels | Segmentation, diffusion |
python import torch, torch.nn as nn # 1. Basic residual block (works for any sublayer) class ResidualBlock(nn.Module): def __init__(self, sublayer): super().__init__() self.sublayer = sublayer def forward(self, x): return x + self.sublayer(x) # 2. Pre-LN transformer block (the modern recipe) class PreLNBlock(nn.Module): def __init__(self, d, n_heads): super().__init__() self.ln1 = nn.LayerNorm(d) self.attn = nn.MultiheadAttention(d, n_heads, batch_first=True) self.ln2 = nn.LayerNorm(d) self.ffn = nn.Sequential(nn.Linear(d, 4*d), nn.GELU(), nn.Linear(4*d, d)) def forward(self, x): h = self.ln1(x) a, _ = self.attn(h, h, h) x = x + a # residual 1 x = x + self.ffn(self.ln2(x)) # residual 2 return x # 3. Zero-init trick for block in model.blocks: nn.init.zeros_(block.ffn[-1].weight) # F(x)=0 at init → y=x # 4. Depth scaling (GPT-3 style) scale = (2 * n_layers) ** -0.5 for block in model.blocks: block.attn.out_proj.weight.data *= scale block.ffn[-1].weight.data *= scale
Where skip connections appear next:
x + sublayer(norm(x)),
you'll know exactly why each piece is there and what happens if you remove it.