Introduction

In Article 01, we turned text into sequences of vectors. Each token became a point in a high-dimensional embedding space, carrying semantic meaning but utterly ignorant of its neighbors. The word "bank" in "river bank" and "bank account" had the same initial embedding. Context had not yet entered the picture.

This article is about how context enters. The self-attention mechanism is the core operation that lets tokens communicate: each token broadcasts a query ("what am I looking for?"), every token broadcasts a key ("here's what I have"), and attention scores determine how much information flows from one position to another. The result is a context-aware representation where "bank" near "river" becomes an entirely different vector than "bank" near "account."

But attention alone is not enough. The transformer block wraps attention with a feed-forward network, residual connections, and layer normalization to form a complete processing unit. Stack 32 to 128 of these blocks, and you have a modern large language model.

ℹ What this article covers
We'll start with the intuition behind attention, derive the scaled dot-product formulation with full math, explore multi-head attention and causal masking, then assemble all the components into a complete transformer block. We'll cover modern variants like RMSNorm, SwiGLU, and Grouped Query Attention that appear in LLaMA, Mistral, and other recent models. Four interactive visualizations and complete PyTorch implementations are included.

Self-Attention Intuition

Consider the sentence "The cat sat on the mat because it was tired." What does "it" refer to? A human reader instantly knows "it" refers to "the cat," not "the mat." But how could a neural network figure this out?

Before transformers, the dominant approach was recurrence (LSTMs, GRUs). Information flowed sequentially: token 1 to token 2 to token 3 and so on. By the time the model reached "it" at position 9, the representation of "cat" at position 2 had been compressed through seven sequential steps, diluted by every intervening token. Long-range dependencies were theoretically possible but practically difficult.

Self-attention solves this with direct connections. Every token can attend directly to every other token in a single step, regardless of distance. "It" at position 9 can look directly at "cat" at position 2 with the same ease as looking at "was" at position 8. The computational path length for any dependency is O(1), not O(n).

The key insight is that attention is content-based. Rather than looking at fixed positions, each token computes a compatibility score with every other token based on their content (their vector representations). The model learns which content patterns should attend to which other patterns. "Pronoun" patterns learn to attend to "noun" patterns; "verb" patterns learn to attend to "subject" patterns; "adjective" patterns learn to attend to the "noun" they modify.

This is fundamentally different from a fixed wiring pattern. The attention pattern is dynamic -- it changes with every input. The same model will produce completely different attention weights for different sentences, because the content determines the routing.

Scaled Dot-Product Attention

The Q, K, V Formulation

Self-attention operates through three learned linear projections applied to each input token. Given an input sequence X of shape (seq_len, d_model), we compute:

Q = X WQ     K = X WK     V = X WV

where WQ, WK are (d_model, d_k) and WV is (d_model, d_v). Typically d_k = d_v = d_model / n_heads.

The three projections serve distinct roles:

  • Query (Q) — "What am I looking for?" Each token's query encodes what kind of information this position needs from other tokens.
  • Key (K) — "What do I contain?" Each token's key encodes what kind of information this position offers to other tokens.
  • Value (V) — "What information do I send?" When a query matches a key, the corresponding value is what actually gets transmitted.

The analogy is a soft database lookup. In a hard lookup, you search for an exact key match and retrieve the associated value. In attention, every query matches every key to some degree, and the output is a weighted sum of all values, where the weights come from query-key compatibility.

The full attention computation is:

Attention(Q, K, V) = softmax( Q KT / √dk ) V

Let's decompose this step by step. First, Q KT is a matrix multiplication of shape (seq_len, d_k) * (d_k, seq_len) = (seq_len, seq_len). Each element (i, j) is the dot product of query i with key j, measuring how compatible position i's question is with position j's content. High dot products mean high compatibility.

The result is a seq_len x seq_len matrix of raw attention scores. Each row corresponds to one query position, and the values in that row indicate how much that position wants to attend to every other position.

Next, we divide by √dk (the scaling factor), apply softmax row-wise to get a proper probability distribution, and multiply by V to produce the final output. Each output position is a weighted average of all value vectors, where the weights are the attention probabilities.

Why Scale by √dk

