Vaswani, Shazeer, Parmar, Uszkoreit, Jones, Gomez, Kaiser, Polosukhin (Google Brain / Google Research) — NeurIPS 2017

Attention Is All You Need

The paper that introduced the Transformer architecture — replacing recurrence entirely with self-attention. The foundation of GPT, BERT, Claude, and every modern large language model.

Prerequisites: Matrix multiplication + Neural network basics + Softmax. That's it.
10
Chapters
10+
Simulations
0
Assumed Knowledge

Chapter 0: The Recurrence Bottleneck

You're translating a sentence from English to French. The English sentence is: "The cat that the dog that the boy owned chased ran away." To translate "ran," you need to know it refers to "the cat" — not "the dog" or "the boy." But "the cat" and "ran" are separated by 8 words. In an RNN, information about "the cat" must survive 8 sequential hidden-state updates before it can influence the translation of "ran."

We've seen why this is problematic. The gradient connecting "ran" to "the cat" decays exponentially through the Jacobian product chain. Even with LSTMs and gradient clipping, long-range dependencies remain difficult to learn. But there's a second problem with recurrence that's arguably even more damaging: sequential computation.

An RNN processes tokens one at a time, left to right. To compute h5, you must first compute h1, h2, h3, h4. This is inherently sequential — you can't parallelize it. On modern GPUs with thousands of cores, this means most of the hardware sits idle, waiting for the previous time step to finish.

Sequential vs Parallel Computation

Top: an RNN must process tokens sequentially (each step depends on the previous). Bottom: self-attention can process all tokens simultaneously — every token looks at every other token in parallel. Click "Animate" to see the computational difference.

Click to compare
PropertyRNN / LSTMTransformer
ComputationSequential (O(T) steps)Parallel (O(1) steps)
Max path lengthO(T) — gradient traverses T stepsO(1) — attention connects any pair directly
GPU utilizationLow — one step at a timeHigh — all tokens computed in parallel
MemoryO(1) per step (just the hidden state)O(T²) — attention matrix over all pairs
Training speedSlow on modern hardwareFast — designed for GPU parallelism
The Transformer's bet: What if we replaced recurrence entirely? Instead of processing tokens sequentially and relying on hidden states to carry information forward, what if every token could directly attend to every other token? The path length between any two tokens would be O(1) — no chain of hidden states, no vanishing gradients, no sequential bottleneck. The cost is O(T²) memory for the attention matrix, but for typical sequence lengths (512-2048 tokens in 2017), this tradeoff is overwhelmingly worth it.

Vaswani et al. called their architecture the Transformer, and its impact was immediate and total. Within two years, every state-of-the-art NLP model was based on the Transformer. Within five years, it had conquered computer vision, speech, protein folding, robotics, and code generation. It is, by any measure, the most influential architecture in the history of deep learning.

The context before 2017

Before the Transformer, the dominant approach for sequence-to-sequence tasks (translation, summarization, etc.) was the encoder-decoder with attention architecture (Bahdanau et al., 2015). This used:

Encoder
Bidirectional LSTM reads the input sequence
Attention
At each decoder step, compute attention over encoder states
Decoder
Unidirectional LSTM generates the output sequence

This architecture was powerful but still fundamentally limited by the sequential nature of the RNNs. The encoder processed the input left-to-right (and right-to-left for bidirectional), one token at a time. Training was slow because the LSTM steps couldn't be parallelized.

The key question Vaswani et al. asked: "What if we used only the attention mechanism and threw away the RNNs entirely?" It seemed radical at the time — attention was viewed as an auxiliary mechanism, not a standalone one. The conventional wisdom was that recurrence was essential for processing sequences — how else would the model know about order?

The answer was surprising: you don't need recurrence for order. You can just tell the model about position through positional encodings — a fixed or learned vector added to each token that encodes its position in the sequence. With positional encodings providing order information and attention providing inter-token communication, recurrence becomes unnecessary.

Let's understand it piece by piece, starting with the core mechanism: attention.

What are the two fundamental problems with recurrent architectures that the Transformer solves?

Chapter 1: Attention Intuition

Before we get into the math, let's build intuition for what attention does. The core idea is remarkably simple.

Imagine you're at a cocktail party. Many people are talking simultaneously, but you're able to selectively focus on the conversation that matters to you. Your brain computes a kind of "relevance score" for each speaker and allocates your attention proportionally. The most relevant voice gets the most attention; background chatter gets almost none.

Attention in a neural network does exactly this. Given a collection of values (the "speakers"), and a query ("what am I looking for?"), attention computes a relevance score between the query and each value, then returns a weighted sum of the values, where the weights are the relevance scores.

Think of attention as a soft dictionary lookup. In a regular dictionary, you look up a key and get exactly one value. In attention, your query partially matches multiple keys. Instead of returning one value, attention returns a weighted blend of all values, where the weights reflect how well each key matches your query.

More concretely, attention involves three components:

Query (Q)
"What am I looking for?" — the representation of the current position
↓ compare with
Key (K)
"What do I contain?" — the representation of each other position
↓ retrieve from
Value (V)
"What information do I carry?" — the actual content at each position

The query asks: "Given where I am, what should I attend to?" The keys answer: "Here's how relevant I am." The values provide: "Here's what I actually contain." The output is a weighted sum of values, weighted by how well each key matches the query.

Attention as Soft Dictionary Lookup

A query (warm) is compared against four keys (teal). The similarity scores determine how much each value contributes to the output. Drag the query to change what we're looking for and watch the attention weights redistribute.

Query focus 1.0

In a language model, the query, key, and value are all derived from the same sequence of tokens (this is why it's called self-attention — the sequence attends to itself). Each token generates its own Q, K, and V by multiplying its embedding by learned weight matrices WQ, WK, WV. These are learned parameters — the model discovers what to query for, what to advertise, and what to provide during training.

An important subtlety: Q, K, and V are linear projections of the same input. The token "cat" generates a query ("I'm looking for my verb"), a key ("I'm an animal noun"), and a value ("here's my semantic content") — all from its single embedding vector, through three different learned linear transformations. The separation into Q, K, V gives the model the flexibility to look for one thing (via Q/K similarity) while retrieving another thing (via V).

Consider the sentence "The cat sat on the mat." When processing the word "sat," the query represents "what did the sitting?" The key for "cat" represents "I'm the subject — I did the action." The attention mechanism computes a high score between the "sat" query and the "cat" key, causing "sat" to attend strongly to "cat" and pull in information about the actor.

This is fundamentally different from an RNN, where "sat" can only access "cat" indirectly, through the chain of hidden states h1 → h2 → h3. In attention, "sat" accesses "cat" directly through the attention weight. The path length is O(1), not O(T). This directness is what makes attention immune to the vanishing gradient problem.

Moreover, the attention weights are interpretable. You can literally look at the attention matrix and see which words each word is attending to. This transparency is rare in neural networks and has spawned an entire subfield of "attention analysis" (Vig 2019, Clark et al. 2019), where researchers study what attention heads learn to do.

python
# Attention intuition: soft dictionary lookup
import torch
import torch.nn.functional as F

# 4 items in our "dictionary"
keys   = torch.tensor([[1.0, 0.0],   # key 0: "cat"
                       [0.0, 1.0],   # key 1: "dog"
                       [0.7, 0.7],   # key 2: "pet" (mix)
                       [-1.0, 0.0]]) # key 3: "table"

values = torch.tensor([[10.0, 20.0],  # value 0: cat's info
                       [30.0, 40.0],  # value 1: dog's info
                       [15.0, 25.0],  # value 2: pet's info
                       [50.0, 60.0]]) # value 3: table's info

# Query: "I'm looking for cat-like things"
query = torch.tensor([0.9, 0.1])

# Scores = dot product of query with each key
scores = query @ keys.T  # [0.9, 0.1, 0.7, -0.9]

# Weights = softmax of scores (sum to 1)
weights = F.softmax(scores, dim=0)
# tensor([0.399, 0.179, 0.327, 0.066])
# Cat gets most attention, table gets least

# Output = weighted sum of values
output = weights @ values
# tensor([15.6, 26.0]) — mostly cat's info, some pet's
In the attention mechanism, what do the Query, Key, and Value represent?

Chapter 2: Scaled Dot-Product Attention

Now let's formalize what we just built intuitively. The Transformer uses scaled dot-product attention, which is defined by a single elegant equation:

Attention(Q, K, V) = softmax(QKT / √dk) V

Let's break this down step by step, with concrete tensor shapes.

Suppose we have a sequence of T tokens, each represented by a d-dimensional embedding. The input X has shape [T, d]. We project X into queries, keys, and values using three learned weight matrices:

Q = X WQ     K = X WK     V = X WV

Where WQ, WK have shape [d, dk] and WV has shape [d, dv]. This gives Q and K shape [T, dk] and V shape [T, dv]. In the original Transformer, dk = dv = dmodel/h = 512/8 = 64.

Why three separate projections instead of using X directly? Because the raw embedding X is designed to represent the meaning of a token, not its role in attention. The projections let each token adopt three different "personas": as a query (what it's looking for), as a key (what it offers), and as a value (what it carries). A word like "cat" might have a query that asks "where is my verb?", a key that advertises "I'm a subject noun", and a value that carries rich semantic information about cat-ness.

Step 1: Compute attention scores

S = QKT     [T, dk] × [dk, T] = [T, T]

The attention score matrix S has shape [T, T]. Entry Sij is the dot product between query i and key j — it measures how much token i should attend to token j. This is the core computation: every query is compared against every key. For a sequence of 512 tokens, this produces a 512 × 512 = 262,144-entry matrix. This quadratic cost (O(T2) in sequence length) is the fundamental computational bottleneck of the Transformer. It's also why longer context windows are expensive — doubling the sequence length quadruples the attention cost.

But here's the crucial tradeoff: this O(T2) matrix multiplication is massively parallel. Every entry can be computed simultaneously on a GPU. An RNN with O(T) total computation can't parallelize at all — each step depends on the previous one. On modern GPUs with thousands of cores, the Transformer's parallel O(T2) is faster than the RNN's sequential O(T) for typical sequence lengths (up to ~8K tokens). This is the engineering insight that made the Transformer practical.

Step 2: Scale

S = S / √dk

Why divide by √dk? Without scaling, the dot products can become very large when dk is large. If Q and K have entries with zero mean and unit variance, then QKT has entries with variance dk (sum of dk products of unit-variance terms). Large values push the softmax into saturation, where the gradients are near-zero. Dividing by √dk normalizes the variance back to 1.

Why √dk specifically? If qi and kj are independent with mean 0 and variance 1, then q · k = ∑l ql kl has mean 0 and variance dk (by independence). Dividing by √dk gives variance 1, keeping the softmax inputs in the linear regime where gradients are healthy. This is why Vaswani et al. call it "scaled" dot-product attention. Without the scaling, larger models (with larger dk) would have worse gradient flow — the opposite of what you want.

Step 3: Softmax

A = softmax(S)     [T, T], each row sums to 1

The softmax converts raw scores into a probability distribution. Row i of A gives the attention weights for token i — how much it attends to each other token. These weights are non-negative and sum to 1.

Step 4: Weighted sum

Output = A V     [T, T] × [T, dv] = [T, dv]

Finally, each token's output is a weighted combination of all value vectors, where the weights come from the attention matrix. Token i's output is ∑j Aij Vj — a blend of all values, weighted by attention.

Why dot product as the similarity function?

The paper considered three similarity functions for computing attention scores:

MethodFormulaComplexityPerformance
Dot productqTkO(d)Best at large d (with scaling)
AdditivevT tanh(W1q + W2k)O(d)Better at small d without scaling
MultiplicativeqTWkO(d²)More parameters, marginal gains

The dot product is preferred because it can be computed as a single matrix multiplication (QKT), which is extremely efficient on GPUs. The additive method requires two separate projections and a nonlinearity, making it slower despite having similar theoretical expressiveness.

Vaswani et al. noted that for large dk (like 64 or 128), the dot product without scaling performs worse than the additive method, because the dot products grow with dk and push the softmax into saturation. The √dk scaling factor restores the dot product's advantage.

Attention as information retrieval

There's a deep connection between attention and information retrieval. In a search engine, you have a query and a database of documents. Each document has a key (its content) and a value (the information you want). The search engine computes similarity between the query and all keys, then returns the most relevant values.

Self-attention does exactly this, but differentiably and within a single sequence. Each token is simultaneously a query (looking for information), a key (advertising its relevance), and a value (providing information). The O(T2) computation is the cost of comparing every query against every key — the same cost as an exhaustive search over the sequence.

A key realization: Self-attention computes a complete bipartite graph between all positions in a single matrix multiplication. Every token gets a direct information channel to every other token. This is why the gradient path is O(1) — the gradient can flow directly from any output to any input through the attention weights, without passing through intermediate hidden states.
Scaled Dot-Product Attention: Step Through

Walk through the four steps of scaled dot-product attention on a 4-token sequence. Click "Next Step" to advance through: (1) QKT scores, (2) scale by √dk, (3) softmax, (4) multiply by V. Watch the tensor shapes transform at each step.

Step 0/4 — Ready
python
import torch
import torch.nn.functional as F
import math

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Q: [batch, seq_len, d_k]
    K: [batch, seq_len, d_k]
    V: [batch, seq_len, d_v]
    Returns: [batch, seq_len, d_v]
    """
    d_k = Q.size(-1)

    # Step 1: Raw attention scores
    scores = torch.matmul(Q, K.transpose(-2, -1))  # [B, T, T]

    # Step 2: Scale by sqrt(d_k)
    scores = scores / math.sqrt(d_k)

    # Optional: apply mask (for decoder self-attention)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)

    # Step 3: Softmax over keys
    attn_weights = F.softmax(scores, dim=-1)  # [B, T, T]

    # Step 4: Weighted sum of values
    output = torch.matmul(attn_weights, V)  # [B, T, d_v]

    return output, attn_weights

# Example: 4 tokens, d_k = d_v = 8
T, d_k = 4, 8
Q = torch.randn(1, T, d_k)
K = torch.randn(1, T, d_k)
V = torch.randn(1, T, d_k)

out, weights = scaled_dot_product_attention(Q, K, V)
# out.shape: [1, 4, 8]
# weights.shape: [1, 4, 4] — each row sums to 1
Why does scaled dot-product attention divide the scores by √dk?

Chapter 3: Multi-Head Attention

A single attention head computes one set of attention weights — one way of looking at the relationships between tokens. But language has many types of relationships simultaneously: syntactic (subject-verb agreement), semantic (word meaning similarity), positional (nearby words), and more.

Multi-head attention runs multiple attention heads in parallel, each with its own learned WQ, WK, WV projections. Each head can learn to attend to different types of relationships. The outputs of all heads are concatenated and projected back to the model dimension.

MultiHead(Q, K, V) = Concat(head1, ..., headh) WO
headi = Attention(Q WQi, K WKi, V WVi)

In the original Transformer:

ParameterValueShape
dmodel512Model dimension
h8Number of attention heads
dk = dv64dmodel/h per head
WQi, WKi[512, 64]Per-head Q/K projection
WVi[512, 64]Per-head V projection
WO[512, 512]Output projection (concat of 8 × 64 = 512)
Why multiple heads? A single attention head can only compute a single pattern of attention weights. But when processing "The cat sat on the mat," the word "sat" needs to simultaneously attend to "cat" (who sat?), "on" (sat where?), and "mat" (on what?). Multiple heads let the model maintain several independent "attention lenses" — one head might focus on syntactic relationships, another on semantic similarity, another on proximity.
Multi-Head Attention Visualization

Four attention heads processing a 5-token sequence. Each head learns different attention patterns. Click a head to see its attention weights. Notice how different heads focus on different relationships.

Head 1: Syntactic

The cost of multi-head attention

Here's a beautiful efficiency insight: multi-head attention with h heads of dimension dk = dmodel/h has the same total computation as single-head attention with dimension dmodel. We're not adding heads on top of single-head attention — we're splitting the same computation into h parallel streams.

Single head: one [d, d] projection for Q, K, V each = 3d2 parameters.

Multi-head (h heads): h × [d, d/h] projections + one [d, d] output projection = 3d2 + d2 = 4d2 parameters.

The extra cost is just the output projection WO. In return, we get h independent attention patterns instead of one.

What happens inside each head

Each head operates on a lower-dimensional space (dk = 64 instead of dmodel = 512). This might seem like a limitation, but it's actually a feature. In high-dimensional space (d = 512), attention weights tend to be diffuse — all query-key dot products are similar. In lower-dimensional space (d = 64), attention can be sharper and more discriminative.

Think of it this way: with a single 512-dimensional head, the model must use one set of attention weights for everything. With 8 heads of 64 dimensions each, each head gets its own 64-dimensional "subspace" to specialize in. One head might project tokens into a subspace where syntactic relationships are prominent; another might project into a subspace where semantic relationships dominate.

An analogy: Multi-head attention is like having 8 different pairs of glasses, each with a different prescription. Through one pair, you see syntactic structure clearly. Through another, you see semantic similarity. Through a third, you see positional relationships. The final output combines all 8 views into a comprehensive understanding.

Head pruning and redundancy

Research after this paper (Michel et al., 2019) found that many attention heads can be removed without significantly hurting performance. Some heads are highly redundant. But the most important heads are critical — removing them causes large performance drops. This suggests that multi-head attention has built-in redundancy, which improves robustness.

python
import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=512, n_heads=8):
        super().__init__()
        self.d_k = d_model // n_heads  # 64
        self.n_heads = n_heads
        # One big projection for all heads (efficient)
        self.W_qkv = nn.Linear(d_model, 3 * d_model)
        self.W_o   = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        B, T, D = x.shape
        # Project to Q, K, V and split into heads
        qkv = self.W_qkv(x)  # [B, T, 3*D]
        qkv = qkv.reshape(B, T, 3, self.n_heads, self.d_k)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # [3, B, h, T, d_k]
        Q, K, V = qkv[0], qkv[1], qkv[2]

        # Scaled dot-product attention per head
        scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attn = scores.softmax(dim=-1)      # [B, h, T, T]
        out  = attn @ V                      # [B, h, T, d_k]

        # Concatenate heads and project
        out = out.transpose(1, 2).reshape(B, T, D)  # [B, T, D]
        return self.W_o(out)  # [B, T, D]
Why does the Transformer use multiple attention heads instead of a single large attention head?

Chapter 4: Positional Encoding

Self-attention has a curious property: it's permutation-invariant. If you shuffle the order of the input tokens, the attention scores change but the mechanism doesn't inherently know that the order matters. "The cat chased the dog" and "The dog chased the cat" would produce the same attention matrix (same Q-K dot products, just in different positions).

But word order obviously matters! We need some way to tell the model about position. The Transformer's solution: add a positional encoding to each token's embedding before feeding it into the attention layers.

inputi = embeddingi + positional_encodingi

Vaswani et al. used sinusoidal positional encodings — a fixed function of position, not learned parameters:

PE(pos, 2i) = sin(pos / 100002i/d)
PE(pos, 2i+1) = cos(pos / 100002i/d)

Where pos is the position in the sequence and i is the dimension index. Each dimension uses a sinusoid of a different frequency. Low-frequency sinusoids (small i) encode coarse position; high-frequency sinusoids (large i) encode fine position.

Why sinusoids? Three crucial properties:
1. Unique: Each position gets a unique encoding vector — no two positions have the same PE.
2. Relative position via linear transformation: The PE at position pos+k can be expressed as a linear function of PE at position pos. This means the model can learn to attend to relative positions (e.g., "the word 3 positions to my left") through the learned Q/K projections.
3. Bounded: All values are in [-1, 1], so they don't dominate the learned embeddings.
Positional Encoding Heatmap

Each row is a position (0-31), each column is a dimension. Color represents the PE value (warm = positive, teal = negative). Notice how lower dimensions oscillate faster (fine position) while higher dimensions oscillate slower (coarse position). Drag to change the number of positions shown.

Positions 32

Learned vs sinusoidal

The paper also experimented with learned positional embeddings — treating the position encoding as a learnable parameter matrix of shape [max_positions, d_model]. They found that learned and sinusoidal encodings performed nearly identically on translation tasks.

In practice, most modern Transformers use learned positional embeddings (GPT-2, BERT) or more advanced schemes like rotary positional encoding (RoPE, used in LLaMA, Claude). RoPE encodes relative position directly into the Q/K dot product, which is more elegant than additive positional encoding.

The relative position property (detailed)

The sinusoidal PE has a remarkable mathematical property. For any fixed offset k, there exists a linear transformation Mk such that:

PE(pos + k) = Mk · PE(pos)

This Mk is a block-diagonal matrix of 2x2 rotation matrices. Each pair of dimensions (2i, 2i+1) rotates by an angle proportional to k/100002i/d. This means the model can learn to attend to relative positions through a linear operation — the WQ and WK projections can encode "attend to the token k positions to my left" as a simple matrix multiply.

To verify this, note that the PE uses sin and cos at the same frequency for each pair of dimensions. The rotation matrix for offset k at frequency ω is:

Mk = [[cos(kω), -sin(kω)], [sin(kω), cos(kω)]]

This rotates the (sin, cos) pair by angle kω, which is equivalent to shifting the position by k. The model can learn WQ and WK that implement this rotation for any k.

Visualizing position similarity

One way to verify the PE works is to compute dot products between position encodings. Positions that are close should have similar encodings (high dot product) while distant positions should be less similar:

python
# Position similarity via PE dot products
import torch

pe = sinusoidal_pe(128, 64)
# Compute pairwise cosine similarities
pe_norm = pe / pe.norm(dim=1, keepdim=True)
sim_matrix = pe_norm @ pe_norm.T

# Positions 0 and 1 are similar (nearby)
print(f"sim(0,1)  = {sim_matrix[0,1]:.3f}")   # ~0.95
# Positions 0 and 50 are less similar
print(f"sim(0,50) = {sim_matrix[0,50]:.3f}")  # ~0.4
# Positions 10 and 11 have same similarity as 0 and 1
print(f"sim(10,11)= {sim_matrix[10,11]:.3f}") # ~0.95 (translation invariant!)
PE TypeUsed ByProsCons
Sinusoidal (fixed)Original TransformerNo parameters, generalizes to any lengthNot as expressive as learned
LearnedGPT-2, BERTCan learn arbitrary patternsFixed max length, more parameters
RoPELLaMA, ClaudeRelative position, length generalizationSlightly more complex implementation
ALiBiMPT, BLOOMVery simple, good length generalizationLinear bias can limit expressivity
python
import torch
import math

def sinusoidal_pe(max_len, d_model):
    """Create sinusoidal positional encoding matrix."""
    pe = torch.zeros(max_len, d_model)
    position = torch.arange(0, max_len).unsqueeze(1).float()
    div_term = torch.exp(
        torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
    )
    pe[:, 0::2] = torch.sin(position * div_term)  # even dims
    pe[:, 1::2] = torch.cos(position * div_term)  # odd dims
    return pe  # [max_len, d_model]

# Example: 128 positions, d_model = 64
pe = sinusoidal_pe(128, 64)
print(pe.shape)  # [128, 64]
print(pe[0, :4])  # [0, 1, 0, 1] — sin(0)=0, cos(0)=1
Why does the Transformer need positional encoding?

Chapter 5: The Encoder Block

We now have all the pieces for the encoder. A Transformer encoder block combines multi-head attention with a feedforward network, connected by residual connections and layer normalization.

Input X
Token embeddings + positional encoding [T, dmodel]
Multi-Head Self-Attention
Each token attends to all tokens in the sequence
↓ + residual → LayerNorm
Feed-Forward Network
Two linear layers with ReLU: dmodel → dff → dmodel
↓ + residual → LayerNorm
Output
Contextualized representations [T, dmodel]

Each sub-layer (attention, FFN) is wrapped in a residual connection and layer norm:

output = LayerNorm(x + SubLayer(x))

Why residual connections?

Residual connections solve the vanishing gradient problem for deep networks. The Jacobian of a residual layer x + F(x) is I + JF, where I is the identity matrix and JF is the Jacobian of the sub-layer. Even if JF is small (vanishing sublayer gradient), the total Jacobian is I + something small, which has eigenvalues near 1. The gradient flows freely through the identity path.

Without residual connections, a 6-layer encoder would suffer gradient decay similar to an RNN over 6 steps. With residual connections, the gradient has a "highway" that bypasses each sub-layer, ensuring it reaches all layers with minimal attenuation. This is exactly the same principle as the LSTM's constant error carousel — additive connections preserve gradient flow.

The original Transformer paper (2017) used 6 encoder and 6 decoder layers. Modern Transformers use 32-96 layers (GPT-3: 96, LLaMA-70B: 80). This massive depth is only possible because of residual connections. Without them, training a 96-layer network would be impossible due to gradient vanishing.

Why layer normalization?

Layer normalization (LayerNorm) normalizes the activations across the feature dimension (the dmodel dimension), stabilizing training by preventing the internal representations from drifting to extreme values. Unlike batch normalization (which normalizes across the batch dimension), LayerNorm works on each example independently. This is essential for two reasons:

First, sequences have variable lengths, so batch statistics would be computed over different numbers of tokens in different batches — making normalization inconsistent. Second, at inference time (generating text token by token), there's no "batch" to compute statistics over. LayerNorm avoids both issues by normalizing each token's representation independently.

LayerNorm(x) = γ · (x - μ) / √(σ2 + ε) + β

Where μ and σ are the mean and standard deviation computed across the dmodel dimension (not across the batch or sequence dimensions), and γ, β are learned scale and shift parameters of shape [dmodel]. For dmodel = 512, LayerNorm adds only 1024 parameters (512 for γ, 512 for β) — negligible compared to the attention and FFN weights.

LayerNorm serves two purposes: (1) it stabilizes the forward pass by keeping activations in a normalized range, preventing the "internal covariate shift" that makes training unstable, and (2) it stabilizes the backward pass by preventing gradient magnitudes from drifting, which is particularly important for deep Transformer stacks.

The feedforward network

The FFN is applied independently to each position — there is no cross-position interaction in the FFN. This is a deliberate separation of concerns: attention handles inter-token communication (which tokens should influence each other), while the FFN handles per-token processing (what to do with the information once gathered). This clean separation makes each component easier to understand and optimize.

Interestingly, recent research has shown that the FFN layers serve as "key-value memories" (Geva et al., 2021). Each row of the first weight matrix W1 acts as a "key" for a particular pattern, and the corresponding row of W2 provides the "value" to be added. When the input matches a key pattern, the ReLU activates and the corresponding value is added to the representation. This explains why FFN layers store factual knowledge (e.g., "Paris is the capital of France").

The FFN equation is:

FFN(x) = ReLU(x W1 + b1) W2 + b2

The inner dimension dff = 2048 is 4x larger than dmodel = 512. This expansion-compression pattern is crucial: the FFN serves as a "memory" where the model stores and retrieves factual knowledge. The 4x expansion gives it capacity to encode many patterns; the compression forces it to select only the relevant ones. In modern Transformers, the ratio is often closer to 8/3x (e.g., LLaMA uses dff = 2.67 × dmodel with the SwiGLU activation, which has three weight matrices instead of two).

The Encoder Block

Data flows through a complete encoder block. Click a component to see its internal operation and tensor shapes. The residual connections (dashed lines) allow gradients to bypass each sub-layer.

Click to cycle through components
python
import torch
import torch.nn as nn

class EncoderBlock(nn.Module):
    def __init__(self, d_model=512, n_heads=8, d_ff=2048, dropout=0.1):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, n_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),     # expand 512 → 2048
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),     # compress 2048 → 512
        )
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Sub-layer 1: Multi-head self-attention + residual + norm
        attn_out = self.attention(x)     # [B, T, D]
        x = self.norm1(x + self.dropout(attn_out))

        # Sub-layer 2: FFN + residual + norm
        ffn_out = self.ffn(x)            # [B, T, D]
        x = self.norm2(x + self.dropout(ffn_out))

        return x  # [B, T, D]
The Transformer encoder stack: The original Transformer uses N=6 encoder blocks stacked on top of each other. The output of each block feeds into the next. Thanks to residual connections, gradients flow cleanly through all 6 blocks. Each block refines the representations — early blocks capture local patterns, later blocks capture global patterns.

Parameter count breakdown

Let's count exactly how many parameters are in one encoder block (dmodel = 512, dff = 2048, h = 8):

ComponentParametersCount
WQ, WK, WV3 × (512 × 512)786,432
WO512 × 512262,144
LayerNorm 1 (γ, β)2 × 5121,024
FFN W1512 × 20481,048,576
FFN W22048 × 5121,048,576
FFN biases2048 + 5122,560
LayerNorm 22 × 5121,024
Total per block~3.15M
6 blocks total~18.9M

The FFN accounts for about 2/3 of the parameters! This is why scaling papers often focus on the FFN dimension. Increasing dff adds capacity cheaply (no quadratic attention cost).

Pre-norm vs post-norm

The original Transformer uses post-norm: LayerNorm is applied after the residual addition. Later work found that pre-norm (applying LayerNorm before the sub-layer) is more stable for training deep Transformers:

Post-norm (original): output = LayerNorm(x + SubLayer(x))
Pre-norm (modern): output = x + SubLayer(LayerNorm(x))

Pre-norm puts the normalization inside the residual branch, which means the gradient through the skip connection is completely unmodified — pure identity. The gradient through the residual path is exactly 1.0, regardless of what happens in the sub-layer. This makes training dramatically more stable, especially for deep models (24+ layers). GPT-2, GPT-3, LLaMA, and most modern LLMs use pre-norm.

The difference matters at scale: the original Transformer with post-norm required careful warmup to train 6 layers. With pre-norm, you can train 96 layers (GPT-3) with a straightforward training recipe. The architectural change is a single line of code — moving the LayerNorm before the sub-layer instead of after — but its impact on trainability is profound.

What is the purpose of the feedforward network (FFN) in each encoder block?

Chapter 6: The Decoder Block

The Transformer was originally designed for sequence-to-sequence tasks (machine translation). The encoder processes the input sentence; the decoder generates the output sentence. The decoder is more complex than the encoder because it has three sub-layers instead of two.

Masked Self-Attention
Decoder attends to previously generated tokens only (causal mask)
↓ + residual → LayerNorm
Cross-Attention
Q from decoder, K/V from encoder output — the decoder "reads" the input
↓ + residual → LayerNorm
Feed-Forward Network
Same expand-compress pattern as encoder
↓ + residual → LayerNorm
Output
Contextualized decoder representations [Tdec, dmodel]

Masked self-attention

In the encoder, every token can attend to every other token (bidirectional). In the decoder, this would be cheating — position i shouldn't be able to see positions i+1, i+2, ... because those tokens haven't been generated yet. The causal mask prevents this by setting future positions to -∞ before the softmax:

mask[i][j] = 0 if j ≤ i, else -∞

After softmax, positions with -∞ get attention weight exactly 0 (since e-∞ = 0). This ensures the model is autoregressive: each token can only depend on tokens that came before it. This is the same constraint as in an RNN (where ht can only depend on x1, ..., xt), but implemented through masking rather than sequential processing.

The causal mask is a lower-triangular matrix of 1s. Token 0 can only see itself. Token 1 can see tokens 0 and 1. Token T-1 can see all tokens. This creates a "cone" of visibility that grows with position — just like an RNN, where later hidden states have access to more of the past.

An important practical detail: the mask is applied before the softmax, by adding -∞ to masked positions. You can't apply it after the softmax (by zeroing out weights and renormalizing), because that would change the gradient computation and be less numerically stable. The -∞ approach ensures that masked positions contribute exactly zero to both the forward pass and the backward pass.

Causal Mask in Decoder Self-Attention

The attention matrix for decoder self-attention. The mask (dark triangular region) prevents tokens from attending to future positions. Token i can only attend to tokens 0 through i. Click "Toggle Mask" to see the difference.

Masked (causal) — autoregressive

Cross-attention

The second sub-layer is cross-attention: the decoder attends to the encoder's output. This is how the decoder "reads" the input sentence. The queries come from the decoder, but the keys and values come from the encoder:

CrossAttention: Q = decoder, K = encoder, V = encoder

This is the bridge between the encoder and decoder. When translating "The cat sat" to "Le chat assis," the decoder token "chat" generates a query that attends to the encoder token "cat," pulling in the relevant information for translation. The cross-attention mechanism learns which input tokens are relevant for generating each output token — this is the mechanism that handles word reordering, one-to-many translations, and other alignment challenges.

Cross-attention is not masked (unlike decoder self-attention). Every decoder token can attend to every encoder token, because the entire input is available when generating the output. The decoder can look anywhere in the input at any point during generation.

Why three sub-layers instead of two?

The encoder has two sub-layers (self-attention + FFN). The decoder adds a third (cross-attention) because it needs to do something the encoder doesn't: read the input. The encoder processes the input independently; the decoder must condition its output on the encoder's representation.

The information flow is:

Masked self-attention
"What have I generated so far?" (internal coherence)
Cross-attention
"What does the input say?" (translation/conditioning)
FFN
"Given both, what comes next?" (generation)

Each sub-layer serves a distinct role. Removing any one significantly hurts performance. The masked self-attention ensures the output is coherent; the cross-attention ensures it's faithful to the input; the FFN provides the capacity for complex reasoning.

Decoder-only models (GPT)

An important simplification: if you don't need encoder-decoder (no separate input/output like translation), you can use decoder-only architecture. This keeps only the masked self-attention + FFN, removing cross-attention entirely. The input and output are concatenated into a single sequence processed autoregressively.

This is what GPT, Claude, LLaMA, and most modern LLMs use. It turns out to be more parameter-efficient and simpler to train at scale. The "encoder" is implicit — the model processes the input prompt and generated output as a single sequence with causal masking. The prompt tokens can attend to each other (acting like an encoder), while the generated tokens attend to both the prompt and previously generated tokens (acting like a decoder).

The key advantage of decoder-only: you only need one type of attention (causal self-attention), one stack of blocks, and one training objective (next-token prediction). This simplicity enables cleaner scaling — fewer hyperparameters, fewer failure modes, and more straightforward parallelization across thousands of GPUs.

python
# Decoder-only Transformer (GPT-style)
class GPTBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, n_heads)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff), nn.GELU(),  # GPT uses GELU, not ReLU
            nn.Linear(d_ff, d_model))
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)

    def forward(self, x, mask):
        # Pre-norm (GPT-2 style) not post-norm
        x = x + self.attn(self.ln1(x), mask=mask)
        x = x + self.ffn(self.ln2(x))
        return x
Encoder vs decoder attention patterns:
Encoder self-attention: Bidirectional. Every input token attends to every other input token. "cat" can see "sat" and vice versa.
Decoder masked self-attention: Unidirectional (causal). Each output token can only attend to previous output tokens. "assis" can see "Le chat" but not future tokens.
Cross-attention: Decoder-to-encoder. Each output token can attend to any input token. "chat" attends to "cat" in the input.
python
class DecoderBlock(nn.Module):
    def __init__(self, d_model=512, n_heads=8, d_ff=2048, dropout=0.1):
        super().__init__()
        # Sub-layer 1: Masked self-attention
        self.self_attn = MultiHeadAttention(d_model, n_heads)
        self.norm1 = nn.LayerNorm(d_model)
        # Sub-layer 2: Cross-attention (decoder queries, encoder keys/values)
        self.cross_attn = MultiHeadAttention(d_model, n_heads)
        self.norm2 = nn.LayerNorm(d_model)
        # Sub-layer 3: FFN
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff), nn.ReLU(),
            nn.Dropout(dropout), nn.Linear(d_ff, d_model))
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, encoder_out, causal_mask):
        # Masked self-attention (decoder looks at itself, causally)
        sa = self.self_attn(x, mask=causal_mask)
        x = self.norm1(x + self.dropout(sa))

        # Cross-attention (decoder queries, encoder K/V)
        # In practice: Q=x, K=encoder_out, V=encoder_out
        ca = self.cross_attn(x)  # simplified; real impl passes K,V separately
        x = self.norm2(x + self.dropout(ca))

        # FFN
        ff = self.ffn(x)
        x = self.norm3(x + self.dropout(ff))
        return x

# Create causal mask: lower triangular matrix
def causal_mask(T):
    return torch.tril(torch.ones(T, T))  # 1s on and below diagonal
What is the purpose of the causal mask in the decoder's self-attention?

Chapter 7: Training Details

The Transformer's architecture is elegant, but training it requires several careful engineering choices. The paper introduced innovations in learning rate scheduling, regularization, and optimization that became standard practice.

The warmup learning rate schedule

The paper introduced a learning rate schedule that has become one of the most widely used in deep learning: linear warmup followed by inverse square root decay. This schedule, often called the "Noam schedule" (after co-author Noam Shazeer), has been adopted by virtually every subsequent Transformer training pipeline, though sometimes with modifications (cosine decay instead of inverse square root, different warmup durations).

lr = dmodel-0.5 · min(step-0.5, step · warmup_steps-1.5)

This looks complex but has a simple two-phase behavior:

PhaseStepsLearning RateWhy
Warmup0 to 4000Linearly increases from 0Large initial LR would destabilize attention weights before they've converged to meaningful patterns
Decay4000 onwardDecays as 1/√stepGradual reduction for fine convergence
Learning Rate Schedule

The Transformer's "noam" learning rate schedule. Linear warmup for the first few thousand steps, then inverse square root decay. Drag to change the warmup steps.

Warmup steps 4000

Label smoothing

Label smoothing: a regularization technique

Instead of training with hard targets (100% probability on the correct token, 0% on everything else), the paper used label smoothing with ε = 0.1. The motivation: with hard targets, the model is incentivized to make its predictions infinitely confident (push softmax outputs toward 0 and 1). This leads to overfitting and poor calibration. Label smoothing softens the targets:

ysmooth = (1 - ε) · yhard + ε / V

Where V is the vocabulary size (typically 32K-50K for word/subword vocabularies). The correct token gets probability 0.9 instead of 1.0, and the remaining 0.1 is spread uniformly across all other tokens. This prevents the model from becoming overconfident and improves generalization.

Label smoothing has an interesting side effect: it slightly increases perplexity (the model assigns lower probability to the correct token) but improves BLEU score (the model generates better translations). This is because the perplexity metric rewards extreme confidence, while BLEU rewards accurate generation. The model becomes better at the actual task (translation) even though it appears worse on the proxy metric (perplexity).

In modern LLM training, label smoothing is less commonly used (GPT models typically don't use it), because the training objective is next-token prediction and the very large vocabulary naturally prevents overconfidence. However, for machine translation and other structured output tasks, it remains standard.

Optimization and regularization

TechniqueValuePurpose
OptimizerAdam (β1=0.9, β2=0.98, ε=10-9)Adaptive learning rates per parameter
DropoutPdrop = 0.1Applied to sub-layer outputs, attention weights, embeddings
Label smoothingε = 0.1Prevents overconfidence, helps perplexity
Gradient clippingNot explicitly mentionedBut used in all subsequent Transformer training
Training scale: The base Transformer was trained on 8 NVIDIA P100 GPUs for 12 hours (100K steps, batch size ~25K tokens). The big Transformer was trained for 3.5 days (300K steps). By today's standards, this is tiny — GPT-3 used thousands of GPUs for weeks. But in 2017, it was enough to set state-of-the-art on English-to-German and English-to-French translation.

Why warmup is necessary for Transformers

The warmup phase is more important for Transformers than for other architectures. Here's why: at initialization, the attention weights are essentially random. The softmax distributes attention roughly uniformly across all positions. In this regime, the model needs to learn which positions to attend to before it can learn what to do with the attended information.

If the learning rate is large at initialization, the model makes large updates to the attention weights based on random attention patterns — learning from noise. This can push the attention weights into bad local optima or cause numerical instability (the softmax can produce very large or very small values).

The warmup gives the model time to establish stable attention patterns at a low learning rate, then ramps up once those patterns are meaningful.

The Adam optimizer and Transformers

The paper uses Adam with non-standard β2 = 0.98 (instead of the default 0.999). This means the second-moment estimate updates faster, making Adam more responsive to recent gradient magnitudes. This is important because Transformer gradients can change rapidly as attention patterns shift during training.

HyperparameterPaper valueDefault AdamWhy different
β10.90.9Same — standard momentum
β20.980.999Faster adaptation to changing gradient statistics
ε10-910-8Smaller for numerical stability with mixed precision

Dropout patterns

Dropout is applied at three places in the Transformer, each serving a different purpose:

Attention dropout
Applied to attention weights after softmax — randomly zeroes some attention connections, forcing the model to not rely on any single token relationship
Sub-layer dropout
Applied to the output of each sub-layer (attention, FFN) before adding the residual — prevents co-adaptation between sub-layers
Embedding dropout
Applied to the sum of token embedding + positional encoding — regularizes the input representation
python
# Transformer training recipe
import torch
import torch.nn as nn

# Noam learning rate schedule
class NoamScheduler:
    def __init__(self, optimizer, d_model=512, warmup=4000):
        self.optimizer = optimizer
        self.d_model = d_model
        self.warmup = warmup
        self.step_num = 0

    def step(self):
        self.step_num += 1
        lr = self.d_model ** (-0.5) * min(
            self.step_num ** (-0.5),
            self.step_num * self.warmup ** (-1.5)
        )
        for p in self.optimizer.param_groups:
            p['lr'] = lr

# Label smoothing loss
class LabelSmoothingLoss(nn.Module):
    def __init__(self, vocab_size, smoothing=0.1):
        super().__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.vocab_size = vocab_size

    def forward(self, pred, target):
        # pred: [B*T, V], target: [B*T]
        log_probs = pred.log_softmax(dim=-1)
        nll = -log_probs.gather(dim=-1, index=target.unsqueeze(1))
        smooth = -log_probs.sum(dim=-1) / self.vocab_size
        loss = self.confidence * nll.squeeze(1) + self.smoothing * smooth
        return loss.mean()
Why does the Transformer's learning rate schedule start with a warmup phase?

Chapter 8: Attention Explorer

This is the payoff. An interactive visualization where you can see the complete self-attention mechanism in action — from Q, K, V computation through attention weights to the final output. Step through each stage and see exactly how a sentence is processed.

Self-Attention Step-Through

Watch self-attention process a sentence step by step. Select a query token to see its Q vector compared against all K vectors, the resulting attention weights, and the weighted V sum. Click tokens on the left to change the query.

Temperature 1.0
What to explore:
1. Click "sat" — notice it attends most strongly to "cat" (who sat?) and "on" (sat where?).
2. Click "The" — articles attend broadly since they don't carry specific meaning.
3. Lower the temperature to 0.3 — attention becomes sharp (one token dominates). This is "hard" attention.
4. Raise the temperature to 3.0 — attention becomes uniform (all tokens get similar weight). This is "soft" attention.
5. Temperature 1.0 is the standard setting — a balanced mix.

What attention patterns emerge

Researchers have analyzed trained Transformer attention heads and found remarkably interpretable patterns:

Head TypeWhat It DoesExample
SyntacticAttends to grammatical dependencies"sat" strongly attends to "cat" (subject-verb)
PositionalAttends to adjacent tokensEach token attends to its immediate neighbor
SemanticAttends to semantically related words"cat" attends to "mat" (both concrete nouns)
CopyAttends to the same or similar token"the" (second occurrence) attends to "the" (first)
InductionLooks for repeated patternsIf "AB" appeared before, attends to B when A appears again

These patterns are not hardcoded — they emerge from training. Different heads in different layers specialize in different types of attention, creating a rich, multi-faceted understanding of the input.

The computational cost of attention

Self-attention has O(T2 · d) complexity, where T is the sequence length and d is the model dimension. The T2 comes from computing all pairwise attention scores. For T = 2048 (a typical context window in 2017), this means 4 million attention score computations per head per layer.

For comparison:

OperationTime complexityMemory
Self-attentionO(T2 · d)O(T2) for attention matrix
FFNO(T · d · dff)O(T · dff)
RNN stepO(d2) per step, O(T · d2) totalO(d) per step

Self-attention is more expensive than an RNN for the same sequence length and model dimension. The Transformer trades compute and memory for parallelism and gradient flow. Modern hardware (GPUs, TPUs) strongly favors parallel operations, so the Transformer's approach is a net win despite the higher theoretical complexity.

Temperature and sharpness

The temperature parameter (which we control with the slider) is equivalent to dividing the scores by an extra factor beyond √dk. Lower temperature = sharper attention (approaches argmax). Higher temperature = softer attention (approaches uniform). The standard Transformer uses temperature 1.0, but during inference for text generation, temperature is often lowered (0.7-0.9) to produce more focused, less random outputs.

python
# Temperature effect on attention distribution
import torch
import torch.nn.functional as F

scores = torch.tensor([2.0, 1.0, 0.5, -0.5])

for temp in [0.1, 0.5, 1.0, 2.0, 5.0]:
    weights = F.softmax(scores / temp, dim=0)
    print(f"T={temp:.1f}: {weights.numpy().round(3)}")
# T=0.1: [1.000, 0.000, 0.000, 0.000]  ← hard (argmax)
# T=0.5: [0.844, 0.114, 0.038, 0.004]  ← sharp
# T=1.0: [0.506, 0.186, 0.113, 0.042]  ← standard
# T=2.0: [0.356, 0.224, 0.192, 0.120]  ← soft
# T=5.0: [0.281, 0.258, 0.248, 0.213]  ← near-uniform
What happens to the attention weights when you lower the temperature (scale parameter) toward 0?

Chapter 9: Connections

The Transformer didn't just improve on RNNs — it created a new paradigm. Within three years of its publication, essentially every state-of-the-art model in NLP, computer vision, speech, and scientific computing was based on the Transformer architecture.

The Transformer family tree

YearModelVariantKey Innovation
2017TransformerEncoder-DecoderSelf-attention replaces recurrence
2018GPT-1Decoder-onlyPre-training on unlabeled text + fine-tuning
2018BERTEncoder-onlyBidirectional pre-training via masked language modeling
2019GPT-2Decoder-onlyLarger scale (1.5B params), zero-shot capabilities
2020GPT-3Decoder-only175B params, in-context learning
2020ViTEncoder-onlyTransformers for images (patches as tokens)
2020T5Encoder-DecoderText-to-text framework for all NLP tasks
2022ChatGPTDecoder-onlyRLHF alignment for conversational AI
2023LLaMADecoder-onlyOpen weights, efficient training, RoPE
2024ClaudeDecoder-onlyConstitutional AI, long context
The three Transformer variants:
Encoder-only (BERT): Bidirectional attention. Best for understanding tasks (classification, NER, question answering).
Decoder-only (GPT, Claude): Causal (left-to-right) attention. Best for generation tasks (text generation, code, chat).
Encoder-Decoder (original Transformer, T5): Full architecture. Best for sequence-to-sequence tasks (translation, summarization).

Results from the paper

On WMT 2014 English-to-German translation, the Transformer achieved 28.4 BLEU — surpassing all previous models including deep ensembles. On English-to-French, it achieved 41.0 BLEU with less than 1/4 the training cost of the previous state-of-the-art.

ModelEN-DE BLEUEN-FR BLEUTraining Cost (FLOPs)
ByteNet23.75
Deep-Att + PosUnk39.21.0×1020
ConvS2S (Gehring)25.1640.461.5×1020
GNMT + RL (Google)26.3039.922.3×1019
Transformer (base)27.338.13.3×1018
Transformer (big)28.441.02.3×1019

The Transformer (base) achieved near state-of-the-art with 10x fewer FLOPs than the next best model. The big Transformer achieved the best results at comparable cost. This efficiency was the key insight: self-attention enables massive parallelism, which translates directly to faster training on GPU hardware.

Model configurations

The paper tested two model sizes:

ConfigdmodeldffHeadsLayersParamsTraining
Base51220488665M12h on 8 P100s
Big10244096166213M3.5d on 8 P100s

For context, GPT-3 (2020) has 175 billion parameters — nearly 1000x larger than the "big" Transformer. The architecture is essentially the same; the scale is radically different. This suggests that the Transformer's power comes more from its architecture enabling efficient scaling than from the base model's inherent capability.

Ablation study highlights

The paper included an ablation study that reveals which components matter most:

VariationBLEU changeWhat it tells us
1 head instead of 8-0.9Multiple heads help, but even 1 head works decently
16 heads instead of 8-0.3Diminishing returns from more heads
Smaller dk (32)-1.5Head dimension matters more than head count
Larger model (d=1024)+1.1Scale is the biggest win
Learned PE-0.0Learned and sinusoidal PE perform identically
Transformer Architecture Overview

The complete Transformer architecture. Encoder (left) processes the input; decoder (right) generates the output. Cross-attention bridges them. Click components to highlight their role.

The Transformer's dominance

No single architecture in the history of machine learning has been as dominant as the Transformer. Before the Transformer, each domain had its own preferred architecture: RNNs for text, CNNs for images, graph neural networks for relational data, and specialized architectures for speech, music, protein structure, and code.

The Transformer replaced all of them. Not by being specialized for each domain, but by being so good at learning from data that specialization became unnecessary. Vision Transformers (ViT) surpass CNNs on image classification. Audio Spectrogram Transformers surpass specialized audio models. AlphaFold 2 uses Transformers for protein structure prediction. The Transformer is the closest thing to a "universal learning machine" that the field has produced.

Why is one architecture so universal? The key insight is that self-attention is a very general operation: compute pairwise relationships and aggregate information accordingly. This pattern appears in language (word relationships), vision (pixel relationships), chemistry (atom interactions), and virtually every other domain. The Transformer doesn't encode any domain-specific assumptions — it just learns relationships from data.

What changed after this paper

The Transformer didn't just improve existing benchmarks — it changed how AI research is conducted:

Before: task-specific
Each NLP task had its own architecture. Translation, summarization, question answering — all different models.
After: pre-train + fine-tune
One large Transformer trained on lots of text. Fine-tune for any task. BERT, GPT.
Now: prompt + generate
One massive Transformer that can do any task via natural language instructions. No fine-tuning needed. GPT-4, Claude.

Limitations acknowledged in the paper

Vaswani et al. noted that self-attention has O(T2) complexity in sequence length, making it expensive for very long sequences. This limitation spurred research into efficient attention variants:

MethodYearComplexityIdea
Sparse Attention2019O(T√T)Only attend to nearby + strided positions
Longformer2020O(T)Local window + global tokens
Linear Attention2020O(T)Replace softmax with kernel trick
FlashAttention2022O(T²) but fastIO-aware tiling, no materialized attention matrix
Mamba / SSM2023O(T)Replace attention with selective state spaces

FlashAttention (Dao et al., 2022) deserves special mention. It doesn't change the theoretical complexity (still O(T2) computations), but it avoids materializing the full T×T attention matrix in GPU memory by computing attention in tiles. Each tile fits in the GPU's fast SRAM (scratchpad memory), avoiding slow reads/writes to GPU HBM (global memory). This reduces memory usage from O(T2) to O(T) and gives a 2-4x wall-clock speedup.

FlashAttention made it practical to train with context lengths of 8K, 32K, and eventually 100K+ tokens. Without it, the O(T2) memory cost would make long-context training prohibitively expensive. With FlashAttention 2 (2023) and FlashAttention 3 (2024), the implementation has been further optimized for newer GPU architectures (H100, H200), achieving near-peak hardware utilization.

State-space models like Mamba take a different approach: they replace the O(T2) attention with O(T) linear recurrence, using a selective scan algorithm. This gives linear scaling with sequence length at the cost of reduced expressiveness (no direct pairwise comparisons). Whether this tradeoff is worthwhile depends on the task — for very long sequences (100K+ tokens), linear complexity may be essential.

Hybrid architectures that interleave attention and SSM layers show promise: attention handles the tasks requiring precise token-to-token comparisons (like copying, pattern matching, and reasoning), while SSM layers handle the tasks where linear processing suffices (like language modeling with local context). The Jamba architecture (AI21 Labs, 2024) and Mamba-2 (Dao and Gu, 2024) explore this direction.

Regardless of what comes next, the Transformer's core insight — that parallelizable attention can replace sequential recurrence — will remain foundational. Even if a better mechanism is discovered, it will be evaluated against the Transformer's remarkable combination of simplicity, scalability, and effectiveness.

The key efficiency innovations since the original paper tell a story of making the same architecture work at increasingly impractical scales:

InnovationYearWhat it enables
Mixed precision (FP16)20182x memory reduction, faster matmuls
Gradient checkpointing2016/2019Trade compute for memory in deep models
FlashAttention2022Long context without O(T²) memory
Tensor parallelism2019Split one layer across GPUs
Pipeline parallelism2019Split layers across GPUs
KV-cache2018+Fast autoregressive inference
Grouped-query attention2023Smaller KV-cache for long inference

The scaling hypothesis

Perhaps the most profound implication of the Transformer is the scaling hypothesis: larger Transformers trained on more data consistently get better, with no sign of diminishing returns. This was first observed by Kaplan et al. (2020) in their "Scaling Laws for Neural Language Models" paper, which showed that loss decreases as a power law in model size, dataset size, and compute.

The scaling hypothesis suggests that the Transformer architecture itself is not the bottleneck — the bottleneck is compute and data. This perspective drove the creation of GPT-3 (175B), PaLM (540B), and eventually Claude and GPT-4 at even larger scales. The architecture is essentially unchanged from 2017; only the scale has changed.

Kaplan's scaling laws show that loss decreases as a power law in each of three factors:

L(N, D, C) ≈ (Nc/N)αN + (Dc/D)αD + (Cc/C)αC

Where N is model size (parameters), D is dataset size (tokens), and C is compute (FLOPs). The exponents α are remarkably stable across model sizes, suggesting a smooth, predictable relationship between resources and performance. This predictability — unusual in ML, where results are often unpredictable — gave researchers confidence to invest billions of dollars in scaling Transformer training.

The scaling laws also revealed a surprising insight: model size matters more than training time. It's better to train a large model for fewer steps than a small model for many steps, given a fixed compute budget. This "Chinchilla optimal" training regime (Hoffmann et al., 2022) rebalanced the compute allocation and led to more efficient training protocols.

All of this flows from the original Transformer architecture. The scaling laws wouldn't hold if the architecture couldn't absorb additional capacity efficiently. The Transformer's combination of parallel computation (for training speed), residual connections (for gradient flow through deep stacks), and multi-head attention (for rich inter-token communication) creates an architecture that scales smoothly from 65M to over a trillion parameters without fundamental changes.

No previous architecture had this property. RNNs hit gradient walls at moderate depth. CNNs required increasingly complex skip connections and normalization tricks at large scale. The Transformer's clean, uniform structure — stack more identical blocks, add more heads, increase the dimension — made scaling almost boring in its predictability. And that predictability is exactly what enables billion-dollar training investments.

The Transformer's success in scaling also has a sociological dimension. Because scaling is predictable, companies can make rational decisions about resource allocation. A 10x increase in compute yields a predictable improvement in loss, which translates to a predictable improvement in capabilities. This reliability turned AI scaling from a research question ("will it work?") into an engineering question ("how much compute can we afford?"). The architecture's scalability made the commercial AI revolution possible.

It's worth reflecting on how unlikely this outcome was. In 2017, the deep learning community was exploring dozens of architectural innovations: dilated convolutions, memory-augmented networks, neural Turing machines, highway networks, and various attention hybrids. Of all these, the simplest and most general architecture — pure attention with residual connections — won so decisively that virtually everything else was abandoned within three years. Simplicity and generality, it turns out, are the most important architectural virtues.

A remarkable fact: The core Transformer architecture from 2017 — multi-head self-attention, FFN, residual connections, layer normalization — is essentially unchanged in 2024's frontier models. The "big model" Transformer in the paper had 213 million parameters. Claude and GPT-4 likely have 100-1000x more parameters, but the fundamental building block is the same. This is unprecedented architectural longevity in a field known for rapid churn.

Related lessons

The vanishing gradient problem that motivated the Transformer
Gradient clipping — still used when training Transformers
This paper (Vaswani 2017)
The Transformer — attention is all you need
Transformers conquer computer vision

The eight authors

An unusual aspect of this paper is its eight co-authors, many of whom were relatively junior at the time. The paper came out of Google Brain and Google Research, and was driven by a remarkably collaborative process. Several authors have spoken about how different parts of the architecture were contributed by different people:

AuthorKey ContributionLater Impact
Ashish VaswaniOverall architecture, scaled dot-productFounded Adept AI
Noam ShazeerMulti-head attention, the "Noam" LR scheduleCo-founded Character.AI
Niki ParmarImage Transformer extensionsFounded Adept AI
Jakob UszkoreitVision applicationsCo-founded Inceptive
Llion JonesArchitecture designCo-founded Sakana AI
Aidan GomezImplementationCo-founded Cohere
Lukasz KaiserTraining methodologyContinued at Google Brain
Illia PolosukhinImplementationCo-founded NEAR Protocol

The fact that six of the eight authors went on to found AI companies speaks to the transformative nature of their contribution. They didn't just write a paper — they catalyzed an entire industry.

The title as a manifesto

The paper's title — "Attention Is All You Need" — was deliberately provocative. At the time, attention was viewed as an auxiliary mechanism that helped RNNs focus on relevant parts of the input. No one thought of attention as a standalone architecture. The title declared that attention alone, without any recurrence or convolution, was sufficient.

The boldness of this claim turned out to be prophetic. Not only was attention sufficient for translation — it was sufficient for language understanding (BERT), language generation (GPT), image recognition (ViT), speech recognition (Whisper), protein folding (AlphaFold), robotics (RT-2), video generation (Sora), and essentially every other AI task attempted since.

The complete Transformer in code

Let's put everything together — the complete Transformer encoder-decoder architecture in PyTorch:

python
import torch
import torch.nn as nn
import math

class Transformer(nn.Module):
    """The complete Transformer from Vaswani et al. 2017."""
    def __init__(self, src_vocab, tgt_vocab, d_model=512,
                 n_heads=8, n_layers=6, d_ff=2048, dropout=0.1,
                 max_len=5000):
        super().__init__()
        # Embeddings
        self.src_embed = nn.Embedding(src_vocab, d_model)
        self.tgt_embed = nn.Embedding(tgt_vocab, d_model)
        self.pos_enc = sinusoidal_pe(max_len, d_model)
        self.scale = math.sqrt(d_model)

        # Encoder stack
        self.encoder = nn.ModuleList([
            EncoderBlock(d_model, n_heads, d_ff, dropout)
            for _ in range(n_layers)
        ])

        # Decoder stack
        self.decoder = nn.ModuleList([
            DecoderBlock(d_model, n_heads, d_ff, dropout)
            for _ in range(n_layers)
        ])

        # Output projection
        self.output_proj = nn.Linear(d_model, tgt_vocab)
        self.dropout = nn.Dropout(dropout)

    def encode(self, src):
        # src: [B, T_src] → [B, T_src, d_model]
        x = self.src_embed(src) * self.scale
        x = x + self.pos_enc[:x.size(1)].to(x.device)
        x = self.dropout(x)
        for layer in self.encoder:
            x = layer(x)
        return x  # encoder output: [B, T_src, d_model]

    def decode(self, tgt, enc_out):
        # tgt: [B, T_tgt] → [B, T_tgt, tgt_vocab]
        T = tgt.size(1)
        mask = torch.tril(torch.ones(T, T)).to(tgt.device)
        x = self.tgt_embed(tgt) * self.scale
        x = x + self.pos_enc[:T].to(x.device)
        x = self.dropout(x)
        for layer in self.decoder:
            x = layer(x, enc_out, mask)
        return self.output_proj(x)  # [B, T_tgt, vocab]

    def forward(self, src, tgt):
        enc_out = self.encode(src)
        return self.decode(tgt, enc_out)

# Instantiate: ~65M parameters (base config)
model = Transformer(
    src_vocab=37000, tgt_vocab=37000,
    d_model=512, n_heads=8, n_layers=6, d_ff=2048
)
n_params = sum(p.numel() for p in model.parameters())
print(f"Parameters: {n_params:,}")  # ~65,000,000
Closing thought: "Attention is all you need" is more than a paper title — it's a design philosophy. The Transformer showed that a single, elegant mechanism (attention) can replace the complex, task-specific architectures that dominated NLP for decades. Recurrence, convolution, hand-crafted features — none of them survived. The field converged on a single building block, and the results speak for themselves. Every word you're reading right now was likely touched, at some point, by a Transformer.
What are the three main variants of the Transformer architecture, and what is each best suited for?