This scaling factor is easy to overlook but critically important. Without it, training becomes unstable for large d_k. Here's why.

Assume the components of Q and K are independent random variables with mean 0 and variance 1. The dot product q · k = ∑i=1d_k qi ki is a sum of d_k independent random variables, each with mean 0 and variance 1. By the central limit theorem, the dot product has mean 0 and variance d_k.

As d_k grows, the dot products grow in magnitude. When d_k = 64, the standard deviation of the dot products is 8. When d_k = 128, it's about 11.3. These large magnitudes push the softmax into regions where its gradient is extremely small (the "saturation" zones), making gradient-based learning almost impossible.

∑ The variance argument
If qi, ki ~ N(0, 1) independently, then:

Var(q · k) = Var(∑ qiki) = ∑ Var(qiki) = ∑ E[qi2]E[ki2] = dk · 1 · 1 = dk

Dividing by √dk normalizes: Var(q · k / √dk) = dk / dk = 1

Dividing by √dk rescales the dot products to have unit variance regardless of the dimensionality, keeping softmax in its well-behaved regime where gradients flow properly. This is such a simple fix that it's easy to underestimate its importance -- but without it, transformers with large d_k simply don't train.

Attention Heatmap Interactive

Hover over cells to see attention weights. Each row shows which tokens a given query attends to. Brighter colors indicate stronger attention. Notice how "it" attends most strongly to "cat."

Hover over a cell

Multi-Head Attention

Parallel Attention Heads

A single attention head can only focus on one type of relationship at a time. But language is full of simultaneous relationships: syntactic (subject-verb agreement), semantic (coreference like "it" to "cat"), positional (adjacent word interactions), and more.

Multi-head attention addresses this by running multiple attention operations in parallel, each with its own learned Q, K, V projections. If the model dimension is d_model = 768 and we use h = 12 heads, each head operates on d_k = d_v = 768 / 12 = 64 dimensions.

Crucially, the total computation is approximately the same as a single head with full dimensionality. We're splitting the representation space, not expanding it. Each head gets a 64-dimensional slice of the 768-dimensional space, and the heads collectively cover the full space.

headi = Attention(X WQi, X WKi, X WVi)

MultiHead(X) = Concat(head1, ..., headh) WO

Each head independently computes attention over its own subspace, free to specialize in different linguistic phenomena. Research has shown that heads do indeed specialize:

  • Positional heads — attend primarily to adjacent or nearby tokens, capturing local syntax.
  • Syntactic heads — track subject-verb relationships, relative clauses, and dependency parse edges.
  • Semantic heads — resolve coreference ("it" to "cat") and semantic role assignments.
  • Induction heads — detect repeated patterns and predict continuations, a key mechanism for in-context learning.

Concatenation & Output Projection

After each head produces its output of shape (seq_len, d_v), we concatenate all heads along the feature dimension to get (seq_len, h * d_v) = (seq_len, d_model). This concatenated tensor is then projected through WO of shape (d_model, d_model) to produce the final multi-head attention output.

The output projection WO serves a critical role: it allows the model to learn how to combine information from different heads. Without it, the heads would be completely independent streams. With it, the model can learn inter-head interactions, routing information from one head's subspace into another's.

💡 Why not just one big head?

A single attention head with d_k = 768 would have 768-dimensional queries and keys. The dot product between two 768-dimensional vectors has high variance (sd = 27.7), requiring aggressive scaling. More importantly, the single attention distribution must be a compromise across all the different relationships the model needs to capture. Multi-head attention avoids this by letting each head specialize. The total parameter count is identical (the projection matrices simply get split), but the representational capacity increases because each head can learn an independent attention pattern.

Multi-Head Attention Patterns Interactive

Each head learns to attend to different aspects of the input. Select a head to see its attention pattern. Notice how Head 1 tracks local syntax while Head 3 captures long-range coreference.

Showing: Positional head

The Attention Mask: Causal Masking

In an autoregressive language model (GPT, LLaMA, Mistral), each token should only attend to tokens at earlier positions -- never to future tokens. If position 5 could attend to position 8, the model would be "cheating": it would know the answer before predicting it.

This constraint is enforced through a causal mask (also called a look-ahead mask). Before applying softmax, we add a mask matrix to the attention scores:

maskij = 0  if  j ≤ i,    -∞  if  j > i

Adding -∞ to the scores for future positions causes softmax to assign them exactly zero probability. The resulting attention pattern is a lower-triangular matrix: each token attends only to itself and all preceding tokens.

In practice, the mask is a boolean tensor or a float tensor filled with 0 and -1e9 (or float('-inf')), created once and reused across all layers and heads. For a sequence of length n, the mask is an n x n upper-triangular matrix of negative infinities.

Encoder-only models like BERT use bidirectional attention -- no mask at all (or only a padding mask). Each token can attend to every other token. This is possible because BERT is trained with masked language modeling (predict the [MASK] token), not autoregressive generation. Encoder-decoder models like T5 use causal masking in the decoder but bidirectional attention in the encoder.

ℹ Prefix caching and the mask
Modern serving systems exploit the causal mask for efficiency. Since the attention computation for token i depends only on tokens 0..i, the key-value pairs for a shared prefix can be computed once and reused across multiple completions. This is the basis of KV-cache and prefix caching in systems like vLLM and TensorRT-LLM.
Causal Attention Mask Interactive

Hover over a row to see which tokens that position can attend to. Green cells are visible (attend), dark cells are masked (blocked). Toggle between causal and bidirectional modes.

Mode: Causal

Sliding Window Attention

Full causal attention lets every token attend to all previous tokens. This is powerful but expensive: the cost is O(n²) in both memory and compute, where n is the sequence length. For a 128K-token context, that means roughly 16 billion attention entries per head per layer -- most of which contribute negligible weight.

Sliding window attention replaces the full lower-triangular mask with a band-diagonal mask of width W. Each token attends only to the W most recent tokens (including itself). Attention weights outside the window are forced to zero:

Attention(Q, K, V)i = softmax( Qi K[i-W+1 : i]T / √dk ) · V[i-W+1 : i]

Full causal mask (lower-triangular) vs. Sliding window mask (band-diagonal):

Full causal:   Sliding (W=3): [1 0 0 0 0 0]  [1 0 0 0 0 0] [1 1 0 0 0 0]  [1 1 0 0 0 0] [1 1 1 0 0 0]  [1 1 1 0 0 0] [1 1 1 1 0 0]  [0 1 1 1 0 0] [1 1 1 1 1 0]  [0 0 1 1 1 0] [1 1 1 1 1 1]  [0 0 0 1 1 1]

This reduces cost from O(n²) to O(n · W). With a window size of W = 4096, Mistral 7B (Jiang et al., 2023) processes long sequences with a fraction of the memory required by full attention. The KV cache per layer is bounded by W entries rather than growing linearly with sequence length.

Information Propagation Across Layers

A natural concern: if each token only sees W positions back, can distant tokens ever influence each other? Yes -- through the cascade effect across layers. At layer 1, token t sees tokens [t-W+1, t]. But at layer 2, each of those tokens has already absorbed information from their window, so token t now has indirect access to tokens as far back as t - 2W + 1. After L layers, the effective receptive field extends to L · W tokens.

💡 Cascade receptive field
With L transformer layers and window size W, the theoretical receptive field is L × W tokens. Mistral 7B uses 32 layers with W = 4096, giving an effective context of 32 × 4096 = 131,072 tokens -- despite each individual attention layer only looking back 4,096 positions. Information from token 0 can reach token 131,072 by propagating through one layer at a time, each hop extending the reach by W positions.

Comparison to other sparse patterns. Sliding window attention is the simplest sparse pattern, but it is not the only one. Longformer (Beltagy et al., 2020) and BigBird (Zaheer et al., 2020) combine local sliding windows with a small number of global attention tokens (e.g., the [CLS] token attends to all positions and all positions attend to it). Some architectures use dilated windows -- skipping every other token to double the window span without doubling the cost. In practice, the pure sliding window used by Mistral has proven remarkably effective for autoregressive generation.

Implementation. A sliding window is just a different attention mask -- band-diagonal instead of lower-triangular. With FlashAttention, you never materialize the full n × n mask; instead, the kernel simply skips blocks that fall outside the band, yielding wall-clock speedups proportional to the sparsity.

# Creating a sliding window attention mask in PyTorch
import torch

def sliding_window_mask(seq_len: int, window_size: int) -> torch.Tensor:
    """Build a causal sliding window mask.

    Returns a (seq_len, seq_len) boolean tensor where True = attend.
    Each token attends to at most `window_size` previous tokens
    (including itself).
    """
    # Start with a standard causal mask (lower-triangular)
    causal = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool))
    # Remove entries beyond the window
    window_band = torch.triu(
        torch.ones(seq_len, seq_len, dtype=torch.bool),
        diagonal=-(window_size - 1),
    )
    return causal & window_band

# Example: seq_len=8, window_size=3
mask = sliding_window_mask(8, 3)
# Row 5 attends to columns [3, 4, 5] only

The Feed-Forward Network Sublayer

After attention mixes information across positions, each position passes through an identical feed-forward network (FFN) independently. This is sometimes called the "MLP sublayer" or "position-wise FFN." It applies the same two-layer network to every position independently -- no cross-position interaction here, just per-token transformation.

FFN(x) = W2 · σ(W1 x + b1) + b2

The inner dimension (often called d_ff) is typically 4 times d_model. So if d_model = 4096, the FFN expands to d_ff = 16384 dimensions and contracts back. This expansion-contraction pattern is universal across all transformer architectures.

Why the 4x expansion? The FFN acts as a key-value memory (Geva et al., 2021). Each row of W1 is a "key" that matches a pattern in the residual stream, and the corresponding column of W2 is the "value" that gets written back. The 4x expansion gives the FFN 4 times as many key-value slots as the model dimension, providing substantial storage capacity for factual knowledge and learned transformations.

Activation Functions: ReLU to GELU to SwiGLU

The original transformer (Vaswani et al., 2017) used ReLU: σ(x) = max(0, x). Simple and effective, but ReLU zeros out all negative values, permanently killing those neurons for a given input. This means roughly half the expanded dimensions are wasted on any given forward pass.

GELU (Gaussian Error Linear Unit, Hendrycks & Gimpel, 2016) replaced ReLU in most models starting with BERT. Rather than a hard cutoff at 0, GELU smoothly gates the input based on its value:

GELU(x) = x · Φ(x) ≈ 0.5x(1 + tanh(√(2/π)(x + 0.044715x3)))

where Φ is the Gaussian CDF. GELU allows small negative values to pass through with reduced magnitude, providing smoother gradients and slightly better training dynamics. GPT-2, GPT-3, and BERT all use GELU.

SwiGLU: The Modern Standard

SwiGLU (Shazeer, 2020) combines the Swish activation with a Gated Linear Unit. It splits the FFN into two parallel linear projections and uses one to gate the other:

SwiGLU(x) = (W1 x · Swish(β)) ⊙ (Wgate x)

where Swish(x) = x · σ(x) and σ is the sigmoid function

This requires three weight matrices instead of two (W1, Wgate, W2), so the inner dimension is typically reduced from 4 * d_model to (8/3) * d_model to keep the parameter count comparable. LLaMA, Mistral, Gemma, and most modern models use SwiGLU. Empirically, it trains 5--10% faster to reach the same loss compared to GELU FFNs.

Layer Normalization

Layer normalization (Ba et al., 2016) normalizes the activations across the feature dimension for each token independently. Given a vector x of dimension d:

LayerNorm(x) = γ ⊙ (x - μ) / √(σ2 + ε) + β

where μ = (1/d) ∑ xi,    σ2 = (1/d) ∑ (xi - μ)2

The learned parameters γ (scale) and β (shift) are per-dimension, allowing the model to undo the normalization where beneficial. The constant ε (typically 1e-5 or 1e-6) prevents division by zero.

Why normalize? Without normalization, activations can grow or shrink exponentially as they pass through many layers. This makes training unstable: gradients explode or vanish, and the learning rate that works for one layer is wrong for another. Layer normalization keeps every layer's inputs on a consistent scale.

Pre-LN vs Post-LN

The original transformer used Post-LN: normalization is applied after the residual addition. This means the residual stream itself is unnormalized, and the sublayer output is LayerNorm(x + Sublayer(x)).

Pre-LN (now standard) applies normalization before the sublayer: x + Sublayer(LayerNorm(x)). This arrangement has a crucial advantage: the residual stream is never normalized, so gradients can flow through it unimpeded. The sublayer operates on normalized inputs, which stabilizes training, but the skip connection provides a direct gradient highway from the loss back to early layers.

Pre-LN is more stable during training (it's harder to make it diverge) and eliminates the need for careful learning rate warmup that Post-LN requires. Virtually all modern LLMs use Pre-LN. GPT-2 was one of the first to adopt it, and GPT-3, LLaMA, Mistral, and others followed.

RMSNorm: Simplifying Further

RMSNorm (Zhang & Sennrich, 2019) drops the mean-centering step entirely, normalizing only by the root mean square:

RMSNorm(x) = γ ⊙ x / √((1/d) ∑ xi2 + ε)

No β parameter, no mean subtraction. This is both simpler and slightly faster than full LayerNorm, while empirically performing just as well. The intuition is that the re-centering provided by mean subtraction is redundant -- the model can learn to adjust the mean through other parameters. LLaMA (all versions), Mistral, Gemma, and Qwen all use RMSNorm instead of LayerNorm.

Residual Connections & Gradient Flow

Every sublayer in a transformer block is wrapped in a residual connection (He et al., 2016): output = x + Sublayer(x). This seemingly trivial addition is one of the most important design decisions in deep learning.

Without residual connections, a 96-layer transformer would need gradients to backpropagate through 96 sequential nonlinear transformations. In practice, gradients would vanish long before reaching the early layers, making them effectively untrainable.

The residual connection provides a gradient superhighway. During backpropagation, the gradient of the loss with respect to layer l's input is:

∂L/∂xl = ∂L/∂xL · ∏i=lL-1 (I + ∂Fi/∂xi)

The identity matrix I in each factor ensures that even if the sublayer gradients ∂Fi/∂xi are small, the gradient through the residual path is preserved. There is always a direct path from the loss to any layer that doesn't pass through any nonlinearity.

This is why modern transformers can be stacked to extreme depths (96 layers in GPT-3, 80 layers in LLaMA 2 70B, 128 layers in some recent models) while remaining trainable. The residual stream acts as a "memory bus" that all layers read from and write to, and each layer's contribution is additive.

💡 The residual stream perspective

Elhage et al. (2021) introduced the "residual stream" framing: instead of thinking of a transformer as a sequence of layers that transform a hidden state, think of it as a stream of vectors that layers read from and write to. Each attention head and FFN sublayer reads from the stream, computes something, and adds its result back. The final output is the initial embedding plus the sum of all sublayer contributions. This perspective makes it easier to understand circuits, feature superposition, and how information flows through deep transformers.

The Full Transformer Block

Now we can assemble all the pieces. A single Pre-LN transformer block processes an input x of shape (batch, seq_len, d_model) through two sublayers:

h = x + MultiHeadAttention(RMSNorm(x))
output = h + FFN(RMSNorm(h))

That's it. The complete data flow is:

  1. RMSNorm the input.
  2. Multi-head attention on the normalized input (Q, K, V all come from the same normalized tensor for self-attention).
  3. Residual add the attention output back to the original input.
  4. RMSNorm the result.
  5. Feed-forward network (SwiGLU) on the normalized result.
  6. Residual add the FFN output back.

A complete model stacks N of these blocks sequentially. LLaMA 2 7B uses 32 blocks; LLaMA 2 70B uses 80. The output of the final block passes through one more RMSNorm and then a linear projection to vocabulary size for next-token prediction.

The elegance of this design is its uniformity. Every block has the exact same structure. The only thing that changes between early and late blocks is the learned weights. This makes transformers straightforward to implement, parallelize, and scale.

Transformer Block — Data Flow Animated

Watch data flow through a single Pre-LN transformer block. The animated particle follows the computation path: normalize, attend, add residual, normalize, FFN, add residual.

Click Play to start

Grouped Query Attention & Multi-Query Attention

Standard multi-head attention has separate Q, K, and V projections for each head. During autoregressive inference, the K and V tensors from all previous positions must be stored in the KV cache to avoid recomputation. With 32 heads and 4096-dimensional model, that's 32 * 2 * 128 * seq_len * 2 bytes = 16 KB per token per layer (in float16). For a 32-layer model generating 4096 tokens, the KV cache alone is 32 * 16 * 4096 = 2 GB. This becomes the primary memory bottleneck during inference, especially for long contexts and high batch sizes.

Multi-Query Attention (MQA, Shazeer, 2019) takes a radical approach: all heads share a single set of K and V projections, while each head still has its own Q projection. This reduces the KV cache by a factor of n_heads (e.g., 32x). The cost is some quality degradation, since all heads must use the same keys and values.

Grouped Query Attention (GQA, Ainslie et al., 2023) is the middle ground. Heads are divided into groups, and each group shares K and V projections. If you have 32 query heads and 8 KV groups, each group of 4 query heads shares one set of keys and values. The KV cache shrinks by 4x (not 32x), but quality is much closer to full MHA.

Model Attention type Query heads KV heads KV cache ratio
GPT-3 175B MHA 96 96 1x (baseline)
LLaMA 2 7B MHA 32 32 1x
LLaMA 2 70B GQA 64 8 8x reduction
LLaMA 3 8B GQA 32 8 4x reduction
Mistral 7B GQA 32 8 4x reduction
Falcon 40B MQA 64 1 64x reduction

GQA has become the standard for models above ~7B parameters, offering the best tradeoff between quality and inference efficiency. The quality difference between GQA and full MHA is typically less than 0.5% on standard benchmarks, while the inference throughput improvement can be 2-4x at large batch sizes.

Code Examples

Theory is best anchored by working code. Here are complete implementations of every component we've discussed, building up to a full transformer block in PyTorch.

Scaled Dot-Product Attention

python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def scaled_dot_product_attention(
    Q: torch.Tensor,    # (batch, heads, seq_len, d_k)
    K: torch.Tensor,    # (batch, heads, seq_len, d_k)
    V: torch.Tensor,    # (batch, heads, seq_len, d_v)
    mask: torch.Tensor = None,  # (seq_len, seq_len) or broadcastable
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Scaled dot-product attention as defined in 'Attention Is All You Need'.
    Returns (output, attention_weights).
    """
    d_k = Q.size(-1)

    # Step 1: Compute raw attention scores
    # Q @ K^T → (batch, heads, seq_len, seq_len)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

    # Step 2: Apply mask (causal or padding)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    # Step 3: Softmax over the key dimension
    attn_weights = F.softmax(scores, dim=-1)

    # Step 4: Weighted sum of values
    output = torch.matmul(attn_weights, V)

    return output, attn_weights

Multi-Head Attention with GQA Support

python
class MultiHeadAttention(nn.Module):
    """
    Multi-head attention supporting MHA, GQA, and MQA.

    - n_kv_heads == n_heads → standard MHA
    - n_kv_heads == 1       → MQA (Multi-Query Attention)
    - 1 < n_kv_heads < n_heads → GQA (Grouped Query Attention)
    """
    def __init__(self, d_model: int, n_heads: int, n_kv_heads: int = None):
        super().__init__()
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads or n_heads
        self.d_k = d_model // n_heads
        self.n_rep = n_heads // self.n_kv_heads  # how many Q heads per KV head

        # Query projection: full n_heads
        self.W_q = nn.Linear(d_model, n_heads * self.d_k, bias=False)
        # Key/Value projections: only n_kv_heads (shared across groups)
        self.W_k = nn.Linear(d_model, self.n_kv_heads * self.d_k, bias=False)
        self.W_v = nn.Linear(d_model, self.n_kv_heads * self.d_k, bias=False)
        # Output projection
        self.W_o = nn.Linear(n_heads * self.d_k, d_model, bias=False)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
        B, L, _ = x.shape

        # Project to Q, K, V
        Q = self.W_q(x).view(B, L, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(B, L, self.n_kv_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(B, L, self.n_kv_heads, self.d_k).transpose(1, 2)

        # Repeat K, V for grouped query attention
        if self.n_rep > 1:
            K = K.repeat_interleave(self.n_rep, dim=1)  # (B, n_heads, L, d_k)
            V = V.repeat_interleave(self.n_rep, dim=1)

        # Scaled dot-product attention
        out, attn = scaled_dot_product_attention(Q, K, V, mask)

        # Concatenate heads and project
        out = out.transpose(1, 2).contiguous().view(B, L, -1)
        return self.W_o(out)

RMSNorm & SwiGLU FFN

python
class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization (Zhang & Sennrich, 2019)."""
    def __init__(self, d_model: int, eps: float = 1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(d_model))
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # RMS = sqrt(mean(x^2))
        rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        return x / rms * self.weight


class SwiGLU_FFN(nn.Module):
    """
    SwiGLU Feed-Forward Network (Shazeer, 2020).
    Uses 3 linear projections with 8/3 expansion ratio
    to match parameter count of standard 4x FFN.
    """
    def __init__(self, d_model: int, d_ff: int = None):
        super().__init__()
        # Default: (8/3) * d_model, rounded to nearest multiple of 256
        if d_ff is None:
            d_ff = int(2 * (4 * d_model) / 3)
            d_ff = 256 * ((d_ff + 255) // 256)  # round up

        self.w1   = nn.Linear(d_model, d_ff, bias=False)  # gate projection
        self.w2   = nn.Linear(d_ff, d_model, bias=False)  # down projection
        self.w3   = nn.Linear(d_model, d_ff, bias=False)  # up projection

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # SwiGLU: w2(swish(w1(x)) * w3(x))
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

Complete Transformer Block

python
class TransformerBlock(nn.Module):
    """
    A single Pre-LN transformer block with:
    - RMSNorm (pre-normalization)
    - Multi-Head Attention with GQA support
    - SwiGLU FFN
    - Residual connections
    """
    def __init__(self, d_model: int, n_heads: int, n_kv_heads: int = None,
                 d_ff: int = None):
        super().__init__()
        self.attn_norm = RMSNorm(d_model)
        self.attn = MultiHeadAttention(d_model, n_heads, n_kv_heads)
        self.ffn_norm = RMSNorm(d_model)
        self.ffn = SwiGLU_FFN(d_model, d_ff)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
        # Sublayer 1: Attention with residual
        h = x + self.attn(self.attn_norm(x), mask)
        # Sublayer 2: FFN with residual
        out = h + self.ffn(self.ffn_norm(h))
        return out


# ── Usage example: LLaMA-style model ──────────────────────
# LLaMA 2 7B configuration
config = dict(d_model=4096, n_heads=32, n_kv_heads=32, n_layers=32)

# LLaMA 2 70B configuration (with GQA)
config_70b = dict(d_model=8192, n_heads=64, n_kv_heads=8, n_layers=80)

# Build a single block
block = TransformerBlock(
    d_model=config['d_model'],
    n_heads=config['n_heads'],
    n_kv_heads=config['n_kv_heads'],
)

# Forward pass
x = torch.randn(2, 128, config['d_model'])  # (batch=2, seq_len=128, d=4096)
causal_mask = torch.tril(torch.ones(128, 128)).unsqueeze(0).unsqueeze(0)
output = block(x, mask=causal_mask)
print(f"Input:  {x.shape}")       # torch.Size([2, 128, 4096])
print(f"Output: {output.shape}")   # torch.Size([2, 128, 4096])

🧭 What comes next

We've now covered the full transformer block -- the repeating unit of every modern LLM. In Article 03: Training at Scale, we'll explore how these blocks are actually trained: the next-token prediction objective, cross-entropy loss, the mechanics of backpropagation through attention, and the distributed training strategies (data parallelism, tensor parallelism, pipeline parallelism) that make billion-parameter models possible.

References

Seminal papers and key works referenced in this article.

  1. Vaswani et al. "Attention Is All You Need." NeurIPS, 2017. arXiv
  2. Bahdanau et al. "Neural Machine Translation by Jointly Learning to Align and Translate." ICLR, 2015. arXiv
  3. Ainslie et al. "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints." EMNLP, 2023. arXiv
  4. Shazeer. "Fast Transformer Decoding: One Write-Head is All You Need." 2019. arXiv
  5. Press et al. "Train Short, Test Long: Attention with Linear Biases Enables Input Length Generalization." ICLR, 2022. arXiv