Day In The Life — AI Research Infrastructure

Frontier Lab Engineer

Staff-level interview prep: GPU kernels, quantization, JAX, scaling laws, LLM-guided search, and the full stack from roofline to deployment.

Prerequisites: PyTorch basics + Linear algebra + Some Python. That's it.
17
Chapters
16+
Simulations
5
Interview Dimensions

Chapter 0: The Role

It is 7:30 AM. You badge into the open-plan research floor at a frontier AI lab — think Anthropic, DeepMind, or OpenAI scale. On one monitor, a Pallas kernel you wrote for a custom grouped-query attention variant is 15% slower than the CUTLASS baseline on H100. The kernel profiles fine on a single chip, but when you run it through the full TP=8 sharding, an unexpected AllReduce is eating 2 ms per layer. You need to figure out why before the training run launches at noon — 2048 H100s are reserved, and idle time costs the lab roughly $50 per GPU per hour.

On your second monitor, a colleague on the post-training team has pinged you. They are running a QuIP#-style 2-bit quantization experiment on the latest 70B checkpoint. Perplexity on the eval set is 0.4 nats higher than the FP16 baseline — tolerable for a 16x compression ratio. But on a specific subset of math reasoning benchmarks, accuracy has cratered from 68% to 41%. They need your diagnosis. Is the quantization destroying the precision in the MLP up-projection layers that are critical for chain-of-thought arithmetic? Or is the calibration set too narrow?

On Slack, the scaling laws team has posted a sprawling analysis. They are debating the next model generation: should the lab invest the $80M compute budget in a dense 400B model, or a Mixture-of-Experts with 16 experts and 1.2T total parameters? The MoE activates only 70B parameters per token, giving it effectively the same inference cost as a dense 70B model, but training requires careful load balancing and routing. They want your compute analysis — FLOPs, memory, communication overhead, and predicted loss at the target budget — by end of day.

Before lunch, you will fix the kernel (the AllReduce was happening on the wrong axis due to a sharding annotation typo in the Pallas mesh spec), diagnose the quantization issue (the calibration set had zero math-heavy prompts, so the activation ranges for arithmetic-critical layers were miscalibrated), and draft the first half of the scaling comparison (using the C ≈ 6ND relationship to compute Chinchilla-optimal points for both architectures).

This is the daily reality of a Frontier Lab Engineer. You sit at the intersection of three disciplines that rarely overlap in a single person:

DisciplineWhat you needHow it shows up daily
Systems & HardwareGPU memory hierarchy, kernel writing (Pallas/Triton/CUDA), roofline analysis, communication primitivesYou profile a Pallas kernel in the XLA profiler, discover a bank conflict in shared memory, fix it, and cut the custom attention from 4.2 ms to 2.9 ms per layer
ML ResearchTransformer architecture, scaling laws, training dynamics, loss landscapes, alignment methodsYou read the DeepSeek-V3 paper to understand their MoE load-balancing loss, then estimate whether the technique saves enough FLOPs to justify the engineering investment
Tooling & InfrastructureJAX/XLA ecosystem, distributed training frameworks, TPU/GPU cluster management, experiment trackingYou write a JAX training loop with FSDP sharding, add gradient checkpointing to fit within HBM, and instrument it with TensorBoard profiling hooks
Two sides of the same coin. Vlad Mnih's career advice blog (the inspiration for this lesson) frames frontier lab engineering as working "below the stack" (kernels, hardware, XLA) and "above the abstraction" (agents, search, deployment). A Staff Engineer at a frontier lab is expected to be fluent in both. You might spend the morning debugging an XLA lowering issue in a custom Pallas kernel, and the afternoon designing a tree-search strategy that uses the model as a heuristic function. This lesson covers the full range: from transistor-level memory latency to RLHF reward modeling.

The System You Build

The diagram below traces the lifecycle of a frontier model from pre-training cluster to production serving. Every box is a system you own or co-own. Think of it as the "full-stack map" you would draw on a whiteboard in a system-design interview.

1. Pre-Training Infrastructure
2048+ GPUs/TPUs orchestrated with JAX pjit + GSPMD sharding. Data pipeline: tokenization, packing, curriculum. Checkpointing every N steps with async serialization. You measure throughput in MFU (Model FLOPs Utilization), not just loss. Target: MFU > 0.45 on H100 clusters.
2. Scaling Law Analysis
Before committing $50-200M in compute, you run dozens of small models (70M to 7B) to fit a parametric loss curve. Chinchilla: L(N,D) = aNα + bDβ + c. You predict the loss at 400B parameters and 10T tokens, then decide: dense vs MoE, context length, vocabulary size.
3. Custom Kernel Development
Standard attention is too slow for your 128K context window. You write Pallas/Triton kernels for FlashAttention-3 with ring attention for cross-device sequence parallelism. You benchmark with roofline analysis, profile with Nsight/XLA profiler, and iterate until MFU hits target.
4. Post-Training Pipeline
SFT (supervised fine-tuning) on curated instruction data. Then RLHF or DPO for alignment. Then tool-use training, function calling, and safety training. Each stage has its own reward model, data pipeline, and eval suite. You own the training infrastructure for all of them.
5. Quantization & Distillation
The 400B FP16 model is 800 GB — too large for a single node. You apply GPTQ/AWQ for INT4 weights, FP8 KV cache, speculative decoding with a 7B draft model. Target: 50 tokens/sec on a single 8xH100 node at INT4, within 1% of FP16 eval scores.
6. Serving & Inference
vLLM/TGI serving with PagedAttention, continuous batching, tensor parallelism. Autoscaling on Kubernetes. Latency SLOs: p50 < 200ms TTFT, p99 < 20ms inter-token. Cost tracking per million tokens. A/B testing new model versions against the production baseline.

What Separates Staff from Senior

A senior engineer can train a model. They know the APIs, can write a training loop, can debug a loss spike. Give them an architecture and a compute budget, and they will hit a target loss.

A staff engineer decides which architecture to train. They run the scaling law experiments that tell you whether 400B dense or 1.2T MoE is optimal for your budget. They design the sharding strategy that determines whether your training run achieves 42% or 48% MFU — a difference that represents $6M in compute savings at scale. They build the evaluation framework that catches capability regressions before the model ships. And when the next generation of hardware arrives (B200, TPU v6), they redesign the parallelism strategy and kernel library rather than patching the old system.

The distinction is leverage. A senior engineer makes one training run succeed. A staff engineer builds the infrastructure that makes every training run at the lab more efficient, more reliable, and more reproducible.

The staff-level mental model. In every system-design interview at a frontier lab, think in three time horizons: (1) what ships this quarter (a concrete model with target evals), (2) what scales to the next model generation (automated scaling law sweeps, reproducible training recipes), and (3) what survives a hardware transition (hardware-agnostic abstractions, kernel libraries with pluggable backends). Interviewers at these labs are specifically looking for all three levels of thinking.

The Five Interview Dimensions

Every strong interview loop at a frontier lab tests five orthogonal skills. Each chapter in this lesson hits multiple dimensions, but the table below shows the kind of question each dimension produces.

DimensionWhat they testExample questionWhat a staff answer adds
CONCEPTFirst-principles math and theory"Derive the roofline bound for multi-head attention on H100"Connects the bound to practical kernel design decisions
DESIGNSystem architecture and trade-offs"Design the sharding strategy for a 400B MoE model across 2048 GPUs"Discusses expert parallelism, communication topology, failure recovery
CODEImplementation in JAX/Python/Triton"Write a JAX training loop with FSDP and gradient checkpointing"Adds sharding annotations, proper PRNG handling, profiling hooks
DEBUGFailure diagnosis under production pressure"Your 2048-GPU training run's loss spikes every 500 steps. What do you check?"Systematic bisection: data pipeline, gradient norms, learning rate schedule, hardware faults
FRONTIERResearch awareness and taste"What architecture changes since GPT-4 do you think matter most?"Discusses GQA, RoPE, SwiGLU, MoE routing, and why each matters at scale with specific numbers

A Day in the Life: Hour by Hour

TimeTaskSkill used
7:30Check overnight training run: loss curve, gradient norms, MFU dashboardDebug
8:00Profile custom Pallas attention kernel, find unnecessary AllReduceCode + Debug
9:00Fix sharding annotation, re-benchmark, verify MFU improvementCode + Concept
10:00Standup: present scaling law projections for next model generationDesign + Concept
10:30Diagnose quantization anomaly on 70B checkpoint math benchmarksDebug + Concept
11:30Write calibration-aware quantization script, re-run with math-heavy promptsCode
13:00Review colleague's DPO training changes: reward model architecture, data mixFrontier + Design
14:00Design doc: MoE vs dense cost-benefit analysis with FLOPs projectionsDesign + Concept
15:30Prototype ring attention kernel for 128K context sequence parallelismCode + Frontier
17:00Run ablation: GQA head count vs memory savings vs eval degradationConcept + Design

The Blog Recipe

This lesson is structured around the advice from Vlad Mnih's influential blog post on how to prepare for frontier lab roles. The core framework: master both "below the stack" (GPU kernels, memory hierarchy, roofline analysis, XLA compilation) and "above the abstraction" (scaling laws, RLHF, agents, evaluation). The specific preparation he recommends:

Preparation areaWhy it mattersChapters in this lesson
Roofline analysisEvery kernel, every architecture choice comes down to "are we compute-bound or memory-bound?"Ch 1
GPU memory hierarchyYou cannot write fast kernels without understanding registers → SRAM → L2 → HBMCh 2
Transformer mathParameter counts, FLOPs, memory budgets — you must be able to derive these from scratch in an interviewCh 3
JAX fundamentalsFrontier labs use JAX/XLA for training. jit, grad, vmap, pjit are non-negotiableCh 4
Custom kernels (Pallas)When the default XLA lowering is too slow, you write Pallas kernelsCh 5
QuantizationServing a 400B model at production scale requires INT4/FP8 without quality lossCh 6
FlashAttentionThe single most important kernel optimization in modern LLMsCh 7
Scaling laws$50-200M compute decisions depend on predicting loss at target scaleCh 8
MoE architecturesEvery frontier lab is shipping MoE models — the routing/balancing tradeoffs are interview staplesCh 9
RLHF / DPOPost-training alignment is where models become useful — you need to understand the full pipelineCh 10
Distributed trainingDP, TP, PP, FSDP, expert parallelism — 2048+ GPUs require all of them simultaneouslyCh 11
Infra & profilingXLA profiler, Nsight, MFU dashboards, checkpoint managementCh 12
LLM agents & search"Above the abstraction" — tree search, tool use, the LLM as a heuristic functionCh 13
Eval & safetyHow do you know if the model is better? How do you know it is safe?Ch 14
This lesson has 17 chapters. Chapters 1-4 cover foundational analysis (roofline, GPU memory, transformer math, JAX). Chapters 5-7 cover kernel engineering (Pallas, quantization, FlashAttention). Chapters 8-10 cover ML research (scaling laws, MoE, RLHF). Chapters 11-14 cover infrastructure and deployment (distributed training, profiling, agents, evaluation). Chapters 15-16 cover integration and interview preparation. Every chapter follows the same pattern: concept derivation, worked example with real numbers, code, failure modes, frontier research, staff-level quiz.
Staff-level warm-up: A frontier lab is planning a 400B-parameter pre-training run. A senior engineer calculates the compute budget as C = 6ND = 6 × 400B × 8T tokens = 1.92 × 1025 FLOPs. They conclude: "On 2048 H100s at 989 TFLOPS FP16 each, this takes 2.72 hours." What critical assumptions are they missing, and what would you estimate for the actual wall-clock time?

Chapter 1: Roofline Analysis

Your colleague submits a Pallas kernel for grouped-query attention. It runs at 180 TFLOPS on an H100. Is that good? Is there room for improvement? Without a framework for answering this question, you are guessing. The roofline model gives you a principled upper bound on performance for any operation, in under 60 seconds of mental math.

Reiner Pope (DeepMind) popularized a style of back-of-the-envelope analysis where every architecture decision — attention variant, parallelism strategy, batch size — is evaluated through the roofline lens. At a frontier lab, this is the first tool you reach for. Master it, and you will never be surprised by a kernel's performance again.

The Two Resources: Compute and Bandwidth

A GPU has two fundamental resources that limit performance. Every operation you run is bottlenecked by one or the other — never both simultaneously (at the ridge point, both bind equally, but this is rare in practice).

Compute is measured in FLOPS (floating-point operations per second). An H100 SXM5 delivers:

PrecisionPeak TFLOPSHardware unit
FP3267CUDA cores
FP16 / BF16989Tensor Cores (4th gen)
FP8 (E4M3)1,979Tensor Cores
INT81,979Tensor Cores

Memory bandwidth is measured in bytes per second. The H100 has HBM3 at 3.35 TB/s (that is 3,350 GB/s). This is the rate at which data can flow between the GPU's main memory (80 GB of HBM) and its compute units.

Think of it this way: the GPU is a factory. Compute is the number of workers on the factory floor. Bandwidth is the width of the loading dock. If you have 1,000 workers but a narrow loading dock, the workers stand idle waiting for materials. If you have a massive loading dock but only 10 workers, materials pile up unprocessed.

Arithmetic Intensity: The Single Number That Matters

Arithmetic intensity (AI) is the ratio of computation to data movement for a given operation:

AI = FLOPs performed / Bytes transferred to/from memory

The units are FLOPs/byte. This single number tells you whether your operation is compute-bound (limited by how fast the GPU can crunch numbers) or memory-bound (limited by how fast data can flow to the compute units).

Deriving the Roofline from First Principles

Let us derive the roofline model step by step. We want to find the maximum achievable FLOPS for an operation with arithmetic intensity AI.

// Given:
Peak Compute: Π FLOPS    (e.g., 989 TFLOPS for H100 FP16)
Peak Bandwidth: β bytes/s    (e.g., 3.35 TB/s for H100 HBM3)
Operation arithmetic intensity: AI = F/B    (FLOPs/byte)

// The operation transfers B bytes and computes F FLOPs.
// Time to transfer data: tmem = B / β
// Time to compute: tcomp = F / Π

// The total time is at least max(tmem, tcomp).
// (Data transfer and compute can overlap, so we take the max.)

// Achieved FLOPS = F / time = F / max(B/β, F/Π)

// Case 1: Memory-bound (tmem > tcomp, i.e., B/β > F/Π)
Achieved FLOPS = F / (B/β) = (F/B) × β = AI × β

// Case 2: Compute-bound (tcomp > tmem)
Achieved FLOPS = F / (F/Π) = Π

// Combining both cases:
Achieved FLOPS = min(Π, AI × β)

// The transition happens at the "ridge point":
AIridge = Π / β
AIridge = 989 × 1012 / 3.35 × 1012 = 295 FLOPs/byte    (H100, FP16)

This is the entire model. Below the ridge point (AI < 295), you are memory-bound and performance scales linearly with arithmetic intensity. Above it (AI > 295), you are compute-bound and performance is flat at the peak compute rate. Plot this on a log-log graph and you get the characteristic "roofline" shape: a diagonal line (memory-bound region) that hits a flat ceiling (compute-bound region).

The fundamental optimization question. Once you know which side of the ridge you are on, the optimization strategy is completely different. Memory-bound? Reduce data movement (fuse operations, tile into shared memory, reduce precision). Compute-bound? Use faster math (Tensor Cores instead of CUDA cores, FP8 instead of FP16, better algorithms with fewer FLOPs). Applying a compute-bound optimization to a memory-bound kernel does nothing. This is the single most common mistake in kernel optimization.

Worked Example 1: Matrix Multiply on H100

Matrix multiplication is the most important operation in deep learning. Let us analyze it with the roofline model.

// Operation: C = A × B where A is [M, K], B is [K, N], C is [M, N]
// All matrices in FP16 (2 bytes per element)

// FLOPs:
// Each element of C requires K multiply-adds = 2K FLOPs
// There are M × N elements in C
F = 2 × M × K × N

// Bytes transferred (naive, no caching):
// Read A: M × K × 2 bytes
// Read B: K × N × 2 bytes
// Write C: M × N × 2 bytes
B = 2(MK + KN + MN)

// Arithmetic Intensity:
AI = 2MKN / 2(MK + KN + MN) = MKN / (MK + KN + MN)

// For square matrices M = K = N = n:
AI = n3 / 3n2 = n/3

// ═══ Concrete examples on H100 (ridge = 295 FLOPs/byte) ═══

// Case 1: Small matmul n = 256
AI = 256/3 = 85.3    < 295 → MEMORY-BOUND
Achieved FLOPS = 85.3 × 3.35 × 1012 = 286 TFLOPS
Efficiency = 286/989 = 28.9% of peak

// Case 2: Medium matmul n = 1024
AI = 1024/3 = 341.3    > 295 → COMPUTE-BOUND
Achieved FLOPS = 989 TFLOPS (peak)
Efficiency = 100% of peak (in theory)

// Case 3: Typical LLM matmul M=4096, K=4096, N=11008 (LLaMA FFN)
AI = 4096×4096×11008 / (4096×4096 + 4096×11008 + 4096×11008)
   = 1.846×1011 / (1.678×107 + 4.509×107 + 4.509×107)
   = 1.846×1011 / 1.070×108 = 1726 FLOPs/byte
// Massively compute-bound. Tensor Cores dominate performance.

The takeaway: large matrix multiplications are compute-bound. This is why Tensor Cores matter — they are the fast lane for matrix multiplication. Small matrix multiplications (batch size 1, short sequences) can become memory-bound, which is exactly why LLM inference at batch=1 is often bandwidth-limited.

Worked Example 2: Self-Attention Roofline

Self-attention in a transformer is more nuanced because it involves multiple operations with very different arithmetic intensities. Let us dissect it.

// Transformer self-attention for one head:
// Q, K, V each have shape [seq_len, d_head] = [S, d]
// All in FP16 (2 bytes per element)

// Step 1: Compute attention scores P = Q × KT
// Shapes: [S, d] × [d, S] = [S, S]
FLOPs1 = 2 × S × d × S = 2S2d
Bytes1 = 2(Sd + Sd + S2) = 2S(2d + S)

// Step 2: Softmax over each row of P
// ~5 FLOPs per element (exp, sum, divide), S2 elements
FLOPs2 = 5S2
Bytes2 = 2 × 2 × S2 = 4S2    (read + write)

// Step 3: Weighted sum O = softmax(P) × V
// Shapes: [S, S] × [S, d] = [S, d]
FLOPs3 = 2S2d
Bytes3 = 2(S2 + Sd + Sd) = 2S(S + 2d)

// Total attention (ignoring QKV projection, just the core):
Ftotal = 4S2d + 5S2
Btotal = 2S(2d + S) + 4S2 + 2S(S + 2d) = 2S(4d + 2S) + 4S2 = 8Sd + 8S2

// AI = (4S2d + 5S2) / (8Sd + 8S2)
// For large S, the S2 terms dominate:
AI ≈ (4S2d) / (8S2) = d/2

// Typical d_head = 128:
AI ≈ 128/2 = 64 FLOPs/byte
// 64 << 295 (H100 ridge) → attention is MEMORY-BOUND!

// This is why FlashAttention exists: it reduces the memory traffic
// by never materializing the S×S attention matrix to HBM.
// Without FlashAttention: must write and read S2 elements from HBM.
// With FlashAttention: O(S×d) HBM traffic — the S2 term vanishes.
The FlashAttention insight, in one sentence. Attention is memory-bound because the S×S attention matrix dominates the data movement. FlashAttention computes attention in tiles, keeping the S×S matrix in SRAM (shared memory) and never writing it to HBM. This changes the memory traffic from O(S2) to O(S·d), making the kernel much faster despite doing the same number of FLOPs. We will derive FlashAttention in full in Chapter 7.

The Three Bounds: Compute, Bandwidth, Communication

For multi-GPU training and inference, there is a third bound that the single-GPU roofline ignores: communication bandwidth. When your model is sharded across 8 GPUs with tensor parallelism, every attention layer requires an AllReduce of the output. The AllReduce transfers data proportional to the tensor size over NVLink or InfiniBand.

// Three-way roofline for distributed operations:

// Time to compute: tcomp = FLOPs / (NGPUs × Π)
// Time for memory: tmem = Bytes / β
// Time for comm: tcomm = MessageSize / BWnetwork

// For an AllReduce of a [4096, 4096] FP16 matrix over 8 GPUs:
MessageSize = 2 × 4096 × 4096 = 33.6 MB
// Ring AllReduce transfers 2×(N-1)/N × MessageSize data total
// On 8 GPUs: 2 × 7/8 × 33.6 = 58.8 MB across the ring

// H100 NVLink: 900 GB/s bidirectional (450 GB/s per direction)
tcomm = 58.8 MB / 450 GB/s = 0.131 ms

// InfiniBand 400G (inter-node): 50 GB/s effective
tcomm = 58.8 MB / 50 GB/s = 1.18 ms

// For a single transformer layer at 70B scale:
// Compute time: ~2.5 ms (TP=8, H100)
// AllReduce time (NVLink): ~0.13 ms (5% overhead)
// AllReduce time (InfiniBand): ~1.18 ms (47% overhead!)

// This is why NVLink matters so much for tensor parallelism,
// and why pipeline parallelism (fewer, larger messages) is
// preferred for inter-node communication.

The three-way roofline analysis is what separates a staff engineer from a senior. A senior knows that attention is memory-bound. A staff engineer can tell you whether the AllReduce or the attention kernel itself is the bottleneck for a specific sharding configuration, and can redesign the parallelism strategy accordingly.

The Roofline Calculator

python
import dataclasses

# ── GPU specifications ──
@dataclasses.dataclass
class GPUSpec:
    name: str
    peak_tflops_fp16: float    # Tensor Core peak
    hbm_bandwidth_tbs: float  # HBM bandwidth in TB/s
    sram_per_sm_kb: float     # Shared memory per SM in KB
    num_sms: int
    nvlink_bw_gbs: float     # NVLink bandwidth per direction in GB/s

    @property
    def ridge_point(self) -> float:
        """Arithmetic intensity where compute = bandwidth bound."""
        return self.peak_tflops_fp16 * 1e12 / (self.hbm_bandwidth_tbs * 1e12)

    def achieved_tflops(self, arithmetic_intensity: float) -> float:
        """Roofline model: min(peak, AI * bandwidth)."""
        bandwidth_limited = arithmetic_intensity * self.hbm_bandwidth_tbs
        return min(self.peak_tflops_fp16, bandwidth_limited)

    def efficiency(self, arithmetic_intensity: float) -> float:
        """What fraction of peak compute we achieve."""
        return self.achieved_tflops(arithmetic_intensity) / self.peak_tflops_fp16

H100 = GPUSpec(
    name="H100 SXM5",
    peak_tflops_fp16=989,
    hbm_bandwidth_tbs=3.35,
    sram_per_sm_kb=228,
    num_sms=132,
    nvlink_bw_gbs=450,
)

A100 = GPUSpec(
    name="A100 SXM4",
    peak_tflops_fp16=312,
    hbm_bandwidth_tbs=2.0,
    sram_per_sm_kb=164,
    num_sms=108,
    nvlink_bw_gbs=300,
)

# ── Roofline analysis for common operations ──

def matmul_roofline(M, K, N, bytes_per_elem=2, gpu=H100):
    """Roofline analysis for C[M,N] = A[M,K] @ B[K,N]."""
    flops = 2 * M * K * N
    bytes_moved = bytes_per_elem * (M*K + K*N + M*N)
    ai = flops / bytes_moved
    tflops = gpu.achieved_tflops(ai)
    time_ms = flops / (tflops * 1e12) * 1e3
    bound = "COMPUTE" if ai > gpu.ridge_point else "MEMORY"
    print(f"Matmul [{M},{K}]x[{K},{N}] on {gpu.name}:")
    print(f"  FLOPs:      {flops:.2e}")
    print(f"  Bytes:      {bytes_moved:.2e}")
    print(f"  AI:         {ai:.1f} FLOPs/byte")
    print(f"  Ridge:      {gpu.ridge_point:.1f} FLOPs/byte")
    print(f"  Bound:      {bound}")
    print(f"  Achieved:   {tflops:.1f} TFLOPS ({gpu.efficiency(ai)*100:.1f}%)")
    print(f"  Time:       {time_ms:.3f} ms")
    return ai, tflops, time_ms

def attention_roofline(seq_len, d_head, n_heads, gpu=H100):
    """Roofline for standard (non-Flash) multi-head attention."""
    S, d, h = seq_len, d_head, n_heads
    # Per head: 2 matmuls (QK^T, PV) + softmax
    flops_per_head = 2 * (2 * S * S * d) + 5 * S * S
    total_flops = h * flops_per_head
    # Without FlashAttention: must materialize S×S attn matrix per head
    bytes_per_head = 2 * (2*S*d + S*S + S*S + S*d + S*d)  # Q,K,P,P_softmax,V,O
    total_bytes = h * bytes_per_head
    ai = total_flops / total_bytes
    tflops = gpu.achieved_tflops(ai)
    bound = "COMPUTE" if ai > gpu.ridge_point else "MEMORY"
    print(f"\nAttention S={S}, d={d}, h={h} on {gpu.name}:")
    print(f"  AI:       {ai:.1f} FLOPs/byte  ({bound})")
    print(f"  Achieved: {tflops:.1f} TFLOPS ({gpu.efficiency(ai)*100:.1f}%)")
    return ai

# ── Run examples ──
matmul_roofline(4096, 4096, 4096)      # Square matmul
matmul_roofline(1, 4096, 4096)         # Batch=1 inference
matmul_roofline(4096, 4096, 11008)    # LLaMA FFN up-projection

attention_roofline(2048, 128, 32)     # Standard LLM attention
attention_roofline(32768, 128, 32)    # Long-context attention
Why batch=1 matmul is memory-bound. Run matmul_roofline(1, 4096, 4096) above. The result: AI = 1.0 FLOPs/byte, meaning we are at 0.3% of the ridge point. The GPU achieves only 3.35 TFLOPS out of a possible 989 — barely 0.3% efficiency. This is why single-request LLM inference is so slow: each token generation is a batch=1 matrix-vector multiply, and the GPU is almost entirely idle, waiting for data to arrive from HBM. This is the fundamental reason batching, speculative decoding, and quantization (which reduces bytes transferred) matter so much for LLM serving.

Interactive Roofline Plot

Roofline Explorer

Adjust the operation parameters and watch the dot move on the roofline plot. Below the ridge: memory-bound. Above: compute-bound.

Matrix M 1024
Matrix K 4096
Matrix N 4096
GPU H100

Predicting Real-World Throughput

The roofline gives you an upper bound. Real kernels fall below the roofline for several reasons:

InefficiencyWhat it meansTypical penalty
OccupancyNot enough warps in flight to hide memory latency10-50% below roofline if occupancy < 50%
Bank conflictsMultiple threads accessing the same shared memory bank2-32x slowdown on shared memory loads
Non-coalesced accessThreads in a warp reading non-contiguous memory2-32x slowdown on global memory loads
Instruction overheadControl flow, address computation, loop overhead5-15% below roofline for simple kernels
Tail effectsLast block of a grid is partially filledDepends on grid dimensions vs SM count
L2 thrashingWorking set too large for L2 cacheEffective bandwidth drops below peak HBM rate

A well-optimized kernel achieves 70-85% of the roofline prediction. If you are below 50%, look for one of the above issues. If you are above 85%, you are doing excellent work. Above 90% is exceptional and usually requires deep hardware-specific tuning.

The Reiner Pope methodology. When analyzing any new operation: (1) Compute the FLOPs. (2) Compute the bytes transferred. (3) Take the ratio (AI). (4) Compare to the ridge point. (5) Predict the time. (6) Measure and compare. If measured is >2x slower than predicted, you have a bug — probably non-coalesced access or a bank conflict. If measured is within 1.3x of predicted, your kernel is good enough and further optimization has diminishing returns. This 60-second analysis saves hours of aimless profiling.

Roofline for Different Operations

Here is a table of common deep learning operations and their typical arithmetic intensity. Memorize these — they come up constantly in interviews.

OperationFLOPsBytes (FP16)AI (FLOPs/byte)Bound on H100
Large matmul [4096×4096]×[4096×4096]1.37×10111.01×1081365Compute
Batch=1 matmul [1×4096]×[4096×4096]3.36×1073.36×1071.0Memory (0.3%)
LayerNorm (d=4096)~5 per elem4 per elem1.25Memory (0.4%)
Softmax (S=2048)~5 per elem4 per elem1.25Memory (0.4%)
GELU / SiLU~10 per elem4 per elem2.5Memory (0.8%)
Embedding lookup02 per elem0Memory (pure)
Attention (standard)4S2d~8S2d/2 ≈ 64Memory (21%)
Attention (Flash)4S2d~8SdS/2Compute for S>590
The key pattern. Matrix multiplications are compute-bound (except at batch=1). Everything else — LayerNorm, softmax, activation functions, attention score computation — is memory-bound. This is why kernel fusion matters: fusing LayerNorm+bias+dropout into a single kernel cuts memory traffic by 3x for those operations. But fusing two large matmuls is pointless — they are already compute-bound.
Scaling Book Exercises — Work these with pen and paper first, then check your answer.
Exercise 1.1: Arithmetic Intensity

A matrix multiply computes C = A×B where A is [2048, 4096] and B is [4096, 2048]. Assume BF16 (2 bytes). What is the arithmetic intensity in FLOPs/byte? (Hint: FLOPs = 2MNK, bytes = 2(MK + KN + MN))

Show derivation
FLOPs = 2 × 2048 × 2048 × 4096 = 34,359,738,368
Bytes = 2 × (2048×4096 + 4096×2048 + 2048×2048) = 2 × (8M + 8M + 4M) = 41,943,040
AI = 34.36B / 41.94M ≈ 819 FLOPs/byte

This is well above the H100 ridge point (~295), so this matmul is compute-bound. Good — tensor cores will be fully utilized.

Exercise 1.2: Is Attention Memory-Bound?

Standard attention (no FlashAttention) with S=2048, d=128, batch=1, single head. FLOPs ≈ 4S²d. Bytes ≈ 2(2Sd + S²) × 2 for BF16. What is the arithmetic intensity? Is this compute-bound or memory-bound on H100?

Show derivation
FLOPs = 4 × 2048² × 128 = 2,147,483,648
Bytes = 2 × (2 × 2048 × 128 + 2048²) × 2 = 2 × (524,288 + 4,194,304) × 2 = 18,874,368
AI = 2.15B / 18.87M ≈ 113.8 FLOPs/byte

The H100 ridge point is ~295 FLOPs/byte. AI=114 < 295, so standard attention is memory-bound. This is exactly why FlashAttention was invented — by tiling the computation and avoiding materializing the S×S matrix to HBM, it increases effective arithmetic intensity.

Staff-level question: You are benchmarking a custom Pallas kernel for GQA (Grouped-Query Attention) with S=4096, d_head=128, 8 KV heads, 32 query heads, on an H100. The kernel achieves 320 TFLOPS. Your colleague says "that's only 32% of peak, the kernel is bad." Using roofline analysis, what is your response?

Chapter 2: GPU Memory Hierarchy

The roofline model told us that most operations in a transformer are memory-bound. The natural follow-up question: which memory? A GPU does not have one flat memory space. It has a hierarchy of four levels, each with vastly different bandwidth, latency, and capacity. Understanding this hierarchy is the foundation of all kernel optimization.

Think of the GPU memory hierarchy as a city's transportation system. Registers are the desk you are sitting at — instant access, but tiny (a few papers). Shared memory (SRAM) is the filing cabinet in your office — a few steps away, fast to access, holds a decent amount. L2 cache is the building's mailroom — takes a walk, much larger. HBM is the warehouse across town — takes a truck to deliver, but stores everything.

The Four Levels: Real Numbers on H100

LevelTechnologyBandwidthLatencyCapacity per SMTotalScope
RegistersFlip-flops~78 TB/s0 cycles (same clock)256 KB~33 MBPer thread
Shared Memory (SRAM)SRAM~19 TB/s~23 cyclesUp to 228 KB~30 MBPer block (CTA)
L2 CacheSRAM~12 TB/s~200 cycles50 MBAll SMs
HBM33D-stacked DRAM3.35 TB/s~400-600 cycles80 GBAll SMs

Look at the bandwidth column. Registers are 23x faster than HBM. Shared memory is 5.7x faster than HBM. These are not small differences — they are order-of-magnitude gaps that completely determine kernel performance.

Data movement is the bottleneck. If you remember only one thing from this chapter, let it be this: the vast majority of your kernel optimization effort goes into data placement — getting data to the right level of the memory hierarchy before computation needs it. The actual arithmetic (multiply, add) is nearly free on a modern GPU. The cost is getting the numbers to the ALU in the first place.

Why This Hierarchy Exists: Physics

The hierarchy is not an arbitrary engineering choice. It is dictated by physics.

Speed vs. capacity trade-off. Flip-flops (registers) toggle in a single clock cycle but require 6 transistors each. One 32-bit register = 192 transistors. To store 80 GB at register density would require trillions of transistors and consume absurd power. SRAM (shared memory) uses 6 transistors per bit but with a different layout — it is denser than registers but slower because the read circuitry adds latency. DRAM (the technology behind HBM) uses just 1 transistor + 1 capacitor per bit, making it ~6x denser than SRAM, but the capacitor must be refreshed periodically and the read is destructive (requiring precharge), adding ~400 cycles of latency.

Speed of light limits. An electrical signal travels about 15 cm per nanosecond on a chip. At 2 GHz clock speed, each cycle is 0.5 ns, so a signal can travel at most 7.5 cm in one cycle. The HBM stacks sit on the edge of the GPU die, about 2-3 cm from the SM array. Even at the speed of light, a round trip takes ~0.4 ns = ~1 cycle just for the physical propagation. Add the DRAM access protocol (row activate, column select, data transfer, precharge) and you get 400+ cycles.

The wire delay problem. On a modern chip, transistor switching is no longer the bottleneck — wire delay is. Moving a bit across a 1 mm wire takes longer than a transistor can switch. This is why local memory (registers, shared memory on the same SM) is fast and global memory (HBM on the edge of the die) is slow. Every level of the hierarchy represents a different physical distance from the compute units.

Registers: The Fastest Memory You Never Think About

Each thread on an H100 SM has access to up to 255 32-bit registers (the maximum register file allocation per thread). The total register file per SM is 256 KB. Registers are the only memory that can be accessed at full speed — zero additional latency, and the read happens in the same cycle as the instruction that uses the value.

// Register budget calculation:
// SM has 256 KB = 65,536 registers (32-bit each)
// If your kernel uses 64 registers per thread:
Max threads per SM = 65,536 / 64 = 1024 threads
Max warps per SM = 1024 / 32 = 32 warps

// H100 can schedule up to 64 warps per SM.
// With 64 regs/thread: occupancy = 32/64 = 50%
// With 32 regs/thread: occupancy = 64/64 = 100%

// The tradeoff: fewer registers = more warps = better latency hiding
// more registers = fewer warps = but each warp computes faster
// Typically: 50-75% occupancy is optimal. 100% often wastes registers on
// redundant loads. Below 25%, latency hiding is insufficient.
Register spilling. When your kernel needs more registers than available, the compiler "spills" excess values to local memory — which is actually in HBM with L1 caching. A spilled register access goes from 0 cycles to 400+ cycles. This is why you should check register usage with --ptxas-options=-v when compiling CUDA kernels, or check the XLA compilation log for Pallas kernels. If register usage per thread exceeds 128, you are likely spilling and losing significant performance.

Shared Memory (SRAM): The Programmer-Controlled Cache

Shared memory is the most important memory level for kernel developers because it is the one you explicitly control. Unlike L1/L2 caches (which are hardware-managed), you decide exactly what data goes into shared memory and when.

On H100, each SM has 228 KB of combined shared memory and L1 cache. You can configure the split: up to 228 KB all shared memory with 0 L1, or various splits in between. For most kernels, maximizing shared memory is the right choice because you get to control the access pattern perfectly.

// Shared memory is organized into 32 banks, each 4 bytes wide.
// All 32 banks can be accessed simultaneously = 32 × 4 = 128 bytes per cycle.
// At ~1.8 GHz boost clock:
SRAM bandwidth per SM = 128 bytes × 1.8 GHz = 230.4 GB/s
Total across 132 SMs = 132 × 230.4 = ~30.4 TB/s

// Compare to HBM: 3.35 TB/s
// Shared memory aggregate bandwidth: ~9x faster than HBM

// But the capacity is tiny: 228 KB per SM vs 80 GB for HBM
// Ratio: HBM is ~2,700x larger per SM equivalent

// The kernel developer's job: tile the computation so that
// each tile fits in 228 KB of shared memory, then stream
// tiles from HBM one at a time, reusing each element as
// many times as possible before evicting.

Bank Conflicts: The Shared Memory Performance Trap

Shared memory's 32 banks are the source of its high bandwidth — but also its most common performance pitfall. When two threads in the same warp access different addresses that map to the same bank, the accesses serialize. This is called a bank conflict.

// Bank assignment: bank = (byte_address / 4) % 32

// Example: float array[32][32] in shared memory
// array[0][0] = bank 0
// array[0][1] = bank 1
// ...
// array[0][31] = bank 31
// array[1][0] = bank 0 (stride 32 elements = 128 bytes, 128/4 % 32 = 0)

// Reading a ROW: array[row][threadIdx.x]
// Thread 0 reads bank 0, thread 1 reads bank 1, ... no conflicts

// Reading a COLUMN: array[threadIdx.x][col]
// Thread 0 reads array[0][col] = bank col
// Thread 1 reads array[1][col] = bank col (same bank!)
// All 32 threads hit the same bank = 32-way conflict = 32x slower

// FIX: pad the inner dimension
// float array[32][32 + 1]; // 33 elements per row
// array[0][col] = bank col
// array[1][col] = bank (col + 33) % 32 = (col + 1) % 32
// Each row shifts by one bank = conflict-free column access
The padding trick. Adding +1 to the inner dimension of shared memory arrays is perhaps the most common optimization in CUDA kernels. It costs negligible extra memory (one float per row) but completely eliminates column-access bank conflicts. You will see __shared__ float tile[BLOCK_M][BLOCK_K + 1] in virtually every optimized matmul kernel. In Pallas, the equivalent is choosing tile dimensions that are not powers of 2, or using explicit swizzling patterns.

HBM: The Vast, Slow Main Memory

HBM (High Bandwidth Memory) is a stack of DRAM dies connected to the GPU die via silicon interposers. The H100 has six 16-GB HBM3 stacks, totaling 80 GB at 3.35 TB/s aggregate bandwidth.

Despite its name ("High Bandwidth Memory"), HBM is the slowest memory level. Its bandwidth is "high" relative to traditional DDR5 (which peaks at ~0.1 TB/s), but low relative to the GPU's compute capability.

// HBM capacity planning for a 70B parameter model in FP16:
Model weights: 70B × 2 bytes = 140 GB    does not fit on one H100 (80 GB)
// Need tensor parallelism TP=2 minimum, TP=4 comfortable

// With TP=4: each GPU holds 140/4 = 35 GB of weights
// Remaining: 80 - 35 = 45 GB for KV cache + activations

// KV cache for batch=32, seq_len=4096:
// 2 (K+V) × n_layers × d_model × seq_len × batch × 2 bytes
// For 70B (80 layers, d=8192):
KV cache = 2 × 80 × 8192 × 4096 × 32 × 2 = 343 GB total
Per GPU (TP=4) = 343/4 = 85.7 GB    does not fit!

// This is why KV cache compression (quantization to INT8 or FP8,
// GQA to reduce heads, PagedAttention to avoid fragmentation)
// is critical for serving large models.

How Tensor Cores Fit In

Tensor Cores are specialized matrix multiply units on each SM. They are the reason H100 achieves 989 TFLOPS for FP16 while standard CUDA cores top out at 67 TFLOPS. A single Tensor Core operation on H100 computes a 16×8×16 matrix multiply-accumulate in one cycle.

// H100 4th-gen Tensor Cores:
// Each SM has 4 Tensor Cores
// Each Tensor Core computes: D[16,8] += A[16,16] × B[16,8] per cycle
// FLOPs per Tensor Core per cycle: 2 × 16 × 16 × 8 = 4096 FLOPs
// Per SM: 4 × 4096 = 16,384 FLOPs per cycle
// H100 has 132 SMs at ~1.83 GHz boost:
Total TC FLOPS = 132 × 16384 × 1.83 GHz = ~3.96 × 1015
// Wait, that's ~3960 TFLOPS, not 989. What gives?

// The 989 TFLOPS spec is for FP16 accumulate.
// The TC instruction (HMMA) at 16x8x16 runs in 4 cycles, not 1.
// So: 132 × 4 × (2 × 16 × 16 × 8) / 4 cycles × 1.83 GHz
= 132 × 4 × 1024 × 1.83e9 = ~989 TFLOPS   

// The WGMMA instruction (new in Hopper) bypasses shared memory,
// loading one operand directly from shared memory into the TC.
// This reduces shared memory traffic and enables higher throughput.

TMA: The Tensor Memory Accelerator (Hopper/Blackwell)

H100 introduced a dedicated hardware unit called the Tensor Memory Accelerator (TMA). TMA offloads data movement from the SM's compute pipeline to a separate engine. Instead of threads computing addresses and issuing loads, a single thread programs the TMA with a tensor descriptor (base pointer, shape, strides, data type), and the TMA fetches the entire tile asynchronously.

// Without TMA (traditional approach):
// 1. Each thread computes its address: addr = base + row*stride + col
// 2. Each thread issues a load: ld.global
// 3. Synchronize: __syncthreads()
// 4. Data is now in shared memory
// Cost: address computation uses ALU cycles, loads occupy warp slots

// With TMA (Hopper):
// 1. One thread programs TMA: cp.async.bulk.tensor.shared.global
// 2. TMA fetches the tile independently of SM compute
// 3. SM continues computing on the previous tile
// 4. Arrive barrier signals when the tile is ready
// Cost: near-zero SM overhead, perfect overlap of compute and transfer

// In Pallas (JAX's kernel language for TPU/GPU):
// TMA is abstracted through BlockSpec and grid_spec
// You specify tile shapes and the compiler generates TMA calls
Why TMA matters for Pallas. When you write a Pallas kernel, you specify BlockSpec for your input and output tensors. The Pallas compiler on Hopper translates these into TMA descriptor loads. This means a well-written Pallas kernel can achieve near-CUTLASS performance without any manual memory management — the compiler handles TMA for you. But you still need to choose tile sizes wisely: tiles too small waste TMA setup overhead; tiles too large spill out of shared memory.

Interactive: GPU Memory Hierarchy

GPU Memory Hierarchy Visualizer

Watch data flow through the memory hierarchy. Click "Run Matmul Tile" to see a single tile loaded from HBM into shared memory and then consumed by Tensor Cores.

Ready

Worked Example: Where Does a Matmul Spend Its Time?

Let us trace a single matmul through the memory hierarchy to understand where time is spent.

// Operation: C[4096, 4096] = A[4096, 4096] × B[4096, 4096] in FP16
// Using a BLOCK_M=128, BLOCK_N=128, BLOCK_K=32 tiling strategy

// ── Step 1: How many tiles? ──
Tiles in M: 4096/128 = 32
Tiles in N: 4096/128 = 32
Tiles in K: 4096/32 = 128
Total tile computations: 32 × 32 = 1024 output tiles
Each output tile requires 128 K-steps

// ── Step 2: Data per K-step ──
Load A tile: 128 × 32 × 2 bytes = 8 KB
Load B tile: 32 × 128 × 2 bytes = 8 KB
Total per K-step: 16 KB from HBM → shared memory

// ── Step 3: HBM transfer time per K-step ──
tHBM = 16 KB / 3.35 TB/s = 4.78 ns

// ── Step 4: Compute per K-step ──
FLOPs = 2 × 128 × 128 × 32 = 1,048,576 = 1.05M FLOPs
// On one SM with 4 Tensor Cores at 989/(132) TFLOPS per SM = 7.49 TFLOPS
tcompute = 1.05e6 / 7.49e12 = 0.14 ns

// ── Step 5: Compare ──
tHBM / tcompute = 4.78 / 0.14 = 34x
// Compute is 34x faster than memory! This is compute-bound.
// BUT: we have 128 K-steps, and we pipeline the loads.
// Double buffering: while computing on tile k, load tile k+1.
// Effective time = max(t_HBM, t_compute) per K-step = 4.78 ns

// Wait — if t_HBM > t_compute, isn't this memory-bound?
// For a SINGLE SM, yes! But across all 132 SMs processing
// different output tiles simultaneously, the aggregate HBM
// demand is distributed across the memory channels.
// With 1024 tiles across 132 SMs, each SM gets ~8 tiles.
// Total HBM reads: 1024 tiles × 128 steps × 16 KB = 2.048 GB
// At 3.35 TB/s: 0.61 ms for all HBM traffic
// Total FLOPs: 2 × 4096^3 = 1.374e11
// At 989 TFLOPS: 0.139 ms for all compute
// 0.61 ms > 0.139 ms → Globally, this matmul is memory-bound!

// Hmm, but we computed AI = 1365 earlier and said it was compute-bound.
// Resolution: AI = 1365 counts total FLOPs / total bytes IF you only
// read each element once. With tiling, each A-tile row and B-tile column
// is read 32 times (once per output tile in that row/column).
// Effective bytes = 2 × 2 × 4096 × 4096 × 32 = 2.048 GB (matches above)
// Effective AI = 1.374e11 / 2.048e9 = 67 FLOPs/byte → MEMORY-BOUND

// The fix: increase tile sizes (BLOCK_M=256, BLOCK_N=256) to reduce
// the number of times each row/column is re-read. This is why
// CUTLASS uses large tiles with multiple pipeline stages.
The tiling paradox. A naive FLOPs/bytes calculation says large matmul is compute-bound. But if you tile it and count the actual HBM traffic (each tile re-read multiple times), it can become memory-bound. The key optimization is maximizing data reuse: larger tiles, double buffering, and on Hopper, using TMA to overlap loads with computation via warp specialization (some warps compute while others manage data movement). This is exactly what CUTLASS 3.x and FlashAttention-3 do.

Blackwell and Beyond: What Changes

MetricA100H100B200Trend
FP16 Tensor TFLOPS312989~2,250~3x per generation
HBM Bandwidth (TB/s)2.03.35~8.0~2x per generation
Ridge Point (FLOPs/byte)156295~281Increasing: more ops become memory-bound
SRAM per SM (KB)164228~256Slow growth: physics limits
HBM Capacity (GB)80801922.4x: finally fits 70B FP16
The widening gap. Compute grows ~3x per generation. Bandwidth grows ~2x. The ridge point creeps up, meaning more operations become memory-bound over time. This is why FlashAttention-style algorithms, kernel fusion, and memory-efficient training (gradient checkpointing, mixed precision) become more — not less — important with each GPU generation. If you are a staff engineer planning for the next hardware generation, your strategy should assume an even more memory-constrained roofline.
Staff-level question: You are profiling a fused LayerNorm+GELU kernel on H100. The kernel processes a tensor of shape [4096, 8192] in FP16. Nsight Compute reports the kernel achieves 2.9 TB/s of HBM bandwidth and 8.2 TFLOPS of compute. The kernel takes 0.055 ms. Is this kernel well-optimized? What is the bottleneck, and what would you check next?

Chapter 3: Transformer Math

In a scaling law meeting, someone asks: "How many FLOPs to train a 70B model on 2T tokens?" You need to answer in under 10 seconds. In a system-design interview, the question is: "How much memory does a 400B MoE model need at inference time with batch=64, seq_len=8192?" You need to derive it on the whiteboard in 2 minutes. This chapter gives you the formulas and worked examples to do both.

All of these are derived from first principles — no memorization required. Once you understand where each parameter lives and how each FLOP is spent, you can re-derive any formula on the spot.

Parameter Count: Deriving It Layer by Layer

A standard transformer decoder (GPT-style) has L identical layers, each containing a multi-head attention block and a feed-forward network (FFN). Let us count every parameter.

// Notation:
d = model dimension (hidden size)    e.g., 8192 for LLaMA 70B
L = number of layers    e.g., 80 for LLaMA 70B
h = number of attention heads    e.g., 64
hkv = number of KV heads (for GQA)    e.g., 8
dhead = d / h    e.g., 8192/64 = 128
dff = FFN intermediate dimension    e.g., 28672 for LLaMA 70B (3.5d)
V = vocabulary size    e.g., 32000

// ═══ ATTENTION BLOCK (per layer) ═══

// Q projection: W_Q is [d, d] (maps d-dim input to d-dim query space)
// Actually: h heads × d_head = d, so W_Q is [d, h × d_head] = [d, d]
PQ = d × d = d2

// K projection: for GQA, only h_kv heads, not h
// W_K is [d, h_kv × d_head]
PK = d × hkv × dhead = d × (hkv/h) × d = (hkv/h) × d2

// V projection: same as K for GQA
PV = (hkv/h) × d2

// Output projection: W_O is [d, d]
PO = d2

// Total attention parameters per layer:
Pattn = d2 + 2(hkv/h)d2 + d2 = d2(2 + 2hkv/h)

// For standard MHA (h_kv = h): P_attn = 4d2
// For GQA with h_kv = h/8 (LLaMA 70B): P_attn = d2(2 + 2/8) = 2.25d2

// ═══ FFN BLOCK (per layer) ═══

// Standard FFN: two linear layers
// W_up: [d, d_ff], W_down: [d_ff, d]
Pffn_standard = 2 × d × dff

// SwiGLU FFN (LLaMA, Mistral): THREE linear layers
// W_gate: [d, d_ff], W_up: [d, d_ff], W_down: [d_ff, d]
// SwiGLU: output = (W_gate(x) × SiLU(W_up(x))) @ W_down
Pffn_swiglu = 3 × d × dff

// For LLaMA 70B: d_ff = 28672, d = 8192
Pffn = 3 × 8192 × 28672 = 704,643,072 ≈ 705M per layer

// ═══ PER-LAYER NORMS ═══
// RMSNorm (no bias, just scale): d parameters per norm
// 2 norms per layer (pre-attention, pre-FFN)
Pnorm = 2d (negligible)

// ═══ TOTAL PER LAYER ═══
Player = Pattn + Pffn + Pnorm
// For LLaMA 70B: 2.25 × 81922 + 3 × 8192 × 28672 + 2 × 8192
= 150,994,944 + 704,643,072 + 16,384
= 855,654,400 ≈ 856M per layer

// ═══ EMBEDDING ═══
Pembed = V × d = 32000 × 8192 = 262,144,000 ≈ 262M
// Often the embedding and output projection share weights (weight tying)
// LLaMA does NOT use weight tying, so add another V×d for the lm_head

// ═══ GRAND TOTAL ═══
Ptotal = L × Player + Pembed + Plm_head + Pfinal_norm
= 80 × 855,654,400 + 262,144,000 + 262,144,000 + 8,192
= 68,452,352,000 + 524,296,192
= 68,976,648,192 ≈ 69.0B parameters

// Close to the "70B" marketing number. The difference comes from
// rounding in d_ff and the exact GQA configuration.
Quick parameter counting rules. For a standard transformer with MHA: P ≈ 12d2L (4d2 attention + 8d2 FFN per layer). For GQA with SwiGLU (modern LLMs): P ≈ (2 + 2hkv/h)d2L + 3d·dff·L. The FFN typically dominates (~70-80% of parameters) because dff is usually 2.5-4x larger than d.

FLOPs Per Token: The 2P and 6P Rules

For every parameter, the forward pass performs approximately 2 FLOPs (one multiply and one add in a matrix multiplication). This gives us a beautifully simple rule:

// ═══ FORWARD PASS FLOPs ═══
// Each linear layer W[m, n] processing input x[batch, m]:
// Output = x @ W: FLOPs = 2 × batch × m × n
// The "2" is for the multiply-add in the dot product.
// For a single token (batch=1), FLOPs = 2 × m × n = 2 × parameters(W)

// Summing over all linear layers in the model:
Forward FLOPs per token ≈ 2P    where P = total parameters

// Plus the attention score computation: 2 × S × d per head per layer
// This adds 4Sd·h·L = 4S·d2·L FLOPs (not captured in "2P")
// For S=4096, d=8192, L=80: 4×4096×81922×80 = 8.8e13
// vs 2P = 2 × 69e9 = 1.38e11
// Attention FLOPs are ~638x the parameter FLOPs at S=4096!
// But attention FLOPs grow as O(S2), so at short seq they are small.

// More precise formula:
Forward FLOPs per token = 2P + 2 × L × nheads × dhead × S
                      = 2P + 2 × L × d × S

// ═══ BACKWARD PASS FLOPs ═══
// The backward pass computes gradients w.r.t. both inputs and weights.
// For a linear layer y = xW:
// dL/dx = dL/dy @ WT (same FLOPs as forward: 2mn)
// dL/dW = xT @ dL/dy (same FLOPs as forward: 2mn)
// Total backward = 2 × forward per layer

Backward FLOPs per token ≈ 4P
Total training FLOPs per token ≈ 6P    (forward + backward)

// ═══ TOTAL TRAINING FLOPs ═══
// For D tokens of training data:
C = 6 × N × D    where N = parameters, D = tokens

// LLaMA 3 70B trained on ~15T tokens:
C = 6 × 70e9 × 15e12 = 6.3 × 1024 FLOPs
// At 989 TFLOPS per H100, MFU=0.45:
// Effective per GPU: 445 TFLOPS
// On 2048 H100s: 2048 × 445e12 = 9.11e17 FLOPS
// Time: 6.3e24 / 9.11e17 = 6.91e6 seconds = 80 days
The 6ND rule. C = 6ND is the most important equation in frontier lab engineering. It connects three quantities that determine every training run: compute budget C (in FLOPs), model size N (parameters), and data size D (tokens). Given any two, you can derive the third. In a scaling law meeting, this is your starting point for everything.

Memory Footprint: Where Every Byte Goes

Training a model requires storing four categories of data in GPU memory. Let us account for each one.

// ═══ 1. MODEL WEIGHTS ═══
// FP32: 4 bytes per param
// BF16: 2 bytes per param (standard for training)
// FP16: 2 bytes per param
Weights (BF16) = N × 2 bytes
// 70B model: 70e9 × 2 = 140 GB

// ═══ 2. OPTIMIZER STATES ═══
// AdamW maintains per-parameter:
// - First moment (m): same shape as weights, FP32 = 4 bytes
// - Second moment (v): same shape as weights, FP32 = 4 bytes
// - Master weights (FP32 copy for precision): 4 bytes
// Total optimizer per param: 12 bytes (with master weights)
Optimizer = N × 12 bytes
// 70B model: 70e9 × 12 = 840 GB

// ═══ 3. GRADIENTS ═══
// Same shape as weights, typically BF16 or FP32
Gradients (BF16) = N × 2 bytes
// 70B model: 140 GB

// ═══ 4. ACTIVATIONS ═══
// Per layer, per token, per micro-batch:
// Attention: Q, K, V, attention scores, output = ~4Sd + S2 per head
// FFN: input, gate, up, down = ~3d_ff + d per token
// With activation checkpointing: only store input to each layer
// Without checkpointing: store all intermediate activations

// Full activations (no checkpointing), per layer:
Alayer ≈ S × B × (10d + 2S×h) × 2 bytes    (rough estimate)
// For S=4096, B=4, d=8192, h=64:
Alayer ≈ 4096 × 4 × (81920 + 524288) × 2 = 39.6 GB per layer
// 80 layers: 3168 GB — impossible without checkpointing!

// With activation checkpointing (recompute forward in backward):
// Only store layer inputs: S × B × d × 2 bytes per layer
Ackpt = L × S × B × d × 2
// = 80 × 4096 × 4 × 8192 × 2 = 21.5 GB
// Cost: ~33% more compute (recompute forward per layer in backward)

// ═══ TOTAL TRAINING MEMORY ═══
Total = Weights + Optimizer + Gradients + Activations
// 70B model, BF16, AdamW, ckpt, B=4, S=4096:
= 140 + 840 + 140 + 21.5 = 1141.5 GB

// On 8 H100 GPUs (80 GB each = 640 GB total): does not fit!
// Need: FSDP/ZeRO-3 to shard optimizer states + weights across GPUs
// With ZeRO-3 across 16 GPUs: 1141.5 / 16 = 71.3 GB per GPU ✓

KV Cache: The Inference Memory Bottleneck

During inference, the model generates one token at a time. For each new token, it needs to attend to all previous tokens. Re-computing K and V for all previous tokens at every step would be wasteful, so we cache them. This KV cache is often the largest memory consumer at inference time.

// KV cache size per token, per layer:
// K: [n_kv_heads, d_head] = h_kv × d_head values
// V: [n_kv_heads, d_head] = h_kv × d_head values
// Per layer: 2 × h_kv × d_head values

// Total KV cache:
KV = 2 × L × hkv × dhead × S × B × bytes_per_value

// For LLaMA 3 70B (GQA: h_kv=8, d_head=128, L=80):
// FP16: bytes_per_value = 2
KV = 2 × 80 × 8 × 128 × S × B × 2
   = 327,680 × S × B bytes

// Example scenarios:
// S=4096, B=1: 327,680 × 4096 × 1 = 1.34 GB
// S=4096, B=32: 327,680 × 4096 × 32 = 42.9 GB
// S=32768, B=32: 327,680 × 32768 × 32 = 343 GB!

// With FP8 KV cache (saves 2x):
// S=4096, B=32: 21.5 GB
// S=32768, B=32: 171.5 GB

// With GQA (8 KV heads instead of 64 MHA heads):
// Already factored in above — GQA gives 8x KV cache savings vs MHA!
// Without GQA (h_kv=64): S=4096, B=32 would be 343 GB, not 42.9 GB.
// This is a primary motivation for GQA in modern LLMs.
GQA is a memory optimization, not a quality optimization. Grouped-Query Attention reduces the number of KV heads from h to hkv (typically h/4 to h/8). This saves KV cache memory proportionally. The quality impact is minimal: LLaMA 2 70B showed that GQA with 8 KV heads matches MHA quality within 0.1% on most benchmarks, while reducing KV cache by 8x. At 32K context with batch=32, GQA is the difference between needing 2 nodes and needing 16 nodes for inference.

MFU: How to Measure Training Efficiency

Model FLOPs Utilization (MFU) is the fraction of the GPU's peak compute that is actually used for model computation:

// Definition:
MFU = (Model FLOPs per step) / (Peak FLOPs per step)

// Model FLOPs per step = 6 × N × B × S
// (forward + backward, per micro-batch of B sequences of length S)

// Peak FLOPs per step = n_GPUs × Peak_TFLOPS × step_time

// Example: LLaMA 70B training
// Global batch = 1024 sequences × 4096 tokens = 4.19M tokens/step
// Model FLOPs/step = 6 × 70e9 × 4.19e6 = 1.76e18 FLOPs
// On 2048 H100s, step time = 28 seconds (hypothetical):
// Peak FLOPs/step = 2048 × 989e12 × 28 = 5.67e19 FLOPs
// MFU = 1.76e18 / 5.67e19 = 0.031 = 3.1%

// Wait, 3.1%? That seems terrible. Let's check:
// The step time should be much shorter. Let's work backwards.
// Target MFU = 0.45:
// step_time = Model_FLOPs / (n_GPUs × Peak × MFU)
// = 1.76e18 / (2048 × 989e12 × 0.45)
// = 1.76e18 / 9.11e17 = 1.93 seconds per step

// So at MFU=0.45, each step takes ~2 seconds.
// Total training time for 15T tokens:
// Steps = 15e12 / 4.19e6 = 3.58e6 steps
// Time = 3.58e6 × 1.93 = 6.91e6 seconds = 80 days
SystemReported MFUNotes
PaLM (Google, 2022)46.2%540B on 6144 TPUv4, pipeline + data parallelism
LLaMA (Meta, 2023)~38%65B on 2048 A100, FSDP
GPT-4 (estimated)~45-50%Reported to use ~25K A100s, exact MFU unknown
DeepSeek-V3 (2024)~52%671B MoE on H800, FP8 training, expert parallelism
MFU vs HFU. Some papers report "Hardware FLOPs Utilization" (HFU), which includes the extra FLOPs from activation recomputation (gradient checkpointing). HFU is always higher than MFU because it counts FLOPs that are "wasted" on recomputation as useful work. MFU is the honest metric — it counts only FLOPs that contribute to model training. When reading papers, always check which metric is reported.

Interactive: Transformer Parameter Calculator

Transformer Math Calculator

Adjust model dimensions and see parameter count, training FLOPs, memory footprint, and KV cache size update in real time.

d_model 8192
Layers 80
KV Heads 8
Seq Length 4096
Batch Size 32

Chinchilla Scaling Laws: How Much Data for How Many Parameters?

In 2022, the Chinchilla paper (Hoffmann et al.) established that for a given compute budget C, there is an optimal split between model size N and training data D. Training a model that is too large on too little data wastes compute; training a small model on excessive data also wastes compute.

// The Chinchilla loss model:
L(N, D) = a × N + b × D + c

// Where:
// a, b, c, α, β are fitted constants from small-scale experiments
// Chinchilla paper values: α ≈ 0.34, β ≈ 0.28
// (Later work by Llama team found slightly different values)

// Given compute budget C = 6ND:
// Substitute D = C/(6N) and minimize L(N, C/(6N)) w.r.t. N:
// dL/dN = -aαN-α-1 + bβ(C/6)Nβ-1 = 0

// Solving:
Nopt ∝ Cβ/(α+β)
Dopt ∝ Cα/(α+β)

// With α=0.34, β=0.28:
Nopt ∝ C0.452
Dopt ∝ C0.548

// The key ratio:
Dopt/Nopt ≈ 20    (Chinchilla: ~20 tokens per parameter)

// Example: compute budget of 1024 FLOPs
// C = 6ND, D = 20N:
// 1024 = 6 × N × 20N = 120N2
// N = sqrt(1024/120) = 2.89 × 1011 ≈ 289B parameters
// D = 20 × 289B = 5.78T tokens

// Modern practice (LLaMA 3, 2024): D/N ≈ 200
// The Chinchilla ratio of 20 is compute-optimal for a FIXED budget.
// But if inference cost dominates (you will serve the model millions
// of times), it is cheaper to train a SMALLER model on MORE data.
// LLaMA 3 8B was trained on 15T tokens: D/N = 1875!
Chinchilla-optimal vs inference-optimal. Chinchilla tells you the cheapest way to reach a given loss during training. But training is a one-time cost; inference is ongoing. If a 8B model trained on 15T tokens achieves the same loss as a 70B model trained on 2T tokens, the 8B model is vastly cheaper to serve (8.75x fewer FLOPs per token, 8.75x less memory). This is why Meta trained LLaMA 3 far past the Chinchilla-optimal point: they optimized for inference cost, not training cost. The staff-level insight: the right scaling law to optimize depends on the ratio of training cost to lifetime inference cost.

Putting It All Together: LLaMA 3 70B Worksheet

QuantityFormulaValue
Parameters80 × (2.25d2 + 3d·dff) + 2Vd~69B
Weights (BF16)69B × 2138 GB
Optimizer (AdamW)69B × 12828 GB
Gradients (BF16)69B × 2138 GB
Training memory totalSum + activations~1.1-1.3 TB
Min GPUs (training)~1200 GB / 80 GB16 H100s (ZeRO-3)
Forward FLOPs/token2P138 GFLOPs
Training FLOPs (15T tok)6 × 69B × 15T6.2 × 1024
Training time (2048 H100)C / (GPUs × FLOPS × MFU)~80 days at MFU=0.45
KV cache (B=32, S=4096)2×80×8×128×4096×32×242.9 GB
Inference memory (TP=4)(138 + 42.9) / 4~45 GB/GPU
Scaling Book Exercises — These map directly to Chapter 4 of the scaling book. Do them on paper first.
Exercise 3.1: Parameter Count

GPT-2 Small: 12 layers, d_model=768, 12 attention heads, d_head=64, FFN inner dim=3072. Ignore embeddings and LayerNorm. How many parameters (in millions)?

Show derivation
Per layer: Attention = 4 × d² = 4 × 768² = 2,359,296
Per layer: FFN = 2 × d × 4d = 2 × 768 × 3072 = 4,718,592
Per layer total = 7,077,888
12 layers = 84,934,656 ≈ 85.1M

(Adding embeddings: vocab_size × d = 50257 × 768 = 38.6M, shared with output head. Full GPT-2 Small ≈ 124M.)

Exercise 3.2: KV Cache Memory

LLaMA 3 70B: 80 layers, GQA with 8 KV heads, d_head=128, batch=32, sequence length=8192, BF16 (2 bytes). How many GB for the KV cache?

Show derivation
KV cache = 2 × n_layers × n_kv_heads × d_head × seq_len × batch × bytes
= 2 × 80 × 8 × 128 × 8192 × 32 × 2
= 2 × 80 × 8 × 128 × 8192 × 32 × 2 = 42,949,672,960 bytes ≈ 42.9 GB

The factor of 2 at the start is for K and V tensors. With GQA (8 KV heads instead of 64 query heads), the cache is 8x smaller than it would be with MHA.

Exercise 3.3: Training FLOPs

You are training a 7B parameter model on 2T tokens. Using the 6ND approximation, how many total FLOPs (in units of 1022)?

Show derivation
C = 6 × N × D = 6 × 7×109 × 2×1012 = 8.4 × 1022

At 50% MFU on 1024 H100s (each 989 TFLOPS BF16): effective = 0.5 × 1024 × 989e12 = 5.06e17 FLOPS. Time = 8.4e22 / 5.06e17 = 166,008 seconds ≈ 1.92 days.

Staff-level question: You are planning a new model training run. Budget: 1024 FLOPs. The research lead wants a 400B dense model. Using the Chinchilla scaling law (D/N ≈ 20), how many tokens can you train this model on? Is this compute-optimal? What would you recommend instead, and why?

Chapter 4: JAX Fundamentals [full lesson →]

At a frontier lab, the training infrastructure is almost certainly built on JAX. Not PyTorch — JAX. Why? Three reasons: (1) JAX compiles your Python code through XLA, which generates fused, optimized kernels for TPUs and GPUs automatically. (2) JAX's functional design (no hidden state, no in-place mutation) makes distributed training with model parallelism radically simpler — the compiler can shard pure functions across devices. (3) JAX's composable transforms (jit, grad, vmap) let you write research code that is simultaneously readable AND production-fast.

If you are coming from PyTorch, JAX will feel strange at first. PyTorch is imperative: you build a computation graph by executing Python statements, and the graph is traced eagerly. JAX is functional: you write pure functions, and JAX transforms (traces, compiles, differentiates) them. The payoff for this discipline is enormous — but the learning curve is real.

jax.numpy: Tracing, Not Executing

The first surprise: jax.numpy operations do not execute immediately (unless you are in eager mode). When you call a function decorated with @jax.jit, JAX traces the function: it runs your Python code with abstract "tracer" values instead of real data. The tracers record the operations, building an XLA computation graph. Then XLA compiles this graph into optimized machine code.

// What actually happens when you call a jitted function:

// 1. TRACING: JAX replaces your input arrays with "tracers"
// Tracers carry shape and dtype, but NO values
// Your Python code runs once with these tracers
// Every jnp operation records itself into a trace

// 2. LOWERING: The trace is converted to StableHLO IR
// (High-Level Operations — a hardware-agnostic graph format)

// 3. COMPILATION: XLA compiles StableHLO to target hardware code
// On GPU: generates CUDA PTX → cubin
// On TPU: generates TPU ISA machine code
// Includes operator fusion, buffer reuse, layout optimization

// 4. EXECUTION: Compiled code runs on the accelerator
// First call is slow (compile), subsequent calls are fast
python
import jax
import jax.numpy as jnp

# ── Eager mode (no jit) ──
# Operations execute immediately, like NumPy.
# Useful for debugging, but slow (no fusion, no optimization).
x = jnp.ones((3, 3))
y = x + 1       # Dispatches immediately to GPU/TPU
z = y * 2       # Another dispatch — TWO kernel launches
print(z)        # [[4. 4. 4.], [4. 4. 4.], [4. 4. 4.]]

# ── JIT mode ──
# JAX traces the function, compiles it, then runs the compiled version.
@jax.jit
def f(x):
    y = x + 1   # NOT executed — recorded as a trace operation
    z = y * 2   # NOT executed — recorded
    return z    # The trace is: z = (x + 1) * 2

# First call: trace + compile + execute (slow)
result = f(x)   # XLA fuses add+mul into ONE kernel launch!

# Second call with same shape/dtype: just execute (fast, no recompile)
result2 = f(jnp.zeros((3, 3)))

# Third call with DIFFERENT shape: retrace + recompile!
result3 = f(jnp.zeros((4, 4)))  # New compilation
The tracing model is the key to everything. JAX traces your function once to build a graph, then compiles the graph. This means: (1) Python control flow that depends on values (not shapes) is captured as a static branch — the other branch is dead code. Use jax.lax.cond for value-dependent branching. (2) Python side effects (print statements, list appends, global variable mutations) happen at trace time only, not at execution time. (3) Functions are recompiled when input shapes change. Design your data pipeline to produce fixed-shape batches to avoid constant recompilation.

Pure Functions: Why No Side Effects Matter

JAX requires your functions to be pure: given the same inputs, they must always return the same outputs, and they must not modify any external state. This is not just a style preference — it is a fundamental requirement for correctness when JAX transforms your code.

python
# ── BAD: impure function ──
counter = [0]  # Mutable external state

@jax.jit
def bad_f(x):
    counter[0] += 1   # Side effect! Modifies external state
    print("traced")   # Side effect! Only prints during tracing
    return x + counter[0]

# Call 1: traces, prints "traced", counter becomes 1, returns x+1
# Call 2: DOES NOT retrace (same shape), DOES NOT print,
#         counter is STILL 1 (the value was captured at trace time),
#         returns x+1 (not x+2!)
# This is a silent correctness bug.

# ── GOOD: pure function ──
@jax.jit
def good_f(x, counter):
    return x + counter, counter + 1  # Return new counter, don't mutate

# State is threaded through as an explicit argument and return value.
# No hidden mutation, no trace-time capture bugs.
result, new_counter = good_f(x, jnp.array(0))
result2, new_counter2 = good_f(x, new_counter)

Why does purity matter beyond avoiding bugs? Because the XLA compiler assumes pure functions to perform optimizations. It can reorder operations, fuse kernels, and even eliminate dead code — all of which are only correct if the function has no side effects. The compiler also needs purity to shard the function across multiple devices: if a function modifies global state, you cannot safely run it on 2048 TPUs simultaneously.

jax.grad: Automatic Differentiation

jax.grad computes the gradient of a scalar-valued function with respect to its first argument (by default). It uses reverse-mode automatic differentiation (backpropagation), which computes all parameter gradients in a single backward pass.

python
import jax
import jax.numpy as jnp

# ── Basic gradient ──
def loss_fn(w, x, y):
    """MSE loss: L = mean((w @ x - y)^2)"""
    pred = w @ x
    return jnp.mean((pred - y) ** 2)

# grad takes the gradient w.r.t. the FIRST argument (w)
grad_fn = jax.grad(loss_fn)

w = jnp.ones((3, 4))
x = jnp.ones((4, 2))
y = jnp.zeros((3, 2))

dL_dw = grad_fn(w, x, y)  # Shape: (3, 4) — same as w
print(dL_dw.shape)  # (3, 4)

# ── Gradient w.r.t. multiple arguments ──
grad_fn_multi = jax.grad(loss_fn, argnums=(0, 1))  # grad w.r.t. w AND x
dL_dw, dL_dx = grad_fn_multi(w, x, y)

# ── Value and gradient together ──
# In training, you need both the loss value and its gradient.
# Computing them separately wastes work (shared forward pass).
loss_val, dL_dw = jax.value_and_grad(loss_fn)(w, x, y)
print(f"Loss: {loss_val:.4f}, Grad norm: {jnp.linalg.norm(dL_dw):.4f}")

Forward Mode vs Reverse Mode: When Each Shines

// Consider f: Rn → Rm (n inputs, m outputs)

// REVERSE MODE (jax.grad, backpropagation):
// Computes df/dx for ALL n inputs in ONE backward pass
// Cost: O(1) backward passes regardless of n
// BUT: only works when m=1 (scalar output) for plain grad
// For m > 1: need one backward pass per output = O(m) passes
// Best when: n >> m (many params, one loss) = TRAINING

// FORWARD MODE (jax.jvp = Jacobian-vector product):
// Computes df/dx for ONE input direction in ONE forward pass
// Cost: O(n) forward passes to get full Jacobian
// BUT: each pass costs about the same as the original function
// Best when: n << m (few params, many outputs) or n=1

// In practice for training:
// n = number of parameters = billions
// m = 1 (scalar loss)
// Reverse mode (jax.grad) wins by a factor of billions.
python
# ── Forward mode: Jacobian-Vector Product (JVP) ──
# Computes f(x) and J @ v simultaneously (J = Jacobian, v = tangent vector)
def f(x):
    return jnp.array([x[0]**2 + x[1], x[0] * x[1]])

x = jnp.array([3.0, 4.0])
v = jnp.array([1.0, 0.0])  # Tangent vector: derivative in x[0] direction

primals, tangents = jax.jvp(f, (x,), (v,))
# primals = f(x) = [9+4, 12] = [13, 12]
# tangents = J @ v = [2*x[0]*1 + 0, x[1]*1 + 0] = [6, 4]

# ── Reverse mode: Vector-Jacobian Product (VJP) ──
# Computes f(x) and a function that maps cotangent to gradient
primals, vjp_fn = jax.vjp(f, x)
# primals = f(x) = [13, 12]

# vjp_fn takes a cotangent vector (same shape as output)
# and returns the gradient w.r.t. the input
cotangent = jnp.array([1.0, 0.0])  # grad of output[0] only
(grad_x,) = vjp_fn(cotangent)
# grad_x = [2*x[0], 1] = [6, 1]  (df[0]/dx[0], df[0]/dx[1])

jax.vmap: Auto-Vectorization

jax.vmap automatically vectorizes a function that operates on a single example to work on a batch of examples. Instead of writing explicit batch loops or reshaping tensors, you write your function for one example and vmap handles batching.

python
# ── Without vmap: manual batching ──
def predict_single(params, x):
    """Forward pass for a single input x of shape (d,)."""
    for W, b in params:
        x = jax.nn.relu(W @ x + b)
    return x

# To handle a batch, you'd either:
# (a) Loop: [predict_single(params, x_i) for x_i in batch]  ← SLOW
# (b) Rewrite: change W @ x to x_batch @ W.T + b  ← manual work

# ── With vmap: automatic batching ──
predict_batch = jax.vmap(predict_single, in_axes=(None, 0))
# in_axes=(None, 0) means:
#   params: NOT batched (same params for all examples)
#   x: batched along axis 0 (first dim is batch)

# Now predict_batch(params, x_batch) works on a batch
# where x_batch has shape (batch_size, d)
# The output has shape (batch_size, output_d)

# ── vmap + grad: per-example gradients ──
# This is one of JAX's superpowers: get the gradient for EACH
# example in the batch, not just the average gradient.
# Useful for: differential privacy, per-example clipping, debugging.

def loss_single(params, x, y):
    pred = predict_single(params, x)
    return jnp.sum((pred - y) ** 2)

# Per-example gradient function:
per_example_grad = jax.vmap(jax.grad(loss_single), in_axes=(None, 0, 0))
# per_example_grad(params, x_batch, y_batch) returns a pytree
# with the same structure as params, but each leaf has an extra
# batch dimension: shape (batch_size, *param_shape)
How vmap works internally. vmap does NOT create a Python loop. It adds a "batch dimension" tracer to the input, traces the function once, and then the XLA compiler generates a single vectorized kernel. The overhead is exactly one trace (at compile time), and the runtime is identical to hand-written batched code. This is why vmap + jit is production-quality: the compiler sees the full vectorized computation and can optimize it globally.

Composing Transforms: The Power Move

JAX's transforms (jit, grad, vmap) compose freely. You can stack them in any order, and JAX will trace through all of them to produce a single optimized program.

python
# ── The classic training step ──
# This is the pattern you'll see in every JAX training loop.

@jax.jit                    # 3. Compile the whole thing
def train_step(params, x_batch, y_batch):
    def loss_fn(params):
        # vmap handles the batch dimension automatically
        preds = jax.vmap(predict_single, in_axes=(None, 0))(params, x_batch)  # 1. vmap for batching
        return jnp.mean((preds - y_batch) ** 2)

    loss, grads = jax.value_and_grad(loss_fn)(params)  # 2. grad for differentiation
    # Simple SGD update
    new_params = jax.tree.map(lambda p, g: p - 0.01 * g, params, grads)
    return new_params, loss

# What happens under the hood:
# 1. jit traces train_step
# 2. Inside, it encounters vmap → adds batch tracing
# 3. Inside, it encounters grad → adds differentiation tracing
# 4. The entire thing (forward + backward + update) compiles
#    to a SINGLE XLA program with fused kernels
# 5. On subsequent calls: no Python overhead, just the compiled kernel

PyTrees: The Tree-of-Arrays Abstraction

Neural network parameters are not a single array — they are a nested structure of arrays (a dict of dicts of arrays, a list of tuples, etc.). JAX calls these nested structures PyTrees and provides utilities to work with them uniformly.

python
# ── Model parameters as a PyTree ──
params = {
    'layer1': {'W': jnp.ones((128, 64)), 'b': jnp.zeros(64)},
    'layer2': {'W': jnp.ones((64, 32)),  'b': jnp.zeros(32)},
    'head':   {'W': jnp.ones((32, 10)),  'b': jnp.zeros(10)},
}
# This is a PyTree: a nested dict of jax arrays.

# ── jax.tree.map: apply a function to every leaf ──
# SGD update: params = params - lr * grads
# grads has the SAME structure as params (JAX guarantees this)
new_params = jax.tree.map(lambda p, g: p - 0.01 * g, params, grads)
# This walks both trees in parallel, applying the lambda to each pair.
# No manual iteration over layers!

# ── Counting parameters ──
n_params = sum(p.size for p in jax.tree.leaves(params))
print(f"Total params: {n_params:,}")  # 10,474

# ── Flatten and unflatten ──
flat_params, tree_def = jax.tree.flatten(params)
# flat_params is a list of arrays: [W1, b1, W2, b2, W3, b3]
# tree_def remembers the structure for reconstruction
restored = jax.tree.unflatten(tree_def, flat_params)
# restored == params (same structure and values)

# ── Custom PyTree nodes ──
# You can register your own classes as PyTree nodes.
# This is how Flax/Equinox/Haiku make their module objects
# work seamlessly with jax.grad, jit, vmap, etc.
import dataclasses

@dataclasses.dataclass
class Linear:
    W: jax.Array
    b: jax.Array

    def __call__(self, x):
        return self.W @ x + self.b

# Register with JAX so jit/grad/vmap know how to traverse it:
jax.tree_util.register_dataclass(Linear, data_fields=['W', 'b'], meta_fields=[])
Why PyTrees matter for distributed training. When you shard a model across devices with jax.experimental.pjit or jax.sharding, you specify a sharding for each leaf in the parameter PyTree. JAX walks the tree, assigns each array to the right devices, and inserts the necessary communication (AllGather, AllReduce) between devices. Without the PyTree abstraction, you would have to manually manage the sharding of every parameter — an error-prone nightmare for a model with thousands of parameters.

Random Numbers: The Split Model

JAX's random number generation is different from NumPy's, and for good reason. NumPy uses a global, stateful PRNG: np.random.seed(42) followed by np.random.randn() advances a global counter. This is fundamentally incompatible with JAX's functional paradigm (no hidden state) and with parallel execution (what happens when two TPUs both try to advance the same counter?).

python
# ── JAX PRNG: explicit, functional, splittable ──

# Create an initial key (this is the ONLY global state)
key = jax.random.PRNGKey(42)
# key is a pair of uint32 values, not a global state
print(key)  # [ 0 42]

# WRONG: reusing the same key gives the SAME random numbers!
a = jax.random.normal(key, (3,))
b = jax.random.normal(key, (3,))
# a == b! This is intentional: same key = same result (purity)

# RIGHT: split the key to get new, independent subkeys
key, subkey1, subkey2 = jax.random.split(key, 3)
a = jax.random.normal(subkey1, (3,))
b = jax.random.normal(subkey2, (3,))
# a != b, and both are deterministic given the original key

# ── Pattern for training loops ──
def train_step(params, batch, key):
    # Split key for this step's randomness
    key, dropout_key, noise_key = jax.random.split(key, 3)

    def loss_fn(params):
        # Use dropout_key for dropout mask
        logits = model_forward(params, batch, dropout_key)
        return cross_entropy_loss(logits, batch['labels'])

    loss, grads = jax.value_and_grad(loss_fn)(params)
    params = update_params(params, grads)
    return params, loss, key  # Return the updated key!

# The key is threaded through the loop as explicit state.
# Every step gets a unique, deterministic key.
# The entire training run is perfectly reproducible.

# ── Why the split model matters for parallelism ──
# On 2048 TPUs, each device needs independent random numbers.
# With the split model, you split the key into 2048 subkeys
# at the start, one per device. No coordination needed.
keys = jax.random.split(jax.random.PRNGKey(42), 2048)
# Each device gets keys[device_id] — deterministic and independent.

Interactive: JAX Computation Graph Tracer

JAX Transform Visualizer

See how JAX transforms a simple function. Click each transform to see the computation graph at that stage.

A Complete JAX Training Loop

Let us put it all together with a complete, production-style training loop for a small MLP. This is the pattern you will see (with more complexity) in frontier lab training codebases.

python
import jax
import jax.numpy as jnp
from typing import NamedTuple

# ── Model definition (pure functions + parameter PyTree) ──

class MLPParams(NamedTuple):
    """Parameters for a 2-layer MLP."""
    W1: jax.Array   # [d_in, d_hidden]
    b1: jax.Array   # [d_hidden]
    W2: jax.Array   # [d_hidden, d_out]
    b2: jax.Array   # [d_out]

def init_params(key, d_in, d_hidden, d_out):
    """Xavier initialization. Key is split for each layer."""
    k1, k2 = jax.random.split(key)
    scale1 = jnp.sqrt(2.0 / (d_in + d_hidden))
    scale2 = jnp.sqrt(2.0 / (d_hidden + d_out))
    return MLPParams(
        W1=scale1 * jax.random.normal(k1, (d_in, d_hidden)),
        b1=jnp.zeros(d_hidden),
        W2=scale2 * jax.random.normal(k2, (d_hidden, d_out)),
        b2=jnp.zeros(d_out),
    )

def forward(params: MLPParams, x):
    """Forward pass for a single example x of shape (d_in,)."""
    h = jax.nn.silu(x @ params.W1 + params.b1)  # SiLU activation (like LLaMA)
    return h @ params.W2 + params.b2

# Batched forward: vmap over the data dimension, not over params
forward_batch = jax.vmap(forward, in_axes=(None, 0))

# ── Loss function ──
def loss_fn(params, x_batch, y_batch):
    """Cross-entropy loss for classification."""
    logits = forward_batch(params, x_batch)  # [B, n_classes]
    # Numerically stable log-softmax
    log_probs = jax.nn.log_softmax(logits, axis=-1)
    # Select the log-prob of the correct class
    loss = -jnp.mean(log_probs[jnp.arange(y_batch.shape[0]), y_batch])
    return loss

# ── Optimizer state (AdamW, simplified) ──
class AdamState(NamedTuple):
    m: MLPParams    # First moment (same tree structure as params)
    v: MLPParams    # Second moment
    step: jax.Array # Step counter for bias correction

def init_adam(params, lr=1e-3, beta1=0.9, beta2=0.999, wd=0.01):
    m = jax.tree.map(jnp.zeros_like, params)
    v = jax.tree.map(jnp.zeros_like, params)
    return AdamState(m=m, v=v, step=jnp.array(0))

def adam_update(params, grads, state, lr=1e-3, beta1=0.9, beta2=0.999, wd=0.01, eps=1e-8):
    step = state.step + 1
    # Update moments
    new_m = jax.tree.map(lambda m, g: beta1 * m + (1 - beta1) * g, state.m, grads)
    new_v = jax.tree.map(lambda v, g: beta2 * v + (1 - beta2) * g**2, state.v, grads)
    # Bias correction
    bc1 = 1 - beta1 ** step
    bc2 = 1 - beta2 ** step
    m_hat = jax.tree.map(lambda m: m / bc1, new_m)
    v_hat = jax.tree.map(lambda v: v / bc2, new_v)
    # AdamW: weight decay applied to params directly, not through gradient
    new_params = jax.tree.map(
        lambda p, m, v: p * (1 - lr * wd) - lr * m / (jnp.sqrt(v) + eps),
        params, m_hat, v_hat
    )
    return new_params, AdamState(m=new_m, v=new_v, step=step)

# ── Training step (compiled) ──
@jax.jit
def train_step(params, opt_state, x_batch, y_batch):
    loss, grads = jax.value_and_grad(loss_fn)(params, x_batch, y_batch)
    # Gradient clipping (max norm = 1.0)
    grad_norm = jnp.sqrt(sum(
        jnp.sum(g**2) for g in jax.tree.leaves(grads)
    ))
    clip_coeff = jnp.minimum(1.0, 1.0 / (grad_norm + 1e-6))
    grads = jax.tree.map(lambda g: g * clip_coeff, grads)
    # Update params
    new_params, new_opt_state = adam_update(params, grads, opt_state)
    return new_params, new_opt_state, loss, grad_norm

# ── Training loop ──
key = jax.random.PRNGKey(42)
key, init_key = jax.random.split(key)
params = init_params(init_key, d_in=784, d_hidden=256, d_out=10)
opt_state = init_adam(params)

for step in range(1000):
    key, data_key = jax.random.split(key)
    # (In real code, load from data pipeline)
    x_batch = jax.random.normal(data_key, (64, 784))
    y_batch = jax.random.randint(data_key, (64,), 0, 10)

    params, opt_state, loss, grad_norm = train_step(
        params, opt_state, x_batch, y_batch
    )
    if step % 100 == 0:
        print(f"Step {step}: loss={loss:.4f}, grad_norm={grad_norm:.4f}")
Production patterns to note. (1) The entire train_step is @jax.jit compiled — including the optimizer update and gradient clipping. This means the GPU runs the full step without returning to Python. (2) Gradient clipping is inside the jitted function, not outside — this avoids a device-to-host transfer of the gradient norm. (3) The key is split at each step and threaded through explicitly. (4) value_and_grad avoids computing the forward pass twice. These patterns matter at scale: on 2048 GPUs, Python overhead between steps can dominate training time if the step is fast enough.

Common JAX Gotchas

GotchaWhat happensFix
Value-dependent shapesx[x > 0] produces different-sized outputs for different inputs. jit cannot handle dynamic shapes.Use jnp.where(x > 0, x, 0) (fixed shape) or jax.lax.dynamic_slice
In-place mutationx[0] = 5 raises an error. JAX arrays are immutable.Use x = x.at[0].set(5) (returns a new array)
Print in jitprint(x) prints a tracer object, not a value.Use jax.debug.print("{x}", x=x) for runtime printing
Recompilation stormPassing different-shaped inputs to jit causes recompilation every call.Pad inputs to fixed shapes, or use jax.ensure_compile_time_eval()
NaN in gradjnp.where(cond, safe_val, jnp.log(x)) — gradient of log(x) is computed even when x ≤ 0.Use jnp.where(cond, safe_val, jnp.log(jnp.maximum(x, 1e-8)))
PRNG reuseUsing the same key twice gives identical "random" numbers.Always jax.random.split before each use
The NaN gradient trap deserves special attention. In JAX (and all AD systems), jnp.where(cond, a, b) computes gradients through BOTH branches, even the one not selected. If b = jnp.log(x) and x = 0, the gradient of log(0) is infinite — even though you "selected" branch a. This is because the gradient of where is grad_a * cond + grad_b * (1-cond), and inf * 0 = NaN. The fix: make sure both branches have safe gradients everywhere, not just safe values.
Staff-level question: You are reviewing a colleague's JAX training code. They define @jax.jit def train_step(params, batch, step_num): where step_num is a Python integer that increases each step. They use it for learning rate scheduling: lr = 0.001 * min(1.0, step_num / 1000). The training is mysteriously slow (each step takes 5 seconds instead of the expected 0.5 seconds). What is the bug, and how do you fix it?

Chapter 5: Quantization — The Full Landscape

It is 10:15 AM. The inference team Slack lights up: the new 70B model fits on a single 80 GB A100 in FP16 — barely. At 140 GB of weights, it overflows. Two options: buy a second GPU (double the cost, double the latency from tensor parallelism), or quantize. Shrink the weights from 16 bits to 4 bits, and the model fits with room to spare for the KV cache. But will the quality hold?

This is not a theoretical exercise. Every frontier lab ships quantized models. The question is how many bits and which method. Get it wrong and your chatbot starts hallucinating. Get it right and you serve 4x more users on the same hardware.

Why Memory Bandwidth, Not Compute

Here is the single most important insight in LLM inference: autoregressive decoding is memory-bandwidth-bound, not compute-bound. During generation, each token requires reading the entire weight matrix once (to compute one output vector), but the arithmetic intensity is tiny — you multiply a matrix by a single vector, not a matrix by a matrix.

Let us make this concrete. An A100 GPU has:

ResourceA100 80GB SXMH100 80GB SXM
FP16 Compute312 TFLOPS989 TFLOPS
Memory Bandwidth2.0 TB/s3.35 TB/s
Arithmetic Intensity Crossover156 FLOPs/byte295 FLOPs/byte

For a matrix-vector multiply W · x where W is (d, d), x is (d, 1): you read d² · 2 bytes of weights (FP16) and perform 2d² FLOPs. The arithmetic intensity is:

Arithmetic Intensity = 2d² FLOPs ÷ (2d² bytes) = 1 FLOP/byte

One. One FLOP per byte. The A100 needs 156 to be compute-bound. You are 156x below the threshold. The GPU spends virtually all its time waiting for weights to arrive from HBM. This means: if you halve the weight size, you nearly double generation speed. Not because you do less math — because you move less data.

The bandwidth wall. This is why quantization matters so much for LLM inference. Cutting from FP16 (16 bits) to INT4 (4 bits) means 4x less data to move from memory. On an A100, that is the difference between generating 30 tokens/second and 120 tokens/second for a 7B model. The compute is basically free — the bottleneck is the memory bus.

The Basics: How Quantization Works

Quantization maps continuous (or high-precision) values to a smaller set of discrete levels. A FP16 weight can represent ~65,536 distinct values. An INT4 weight can represent only 16. The art is choosing which 16 values best represent the original distribution.

Symmetric quantization maps the range [-α, α] to integer range [-2b-1, 2b-1 - 1], using a single scale factor:

xq = round(x / s)     where   s = max(|x|) / (2b-1 - 1)

Dequantization is simply x̂ = xq · s. This is fast because there is no zero-point offset, but wasteful when the distribution is asymmetric (e.g., mostly positive activations after ReLU).

Asymmetric quantization adds a zero-point z to handle shifted distributions:

xq = round(x / s) + z     where   s = (max(x) - min(x)) / (2b - 1)    z = round(-min(x) / s)

Dequantization: x̂ = (xq - z) · s. More accurate for skewed distributions, but the extra subtraction costs a few cycles per element.

The worked example. Suppose you have weights [-0.8, 0.3, 0.5, -0.1, 0.7] and want INT4 (4-bit, range 0..15 unsigned asymmetric). The min is -0.8, max is 0.7. Scale s = (0.7 - (-0.8)) / 15 = 0.1. Zero-point z = round(0.8 / 0.1) = 8. Quantized: [-0.8 → round(-0.8/0.1)+8 = 0, 0.3 → round(0.3/0.1)+8 = 11, 0.5 → 13, -0.1 → 7, 0.7 → 15]. Dequantized: [0-8)·0.1 = -0.8, (11-8)·0.1 = 0.3, (13-8)·0.1 = 0.5, (7-8)·0.1 = -0.1, (15-8)·0.1 = 0.7]. Perfect reconstruction in this case! But with more values and fewer bits, errors accumulate.

Granularity: Per-Tensor vs Per-Channel vs Per-Group

The choice of how many weights share a single (s, z) pair dramatically affects accuracy:

GranularityWhat shares (s, z)OverheadAccuracyWhen to use
Per-tensorAll weights in the entire matrix2 values totalWorst — outliers ruin scale for everyoneAlmost never for weights
Per-channelOne row (output channel)2 values per rowGood — each neuron gets its own rangeStandard for INT8
Per-groupg consecutive weights (g=32, 64, 128)2 values per groupBest — fine-grained adaptationStandard for INT4 and below

Per-group quantization with group size g = 128 means that for a weight matrix of shape (4096, 4096), you store 4096 · (4096/128) = 131,072 scale/zero-point pairs. At FP16 each, that is 0.5 MB of overhead versus the 32 MB of quantized weights. Tiny price for dramatically better accuracy.

LLM.int8(): The Outlier Discovery [full paper →]

In 2022, Tim Dettmers made a pivotal observation: transformer weights have emergent outlier features. A tiny fraction of hidden dimensions (as few as 6 out of 4096) have magnitudes 10-100x larger than the rest. These outliers appear in every layer, in the same feature dimensions, and they are critical for model quality. Quantize them to INT8 and the model breaks. Leave the rest in INT8 and the model is fine.

LLM.int8() implements mixed-precision decomposition:

1. Detect Outliers
Scan each input activation vector. Any feature dimension with magnitude > 6.0 is flagged as an outlier. Typically 0.1-1% of dimensions.
2. Decompose
Split the matrix multiply: X · WT = Xoutlier · WToutlier + Xnormal · WTnormal. The outlier part uses the original FP16 weights. The normal part uses INT8.
3. Compute
The INT8 matmul runs on tensor cores (fast). The FP16 matmul is tiny (few columns). Sum the results in FP16.
python
import torch

def llm_int8_matmul(X, W, threshold=6.0):
    """Mixed-precision decomposition from Dettmers et al. 2022."""
    # Step 1: find outlier feature dimensions
    outlier_mask = X.abs().max(dim=0).values > threshold  # shape: (d_in,)
    normal_mask = ~outlier_mask

    # Step 2: decompose into outlier and normal parts
    X_outlier = X[:, outlier_mask].half()   # (batch, d_outlier) in FP16
    W_outlier = W[outlier_mask, :].half()   # (d_outlier, d_out)  in FP16

    X_normal = X[:, normal_mask]            # (batch, d_normal)
    W_normal = W[normal_mask, :]            # (d_normal, d_out)

    # Step 3: quantize the normal part to INT8
    sx = X_normal.abs().max(dim=1, keepdim=True).values / 127
    sw = W_normal.abs().max(dim=0, keepdim=True).values / 127
    X_int8 = (X_normal / sx).round().to(torch.int8)
    W_int8 = (W_normal / sw).round().to(torch.int8)

    # Step 4: compute both parts
    out_normal = (X_int8.float() @ W_int8.float()) * (sx * sw)  # dequantize
    out_outlier = X_outlier @ W_outlier                          # FP16 matmul

    return (out_normal + out_outlier).half()
Why 6.0? Dettmers tested thresholds from 1.0 to 10.0. Below 6.0, too many features are treated as outliers (slow FP16 path dominates). Above 8.0, actual outliers get crushed by INT8 quantization. 6.0 hits the sweet spot: captures the problematic features while keeping 99%+ in the fast INT8 path. This threshold is remarkably consistent across model sizes from 1B to 175B.

The Chris De Sa Quantization Series: QuiPQuiP#QTIP

While LLM.int8() proved mixed-precision works at 8 bits, a group at Cornell led by Chris De Sa asked a bolder question: can we go to 2 bits per weight without destroying the model? The answer required ideas from information theory and lattice coding.

QuiP (2023) introduced two key ideas:

1. Incoherence processing. The problem with naive quantization is that some weight columns have much larger norms than others. If you apply the same quantization grid to all columns, the high-norm columns suffer. QuiP multiplies the weight matrix by random orthogonal matrices U and V: W̃ = U · W · VT. This "spreads" the information evenly across all entries (makes the matrix incoherent), so uniform quantization works better. You store U and V as random seeds (cheap) and undo the rotation during inference.

2. LDLQ (Lattice-based Data-aware Quantization). Instead of rounding each weight independently (which ignores correlations), LDLQ uses the inverse Hessian H-1 = (XTX)-1 to quantize weights in an order that accounts for their mutual information. The Hessian tells you which weight errors are cheap (dimensions the model rarely uses) and which are expensive (dimensions on which loss is sensitive). LDLQ performs LDLT decomposition of the Hessian and uses it to guide a greedy per-column rounding scheme similar to GPTQ.

min (W - Ŵ)T H (W - Ŵ)    subject to   Ŵij ∈ Q

This is a weighted quantization error where H = 2XTX is the Hessian of the layer's squared error. LDLQ solves it approximately via the LDLT factorization of H, processing one row at a time with adaptive rounding.

QuiP# (2024) made two improvements that turned QuiP from a proof of concept into a practical method:

1. Hadamard rotations instead of random orthogonal matrices. A random orthogonal matrix multiply costs O(d²). A Hadamard transform costs O(d log d) and achieves the same incoherence property. This is because Hadamard matrices have the special property that all entries have the same magnitude (1/√d), which is exactly the definition of maximal incoherence. QuiP# replaces the expensive random rotations with fast Walsh-Hadamard transforms — a huge speedup at inference time.

2. E8 lattice codebook. Instead of rounding to a uniform integer grid, QuiP# rounds to the nearest point in the E8 lattice — the densest known sphere packing in 8 dimensions. Why does this matter? The quantization error of a codebook depends on how efficiently it tiles space. A uniform grid in 8D wastes space in the corners of each hypercube. The E8 lattice packs more representable points into the same volume, reducing the average distance from any real weight vector to its nearest codebook entry. At 2 bits per weight, QuiP# with E8 lattice achieves perplexity close to what older methods achieve at 3 bits.

python
# Conceptual: E8 lattice quantization in 8D
# The E8 lattice consists of all vectors in R^8 whose coordinates
# are either all integers or all half-integers, and whose sum is even.

import numpy as np

def nearest_e8(x):
    """Find the nearest E8 lattice point to vector x in R^8."""
    # Two candidates: round to integers, or round to half-integers
    c1 = np.round(x)                           # nearest integers
    if np.sum(c1) % 2 != 0:                     # enforce even sum
        idx = np.argmin(np.abs(x - c1))         # flip the closest-to-boundary entry
        c1[idx] += 1 if x[idx] > c1[idx] else -1

    c2 = np.round(x - 0.5) + 0.5              # nearest half-integers
    if np.sum(c2) % 2 != 0:
        idx = np.argmin(np.abs(x - c2))
        c2[idx] += 1 if x[idx] > c2[idx] else -1

    # Pick the closer candidate
    d1 = np.sum((x - c1)**2)
    d2 = np.sum((x - c2)**2)
    return c1 if d1 <= d2 else c2

QTIP (2024) pushed even further. Instead of the E8 lattice (which is optimal in 8D but doesn't generalize easily), QTIP uses trellis codes — a technique borrowed from telecommunications. A trellis code defines a state machine where each transition corresponds to a quantized symbol. The Viterbi algorithm finds the sequence of quantized weights that minimizes total error while respecting the code constraints. This is analogous to how your cell phone decodes noisy signals, but applied to weight quantization.

The key advantage of trellis codes: they naturally capture sequential dependencies between adjacent weights. If weight[i] was rounded up, the trellis can compensate by rounding weight[i+1] down. This error diffusion is automatic and optimal (in the Viterbi sense).

AQLM: Additive Codebook Quantization [full paper →]

AQLM (2024) takes a completely different approach: multi-codebook additive quantization. Instead of quantizing each weight independently, AQLM groups weights into vectors of size g (typically 8) and represents each vector as the sum of entries from M learned codebooks:

ŵ = c1[i1] + c2[i2] + ... + cM[iM]

Each codebook has 2B entries (e.g., B=8 gives 256 entries per codebook). With M=2 codebooks, the total number of representable vectors is 256 × 256 = 65,536 — far more than the 16 levels of straight INT4. The indices (i1, i2) are stored as integers, and the codebooks are learned via the beam search variant of product quantization with calibration data.

The effective bit rate is (M · B) / g bits per weight. With M=2, B=8, g=8: (2 · 8) / 8 = 2 bits per weight. But the representational power is vastly greater than naive 2-bit quantization because the codebook entries are learned from the actual weight distribution.

python
# AQLM dequantization (simplified)
def aqlm_dequant(indices, codebooks, group_size=8):
    """
    indices: (n_groups, M) tensor of codebook indices
    codebooks: list of M tensors, each (2^B, group_size)
    Returns: (n_groups * group_size,) dequantized weights
    """
    M = len(codebooks)
    n_groups = indices.shape[0]
    result = torch.zeros(n_groups, group_size)
    for m in range(M):
        result += codebooks[m][indices[:, m]]  # additive: sum across codebooks
    return result.reshape(-1)

The Progression: 8-bit to 2-bit

Each paper in the quantization lineage solved a specific limitation of its predecessor:

MethodYearBitsKey IdeaLlama-2 70B PPL (WikiText)Speed vs FP16
FP16 baseline163.321.0x
LLM.int8()20228 (mixed)Outlier decomposition3.33~1.0x (overhead from decomposition)
GPTQ20224OBS-based layer-wise quant3.85~3.2x
AWQ20234Activation-aware scaling3.62~3.4x
QuiP20232-4Incoherence + LDLQ5.40 (2-bit)~1.5x (slow rotations)
QuiP#20242Hadamard + E8 lattice4.16~3.0x
AQLM20242Multi-codebook additive4.21~2.5x (lookup overhead)
QTIP20242Trellis codes + Viterbi3.94~2.8x
The trend line. In 2022, 8-bit quantization was considered aggressive. By 2024, 2-bit quantization matches what 4-bit achieved just two years earlier. The theoretical floor for lossless compression of transformer weights is estimated at ~1.5 bits per weight (based on the intrinsic dimensionality of the weight space). We are approaching it.
Weight Distribution & Quantization Grid

Visualizes a transformer weight distribution with outlier features (red spikes). Adjust bit width and see how the quantization grid resolves the distribution. Notice how outliers distort the grid at low bit widths.

Bit Width 4
Outlier Magnitude 8.0
Group Size 128

Practical Decision Tree

When a new model lands on your desk and you need to deploy it, here is the decision process a frontier lab engineer follows:

Does FP16 fit in VRAM?
Model size in bytes = params × 2 (FP16). A 70B model = 140 GB. Add ~30% overhead for KV cache and activations. If total < GPU memory, skip quantization.
↓ No
Is latency critical?
If yes: AWQ or GPTQ at 4-bit with per-group quantization (g=128). Fast kernels exist (exllama, Marlin). If throughput matters more than latency: FP8 with per-tensor scales (H100 native).
↓ Need fewer bits
How much quality loss can you tolerate?
< 0.5 PPL increase: QTIP or QuiP# at 3-bit. < 1.0 PPL increase: QTIP or AQLM at 2-bit. For evals that matter (your specific benchmark), always compare against FP16 — perplexity alone does not capture task-specific degradation.
↓ Need more quality
Quantization-aware training (QAT)
If post-training quantization is not enough, fine-tune with quantization in the loop (STE gradients through the rounding operation). 10-100x more expensive but recovers most lost quality. This is what Apple, Meta, and Google do for on-device models.

Implementation: Quantizing a Weight Matrix

python
import torch
import torch.nn.functional as F

def quantize_per_group(W, bits=4, group_size=128):
    """
    Per-group asymmetric quantization.
    W: (out_features, in_features) weight tensor
    Returns: quantized weights, scales, zero_points
    """
    out_f, in_f = W.shape
    assert in_f % group_size == 0, "in_features must divide group_size"
    n_groups = in_f // group_size

    # Reshape to (out_features * n_groups, group_size)
    W_grouped = W.reshape(out_f * n_groups, group_size)

    # Compute per-group min/max
    w_min = W_grouped.min(dim=1).values  # (out_f * n_groups,)
    w_max = W_grouped.max(dim=1).values

    # Compute scale and zero-point
    q_max = 2 ** bits - 1
    scale = (w_max - w_min) / q_max
    scale = torch.clamp(scale, min=1e-8)  # avoid division by zero
    zero_point = torch.round(-w_min / scale).clamp(0, q_max).to(torch.int8)

    # Quantize
    W_q = torch.round(W_grouped / scale.unsqueeze(1)) + zero_point.unsqueeze(1)
    W_q = W_q.clamp(0, q_max).to(torch.uint8)

    # Verify: dequantize and check error
    W_deq = (W_q.float() - zero_point.unsqueeze(1).float()) * scale.unsqueeze(1)
    error = (W_grouped - W_deq).pow(2).mean().sqrt()
    print(f"RMSE: {error:.6f}, relative: {error / W_grouped.abs().mean():.4f}")

    return W_q.reshape(out_f, in_f), scale.reshape(out_f, n_groups), zero_point.reshape(out_f, n_groups)

# Example: quantize a 4096x4096 matrix to 4-bit with group_size=128
W = torch.randn(4096, 4096) * 0.02  # typical init scale
# Inject outlier features (simulating the LLM.int8() discovery)
W[:, 42] *= 15   # feature 42 is an outlier
W[:, 1337] *= 12 # feature 1337 is an outlier
W_q, scales, zps = quantize_per_group(W, bits=4, group_size=128)
# RMSE: ~0.0029, relative: ~0.18
# Memory: 4096*4096*0.5 bytes (4-bit) + scales + zps ≈ 8.4 MB vs 33.6 MB (FP16)
The memory accounting. For a 70B parameter model at 4-bit with group_size=128: Weight storage = 70B × 0.5 bytes = 35 GB. Scale/zero-point overhead = 70B / 128 × 4 bytes (FP16 each) ≈ 2.2 GB. Total ≈ 37.2 GB. This fits on a single 80 GB A100 with ~40 GB remaining for KV cache (enough for ~32K context length). At FP16, the same model needs 140 GB — requiring at least two GPUs with tensor parallelism.
You have a 70B model that needs to run on a single 80 GB A100. The FP16 weights are 140 GB. You quantize to INT4 per-group (g=128) and the weights + metadata fit in ~37 GB. During generation, you observe that the model produces gibberish specifically on math-heavy prompts, but casual conversation is fine. What is the most likely explanation?

Chapter 6: Attention Kernels — FlashAttention 1 → 4 [attention lesson]

It is 11:30 AM. You are reviewing a profiling trace from the training cluster. Attention layers consume 42% of wall-clock time and, more importantly, 68% of peak HBM usage. The model processes sequences of length 8192, which means the attention matrix is 8192 × 8192 = 67 million entries per head, per layer. With 32 heads across 80 layers, that is 171 billion intermediate values materialized in memory. Your hardware budget says you need 32K context. Quadrupling the sequence length quadruples memory but sixteen-tuples the attention matrix. Something has to give.

This is the problem FlashAttention solves. Not by changing the math — the output is bit-for-bit identical to standard attention — but by changing how the math is executed on the hardware.

Standard Attention: The Memory Materialization Problem

The textbook attention computation proceeds in three steps:

S = Q KT / √d     P = softmax(S)     O = P V

Where Q, K, V have shape (N, d) — N tokens, d head dimension (typically 64 or 128). The problem is the intermediate matrix S (and P), which has shape (N, N). Let us trace the memory access pattern:

python
# Standard attention — what actually happens on the GPU

# Step 1: S = Q @ K.T / sqrt(d)
# Read Q (N×d) from HBM, read K (N×d) from HBM
# Compute matmul (fast on tensor cores)
# Write S (N×N) to HBM  ← THE BOTTLENECK
S = Q @ K.T / math.sqrt(d)  # O(N²d) FLOPs, O(N²) HBM writes

# Step 2: P = softmax(S, dim=-1)
# Read S (N×N) from HBM
# Compute softmax (element-wise, cheap)
# Write P (N×N) to HBM
P = torch.softmax(S, dim=-1)  # O(N²) FLOPs, O(N²) HBM reads + writes

# Step 3: O = P @ V
# Read P (N×N) from HBM, read V (N×d) from HBM
# Compute matmul
# Write O (N×d) to HBM
O = P @ V  # O(N²d) FLOPs, O(N²) HBM reads

Total HBM accesses: O(N²) reads and writes for S and P. For N = 8192, d = 128: the compute is 8192² × 128 × 2 ≈ 17.2 billion FLOPs (trivial for a GPU). But reading and writing S and P requires 8192² × 2 bytes × 2 (read + write) × 2 (for S and P) ≈ 1 GB of memory traffic. Per head. Per layer. This is the problem.

The O(N²) wall. Standard attention's memory usage grows as N². At 8K tokens: 128 MB per head. At 32K tokens: 2 GB per head. At 128K tokens: 32 GB per head. With 32 heads, that is 1 TB just for the attention matrices. This is why pre-FlashAttention transformers were limited to 2048 tokens. The constraint was never compute — it was memory.

FlashAttention-1 (2022): Tiling + Online Softmax

Tri Dao's key insight: never materialize the N × N matrix. Instead, compute attention in tiles that fit in SRAM (the GPU's fast on-chip memory, ~20 MB on A100 vs 80 GB HBM). Each tile computes a partial result and accumulates it, without ever writing the full S or P to HBM.

The challenge: softmax is a non-decomposable operation. To compute softmax(S[i, :]), you need max(S[i, :]) and sum(exp(S[i, :] - max)), which requires the entire row. If you are processing the row in blocks, you do not have the full row available. You need the online softmax trick.

Deriving Online Softmax

Standard softmax of row s = [s1, ..., sN]:

pi = exp(si - m) / ℓ     where   m = max(s)    ℓ = ∑j exp(sj - m)

Now suppose you have processed the first block of B columns and computed:

m(1) = max(s1, ..., sB)     ℓ(1) = ∑j=1B exp(sj - m(1))

Then you process the second block and find a new maximum m(2) = max(m(1), max(sB+1, ..., s2B)). The old partial sum ℓ(1) was computed with respect to m(1), not m(2). We need to rescale:

(new) = ℓ(1) · exp(m(1) - m(2)) + ∑j=B+12B exp(sj - m(2))

And the partial output O(1) (which was already multiplied by the old softmax weights) must also be rescaled:

O(new) = O(1) · exp(m(1) - m(2)) / (rescale factor) + new_block_contribution

This is the online softmax algorithm: maintain a running (m, ℓ, O) triple, and after each new block, rescale the accumulated result by exp(mold - mnew). The final result is mathematically identical to computing softmax on the full row.

python
def flash_attention_forward(Q, K, V, block_size=64):
    """
    Simplified FlashAttention-1 forward pass.
    Q, K, V: (N, d) tensors
    Returns O: (N, d) — same as standard attention
    """
    N, d = Q.shape
    O = torch.zeros(N, d)
    m = torch.full((N,), float('-inf'))  # running max per query
    l = torch.zeros(N)                    # running sum per query

    # Iterate over KV blocks
    for j in range(0, N, block_size):
        j_end = min(j + block_size, N)
        K_block = K[j:j_end]  # (B, d)
        V_block = V[j:j_end]  # (B, d)

        # Iterate over Q blocks (in practice, both loops are tiled)
        for i in range(0, N, block_size):
            i_end = min(i + block_size, N)
            Q_block = Q[i:i_end]  # (B, d)

            # Compute tile of S = Q_block @ K_block.T / sqrt(d)
            S_tile = Q_block @ K_block.T / (d ** 0.5)  # (B, B) — fits in SRAM!

            # Online softmax: update running stats
            m_new = torch.max(m[i:i_end], S_tile.max(dim=-1).values)
            correction = torch.exp(m[i:i_end] - m_new)
            P_tile = torch.exp(S_tile - m_new.unsqueeze(-1))
            l_new = l[i:i_end] * correction + P_tile.sum(dim=-1)

            # Rescale accumulated output and add new contribution
            O[i:i_end] = O[i:i_end] * correction.unsqueeze(-1) + P_tile @ V_block

            # Update running stats
            m[i:i_end] = m_new
            l[i:i_end] = l_new

    # Final normalization
    O = O / l.unsqueeze(-1)
    return O
Why this works. The N × N matrix S is never fully materialized. At any moment, only a B × B tile exists in SRAM (on-chip). For B = 64 and d = 128: the tile is 64 × 64 = 4096 entries = 8 KB. The SRAM can hold many such tiles simultaneously. The HBM I/O drops from O(N²) to O(N²d / M) where M is SRAM size. For an A100 with M = 20 MB: this is roughly a 10-20x reduction in memory traffic for typical sequence lengths.

IO Complexity: Why FlashAttention Is Faster

Let us be precise about the IO complexity. Define M = SRAM size in elements.

AlgorithmFLOPsHBM Reads/WritesExtra HBM for Intermediates
Standard AttentionO(N²d)O(N² + Nd)O(N²) — store S and P
FlashAttention-1O(N²d)O(N²d² / M)O(N) — only m and ℓ vectors

Same FLOPs. Dramatically less memory traffic. And the intermediate memory drops from O(N²) to O(N) — this is what enables 32K, 64K, 128K context lengths.

FlashAttention-2 (2023): Better Parallelism

FlashAttention-1 was IO-optimal but left GPU SMs underutilized. FA-2 made two key changes:

1. Swap the loop order. FA-1 iterates over KV blocks in the outer loop and Q blocks in the inner loop. FA-2 reverses this: iterate over Q blocks in the outer loop. Why? Each Q block now runs on a separate thread block (SM), and the inner KV loop is serial within that block. This means the number of thread blocks = N / BQ, which is much larger than the number of SMs, giving the GPU scheduler plenty to work with.

2. Reduce non-matmul FLOPs. On modern GPUs (A100, H100), matmul operations run on specialized tensor cores at 10-16x the throughput of general-purpose FP32 ops. The rescaling in online softmax uses exp() and division, which run on the slower CUDA cores. FA-2 carefully restructures the computation to minimize these non-matmul operations, achieving ~70% of theoretical peak matmul throughput (vs ~35% for FA-1).

3. Work partitioning within warps. FA-2 splits Q among warps but shares K/V across warps. This minimizes communication between warps (no cross-warp reduction needed for the softmax statistics) while maximizing the reuse of K/V from shared memory.

FlashAttention-3 (2024): Hopper Architecture

The H100 Hopper GPU introduced three hardware features that FA-3 exploits:

H100 FeatureWhat It DoesHow FA-3 Uses It
WGMMAWarp Group Matrix Multiply — issues matmul instructions across 4 warps (128 threads)Larger tiles, fewer instructions. A 128×128×64 matmul in one WGMMA call vs multiple HMMA calls on A100.
TMATensor Memory Accelerator — hardware unit for async memory copiesOverlaps memory loads with computation. While WGMMA computes on tile N, TMA prefetches tile N+1 from HBM to shared memory.
FP8 Tensor CoresNative 8-bit floating-point matmul at 2x FP16 throughputFA-3 supports FP8 attention: Q and K in FP8, with careful handling of the softmax numerics (still in FP32 for stability).

Warp specialization. FA-3 assigns different roles to different warp groups: "producer" warps issue TMA loads, "consumer" warps issue WGMMA instructions. They communicate through shared memory barriers. This is a software pipeline: while consumers compute on data that producers loaded in the previous step, producers are simultaneously loading the next batch. The result: near-zero memory latency hiding.

pseudocode
// FA-3 warp-specialized pipeline (conceptual)
// Two warp groups: PRODUCER and CONSUMER

// PRODUCER warp group:
for tile in kv_tiles:
    TMA_async_load(K_tile[next], smem_K[next_buf])   // hardware-accelerated
    TMA_async_load(V_tile[next], smem_V[next_buf])
    signal(barrier[next_buf])                          // tell consumer data is ready
    wait(barrier[curr_buf])                            // wait for consumer to finish
    swap(curr_buf, next_buf)

// CONSUMER warp group:
for tile in kv_tiles:
    wait(barrier[curr_buf])                            // wait for producer to load
    S = WGMMA(Q_reg, smem_K[curr_buf])                // 128×128×64 matmul
    P = online_softmax(S, &m, &ℓ)                   // on CUDA cores
    O += WGMMA(P, smem_V[curr_buf])                   // accumulate output
    signal(barrier[curr_buf])                          // tell producer buffer is free
    swap(curr_buf, next_buf)

FlashAttention-4 (2025): Blackwell Architecture [full paper →]

The B200 Blackwell GPU doubles down on the trends from Hopper, and FA-4 exploits new capabilities:

CuTe DSL. FA-4 is written using NVIDIA's CuTe (Cute Templates) abstraction, which represents tiles as first-class objects with layout algebra. Instead of manually indexing into shared memory with pointer arithmetic, you declare a tile's layout (e.g., "128 rows × 64 cols, row-major, swizzled by 128 bytes") and CuTe generates the correct indexing code. This eliminates an entire class of shared memory bank conflict bugs.

Pingpong scheduling. Blackwell has even larger shared memory (228 KB per SM vs 228 KB on Hopper, but with a new partitioned structure). FA-4 uses a "pingpong" pattern where two warp groups alternate between computing and loading, with triple-buffering to keep both the compute pipeline and the memory pipeline fully saturated at all times.

FP4 support. Blackwell's tensor cores natively support FP4 matmul. FA-4 can run the Q×KT multiply in FP4 (with FP32 accumulation), achieving 4x the throughput of FP16 at the cost of some precision. For inference with quantized models, this is a game-changer.

The Evolution in One Table

VersionYearGPU TargetKey Innovation% of Peak FLOPsSpeedup vs Standard
StandardAnyBaseline (materializes N×N)~15%1.0x
FA-12022A100Tiling + online softmax~35%2-4x
FA-22023A100/H100Q-outer loop, warp partitioning~70%3-5x
FA-32024H100 HopperWGMMA + TMA + warp specialization + FP8~80%1.5-2x over FA-2 on H100
FA-42025B200 BlackwellCuTe DSL + pingpong scheduling + FP4~85%~2x over FA-3 on B200
The compounding gains. From standard attention to FA-4: roughly 15-20x faster, O(N) memory instead of O(N²), and bit-for-bit identical outputs (for the same precision). This is what enabled the jump from 2K to 128K to 1M+ context windows. None of the math changed — the innovation was entirely in how the computation is mapped to the hardware memory hierarchy.
Attention Tiling Visualization

Watch how FlashAttention processes the Q×KT matrix in tiles without materializing the full N×N matrix. The highlighted tile is the current computation in SRAM. Standard attention would need the entire grid in HBM.

Sequence Length (N) 8
Tile Size (B) 2

Backward Pass: Why It's Harder

The forward pass avoids materializing S and P, but the backward pass needs them for the gradient computation. FA-1 and FA-2 solve this by recomputing S and P from Q, K, V during the backward pass, which trades extra FLOPs for HBM savings. Since the forward pass is memory-bound (extra FLOPs are free), this recomputation is essentially costless. FA-3 further optimizes the backward by keeping a small auxiliary tensor (the log-sum-exp per row, size N) from the forward pass to avoid some redundant work.

The log-sum-exp trick. During the forward pass, FlashAttention stores L[i] = m[i] + log(ℓ[i]) for each query row i. This single scalar per row encodes both the max and the normalization constant. During the backward pass, you can reconstruct the softmax values from S (recomputed), L (stored), and O (stored), without ever materializing P. The memory overhead: N floats per head — negligible compared to the N² of standard attention.

Practical Integration

python
# Using FlashAttention in practice (PyTorch 2.0+)
import torch
import torch.nn.functional as F

# Option 1: PyTorch native SDPA (auto-selects FlashAttention when possible)
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
    out = F.scaled_dot_product_attention(Q, K, V, is_causal=True)

# Option 2: Direct flash-attn library
from flash_attn import flash_attn_func

# Q, K, V shape: (batch, seqlen, nheads, headdim)
out = flash_attn_func(Q, K, V, causal=True)

# Option 3: FA-3 with FP8 on H100
out = flash_attn_func(
    Q.to(torch.float8_e4m3fn),   # FP8 E4M3 format
    K.to(torch.float8_e4m3fn),
    V.to(torch.float8_e4m3fn),
    causal=True,
    softmax_scale=1.0 / (d ** 0.5),
)

# Memory comparison for sequence length 32768, 32 heads, dim 128:
# Standard:  32768² × 32 × 2 bytes = 64 GB (!) — doesn't fit on any GPU
# FlashAttn: 32768 × 32 × 4 bytes  = 4 MB  (just the L vector)
FlashAttention-1 computes mathematically identical results to standard attention but uses O(N) memory instead of O(N²). The core trick is "online softmax." Why can't you just compute softmax independently on each tile of S without the rescaling step?

Chapter 7: Kernel DSLs — The 2025 Explosion [full lesson →]

It is 1:15 PM, back from lunch. A Slack thread is debating which framework to use for a custom fused kernel: someone advocates Triton, another insists on CUTLASS, a third says ThunderKittens is the only way to get 80%+ utilization on Hopper. The team lead posts a screenshot of a timeline from an AI infrastructure newsletter showing eleven new kernel DSLs released in the past twelve months. "We need a decision by Friday," she writes.

This explosion is not arbitrary. It reflects a fundamental tension in GPU programming that every frontier lab grapples with: productivity vs performance. Writing CUDA by hand gives you total control over shared memory layout, warp scheduling, and instruction selection. But a single fused attention kernel can take 2,000+ lines of CUDA and weeks to debug. Triton gives you 100 lines of Python for the same kernel, but the compiler might leave 30% of the hardware idle because it made suboptimal scheduling choices. Every DSL in the 2025 explosion stakes out a different position on this tradeoff.

The Productivity-Performance Spectrum

Let us map the landscape. At one end: raw CUDA/PTX, where you control every register and every instruction. At the other: PyTorch's autograd, where you write math and hope the compiler figures out the rest.

LevelToolAbstractionLines for Fused GEMM+ReLU% Peak FLOPS (A100)
0: HardwarePTX / SASSIndividual instructions~300095%+
1: Low-levelCUDAThreads, warps, shared memory~80085-90%
2: TemplateCUTLASS / CuTeTiles, layouts, MMA descriptors~20080-90%
3: Register-tileThunderKittens16×16 register tiles as objects~10075-85%
4: CompilerTritonBlock-level programs~5065-80%
5: Auto-scheduleTileLang, HelionDeclarative tiles + auto-tuning~3060-75%
6: GraphPyTorch compileMath expressions~540-60%
The 80% rule. For most frontier lab applications, getting to 80% of peak FLOPs is good enough. The last 20% requires heroic effort (hand-tuned CUDA, architecture-specific scheduling). The question is: which tool gets you to 80% fastest? This is where the explosion comes from — different teams answered this question differently and built their own DSLs.

CUDA: The Baseline

Every kernel ultimately compiles to CUDA (or rather, to PTX and then SASS). Understanding the CUDA programming model is non-negotiable for a frontier lab engineer, even if you rarely write raw CUDA yourself.

The key abstractions: threads execute scalar operations. Warps (32 threads) execute in lockstep (SIMT). Thread blocks (up to 1024 threads) share SRAM. Grids of thread blocks fill the GPU. The programmer must manually manage:

cuda
// CUDA GEMM kernel — simplified, no tiling optimizations
// This is what you're trying to AVOID writing by using a DSL
__global__ void matmul_naive(float* A, float* B, float* C,
                             int M, int N, int K) {
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    int col = blockIdx.x * blockDim.x + threadIdx.x;
    if (row < M && col < N) {
        float sum = 0.0f;
        for (int k = 0; k < K; k++) {
            sum += A[row * K + k] * B[k * N + col];
        }
        C[row * N + col] = sum;
    }
}
// Performance: ~5% of peak on A100. Why? No shared memory,
// no tiling, no tensor core usage, terrible memory access pattern.
// A production GEMM kernel is 50-100x more complex.

Triton: Python-Level Kernel Writing

Triton (by OpenAI, first released 2021, mature by 2023) lets you write GPU kernels in Python. The key abstraction is the block program: your code operates on blocks (tiles) of data, and the Triton compiler handles the mapping to warps, shared memory, and tensor cores.

python
import triton
import triton.language as tl

@triton.jit
def matmul_kernel(
    A_ptr, B_ptr, C_ptr,
    M, N, K,
    stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
    # Which tile of C does this program compute?
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    # Compute offsets for this tile
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)

    # Pointers to first tile of A and B
    a_ptrs = A_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
    b_ptrs = B_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn

    # Accumulate in FP32 for numerical stability
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

    # Loop over K dimension in blocks
    for k in range(0, K, BLOCK_K):
        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k)
        acc += tl.dot(a, b)  # Triton automatically uses tensor cores!
        a_ptrs += BLOCK_K * stride_ak
        b_ptrs += BLOCK_K * stride_bk

    # Write result
    c_ptrs = C_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
    tl.store(c_ptrs, acc, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))

What the compiler does for you: shared memory allocation, double-buffering (loading the next tile while computing the current one), warp scheduling, tensor core instruction selection, predicated loads for boundary tiles. What it does NOT do well: cross-SM communication, complex warp-level synchronization, fine-grained instruction scheduling (which is why hand-tuned CUDA can still beat Triton by 10-20% on specific kernels).

CUTLASS / CuTe: C++ Template Metaprogramming

CUTLASS (NVIDIA) is a C++ template library for writing high-performance GPU kernels. Its modern incarnation, CuTe, introduces layout algebra — a type system for describing how data is arranged in memory. A layout is a function from logical coordinates to physical memory addresses.

c++
// CuTe layout example: a 128×64 tile in shared memory
// with 128-byte swizzle to avoid bank conflicts
using SmemLayout = Layout,
                          Stride<_64, _1>>;

// A swizzled version: XOR the row index with bytes-per-row
// This eliminates bank conflicts for column-major access
using SmemLayoutSwizzled = composition(
    Swizzle<3, 3, 3>{},  // 8-byte swizzle pattern
    SmemLayout{}
);

// The layout algebra handles indexing automatically:
// smem(row, col) returns the correct physical address
// even with the swizzle — no manual index math needed.

CUTLASS is what NVIDIA uses internally for cuBLAS. If you need absolute peak performance on NVIDIA hardware and are willing to invest in C++ template wizardry, it is the tool.

ThunderKittens: Register Tiles as First-Class Citizens [full paper →]

ThunderKittens (Stanford, 2024) takes a radical position: the right abstraction level is the register tile — a 16×16 matrix that lives in a warp's registers and maps directly to a single tensor core MMA instruction.

c++
// ThunderKittens: the entire abstraction in a few types
using namespace kittens;

// Register tiles: 16×16 matrices that live in warp registers
rt_fl<16, 16> a_tile, b_tile, c_tile;     // float register tiles
rt_bf<16, 16> a_bf, b_bf;                  // bfloat16 register tiles

// Shared memory tiles
st_bf<64, 64> smem_a, smem_b;              // shared memory tiles

// Operations map directly to hardware
mma_ABt(c_tile, a_bf, b_bf, c_tile);      // tensor core MMA: C += A @ B^T
load(a_bf, smem_a, row_offset, col_offset);// load from shared to register
store(smem_a, a_bf, row_offset, col_offset);// store from register to shared

The philosophy: "If it's not a 16×16 tile, you're not thinking at the right level." ThunderKittens forces you to decompose your algorithm into register-tile operations, which guarantees that the tensor cores are always engaged. The downside: you still manage shared memory layout and warp scheduling manually.

The 2025-2026 Explosion

Starting in mid-2024, a wave of new DSLs appeared, each targeting a specific niche in the productivity-performance space:

DSLOriginKey IdeaTarget Use Case
TileLangHKUST / TVMPython DSL compiled via TVM, multi-backend (CUDA, ROCm)Cross-platform kernel development
HelionMeta (PyTorch team)Python DSL with autotuning built in, integrated with torch.compilePyTorch ecosystem, research prototyping
DeepGEMMDeepSeekJIT-compiled GEMM library, no installation, pure headerFast FP8 GEMMs for MoE inference
Mirage / MPKCMUSuperoptimizer that searches over kernel implementationsAutomatically finding optimal kernels
ThunderKittens 2.0StanfordBlackwell support, TMA integration, multi-GPU tilesCutting-edge hardware utilization
CUDA Tile / TileIRNVIDIAOfficial tile-level IR, successor to CuTeNVIDIA's answer to the DSL explosion
TilusModular (Mojo team)Tile-level DSL compiled to MLIRMulti-hardware (GPU + TPU + custom)
Gluon / TLXVarious startupsHigher-level fusion DSLs that auto-tileNon-expert kernel writers
Why so many, why now? Three converging forces: (1) New hardware generations (Hopper, Blackwell, MI300X) each require different scheduling strategies, making old kernels suboptimal. (2) The rise of MoE, long-context attention, and quantized inference creates demand for custom kernels that generic libraries do not cover. (3) The success of Triton proved that compiler-based approaches can get within striking distance of hand-tuned code, inspiring many teams to try their own variation.

When to Use What: The Decision Tree

Is it a standard operation (GEMM, attention, conv)?
Use the vendor library: cuBLAS, cuDNN, or FlashAttention. Do not write a custom kernel for standard ops unless you have profiling evidence that the library is suboptimal for your specific shapes.
↓ No, it's custom
Do you need > 80% peak utilization?
If no: Triton. 50 lines, auto-tuning, works on A100/H100/MI300X. If yes: continue below.
↓ Yes, need peak perf
Is it a matmul variant (quantized GEMM, MoE dispatch, fused GEMM+activation)?
CUTLASS/CuTe (NVIDIA) or DeepGEMM (for FP8 MoE). These have the best matmul performance.
↓ It's not a matmul variant
Is it attention-like (custom masking, sliding window, sparse pattern)?
ThunderKittens or fork FlashAttention. Both provide the online-softmax + tiling infrastructure.
↓ It's something entirely new
Multi-hardware target (NVIDIA + AMD + TPU)?
TileLang or Tilus. They compile to multiple backends.

The Same Matmul, Three Ways

To make the tradeoffs concrete, here is the conceptual core of a tiled GEMM in three DSLs. All three compute C += A × B in tiles.

python — triton
# Triton: ~50 lines, compiler handles shared memory + scheduling
@triton.jit
def gemm(A, B, C, M, N, K, BM: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr):
    pid_m, pid_n = tl.program_id(0), tl.program_id(1)
    rm = pid_m * BM + tl.arange(0, BM)
    rn = pid_n * BN + tl.arange(0, BN)
    acc = tl.zeros((BM, BN), dtype=tl.float32)
    for k in range(0, K, BK):
        rk = k + tl.arange(0, BK)
        a = tl.load(A + rm[:, None] * K + rk[None, :])
        b = tl.load(B + rk[:, None] * N + rn[None, :])
        acc += tl.dot(a, b)
    tl.store(C + rm[:, None] * N + rn[None, :], acc)
c++ — cutlass cute
// CuTe: ~200 lines, explicit layout algebra, explicit pipeline
// (heavily abbreviated — real code has copy engines, pipeline barriers)
auto tiled_mma = make_tiled_mma(SM80_16x8x16_F32BF16BF16F32{},
                                Layout>{});
auto smem_a = make_tensor(smem_ptr, SmemLayoutA{});
auto smem_b = make_tensor(smem_ptr + smem_a_size, SmemLayoutB{});
for (int k = 0; k < K; k += BK) {
    copy(gmem_a_slice, smem_a);  // global → shared (via TMA on Hopper)
    copy(gmem_b_slice, smem_b);
    __syncthreads();
    gemm(tiled_mma, smem_a, smem_b, acc);  // tensor core MMA
    __syncthreads();
}
c++ — thunderkittens
// ThunderKittens: ~100 lines, register-tile-centric
using namespace kittens;
st_bf smem_a; st_bf smem_b;
rt_fl<16, 16> acc[BM/16][BN/16] = {};  // register tile grid
for (int k = 0; k < K; k += BK) {
    load(smem_a, A + k, K);    // HBM → shared
    load(smem_b, B + k * N, N);
    __syncthreads();
    for (int i = 0; i < BM/16; i++)
        for (int j = 0; j < BN/16; j++)
            mma_ABt(acc[i][j], smem_a.get_tile(i), smem_b.get_tile(j), acc[i][j]);
    __syncthreads();
}

Same algorithm. Three levels of abstraction. Triton is shortest but gives you least control over the schedule. CuTe is longest but lets you specify exactly how data flows through the memory hierarchy. ThunderKittens is in between — you think in register tiles, but the library handles tile loads and MMA instruction selection.

Kernel DSL Landscape

DSLs positioned on the productivity-performance axis. Hover/tap a node to see details. The vertical axis shows relative performance (% of peak FLOPs). The horizontal axis shows developer productivity (inverse of lines of code).

The Meta-Trend: Convergence

Despite the explosion, the DSLs are converging on a shared set of abstractions: tiles as first-class objects, pipeline stages for memory latency hiding, layout descriptors for swizzled shared memory, and auto-tuning for block-size selection. The eventual outcome is likely one or two dominant frameworks (Triton + CUTLASS/CuTe) with specialized tools for niche use cases.

The staff engineer's perspective. You do not need to master all eleven DSLs. You need to deeply understand ONE (Triton for most teams, CUTLASS for NVIDIA-focused teams), have working knowledge of the landscape to make technology choices, and know enough CUDA fundamentals to read and debug any of them. The decision of which DSL to adopt is a strategic choice that affects hiring (Triton has a larger talent pool than ThunderKittens), maintainability (Python beats C++ for most teams), and vendor lock-in (CUTLASS is NVIDIA-only).
Your team needs to write a custom kernel for a fused operation: quantized GEMM (INT4 weights, FP16 activations) + SiLU activation + residual add. The kernel must run on both NVIDIA H100 and AMD MI300X. You need at least 70% of peak FLOPs. Which tool is the best fit?

Chapter 8: KV Cache Optimization — The Memory Crisis

It is 2:30 PM. The product team wants to launch a feature that requires 128K context: users upload a 200-page document and ask questions about it. Your current serving stack handles 8K context comfortably. At 128K, the KV cache — the stored key and value tensors from previous tokens — consumes more memory than the model weights themselves. Your 70B model at FP16 uses 140 GB for weights. The KV cache at 128K tokens adds another 160 GB. You need a plan.

This chapter covers every major technique for taming the KV cache: architectural changes (MQA, GQA), system-level tricks (PagedAttention), intelligent eviction (SnapKV, H2O, StreamingLLM), and quantization. By the end, you will know exactly which technique to apply for which use case.

The KV Cache Memory Formula

Let us derive the exact memory cost. For a transformer with:

SymbolMeaningTypical (Llama-2 70B)
LNumber of layers80
nkvNumber of KV heads per layer8 (GQA) or 64 (MHA)
dhHead dimension128
SSequence length (current + generated)8,192 to 128,000+
bBytes per element2 (FP16) or 1 (FP8) or 0.5 (INT4)

Each layer stores K and V tensors of shape (nkv, S, dh). The total KV cache memory:

KV Memory = 2 · L · nkv · S · dh · b

The factor of 2 is for K and V. Let us plug in numbers:

python
def kv_cache_memory_gb(L, n_kv, d_h, S, bytes_per_elem=2):
    """Compute KV cache memory in GB."""
    total_bytes = 2 * L * n_kv * S * d_h * bytes_per_elem
    return total_bytes / (1024**3)

# Llama-2 70B with MHA (64 KV heads)
print(kv_cache_memory_gb(80, 64, 128, 8192))      # → 20 GB at 8K context
print(kv_cache_memory_gb(80, 64, 128, 32768))     # → 80 GB at 32K context
print(kv_cache_memory_gb(80, 64, 128, 131072))    # → 320 GB at 128K context!

# Llama-2 70B with GQA (8 KV heads) — 8x reduction
print(kv_cache_memory_gb(80, 8, 128, 8192))       # → 2.5 GB
print(kv_cache_memory_gb(80, 8, 128, 131072))     # → 40 GB — manageable!

# With INT4 KV quantization on top of GQA:
print(kv_cache_memory_gb(80, 8, 128, 131072, 0.5)) # → 10 GB — fits easily
The quadruple whammy. KV cache memory scales linearly with: (1) number of layers, (2) number of KV heads, (3) sequence length, (4) batch size. For a serving system with batch size 32 and 128K context: multiply the single-sequence number by 32. That is 1.28 TB for Llama-2 70B with MHA. This is why GQA was invented — it cuts the first multiplier by 8x. And why KV cache optimization is the single most impactful system-level problem in LLM inference.

Architectural Solutions: MHA → MQA → GQA

The most effective KV cache reduction is architectural — baked into the model during training:

Multi-Head Attention (MHA): standard transformer. Each of the nh query heads has its own K and V head. KV cache stores nh × S × dh per layer. This is what GPT-3, Llama-1, and Llama-2 7B/13B use.

Multi-Query Attention (MQA): all query heads share ONE K head and ONE V head. KV cache drops to 1 × S × dh per layer — an nh-fold reduction (typically 32-64x). Introduced by Noam Shazeer in 2019. Used by PaLM, Falcon, StarCoder. The downside: some quality degradation because all heads are forced to attend to the same keys/values.

Grouped-Query Attention (GQA): the compromise. Group every g query heads to share one KV head. With nh = 64 query heads and g = 8 groups: nkv = 8 KV heads. An 8x reduction over MHA with minimal quality loss. Used by Llama-2 70B, Llama-3, Mistral, Gemma. GQA is now the de facto standard for large models.

SchemeKV HeadsKV Cache Size (relative)Quality ImpactModels Using It
MHAnh (e.g., 64)1.0xBaselineGPT-3, Llama-1, Llama-2 7B/13B
GQAnh/g (e.g., 8)1/g (e.g., 0.125x)Negligible (<0.5% PPL)Llama-2 70B, Llama-3, Mistral
MQA11/nh (e.g., 0.016x)Small (1-2% PPL)PaLM, Falcon, StarCoder

PagedAttention: Virtual Memory for KV Caches

Even with GQA, serving multiple concurrent requests requires careful memory management. The naive approach: pre-allocate a contiguous buffer of max_seq_len × n_kv × d_h for each request. If max_seq_len = 128K but the average request uses 4K tokens, you waste 97% of the allocated memory. With 32 concurrent requests, the waste is hundreds of GB.

PagedAttention (vLLM, 2023) borrows the virtual memory concept from operating systems. Instead of contiguous buffers, KV cache entries are stored in fixed-size pages (blocks) of, say, 16 tokens. Pages are allocated on demand as the sequence grows, and reclaimed when the sequence ends. Pages need not be contiguous in physical GPU memory — a page table maps logical token positions to physical page addresses.

python
# PagedAttention: conceptual implementation
class PagedKVCache:
    def __init__(self, page_size=16, n_layers=80, n_kv_heads=8, d_head=128):
        self.page_size = page_size
        self.n_layers = n_layers
        self.n_kv_heads = n_kv_heads
        self.d_head = d_head

        # Physical page pool: pre-allocate all GPU memory as pages
        self.total_pages = 65536  # e.g., 65K pages
        # Each page: (n_layers, 2, n_kv_heads, page_size, d_head) — K and V
        self.page_pool = torch.zeros(
            self.total_pages, n_layers, 2, n_kv_heads, page_size, d_head,
            dtype=torch.float16, device="cuda"
        )
        self.free_pages = list(range(self.total_pages))  # free list

        # Per-sequence page tables: seq_id → list of page indices
        self.page_tables = {}  # Dict[int, List[int]]

    def allocate_page(self, seq_id):
        """Allocate a new page for a sequence (called when it grows)."""
        if not self.free_pages:
            raise MemoryError("No free KV cache pages")
        page_idx = self.free_pages.pop()
        if seq_id not in self.page_tables:
            self.page_tables[seq_id] = []
        self.page_tables[seq_id].append(page_idx)
        return page_idx

    def free_sequence(self, seq_id):
        """Release all pages for a completed sequence."""
        for page_idx in self.page_tables.pop(seq_id, []):
            self.free_pages.append(page_idx)

    def get_kv(self, seq_id, layer, token_pos):
        """Read K/V for a specific token position."""
        page_num = token_pos // self.page_size
        offset = token_pos % self.page_size
        page_idx = self.page_tables[seq_id][page_num]
        k = self.page_pool[page_idx, layer, 0, :, offset, :]  # (n_kv_heads, d_head)
        v = self.page_pool[page_idx, layer, 1, :, offset, :]
        return k, v

    def utilization(self):
        used = self.total_pages - len(self.free_pages)
        return used / self.total_pages

# Without paging: 32 requests × 128K tokens × ~0.31 GB each = ~10 GB
# With paging: only allocate pages for actual tokens used
# If average length is 4K: 32 × 4K × 0.31/128K × GB ≈ 0.31 GB — 32x saving!
Copy-on-write for beam search. PagedAttention enables copy-on-write semantics: when beam search creates branches, all beams share the same pages for the prefix (which is identical). A new page is only allocated when a beam diverges. For beam width 4 with 90% shared prefix, this cuts KV cache memory by ~75%. The same trick enables prefix caching: if multiple requests share the same system prompt, their KV cache pages for the prompt are shared (read-only), allocated only once.

SnapKV: Attention-Guided KV Compression [full paper →]

SnapKV (2024) made a key observation: in long-context attention, most tokens receive negligible attention weight. A small set of "important" tokens dominate the attention pattern, and this set can be identified cheaply by looking at the attention pattern in a few observation layers near the end of the prefix.

The algorithm:

1. Identify Observation Window
Take the last W tokens of the prompt as the "observation window" (W = 32 or 64). These tokens' attention patterns reveal which prefix tokens are important.
2. Vote on Important Tokens
For each attention head, compute attention scores from the observation window to all prefix tokens. Sum across the observation window positions to get a "vote" score per prefix token per head.
3. Select Top-K
Keep the top-k prefix tokens (per head) based on vote scores. Concatenate with the observation window tokens. Discard the rest.
4. Compressed KV Cache
The KV cache now contains only k + W tokens instead of S tokens. For k = 1024 and W = 64 on a 128K context: that is a 117x compression ratio.
python
def snapkv_compress(K, V, attn_scores, k=1024, window=64):
    """
    SnapKV: attention-guided KV cache compression.
    K, V: (n_heads, seq_len, d_head)
    attn_scores: (n_heads, window, seq_len) — attention from last `window` tokens
    """
    n_heads, seq_len, d_head = K.shape
    prefix_len = seq_len - window

    # Step 1: attention scores from observation window to prefix tokens
    prefix_scores = attn_scores[:, :, :prefix_len]  # (n_heads, window, prefix_len)

    # Step 2: vote — sum attention across observation window positions
    votes = prefix_scores.sum(dim=1)  # (n_heads, prefix_len)

    # Step 3: top-k selection per head
    _, top_indices = votes.topk(k, dim=-1)  # (n_heads, k)
    top_indices, _ = top_indices.sort(dim=-1)  # maintain order

    # Step 4: gather selected K/V + keep observation window
    K_prefix = K.gather(1, top_indices.unsqueeze(-1).expand(-1, -1, d_head))
    V_prefix = V.gather(1, top_indices.unsqueeze(-1).expand(-1, -1, d_head))

    K_window = K[:, prefix_len:, :]  # last `window` tokens — always kept
    V_window = V[:, prefix_len:, :]

    K_compressed = torch.cat([K_prefix, K_window], dim=1)  # (n_heads, k+window, d_head)
    V_compressed = torch.cat([V_prefix, V_window], dim=1)

    return K_compressed, V_compressed  # seq_len reduced from S to k+window

H2O: Heavy Hitter Oracle

H2O (2023) takes a different approach to KV eviction. Instead of a one-shot compression at prompt time, H2O maintains a running eviction policy during generation. It tracks the cumulative attention score for each token across all generation steps. Tokens that consistently receive high attention ("heavy hitters") are kept. Tokens with low cumulative scores are evicted.

python
class H2OCache:
    """Heavy Hitter Oracle: cumulative attention-based eviction."""
    def __init__(self, max_size=2048, window_size=256):
        self.max_size = max_size
        self.window_size = window_size  # recent tokens always kept
        self.cumulative_scores = {}     # token_pos → cumulative attention

    def update_and_evict(self, attn_weights, K, V):
        """
        attn_weights: (n_heads, 1, current_cache_size) — attention from new token
        K, V: current KV cache tensors
        """
        # Update cumulative scores
        scores = attn_weights.mean(dim=0).squeeze()  # average across heads
        for pos in range(len(scores)):
            self.cumulative_scores[pos] = self.cumulative_scores.get(pos, 0) + scores[pos].item()

        # Check if eviction needed
        current_size = K.shape[1]
        if current_size <= self.max_size:
            return K, V

        # Protect recent tokens (sliding window)
        n_evictable = current_size - self.window_size
        evictable_scores = {pos: self.cumulative_scores[pos]
                           for pos in range(n_evictable)}

        # Evict the token with lowest cumulative attention
        victim = min(evictable_scores, key=evictable_scores.get)
        keep_mask = torch.ones(current_size, dtype=torch.bool)
        keep_mask[victim] = False

        return K[:, keep_mask], V[:, keep_mask]

The key insight: in practice, attention patterns follow a power law. A small set of tokens (typically the first few tokens, punctuation, and key content words) receive the lion's share of attention. H2O exploits this by keeping the heavy hitters and evicting the long tail.

StreamingLLM: Attention Sinks + Sliding Window

StreamingLLM (2023) discovered a remarkable phenomenon: the first few tokens in any sequence always receive disproportionately high attention, regardless of their content. These are called attention sinks. The hypothesis is that softmax needs somewhere to dump attention mass when no token is particularly relevant, and position 0 becomes the default "sink."

StreamingLLM's algorithm is dead simple:

Keep First 4 Tokens
Always retain the first 4 tokens as attention sinks. These anchor the attention pattern regardless of context length.
+
Sliding Window of Last W Tokens
Keep the most recent W tokens (e.g., W = 2044 for a total cache of 2048). Evict everything in between.
=
Infinite Streaming
The model can process infinite-length streams with a fixed-size cache. Quality degrades gracefully — you lose access to middle-context information, but local coherence is maintained.

StreamingLLM does NOT help with tasks that require accessing information from the evicted middle (e.g., "what was mentioned in paragraph 5 of a 100-paragraph document?"). It is designed for streaming use cases: chat, real-time transcription, monitoring dashboards.

Quantized KV Caches

Orthogonal to eviction strategies, you can quantize the KV cache itself from FP16 to INT8 or INT4. This is different from weight quantization because:

1. KV values change every token. You cannot pre-compute quantization parameters. You need an online quantization scheme that is fast enough to run per-token.

2. The error budget is different. Weight quantization errors are fixed — the same error on every inference. KV quantization errors accumulate: an error in the cached K at position 100 affects every subsequent attention computation forever. This makes KV quantization more sensitive.

3. The distribution is friendlier. KV cache values tend to be more uniformly distributed (post-RoPE, K values are roughly Gaussian) compared to weights (which have those pesky outlier features). INT8 KV quantization with per-head scales typically causes <0.1% perplexity increase. INT4 requires per-token-group scales and causes 0.3-1% increase.

python
def quantize_kv_cache(K, V, bits=8):
    """
    Per-head quantization of KV cache.
    K, V: (n_heads, seq_len, d_head) in FP16
    Returns quantized tensors + per-head scales
    """
    q_max = 2 ** (bits - 1) - 1  # symmetric: -127..127 for INT8

    # Per-head scale (one scale per head)
    K_absmax = K.abs().amax(dim=(1, 2), keepdim=True)  # (n_heads, 1, 1)
    V_absmax = V.abs().amax(dim=(1, 2), keepdim=True)

    K_scale = K_absmax / q_max
    V_scale = V_absmax / q_max

    K_q = (K / K_scale).round().clamp(-q_max, q_max).to(torch.int8)
    V_q = (V / V_scale).round().clamp(-q_max, q_max).to(torch.int8)

    # Memory: INT8 = 1 byte vs FP16 = 2 bytes → 2x compression
    # INT4 (packed) = 0.5 bytes → 4x compression
    return K_q, V_q, K_scale, V_scale

def dequantize_for_attention(K_q, K_scale, Q):
    """Dequantize K on-the-fly during attention computation."""
    K_deq = K_q.float() * K_scale  # dequant before Q @ K^T
    attn = Q @ K_deq.transpose(-1, -2)
    return attn

Combining Techniques: The Full Strategy

In practice, frontier labs combine multiple techniques. Here is the typical stack for long-context serving:

LayerTechniqueCompression RatioQuality Impact
ArchitectureGQA (8 groups)8xNegligible (trained in)
Memory managementPagedAttention2-4x (reduces waste)Zero (exact)
PrecisionINT8 KV cache2x<0.1% PPL increase
EvictionSnapKV (top-k=2048)Up to 64x on 128K~1-3% quality loss on long-range tasks
TotalAll combined~128-512xDepends on task type
The practical impact. A Llama-2 70B with MHA at 128K FP16: 320 GB KV cache per sequence. With the full stack (GQA + PagedAttention + INT8 + SnapKV top-2048): approximately 0.6 GB per sequence. That is 530x compression. Enough to serve 100+ concurrent 128K conversations on a single 8×H100 node.
KV Cache Memory vs Sequence Length

Compare KV cache memory usage across different optimization strategies as sequence length grows. Toggle techniques on/off to see their individual and combined impact.

Max Sequence Length 32768

The "Barbarians at the Gate" Perspective [full paper →]

The phrase "Barbarians at the Gate" in the context of KV caches refers to the growing realization that the attention mechanism itself may need to change. All the techniques above are engineering patches on a fundamentally O(N) memory mechanism (linear in sequence length, per layer). State-space models (Mamba, RWKV), linear attention variants, and hybrid architectures are the "barbarians" — they offer O(1) memory per token by replacing the KV cache with a fixed-size recurrent state.

The counterargument: the KV cache is not just overhead. It IS the model's explicit memory. Eviction techniques lose information. Recurrent states compress information lossy. The KV cache is a perfect, lossless memory of every token the model has seen. The frontier lab engineer's job is to make this memory affordable, not to throw it away.

The hybrid future. The likely outcome is hybrid architectures: a few attention layers with full KV cache for tasks requiring precise recall (retrieval, citation, code), and many SSM/linear-attention layers with constant-size state for bulk processing. Jamba (AI21) and Zamba are early examples. The KV cache optimization techniques in this chapter remain critical for the attention layers in these hybrids.

The Engineer's Cheat Sheet

ScenarioBest KV Cache StrategyWhy
Chatbot, 4K contextGQA + PagedAttentionSmall cache, no compression needed
Document QA, 32K contextGQA + Paged + INT8 KVModerate compression, zero quality loss
Book summarization, 128KGQA + Paged + INT8 + SnapKV (k=4096)Heavy compression needed; SnapKV preserves retrieval-relevant tokens
Infinite streaming (chat)StreamingLLM (sinks + window)Fixed memory, infinite context, but no recall of old turns
Serving 1000+ concurrent usersGQA + Paged + INT4 KV + prefix cachingMaximize throughput; shared system prompt saves 30%+ memory
Beam search (width 4)PagedAttention with copy-on-writeBeams share prefix pages, 75% savings
You are serving a 70B GQA model (8 KV heads, 80 layers, d=128) to 50 concurrent users with 32K context. Each user's KV cache in FP16 is ~2.5 GB. Total: 125 GB, which exceeds your 80 GB A100. You apply INT8 KV quantization (2x compression). Now each user needs ~1.25 GB, total 62.5 GB. It fits! But a PM asks for 128K context support without adding GPUs. What is the best next step?

Chapter 9: Training at Scale [dist. training lesson]

You have a 70-billion parameter model. Each parameter is a BF16 number: 2 bytes. The model weights alone occupy 70B × 2 = 140 GB. A single H100 has 80 GB of HBM. The model does not even fit on one GPU — and we have not yet accounted for optimizer states, gradients, or activations.

This is the fundamental problem of large-scale training: the model, its optimizer state, its gradients, and its activations are collectively far too large for any single device. The solution is to split the work across many GPUs. But how you split matters enormously. A naive approach can leave 90% of your GPUs idle, waiting for communication. A well-engineered approach can achieve 40-55% of peak FLOPS utilization — and the difference between 20% and 50% MFU on a 2048-GPU cluster is millions of dollars in wasted compute.

This chapter covers the four pillars of distributed training: data parallelism, fully sharded data parallelism (FSDP), tensor parallelism, and pipeline parallelism. By the end, you will be able to derive the optimal parallelism configuration for any model-cluster pair.

The Memory Budget: What Goes Where

Before choosing a parallelism strategy, you must know what consumes memory. Let us derive the full breakdown for a 70B parameter model trained in BF16 with AdamW optimizer:

Model weights: 70B × 2 bytes (BF16) = 140 GB

Gradients: 70B × 2 bytes (BF16) = 140 GB

Optimizer states (AdamW):
  First moment (m): 70B × 4 bytes (FP32) = 280 GB
  Second moment (v): 70B × 4 bytes (FP32) = 280 GB
  FP32 master weights: 70B × 4 bytes = 280 GB
  Total optimizer: 840 GB

Grand total (excluding activations): 140 + 140 + 840 = 1,120 GB

Per-GPU HBM (H100): 80 GB
Minimum GPUs just for static memory: 1120 / 80 = 14 GPUs
With activations and fragmentation (~1.5x): ~21 GPUs minimum

Notice the optimizer dominates: 840 GB out of 1,120 GB. This is why ZeRO and FSDP focus primarily on sharding optimizer states.

The 16x rule of thumb. A model with P parameters in BF16 needs roughly 16P bytes of memory for training with AdamW. For 70B parameters: 70B × 16 = 1,120 GB. Memorize this multiplier — it lets you instantly estimate cluster requirements. Inference needs only 2P bytes (weights only) or 4P for FP32.

Strategy 1: Data Parallelism (DDP)

Data Parallelism is the simplest distributed strategy. Every GPU holds a complete copy of the model. Each GPU processes a different mini-batch. After the forward and backward passes, gradients are synchronized across all GPUs via an all-reduce operation, and every GPU applies the same optimizer step.

1. Replicate
Every GPU holds full model weights, optimizer states, and gradients. For 70B in BF16, each GPU needs 1,120 GB. This clearly does not work for large models — DDP requires the model to fit on a single GPU.
2. Partition Data
The training dataset is split into N shards (one per GPU). Each GPU processes its own mini-batch of size B/N, where B is the global batch size.
3. Forward + Backward
Each GPU independently computes the forward pass and backward pass on its local mini-batch. Gradients are computed locally.
4. All-Reduce Gradients
Gradients are averaged across all N GPUs using an all-reduce collective. After this, every GPU has identical averaged gradients.
5. Optimizer Step
Every GPU applies the same optimizer update. Since gradients are identical, all replicas stay in sync. No model divergence.

DDP works beautifully for models that fit on one GPU (up to ~10B parameters in BF16 on an 80GB H100). The communication cost is a single all-reduce of the gradient tensor per step.

All-Reduce: The Core Communication Primitive

The all-reduce operation takes a tensor that exists on every GPU, sums (or averages) them, and distributes the result back to every GPU. The efficient implementation is the ring all-reduce, which works in two phases:

Phase 1: Reduce-Scatter. Each GPU sends 1/N of its gradient tensor to the next GPU in the ring. After N-1 steps, each GPU holds the fully reduced sum for its 1/N shard.

Phase 2: All-Gather. Each GPU sends its reduced shard around the ring. After N-1 steps, every GPU holds the complete reduced result.

The total data transferred per GPU in a ring all-reduce is:

Ring All-Reduce Communication Volume per GPU:

Reduce-Scatter: each GPU sends (N-1)/N × D bytes
All-Gather: each GPU sends (N-1)/N × D bytes

Total per GPU = 2 × (N-1)/N × D

Where D = total gradient size in bytes, N = number of GPUs

Key insight: as N → ∞, the per-GPU volume approaches 2D.
This is independent of cluster size! Ring all-reduce scales nearly perfectly.
Why ring all-reduce is magical. A naive approach where every GPU sends its gradient to a single "parameter server" would require that server to receive N×D bytes — a bottleneck that grows linearly with cluster size. Ring all-reduce keeps per-GPU communication at 2D regardless of N. This is why DDP scales to thousands of GPUs for models that fit on a single device.

Strategy 2: Fully Sharded Data Parallelism (FSDP / ZeRO)

DDP requires every GPU to hold the full model, gradients, and optimizer states. For a 70B model, that is 1,120 GB per GPU — impossible. FSDP (PyTorch's implementation of DeepSpeed ZeRO) fixes this by sharding everything across GPUs. Each GPU stores only 1/N of the weights, 1/N of the gradients, and 1/N of the optimizer states.

But if each GPU only has 1/N of the weights, how does it compute the forward pass? It all-gathers the full weights just before each layer's forward pass, uses them, then discards them. During the backward pass, it all-gathers again, computes gradients, then reduce-scatters to keep only its 1/N shard of the gradients.

1. Shard Everything
Each GPU stores 1/N of parameters, 1/N of gradients, 1/N of optimizer states. For 70B on 64 GPUs: 1120/64 = 17.5 GB per GPU. Fits on an 80 GB H100 with room for activations.
2. Before Forward (per layer)
All-gather the full weight tensor for this layer from all GPUs. Now every GPU has the full layer weights temporarily in memory.
3. Forward Pass
Compute forward pass for this layer using the gathered weights. Then discard the non-local weight shards to free memory.
4. Before Backward (per layer)
All-gather weights again for this layer (they were discarded after forward).
5. Backward Pass
Compute gradients. Reduce-scatter: each GPU keeps only its 1/N shard of the gradients. Discard non-local weights again.
6. Optimizer Step
Each GPU updates only its 1/N shard of parameters using its 1/N shard of gradients and optimizer states. Memory-efficient.

The communication cost is higher than DDP. Let us derive it:

FSDP Communication per GPU per Step:

Forward pass: 1 all-gather per layer = total (N-1)/N × D bytes
Backward pass: 1 all-gather + 1 reduce-scatter per layer
  = (N-1)/N × D + (N-1)/N × D bytes

Total = 3 × (N-1)/N × D ≈ 3D for large N

Compare to DDP: 2 × (N-1)/N × D ≈ 2D

FSDP communicates 50% more than DDP, but uses 1/N the memory.
ZeRO Stages explained. DeepSpeed introduced three stages of memory optimization. ZeRO-1 shards only optimizer states (saves ~4x memory). ZeRO-2 shards optimizer states + gradients (saves ~8x). ZeRO-3 (= FSDP) shards everything including parameters (saves ~N×). Most large-scale training uses ZeRO-3 / FSDP because the memory savings enable training models that could never fit on a single GPU.

Strategy 3: Tensor Parallelism (Megatron-Style)

FSDP shards weights but requires an all-gather before every layer. Tensor parallelism takes a different approach: it splits each individual layer's computation across GPUs so that no GPU ever needs the full weight matrix.

Consider a linear layer Y = XW, where X is the input [B, dmodel] and W is the weight matrix [dmodel, dhidden]. Tensor parallelism splits W column-wise across T GPUs:

Column-Parallel Linear (used in FFN first layer and V/O projections):

W = [W1 | W2 | ... | WT]   split columns, each Wi is [dmodel, dhidden/T]

GPU i computes: Yi = X × Wi   output shape [B, dhidden/T]

Each GPU has a partial result. No communication needed for column-parallel forward!
(Each GPU gets the full X via an identity or all-gather, depending on the preceding layer.)

Row-Parallel Linear (used in FFN second layer and Q/K projections):

W = [W1; W2; ...; WT]   split rows, each Wi is [dhidden/T, dmodel]

GPU i computes: Yi = Xi × Wi   Xi is the local partition from previous layer

Y = Y1 + Y2 + ... + YT   all-reduce to sum partial products

In a transformer, Megatron-LM chains column-parallel and row-parallel layers to minimize communication. Each transformer block requires exactly 2 all-reduces in the forward pass (one after the attention block, one after the FFN) and 2 all-reduces in the backward pass.

Attention head parallelism. Multi-head attention is naturally parallelizable: if you have H heads and T GPUs, each GPU computes H/T heads. The Q, K, V projections are column-parallel (each GPU produces its heads' Q, K, V), and the output projection is row-parallel (each GPU contributes its heads' output, summed via all-reduce). This is why tensor parallelism degrees are almost always divisors of the head count.

Strategy 4: Pipeline Parallelism

Pipeline parallelism splits the model vertically by layers. If you have 80 transformer layers and 8 pipeline stages, each GPU holds 10 consecutive layers. GPU 0 runs layers 0-9, GPU 1 runs layers 10-19, and so on.

The problem: naive pipelining has catastrophic bubble overhead. GPU 7 sits idle while GPUs 0-6 process the forward pass through their layers. Then GPU 0 sits idle while GPUs 1-7 finish the forward pass and start the backward pass.

The solution is microbatching. Instead of sending one large batch through the pipeline, split it into M microbatches. As soon as GPU 0 finishes the forward pass for microbatch 1, it starts microbatch 2 while GPU 1 works on microbatch 1. This fills the pipeline.

Pipeline Bubble Fraction:

With P pipeline stages and M microbatches:

Bubble fraction = (P - 1) / (P - 1 + M)

Example: P=8 stages, M=32 microbatches
Bubble = (8-1) / (8-1+32) = 7/39 = 17.9%

To get bubble < 5%, you need M ≥ 19×(P-1):
M ≥ 19 × 7 = 133 microbatches   (often impractical)

Rule of thumb: M ≥ 4×P gives bubble < 20%.
1F1B schedule. The one-forward-one-backward schedule interleaves forward and backward passes of different microbatches to reduce peak memory. Instead of running all forward passes first (which keeps all M microbatches' activations in memory), 1F1B starts backward passes as soon as possible, freeing activation memory. This reduces peak activation memory from O(M) to O(P), where P is the number of pipeline stages. Megatron-LM and DeepSpeed both use 1F1B.

The Parallelism Decision Tree

Given N GPUs and a model of size M parameters, which strategy do you use? Here is the decision framework that frontier labs actually follow:

ConditionStrategyWhy
Model fits on 1 GPU with optimizerDDP (pure data parallel)Simplest, lowest communication overhead (2D per step)
Model fits on 1 GPU but optimizer does notFSDP / ZeRO-1 or ZeRO-2Shard optimizer states across GPUs, keep model replicated
Model does not fit on 1 GPUFSDP (ZeRO-3) + TP within nodeShard everything. Use TP=8 within each 8-GPU node (fast NVLink), FSDP across nodes
Model is very large (>100B) and cluster is huge3D parallelism: TP + PP + DPTP within node, PP across a few nodes, DP across the rest. Minimizes cross-node communication.
The golden rule of parallelism placement. Always assign the parallelism dimension with the most communication to the fastest interconnect. Tensor parallelism communicates after every layer (2 all-reduces per block) — put it within a node where NVLink provides 900 GB/s. Pipeline parallelism communicates once per microbatch (point-to-point activation transfer) — put it across nearby nodes. Data parallelism communicates once per step (all-reduce of gradients) — it can span the entire cluster because it is the most communication-efficient.

Worked Example: LLaMA 3 70B on 2048 H100s

Let us derive the optimal parallelism configuration for training LLaMA 3 70B on a cluster of 2048 NVIDIA H100 GPUs, each with 80 GB HBM, connected by NVLink within 8-GPU nodes and InfiniBand across nodes.

Step 1: Memory Analysis
Parameters: 70B, BF16 training
Weights: 70B × 2 = 140 GB
Gradients: 140 GB
Optimizer (AdamW, FP32): 70B × 12 = 840 GB
Total static: 1,120 GB

Step 2: Choose Tensor Parallelism (TP)
70B model has 80 layers, 64 heads, dmodel=8192, dFFN=28672
TP must divide head count: TP ∈ {1, 2, 4, 8, 16, 32, 64}
Within an 8-GPU NVLink domain: TP=8
Per-GPU weight memory with TP=8: 140/8 = 17.5 GB ✓

Step 3: Choose Pipeline Parallelism (PP)
With TP=8, each "TP group" of 8 GPUs shares one model replica.
80 layers, try PP=4: each stage has 80/4 = 20 layers.
Each pipeline stage's weight shard: 140 × (20/80) / 8 = 4.375 GB per GPU
Optimizer shard per GPU: 840 × (20/80) / (DP × 8) — need to determine DP first.

Step 4: Choose Data Parallelism (DP)
Total GPUs: 2048
DP = Total / (TP × PP) = 2048 / (8 × 4) = 64

With FSDP (ZeRO-3) across DP dimension:
Optimizer per GPU: 840 × (20/80) / 64 = 3.28 GB
Gradients per GPU: 140 × (20/80) / 64 = 0.55 GB
Weights (gathered for compute): 140 × (20/80) / 8 = 4.375 GB
Total static per GPU: ~8.2 GB (out of 80 GB — plenty for activations)

Step 5: Verify Activation Memory
Per-layer activation: Bmicro × seq × dmodel × 2 bytes
With Bmicro=1, seq=8192: 1 × 8192 × 8192 × 2 = 128 MB per layer
20 layers per pipeline stage: 20 × 128 MB = 2.56 GB
With activation checkpointing (recompute every 2 layers): ~1.3 GB
Total per GPU: ~9.5 GB ✓ (well within 80 GB)

Final Configuration:
TP=8 (within node, NVLink) × PP=4 (across 4 nodes, IB) × DP=64 (FSDP, across cluster)
Global batch size: 64 × M microbatches × Bmicro
Microbatches per pipeline: M=16 (bubble = 3/19 = 15.8%)

MFU: Model FLOPS Utilization

The ultimate metric for training efficiency is MFU — what fraction of the GPU's theoretical peak FLOPS you are actually using for useful computation (excluding communication, bubbles, recomputation, and overhead).

MFU Calculation:

Step 1: Compute model FLOPS per token
For a transformer: FLOPS/token ≈ 6 × P   (forward + backward, 2x forward × 3 for backprop)
70B model: 6 × 70×109 = 420 TFLOP per token

Step 2: Measure tokens per second (throughput)
Suppose we measure: 3,200 tokens/sec/GPU

Step 3: Compute achieved FLOPS per GPU
Achieved = 3200 × 420×1012 / (2048 GPUs)   wait — tokens/sec is already per-GPU
Achieved per GPU = 3200 × (6 × 70×109) = 3200 × 4.2×1011 = 1.344 PFLOP/s per GPU

Step 4: H100 peak BF16 FLOPS = 989 TFLOP/s (with sparsity: 1,979 TFLOP/s, ignore sparsity)

Step 5: MFU = achieved / peak = 1,344 / 989 = 135%?!

Wait — that is wrong. Let us recheck.
3200 tokens/sec/GPU is the global throughput divided by GPU count.
Global throughput: 3200 × 2048 = 6.55M tokens/sec
Global FLOPS = 6.55×106 × 4.2×1011 = 2.75×1018 FLOPS/s = 2,750 PFLOP/s
Cluster peak = 2048 × 989 TFLOP/s = 2,025 PFLOP/s

Still over 100%. The issue: our tokens/sec estimate was too high.
Realistic for LLaMA 3 70B: ~400 tokens/sec/GPU (from Meta's paper).

Corrected:
Achieved per GPU = 400 × 4.2×1011 = 168 TFLOP/s
MFU = 168 / 989 = 17%

Meta reported ~38-43% MFU for LLaMA 3 training. The difference is
optimizations: overlapped communication, optimal microbatch sizing,
activation checkpointing, and custom CUDA kernels.
What good MFU looks like. State of the art for large-scale training in 2024-2025: 35-55% MFU. Google's PaLM achieved 46.2% on TPU v4. Meta's LLaMA 3 405B achieved 38-43% on H100s. Anything below 30% means you have a configuration bug (wrong parallelism, too small microbatch, unoptimized kernels). Anything above 55% is exceptional and usually means a simpler model architecture (fewer attention layers, more FFN).

Gradient Accumulation and Its Interaction with Parallelism

Sometimes the global batch size you need is larger than what fits in GPU memory across all data-parallel replicas. Gradient accumulation solves this: run K forward-backward passes, accumulating gradients, then do one optimizer step.

Effective batch size with gradient accumulation:

Beffective = Bmicro × DP × Kaccum × Mpipeline

Where:
  Bmicro = microbatch size per GPU
  DP = data parallel degree
  Kaccum = gradient accumulation steps
  Mpipeline = microbatches per pipeline (if using PP)

Example: LLaMA 3 uses Beffective = 4M tokens
seq_len=8192, so 4M/8192 = 488 sequences per step
With DP=64, M=16, Bmicro=1:
Kaccum = 488 / (64 × 16 × 1) ≈ 0.48

So no accumulation needed! 64 × 16 = 1024 > 488. We can reduce M to 8:
Kaccum = 488 / (64 × 8 × 1) ≈ 0.95 ≈ 1 (no accumulation)

Gradient accumulation is "free" in terms of communication — you only all-reduce once per K steps. But it increases latency per optimizer step by K×.

Mixed Precision Training: The Loss Scaling Dance

Training in BF16 saves memory and doubles throughput on tensor cores. But gradients can underflow to zero in BF16 (the smallest representable normal number is ~1.2×10-38 for the exponent, but mantissa precision is only ~0.01). Small gradients from early layers can vanish.

Loss scaling solves this: multiply the loss by a large constant S before backpropagation. This scales all gradients by S, pushing them away from the underflow zone. After the backward pass, divide gradients by S before the optimizer step.

Dynamic Loss Scaling Algorithm:

1. Initialize scale S = 216 = 65536
2. Forward pass in BF16
3. Compute loss L, then Lscaled = L × S
4. Backward pass computes ∇Wscaled = S × ∇W
5. Check for inf/NaN in gradients:
   If inf/NaN found: skip optimizer step, S = S / 2
   If no inf/NaN for 2000 consecutive steps: S = S × 2
6. Unscale: ∇W = ∇Wscaled / S
7. Optimizer step in FP32 (master weights updated in FP32, cast back to BF16)
BF16 and loss scaling. In practice, BF16's larger exponent range means loss scaling is rarely needed for BF16 training (unlike FP16 which absolutely requires it). Most frontier labs train in BF16 without loss scaling. However, if you see training instabilities (loss spikes, divergence), loss scaling is the first thing to try before reducing the learning rate.

Communication Cost Summary

StrategyPer-GPU MemoryCommunication per StepCommunication TypeBest Interconnect
DDPFull model + optimizer2 × (N-1)/N × DAll-reduce (gradients)Any (least communication)
FSDP/ZeRO-31/N of everything3 × (N-1)/N × DAll-gather + reduce-scatterCross-node IB OK
Tensor Parallel1/T of each layer4 × all-reduce per blockAll-reduce (activations)NVLink required (intra-node)
Pipeline ParallelL/P layersPoint-to-point per microbatchSend/recv (activations)IB OK, low bandwidth needs

Interactive: Parallelism Strategy Visualizer

Parallelism Strategy Explorer

Adjust model size and GPU count to see how parallelism strategies affect memory usage and communication. The visualizer shows per-GPU memory and communication volume.

Model (B params) 70
GPUs 2048
TP degree 8
PP stages 4

Activation Checkpointing: Trading Compute for Memory

During the forward pass, every layer's output is stored in memory so it can be used during the backward pass. For a 70B model with 80 layers, batch size 1, and sequence length 8192, the activation memory is enormous:

Activation Memory per Layer (transformer):

Input activations: B × S × dmodel × 2 bytes = 1 × 8192 × 8192 × 2 = 128 MB
Attention scores (Q×KT): B × H × S × S × 2 = 1 × 64 × 8192 × 8192 × 2 = 8 GB (!)
FFN intermediate: B × S × 4d × 2 = 1 × 8192 × 32768 × 2 = 512 MB

Total per layer: ~8.6 GB
All 80 layers: 80 × 8.6 = 688 GB

This is why you cannot train long-sequence transformers without activation checkpointing.

Activation checkpointing (also called gradient checkpointing or rematerialization) discards intermediate activations during the forward pass and recomputes them during the backward pass. The trade-off: 33% more compute (one extra forward pass per layer), but activation memory drops from O(L) to O(sqrt(L)) or O(1) depending on the strategy.

StrategyMemory SavedExtra ComputeWhen to Use
No checkpointingNoneNoneModel and activations fit in memory
Selective (every N layers)~(N-1)/N of activations~(N-1)/N extra forwardMost common: checkpoint every 2-4 layers
Full (every layer)~L× activations~33% extra computeVery large models or long sequences
FlashAttentionEliminates S×S attention matrixRecomputes in backwardAlways use — saves the biggest chunk
FlashAttention is a form of activation checkpointing. The S×S attention score matrix (8 GB per layer in our example) is the single largest activation tensor. FlashAttention never materializes it — it computes attention in tiles, keeping only the current tile in SRAM. During the backward pass, it recomputes the attention scores from Q, K, V. This is why FlashAttention uses more FLOPS than standard attention but is faster: it trades the enormous memory write/read of the S×S matrix for a cheaper recomputation.

Sequence Parallelism and Context Parallelism

When training with very long sequences (32K, 128K, or even 1M tokens), the attention memory grows quadratically with sequence length. Even with FlashAttention, the KV cache for one layer is B × 2 × H_kv × S × d_head × 2 bytes. At S=128K with GQA (8 KV heads, d_head=128):

KV cache per layer = 1 × 2 × 8 × 131072 × 128 × 2 = 512 MB
All 80 layers = 80 × 512 MB = 40 GB

This is half an H100's memory — just for one sequence's KV cache.

Context parallelism (also called sequence parallelism for the attention component) splits the sequence across GPUs. Each GPU processes a shard of the sequence and communicates partial attention results via all-to-all or ring exchanges. This allows training on sequences far longer than what fits on a single GPU.

Ring attention is the most elegant implementation: GPUs are arranged in a ring. Each GPU holds a chunk of Q and a chunk of K,V. The K,V chunks are passed around the ring, and each GPU accumulates the attention for its Q chunk against all K,V chunks as they arrive. After one full ring rotation, every GPU has its complete attention output.

The parallelism zoo summary. Modern large-scale training combines five forms of parallelism: (1) Data Parallelism — replicate model, partition data. (2) FSDP — shard model + optimizer + gradients. (3) Tensor Parallelism — shard each layer's weights. (4) Pipeline Parallelism — shard layers across stages. (5) Context Parallelism — shard sequence length. Meta's LLaMA 3 405B used all five simultaneously: TP=8, PP=16, DP=128, CP=4, with FSDP across the DP dimension.

Failure Recovery at Scale

When training on 2048+ GPUs for weeks, hardware failures are guaranteed. A single GPU failure can halt a training run that costs $50K per hour of idle time. Frontier labs engineer for this:

Failure TypeFrequency (2048 GPUs)DetectionRecovery
GPU memory error (ECC)~1 per dayCUDA runtime errorReplace GPU, resume from checkpoint
NIC failure~1 per 3 daysNCCL timeoutReroute traffic, resume
Node crash~1 per weekHeartbeat timeoutExclude node, redistribute ranks
Software hang~2 per dayTraining step time 3x normalKill, restart from checkpoint
Checkpoint corruptionRareChecksum validationFall back to previous checkpoint
Checkpointing cost. Saving a checkpoint for a 70B model: 1,120 GB of state must be written to persistent storage. At 10 GB/s write bandwidth: 112 seconds. During this time, training is paused. If you checkpoint every 100 steps, and each step takes 30 seconds, checkpointing overhead is 112 / (100 × 30 + 112) = 3.6%. Frontier labs use asynchronous checkpointing (write to a staging buffer in CPU RAM, then flush to disk while training continues) to reduce this to near zero.
python
# Compute parallelism configuration for any model-cluster pair
def compute_parallel_config(
    num_params_B: float,   # billions of parameters
    num_gpus: int,         # total GPU count
    hbm_per_gpu_GB: float = 80,
    num_heads: int = 64,
    num_layers: int = 80,
):
    P = num_params_B * 1e9

    # Memory breakdown (BF16 training with AdamW)
    weights_GB = P * 2 / 1e9          # BF16 weights
    grads_GB = P * 2 / 1e9            # BF16 gradients
    opt_GB = P * 12 / 1e9             # FP32 master + m + v
    total_static_GB = weights_GB + grads_GB + opt_GB

    # Strategy 1: Can we fit on 1 GPU? → DDP
    if total_static_GB * 1.3 < hbm_per_gpu_GB:
        return {"strategy": "DDP", "DP": num_gpus,
                "mem_per_gpu": total_static_GB}

    # Strategy 2: Choose TP (must divide heads, max 8 for NVLink)
    tp = 1
    for t in [8, 4, 2]:
        if num_heads % t == 0 and weights_GB / t < hbm_per_gpu_GB * 0.5:
            tp = t; break

    # Strategy 3: Choose PP (enough stages to fit layers)
    pp = 1
    while total_static_GB / (tp * pp) > hbm_per_gpu_GB * 0.6:
        pp *= 2
        if pp > 16: break

    # Remaining GPUs are data parallel
    dp = num_gpus // (tp * pp)

    # Per-GPU memory with FSDP across DP dimension
    layers_per_stage = num_layers // pp
    frac = layers_per_stage / num_layers
    mem_weights = weights_GB * frac / tp          # TP shards weights
    mem_grads = grads_GB * frac / (dp * tp)       # FSDP shards grads
    mem_opt = opt_GB * frac / (dp * tp)           # FSDP shards opt
    mem_per_gpu = mem_weights + mem_grads + mem_opt

    return {
        "strategy": f"TP={tp} x PP={pp} x DP={dp}",
        "tp": tp, "pp": pp, "dp": dp,
        "mem_per_gpu_GB": round(mem_per_gpu, 2),
        "total_static_GB": round(total_static_GB, 1),
        "bubble_frac": (pp - 1) / (pp - 1 + 4 * pp) if pp > 1 else 0,
    }

# Example: LLaMA 3 70B on 2048 H100s
config = compute_parallel_config(70, 2048)
print(config)
# {'strategy': 'TP=8 x PP=4 x DP=64', 'mem_per_gpu_GB': 5.14, ...}

Real-World Training Configurations (2024-2025)

Let us examine the actual parallelism configurations used by frontier labs. These numbers come from published papers and technical reports:

ModelParamsGPUsTPPPDPCPMFUGlobal Batch
LLaMA 3 8B8B512 H10011512143%4M tokens
LLaMA 3 70B70B2048 H1008464141%4M tokens
LLaMA 3 405B405B16384 H100816128138%8M tokens
GPT-4 (estimated)~1.7T MoE~25000 A1008~64~48?~35%~60M tokens
Gemini 1.5 (est)>1T~4096 TPUv5p16~8~324~45%?
DeepSeek V3671B MoE2048 H80041632152%15M tokens
DeepSeek's 52% MFU breakthrough. DeepSeek V3 achieved the highest reported MFU for a frontier model by using several innovations: (1) FP8 training throughout (not just BF16), which doubles tensor core throughput, (2) aggressive communication-computation overlap using a custom DualPipe scheduler that interleaves two microbatches within a single pipeline stage, (3) cross-node tensor parallelism via InfiniBand (unconventional — most labs restrict TP to NVLink), made possible by their custom all-to-all kernel for MoE dispatch. The total training cost was only $5.5M — roughly 10x cheaper than comparable models.

DeepSpeed vs Megatron-LM vs PyTorch FSDP

Three major frameworks support distributed training. Knowing their trade-offs matters for system design interviews:

FeaturePyTorch FSDPDeepSpeed ZeROMegatron-LM
Maintained byMeta (PyTorch team)MicrosoftNVIDIA
Data parallelismZeRO-3 (native)ZeRO-1/2/3DDP
Tensor parallelismVia DTensor (limited)Via Megatron-DeepSpeedNative (best support)
Pipeline parallelismLimitedYes (1F1B, interleaved)Yes (1F1B, interleaved)
Sequence parallelismNoUlysses SPYes
Mixed precisionBF16 nativeBF16/FP16BF16/FP16 + FP8 (TransformerEngine)
Best forSimple FSDP jobs, <100B paramsFlexible configs, researchMax performance, 100B+ models
Learning curveLow (PyTorch native)Medium (config-driven)High (custom model code)

In practice, Meta uses Megatron-LM + FSDP for LLaMA training. Google uses its own internal framework (Pathways) on TPUs. DeepSpeed is popular for research labs and smaller companies. The choice often depends on what your cluster runs (NVIDIA GPUs → Megatron-LM, TPUs → JAX/Pax, budget-constrained → DeepSpeed).

The Learning Rate Schedule: An Overlooked Performance Factor

The learning rate schedule interacts with parallelism in non-obvious ways. When you change the global batch size (by changing DP or gradient accumulation), the optimal learning rate changes too.

Linear scaling rule:
If you multiply global batch size by k, multiply learning rate by k.
LRnew = LRbase × (Bnew / Bbase)

Square root scaling rule (more conservative):
LRnew = LRbase × √(Bnew / Bbase)

Example: LLaMA 3 70B
Bbase = 4M tokens, LRbase = 1.5×10-4
If you double DP (double batch to 8M):
Linear: LR = 3.0×10-4 (risky — may diverge)
Sqrt: LR = 2.12×10-4 (safer, commonly used)

Warmup is critical: ramp LR from 0 to target over the first 2000 steps.
Without warmup, large-batch training with scaled LR reliably diverges.
Warmup steps must scale with batch size. When you increase the batch size, each step covers more data, so the warmup should cover the same amount of data (not steps). If you warmup over 2000 steps with batch 4M, and then switch to batch 8M, warmup should be 1000 steps (same 8B tokens of warmup data).
You are training a 13B parameter model in BF16 with AdamW on a cluster of 64 H100 GPUs (80 GB each). The total static memory (weights + gradients + optimizer) is 13B × 16 = 208 GB. Each GPU has 80 GB. What is the most efficient parallelism strategy, and why?

Chapter 10: Serving Infrastructure [inference lesson] [handbook]

Training a model costs millions of dollars and takes weeks. Serving that model costs millions of dollars per month and runs indefinitely. At frontier labs, the serving infrastructure team often has a larger headcount than the training team. And the engineering challenges are fundamentally different: training optimizes for throughput (tokens per second across the cluster), serving optimizes for latency under load (time to first token for the user who just typed a query while 10,000 other users are also waiting).

This chapter covers the core systems that make LLM serving efficient: the two-phase serving model, continuous batching, speculative decoding, disaggregated serving, and KV cache management. By the end, you will be able to derive the cost of serving a 70B model and identify exactly where the bottlenecks are.

Prefill vs Decode: The Two-Phase Model

When a user sends a prompt to an LLM, the server does two fundamentally different things:

Phase 1: Prefill. The entire prompt (say, 500 tokens) is processed in a single forward pass. Every token can attend to all previous tokens in parallel. This is a large matrix multiply — compute-bound. The GPU tensor cores are busy. Memory bandwidth is not the bottleneck.

Phase 2: Decode. Output tokens are generated one at a time. Each forward pass processes exactly 1 new token (attending to all previous tokens via the KV cache). This is a tiny matrix multiply — memory-bandwidth-bound. The tensor cores are mostly idle, waiting for weights to be loaded from HBM.

Arithmetic Intensity Analysis:

Prefill (processing S prompt tokens):
FLOPS = 2 × S × dmodel2 × num_layers × 4   (4 projections: Q,K,V,O)
Bytes loaded = dmodel2 × 2 × num_layers × 4   (weight matrices, BF16)
Arithmetic intensity = FLOPS / Bytes = 2S / 2 = S

For S=512: intensity = 512 ops/byte
H100 roofline crossover: ~200 ops/byte
512 > 200 → compute-bound

Decode (generating 1 token):
FLOPS = 2 × 1 × dmodel2 × num_layers × 4
Bytes loaded = same weight matrices
Arithmetic intensity = 2 × 1 / 2 = 1 op/byte

1 << 200 → memory-bandwidth-bound
The GPU is loading 140 GB of weights to do 1 token of compute.
The decode bottleneck is fundamental. During decoding, the GPU loads the entire model's weights from HBM for every single output token. For a 70B model in BF16, that is 140 GB per token. An H100 has 3.35 TB/s HBM bandwidth. Maximum decode speed: 3,350 / 140 = 24 tokens/sec for a single request. This is a hard ceiling set by physics (memory bandwidth), not by compute. Batching multiple requests together amortizes the weight load — processing 32 requests simultaneously gives 32 × 24 = ~768 tokens/sec total, though per-request latency increases.

Batching Theory: Deriving the Throughput-Latency Curve

Batching is the single most important optimization in LLM serving. Let us derive exactly why, starting from first principles.

Consider a 70B model on 4x H100 with TP=4. The model weights occupy 140/4 = 35 GB per GPU. During decode, each token requires loading all weights from HBM. The key insight: loading weights once for B tokens costs nearly the same as loading them for 1 token (the extra KV cache and activation memory is tiny in comparison).

Decode latency as a function of batch size B:

Weight load time: 35 GB / 3.35 TB/s = 10.4 ms (constant, independent of B)
KV cache load per request: ~0.3 MB per layer per token (negligible for small B)
Compute per token: 2 × d2 × layers × 4 projections = tiny vs load time

For B=1: Total = 10.4 ms → 96 tokens/sec/GPU → 384 tokens/sec total
For B=8: Total ≈ 10.4 ms + 0.3 ms overhead → 3,040 tokens/sec total
For B=32: Total ≈ 10.4 ms + 1.5 ms → 10,700 tokens/sec total
For B=128: Total ≈ 10.4 ms + 8 ms → 6,950 tokens/sec total (decreasing!)

Why throughput peaks then falls:
At small B: adding requests costs almost nothing (amortize weight load)
At large B: KV cache reads dominate, and HBM is shared → contention
Optimal B ≈ 32-64 for this config (roofline crossover)
The throughput cliff. As batch size grows past a threshold, the KV cache memory for all B requests exceeds the L2 cache, and every KV read becomes a full HBM access. Throughput per request drops, and total throughput plateaus. The sweet spot depends on model size, sequence length, and GPU memory. For 70B models, it is typically B=32-64. For 7B models, it can be B=128-256.

Continuous Batching: Why Static Batching Wastes GPUs

In static batching, you collect B requests, process them together, and wait for ALL B requests to finish generating before accepting new ones. If request 1 generates 20 tokens and request 2 generates 500 tokens, request 1's GPU slot sits idle for 480 decode steps. For a batch of diverse-length requests, average GPU utilization can be as low as 30-40%.

Continuous batching (also called "in-flight batching" or "iteration-level batching") fixes this: as soon as any request finishes, its slot is immediately filled with a new request from the queue. The batch composition changes at every decode step.

Static Batching
Collect 4 requests → process all 4 together → wait for slowest to finish (200 tokens) → request A (finished at 30 tokens) wastes 170 decode steps → only then accept new batch. Utilization: ~45%.
vs
Continuous Batching
Start 4 requests → Request A finishes at step 30 → immediately replace A with Request E → Request B finishes at step 50 → replace with F → batch always full. Utilization: ~95%.

Let us quantify the waste. Consider 4 requests with output lengths [30, 80, 150, 500]:

Static batching waste:
All 4 requests run for 500 decode steps (waiting for the longest).
Total "GPU-steps" consumed: 4 × 500 = 2,000
Useful work: 30 + 80 + 150 + 500 = 760 token-generations
Utilization: 760 / 2000 = 38%

Continuous batching:
Step 0-30: 4 requests active (30 useful per step)
Step 30: Request A done, immediately add Request E (length 200)
Step 30-80: 4 requests active
Step 80: Request B done, add Request F (length 100)...
GPU is always processing 4 requests. Utilization: ~95%

Throughput improvement: 95% / 38% = 2.5x for this workload.

The implementation requires a scheduler that tracks the state of every request (prefilling, decoding, waiting, done) and makes admission decisions at every iteration:

python
class ContinuousBatchScheduler:
    def __init__(self, max_batch: int, max_tokens: int):
        self.max_batch = max_batch     # max concurrent requests
        self.max_tokens = max_tokens   # max total tokens in KV cache
        self.running = []              # requests currently generating
        self.waiting = []              # queue of pending requests

    def schedule_step(self):
        # Step 1: Remove finished requests (hit EOS or max_len)
        finished = [r for r in self.running if r.is_done()]
        self.running = [r for r in self.running if not r.is_done()]

        # Step 2: Free KV cache memory from finished requests
        for r in finished:
            self.kv_cache.free(r.request_id)

        # Step 3: Admit new requests from the queue
        while (self.waiting
               and len(self.running) < self.max_batch
               and self.kv_cache.available() > self.waiting[0].est_tokens):
            req = self.waiting.pop(0)
            req.state = "prefill"
            self.running.append(req)

        # Step 4: Separate prefill and decode requests
        prefill_reqs = [r for r in self.running if r.state == "prefill"]
        decode_reqs = [r for r in self.running if r.state == "decode"]

        # Step 5: Run prefill for new requests (compute-bound)
        if prefill_reqs:
            self.run_prefill(prefill_reqs)
            for r in prefill_reqs:
                r.state = "decode"

        # Step 6: Run one decode step for all decoding requests
        if decode_reqs:
            self.run_decode(decode_reqs)

        return finished

Chunked Prefill: Reducing Head-of-Line Blocking

A subtle problem with continuous batching: when a new request with a long prompt (say, 10,000 tokens) enters the batch, its prefill takes hundreds of milliseconds. During this time, the decode requests in the same batch are blocked — they cannot generate their next token until the prefill completes. This is head-of-line blocking.

Chunked prefill solves this by breaking long prefills into smaller chunks (e.g., 512 tokens at a time). Between chunks, decode requests get a chance to run. The result: prefill spreads over multiple iterations, and decode latency stays bounded.

Chunked Prefill Impact:

Without chunked prefill:
  Prefill 10K tokens: ~120 ms (compute-bound)
  During this 120 ms: all 32 decode requests stalled
  ITL spike: 120 ms (vs normal 30 ms)

With chunked prefill (chunk=512):
  Each chunk: ~6 ms
  After each chunk: run 1 decode step (~10 ms)
  Total prefill time: 20 chunks × (6 + 10) ms = 320 ms (2.7x slower)
  But ITL stays bounded at ~16 ms (6ms chunk + 10ms decode)

Trade-off: TTFT increases (320 vs 120 ms) but ITL stays smooth.
For chatbots: smooth typing > fast first token. Chunked prefill wins.

Guided Decoding and Structured Output

Many applications need the LLM to output valid JSON, SQL, or code that follows a grammar. Guided decoding constrains the token-by-token generation to follow a formal grammar or regex pattern by masking out invalid tokens at each step.

python
# Guided JSON decoding: mask invalid tokens at each step
class JSONGuide:
    def __init__(self, schema: dict):
        self.schema = schema
        self.state = "start"  # FSM state: start, key, value, etc.
        self.stack = []        # nesting stack: [object, array, ...]

    def get_valid_tokens(self, generated_so_far: str) -> set:
        """Return the set of token IDs that are valid at this position."""
        if self.state == "start":
            return tokens_starting_with("{")
        elif self.state == "key":
            # Only allow keys from schema
            valid_keys = self.schema["properties"].keys()
            return tokens_for_any_of([f'"{k}"' for k in valid_keys])
        elif self.state == "value":
            expected_type = self.current_field_type()
            if expected_type == "integer":
                return digit_tokens | tokens_for("-")
            elif expected_type == "string":
                return tokens_starting_with('"')
            # ... more types
        return all_tokens  # fallback: no constraint

    def mask_logits(self, logits, generated_so_far):
        """Zero out logits for invalid tokens."""
        valid = self.get_valid_tokens(generated_so_far)
        mask = torch.full_like(logits, -1e9)
        mask[list(valid)] = 0
        return logits + mask
Performance impact of guided decoding. Constraining tokens adds ~5-10% overhead per decode step (the grammar FSM transition + logit masking). But it eliminates the need for retry loops when the LLM generates invalid output. For JSON output, unguided generation fails ~15-30% of the time on complex schemas, requiring a full regeneration. Guided decoding has a 100% success rate by construction. The net throughput is higher because no retries are needed.

Speculative Decoding: Small Model Drafts, Large Model Verifies

Autoregressive decoding is slow because you generate one token at a time. Speculative decoding accelerates this with a clever insight: use a small, fast draft model to generate K candidate tokens, then have the large target model verify all K tokens in a single forward pass (which is as fast as verifying 1 token, since it is compute-bound like prefill).

1. Draft
Small model (e.g., 1B params) generates K=5 candidate tokens autoregressively. This is fast because the model is small. Takes ~5 ms for 5 tokens.
2. Verify
Large model (70B) processes all 5 candidates + context in a single forward pass. Computes the probability of each candidate token. Takes ~30 ms (same as 1-token decode, because it is a batch operation).
3. Accept/Reject
Compare draft probabilities psmall(t) with target probabilities plarge(t). Accept token i if plarge(ti) ≥ psmall(ti) (always accept). If plarge < psmall, accept with probability plarge/psmall. First rejection terminates the chain.
4. Result
If 4 out of 5 tokens are accepted, we generated 4 tokens in one large-model forward pass instead of 4 separate passes. Speedup: ~4x on that step. Average speedup depends on acceptance rate.
Speculative Decoding Speedup:

Let α = acceptance rate (probability a draft token is accepted)
Let K = number of draft tokens per speculation round
Let Tdraft = time for draft model to generate K tokens
Let Tverify = time for one target model forward pass

Expected accepted tokens per round = ∑i=0..K-1 αi × (1-α) × i + αK × K
  = (1 - αK+1) / (1 - α) - 1   (geometric series)
  ≈ 1/(1-α) for large K and moderate α

Speedup = Expected_accepted / (Tdraft + Tverify) × Tdecode_one

Example: α=0.8, K=5, Tdraft=5ms, Tverify=30ms, Tdecode_one=30ms
Expected accepted ≈ 1/(1-0.8) = 5 tokens per round ≈ 4.0 (exact: 3.36 for K=5)
Normal time for 3.36 tokens: 3.36 × 30 = 100.8 ms
Speculative time: 5 + 30 = 35 ms
Speedup: 100.8 / 35 = 2.88x
When speculative decoding works best. High acceptance rate α requires the draft model to closely match the target model's distribution. This works best for: (1) code generation (highly structured, predictable), (2) translation (constrained by source), (3) summarization (output strongly conditioned on input). It works worst for creative writing (low predictability). Google's Gemini uses speculative decoding in production with a dedicated draft model trained to match the large model's distribution.

Disaggregated Serving: Split Prefill and Decode

Prefill and decode have opposite hardware requirements. Prefill is compute-bound — it benefits from high FLOPS (many tensor cores). Decode is memory-bandwidth-bound — it benefits from high HBM bandwidth and large batch sizes. Running both on the same GPU creates contention: a long prefill blocks decode for all batched requests, causing latency spikes.

Disaggregated serving runs prefill and decode on separate GPU pools:

PropertyPrefill GPU PoolDecode GPU Pool
WorkloadProcess input promptsGenerate output tokens
BottleneckCompute (tensor cores)Memory bandwidth (HBM)
Optimal hardwareHigh FLOPS (H100 SXM)High bandwidth, less compute (H100 NVL, or even A100)
Batch strategyLarge batches of promptsLarge batches of single-token decodes
KV cacheGenerate KV cache, transfer to decode poolStore and extend KV cache
UtilizationNear 100% computeNear 100% bandwidth

The critical challenge is transferring the KV cache from prefill to decode. For a 70B model with 80 layers, 8 KV heads, d_head=128, and a 4096-token prompt in BF16:

KV Cache Size for One Request:

KV cache per layer = 2 × num_kv_heads × d_head × seq_len × 2 bytes
= 2 × 8 × 128 × 4096 × 2 = 16.78 MB per layer

Total = 80 layers × 16.78 MB = 1.34 GB per request

At 100 Gbps network: transfer time = 1.34 GB / 12.5 GB/s = 107 ms
At 400 Gbps InfiniBand: transfer time = 1.34 / 50 = 26.8 ms

This transfer latency adds to TTFT. Disaggregation only helps
when the prefill-decode interference cost exceeds the transfer cost.

KV Cache Management in Production

The KV cache is the largest per-request memory consumer. For a serving system handling hundreds of concurrent requests, KV cache management becomes a critical systems problem.

PagedAttention (from vLLM) is the breakthrough that made efficient KV cache management possible. Instead of allocating contiguous memory for each request's KV cache (which leads to fragmentation when requests have variable lengths), PagedAttention divides KV cache memory into fixed-size pages (typically 16 tokens per page):

python
# PagedAttention KV Cache Manager (simplified)
class PagedKVCacheManager:
    def __init__(self, num_pages: int, page_size: int = 16,
                 num_layers: int = 80, num_heads: int = 8,
                 head_dim: int = 128):
        self.page_size = page_size
        # Physical pages: [num_pages, 2, num_layers, num_heads, page_size, head_dim]
        # 2 = key and value
        self.physical_pages = torch.zeros(
            num_pages, 2, num_layers, num_heads, page_size, head_dim,
            dtype=torch.bfloat16, device="cuda"
        )
        self.free_pages = list(range(num_pages))
        self.page_tables = {}  # request_id → [page_indices]

    def allocate(self, request_id: str, num_tokens: int):
        num_pages_needed = (num_tokens + self.page_size - 1) // self.page_size
        if len(self.free_pages) < num_pages_needed:
            return False  # out of memory — preempt a request
        pages = [self.free_pages.pop() for _ in range(num_pages_needed)]
        self.page_tables[request_id] = pages
        return True

    def append_token(self, request_id: str):
        # Check if current last page has space
        pages = self.page_tables[request_id]
        current_len = self.get_seq_len(request_id)
        if current_len % self.page_size == 0:
            # Need a new page
            if not self.free_pages:
                return False
            pages.append(self.free_pages.pop())
        return True

    def free(self, request_id: str):
        pages = self.page_tables.pop(request_id, [])
        self.free_pages.extend(pages)
Prefix caching. Many requests share common prefixes (system prompts, few-shot examples). PagedAttention enables prefix caching: the KV cache pages for the shared prefix are computed once and reused across requests via copy-on-write page table entries. For a system prompt of 2000 tokens, this saves 2000-token prefills for every subsequent request. vLLM's automatic prefix caching can reduce TTFT by 50-80% for workloads with shared prefixes.

Latency Metrics: What You Measure, What You Optimize

MetricDefinitionWhat It MeasuresTarget (chatbot)
TTFTTime from request arrival to first output tokenPrefill speed + queueing delay< 500 ms (p95)
ITLInter-token latency: time between consecutive output tokensDecode speed per token< 50 ms (p95) for real-time feel
E2E LatencyTime from request to last tokenTTFT + (output_len × ITL)Depends on output length
ThroughputTotal output tokens per second across all requestsSystem capacityMaximize (cost efficiency)
TPS/userTokens per second experienced by one userPerceived speed> 20 TPS for "fast" feeling
The throughput-latency tradeoff. Larger batches improve throughput but increase per-request latency (more contention for memory bandwidth). A serving system must navigate this tradeoff based on SLA requirements. For real-time chat: optimize for low ITL (small batches). For batch processing (summarizing 10K documents): optimize for throughput (large batches, higher latency acceptable).

The Cost Model: $/1M Tokens

Cloud providers and API companies price LLM inference in dollars per million tokens. Let us derive this cost from first principles for serving LLaMA 3 70B:

Cost Derivation for LLaMA 3 70B on H100:

Hardware: 4x H100 80GB (TP=4 for 70B), 1 server
Cost: ~$4/hr per H100 (cloud spot) × 4 = $16/hr

Decode throughput:
Single request: ~24 tokens/sec (memory-bandwidth limited, derived earlier)
With batch=64: ~768 tokens/sec total × 4 GPUs × 0.7 efficiency = 2,150 tokens/sec
(0.7 accounts for prefill interference, scheduling overhead, KV cache memory)

Tokens per hour: 2,150 × 3600 = 7.74M tokens/hr

Cost per 1M tokens: $16 / 7.74 = $2.07/1M output tokens

Input tokens are cheaper: prefill processes ~10x more tokens/sec
Input cost ≈ $0.20/1M tokens

Compare to API pricing (2025):
GPT-4o: $2.50/1M input, $10.00/1M output
Claude 3.5 Sonnet: $3.00/1M input, $15.00/1M output
LLaMA 3.1 70B (Fireworks): $0.90/1M input, $0.90/1M output
Open models on dedicated hardware: 2-5x cheaper than API providers

Interactive: Continuous Batching Visualizer

Continuous Batching Lifecycle

Watch requests flow through a serving system. Green = prefilling, blue = decoding, gray = waiting. Compare static vs continuous batching utilization. Click "Add Request" to queue new work.

Throughput: 0 tok/s Utilization: 0% Queue: 0

Worked Example: Serving LLaMA 3 70B on 8x TPU v5e

Let us derive the full serving profile for LLaMA 3 70B on a TPU v5e pod slice of 8 chips, using tensor parallelism.

Step 1: Model Memory
70B params × 2 bytes (BF16) = 140 GB
TPU v5e: 16 GB HBM per chip, 8 chips → 128 GB total
140 GB > 128 GB → Model does not fit in BF16!

Solution: INT8 weight quantization
70B × 1 byte = 70 GB. Across 8 chips: 70/8 = 8.75 GB per chip. ✓

Step 2: KV Cache Budget
Remaining HBM per chip: 16 - 8.75 = 7.25 GB
KV cache per token per layer: 2 × 8 × 128 × 2 bytes = 4 KB (GQA with 8 KV heads)
KV per token (all 80 layers): 80 × 4 KB = 320 KB
Per chip (TP=8, 10 layers per chip): 10 × 4 KB = 40 KB per token

Max tokens in cache per chip: 7.25 GB / 40 KB = 190,000 tokens
With 4096-token avg sequence: ~46 concurrent requests max

Step 3: Decode Throughput
TPU v5e BF16 bandwidth: 819 GB/s (but INT8 weights → less bandwidth needed)
Weight load per token: 70 GB / 8 chips = 8.75 GB per chip
Minimum time per token: 8.75 GB / 819 GB/s = 10.7 ms

With batch=32: all 32 requests share the weight load
Tokens per second: 32 / 10.7 ms = 2,991 tokens/sec total

Step 4: Cost
TPU v5e cost: ~$1.20/hr per chip × 8 = $9.60/hr
Tokens per hour: 2,991 × 3,600 = 10.77M tokens/hr
Cost: $0.89/1M output tokens

Compare to H100 ($2.07/1M) → TPU is 2.3x cheaper for this workload.
Why TPU v5e wins for serving. The TPU v5e has lower HBM bandwidth than H100 (819 vs 3,350 GB/s), but it also costs 3-4x less per chip. For memory-bandwidth-bound decode, the cost per token depends on bandwidth per dollar, not raw bandwidth. TPU v5e: 819 GB/s / $1.20 = 682 GB/s/$. H100: 3,350 GB/s / $4.00 = 838 GB/s/$. On a raw BW/$ basis, H100 is actually 23% more cost-efficient. But TPU v5e's lower per-chip cost means you can use more chips, enabling larger batches and better amortization. The real winner depends on your batch size and SLA constraints.

Request Preemption and Priority Scheduling

When KV cache memory runs out, the scheduler must decide: reject new requests (increasing queue wait time) or preempt an existing request (evict its KV cache, restart it later). Sophisticated schedulers use priority-based preemption:

python
# Priority preemption: evict lowest-priority request when memory is full
def preempt_if_needed(self, new_request):
    """Evict a running request to make room for a higher-priority one."""
    if self.kv_cache.available() >= new_request.est_tokens:
        return True  # enough memory, no preemption needed

    # Find lowest-priority running request
    candidates = sorted(
        self.running,
        key=lambda r: (r.priority, -r.tokens_generated)
    )
    # Evict requests with LOWER priority than the new one
    for victim in candidates:
        if victim.priority >= new_request.priority:
            break  # don't preempt equal or higher priority

        # Swap out: save KV cache to CPU memory (or discard + recompute later)
        self.kv_cache.swap_out(victim.request_id)
        self.running.remove(victim)
        victim.state = "preempted"
        victim.preempt_count += 1
        self.waiting.insert(0, victim)  # re-queue at front

        if self.kv_cache.available() >= new_request.est_tokens:
            return True

    return False  # cannot make room even after preemption

Quantized KV Cache: FP8 and INT4

The KV cache is a major memory consumer during serving. Quantizing it from BF16 to FP8 or even INT4 can double or quadruple the number of concurrent requests:

KV Cache Memory Savings:

LLaMA 3 70B, 80 layers, 8 KV heads, d_head=128:

BF16: 2 bytes per element
Per token: 2 × 80 × 8 × 128 × 2 = 327,680 bytes = 320 KB

FP8 (E4M3): 1 byte per element
Per token: 2 × 80 × 8 × 128 × 1 = 163,840 bytes = 160 KB
Savings: 50%, accuracy loss: <0.1% on most benchmarks

INT4 (per-group quantized, group_size=128):
Per token: 2 × 80 × 8 × 128 × 0.5 + scales_overhead ≈ 85 KB
Savings: 73%, accuracy loss: 0.5-2% (task-dependent)

For latency-sensitive serving, FP8 KV cache is the best tradeoff.
For throughput-optimized batch serving, INT4 KV is acceptable.

Multi-Turn Conversation Optimization

In a chatbot, users send multiple messages in a conversation. Each new message requires attending to the full conversation history. Without optimization, every user message triggers a complete prefill of the entire conversation — wasting compute on tokens already processed.

The solution is KV cache persistence: keep the KV cache from previous turns alive between requests. When the user sends a new message, only the new tokens need prefill. This requires:

ChallengeSolutionTrade-off
KV cache occupies memory between turnsTTL-based eviction (expire after 5-30 min idle)Memory vs latency for returning users
User may never returnLRU eviction when memory pressure is highCold-start penalty for evicted users
Long conversations grow unboundedlySliding window attention or summary compressionContext accuracy vs memory
Scaling across serversSession affinity (route user to same server)Load balancing flexibility
A serving system handles a 70B model on 4x H100s with TP=4. During a traffic spike, TTFT (time to first token) increases from 200ms to 3 seconds, but ITL (inter-token latency) stays constant at 40ms. A junior engineer says "the GPUs are overloaded." What is actually happening, and what would you check?

Chapter 11: LLM-Guided Search

Every system we have studied so far uses LLMs as standalone generators: prompt in, text out. But the most exciting frontier in 2025 is using LLMs as components inside larger optimization loops. The LLM does not solve the problem — it proposes candidate solutions that an automated evaluator scores, and the best candidates are fed back to the LLM to produce better proposals. This is LLM-guided search.

This pattern has already produced genuine scientific discoveries: new algorithms for matrix multiplication, state-of-the-art solutions to open combinatorics problems, and improvements to real-world infrastructure systems. It is the closest thing we have to "AI doing research."

The General Pattern

Every LLM-guided search system follows the same four-step loop:

1. Specification
Define the problem formally: what is the input, what is the output, and what makes one output better than another. This must be machine-evaluable — no human judgment in the loop.
2. LLM Generation
Prompt the LLM with the specification, the best solutions found so far, and (optionally) a description of what has been tried. The LLM proposes new candidate solutions — typically as code.
3. Automated Evaluation
Run the candidate code in a sandbox. Measure its performance on a test suite. Score it. This step must be fully automatic and fast — thousands of evaluations per hour.
4. Selection + Feedback
Keep the best candidates. Discard the rest. Feed the best back into step 2 as examples. Repeat for hundreds or thousands of generations.
↻ repeat
Why this works. LLMs are remarkably good at making small, creative modifications to existing code. They are bad at solving hard problems from scratch. LLM-guided search exploits this asymmetry: the LLM does not need to find the optimal solution — it just needs to propose variations that are occasionally better than the current best. The automated evaluator handles the hard part (judging quality), and selection pressure handles the optimization.

FunSearch: LLM as Mutation Operator [lesson] [applied] [comp.prog paper]

Google DeepMind's FunSearch (2023) was the first system to demonstrate that LLMs can make genuine mathematical discoveries. The name stands for "searching in the space of functions" — the LLM generates Python functions, and an evaluator scores them.

FunSearch tackled the cap set problem — a notorious open problem in combinatorics. A cap set is a subset of points in F3n (the n-dimensional space over the field with 3 elements) such that no three points are collinear. The question: how large can a cap set be?

Mathematicians had found cap sets of size 512 in F38 through clever constructions. FunSearch found cap sets of size 512 as well — matching the best known — but using a completely different construction method that humans had not considered. For larger dimensions, it exceeded the previous best.

Worked Example: The Cap Set Problem

To understand what FunSearch actually discovered, we need to understand the problem it solved. A cap set in F3n is a subset S of {0, 1, 2}n such that no three elements a, b, c in S satisfy a + b + c = 0 (mod 3). In other words, no three points form an arithmetic progression.

Cap Set Example in F32 (2D, easy to visualize):

All 9 points: {(0,0), (0,1), (0,2), (1,0), (1,1), (1,2), (2,0), (2,1), (2,2)}

A cap set: {(0,0), (0,1), (1,0), (1,2), (2,1), (2,2)} — 6 points

Check: are any 3 collinear (mod 3)?
(0,0) + (1,1) + (2,2) = (3,3) = (0,0) mod 3 → Yes! These 3 are collinear.
So (0,0), (1,1), (2,2) cannot all be in a cap set. Remove (1,1).

Corrected cap set: {(0,0), (0,1), (1,0), (1,2), (2,1)} — 5 points
Verify: no triple sums to (0,0) mod 3. ✓
Maximum cap set size in F32: 4 (general construction)

The challenge grows exponentially with n:
F33: 27 points, max cap set = 9
F36: 729 points, max cap set = 112
F38: 6,561 points, max cap set = 512 (FunSearch matched this)
F3n for large n: best known ≈ 2.218n (Ellenberg-Gijswijt bound)

FunSearch does not search directly for cap sets. Instead, it evolves a construction function — a Python function that takes the dimension n and returns a set of points. The evaluator checks (a) that the returned set is actually a cap set (no collinear triples), and (b) how large it is. Larger = better score.

The insight: search over programs, not solutions. Searching directly over subsets of F38 would be intractable — there are 26561 possible subsets. But searching over programs that generate subsets is tractable because the LLM can express structured constructions (symmetry-based, layered, recursive) that implicitly encode huge sets. The program is a compressed representation of the solution.

FunSearch's Island Model

FunSearch does not use a single population of solutions. It maintains multiple islands — independent populations that evolve separately. Periodically, the best solution from one island migrates to another. This prevents premature convergence (all candidates becoming too similar).

FunSearch Island Architecture:

K islands, each with a population of N best solutions
Each island has its own "programs database" (ranked by score)

Per iteration:
1. Pick an island i uniformly at random
2. Sample 2-3 high-scoring programs from island i (tournament selection)
3. Prompt the LLM with these programs + the problem specification
4. LLM generates a new candidate program
5. Evaluate the candidate in a sandbox
6. If score > worst in island i's database: insert, evict worst

Every T iterations: migrate best program from random island to another

Key hyperparameters:
K = 10-20 islands
N = 50-100 programs per island
T = 100 iterations between migrations
Temperature = 1.0-1.2 (high creativity)
Why islands matter. Without islands, the LLM quickly converges to making tiny variations of one good solution. With islands, different populations explore different regions of the solution space. When a breakthrough in island 3 migrates to island 7, it cross-pollinates with island 7's different approach, sometimes producing solutions better than either parent. This is directly inspired by island-model evolutionary algorithms from the 1990s.

FunSearch Applied: Online Bin Packing

Beyond pure mathematics, FunSearch discovered new heuristics for online bin packing — a classic NP-hard optimization problem with real-world applications (packing containers, scheduling jobs, memory allocation).

The problem: items of various sizes arrive one at a time. You must place each item into a bin (capacity 1.0) immediately, without seeing future items. The goal: minimize the total number of bins used.

The best known online heuristic was First Fit Decreasing. FunSearch discovered a new priority function that outperformed all previously known heuristics on standard benchmarks. The discovered function was non-obvious — a piecewise function with specific breakpoints that no human had designed.

python
# FunSearch-discovered bin packing heuristic (simplified)
# The LLM evolved this priority function over 10,000+ generations
def priority(item_size: float, bin_remaining: float) -> float:
    """Score for placing item into a bin. Higher = preferred."""
    waste = bin_remaining - item_size
    if waste < 0:
        return -1e9  # does not fit
    if waste < 0.01:
        return 1e6   # nearly perfect fit — always prefer
    if waste < 0.1:
        return 500 - waste * 1000
    if waste < 0.33:
        return 100 - waste * 200
    # Key insight: prefer bins that leave specific "useful" gaps
    # (gaps that commonly-sized future items could fill)
    if 0.49 < waste < 0.51:
        return 80   # half-empty bins are useful
    if 0.32 < waste < 0.35:
        return 60   # one-third gaps are useful
    return 10 - waste * 5

AlphaEvolve: FunSearch's Successor [full paper →]

Google DeepMind's AlphaEvolve (2025) extends FunSearch in three critical ways:

FeatureFunSearchAlphaEvolve
Code scopeEvolves a single functionEvolves entire multi-file programs
ObjectivesSingle scalar scoreMulti-objective (Pareto frontier)
LLMOne model (Codey/PaLM)Ensemble: Gemini Flash (fast exploration) + Gemini Pro (deep reasoning)
EvaluationSandbox executionSandbox + formal verification where possible
ContextBest programs + specBest programs + spec + diff history + failure analysis

AlphaEvolve's most famous result: discovering a new algorithm for 4×4 complex matrix multiplication that uses fewer scalar multiplications than Strassen's algorithm. Specifically, it found a decomposition that uses 48 multiplications for 4×4 matrices, improving on the 49 previously known. This is a genuine mathematical discovery — the first improvement in this specific problem in decades.

AlphaEvolve in production at Google. Beyond pure math, AlphaEvolve is deployed inside Google's infrastructure. It optimized data center scheduling heuristics, reducing wasted compute by 0.7%. On a cluster the size of Google's, 0.7% savings is worth hundreds of millions of dollars per year. It also improved the Gemini training pipeline itself — finding a better learning rate schedule that reduced training time by 1.2%.

AlphaEvolve: A Deeper Look at the Architecture

AlphaEvolve's architecture is worth studying in detail because it represents the state of the art in LLM-guided search as of 2025. Let us trace the lifecycle of a single evolution step:

AlphaEvolve Evolution Step (detailed):

1. Problem Context Assembly
  - Full problem specification (natural language + formal constraints)
  - Current best programs (top 5, with scores)
  - Diff history: last 10 accepted mutations with annotations
  - Failure log: last 5 rejected mutations with error analysis
  - Domain hints: "try varying the inner loop structure" or "explore piecewise functions"

2. LLM Selection
  - 80% of iterations: Gemini Flash (fast, high temperature=1.2, broad exploration)
  - 20% of iterations: Gemini Pro (slow, lower temperature=0.7, deeper reasoning)
  - Switch to Pro exclusively when best score has not improved in 100 iterations

3. Multi-File Generation
  - Unlike FunSearch (one function), AlphaEvolve generates patches to a codebase
  - The LLM outputs a unified diff that modifies multiple files
  - A syntax checker validates the diff before evaluation

4. Multi-Objective Evaluation
  - Primary objective: correctness (does the program solve the problem?)
  - Secondary objectives: efficiency (runtime), simplicity (code length)
  - Pareto dominance: a new solution is accepted if it dominates the current
    best on at least one objective without worsening any other
The Strassen discovery in detail. Matrix multiplication of two n×n matrices requires n3 scalar multiplications with the naive algorithm. Strassen's 1969 algorithm does 2×2 blocks with 7 multiplications instead of 8, giving O(n2.807). For 4×4 matrices, the best known decomposition used 49 multiplications. AlphaEvolve searched over tensor decompositions — factorizations of the matrix multiplication tensor into rank-1 terms. After 200,000+ generations over several days, it found a decomposition using only 48 multiplications. The human mathematics community had not found this in 55 years of trying. The discovered decomposition uses non-intuitive coefficients that no human would have guessed.

Beyond FunSearch: Other LLM-as-Optimizer Systems

FunSearch and AlphaEvolve are the most famous, but the LLM-guided search paradigm has spawned many systems:

SystemYearLabDomainKey Innovation
FunSearch2023DeepMindMath, combinatoricsIsland model + LLM mutation
AlphaEvolve2025DeepMindMath, infrastructureMulti-file, multi-objective, LLM ensemble
EvoPrompting2023VariousPrompt optimizationEvolve prompts instead of code
ReEvo2024AcademicCombinatorial optimizationLLM reflects on failed attempts
OpenELM2024AcademicRobot controlLLM evolves reward functions
AIDE2024Weco AIML researchLLM writes + runs experiments in a tree search

The common thread: an LLM generates candidates in a domain where automated evaluation is possible, and selection pressure drives improvement over hundreds or thousands of generations.

The AutoResearch Pattern

Andrej Karpathy's concept of Autoresearch (2024-2025) extends LLM-guided search to the full research workflow. Instead of searching for a single function, the system:

1. Literature Review
LLM reads papers, identifies gaps, proposes hypotheses. Automated: arxiv search + paper parsing + claim extraction.
2. Experiment Design
LLM writes experiment code: model architecture, training loop, evaluation metrics. Follows templates from prior successful experiments.
3. Run Experiments
Execute on a compute cluster. Fully automated: launch job, monitor training, collect metrics, detect failures.
4. Analyze Results
LLM reads training logs, identifies trends, compares to baselines, generates plots. Proposes next experiments based on findings.
↻ repeat
When LLM search works (and when it does not). The pattern succeeds when: (1) evaluation is fully automatable (code runs, score appears — no human judgment needed), (2) the search space is vast (many possible programs), (3) small improvements compound (0.1% better per generation × 1000 generations = significant), and (4) the LLM can understand the domain well enough to make non-random proposals. It fails when: evaluation requires human judgment (aesthetics, safety), the search space is tiny (few valid programs), or the domain is so novel that the LLM's training data contains no relevant patterns.

Minimal FunSearch-Style Loop

Here is a complete, runnable implementation of the core FunSearch loop for evolving a sorting key function:

python
import random
import subprocess
import json
from openai import OpenAI

client = OpenAI()

# ── Problem: evolve a priority function for bin packing ──
SPEC = """You are evolving a Python function called `priority(item_size, bin_remaining)`
that returns a float score. Higher score = preferred bin for the item.
The function is used to solve online bin packing: items arrive one at a time,
each must be placed in a bin (capacity 1.0) immediately.
Goal: minimize total bins used across the test suite."""

EVAL_HARNESS = """
import random
random.seed(42)

# Generate test instances
def make_instance(n=100):
    return [random.uniform(0.1, 0.9) for _ in range(n)]

def evaluate(priority_fn, instances):
    total_bins = 0
    for items in instances:
        bins = []  # list of remaining capacities
        for item in items:
            scores = [(priority_fn(item, b), i) for i, b in enumerate(bins) if b >= item]
            if scores:
                _, best_i = max(scores)
                bins[best_i] -= item
            else:
                bins.append(1.0 - item)
        total_bins += len(bins)
    return -total_bins  # negative because lower bins = better, but we maximize score

instances = [make_instance(200) for _ in range(20)]
{candidate_code}
score = evaluate(priority, instances)
print(json.dumps({"score": score}))
"""

class Island:
    def __init__(self, capacity=50):
        self.programs = []  # (score, code) sorted by score descending
        self.capacity = capacity

    def insert(self, score, code):
        self.programs.append((score, code))
        self.programs.sort(key=lambda x: x[0], reverse=True)
        if len(self.programs) > self.capacity:
            self.programs.pop()

    def sample_best(self, k=2):
        # Tournament selection: pick from top 30%
        top = self.programs[:max(1, len(self.programs) // 3)]
        return random.sample(top, min(k, len(top)))

def evaluate_candidate(code: str, timeout: int = 30) -> float:
    """Run candidate in sandbox, return score."""
    full_code = EVAL_HARNESS.replace("{candidate_code}", code)
    try:
        result = subprocess.run(
            ["python3", "-c", full_code],
            capture_output=True, text=True, timeout=timeout
        )
        return json.loads(result.stdout)["score"]
    except:
        return -1e9  # crashed or timed out

def generate_candidate(parents: list) -> str:
    """Ask LLM to evolve a new candidate from parent programs."""
    parent_text = "\n\n".join(
        f"# Score: {s}\n{c}" for s, c in parents
    )
    response = client.chat.completions.create(
        model="gpt-4o",
        temperature=1.0,
        messages=[{
            "role": "user",
            "content": f"""{SPEC}

Here are the best solutions found so far:
{parent_text}

Write an improved `priority` function. Be creative — try a
different mathematical form. Output ONLY the Python function."""
        }]
    )
    return response.choices[0].message.content

# ── Main evolution loop ──
NUM_ISLANDS = 5
islands = [Island() for _ in range(NUM_ISLANDS)]

# Seed with a simple baseline
seed = """def priority(item_size, bin_remaining):
    if bin_remaining < item_size:
        return -1e9
    return -(bin_remaining - item_size)  # first-fit-decreasing"""
seed_score = evaluate_candidate(seed)
for island in islands:
    island.insert(seed_score, seed)

best_ever = seed_score
for gen in range(500):
    # Pick a random island
    island = random.choice(islands)
    parents = island.sample_best(2)

    # Generate and evaluate
    candidate = generate_candidate(parents)
    score = evaluate_candidate(candidate)
    island.insert(score, candidate)

    if score > best_ever:
        best_ever = score
        print(f"Gen {gen}: NEW BEST = {score}")

    # Migrate every 50 generations
    if gen % 50 == 0 and gen > 0:
        src, dst = random.sample(islands, 2)
        if src.programs:
            dst.insert(*src.programs[0])

When LLM Search Fails: The Antipatterns

Not every problem benefits from LLM-guided search. Understanding the failure modes is as important as understanding the successes:

AntipatternWhy It FailsExample
Non-automatable evaluationIf a human must judge quality, the loop is too slow (hours per generation vs seconds)Generating creative writing, designing UIs, composing music
Tiny search spaceIf there are only a few valid programs, random search works as well as LLM-guidedChoosing between 3 sorting algorithms
Deceptive fitness landscapeIf small changes cause large score swings, the LLM cannot learn meaningful gradientsCryptographic hash optimization
Insufficient LLM knowledgeIf the domain has no representation in training data, proposals are randomNovel chemistry synthesis (pre-2024 data cutoff)
Evaluation is too expensiveIf each evaluation takes hours (e.g., training a model), you get too few generationsNeural architecture search with full training

The Prompt Engineering of LLM Search

The quality of the LLM's proposals depends critically on how you prompt it. FunSearch and AlphaEvolve both discovered that specific prompt structures dramatically improve the mutation quality:

Show the diff, not just the code. Including the changes between parent generations (what was modified, what improved) helps the LLM understand the direction of productive mutations. AlphaEvolve includes a "diff history" of the last 5 improvements, annotated with which changes caused score increases.
Include failure analysis. When a candidate scores poorly, include a brief explanation of why (e.g., "timeout — function took 30x longer than baseline"). This steers the LLM away from known failure modes. Without this, the LLM will repeatedly propose the same broken approaches.
Temperature matters enormously. Low temperature (0.2-0.5) produces conservative mutations — small tweaks to existing code. High temperature (1.0-1.5) produces creative leaps — entirely new approaches. The optimal strategy is to use high temperature early (explore) and lower it as the best score plateaus (exploit). AlphaEvolve uses Gemini Flash at high temperature for broad exploration and Gemini Pro at lower temperature for targeted refinement.

Building Your Own LLM Search System

If you want to apply LLM-guided search to your own problem, here is the engineering checklist:

1. Define the Evaluation Function
This is the hardest part. It must: (a) return a scalar score, (b) run in under 30 seconds, (c) be deterministic (same code → same score), (d) have a smooth-ish landscape (small code changes → small score changes). If you cannot build this, LLM search will not work for your problem.
2. Build the Sandbox
LLM-generated code can be malicious, buggy, or infinite-looping. Run it in a sandboxed environment with: CPU/memory limits, network isolation, filesystem restrictions, and a hard timeout. Docker containers or gVisor work well.
3. Choose the Representation
What does the LLM generate? A single function is simplest (FunSearch). A config file is easier to validate. A multi-file program is most expressive but hardest to evaluate (AlphaEvolve). Start with the simplest representation that can express your solution space.
4. Implement the Population
Start with 5 islands of 30 programs each. Use tournament selection (pick from top 30%). Migrate every 50 generations. Seed each island with a different baseline heuristic if possible.
5. Scale and Monitor
Run 1,000-10,000 generations. Log every candidate's score, even failed ones. Plot best-score-over-time per island. If all islands converge to the same score, increase temperature or add more diverse seeds.

Interactive: Evolutionary Loop Visualizer

LLM-Guided Evolution Simulator

Watch islands of candidate solutions evolve over generations. Each dot is a candidate program. Color = island. Height = fitness score. Click "Evolve" to run one generation. Watch how migration spreads breakthroughs across islands.

Generation: 0 Best fitness: 0.00 Migrations: 0
FunSearch uses multiple "islands" that evolve independently with periodic migration. Why not just use a single population with a larger size? What failure mode does the island model prevent?

Chapter 12: The Scaling Book — TPU/GPU Rooflines [exercises →] [full book →]

You have written the training loop. You have configured the parallelism. The job is running. But the MFU reads 22%. Nearly 80% of the cluster's theoretical capability is wasted. Where is the performance going?

To answer this question, you need to understand how the hardware actually works — not at the CUDA-kernel level (Chapter 2 covered that), but at the architectural level: how data flows through the chip, where bottlenecks form, and how to predict performance before writing a single line of code. This is the roofline model, and it is the single most useful tool in a performance engineer's toolkit.

GPU Architecture: The H100

An NVIDIA H100 SXM is organized as a hierarchy of compute and memory:

ComponentSpecificationRole
Streaming Multiprocessors (SMs)132 SMsEach SM is an independent processor with its own registers, shared memory, and warp schedulers
Tensor Cores (per SM)4 fourth-gen Tensor CoresMatrix multiply accelerators. Each performs 16×16 matrix ops per cycle in BF16.
Peak BF16 FLOPS989 TFLOP/sTheoretical maximum with all tensor cores active
HBM3 Memory80 GB, 3.35 TB/s bandwidthMain memory. All model weights, KV cache, activations live here.
L2 Cache50 MBShared across all SMs. Caches frequently accessed HBM data.
Shared Memory (per SM)228 KBProgrammer-managed scratchpad. Faster than L2, used for tiling.
Register File (per SM)256 KBFastest storage. Each thread has up to 255 registers.
NVLink900 GB/s (bidirectional)GPU-to-GPU interconnect within a node (8 GPUs)
The memory hierarchy speed ladder. Registers: ~19 TB/s effective per SM. Shared memory: ~15 TB/s per SM. L2 cache: ~12 TB/s aggregate. HBM: 3.35 TB/s. NVLink: 0.9 TB/s. InfiniBand: 0.05 TB/s. Each step down is 3-10x slower. The art of GPU programming is keeping data as high in this hierarchy as possible.

TPU Architecture: The v5e

Google's TPU v5e is a fundamentally different design philosophy. Where GPUs evolved from graphics (many small cores, complex scheduling), TPUs were built from scratch for matrix multiplication.

ComponentSpecificationRole
MXU (Matrix Multiply Unit)2 MXUs per chipSystolic array: 128×128 BF16 multiply-accumulate units. Performs one 128×128 matmul per cycle.
Peak BF16 FLOPS197 TFLOP/sLower than H100, but higher utilization in practice
HBM Memory16 GB, 819 GB/s bandwidthSmaller and slower than H100. Constrains model size per chip.
VMEM (Vector Memory)16-32 MBOn-chip SRAM. Equivalent to GPU shared memory, but larger.
VPU (Vector Processing Unit)1 per chipHandles non-matmul operations: activations, normalization, softmax
ICI (Inter-Chip Interconnect)~400 GB/s per link, 6 linksDirect chip-to-chip connection in a 3D torus topology. No switch needed.

Systolic Arrays: How TPU MXUs Work

A systolic array is a grid of processing elements (PEs) where data flows rhythmically through the grid, like a heartbeat (hence "systolic"). Each PE performs one multiply-accumulate per cycle and passes the result to its neighbor.

For a 128×128 systolic array computing C = A × B:

Systolic Array Data Flow:

Matrix A rows flow left-to-right, staggered by one cycle per row.
Matrix B columns flow top-to-bottom, staggered by one cycle per column.
Each PE(i,j) accumulates: C[i,j] += A[i,k] × B[k,j] for each k

For 128×128 × 128×128 matmul:
Total multiply-accumulates: 1283 = 2,097,152
PEs in array: 128 × 128 = 16,384
Cycles needed: 128 (pipeline fill) + 128 (compute) + 128 (drain) = 384 cycles
But steady-state: 128 cycles for the core computation

Utilization: 1283 useful MACs / (16,384 PEs × 384 cycles) = 33.3%
(The fill and drain phases waste cycles. Larger matrices amortize this.)

For 4096×4096 × 4096×4096 (tiled into 128×128 blocks):
Tiles: 32×32×32 = 32,768 tiles
Each tile: 384 cycles, but pipelined: ~128 cycles steady-state
Effective utilization: >90%
Why TPUs love large matmuls. The systolic array has a fixed 128×128 tile size. If your matmul dimensions are not multiples of 128, you waste PEs on padding. A [127, 127] × [127, 127] matmul uses a full 128×128 array but only does useful work for 127/128 = 99.2% of each dimension — overall 96.9% efficiency. A [65, 65] matmul uses 50.8% × 50.8% = 25.8% of the array. This is why TPU-optimized models use hidden dimensions that are multiples of 128.

The Roofline Model

The roofline model predicts the maximum achievable performance of any operation based on a single metric: arithmetic intensity (operations per byte of memory traffic). The model has two regimes:

Arithmetic Intensity:
I = FLOPS / Bytes   (operations per byte of data moved)

Roofline Performance:
P = min(Peak_FLOPS,  I × Memory_Bandwidth)

Crossover point (ridge point):
Iridge = Peak_FLOPS / Memory_Bandwidth

H100: Iridge = 989 TFLOP/s / 3.35 TB/s = 295 ops/byte
TPU v5e: Iridge = 197 TFLOP/s / 0.819 TB/s = 240 ops/byte

If I < Iridge: operation is memory-bandwidth-bound
If I ≥ Iridge: operation is compute-bound

Let us compute the arithmetic intensity for the key operations in a transformer:

OperationFLOPSBytesIntensity (I)Bound
Linear layer (batch B, in d, out d)2Bd22d2 + 2Bd + 2Bd≈ B (for d>>B)Compute if B > 295
LayerNorm (batch B, dim d)5Bd4Bd (read + write)1.25Memory-bound (always)
Softmax (batch B, seq S)3BS4BS0.75Memory-bound (always)
GeLU (batch B, dim d)~8Bd4Bd2Memory-bound (always)
Attention QKT (batch B, heads H, seq S, d_h)2BHS2dh2BHSdh + 2BHS2≈ S (for large S)Compute if S > 295
Most of a transformer is memory-bound. Only the large matrix multiplies (linear layers with large batch, attention with long sequences) are compute-bound. LayerNorm, softmax, GeLU, residual adds, and dropout are ALL memory-bandwidth-bound. This is why kernel fusion matters so much: fusing GeLU into the preceding linear layer eliminates one round-trip to HBM, and FlashAttention fuses the entire attention computation (QKT, softmax, AV) into a single kernel that keeps data in SRAM.

TPU vs GPU: Where Each Wins

DimensionH100 (GPU)TPU v5eWinner for...
Peak BF16 FLOPS989 TFLOP/s197 TFLOP/sGPU: raw compute power
HBM Bandwidth3.35 TB/s0.819 TB/sGPU: memory-bound ops
HBM Capacity80 GB16 GBGPU: large models per chip
On-chip SRAM50 MB (L2) + 228KB/SM shared16-32 MB VMEMComparable
Interconnect BW900 GB/s NVLink (8 GPUs)~2.4 TB/s ICI (3D torus)TPU: multi-chip communication
Interconnect topologyAll-to-all via NVSwitch (8 GPUs)3D torus (thousands of chips)TPU: massive scale
Programming modelCUDA (explicit), TritonXLA (compiler-driven, implicit)GPU: flexibility; TPU: ease
Matmul efficiencyHigh but requires careful tuningVery high out of the box (systolic array)TPU: less tuning needed
Cost per TFLOPHigher ($)Lower ($/TFLOP)TPU: cost efficiency
Non-matmul opsFast (many SMs, flexible)Slower (single VPU)GPU: diverse workloads
The practical takeaway. TPUs excel at large, regular matrix multiplies at massive scale (training LLMs with thousands of chips). GPUs excel at diverse workloads, smaller models, inference, and anything that requires custom kernels or non-standard operations. Google trains Gemini on TPUs. Meta trains LLaMA on GPUs. Both achieve similar MFU — the engineering effort just goes to different places.

NVLink/NVSwitch vs ICI: Interconnect Battle

For distributed training, the interconnect between devices is often the bottleneck, not the compute. The two architectures take fundamentally different approaches:

NVLink/NVSwitch (NVIDIA):
8 H100s connected via NVSwitch → full bisection bandwidth
Any GPU can talk to any other at 900 GB/s
Cross-node: InfiniBand 400 Gbps (50 GB/s) — 18x slower than NVLink
This cliff between intra-node and inter-node is why TP must stay within a node.

ICI (Google TPU):
Each chip has 6 ICI links, each ~400 GB/s
Arranged in a 3D torus topology: [X, Y, Z] dimensions
A 256-chip pod: 4 × 8 × 8 torus
Bisection bandwidth of the torus: O(N2/3) × link_bandwidth
No NVSwitch-style full crossbar, but every chip can reach every other
via at most (X/2 + Y/2 + Z/2) hops.

Key difference: NVLink is fast within 8 GPUs, slow outside.
ICI is moderately fast across thousands of chips with no cliff.
This affects parallelism strategy: GPU clusters use TP=8 inside node + PP/DP outside.
TPU pods can use TP across 64+ chips because ICI bandwidth is uniform.

Worked Example: Same Matmul on H100 vs TPU v5e

Let us predict the performance of a [4096, 4096] × [4096, 4096] BF16 matmul on both chips, then verify against roofline predictions.

Operation: C = A × B, where A and B are [4096, 4096] BF16 matrices

FLOPS: 2 × 40963 = 137.4 GFLOP
Bytes loaded: 2 × 40962 × 2 bytes = 67.1 MB (A and B)
Bytes stored: 40962 × 2 bytes = 33.6 MB (C, BF16 output)
Total memory traffic: 100.7 MB

Arithmetic intensity: 137.4×109 / 100.7×106 = 1,364 ops/byte

Both chips' ridge points are ~240-295 ops/byte.
1,364 >> 295 → compute-bound on both.

H100 predicted time:
Time = FLOPS / Peak = 137.4×109 / 989×1012 = 0.139 ms
Realistic (with overhead): ~0.15-0.18 ms (~80% peak utilization)

TPU v5e predicted time:
Time = FLOPS / Peak = 137.4×109 / 197×1012 = 0.697 ms
Realistic: ~0.72-0.80 ms (~90% peak utilization)
(Higher utilization because systolic array has less scheduling overhead)

H100 is ~4.5x faster in raw time.
But TPU v5e is ~2-3x cheaper per chip.
Performance per dollar: roughly comparable.

The Profiling Workflow: JAX + XLA

On TPUs, you do not write kernels. Instead, JAX traces your Python code, XLA compiles it to optimized HLO (High-Level Operations), and the TPU runtime schedules the execution. Profiling tells you where time is spent in this compiled execution.

python
import jax
import jax.numpy as jnp
from jax import profiler

# ── Step 1: Define the computation ──
def transformer_block(x, w_q, w_k, w_v, w_o, w_ff1, w_ff2, w_ln):
    # Layer norm (memory-bound)
    x_norm = jax.nn.standardize(x, axis=-1)

    # Attention projections (compute-bound for large batch)
    q = x_norm @ w_q   # [B, S, D] @ [D, D] → [B, S, D]
    k = x_norm @ w_k
    v = x_norm @ w_v

    # Attention scores (compute-bound for long sequences)
    d_k = q.shape[-1]
    scores = (q @ k.transpose(0, 1, 3, 2)) / jnp.sqrt(d_k)
    attn = jax.nn.softmax(scores, axis=-1)

    # Attention output
    out = (attn @ v) @ w_o
    x = x + out  # residual (memory-bound)

    # FFN (compute-bound)
    x_norm = jax.nn.standardize(x, axis=-1)
    h = jax.nn.gelu(x_norm @ w_ff1)  # [B,S,D] @ [D,4D] → [B,S,4D]
    x = x + h @ w_ff2                 # [B,S,4D] @ [4D,D] → [B,S,D]
    return x

# ── Step 2: JIT compile ──
transformer_jit = jax.jit(transformer_block)

# ── Step 3: Profile ──
B, S, D = 8, 2048, 4096
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (B, S, D), dtype=jnp.bfloat16)
# ... initialize weight matrices ...

# Warm up (trigger compilation)
_ = transformer_jit(x, w_q, w_k, w_v, w_o, w_ff1, w_ff2, w_ln)

# Profile 10 iterations
with profiler.trace("/tmp/jax-trace"):
    for _ in range(10):
        x = transformer_jit(x, w_q, w_k, w_v, w_o, w_ff1, w_ff2, w_ln)
        x.block_until_ready()  # force synchronization

# View in TensorBoard: tensorboard --logdir=/tmp/jax-trace
# Look for:
#   - HLO operation breakdown (which ops take the most time)
#   - Memory usage timeline (peak vs sustained)
#   - ICI communication time (for multi-chip runs)
#   - MXU utilization % (target: >80%)

XLA Compilation: How JAX Lowers to Hardware

When you write `y = x @ w` in JAX, the path to execution is:

1. Python Tracing
JAX traces your Python function, recording operations as a computation graph (Jaxpr). No actual computation happens — just graph construction.
2. HLO Generation
The Jaxpr is lowered to HLO (High-Level Operations) — XLA's intermediate representation. HLO is like assembly for accelerators: dot products, broadcasts, reduces, etc.
3. XLA Optimization
XLA runs 50+ optimization passes: operation fusion (combine elementwise ops), layout optimization (row-major vs column-major), memory scheduling (minimize peak usage), sharding propagation (distribute across devices).
4. Backend Codegen
Optimized HLO is compiled to device-specific code: PTX for GPUs, TPU microcode for TPUs. The compiler chooses tile sizes, schedules memory transfers, and generates the final executable.
5. Execution
The compiled program runs on the hardware. Subsequent calls with the same shapes reuse the compiled binary (cache hit). Shape changes trigger recompilation.
Why XLA matters for TPUs. GPUs have CUDA: a flexible programming model where you write explicit kernels. TPUs have no public low-level programming model — XLA is the only way to program them. This means XLA must be much smarter than a typical GPU compiler. It performs operation fusion automatically (equivalent to writing a FlashAttention kernel), it handles multi-device sharding automatically (equivalent to writing NCCL collectives), and it manages memory automatically (equivalent to writing a custom allocator). The tradeoff: less control, but less engineering effort.

Sharded Matrices: How Matmul Distributes

When training on multiple TPU chips (or GPUs with tensor parallelism), the large weight matrices are sharded across devices. The matmul C = A × B becomes a distributed operation.

Column-Sharded Weight Matmul (TP across 4 chips):

A = [B, S, D] on all chips (replicated input)
B = [D, 4D] sharded column-wise: each chip holds [D, D]

Each chip computes: Ci = A × Bi → [B, S, D]
Result C = [C0 | C1 | C2 | C3] = [B, S, 4D]
No communication needed! Each chip has its shard of the output.

Row-Sharded Weight Matmul (for the second FFN layer):

A = [B, S, 4D] sharded: each chip holds [B, S, D] (its output from above)
B = [4D, D] sharded row-wise: each chip holds [D, D]

Each chip computes: Ci = Ai × Bi → [B, S, D]
But we need C = C0 + C1 + C2 + C3all-reduce

Communication: all-reduce of [B, S, D] across 4 chips
Volume: 2 × (4-1)/4 × B × S × D × 2 bytes
For B=8, S=2048, D=4096: 2 × 0.75 × 128 MB = 192 MB

Interactive: TPU vs GPU Architecture Comparison

Roofline Model Comparison

Visualizes the roofline model for H100 and TPU v5e. Drag the "operation" dot to see where different workloads fall. Operations below the roofline are underperforming — the gap shows wasted potential.

Arithmetic Intensity (ops/byte) 100

GPU Profiling with Nsight

On GPUs, the profiling workflow uses NVIDIA's Nsight Systems (system-level timeline) and Nsight Compute (kernel-level analysis):

bash
# Nsight Systems: capture a full training step timeline
nsys profile --trace=cuda,nvtx --output=train_step \
    python train.py --steps 10

# What to look for in the timeline:
# 1. GPU idle gaps → CPU-bound or waiting for data
# 2. Small kernels with gaps → launch overhead (use CUDA Graphs)
# 3. Long NCCL kernels → communication bottleneck
# 4. Memory copy (H2D/D2H) overlapping with compute → good!
# 5. Memory copy NOT overlapping → pipeline stall

# Nsight Compute: deep-dive on one kernel
ncu --set full --target-processes all \
    python -c "import torch; a=torch.randn(4096,4096,device='cuda',dtype=torch.bfloat16); b=a@a"

# Key metrics from ncu:
# - SM utilization (% of SMs active)
# - Tensor Core utilization (% of TC active)
# - Memory throughput (% of peak BW achieved)
# - Achieved occupancy (warps per SM)
# - Arithmetic intensity (actual, not theoretical)
The profiling checklist for a training run. When MFU is low, check in this order: (1) Are you compute-bound or memory-bound? (roofline analysis). (2) Is communication overlapped with compute? (Nsight Systems timeline). (3) Are individual kernels efficient? (Nsight Compute on the top 5 kernels by time). (4) Is the pipeline bubble too large? (count idle cycles). (5) Is the batch size large enough? (arithmetic intensity of matmuls). Fix the biggest bottleneck first — Amdahl's law means fixing a 5% bottleneck when a 40% bottleneck exists is pointless.

Memory-Bound Optimization: The Fusion Imperative

Our roofline analysis showed that most transformer operations are memory-bound. Let us derive exactly how much time is wasted on non-compute operations for a single transformer block on an H100:

Time Breakdown: 1 Transformer Block (LLaMA 3 70B, B=8, S=2048)

dmodel = 8192, dFFN = 28672, H = 64, HKV = 8

Compute-bound operations (limited by 989 TFLOP/s):
QKV projection: 2 × 8 × 2048 × 8192 × (8192+1024+1024) = 2.75T FLOP → 2.78 ms
Attn scores QKT: 2 × 8 × 64 × 20482 × 128 = 8.59T FLOP → 8.69 ms
Attn output AV: same as QKT → 8.69 ms
Output projection: 2 × 8 × 2048 × 81922 = 2.20T FLOP → 2.22 ms
FFN up: 2 × 8 × 2048 × 8192 × 28672 = 7.70T FLOP → 7.79 ms
FFN down: same → 7.79 ms
Total compute time: ~38.0 ms

Memory-bound operations (limited by 3.35 TB/s):
2× LayerNorm: 2 × 8 × 2048 × 8192 × 4 bytes (read+write) = 1.07 GB → 0.32 ms
Softmax: 8 × 64 × 20482 × 4 = 8.59 GB → 2.56 ms
GeLU: 8 × 2048 × 28672 × 4 = 1.88 GB → 0.56 ms
2× Residual add: 2 × 8 × 2048 × 8192 × 4 = 1.07 GB → 0.32 ms
Total memory time: ~3.76 ms

Overhead: 3.76 / (38.0 + 3.76) = 9.0% spent on memory-bound ops

With kernel fusion (FlashAttention + fused LayerNorm + fused GeLU):
The memory-bound ops are folded into compute kernels.
Overhead drops from 9% to ~2%. This is worth 7% MFU improvement.
7% MFU from fusion alone. On a 2048 H100 cluster costing $8,000/hr, a 7% efficiency improvement saves $560/hr = $13,440/day = $470K over a 35-day training run. This is why kernel fusion engineers are among the highest-paid specialists in ML infrastructure. A single well-fused kernel can save a frontier lab hundreds of thousands of dollars.

Practical Profiling: What to Do When MFU Is Low

Here is the systematic debugging checklist that production teams at frontier labs follow when MFU is below expectations:

StepCheckToolExpected ValueIf Failing
1Batch size large enough?Config checkB ≥ 8 per GPU (ideally 16+)Increase microbatch size, use gradient accumulation
2Communication overlapped?Nsight Systems / JAX profilerNCCL/ICI ops overlap with computeEnable async communication, check FSDP prefetch
3Pipeline bubble acceptable?Calculate (P-1)/(P-1+M)< 15%Increase microbatches M or reduce pipeline stages P
4FlashAttention enabled?Code checkUsing FA2 or FA3Switch to FlashAttention, massive memory + speed win
5Matmul dimensions aligned?Shape analysisMultiples of 64 (GPU) or 128 (TPU)Pad hidden dims to nearest multiple
6Activation checkpointing balanced?Memory profilerHBM usage 70-90% peakIf <70%: reduce checkpointing. If OOM: increase it.
7Data loading bottleneck?CPU utilizationGPU should never wait for dataMore dataloader workers, prefetch, faster storage

Putting It All Together: From Code to FLOPS

Here is the mental model you should carry. Every computation you write passes through this stack:

LayerGPU PathTPU Path
User codePyTorch / TritonJAX / Flax
Graph capturetorch.compile / CUDA Graphsjax.jit (tracing)
CompilerTorchInductor / Triton compilerXLA
IRTriton IR → PTX → SASSHLO → LLO → TPU microcode
RuntimeCUDA driver, NCCLTPU runtime, ICI stack
HardwareSM → Tensor Core → Register → SMEM → L2 → HBMVPU → MXU → VMEM → HBM
InterconnectNVLink (intra) → IB (inter)ICI (uniform 3D torus)

At every layer, performance can be lost. The profiling tools let you pinpoint where. The roofline model tells you the theoretical maximum. The gap between theoretical and achieved is your optimization opportunity.

You profile a transformer training step and find that LayerNorm takes 18% of total GPU time. LayerNorm has an arithmetic intensity of ~1.25 ops/byte. On an H100, the ridge point is ~295 ops/byte. A colleague suggests "we should optimize LayerNorm by using a more efficient algorithm." Is this the right approach?

Chapter 13: Mixture-of-Experts Architectures

You have just been handed an uncomfortable reality. The dense transformer you have been training — every parameter active on every token — is hitting diminishing returns. You doubled the parameter count from 7B to 14B, doubled the compute budget, and gained... 0.3 points on your eval suite. The Chinchilla-optimal frontier says you need 4x the tokens to justify 4x the parameters, but your data pipeline cannot produce tokens that fast. There has to be a way to get more parameters without a proportional compute increase.

There is. It is called Mixture of Experts (MoE), and it is the architectural pattern behind every frontier lab's largest models: GPT-4, Gemini, DeepSeek-V3, Mixtral, and Grok. The idea is deceptively simple: instead of one big feedforward network (FFN) that processes every token, you have many smaller FFNs (the "experts"), and a lightweight router decides which experts handle which token. Most of the parameters are dormant for any given token. You pay the storage cost of a huge model but only the compute cost of a small one.

The core trade-off. A dense 7B model uses 7 billion parameters and ~14 GFLOPs per token. An MoE model with 8 experts and top-2 routing has ~28B total parameters but still uses ~14 GFLOPs per token (only 2 of 8 experts fire). You get the representational capacity of a much larger model at the compute cost of the small one. The catch: you need enough memory to store all 28B parameters, even though most are idle at any moment.

Where MoE Lives in the Transformer

A standard transformer block has two main components: a multi-head attention layer and a feedforward network (FFN). MoE replaces the FFN. The attention layer stays dense — every head sees every token. Only the FFN becomes sparse.

Why only the FFN? Two reasons. First, the FFN typically accounts for 2/3 of the parameters in a transformer block (it expands the hidden dimension by 4x, then contracts back). Replacing it with experts gives you the biggest parameter multiplier for the smallest architectural change. Second, attention is inherently a cross-token operation — all tokens need to interact. The FFN operates token-independently, so routing tokens to different experts does not break any dependency.

Input Token Embeddings
Shape: [batch, seq_len, d_model]. Every token enters the block identically.
Multi-Head Attention (Dense)
All tokens attend to all other tokens. Same as any standard transformer. Nothing changes here.
Router Network
A tiny linear layer: Wgate ∈ ℝd_model × num_experts. For each token, produces a probability distribution over experts.
↓ top-k selection
Expert FFNs (Sparse)
Only top-k experts (typically k=2) are activated per token. Each expert is an independent FFN with its own weights. Outputs are weighted-summed by router probabilities.
Residual Add + LayerNorm
The weighted expert output is added back to the residual stream, same as a dense FFN.

The Router: Gated Softmax Selection

The router is the brain of MoE. It takes a token's hidden state x and produces a routing decision. Here is the math, step by step.

Step 1: Compute logits
h = x · Wgate     Wgate ∈ ℝd_model × N, N = num experts

Step 2: Select top-k experts
T = TopK(h, k)     indices of k largest logits

Step 3: Normalize routing weights (only among selected experts)
gi = softmax(hT)i = exp(hi) / ∑j ∈ T exp(hj)     for i ∈ T

Step 4: Weighted combination of expert outputs
y = ∑i ∈ T gi · Ei(x)     Ei(x) = the i-th expert FFN applied to x

Think of it as a committee vote where only the top-k committee members get to speak, and their influence is proportional to how confident the router is in their relevance.

Worked Example: Routing a Single Token

Let's trace through a concrete example. We have 8 experts, top-2 routing, and d_model = 4 (tiny for illustration).

// Token hidden state (d_model=4)
x = [0.5, -1.2, 0.8, 0.3]

// Gate weights W_gate shape [4, 8] (8 experts)
// After matmul: logits = x · W_gate, shape [8]
logits = [1.2, -0.5, 3.1, 0.2, -1.0, 2.8, 0.7, -0.3]

// Step 1: Top-2 selection
Top-2 indices: [2, 5]    (logits 3.1 and 2.8)

// Step 2: Softmax over selected logits only
g2 = exp(3.1) / (exp(3.1) + exp(2.8)) = 22.20 / (22.20 + 16.44) = 0.574
g5 = exp(2.8) / (exp(3.1) + exp(2.8)) = 16.44 / 38.64 = 0.426

// Step 3: Compute expert outputs (each is a full FFN forward pass)
E2(x) = FFN_2(x) = [0.9, -0.4, 1.1, 0.2]    expert 2 output
E5(x) = FFN_5(x) = [0.3, 0.8, -0.5, 1.0]    expert 5 output

// Step 4: Weighted sum
y = 0.574 × [0.9, -0.4, 1.1, 0.2] + 0.426 × [0.3, 0.8, -0.5, 1.0]
y = [0.516 + 0.128, -0.230 + 0.341, 0.631 - 0.213, 0.115 + 0.426]
y = [0.644, 0.111, 0.418, 0.541]

// Experts 0,1,3,4,6,7 were NEVER computed. Zero FLOPs for them.
Why top-k and not full softmax? If you used all 8 experts with softmax weights, you would compute all 8 FFN forward passes — the entire point of MoE (sparse computation) is lost. Top-k selection is what makes MoE efficient. The unselected experts contribute exactly zero to both the forward pass and the gradient. They are not "weighted low" — they are completely skipped.

Load Balancing: The Expert Collapse Problem

Here is the failure mode that kills naive MoE training. Without intervention, the router learns to send almost all tokens to 1-2 "favorite" experts. Those experts get better because they see more data. Because they are better, the router sends even more tokens to them. The other experts never get trained. Within a few thousand steps, you have a "Mixture of 2 Experts plus 6 dead weights." This is called expert collapse.

The standard fix is an auxiliary load-balancing loss that penalizes uneven expert usage. The loss has two terms:

Auxiliary load-balancing loss (Switch Transformer formulation):

fi = fraction of tokens routed to expert i
pi = average router probability assigned to expert i

Laux = α · N · ∑i=1N fi · pi

where N = number of experts, α = balancing coefficient (typically 0.01)

Perfect balance: fi = 1/N for all i, so the loss = α · N · N · (1/N) · (1/N) = α
Total collapse: f1 = 1, all others 0 ⇒ loss = α · N · 1 · p1 ≫ α

Why multiply fi by pi? Because fi depends on the argmax (not differentiable), but pi is the smooth softmax probability (differentiable). The product lets gradients flow through pi to encourage the router to spread probability mass across experts. The total loss becomes Ltotal = LLM + Laux.

Deriving the FLOPs Savings

Let's be precise about why MoE is efficient. A dense transformer FFN has two weight matrices: W1 ∈ ℝd × 4d and W2 ∈ ℝ4d × d. The FLOPs for one token through this FFN are approximately 2 × d × 4d + 2 × 4d × d = 16d2 (counting multiply-accumulate as 2 FLOPs).

// Dense FFN: one big FFN with hidden_dim = 4d
FLOPsdense = 16d2 per token
Parametersdense = 8d2

// MoE FFN: N experts, each with hidden_dim = 4d/N (common) or 4d (full-size)
// Case 1: Full-size experts (each expert = same size as the original FFN)
ParametersMoE = N × 8d2    N× more parameters
FLOPsMoE = k × 16d2 + FLOPsrouter
FLOPsrouter = 2 × d × N    tiny: just one matmul

// For top-2 of 8 full-size experts:
FLOPsMoE = 2 × 16d2 + 2dN = 32d2 + 16d
FLOPsdense_equivalent = 8 × 16d2 = 128d2

// Savings ratio (ignoring the negligible router term):
FLOPsMoE / FLOPsdense_equivalent = k/N = 2/8 = 25%

// You get 8× the parameters at 2× the dense compute, not 8×.
// Or equivalently: 4× parameter-to-FLOP efficiency.
The MoE efficiency equation. For top-k routing with N experts of identical size to the dense FFN: effective FLOPs ≈ k/N × dense-equivalent FLOPs. With top-2-of-8, you compute 25% of what a dense model with the same parameter count would need. The overhead (router computation, all-to-all communication) is real but typically adds only 5-10%.

Expert Parallelism: All-to-All Communication

MoE introduces a unique parallelism challenge. In a dense model, every GPU processes the same layers — you can split by data parallelism (different batches) or tensor parallelism (different parts of the same matrix). In MoE, different GPUs hold different experts, and tokens must physically move to the GPU hosting their assigned expert.

This is called expert parallelism, and the communication pattern is all-to-all: every GPU may need to send tokens to every other GPU, and every GPU may receive tokens from every other GPU. It is the MoE equivalent of traffic at an intersection where every car wants to go in a different direction.

Parallelism TypeWhat splitsCommunicationUsed for
Data ParallelBatch across GPUsAllReduce gradientsDense layers, attention
Tensor ParallelWeight matrices across GPUsAllReduce activationsLarge single layers
Expert ParallelExperts across GPUsAll-to-All tokensMoE FFN layers only
Pipeline ParallelLayers across GPUsPoint-to-point activationsVery deep models

The all-to-all operation has two phases. In the dispatch phase, each GPU sends its tokens to the GPUs hosting the selected experts. In the combine phase, each GPU sends the expert outputs back to the originating GPUs. Both phases are all-to-all collectives, and they are the dominant communication overhead in MoE training.

GPU 0: Experts 0-1
Holds tokens from its local batch. Router says: token A → expert 0, token B → expert 5 (on GPU 2), token C → expert 3 (on GPU 1).
↓ All-to-All Dispatch
Token Redistribution
Token B physically moves from GPU 0 to GPU 2. Token C moves to GPU 1. GPU 0 receives tokens from other GPUs that need expert 0 or 1.
↓ Expert Computation
Each GPU runs its local experts
GPU 0 runs expert 0 and 1 on all received tokens. GPU 2 runs expert 4 and 5. Each GPU processes a variable number of tokens.
↓ All-to-All Combine
Results return to origin GPUs
Expert outputs are sent back to the GPU that owns each token. Weighted sum by router probabilities. Resume dense computation.

Worked Example: All-to-All Communication Cost

Let's calculate the actual communication overhead for a realistic MoE setup, because this is the number that determines whether MoE is practical at your cluster scale.

// Setup: 8 GPUs, 8 experts (1 per GPU), top-2 routing
// Batch: 4096 tokens, d_model = 4096, BF16 (2 bytes/element)

// Each GPU has 4096 tokens. With top-2 routing, each token
// needs to go to 2 GPUs. So each GPU sends 2 × 4096 = 8192
// token copies total, distributed across 8 destination GPUs.

// On average, each GPU sends 8192/8 = 1024 token copies to each other GPU.
// Each token copy is d_model × 2 bytes = 4096 × 2 = 8 KB.

// Total data per GPU in ALL-TO-ALL DISPATCH:
sent_per_gpu = 8192 × 8 KB = 64 MB
// With NVLink at 600 GB/s bidirectional between GPUs:
dispatch_time = 64 MB / 600 GB/s ≈ 0.11 ms

// ALL-TO-ALL COMBINE (sending results back): same size
combine_time ≈ 0.11 ms

// Total communication overhead: ~0.22 ms per MoE layer
// Expert computation: 2 experts × matmul(1024, 4096, 16384)
expert_compute = 2 × 2 × 1024 × 4096 × 16384 / (275 × 1012) ≈ 0.79 ms

// Communication / Compute ratio = 0.22 / 0.79 = 28%
// Significant but manageable. On inter-node Ethernet (100 GB/s):
// dispatch_time = 64 MB / 100 GB/s = 0.64 ms → ratio jumps to 81%!
// This is why expert parallelism should stay INTRA-NODE.
The intra-node rule. Expert parallelism should always use the fastest interconnect available. On NVIDIA DGX systems, intra-node NVLink is 6x faster than inter-node InfiniBand. On TPU pods, intra-chip ICI is 10x faster than inter-host DCN. Rule of thumb: put expert parallelism on the fastest interconnect, data parallelism on the slowest. If you have 8 GPUs per node and 8 experts, put all experts on one node. If you have 64 experts, use 8 experts per node with data parallelism across nodes.

The Capacity Factor and Token Dropping

Load balancing is imperfect. Even with the auxiliary loss, some experts receive more tokens than others. If expert 3 is assigned 2x the average number of tokens, processing it takes 2x as long, and all other GPUs sit idle waiting — destroying parallelism.

The capacity factor C controls the maximum number of tokens each expert can process. If an expert is assigned more than C × (total_tokens / num_experts) tokens, the excess tokens are dropped — they pass through the layer with only the residual connection (no FFN computation). Typical values: C = 1.0 (strict) to C = 1.5 (generous).

// Expert capacity calculation
tokens_per_expert_ideal = total_tokens / num_experts
expert_buffer_size = ceil(C × tokens_per_expert_ideal)

// Example: 1024 tokens, 8 experts, C=1.25
ideal = 1024 / 8 = 128 tokens per expert
buffer = ceil(1.25 × 128) = 160 tokens max per expert

// If expert 3 is routed 200 tokens, 40 are dropped.
// Those 40 tokens get only the residual connection for this layer.
Token dropping is information loss. Dropped tokens skip the FFN entirely. In a 32-layer model where every other layer is MoE, a token might get dropped in 2-3 layers out of 16 MoE layers. Each drop means that token gets less processing. This is why load balancing matters so much — poor balance means high drop rates means degraded quality for tokens that happen to prefer popular experts.

DeepSeek-V3's MoE Design: Auxiliary-Free Balancing

DeepSeek-V3 (2024) introduced several MoE innovations that frontier labs are rapidly adopting. The most significant: no auxiliary balancing loss. Instead, they use a simple bias term.

FeatureStandard MoEDeepSeek-V3
Total params~47B (Mixtral 8x7B)671B
Active params~13B37B
Expert count8 coarse experts256 fine-grained + 1 shared
Top-k28 of 256 (3.1% activation ratio)
Shared expertNone1 expert that processes ALL tokens
BalancingAuxiliary loss (α = 0.01)Per-expert bias term, no aux loss
Expert granularityCoarse (each expert = full FFN)Fine-grained (each expert = 1/16 of FFN)

The key insight behind DeepSeek's auxiliary-free approach: instead of adding a loss term that fights with the main training objective, they add a learnable bias bi to each expert's routing logit. When an expert is underutilized, its bias increases slightly; when overutilized, it decreases. The bias is updated by a simple heuristic (not gradient descent), decoupling load balancing from the language modeling loss entirely.

DeepSeek-V3 routing with bias correction:

hi = x · wi + bi     bi adjusted by usage stats, not gradients

T = TopK(h, k=8)     select 8 of 256 experts

gi = sigmoid(hi) / ∑j ∈ T sigmoid(hj)     note: sigmoid, not softmax

y = Eshared(x) + ∑i ∈ T gi · Ei(x)     shared expert always active

Why 256 fine-grained experts instead of 8 coarse ones? Finer granularity means each expert can specialize more precisely. An expert that handles "Python variable assignment" is more useful than one that handles "all of programming." The cost is more complex routing, but with 8-of-256 selection, the activation ratio (3.1%) is even sparser than 2-of-8 (25%).

The shared expert is another innovation. One expert processes every single token — it captures common patterns that all tokens need (basic grammar, formatting, common words). The routed experts then only need to capture specialized knowledge. Think of it as a "base layer" that everyone gets, plus "specialist consultants" routed by topic.

jax.lax.ragged_dot: Variable-Length Expert Batches

Here is a practical problem that matters for your capstone implementation. When you route tokens to experts, each expert receives a different number of tokens. Expert 3 might get 50 tokens while expert 7 gets 120. You need to do batched matrix multiplication where each batch element has a different length. In dense models, every batch element has the same shape, so standard matmul works. In MoE, you need a ragged (variable-length) batched dot product.

JAX provides jax.lax.ragged_dot for exactly this. It takes:

python
import jax
import jax.numpy as jnp

# x: [total_tokens, d_model] — all tokens concatenated
# w: [num_experts, d_model, expert_dim] — stacked expert weights
# group_sizes: [num_experts] — how many tokens assigned to each expert
#   e.g., [50, 80, 120, 45, 90, 110, 65, 75] — sums to total_tokens

# ragged_dot computes expert_i's matmul on its assigned chunk of x
output = jax.lax.ragged_dot(
    x,              # [total_tokens, d_model]
    w,              # [num_experts, d_model, expert_dim]
    group_sizes,    # [num_experts] — token counts per expert
)
# output: [total_tokens, expert_dim]
# Equivalent to: for each expert i, slice x[start:start+group_sizes[i]]
# and compute matmul(slice, w[i]). Concatenate results.

The alternative without ragged_dot is to pad every expert's batch to the maximum length, waste compute on padding tokens, and then mask out the results. For highly imbalanced routing, this can waste 2-3x FLOPs. ragged_dot avoids the waste by handling variable-length segments natively on the accelerator.

The Switch Transformer: Where It All Began

Before Mixtral and DeepSeek-V3, there was the Switch Transformer (Fedus et al., 2022). It simplified the original MoE design (Shazeer et al., 2017) from top-2 routing to top-1 — each token goes to exactly one expert. This sounds wasteful, but it has a crucial advantage: simpler communication patterns and higher training stability.

// Switch Transformer routing (top-1):
h = x · Wgate     [batch × seq, num_experts]
i = argmax(h)     single expert selection per token
g = softmax(h)i     gate value for the selected expert
y = g · Ei(x)     no weighted sum — just one expert, scaled

// Compare with top-2:
// Top-1: each token to 1 expert. Simpler, faster, but less capacity per token.
// Top-2: each token to 2 experts. More capacity, but 2x the expert compute.
// The Switch Transformer showed that top-1 can match top-2 quality
// by using more experts (128 instead of 8).

The Switch Transformer also introduced the capacity factor and the auxiliary load-balancing loss formulation that became standard. Its key finding: MoE scaling is efficient. A Switch Transformer with 1.6 trillion parameters (but only ~100B active) trained 4x faster than a compute-equivalent dense T5 model, reaching the same quality in 1/4 the time.

Gating Mechanisms: Beyond Simple Softmax

The router's gating function is a design choice with significant consequences. Different papers use different gating mechanisms, each with tradeoffs.

Gating MechanismFormulaProsConsUsed In
Softmax top-ksoftmax(h)top-kStandard, well-understoodTop-k is not differentiable — requires STE or aux lossMixtral, Switch, GShard
Sigmoid top-ksigmoid(h)top-k / sumIndependent expert probabilities (not sum-to-1)Need explicit normalizationDeepSeek-V3
Expert ChoiceTop-k tokens per expert (inverse routing)Perfect load balance by constructionSome tokens may get zero expertsZhou et al. (2022)
Noisy top-ksoftmax(h + noise)top-kEncourages exploration during trainingAdded noise hurts convergenceShazeer et al. (2017)
Hash routingexpert_id = hash(token_id) % NZero overhead, deterministicNo learned specializationRoller et al. (2021)

The Expert Choice mechanism deserves special attention. Instead of each token choosing its experts, each expert chooses its top-k tokens. This guarantees perfect load balance (each expert gets exactly the same number of tokens) but creates a different problem: some tokens may be chosen by zero experts and others by many. In practice, Expert Choice often works better than auxiliary-loss-based balancing because it eliminates the hyperparameter sensitivity of the balancing coefficient.

Expert Choice routing. Think of it as a two-sided matching problem. Token routing asks: "Which experts should serve this token?" Expert Choice asks: "Which tokens should this expert serve?" By letting experts choose, you guarantee utilization. The tokens that no expert claims are processed by a small shared dense FFN as a fallback. DeepSeek-V3's shared expert serves a similar purpose — it catches the tokens that the routed experts might miss.

Chinchilla Laws for MoE

Chinchilla scaling laws for dense models say: for a compute budget C, the optimal model has Nopt ∝ C0.5 parameters trained on Dopt ∝ C0.5 tokens. But MoE breaks this equation because active parameters ≠ total parameters.

// Dense Chinchilla: N_opt and D_opt scale with sqrt(C)
Nopt ≈ 0.7 × C0.5
Dopt ≈ 1.4 × C0.5

// MoE adjustment: effective compute per token depends on active params
Nactive = Ntotal × (k / Nexperts) + Ndense_layers
Ceffective = 6 × Nactive × D    (6 FLOPs per param for forward+backward)

// MoE sweet spot: more total params, same compute
// Scaling law from Clark et al. (2022) and Krajewski et al. (2024):
L(Nactive, Ntotal, D) = A / Nactiveα + B / Ntotalβ + C / Dγ + L0

// Key finding: β < α, meaning adding total params (more experts)
// has diminishing returns compared to adding active params.
// The optimal number of experts scales sub-linearly with compute.

The practical implication: you cannot just add infinite experts. At some point, the communication overhead and the diminishing scaling returns of inactive parameters make it better to increase expert size or add more active experts. DeepSeek-V3's choice of 256 experts with 8 active represents a point on this frontier that was determined empirically.

MoE Router Implementation in JAX

jax
import jax
import jax.numpy as jnp
import flax.linen as nn

class TopKRouter(nn.Module):
    """Top-k expert routing with load-balancing auxiliary loss."""
    num_experts: int = 8
    top_k: int = 2

    @nn.compact
    def __call__(self, x):
        # x: [batch, seq_len, d_model]
        B, S, D = x.shape

        # Router: single linear projection to expert logits
        logits = nn.Dense(self.num_experts, name="gate")(x)  # [B, S, N]

        # Top-k selection
        top_k_logits, top_k_indices = jax.lax.top_k(logits, self.top_k)

        # Routing weights: softmax over selected experts only
        gates = jax.nn.softmax(top_k_logits, axis=-1)  # [B, S, k]

        # ── Auxiliary load-balancing loss ──
        # f_i = fraction of tokens routed to expert i
        router_probs = jax.nn.softmax(logits, axis=-1)  # [B, S, N]
        expert_mask = jax.nn.one_hot(top_k_indices, self.num_experts)  # [B,S,k,N]
        expert_mask = expert_mask.sum(axis=-2)  # [B, S, N] — 1 if expert selected

        f_i = expert_mask.mean(axis=(0, 1))     # [N] — fraction of tokens per expert
        p_i = router_probs.mean(axis=(0, 1))  # [N] — mean router prob per expert

        aux_loss = self.num_experts * (f_i * p_i).sum()

        return gates, top_k_indices, aux_loss

class MoELayer(nn.Module):
    """Mixture of Experts FFN layer."""
    num_experts: int = 8
    top_k: int = 2
    expert_dim: int = 2048  # hidden dim of each expert FFN

    @nn.compact
    def __call__(self, x):
        B, S, D = x.shape

        # Route tokens to experts
        gates, indices, aux_loss = TopKRouter(
            self.num_experts, self.top_k
        )(x)  # gates: [B,S,k], indices: [B,S,k]

        # Compute ALL expert outputs (simple but wasteful)
        # Production code uses scatter-gather for true sparsity
        expert_outputs = []
        for i in range(self.num_experts):
            # Each expert is a 2-layer FFN with GELU
            h = nn.Dense(self.expert_dim, name=f"expert_{i}_up")(x)
            h = jax.nn.gelu(h)
            h = nn.Dense(D, name=f"expert_{i}_down")(h)
            expert_outputs.append(h)

        # Stack: [num_experts, B, S, D]
        all_experts = jnp.stack(expert_outputs, axis=0)

        # Gather selected expert outputs using indices
        # indices: [B, S, k] — which experts for each token
        selected = jnp.take_along_axis(
            all_experts,
            indices.transpose(2, 0, 1)[..., None].broadcast_to(
                self.top_k, B, S, D),
            axis=0
        ).transpose(1, 2, 0, 3)  # [B, S, k, D]

        # Weighted sum by gate values
        output = (selected * gates[..., None]).sum(axis=-2)  # [B, S, D]

        return output, aux_loss
Note on the implementation. This code computes all expert outputs and then selects — this is correct but not efficient for large N. Production MoE uses permute-compute-unpermute: sort tokens by expert, run each expert only on its assigned tokens (via ragged_dot or padding), then unsort. The above is clearer for learning; the capstone in chapter 15 will implement the efficient version.

MoE and Multi-Modality: Routing Across Modalities

A frontier topic that appears in research discussion interviews: how do you route tokens from different modalities (text, images, audio) through the same MoE? This is relevant because frontier models like Gemini and GPT-4o are multi-modal.

The challenge: text tokens and image patch tokens have very different statistical properties. Text tokens have a discrete, high-entropy distribution. Image patch tokens are continuous, locally correlated, and lower-entropy. If you use a single router for both, the router may route all image tokens to one set of experts and all text tokens to another — creating accidental modality-specific experts rather than concept-specific experts.

// Multi-modal routing approaches:

// Approach 1: Shared router (simplest)
// All modalities use the same W_gate
h = [text_tokens; image_tokens] · Wgate
// Problem: modality bias in routing decisions

// Approach 2: Modality-specific routers
htext = text_tokens · Wgate_text
himage = image_tokens · Wgate_image
h = concat(htext, himage)
// Better: each modality gets routing weights tuned to its distribution

// Approach 3: Modality-reserved experts (DeepMind style)
// Experts 0-3: reserved for text. Experts 4-7: reserved for image.
// Text tokens only routed among text experts.
// Prevents cross-modal interference but limits cross-modal learning.

// Approach 4: Shared + modality-specific experts
// 2 shared experts (all modalities) + 4 text + 4 image
// Shared experts learn cross-modal patterns (e.g., "this image matches this text")
// Modality experts learn modality-specific patterns

This is an active research area with no consensus answer. In an interview, the best response is to present the tradeoffs (shared routing gives more flexibility but risks modality collapse; reserved experts are safer but limit cross-modal transfer) and propose an experiment to decide (train both configurations, measure loss on text-only, image-only, and multi-modal benchmarks).

MoE and Reinforcement Learning from Human Feedback

An important practical consideration: MoE models behave differently under RLHF than dense models. The router's decisions can shift dramatically during the reward model training phase, because RLHF changes the distribution of "good" outputs.

The RLHF routing shift. During pretraining, the router learns a stable routing pattern based on the data distribution. During RLHF, the reward model penalizes certain outputs, which changes the gradient signal to the router. Experts that handled "unsafe" or "low-quality" patterns during pretraining may see dramatically fewer tokens during RLHF. This can trigger a mini expert collapse: the "safety-penalized" experts get starved of gradient signal and degrade. The fix: freeze the router during RLHF (only fine-tune the expert weights) or use a lower learning rate for the router than for the experts.

MoE Routing Visualization

Expert Specialization: What Do Experts Learn?

A natural question: do MoE experts specialize in meaningful ways, or is the routing arbitrary? Empirical studies on trained MoE models reveal striking patterns.

In language models, experts often specialize by domain. One expert handles code, another handles scientific text, another handles conversational language. This happens without any explicit supervision — the router learns to route tokens to the expert that produces the best loss for that domain. Think of it as emergent department formation in a company: people naturally gravitate toward the work they are best at.

In the addition task from our capstone, we expect even finer specialization. Some experts should learn to handle carry operations (the hard part), while others handle simple single-digit addition (the easy part). Some might specialize by digit position — handling the ones place, tens place, etc. Analyzing this specialization after training is one of the most interesting parts of the capstone.

// Expert specialization analysis (run after training)
// For each token in a batch, record which experts were selected
// and what type of token it was:

Token types for addition:
- DIGIT_INPUT: digits before "="
- OPERATOR: the "+" token
- SEPARATOR: the "=" token
- DIGIT_OUTPUT_NOCARRY: output digits with no carry from previous position
- DIGIT_OUTPUT_CARRY: output digits that required a carry

// Expected findings after training:
// Expert 2 handles 80% of DIGIT_OUTPUT_CARRY tokens
// Expert 5 handles 70% of DIGIT_INPUT tokens at even positions
// Expert 0 is the "generalist" — handles mixed types uniformly
// This is the kind of analysis that impresses interviewers.

MoE Failure Modes: A Field Guide

When MoE training goes wrong, the symptoms are distinctive. Here is a diagnostic guide that will save you hours of debugging, both in the capstone and in production.

SymptomLikely CauseFix
Loss plateaus after 1K steps, never improvesExpert collapse — all tokens route to 1-2 experts. The "dead" experts have random weights that never receive gradients.Increase α (balancing coefficient) from 0.01 to 0.1. If already high, check that the balancing loss is actually being added to the total loss (common bug: forgetting to add aux_loss).
Loss is good but eval accuracy is badToken dropping. High-capacity experts are hitting their capacity limit and dropping tokens during evaluation. Dropped tokens get degraded processing.Increase capacity factor C from 1.0 to 1.5. Or switch to auxiliary-free balancing (DeepSeek-V3 style) which avoids the capacity factor entirely.
Training is 3x slower than expected for top-k routingAll-to-all communication is bottlenecked. Tokens are being sent across slow inter-node links instead of fast intra-node NVLink/ICI.Co-locate frequently-paired experts on the same node. Use expert-slicing: split each expert across fewer GPUs rather than distributing experts across all GPUs.
Router probabilities are near-uniform (entropy too high)Router has not learned to differentiate tokens. Often happens when the balancing loss is too strong — it overwhelms the language modeling signal.Reduce α by 10x. Warm up the balancing loss: start with α=0 for the first 1000 steps, then ramp up linearly.
One expert has 10x the gradient norm of othersThat expert is receiving all the "hard" tokens (high loss, high gradients). Common with unbalanced routing or when one domain is much harder than others.Per-expert gradient clipping. Clip each expert's gradients independently rather than globally. This prevents the "hard" expert from destabilizing the others.

MoE at Inference Time: The Serving Challenge

Training MoE is hard. Serving MoE is harder. During training, you have a large batch of tokens that statistically covers all experts — every expert gets enough work to justify its GPU. During inference, a single request might only activate 2 experts, leaving the other 6 GPUs idle.

// Serving efficiency calculation:
// Dense model: every GPU does useful work on every token
GPU utilization (dense) = 100%

// MoE with 8 experts on 8 GPUs, single request, top-2:
// Only 2 of 8 GPUs do expert computation per token
GPU utilization (MoE, single request) = 2/8 = 25%

// MoE with batching: 32 concurrent requests
// Each request activates 2 experts. With 32 requests,
// each expert gets ~32 × 2/8 = 8 tokens on average
GPU utilization (MoE, batched) ≈ 80-95%

// Key insight: MoE inference REQUIRES batching for efficiency.
// This is why vLLM and TGI implement continuous batching —
// it's not just about throughput, it's about GPU utilization.
Expert offloading. An emerging technique for serving very large MoE models (100B+ total params) on limited hardware: keep only a few experts in GPU memory and offload the rest to CPU memory or SSD. The router predicts which experts the next token will need, and they are prefetched from CPU while the current token is being processed. This works because expert access is sparse and somewhat predictable from context. Mixtral-8x7B can run on a single GPU this way, at the cost of higher latency.

The simulation below shows tokens flowing through an MoE layer. Each token is routed to its top-2 experts. Watch how load balancing distributes tokens, and see what happens when you disable it — experts collapse.

MoE Token Routing

Tokens enter from the left and route to their top-2 experts. Bar chart shows expert utilization. Toggle load balancing to see expert collapse.

Mixtral vs DeepSeek-V3 vs GShard: Comparing MoE Designs

Three landmark MoE designs represent different points on the design space. Understanding their differences teaches you the tradeoffs that matter.

FeatureGShard (2020)Mixtral 8x7B (2023)DeepSeek-V3 (2024)
OriginGoogleMistral AIDeepSeek
Total params600B46.7B671B
Active params~100B12.9B37B
Num experts20488256 + 1 shared
Top-k228
Activation ratio0.1%25%3.1%
Expert sizeVery small (fine-grained)Large (full 7B FFN each)Small (fine-grained)
Shared expertNoNoYes (1)
BalancingCapacity factor + aux lossAux loss onlyBias term (no aux loss)
Key innovationScaled MoE to 600B+ paramsOpen-weight MoE that beats dense Llama 2 70BAuxiliary-free balancing + shared experts

The trend is clear: newer designs use more experts, finer granularity, and more sophisticated balancing. GShard proved MoE works at scale. Mixtral proved it works in the open-weights community. DeepSeek-V3 pushed the frontier with novel routing and training techniques.

MoE Training Budget: How Much Compute Do You Save?

Let's make the savings concrete with a worked example comparing training costs.

// Scenario: you have a compute budget of C = 10^24 FLOPs
// (roughly: 1000 H100 GPUs for 1 month)

// Option A: Dense model
// Chinchilla-optimal: N = 0.7 × (C/6)^0.5 = 0.7 × (1.67 × 10^23)^0.5
Ndense ≈ 9B parameters
Ddense ≈ 180B tokens
// Expected loss: L ≈ 2.8 (on a typical LM benchmark)

// Option B: MoE model (8 experts, top-2)
// Active parameters = same as dense: 9B
// Total parameters = 9B + 3 × 8 × FFN_params ≈ 36B
// FLOPs per token = same (top-2 of 8 = 25% experts, but experts are 4x larger)
// Actually: active params = attention_params + 2/8 × expert_params
NMoE_active ≈ 9B
NMoE_total ≈ 36B
DMoE ≈ 180B tokens (same compute budget)
// Expected loss: L ≈ 2.5 (0.3 lower than dense!)
// The extra parameters provide more capacity at same compute cost.

// What if you wanted dense to match MoE's loss?
// You'd need N_dense ≈ 20B + D ≈ 400B tokens
// Compute: 6 × 20B × 400B = 4.8 × 10^22 × 6 = 2.88 × 10^24 FLOPs
// That's ~3x MORE compute than MoE needed for the same loss.
The MoE value proposition, quantified. For the same compute budget, MoE achieves a loss that a dense model would need ~3x more compute to match. Equivalently: MoE lets you train a model of equivalent quality using 3x fewer GPU-hours. At $2/GPU-hour for H100s, on a 1000-GPU cluster, that is saving $1.4M per month. This is why every frontier lab uses MoE for their largest models.

Implementing Permute-Compute-Unpermute

The efficient MoE implementation avoids computing all experts by physically sorting tokens by expert assignment. This is the permute-compute-unpermute pattern used in production systems.

jax
def efficient_moe_forward(x, expert_weights_up, expert_weights_down, gates, indices):
    """Efficient MoE: sort tokens by expert, compute, unsort."""
    B, S, D = x.shape
    x_flat = x.reshape(-1, D)  # [B*S, D]
    T = B * S

    # Step 1: PERMUTE — sort tokens by expert assignment
    # For top-2: each token appears twice (once per selected expert)
    expert_ids = indices.reshape(-1)  # [T*k] — flat list of expert assignments
    token_ids = jnp.repeat(jnp.arange(T), 2)  # [T*k] — which token each entry came from

    # Sort by expert id to group tokens for each expert together
    sort_order = jnp.argsort(expert_ids)
    sorted_expert_ids = expert_ids[sort_order]
    sorted_token_ids = token_ids[sort_order]

    # Gather sorted tokens
    sorted_x = x_flat[sorted_token_ids]  # [T*k, D]

    # Count tokens per expert (group_sizes for ragged_dot)
    group_sizes = jnp.bincount(sorted_expert_ids, length=NUM_EXPERTS)

    # Step 2: COMPUTE — run each expert on its tokens
    # Using ragged_dot or sequential per-expert matmul
    h = jax.lax.ragged_dot(sorted_x, expert_weights_up, group_sizes)
    h = jax.nn.gelu(h)
    y = jax.lax.ragged_dot(h, expert_weights_down, group_sizes)

    # Step 3: UNPERMUTE — scatter results back to original token order
    # Weight by gate values and sum the k=2 contributions per token
    gate_flat = gates.reshape(-1)  # [T*k]
    sorted_gates = gate_flat[sort_order]
    y_weighted = y * sorted_gates[:, None]

    # Scatter-add back to original positions
    output = jnp.zeros_like(x_flat)
    output = output.at[sorted_token_ids].add(y_weighted)

    return output.reshape(B, S, D)
Why permute-compute-unpermute? The naive approach (loop over experts, mask tokens) computes all experts on all tokens and masks out the irrelevant ones — O(N × T) compute instead of O(k × T). The efficient approach physically moves tokens to their assigned experts, computes only what is needed, and moves results back. The overhead is two all-gather/scatter operations, but the compute savings dominate for large N.
MoE fundamentals: An MoE model has 64 experts, top-4 routing, and each expert is the same size as the dense FFN it replaces. What is the approximate FLOPs savings compared to a dense model with the same total parameter count?

Chapter 14: Pallas Kernels

Your MoE model is training on a TPU v4 pod. You profile the training step and discover that 35% of the step time is spent inside the MoE layer — specifically, inside jax.lax.ragged_dot. The operation handles variable-length expert batches correctly, but its generic implementation is not exploiting the specific structure of your problem. Your experts all have the same hidden dimension. Your batch sizes per expert, while variable, fall into a predictable range. You know that if you could write a custom kernel that tiles the computation exactly right for your access pattern, you could cut that 35% down to 20%.

This is where Pallas comes in. Pallas is JAX's framework for writing custom accelerator kernels — for both TPUs and GPUs — directly in Python. No CUDA C++. No XLA HLO surgery. You write Python functions decorated with a few Pallas-specific annotations, and the framework compiles them to efficient machine code for your target hardware.

Why Pallas matters for your interview. Frontier labs are heavy JAX/TPU shops (Google DeepMind, Anthropic, xAI). The ability to write Pallas kernels is a rare and highly valued skill because it lets you bypass the performance ceiling of JAX's built-in operations. When an interviewer at one of these labs asks "what do you do when a JAX op is too slow?", the answer is not "switch to PyTorch." The answer is Pallas.

What is Pallas?

Pallas is a kernel authoring API that sits between JAX's high-level NumPy-like interface and the low-level hardware instruction set. Think of it as JAX's answer to Triton (for NVIDIA GPUs) — except Pallas targets both TPUs and GPUs with a unified programming model.

Abstraction LevelTechnologyYou writeYou control
HighestJAX / NumPyjnp.matmul(A, B)Nothing — XLA decides everything
MidPallasGrid + BlockSpec + kernel bodyTiling, memory placement, accumulation
LowestMosaic (TPU) / PTX (GPU)Raw hardware instructionsEverything — registers, pipelines, barriers

The key insight of Pallas is declarative data movement. Instead of manually loading tiles from global memory into shared memory (as you would in CUDA), you declare what data each block needs via a BlockSpec, and the compiler figures out how to move it. This is a fundamental difference from CUDA, where you write the loads yourself.

The BlockSpec Abstraction

A BlockSpec describes the shape and indexing pattern of one input or output. It tells Pallas: "For grid cell (i, j), give me this slice of the array." The compiler then generates the appropriate memory load/store instructions.

// BlockSpec anatomy
BlockSpec(
  block_shape=(BLOCK_M, BLOCK_K), # shape of the tile this kernel sees
  index_map=lambda i, j: (i, j), # given grid indices, which tile?
)

// Example: matrix A of shape [M, K], tiled into BLOCK_M × BLOCK_K blocks
// Grid cell (2, 3) maps to A[2*BLOCK_M : 3*BLOCK_M, 3*BLOCK_K : 4*BLOCK_K]
// The index_map lambda tells Pallas exactly this mapping.

Why is this better than writing loads manually? Three reasons. First, the compiler can prefetch: it knows what data the next grid cell will need before you ask. Second, it can double-buffer: while the kernel computes on one tile, the hardware asynchronously loads the next tile. Third, it can optimize memory placement: on TPU, it decides whether to put your tile in VMEM (fast, small) or HBM (slow, large). You do not have to think about any of this.

Grid and Block Dimensions

A Pallas kernel executes on a grid — a multi-dimensional array of blocks. Each block processes one tile of the computation. The grid dimensions determine the parallelism and tiling strategy.

Grid dimensions for matrix multiply C = A · B:
A: [M, K]    B: [K, N]    C: [M, N]

Block sizes: BLOCK_M, BLOCK_K, BLOCK_N
Grid: (M / BLOCK_M, N / BLOCK_N)     2D grid, one cell per output tile

For each grid cell (i, j):
  C_tile = zeros(BLOCK_M, BLOCK_N)
  for k in range(K / BLOCK_K):
    A_tile = A[i*BM : (i+1)*BM, k*BK : (k+1)*BK]
    B_tile = B[k*BK : (k+1)*BK, j*BN : (j+1)*BN]
    C_tile += A_tile @ B_tile
  C[i*BM : (i+1)*BM, j*BN : (j+1)*BN] = C_tile

The k-dimension loop is the reduction axis. On each iteration, we load a tile of A and a tile of B, multiply them, and accumulate into the output tile. The accumulator stays in fast on-chip memory (VMEM on TPU, registers on GPU) across all k iterations.

Memory Spaces: VMEM, SMEM, HBM

Understanding the memory hierarchy is essential for writing fast kernels.

MemoryHardwareSizeBandwidthLatencyYou manage?
HBMTPU & GPU16-80 GB~2 TB/s (GPU), ~1.2 TB/s (TPU)~400 cyclesPallas handles via BlockSpec
VMEMTPU only16-32 MB per core~100 TB/s (on-chip)~5 cyclesPallas auto-manages tiles here
SMEMGPU only48-228 KB per SM~20 TB/s (on-chip)~20 cyclesPallas uses for tile staging
RegistersBoth~256 KB per core/SMInfinite (local)0 cyclesCompiler assigns

The key ratio is compute-to-memory. A TPU v4 core does ~275 TFLOPS but can only load ~1.2 TB/s from HBM. That means for every byte you load, you need to do ~230 FLOPs to keep the compute units busy. If your kernel does fewer FLOPs per byte loaded, it is memory-bound and the compute units are starving. If it does more, it is compute-bound and the memory system can keep up.

The tiling insight. Matrix multiplication has an arithmetic intensity of O(N) FLOPs per byte (for NxN matrices), which is why it benefits enormously from tiling: load a block into fast memory, do O(BLOCK_SIZE) FLOPs on it, and never touch slow memory again. Elementwise operations like ReLU have O(1) FLOPs per byte — they are always memory-bound, and tiling barely helps. This is why kernel fusion (combining elementwise ops with matmuls) is so important: the elementwise ops "ride for free" on data that is already in fast memory for the matmul.

A First Pallas Kernel: Vector Add

Let's start with the simplest possible Pallas kernel — adding two vectors. This is memory-bound and will not be faster than jnp.add, but it teaches the API structure.

jax/pallas
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl

def vector_add_kernel(x_ref, y_ref, o_ref):
    """Pallas kernel body. Operates on one block at a time.
    x_ref, y_ref, o_ref are Ref objects — views into the block.
    """
    o_ref[...] = x_ref[...] + y_ref[...]

def vector_add(x, y):
    """Launch the Pallas kernel over a 1D grid."""
    n = x.shape[0]
    BLOCK = 512  # process 512 elements per block
    grid = (n // BLOCK,)  # 1D grid

    return pl.pallas_call(
        vector_add_kernel,
        grid=grid,
        in_specs=[
            pl.BlockSpec(block_shape=(BLOCK,), index_map=lambda i: (i,)),
            pl.BlockSpec(block_shape=(BLOCK,), index_map=lambda i: (i,)),
        ],
        out_specs=pl.BlockSpec(block_shape=(BLOCK,), index_map=lambda i: (i,)),
        out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
    )(x, y)

# Test
x = jnp.ones(4096)
y = jnp.ones(4096) * 2
result = vector_add(x, y)  # [3.0, 3.0, 3.0, ...]
print(result[:5])  # [3. 3. 3. 3. 3.]

Let's dissect every piece:

ComponentWhat it doesAnalogy
kernel functionThe computation for one block. Receives Ref objects (views into tiles), writes output.The body of a CUDA __global__ function
gridHow many blocks to launch. (8,) means 8 blocks, each processing one tile.CUDA grid dimensions
BlockSpecDeclares the shape and index mapping for each input/output tile.The shared memory load pattern in CUDA
index_mapLambda that maps grid indices to tile indices. lambda i: (i,) means block i reads tile i.blockIdx.x * blockDim.x in CUDA
out_shapeThe full shape and dtype of the output. Pallas allocates this.The output buffer you'd cudaMalloc
pallas_callWraps everything into a JAX-compatible function. JIT-compiles the kernel.cudaLaunchKernel

Matrix Multiply in Pallas

Now let's write a tiled matrix multiplication. This is the kernel that actually matters for MoE — every expert FFN is a matmul, and how you tile it determines your performance.

jax/pallas
def matmul_kernel(a_ref, b_ref, o_ref, *, BLOCK_K):
    """Tiled matmul kernel. Accumulates over K dimension."""
    # o_ref is initialized to zero by Pallas (we'll specify this)
    # a_ref: [BLOCK_M, BLOCK_K], b_ref: [BLOCK_K, BLOCK_N]
    o_ref[...] += a_ref[...] @ b_ref[...]

def matmul(a, b, BLOCK_M=128, BLOCK_N=128, BLOCK_K=64):
    """Pallas tiled matmul: C[M,N] = A[M,K] @ B[K,N]"""
    M, K = a.shape
    K2, N = b.shape
    assert K == K2

    grid = (M // BLOCK_M, N // BLOCK_N, K // BLOCK_K)  # 3D grid

    return pl.pallas_call(
        functools.partial(matmul_kernel, BLOCK_K=BLOCK_K),
        grid=grid,
        in_specs=[
            # A: for grid cell (i, j, k), read block at row i, col k
            pl.BlockSpec(
                block_shape=(BLOCK_M, BLOCK_K),
                index_map=lambda i, j, k: (i, k),
            ),
            # B: for grid cell (i, j, k), read block at row k, col j
            pl.BlockSpec(
                block_shape=(BLOCK_K, BLOCK_N),
                index_map=lambda i, j, k: (k, j),
            ),
        ],
        out_specs=pl.BlockSpec(
            block_shape=(BLOCK_M, BLOCK_N),
            index_map=lambda i, j, k: (i, j),  # output tile at (i, j)
        ),
        out_shape=jax.ShapeDtypeStruct((M, N), a.dtype),
        # CRITICAL: the k dimension is a reduction axis.
        # We accumulate across k, so the output is initialized to 0
        # and += is used in the kernel body.
        grid_spec=pl.GridSpec(
            grid=(M // BLOCK_M, N // BLOCK_N, K // BLOCK_K),
            in_specs=[
                pl.BlockSpec((BLOCK_M, BLOCK_K), lambda i, j, k: (i, k)),
                pl.BlockSpec((BLOCK_K, BLOCK_N), lambda i, j, k: (k, j)),
            ],
            out_specs=pl.BlockSpec((BLOCK_M, BLOCK_N), lambda i, j, k: (i, j)),
        ),
    )(a, b)
The accumulator pattern. Notice that the output BlockSpec maps grid cell (i, j, k) to tile (i, j) — the same output tile for all values of k. This tells Pallas that k is a reduction dimension: the kernel body accumulates (+=) into the output tile across all k iterations. The accumulator lives in fast on-chip memory (VMEM/registers) and is only written back to HBM once, after all k blocks are processed. This is the same tiling trick that makes CUDA matmul fast.

Understanding XLA: What Happens Before Pallas

Before reaching for Pallas, you should understand what JAX's compiler (XLA) does automatically. XLA already performs aggressive optimizations, and sometimes your "custom kernel" is actually slower than what XLA generates.

XLA OptimizationWhat It DoesExample
Operator fusionCombines adjacent elementwise ops into a single kerneljax.nn.gelu(x + bias) becomes one kernel, not two
Layout optimizationRearranges memory layout for optimal access patternsTransposes a matrix before matmul if that reduces HBM reads
Common subexpression eliminationComputes shared subexpressions oncea*b + a*b computes a*b once
Constant foldingPre-computes operations on known constants at compile timejnp.ones((3,3)) * 2 becomes a constant tensor
Buffer aliasingReuses memory buffers when a tensor is consumed and no longer neededAfter h = gelu(x), x's memory can be reused for the next layer

XLA's fusion is particularly good for elementwise chains. If your "slow operation" is a chain of x + bias, gelu(x), dropout(x), XLA has already fused them into one kernel. Writing a Pallas version will not help. Where Pallas wins is on operations that XLA cannot fuse: matmul followed by elementwise followed by another matmul. XLA cannot fuse across matmul boundaries because the tiling strategy changes. Pallas can.

// When Pallas wins vs XLA:

// Case 1: XLA WINS — pure elementwise chain
y = gelu(layer_norm(x + bias))    XLA fuses all of this into 1 kernel
// Pallas kernel for this = same speed or slower (XLA already optimal)

// Case 2: PALLAS WINS — matmul-elementwise-matmul
h = x @ W_up    kernel 1: matmul (writes h to HBM)
h = gelu(h)     kernel 2: elementwise (reads h, writes h to HBM)
y = h @ W_down   kernel 3: matmul (reads h from HBM)
// XLA cannot fuse across matmuls. 3 kernels, 2 extra HBM round-trips.
// Pallas fuses all 3 into 1 kernel: h stays in VMEM.

// Case 3: PALLAS WINS — non-standard access patterns
// ragged_dot, sparse attention, custom scatter patterns
// XLA has no built-in template for these. It falls back to generic code.

Inspecting XLA's Output: The HLO

How do you know if XLA is already optimizing your code well? Inspect the HLO (High-Level Optimizer intermediate representation). This shows you exactly what kernels XLA generates.

python
import jax

# Method 1: Print the HLO for a function
def my_function(x, w):
    h = x @ w
    h = jax.nn.gelu(h)
    return h

# Lower to HLO (shows the computation graph)
lowered = jax.jit(my_function).lower(
    jnp.ones((128, 4096)),
    jnp.ones((4096, 11008))
)
print(lowered.as_text())  # Shows fusion decisions

# Method 2: Full XLA dump (for deep debugging)
# Run with: XLA_FLAGS="--xla_dump_to=/tmp/xla_dump"
# Generates .hlo files showing every optimization pass

# Method 3: Profile with JAX profiler
with jax.profiler.trace("/tmp/jax-trace"):
    result = my_function(x, w)
    result.block_until_ready()
# Open the trace in perfetto or TensorBoard to see kernel timings
When to inspect HLO. You do not need to read HLO for every operation. Inspect it when: (1) your Pallas kernel is slower than the JAX equivalent (maybe XLA already fused it), (2) profiling shows an unexpected kernel boundary (XLA is not fusing something you expected it to), or (3) you are debugging a performance regression after a JAX version update (XLA's fusion heuristics may have changed).

The Ragged Dot Challenge

Now the real problem. In your MoE layer, you need to do a batched matmul where each batch element (expert) has a different number of tokens. The straightforward approach is padding:

// Padding approach (wasteful)
max_tokens = max(group_sizes) # e.g., 200 if the busiest expert gets 200
x_padded = pad_each_group(x, group_sizes, max_tokens) # [N_experts, max_tokens, D]
output = jnp.einsum('emd,edk->emk', x_padded, expert_weights)

// If group_sizes = [50, 200, 80, 30, 150, 90, 110, 60]
// We pad everyone to 200. Total compute: 8 × 200 = 1600 token-matmuls
// Actual useful compute: 50+200+80+30+150+90+110+60 = 770 token-matmuls
// Wasted compute: 830 / 1600 = 52% of FLOPs are on padding!

jax.lax.ragged_dot avoids this waste by natively handling variable-length segments. But it is a general-purpose implementation. When you know your specific problem structure (e.g., expert dimensions are always the same, group sizes fall in a known range), you can write a custom Pallas kernel that is even faster.

When to Write a Custom Kernel: The Decision Framework

Writing a Pallas kernel is an investment. Before you start, apply this decision tree:

Step 1: Profile
Is the operation actually a bottleneck? If it is <5% of step time, optimization elsewhere gives better ROI.
↓ yes, it is >10% of step time
Step 2: Check JAX builtins
Does jax.lax have a specialized op? Does jnp compose existing ops efficiently? Try before writing custom code.
↓ builtins are too slow or don't exist
Step 3: Analyze the bottleneck
Is it compute-bound or memory-bound? If memory-bound, kernel fusion may help more than better tiling.
↓ compute-bound, tiling matters
Step 4: Estimate the ceiling
Calculate theoretical peak: FLOPs / (peak TFLOPS). If the current op is within 80% of peak, a custom kernel won't help much.
↓ current op is <50% of peak
Step 5: Write the Pallas kernel
Start with the simplest correct version. Profile. Tune block sizes. Add prefetching hints if needed.

Deriving When Pallas Beats ragged_dot

Let's derive the condition where a custom Pallas kernel outperforms ragged_dot. The key insight: ragged_dot uses a conservative tiling strategy because it must handle arbitrary group sizes. If you know your group sizes at kernel-write time (or they cluster around known values), you can tile more aggressively.

// ragged_dot: tiles by expert, but must handle the worst-case
// For N experts with group_sizes g_i:
Tragged = ∑i ceil(gi / BLOCK_M) × (K / BLOCK_K) × (D / BLOCK_N) × tblock

// Custom kernel: you know max(g_i) ≤ G, so you pre-allocate
// and avoid per-expert dispatch overhead
Tcustom = N × ceil(G / BLOCK_M) × (K / BLOCK_K) × (D / BLOCK_N) × tblock

// Pallas wins when: dispatch overhead dominates AND
// you can use larger BLOCK_M because you know the structure
// Condition: F = D × expert_dim / group_size_avg
// If F > 128 (high arithmetic intensity per expert), tiling wins.
// If F < 32 (small experts, few tokens), overhead dominates — use ragged_dot.
The F > D rule of thumb. When the arithmetic intensity per expert (FLOPs per byte of input) exceeds the HBM bandwidth divided by peak compute (about 128 on TPU v4), a custom Pallas kernel with tuned block sizes will beat the generic ragged_dot. For the typical MoE configuration (d_model=4096, expert_dim=11008, 50-200 tokens per expert), the arithmetic intensity is well above this threshold.

A Complete Pallas Kernel: Fused Expert Matmul

Here is a simplified but complete Pallas kernel for the MoE forward pass. It processes one expert at a time with optimal tiling.

jax/pallas
import functools
from jax.experimental import pallas as pl

def expert_matmul_kernel(
    x_ref,    # [BLOCK_M, D] — input token tile
    w_ref,    # [D, BLOCK_N] — expert weight tile (reduction axis tiled)
    o_ref,    # [BLOCK_M, BLOCK_N] — output accumulator
):
    """Single-expert tiled matmul kernel."""
    # Accumulate: load tile of input and weight, matmul, add to output
    o_ref[...] += x_ref[...] @ w_ref[...]

def moe_forward_pallas(
    x,              # [total_tokens, D] — pre-sorted by expert assignment
    expert_weights, # [num_experts, D, expert_dim] — stacked weights
    group_sizes,    # [num_experts] — tokens per expert
    BLOCK_M=64,
    BLOCK_K=128,
    BLOCK_N=64,
):
    """Process each expert's tokens with a tuned Pallas matmul."""
    num_experts, D, expert_dim = expert_weights.shape
    outputs = []
    offset = 0

    for e in range(num_experts):
        n_tokens = group_sizes[e]
        if n_tokens == 0:
            continue

        # Slice this expert's tokens and weight
        x_e = jax.lax.dynamic_slice(
            x, (offset, 0), (n_tokens, D)
        )  # [n_tokens, D]
        w_e = expert_weights[e]  # [D, expert_dim]

        # Pad n_tokens to multiple of BLOCK_M
        n_padded = ((n_tokens + BLOCK_M - 1) // BLOCK_M) * BLOCK_M
        x_padded = jnp.pad(x_e, ((0, n_padded - n_tokens), (0, 0)))

        grid = (n_padded // BLOCK_M, expert_dim // BLOCK_N, D // BLOCK_K)

        result = pl.pallas_call(
            expert_matmul_kernel,
            grid=grid,
            in_specs=[
                pl.BlockSpec((BLOCK_M, BLOCK_K), lambda i, j, k: (i, k)),
                pl.BlockSpec((BLOCK_K, BLOCK_N), lambda i, j, k: (k, j)),
            ],
            out_specs=pl.BlockSpec((BLOCK_M, BLOCK_N), lambda i, j, k: (i, j)),
            out_shape=jax.ShapeDtypeStruct((n_padded, expert_dim), x.dtype),
        )(x_padded, w_e)

        # Unpad and collect
        outputs.append(result[:n_tokens])
        offset += n_tokens

    return jnp.concatenate(outputs, axis=0)
Production vs learning code. This kernel processes experts sequentially in a Python loop, which adds Python overhead and prevents cross-expert parallelism. A production Pallas kernel would encode the expert index as a grid dimension, use advanced BlockSpec indexing to handle variable group sizes within a single kernel launch, and fuse the up-projection, GELU activation, and down-projection into one kernel. The version above is correct and clear — optimize for clarity first, speed second.

Pallas Grid/Block Visualization

The canvas below shows how a matrix multiply is tiled into a Pallas grid. Each colored block represents one grid cell. Watch the accumulation pattern as blocks iterate along the K dimension.

Pallas Grid Tiling Visualization

A matrix C = A × B is tiled into blocks. Click "Step" to advance one grid cell at a time and watch the accumulator build up. The highlighted tiles show which blocks of A and B are being multiplied.

Grid cell (0,0,0)

Pallas on TPU vs GPU: The Key Differences

Pallas runs on both TPU and GPU, but the programming model has hardware-specific differences that affect how you write kernels.

AspectPallas on TPU (Mosaic)Pallas on GPU (Triton backend)
Compute unitMXU (Matrix Multiply Unit) — 128x128 systolic arrayCUDA Cores + Tensor Cores
Natural tile size128x128 (matches MXU)Varies (16x16 to 128x128 depending on Tensor Core generation)
On-chip memoryVMEM: 16-32 MB per core. Huge!SMEM: 48-228 KB per SM. Much smaller.
Data movementCompiler manages via BlockSpec. DMA handles async loads.You hint via BlockSpec. Compiler maps to shared memory loads.
Parallelism modelGrid cells map to TPU cores. Each core has one MXU.Grid cells map to thread blocks. Each block has multiple warps.
Matmul primitivejax.lax.dot inside kernel maps directly to MXUMaps to wmma or mma instructions via Tensor Cores
Best block sizes128-512 (large VMEM allows big tiles)32-128 (limited SMEM constrains tile size)
Reduction axisBuilt into grid_spec — compiler unrolls the k-loopManual loop inside kernel body

The biggest practical difference: TPU's VMEM is 100-500x larger than GPU's SMEM. This means you can use much larger block sizes on TPU, which improves compute utilization (fewer loop iterations, less overhead per tile). On GPU, you are forced to use smaller tiles and more iterations.

// Maximum tile size comparison:

// TPU v4: 16 MB VMEM, BF16
// 3 tiles (A, B, C) of 512x512:
3 × 512 × 512 × 2 bytes = 1.5 MB    Fits easily in 16 MB

// GPU A100: 192 KB SMEM, FP16
// 3 tiles of 512x512:
3 × 512 × 512 × 2 bytes = 1.5 MB    DOES NOT FIT in 192 KB!
// Maximum on A100: ~128x128 per tile
3 × 128 × 128 × 2 bytes = 96 KB    Fits in 192 KB

// Implication: TPU can accumulate over 4x fewer k-iterations than GPU
// for the same matrix. This means TPU kernels have less loop overhead.

Real-World Pallas Usage: What Frontier Labs Write Kernels For

In practice, engineers at frontier labs do not write Pallas kernels for standard matmuls — XLA handles those well. They write kernels for operations that fall outside XLA's optimization reach. Here are the most common use cases:

Use CaseWhy XLA is Not EnoughThe Pallas Solution
MoE expert dispatchVariable-length batches per expert. XLA pads to max length, wasting compute.Custom kernel that processes each expert's tokens without padding, fusing the permute-compute-unpermute.
Flash-style attentionStandard attention materializes the S×S matrix. XLA cannot discover the tiled-softmax trick.Pallas kernel implementing online softmax with tile-by-tile processing, keeping the running max and sum in registers.
Custom activationsNovel activation functions (e.g., SwiGLU with gating) create fusion barriers in XLA.Fuse the gated activation into the preceding matmul. The activation happens on tiles already in VMEM.
Quantized matmulINT8/INT4 matmul with per-channel dequantization. XLA does not always find the optimal fusion pattern.Custom kernel that loads quantized weights, dequantizes in-register, and accumulates in BF16/FP32. Avoids writing dequantized weights to HBM.
Sparse attention patternsBlock-sparse or sliding-window attention has irregular access patterns that XLA handles conservatively.Custom kernel with specialized BlockSpec index_maps that only load the non-zero blocks, skipping empty regions entirely.
The 90% rule. In a typical large-scale training codebase, 90% of the code uses standard JAX/PyTorch operations that XLA/CUDA handles well. Only 10% of the code benefits from custom kernels — but that 10% often accounts for 30-40% of the step time. Your job as a kernel engineer is to identify that 10%, measure it precisely, and optimize it aggressively while leaving the other 90% to the compiler.

Block Size Tuning: The Art of Tiling

The block sizes (BLOCK_M, BLOCK_K, BLOCK_N) are the most important tuning knobs. Too small and you have too many blocks (overhead dominates). Too large and tiles do not fit in fast memory (spill to HBM, killing performance).

Block SizeTPU v4 Sweet SpotGPU (A100) Sweet SpotWhy
BLOCK_M128-25664-128Batch dimension. Larger = better compute utilization, but needs more VMEM/SMEM.
BLOCK_K128-25632-64Reduction dimension. Larger = fewer loop iterations, but wider loads from HBM.
BLOCK_N128-25664-128Output dimension. Larger = more parallelism per block.

The constraint: BLOCK_M × BLOCK_K + BLOCK_K × BLOCK_N + BLOCK_M × BLOCK_N must fit in fast memory (VMEM on TPU, SMEM on GPU). For TPU v4 with 16 MB VMEM per core at BF16 (2 bytes), you can fit ~8M elements across all three tiles. For a A100 with 192 KB SMEM at FP16, you can fit ~96K elements.

// Memory budget calculation for TPU v4:
VMEM = 16 MB = 8M BF16 elements

// For BLOCK_M=128, BLOCK_K=128, BLOCK_N=128:
A_tile = 128 × 128 = 16,384 elements
B_tile = 128 × 128 = 16,384 elements
C_tile = 128 × 128 = 16,384 elements
Total = 49,152 elements = 96 KB    << 16 MB. Fits easily.

// For BLOCK_M=512, BLOCK_K=512, BLOCK_N=512:
Total = 3 × 262,144 = 786,432 elements = 1.5 MB    Still fits!

// For GPU A100 with SMEM=192KB:
// BLOCK_M=128, BLOCK_K=64, BLOCK_N=128:
Total = 128×64 + 64×128 + 128×128 = 32,768 elements = 64 KB    Fits.
// BLOCK_M=256, BLOCK_K=128, BLOCK_N=256: 163,840 elements = 320 KB — DOES NOT FIT

Pallas Scratch Memory: Working with Temporary Buffers

Sometimes your kernel needs temporary storage that does not correspond to any input or output. For example, in a fused attention kernel, you need to store the running softmax denominator and maximum. Pallas provides scratch memory for this: on-chip buffers that exist only for the lifetime of a single grid cell's computation.

jax/pallas
from jax.experimental.pallas import MemorySpace

def attention_kernel(q_ref, k_ref, v_ref, o_ref, m_scratch, l_scratch):
    """Simplified flash-attention-style kernel with scratch memory.
    m_scratch: [BLOCK_M] — running max of attention scores
    l_scratch: [BLOCK_M] — running sum of exp(scores - max)
    """
    # Compute attention scores for this tile
    scores = q_ref[...] @ k_ref[...].T  # [BLOCK_M, BLOCK_N]

    # Online softmax: update running max and sum
    tile_max = scores.max(axis=-1)  # [BLOCK_M]
    new_max = jnp.maximum(m_scratch[...], tile_max)

    # Correction factor for previous tiles
    correction = jnp.exp(m_scratch[...] - new_max)

    # Update running sum
    tile_exp = jnp.exp(scores - new_max[:, None])
    new_sum = l_scratch[...] * correction + tile_exp.sum(axis=-1)

    # Update output accumulator with correction
    o_ref[...] = o_ref[...] * correction[:, None] + tile_exp @ v_ref[...]

    # Update scratch buffers
    m_scratch[...] = new_max
    l_scratch[...] = new_sum

# Scratch memory is specified as additional BlockSpec with MemorySpace.VMEM
# It is allocated on-chip and never touches HBM.
# This is how FlashAttention avoids materializing the full S×S matrix.

The scratch memory pattern is what makes Pallas powerful for attention and other reduction operations. Without it, you would need to write intermediate results to HBM and read them back — exactly the inefficiency that FlashAttention eliminates.

Kernel Fusion: Combining Operations

The biggest wins from Pallas come not from writing faster matmuls (XLA already does that well) but from fusing multiple operations into a single kernel. Every time data travels to HBM and back, you pay a latency penalty. Fusion eliminates these round-trips.

Consider the MoE forward pass for a single expert:

Without Fusion (5 HBM round-trips)
1. Load x from HBM, matmul with W_up, write h to HBM
2. Load h from HBM, apply GELU, write h' to HBM
3. Load h' from HBM, matmul with W_down, write y to HBM
Total: 5 reads + 5 writes = 10 HBM operations
↓ Fuse into single kernel
With Fusion (2 HBM round-trips)
1. Load x from HBM, matmul W_up, GELU, matmul W_down, write y to HBM
Intermediate h never leaves on-chip memory (VMEM/registers)!
Total: 1 read + 1 write = 2 HBM operations

For an expert FFN with d_model=4096 and d_ff=11008, the intermediate activation h is 11008 floats per token. At BF16, that is 22 KB per token. For a batch of 128 tokens, the intermediate is 2.8 MB — well within VMEM. This data is loaded from HBM 3 times in the unfused version (once per kernel) but stays on-chip in the fused version. At 1.2 TB/s HBM bandwidth, those extra loads cost ~7 microseconds per round-trip. Across 256 experts and 60+ layers, this adds up to milliseconds per training step.

// Bandwidth savings from fusion:
bytes_intermediate = batch_tokens × d_ff × 2 (BF16)
bytes_intermediate = 128 × 11008 × 2 = 2.82 MB

// Unfused: read+write this 3 times (load for GELU, load for W_down, load for add)
unfused_traffic = 2.82 MB × 6 = 16.9 MB extra HBM traffic

// At 1.2 TB/s:
unfused_overhead = 16.9 MB / 1.2 TB/s = 14 μs per expert
total_overhead = 14 μs × 8 active experts × 30 MoE layers = 3.4 ms per step

// If step time is 50 ms, that's 6.8% — significant!

Debugging Pallas Kernels

Pallas kernels are compiled, not interpreted. When something goes wrong, you do not get a Python traceback pointing to the offending line. Instead you get cryptic XLA errors or, worse, silently wrong results. Here is a systematic debugging approach.

ProblemSymptomDebug Strategy
Shape mismatchXLA compilation error mentioning "incompatible shapes"Print the expected shapes from your BlockSpec. Verify: grid[i] * block_shape[dim] == array_shape[dim] for every dimension. Off-by-one in grid calculation is the #1 bug.
Wrong resultsOutput does not match reference (jnp.matmul)Test with BLOCK sizes equal to the full matrix (grid=(1,1,1)). If this works, the bug is in tiling/accumulation. Halve block sizes until you find the boundary that breaks.
NaN on large inputsWorks on small, NaN on largeThe last tile has fewer valid elements than BLOCK_SIZE. Uninitialized memory in the tile causes NaN. Add explicit padding or boundary masking.
Performance worse than jnpKernel is slower than built-in opCheck if your kernel is memory-bound. Profile with jax.profiler. If FLOP utilization is <30%, try larger block sizes. If HBM utilization is >80%, you are memory-bound — fusion is the answer, not better tiling.
Compilation takes minutespallas_call hangs during JITReduce grid dimensions. Very large grids (>10000 cells) can cause the compiler to spend excessive time optimizing. Try fusing grid dimensions: (M, N) grid instead of (M, N, K).
The golden debugging rule. Always write a reference implementation in pure JAX first (jnp.matmul, jnp.einsum). Run both on the same random input. Compare with jnp.allclose(pallas_out, ref_out, atol=1e-5). If they match on random data, test on structured data (all zeros, all ones, identity matrix). If they match everywhere, your kernel is correct — profile it for performance.

Advanced: Prefetching and Pipelining

On TPU, Pallas supports prefetching hints that tell the hardware to start loading the next tile before the current tile's computation finishes. This hides memory latency behind compute.

jax/pallas (advanced)
# Prefetch hint: tell the compiler to load the next k-block
# while the current k-block is being computed.
# This is TPU-specific and requires Pallas's prefetch API.

def matmul_prefetch_kernel(a_ref, b_ref, o_ref, a_scratch, b_scratch):
    """Matmul with double-buffered prefetching."""
    # a_scratch and b_scratch are allocated in VMEM by Pallas
    # They hold the prefetched next tile while we compute on the current

    # Compute current tile (data already in VMEM from previous prefetch)
    o_ref[...] += a_ref[...] @ b_ref[...]

    # The compiler automatically pipelines: next tile loads overlap
    # with current tile's compute. On TPU v4, this hides ~80% of
    # memory latency for compute-bound kernels.

# Scratch memory specification:
# pl.BlockSpec with memory_space=pl.MemorySpace.VMEM
# tells Pallas to allocate these in fast on-chip memory

Double-buffering works because TPU has separate data movement units (DMAs) and compute units (MXUs). While the MXU multiplies the current tile, the DMA loads the next tile. When the MXU finishes, the next tile is already in VMEM — zero wait time. This is the same principle as CPU prefetching, but at a much larger scale (megabytes instead of cache lines).

Pallas vs Triton: When to Use Which

If you are targeting NVIDIA GPUs, you have a choice: Triton (Python DSL, mature ecosystem, broad adoption) or Pallas (JAX-native, works on TPU and GPU, newer). Here is an honest comparison.

DimensionPallasTriton
HardwareTPU + GPUGPU only (NVIDIA, AMD)
EcosystemNiche — mostly Google/DeepMind/AnthropicBroad — PyTorch, vLLM, HuggingFace
MaturityYoung (2023-). API still changing.Established (2019-). Stable API.
Memory modelDeclarative (BlockSpec — compiler manages)Explicit (you write tl.load, tl.store)
DebuggingHarder — compiled through XLA, cryptic errorsEasier — interpret mode, torch integration
Best forTPU workloads, JAX-native projects, Anthropic/Google shopsNVIDIA GPU kernels, PyTorch projects, broad deployment
For frontier lab interviews: if the lab uses JAX/TPU (Anthropic, Google DeepMind, xAI), know Pallas. If they use PyTorch/CUDA (Meta FAIR, most startups), know Triton. If you want to maximize your options, learn both — the concepts (tiling, memory hierarchy, fusion) transfer directly between them. The capstone in chapter 15 uses Pallas because it targets Colab TPU, but the same kernel logic translates to Triton with minimal changes.

Translating a Pallas Kernel to Triton

To show that the concepts transfer, here is the same vector add kernel written in both Pallas and Triton side by side.

pallas (JAX)
# Pallas version: declarative data movement
def vadd_kernel(x_ref, y_ref, o_ref):
    o_ref[...] = x_ref[...] + y_ref[...]

def vadd(x, y):
    return pl.pallas_call(
        vadd_kernel,
        grid=(x.shape[0] // 512,),
        in_specs=[
            pl.BlockSpec((512,), lambda i: (i,)),
            pl.BlockSpec((512,), lambda i: (i,)),
        ],
        out_specs=pl.BlockSpec((512,), lambda i: (i,)),
        out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
    )(x, y)
triton (PyTorch)
# Triton version: explicit loads and stores
import triton
import triton.language as tl

@triton.jit
def vadd_kernel(x_ptr, y_ptr, o_ptr, n, BLOCK: tl.constexpr):
    pid = tl.program_id(0)
    offsets = pid * BLOCK + tl.arange(0, BLOCK)
    mask = offsets < n  # boundary check (Pallas does this via BlockSpec)
    x = tl.load(x_ptr + offsets, mask=mask)  # explicit load
    y = tl.load(y_ptr + offsets, mask=mask)
    tl.store(o_ptr + offsets, x + y, mask=mask)  # explicit store

def vadd(x, y):
    o = torch.empty_like(x)
    n = x.numel()
    grid = (triton.cdiv(n, 512),)
    vadd_kernel[grid](x, y, o, n, BLOCK=512)
    return o

Notice the key differences: Pallas uses BlockSpec to declare data movement (the compiler generates loads/stores). Triton uses explicit tl.load/tl.store calls (you control the loads). Pallas requires out_shape upfront (it pre-allocates). Triton takes a pre-allocated output pointer. The kernel body logic (the actual addition) is identical in both — because the math does not change between hardware platforms.

Pallas and Automatic Differentiation

A critical question for training: can you differentiate through a Pallas kernel? Yes — but with a caveat. JAX's autodiff system (jax.grad, jax.vjp) works on Pallas kernels, but only if you provide a custom VJP rule (vector-Jacobian product). The compiler cannot automatically differentiate through the low-level kernel code.

jax/pallas
import jax
from jax import custom_vjp

@custom_vjp
def my_pallas_matmul(a, b):
    """Forward pass: use Pallas kernel."""
    return pallas_matmul(a, b)  # our custom Pallas implementation

def my_pallas_matmul_fwd(a, b):
    """Forward pass + save residuals for backward."""
    result = pallas_matmul(a, b)
    return result, (a, b)  # save inputs for backward

def my_pallas_matmul_bwd(res, g):
    """Backward pass: compute gradients.
    For C = A @ B:
      dA = dC @ B^T
      dB = A^T @ dC
    """
    a, b = res
    # Can use Pallas kernels for backward matmuls too!
    da = pallas_matmul(g, b.T)
    db = pallas_matmul(a.T, g)
    return da, db

my_pallas_matmul.defvjp(my_pallas_matmul_fwd, my_pallas_matmul_bwd)

# Now jax.grad works through my_pallas_matmul!
loss_fn = lambda a, b: my_pallas_matmul(a, b).sum()
grads = jax.grad(loss_fn)(a, b)
Why custom VJP matters for MoE. Your MoE layer uses Pallas kernels for the expert matmuls. For training, gradients must flow back through these kernels. By defining custom VJP rules, you ensure that the backward pass also uses your optimized Pallas kernels — not falling back to generic JAX matmuls. This is why production MoE implementations always define both forward and backward kernels.
Pallas fundamentals: You are writing a Pallas kernel for a matmul C = A @ B where A is [2048, 4096] and B is [4096, 1024]. You choose BLOCK_M=128, BLOCK_K=64, BLOCK_N=128. How many total grid cells (kernel invocations) are there, and how many times is each output tile accumulated into?

Chapter 15: The Capstone — 10M Param Addition Transformer

Everything in this lesson has been building to this chapter. Vlad Mikulik's blog post for Anthropic explicitly names this exercise: train a ~10 million parameter transformer from scratch that learns addition, compare dense vs MoE architectures, derive the Chinchilla-optimal training configuration, and write a custom Pallas kernel that beats ragged_dot for the MoE variant. This is the proof of work. Not a toy exercise — a miniature version of what you will do daily at a frontier lab.

We will design every component from first principles: the tokenizer, the architecture, the dataset, the training loop, the scaling law analysis, and the Pallas kernel. By the end of this chapter, you will have a complete implementation plan that you can execute on a free Colab TPU in under an hour.

This is the interview differentiator. Any strong candidate can talk about transformers, MoE, and Pallas in the abstract. Fewer than 5% have actually built all three from scratch and measured their behavior. Completing this capstone — and documenting it in a public repo — is the single highest-signal artifact you can show a frontier lab interviewer. It proves you can move from theory to working code on real hardware.

The Task: Learning Addition

Addition is the perfect capstone task for three reasons. First, it is learnable by a small model — you do not need billions of parameters or weeks of training. Second, it has a crisp success metric — the model either outputs the correct sum or it does not. Third, it exercises genuine algorithmic reasoning — the model must learn to propagate carries across digit positions, which requires attending to non-local dependencies.

// Example training data:
Input: "1 2 3 + 4 5 6 ="
Target: "5 7 9"

Input: "9 9 9 + 1 ="
Target: "1 0 0 0"

Input: "4 8 + 5 3 ="
Target: "1 0 1"

// The model sees individual digits as tokens.
// It must learn: (a) positional alignment, (b) single-digit addition,
// (c) carry propagation. The carry is the hard part.

Tokenizer Design

We keep the tokenizer dead simple. Thirteen tokens are sufficient:

TokenIDPurpose
0-90-9Digit tokens
+10Addition operator
=11Equals sign (transition from input to output)
<PAD>12Padding token for fixed-length sequences

Why separate digit tokens instead of multi-digit numbers? Because multi-digit tokenization introduces a combinatorial explosion. "123" as a single token means your vocabulary grows to 1000+ for 3-digit numbers. With single-digit tokens, your vocabulary stays at 13 regardless of number length. The model must learn positional alignment, but that is the point — we want it to learn this.

python
# Tokenizer: dead simple character-level encoding
VOCAB = {'0':0, '1':1, '2':2, '3':3, '4':4,
         '5':5, '6':6, '7':7, '8':8, '9':9,
         '+':10, '=':11, '<PAD>':12}
VOCAB_SIZE = 13

def encode(s):
    """Convert "1 2 3 + 4 5 6 = 5 7 9" to token IDs."""
    tokens = []
    for ch in s.split():
        if ch in VOCAB:
            tokens.append(VOCAB[ch])
        else:
            raise ValueError(f"Unknown token: {ch}")
    return tokens

def decode(ids):
    """Convert token IDs back to string."""
    inv = {v: k for k, v in VOCAB.items()}
    return ' '.join(inv[i] for i in ids if i != 12)

Architecture Decisions: Hitting ~10M Parameters

We need to choose d_model, n_layers, n_heads, and d_ff to hit approximately 10 million parameters. Let's derive this.

// Parameter count for a decoder-only transformer:
Pembed = V × d    (token embeddings, V=13)
Ppos = Lmax × d    (positional embeddings, L_max = max seq length)
Pattn = 4 × d2    (Q, K, V, O projections per layer)
Pffn = 2 × d × dff + d + dff ≈ 2 × d × dff    (up + down projections)
Pln = 4 × d    (2 LayerNorms per layer, scale + bias each)
Phead = d × V    (output projection, often tied with embeddings)

Ptotal = Pembed + Ppos + nlayers × (Pattn + Pffn + Pln) + Phead

// For d_ff = 4d (standard ratio):
Pper_layer ≈ 4d2 + 8d2 = 12d2
Ptotal ≈ nlayers × 12d2 + (V + Lmax) × d

// Target: ~10M parameters. Try d=256, n_layers=8:
P ≈ 8 × 12 × 2562 + 13 × 256 + 32 × 256 + 256 × 13
P ≈ 8 × 786,432 + 3,328 + 8,192 + 3,328
P ≈ 6,291,456 + 14,848 = 6.3M    Too few.

// Try d=384, n_layers=6:
P ≈ 6 × 12 × 3842 + overhead
P ≈ 6 × 1,769,472 + ~25K = 10.6M    Close!

// Final architecture:
d_model = 384, n_heads = 6, n_layers = 6, d_ff = 1536
max_seq_len = 32, vocab_size = 13
Total parameters ≈ 10.6M
HyperparameterValueRationale
d_model384Divisible by 6 heads. Large enough to represent digit relationships.
n_heads6d_head = 384/6 = 64. Standard head dimension.
n_layers6Deep enough for carry propagation (needs multi-hop reasoning).
d_ff15364 × d_model. Standard FFN expansion ratio.
max_seq_len32Enough for "9 9 9 9 9 + 9 9 9 9 9 = 1 0 0 0 0 0" (21 tokens).
vocab_size1310 digits + 3 special tokens.

Dataset Generation

We generate addition problems randomly, with a curriculum that starts with small numbers and gradually increases.

python
import random
import jax.numpy as jnp

def generate_addition_example(max_digits=5):
    """Generate one addition problem with answer."""
    # Random number of digits for each operand
    d1 = random.randint(1, max_digits)
    d2 = random.randint(1, max_digits)

    a = random.randint(0, 10**d1 - 1)
    b = random.randint(0, 10**d2 - 1)
    c = a + b

    # Convert to space-separated digits
    a_str = ' '.join(str(a))
    b_str = ' '.join(str(b))
    c_str = ' '.join(str(c))

    # Full sequence: "1 2 3 + 4 5 6 = 5 7 9"
    seq = f"{a_str} + {b_str} = {c_str}"
    return seq

def make_batch(batch_size, max_digits, max_len=32):
    """Generate a padded batch of addition problems."""
    inputs = []
    targets = []

    for _ in range(batch_size):
        seq = generate_addition_example(max_digits)
        tokens = encode(seq)

        # Input: everything up to and including "="
        # Target: everything after "="
        eq_pos = tokens.index(11)  # position of "="

        # For autoregressive training: input is full sequence shifted right
        # Target is full sequence shifted left
        # We mask the loss on the input portion (before "=")
        inp = tokens[:-1]  # all tokens except last
        tgt = tokens[1:]   # all tokens except first

        # Pad to max_len
        pad_len = max_len - len(inp)
        inp = inp + [12] * pad_len
        tgt = tgt + [12] * pad_len

        # Loss mask: only supervise tokens after "="
        mask = [0] * eq_pos + [1] * (len(tokens) - 1 - eq_pos) + [0] * pad_len

        inputs.append(inp)
        targets.append(tgt)

    return {
        'input_ids': jnp.array(inputs),    # [B, max_len]
        'targets': jnp.array(targets),      # [B, max_len]
        'loss_mask': jnp.array(mask),       # [B, max_len]
    }
The loss mask is critical. We only compute loss on the answer tokens (after "="). Without this mask, the model would get "free reward" for predicting the digits of the input — which it can trivially memorize from context. We want to measure whether it can actually compute the sum. The mask also prevents gradient signal from the predictable input sequence from drowning out the harder addition signal.

Training Loop in JAX/Flax/Optax

jax/flax
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax

class TransformerBlock(nn.Module):
    d_model: int = 384
    n_heads: int = 6
    d_ff: int = 1536
    dropout_rate: float = 0.1

    @nn.compact
    def __call__(self, x, mask, deterministic=True):
        # Pre-norm attention
        h = nn.LayerNorm()(x)
        h = nn.MultiHeadDotProductAttention(
            num_heads=self.n_heads,
            qkv_features=self.d_model,
            deterministic=deterministic,
        )(h, h, mask=mask)
        h = nn.Dropout(self.dropout_rate)(h, deterministic=deterministic)
        x = x + h  # residual

        # Pre-norm FFN
        h = nn.LayerNorm()(x)
        h = nn.Dense(self.d_ff)(h)
        h = jax.nn.gelu(h)
        h = nn.Dense(self.d_model)(h)
        h = nn.Dropout(self.dropout_rate)(h, deterministic=deterministic)
        x = x + h  # residual
        return x

class AdditionTransformer(nn.Module):
    vocab_size: int = 13
    d_model: int = 384
    n_heads: int = 6
    n_layers: int = 6
    d_ff: int = 1536
    max_len: int = 32

    @nn.compact
    def __call__(self, x, deterministic=True):
        B, S = x.shape

        # Token + positional embeddings
        tok_emb = nn.Embed(self.vocab_size, self.d_model)(x)
        pos_emb = nn.Embed(self.max_len, self.d_model)(jnp.arange(S))
        h = tok_emb + pos_emb  # [B, S, d_model]

        # Causal mask: each position can only attend to earlier positions
        causal_mask = jnp.tril(jnp.ones((S, S)))[None, None, :, :]

        # Transformer layers
        for _ in range(self.n_layers):
            h = TransformerBlock(
                self.d_model, self.n_heads, self.d_ff
            )(h, causal_mask, deterministic)

        # Final LayerNorm + output projection
        h = nn.LayerNorm()(h)
        logits = nn.Dense(self.vocab_size)(h)  # [B, S, vocab_size]
        return logits

# ── Training setup ──
model = AdditionTransformer()
rng = jax.random.PRNGKey(42)
dummy = jnp.zeros((1, 32), dtype=jnp.int32)
params = model.init(rng, dummy)

# Count parameters
n_params = sum(p.size for p in jax.tree.leaves(params))
print(f"Total parameters: {n_params:,}")  # ~10.6M

# Optimizer: AdamW with cosine decay
schedule = optax.warmup_cosine_decay_schedule(
    init_value=0.0,
    peak_value=3e-4,
    warmup_steps=500,
    decay_steps=50000,
    end_value=1e-5,
)
optimizer = optax.adamw(schedule, weight_decay=0.01)
opt_state = optimizer.init(params)

# ── Loss function ──
def loss_fn(params, batch):
    logits = model.apply(params, batch['input_ids'], deterministic=False,
                         rngs={'dropout': jax.random.PRNGKey(0)})
    # Cross-entropy loss, masked to answer tokens only
    labels = jax.nn.one_hot(batch['targets'], 13)
    loss = -jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1)
    loss = (loss * batch['loss_mask']).sum() / batch['loss_mask'].sum()
    return loss

# ── Training step ──
@jax.jit
def train_step(params, opt_state, batch):
    loss, grads = jax.value_and_grad(loss_fn)(params, batch)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

# ── Training loop ──
for step in range(50000):
    batch = make_batch(batch_size=128, max_digits=5)
    params, opt_state, loss = train_step(params, opt_state, batch)
    if step % 500 == 0:
        print(f"Step {step:5d} | Loss: {loss:.4f}")

Chinchilla Scaling Laws for This Toy Model

The Chinchilla paper found that for a compute budget C, the optimal model size and token count scale as:

Nopt ≈ 0.7 × (C / 6)0.5     optimal parameters
Dopt ≈ 1.4 × (C / 6)0.5     optimal tokens
Dopt / Nopt ≈ 20     ~20 tokens per parameter

For our 10.6M parameter model, the Chinchilla-optimal token budget is approximately 10.6M × 20 = 212M tokens. Each training example is about 15 tokens on average, so we need about 212M / 15 ≈ 14 million training examples. At batch size 128, that is about 110,000 training steps.

// Chinchilla analysis for our capstone:
N = 10.6M params
Dopt = 20 × N = 212M tokens
Avg tokens per example = 15
Examples needed = 212M / 15 = 14.1M
Steps at batch 128 = 14.1M / 128 ≈ 110K steps

// Compute budget:
C = 6 × N × D = 6 × 10.6M × 212M = 1.35 × 1016 FLOPs
// On TPU v2 (~45 TFLOPS): 1.35e16 / 45e12 = ~300 seconds = 5 minutes
// On Colab TPU: maybe 10-20 minutes with overhead
Sanity check. 10.6M parameters × 20 tokens/param × 6 FLOPs/token ≈ 1.3 × 1016 FLOPs. A Colab TPU v2 delivers about 45 TFLOPS at sustained utilization. So the Chinchilla-optimal training run takes ~300 seconds of raw compute, or about 10-20 minutes with data loading and Python overhead. This is genuinely doable on a free Colab TPU in a single session.

Dense vs MoE: Implementing Both

For the capstone comparison, we implement the same model in both dense and MoE variants. The MoE variant replaces the FFN in every other layer (layers 1, 3, 5) with an MoE layer, keeping total active parameters the same.

ConfigDenseMoE (8 experts, top-2)
d_model384384
n_layers66 (3 dense FFN + 3 MoE)
d_ff15361536 per expert
Active params per token10.6M~10.6M (same)
Total params10.6M~31M (3 layers × 8 experts each)
FLOPs per token~127M~127M (same — top-2 of 8)
Memory~42 MB (FP32)~124 MB (FP32)
jax/flax
class MoETransformerBlock(nn.Module):
    """Transformer block with MoE FFN instead of dense FFN."""
    d_model: int = 384
    n_heads: int = 6
    d_ff: int = 1536
    num_experts: int = 8
    top_k: int = 2

    @nn.compact
    def __call__(self, x, mask, deterministic=True):
        # Pre-norm attention (identical to dense)
        h = nn.LayerNorm()(x)
        h = nn.MultiHeadDotProductAttention(
            num_heads=self.n_heads,
            qkv_features=self.d_model,
            deterministic=deterministic,
        )(h, h, mask=mask)
        x = x + h

        # Pre-norm MoE FFN (replaces dense FFN)
        h = nn.LayerNorm()(x)
        h_flat = h.reshape(-1, self.d_model)  # [B*S, D]
        moe_out, aux_loss = MoELayer(
            num_experts=self.num_experts,
            top_k=self.top_k,
            expert_dim=self.d_ff,
        )(h_flat[jnp.newaxis])  # MoELayer expects [B, S, D]
        moe_out = moe_out.reshape(x.shape)

        x = x + moe_out
        self.sow('intermediates', 'aux_loss', aux_loss)
        return x

class MoEAdditionTransformer(nn.Module):
    """Addition transformer with MoE in alternating layers."""
    vocab_size: int = 13
    d_model: int = 384
    n_layers: int = 6

    @nn.compact
    def __call__(self, x, deterministic=True):
        B, S = x.shape
        tok_emb = nn.Embed(self.vocab_size, self.d_model)(x)
        pos_emb = nn.Embed(32, self.d_model)(jnp.arange(S))
        h = tok_emb + pos_emb

        causal_mask = jnp.tril(jnp.ones((S, S)))[None, None, :, :]

        for i in range(self.n_layers):
            if i % 2 == 1:  # Odd layers get MoE
                h = MoETransformerBlock()(h, causal_mask, deterministic)
            else:  # Even layers stay dense
                h = TransformerBlock()(h, causal_mask, deterministic)

        h = nn.LayerNorm()(h)
        return nn.Dense(self.vocab_size)(h)

Understanding the Attention Pattern

What does the model actually learn? After training, we can visualize the attention patterns to understand the algorithm the model has discovered.

For a problem like "4 8 + 5 3 = 1 0 1", the model needs to:

Step 1: Align digits by position
The ones digit of the first number (8) must attend to the ones digit of the second number (3). This is a positional alignment problem — the model uses positional embeddings to learn which positions correspond.
Step 2: Compute single-digit sums
8 + 3 = 11. This requires looking up the addition table from the training data. The model stores this in its weights.
Step 3: Determine carry
11 ≥ 10, so carry = 1, output digit = 1. The carry information must propagate to the next position.
Step 4: Propagate carry
Tens place: 4 + 5 + carry(1) = 10. Again ≥ 10, so carry = 1, output digit = 0. This chain can continue for arbitrarily many digits.
Step 5: Handle leading digit
If the final carry is 1, output an extra leading 1. The model must learn to predict the output length, which changes when carries overflow the most significant digit.

Empirically, trained addition transformers develop a characteristic attention pattern: early layers attend strongly to positionally-aligned digits (learning the alignment), middle layers attend to adjacent output positions (propagating carries), and the final layers attend globally (handling edge cases like leading digits).

// Attention analysis code (run after training)
// Extract attention weights from each layer and head
_, intermediates = model.apply( params, test_input, capture_intermediates=True ) // For each layer l and head h:
// attn_weights[l][h] is [seq_len, seq_len]
// Look at row i (what position i attends to) at the "=" token
// and the output digit positions.

// Expected pattern for "4 8 + 5 3 = 1 0 1":
// Layer 0: Position 7 ("1" in output) attends to positions 1("8") and 4("3")
// Layer 2: Position 7 attends to position 8 (carry from next position)
// Layer 5: Position 6 ("0") attends to positions 0("4"), 3("5"), and 7

The Training Schedule: Hyperparameter Choices

Every hyperparameter choice has a reason. Here is the complete training configuration with justifications.

HyperparameterValueWhy
OptimizerAdamWStandard for transformers. Adam handles the varying gradient scales across attention and FFN layers. Weight decay prevents overfitting on the finite addition patterns.
Peak LR3e-4Standard for small transformers. Too high (>1e-3) causes training instability. Too low (<1e-5) slows convergence.
Warmup steps500Prevents early loss spikes from large random gradients. 1% of total training steps is a good rule of thumb.
ScheduleCosine decaySmooth decay from peak to near-zero. Better than step decay for transformers because it avoids discontinuous learning rate changes.
Weight decay0.01Mild regularization. The model should not memorize specific addition pairs but learn the algorithm. Weight decay encourages simpler solutions.
Batch size128Large enough for stable gradient estimates. Small enough to fit in TPU memory. With 128 problems per batch, the model sees ~1920 tokens per step.
Dropout0.1Regularization during training. Prevents co-adaptation of attention heads. Applied to attention weights and FFN output.
Max seq length32The longest possible sequence is "9 9 9 9 9 + 9 9 9 9 9 = 1 9 9 9 9 8" (21 tokens). 32 gives headroom for padding.
The overfitting test. Before training the full model, train a tiny model (d_model=64, 2 layers, ~100K params) on a fixed set of 1000 addition problems for 10K steps. If it cannot memorize this small set (loss does not approach 0), something is wrong with your code — likely a bug in the loss mask, the causal mask, or the autoregressive generation loop. This sanity check takes 30 seconds and saves hours of debugging.

The Expected Results

Based on the scaling laws and empirical results from similar toy experiments, here is what you should expect:

MetricDense 10.6MMoE 31M (10.6M active)
Training loss at 50K steps~0.05~0.03
Exact-match accuracy (5-digit)~92%~96%
Exact-match accuracy (6-digit, OOD)~60%~75%
Training time (Colab TPU)~15 min~18 min
Step time~18 ms~22 ms

The MoE model achieves lower loss at the same compute budget because it has 3x the parameters, giving it more capacity to memorize digit patterns and carry rules. The training time is only slightly longer because the active compute is the same — the overhead comes from routing and the auxiliary loss computation.

Writing the Pallas Kernel for the MoE Variant

For the MoE variant, the ragged_dot inside the MoE layer is the bottleneck. Your Pallas kernel fuses the expert dispatching and matmul into a single kernel, avoiding the intermediate permutation.

jax/pallas
def moe_fused_kernel(
    x_ref,       # [BLOCK_M, D] — sorted input tokens
    w_up_ref,    # [D, D_FF_BLOCK] — expert up-projection tile
    w_down_ref,  # [D_FF_BLOCK, D] — expert down-projection tile
    o_ref,       # [BLOCK_M, D] — output accumulator
):
    """Fused expert forward: up-project, GELU, down-project."""
    # Up-projection
    h = x_ref[...] @ w_up_ref[...]  # [BLOCK_M, D_FF_BLOCK]

    # GELU activation (fused — no round-trip to HBM!)
    h = h * jax.lax.erfc(-h / jnp.sqrt(2.0)) * 0.5

    # Down-projection and accumulate
    o_ref[...] += h @ w_down_ref[...]  # [BLOCK_M, D]
Why this fused kernel matters. Without fusion, the standard MoE forward pass is: (1) sort tokens by expert, (2) up-project (matmul, write to HBM), (3) read from HBM, apply GELU (write to HBM), (4) read from HBM, down-project (write to HBM), (5) unsort. That is 4 HBM round-trips for the activation tensor. The fused kernel does steps 2-4 in a single kernel: the intermediate activation h never leaves fast on-chip memory. On memory-bound expert sizes, this can give a 2-3x speedup.

The Roofline Model: Predicting Kernel Performance

Before writing a kernel, you should know its theoretical maximum performance. The roofline model gives you this: it plots achievable performance (FLOPS) against arithmetic intensity (FLOPS per byte of memory traffic).

// Roofline model for TPU v4:
Peak compute: 275 TFLOPS (BF16)
Peak bandwidth: 1.2 TB/s (HBM)

// The "ridge point" — where compute-bound meets memory-bound:
Ridge = Peak compute / Peak bandwidth
Ridge = 275 × 1012 / 1.2 × 1012 = 229 FLOPS/byte

// For a kernel with arithmetic intensity I:
// If I < 229: memory-bound. Achievable = I × 1.2 TB/s
// If I ≥ 229: compute-bound. Achievable = 275 TFLOPS

// Matrix multiply (large N): I = 2N/4 = N/2 FLOPS/byte (BF16)
// For N=4096: I = 2048 FLOPS/byte >> 229. Compute-bound. Good.
// For N=32: I = 16 FLOPS/byte << 229. Memory-bound. Kernel will starve.

// Vector add: I = 1 FLOP/byte. Always memory-bound.
// GELU: I = ~5 FLOPS/byte. Always memory-bound.
// This is why fusing GELU into matmul helps: GELU rides for free
// on data already in fast memory.

When you profile a Pallas kernel and find it achieves 50 TFLOPS instead of the theoretical 275 TFLOPS, the roofline model tells you whether the problem is (a) your kernel is memory-bound and 50 TFLOPS is actually close to the ceiling, or (b) your kernel has poor compute utilization and there is room to improve.

Benchmarking Pallas Kernels

Correct benchmarking of accelerator kernels is tricky. Here are the common pitfalls and how to avoid them.

jax
import jax
import jax.numpy as jnp
import time

def benchmark_kernel(fn, *args, warmup=5, repeats=20):
    """Correct benchmarking for JAX/Pallas kernels."""

    # Step 1: Compile the kernel (JIT warmup)
    # First call triggers compilation — don't time this!
    out = fn(*args)
    out.block_until_ready()  # CRITICAL: wait for async compute

    # Step 2: Warmup runs (saturate caches, warm up the device)
    for _ in range(warmup):
        out = fn(*args)
        out.block_until_ready()

    # Step 3: Timed runs
    times = []
    for _ in range(repeats):
        start = time.perf_counter()
        out = fn(*args)
        out.block_until_ready()  # CRITICAL: JAX is async!
        end = time.perf_counter()
        times.append(end - start)

    # Report median (not mean — avoids outlier sensitivity)
    times.sort()
    median_ms = times[len(times) // 2] * 1000
    print(f"Median: {median_ms:.3f} ms")
    print(f"Min: {times[0]*1000:.3f} ms, Max: {times[-1]*1000:.3f} ms")

    # Calculate FLOPS utilization
    # flops = ... (compute for your specific operation)
    # print(f"Achieved: {flops / median_s / 1e12:.1f} TFLOPS")

    return median_ms

# COMMON MISTAKES:
# 1. Forgetting block_until_ready() — times only CPU dispatch, not compute
# 2. Including JIT compilation in the timing
# 3. Using mean instead of median (one outlier ruins your numbers)
# 4. Not warming up (first few runs are slower due to cache effects)
# 5. Benchmarking tiny inputs (overhead dominates, not representative)
The block_until_ready() trap. JAX operations are asynchronous: when you call jnp.matmul(a, b), JAX returns a future immediately while the actual computation happens on the accelerator. If you time the function call without block_until_ready(), you are timing how fast Python can enqueue work (microseconds), not how fast the hardware executes it (milliseconds). This is the #1 benchmarking mistake in JAX. Always call .block_until_ready() on the output before reading the timer.

The Complete Kernel Development Workflow

Putting it all together, here is the workflow for writing a production Pallas kernel from scratch:

1. Write Reference in JAX
Implement the operation using jnp.matmul, jnp.einsum, or other high-level ops. This is your correctness oracle. Profile it to establish the baseline speed.
2. Compute the Roofline
Calculate arithmetic intensity. Is this operation compute-bound or memory-bound? If memory-bound, fusion (not a custom kernel) is probably the answer.
3. Choose Block Sizes
Start with the hardware's natural tile size (128x128 for TPU v4). Check VMEM/SMEM budget. Adjust if tiles don't fit or if grid is too large.
4. Write the Simplest Correct Kernel
One grid cell, one accumulation loop. Test against the reference on random + structured inputs. Don't optimize yet.
5. Benchmark and Profile
Use benchmark_kernel() above. Compare against the JAX reference. If slower, check: are you memory-bound? Is the grid too fine-grained?
6. Optimize
Try different block sizes (sweep 64, 128, 256). Add prefetch hints (TPU). Fuse adjacent operations. Re-benchmark after each change.
7. Integration Test
Drop the kernel into the actual training loop. Verify that the model trains identically (same loss curve). Measure end-to-end step time improvement.

The "Proof of Work" Documentation

Once you have completed the capstone, you need to document it in a way that a frontier lab interviewer can evaluate. Here is the structure that maximizes signal:

1. README.md
One-paragraph summary: "I trained a 10.6M parameter transformer to learn addition, compared dense vs MoE, and wrote a custom Pallas kernel." Link to Colab notebook.
2. Scaling Law Analysis
Plot: loss vs compute for 3-5 model sizes. Fit Chinchilla power law. Show predicted vs actual optimal training duration. This demonstrates you understand why, not just how.
3. Dense vs MoE Comparison
Same compute budget, different architectures. Table: loss, accuracy, step time. Plot: learning curves overlaid. Analysis of expert specialization (which experts handle carry digits?).
4. Pallas Kernel + Benchmark
Side-by-side comparison: ragged_dot vs custom kernel. Benchmark table across group size distributions. Roofline model showing achieved vs theoretical performance.
5. Expert Analysis
Visualize routing patterns. Do experts specialize by digit position? By carry vs no-carry? By operand length? This is the kind of analysis that shows intellectual curiosity.

Evaluation: Measuring What Matters

Training loss is a proxy. What we actually care about is: does the model get the right answer? For addition, we have a crisp metric: exact-match accuracy. The model's output must be exactly correct — "101" is wrong if the answer is "100." There is no partial credit.

python
def evaluate(params, model, num_samples=1000, max_digits=5):
    """Evaluate exact-match accuracy on random addition problems."""
    correct = 0
    for _ in range(num_samples):
        # Generate a random problem
        d1 = random.randint(1, max_digits)
        d2 = random.randint(1, max_digits)
        a = random.randint(0, 10**d1 - 1)
        b = random.randint(0, 10**d2 - 1)
        expected = a + b

        # Encode input: "1 2 3 + 4 5 = "
        prompt = ' '.join(str(a)) + ' + ' + ' '.join(str(b)) + ' ='
        tokens = encode(prompt)
        input_ids = jnp.array([tokens + [12] * (32 - len(tokens))])

        # Autoregressive generation
        generated = []
        for _ in range(8):  # max answer length
            logits = model.apply(params, input_ids)
            next_token = jnp.argmax(logits[0, len(tokens) + len(generated) - 1])
            if next_token == 12:  # PAD = stop
                break
            generated.append(int(next_token))
            # Append to input for next step
            tokens.append(int(next_token))
            input_ids = jnp.array([tokens + [12] * (32 - len(tokens))])

        # Check if generated digits form the correct number
        predicted = int(''.join(str(d) for d in generated)) if generated else -1
        if predicted == expected:
            correct += 1

    return correct / num_samples

Error Analysis: Where the Model Fails

Raw accuracy hides the most interesting patterns. Break down errors by type to understand how the model fails, not just that it fails.

Error TypeExampleFrequencyWhat It Tells You
Carry propagation999 + 1 = 990 (missed carry chain)~60% of errorsModel has not learned multi-hop carry. Need more training or deeper model.
Off-by-one digit123 + 456 = 579 (correct!) but 123 + 457 = 579 (wrong by 1)~20% of errorsModel is "almost right" — the logit margin for the correct vs incorrect digit is tiny. More training data at this difficulty level helps.
Wrong length99 + 1 = 10 (should be 100, missing a digit)~10% of errorsModel has not learned that addition can increase the number of digits. This is fundamentally about predicting the output length before generating.
Positional confusion12 + 34 = 55 (treated as 21 + 34)~5% of errorsPositional encoding is not strong enough to distinguish digit positions. Consider explicit positional markers or reversed digit order.
Random/garbage123 + 456 = 333~5% of errorsModel has not converged for this input range. Insufficient training.
python
def error_analysis(params, model, num_samples=500):
    """Categorize errors by type."""
    errors = {'carry': 0, 'off_by_one': 0, 'wrong_length': 0,
              'positional': 0, 'random': 0}
    total_errors = 0

    for _ in range(num_samples):
        a, b = random.randint(0, 99999), random.randint(0, 99999)
        expected = a + b
        predicted = generate_answer(params, model, a, b)

        if predicted != expected:
            total_errors += 1

            # Classify the error
            exp_str, pred_str = str(expected), str(predicted)
            if len(exp_str) != len(pred_str):
                errors['wrong_length'] += 1
            elif abs(expected - predicted) == 1:
                errors['off_by_one'] += 1
            elif needs_carry(a, b) and carry_error(expected, predicted):
                errors['carry'] += 1
            else:
                errors['random'] += 1

    print(f"Total errors: {total_errors}/{num_samples}")
    for k, v in errors.items():
        print(f"  {k}: {v} ({v/max(total_errors,1)*100:.1f}%)")
The reversed-digit trick. A known technique for improving addition accuracy: reverse the digit order so the ones digit comes first. "123 + 456 = 579" becomes "3 2 1 + 6 5 4 = 9 7 5". This makes carry propagation easier because the model processes digits from least significant to most significant — the same order a human adds. Papers report 10-15% accuracy improvement from this simple change. In your capstone, try both orderings and compare.

Curriculum Learning: Teaching Addition Gradually

Throwing 5-digit addition at a randomly initialized model is like teaching calculus before arithmetic. Curriculum learning starts with easy problems and gradually increases difficulty. For addition:

// Curriculum schedule:
Steps 0-5K:       max_digits = 2    (0+0 through 99+99)
Steps 5K-15K:     max_digits = 3    (add 3-digit numbers)
Steps 15K-30K:    max_digits = 4    (add 4-digit numbers)
Steps 30K-50K:    max_digits = 5    (full difficulty)

// Implementation: just change the argument to make_batch()
max_d = 2 + min(3, step // 10000)
batch = make_batch(batch_size=128, max_digits=max_d)

Why does curriculum learning help? The model first masters the basic pattern (single-digit addition, no carry), then learns single carries, then carry chains. Each skill builds on the previous one. Without curriculum, the model sees hard problems before it has learned the prerequisites, wasting gradient signal on problems it has no hope of solving yet.

Scaling Law Experiment Design

To derive Chinchilla-style scaling laws for your toy model, you need to train multiple model sizes at multiple compute budgets. Here is the experiment grid:

Model Sized_modeln_layersParamsToken Budgets to Test
Tiny12840.8M5M, 10M, 20M, 50M tokens
Small25643.1M10M, 30M, 60M, 120M tokens
Medium384610.6M50M, 100M, 200M, 400M tokens
Large512825M100M, 250M, 500M, 1B tokens
XL768856M250M, 500M, 1B, 2B tokens

For each (model_size, token_budget) pair, train to completion and record the final loss. Then plot loss vs compute (= 6 × params × tokens) and fit the power law L(C) = A × C + L0. The fitted parameters A, α, and L0 are your toy-model scaling law. Compare them to the Chinchilla values (A=406.4, α=0.34 for their setup) — they will differ because addition is not natural language, but the power-law structure should hold.

All 20 runs on Colab TPU. The total compute for all 20 runs is approximately: sum of 6 × N × D for each pair. The largest run (56M params, 2B tokens) takes ~3 hours on a TPU v2. The entire grid takes about 8-12 hours of Colab time, spread over 2-3 sessions. This is genuinely feasible as a weekend project.

Training Loss Visualization

The canvas below simulates the training loss curves for both the dense and MoE models. Watch how the MoE model converges faster and reaches a lower final loss, despite using the same compute per step.

Training Loss: Dense vs MoE

Simulated training loss curves for the 10.6M addition transformer. Dense (orange) vs MoE 8-expert top-2 (teal). Click "Train" to animate.

Step 0 / 50000
Capstone design: You are training the 10.6M addition transformer and notice that after 20K steps, the model gets 95% accuracy on 3-digit addition but only 40% on 5-digit addition. The loss is still decreasing slowly. What is the most likely bottleneck, and what change would you make?

Chapter 16: Frontier Lab Arsenal

You have spent fifteen chapters absorbing the technical foundation of what frontier AI labs expect from their research engineers. This final chapter is your war chest: the reading list, the interview question bank, the portfolio strategy, and the cross-references to every relevant Engineermaxxing resource. Bookmark this chapter. Return to it before every interview.

The Two Paths

Frontier lab engineering roles split into two broad categories. Your path determines which skills to emphasize, but the strongest candidates have depth in both.

DimensionBelow the StackAbove the Abstraction
FocusMaking models run fast: kernels, compilers, distributed training, hardwareMaking models do useful things: agents, evals, RLHF, data, reasoning
Typical titlesML Systems Engineer, Kernel Engineer, Training InfrastructureResearch Engineer, Applied ML, Agent Infrastructure
Core skillsCUDA/Triton/Pallas, XLA, distributed systems, profiling, quantizationPrompt engineering, RLHF/DPO, evaluation design, tool use, planning
Interview emphasis70% CODE + DEBUG, 30% CONCEPT + DESIGN40% DESIGN + FRONTIER, 40% CODE, 20% CONCEPT
Portfolio artifactCustom kernel that beats a baseline, scaling law reproductionAgent system with real-world evals, RLHF pipeline, benchmark contribution
Labs hiring heavilyAnthropic, Google DeepMind, xAI, Meta FAIR (Infra)Anthropic, OpenAI, Google DeepMind, Cohere
The strongest signal. The single most impressive thing you can show a frontier lab interviewer is a project that spans both paths. The capstone from chapter 15 does exactly this: you build a model (above), optimize it with a custom kernel (below), and analyze the scaling laws (research). If you complete it and write it up publicly, you are already in the top 5% of applicants.

Building Your Public Portfolio

Frontier labs filter hundreds of applications. Your portfolio needs to signal three things in under 60 seconds of scanning:

1. You Can Build
Working code on real hardware. Not toy scripts — systems that train models, serve inference, or optimize kernels. A Colab notebook that trains and evaluates is better than a README describing what you would build.
2. You Understand Why
Scaling law analysis. Ablation studies. "I tried X, it failed because Y, so I did Z." This separates engineers who follow tutorials from engineers who understand first principles.
3. You Stay Current
You engage with recent papers (last 6 months). You have opinions about open problems. You can discuss what is broken in current approaches and what might fix it.

Concrete portfolio items, ranked by signal strength:

RankArtifactTime InvestmentSignal
1Capstone (Ch 15): addition transformer + MoE + Pallas kernel + scaling laws20-40 hoursMaximum — covers all dimensions
2Reproduce a recent paper's key result (e.g., FlashAttention-4 on a subset)40-80 hoursVery high — proves deep understanding
3Open-source contribution to JAX, vLLM, or similar project20-60 hoursHigh — proves you work with production code
4Blog post explaining a paper with original code and experiments10-20 hoursMedium-high — good communication signal
5Technical Twitter/blog presence discussing current researchOngoingMedium — shows engagement, not depth

The Paper Reading List

These are the papers you need to have read — and be prepared to discuss in detail — for a frontier lab interview. They are organized into five tracks. Within each track, papers are listed in recommended reading order (build on each other).

Track 1: Quantization

The compression frontier. Each paper introduces a technique that the next paper builds on. By the end, you should be able to explain why QTIP achieves near-lossless INT2 quantization.

#PaperKey ContributionRead After
1LLM.int8() (Dettmers et al., 2022)Mixed-precision decomposition: outlier features in FP16, rest in INT8
2GPTQ (Frantar et al., 2023)One-shot weight quantization using inverse Hessian for error compensationLLM.int8()
3QuIP (Chee et al., 2023)Random orthogonal transform before quantization — provably optimal incoherenceGPTQ
4QuIP# (Tseng et al., 2024)Lattice codebooks + incoherence processing for near-lossless 2-bit weightsQuIP
5QTIP (Tseng et al., 2024)Trellis-coded quantization — information-theoretic lower bound on compressionQuIP#
6AQLM (Egiazarian et al., 2024)Additive vector quantization for extreme 1-2 bit compressionQuIP#
Engineermaxxing cross-references (Quantization track):
Veanor: LLM.int8() — Full interactive breakdown of mixed-precision decomposition
Veanor: QuIP — Incoherence processing derivation with canvas demos
Veanor: QuIP# — Lattice codebooks and E8 lattice explained
Veanor: QTIP — Trellis codes and information-theoretic quantization bounds
Veanor: AQLM — Additive vector quantization deep-dive

Track 2: Attention Mechanisms

The computational engine of transformers. FlashAttention is not one paper — it is a lineage. Each version solves the previous version's limitation.

#PaperKey ContributionRead After
1FlashAttention (Dao et al., 2022)IO-aware exact attention: tiled computation that never materializes the full attention matrix in HBM
2FlashAttention-2 (Dao, 2023)Better work partitioning across warps: 2x faster than FA1 by reducing shared memory readsFA1
3FlashAttention-3 (Shah et al., 2024)Hopper-specific: warp specialization, FP8 attention, asynchronous softmaxFA2
4FlashAttention-4 (2025)KV-cache-aware decode attention, variable-length prefill batchingFA3
5SnapKV (Li et al., 2024)KV cache compression: identify and retain only informative KV pairsFA2
Engineermaxxing cross-references (Attention track):
Veanor: FlashAttention-4 — Full evolution FA1 → FA4 with hardware-specific optimizations
Veanor: SnapKV — KV cache compression strategies and observation-window analysis
Micro: Attention & Transformers — Build attention from scratch with interactive sims

Track 3: Kernel DSLs and GPU Programming

The low-level tools that let you write custom operations. Understanding the progression from Triton to ThunderKittens to CuTe is essential for systems interviews.

#Paper/SystemKey ContributionRead After
1Triton (Tillet et al., 2019)Python-based GPU kernel DSL with automatic tiling and memory management
2ThunderKittens (Spector et al., 2024)Embedded C++ DSL for writing CUDA kernels at the warpgroup levelTriton
3CuTe (NVIDIA, 2024)Layout algebra for describing multi-dimensional data layouts on GPU memory hierarchiesThunderKittens
4Pallas (Google, 2023-2024)JAX kernel authoring for TPU and GPU with BlockSpec-based data movementTriton
Engineermaxxing cross-references (Kernel track):
Veanor: ThunderKittens — Warpgroup-level programming and the register tile abstraction
Micro: GPU Kernel Landscape — From CUDA to Triton to ThunderKittens: interactive comparison

Track 4: Agents and Automated Discovery

The "above the abstraction" reading list. These papers show how LLMs are used to build systems, not just generate text.

#PaperKey ContributionRead After
1Toolformer (Schick et al., 2023)Self-supervised tool use: model learns when and how to call APIs
2FunSearch (Romera-Paredes et al., 2024)LLM-guided program search that discovers new mathematical resultsToolformer
3AlphaEvolve (DeepMind, 2025)Evolution-guided code generation with automated verificationFunSearch
4SWE-bench (Jimenez et al., 2024)Real-world software engineering benchmark for LLM agentsFunSearch
Engineermaxxing cross-references (Agent track):
Veanor: AlphaEvolve — Evolution-guided search with LLM-generated mutations
Micro: FunSearch — From LLM-driven search to new math discoveries
Micro: FunSearch Applied — Build your own program search system
Veanor: FunSearch for Competition — 8 years of Hash Code problems solved
Veanor: Barbarians at the Gate — AI-driven systems research with OpenEvolve

Track 5: Scaling Laws

The theoretical foundation that drives every training decision at frontier labs. Without understanding scaling laws, you cannot reason about resource allocation.

#PaperKey ContributionRead After
1Scaling Laws for Neural LMs (Kaplan et al., 2020)Power-law relationships between compute, data, model size, and loss
2Chinchilla (Hoffmann et al., 2022)Corrected scaling: tokens matter as much as parameters. Optimal ratio ~20:1.Kaplan
3The Scaling Book (Pope et al., Google, 2024)Comprehensive roofline analysis, transformer math, parallelism, and inference on TPU/GPUChinchilla
4Distillation Scaling Laws (Busbridge et al., Apple ICML 2025)Scaling laws for distillation — the capacity gap phenomenon where stronger teachers hurtChinchilla
Engineermaxxing cross-references (Scaling & Training):
Parminces: The Scaling Book — All 12 chapters with interactive exercises
Micro: Scaling Book Workbook — 54 fill-in-the-blank exercises from the scaling book
Micro: Thinking in JAX — JAX fundamentals: jit, grad, vmap, pmap
Veanor: Distillation Scaling Laws — Capacity gap and compute-optimal distillation
Micro: Distributed Training — Data, tensor, pipeline, and expert parallelism from scratch
Micro: LLM Inference — KV cache, batching strategies, and serving optimization
Micro: Transformer — Build a transformer from absolute zero

Interview Question Bank

Five questions per dimension, calibrated to staff-level expectations. For each question, we note the key insight the interviewer is looking for.

CONCEPT Questions (First Principles)

#QuestionKey Insight They Want
1"Derive the memory bandwidth needed to serve a 70B parameter model at 100 tokens/sec in FP16."70B × 2 bytes = 140 GB. At 100 tok/s, you load all weights 100 times/sec = 14 TB/s. A single H100 has 3.35 TB/s HBM bandwidth, so you need at least ceil(14/3.35) = 5 GPUs minimum. But tensor parallelism overhead adds ~10%, so 6 GPUs in practice.
2"Why does KV cache grow linearly with sequence length and batch size? Derive the formula."KV cache per layer = 2 × batch × seq × n_heads × d_head × dtype_bytes. For 32 layers, 70B model (n_heads=64, d_head=128), batch=32, seq=4096, FP16: 2 × 32 × 4096 × 64 × 128 × 2 × 32 = 68.7 GB per token-batch. This is why KV cache, not weights, is often the memory bottleneck.
3"Explain why the auxiliary load-balancing loss in MoE training uses f_i × p_i instead of just minimizing the variance of f_i."f_i (token fraction) is not differentiable — it depends on argmax. p_i (mean router probability) is differentiable via softmax. Their product lets gradients flow while still penalizing imbalance.
4"What is the theoretical minimum memory for training a B-parameter model with Adam?"Model: 2B bytes (FP16). Gradients: 2B. Adam states (m and v): 4B each = 8B. Total: 12B bytes minimum. For 7B params: 84 GB. This is why FSDP/ZeRO-3 shards everything across GPUs.
5"Derive the arithmetic intensity of self-attention for sequence length S and head dimension d."Q@K^T: 2Sd FLOPs, loads 2Sd bytes. Softmax: O(S^2). Attn@V: 2S^2d FLOPs. Total FLOPs ≈ 4S^2d. Total bytes ≈ 4Sd + 2S^2. Intensity = 4S^2d / (4Sd + 2S^2) ≈ 2d for large S. Since d=128 typically, intensity ≈ 256 FLOPs/byte — attention is compute-bound for large S.

DESIGN Questions (System Architecture)

#QuestionKey Insight They Want
1"Design a system that trains a 100B MoE model on 1000 GPUs."Layered parallelism: expert parallel for MoE layers, tensor parallel within each expert group, pipeline parallel across layer chunks, data parallel for throughput. The all-to-all communication for expert routing is the bottleneck — co-locate experts that frequently share tokens.
2"Your serving system handles 10K requests/sec with a 7B model. The team wants to switch to a 70B MoE. What changes?"Memory: 10x more weight storage but only 2-3x more active compute. Need multi-GPU inference with expert parallelism. Routing adds latency variance — some requests hit "hot" experts. Solution: disaggregate prefill (compute-bound) from decode (memory-bound) on different hardware pools.
3"Design the evaluation pipeline for a new foundation model release."Multi-level: unit tests (basic capabilities), benchmark suites (MMLU, HumanEval, etc.), red-teaming (adversarial safety), A/B testing (human preference), long-context specific tests, multilingual parity checks. Key: evals must run in CI, not manually.
4"Your training run has been running for 3 days. Loss suddenly spikes. Walk me through your debug process."Systematic: (1) check for NaN/Inf in gradients, (2) check learning rate schedule, (3) inspect the exact batch that caused the spike, (4) check for hardware failures (GPU ECC errors), (5) check data pipeline for corrupted examples, (6) if gradient norm exploded, was the spike in a specific layer? (7) rollback to last good checkpoint, reduce lr, resume.
5"Design a system for continuous model improvement using user feedback."RLHF/DPO pipeline: collect preference pairs from user thumbs up/down, train reward model (or use DPO directly), fine-tune with PPO/DPO, evaluate on held-out preferences + safety benchmarks. Key: prevent reward hacking, maintain base capability, detect distribution shift in user queries.

CODE Questions (Implementation)

#QuestionKey Insight They Want
1"Write the top-k router for an MoE layer in JAX. Include the load-balancing loss."See chapter 13 implementation. Key details: softmax only over selected experts, aux loss uses f_i × p_i product, handle the non-differentiability of top-k selection.
2"Implement a Pallas kernel for batched vector add on TPU."See chapter 14 vector_add example. Show BlockSpec, grid, index_map. Explain the compilation pipeline.
3"Write a KV cache implementation that supports dynamic batch sizes."Pre-allocate max capacity, track per-sequence length, use scatter/gather for non-contiguous sequences. PagedAttention: page table + block allocator.
4"Implement gradient checkpointing for a transformer layer."jax.checkpoint (or torch.utils.checkpoint): recompute activations during backward instead of storing. Trade 33% more compute for ~60% less memory. Choose which layers to checkpoint based on activation size.
5"Write the data loading pipeline for a multi-host TPU training run."tf.data or grain: shard dataset across hosts, prefetch to device, handle host-device async, deterministic shuffling (seeded per epoch for reproducibility), handle restarts from checkpoint.

DEBUG Questions (Failure Diagnosis)

#QuestionKey Insight They Want
1"Your MoE model's loss is 10% worse than expected. All experts have equal utilization. What's wrong?"Equal utilization is suspicious — it might mean the router has collapsed to uniform random routing (all logits near zero). Check router entropy: it should be high but not maximum. If the router is not specializing, the MoE is just an ensemble of random experts, losing the specialization benefit.
2"A Pallas kernel produces correct results on small inputs but NaN on large inputs."Likely a tile boundary issue: when input size is not divisible by block size, the last tile has uninitialized memory or goes out of bounds. Check padding, check boundary masking, verify grid dimension calculation.
3"Training throughput dropped 30% after a JAX version upgrade."Check XLA compilation: new version may choose different fusion patterns. Run with XLA_FLAGS=--xla_dump_to to compare HLO before/after. Check for new op lowerings that bypass custom patterns. Profile with Tensorboard to find the slow op.
4"Your distributed training run hangs intermittently after 2-3 hours."Classic symptoms of a slow or failing GPU. Check: (1) NCCL/TPU runtime logs for timeout errors, (2) individual GPU compute times (one outlier = hardware issue), (3) memory pressure causing OOM on one host, (4) network congestion in all-reduce (check for competing jobs).
5"Your quantized INT8 model is 2% worse on English benchmarks but 15% worse on code generation."Code tokens have sharper distributions (exact syntax matters). Quantization smooths the distribution, rounding "near-correct" logits to incorrect tokens. Fix: keep the code-heavy layers (likely the last few layers and the LM head) in FP16. Run per-layer sensitivity analysis specifically on code benchmarks.

FRONTIER Questions (Research Awareness)

#QuestionKey Insight They Want
1"What are the open problems in MoE scaling?"Expert collapse at scale (>1000 experts), optimal granularity (few large vs many small), routing for multi-modal inputs, expert merging for inference efficiency, MoE-specific distillation.
2"Where is quantization headed after INT4/INT2?"Microscaling (MX) formats from the OCP consortium: block-structured formats with per-block exponents. FP4 E2M1 on Blackwell GPUs. The frontier is 1-2 bit quantization with lookup-table decompression. Information theory says the limit is ~1.5 bits for weights with current model architectures.
3"What is the biggest bottleneck in LLM inference today?"Memory bandwidth for autoregressive decode. Each generated token requires loading all model weights. For a 70B FP16 model, that is 140 GB per token. At 3.35 TB/s (H100), theoretical max is ~24 tokens/sec. Speculative decoding, compression, and batch scheduling are the active research frontiers.
4"How would you design an LLM that can reliably use tools?"The key challenge is not calling tools (any prompted model can emit a function call) but knowing when to call them and what to do with failures. Research frontiers: ReAct-style reasoning traces, function-calling fine-tuning with execution feedback, tool-use RLHF where the reward includes task completion not just format correctness.
5"What changed in the last 6 months that you're most excited about?"This is a culture-fit question. They want genuine enthusiasm about a specific development — not a list. Pick one thing you actually care about and explain why it matters technically. Good answers name a specific paper or system and articulate its implications.

Organizing Your Study: A Concept Dependency Graph

The papers and concepts in this chapter are not independent. They form a dependency graph — you cannot understand QTIP without understanding QuIP, and you cannot understand QuIP without understanding quantization fundamentals. Here is the critical path through the material.

Foundation Layer (Week 1)
Transformer mechanicsAttentionTraining dynamics
Can you implement a transformer from scratch? Can you derive attention? Do you understand Adam, warmup, cosine decay?
If not: start with Micro lessons Transformer and Attention.
Systems Layer (Week 2)
GPU memory hierarchyCUDA/Triton basicsDistributed training
Can you explain why matmul is compute-bound but attention is memory-bound? Can you calculate GPU utilization?
If not: start with Micro lesson GPU Kernel Landscape.
Specialization Layer (Weeks 3-4)
Quantization track OR Kernel DSL track OR Agent track
Pick one track based on your target role. Go deep: read every paper in order. Implement the key algorithms.
The capstone integrates all tracks.

The Anti-Patterns: What NOT to Do

Common preparation mistakes that waste time or actively hurt your candidacy:

Anti-PatternWhy It FailsWhat to Do Instead
Grinding LeetCodeFrontier lab interviews are not about algorithms and data structures. They test ML systems knowledge. A perfect score on LeetCode does not help you derive the Chinchilla scaling law.Spend that time on the capstone project or paper reading. If you must practice coding, implement ML algorithms (attention, quantization, MoE routing) not two-sum.
Reading 50 papers shallowlySurface-level knowledge of many papers is less useful than deep knowledge of 5. Interviewers probe depth, not breadth.Read 5-7 papers deeply. For each: implement the key algorithm, reproduce one figure, write a 3-paragraph critique.
Memorizing formulasIf you memorize the attention formula but cannot derive it from the query-key-value interpretation, you will fail the follow-up question.Derive every formula from first principles. If you cannot derive it, you do not understand it.
Only studying theory"I understand the concept of quantization" means nothing if you have never quantized a model and measured the accuracy drop.Every concept must have a corresponding implementation. Theory + code + measurement = understanding.
Ignoring the lab's specific workDiscussing FlashAttention at a lab that uses TPUs shows you did not research the company.Read the lab's recent publications. Know their stack (JAX vs PyTorch, TPU vs GPU). Reference their work in your answers.

The Interview Process: What to Expect

Frontier lab interview loops are intense. They typically span 4-6 hours across 1-2 days, with each session testing a different dimension. Here is the typical structure:

RoundDurationFormatWhat They Test
Phone Screen45-60 minVideo call with one interviewerBasic competence: can you code, do you know ML fundamentals, are you articulate?
Technical Deep-Dive60-90 minWhiteboard or shared doc with 1-2 senior engineersOne topic in depth: derivations, edge cases, failure modes. They want to see how deep your understanding goes.
System Design60 minWhiteboard with a senior/staff engineerDesign a training pipeline, inference system, or evaluation framework. They want to see you make tradeoffs and justify decisions.
Coding60-90 minLive coding in an IDE or ColabImplement a non-trivial algorithm: attention mechanism, MoE router, quantization routine, or Pallas kernel. Clean code matters as much as correctness.
Research Discussion45-60 minConversation with a research leadDiscuss recent papers, open problems, your own research interests. They want intellectual curiosity and the ability to identify important problems.
Culture/Values30-45 minConversation with a team lead or managerCollaboration style, conflict resolution, what motivates you. They want to know if you will thrive on the team.
The hidden evaluation axis. Beyond the five explicit dimensions (CONCEPT, DESIGN, CODE, DEBUG, FRONTIER), interviewers are assessing a sixth: calibration. Do you know what you know? Do you admit what you do not know? Can you distinguish "I am confident because I have derived this" from "I think this is right but have not verified"? Frontier labs deal with models that are confidently wrong. They do not want engineers who are confidently wrong too. Saying "I don't know, but here is how I would figure it out" is a stronger answer than a plausible-sounding guess.

Common Interview Mistakes (and How to Avoid Them)

MistakeExampleBetter Approach
Answering too broadly"How would you speed up inference?" — "Well, there's quantization, pruning, distillation, caching, batching, compilation..."Pick the most impactful one for the specific scenario. "For a 70B model serving chat at 100 req/s, the bottleneck is memory bandwidth during decode. I'd start with INT8 quantization for a 2x speedup, then implement continuous batching for utilization."
Not asking clarifying questionsJumping into a system design without understanding constraintsSpend the first 5 minutes asking: What hardware? What latency budget? What accuracy tolerance? What traffic pattern? These determine the entire design.
Memorized definitions without understanding"FlashAttention uses tiling to reduce memory" — correct but shallow"FlashAttention exploits the fact that attention is a streaming reduction: you never need the full S×S matrix in memory. By tiling along the sequence dimension and keeping running softmax statistics, each tile only needs O(BLOCK_SIZE) memory instead of O(S^2). The key insight is that softmax can be decomposed into local computations with a correction factor."
Ignoring failure modesPresenting a system design without discussing what breaksFor every component, mention: "This fails when X. The mitigation is Y. If Y also fails, we gracefully degrade to Z."
Not writing testsSubmitting code without verificationAfter implementing, immediately write a test case. "Let me verify with a simple example: 12 + 34 should be 46..." Running the test live shows engineering discipline.

Answering "What is the Biggest Open Problem?"

This question tests research taste. The interviewer wants to know: can you identify problems that matter? Here are five strong answers with the reasoning behind each.

Open ProblemWhy It MattersCurrent Best ApproachWhy It Is Not Solved
Reliable reasoningLLMs can write poetry but struggle with multi-step logic. This limits their use in math, code verification, and scientific discovery.Chain-of-thought prompting, verification-guided search (AlphaProof), process reward modelsWe do not understand why in-context reasoning works or when it fails. No theory predicts which problems a model can solve by reasoning vs which require more training data.
Efficient long-contextReal-world tasks (codebase understanding, document analysis, multi-turn conversations) require contexts of 100K-1M+ tokens. Current models degrade badly beyond 32K.RoPE extension, ring attention, sliding window + global tokensKV cache memory scales linearly with context. Attention quality degrades with distance. No architecture efficiently captures both local and global dependencies at million-token scale.
Data efficiencyChinchilla says we need 20 tokens per parameter. A 1T parameter model needs 20T tokens. We are running out of high-quality text data.Synthetic data generation, data deduplication, curriculum learningSynthetic data risks "model collapse" (self-reinforcing errors). There is no theoretical framework for predicting data quality impact on downstream capabilities.
Interpretability at scaleWe deploy models we do not understand. This is unacceptable for safety-critical applications (medicine, law, autonomous vehicles).Sparse autoencoders, probing classifiers, circuit analysisCurrent methods work on small models but do not scale to 100B+ parameters. We have no way to verify that an interpretation is complete (covers all model behavior).
Compute efficiency cliffTransformer compute scales quadratically with sequence length and linearly with parameters. We are approaching the limits of what current hardware can sustain.MoE (this chapter!), linear attention, state-space models, speculative decodingEvery efficiency technique has tradeoffs (MoE = memory, SSM = reduced in-context learning). No single architecture is efficient across all workloads.
How to present an open problem in an interview. Do not just name the problem. Show: (1) why it matters to this lab specifically, (2) what has been tried, (3) why existing solutions are insufficient, and (4) what you think a promising direction might be. The fourth point is the hardest and most valuable — it shows you can think beyond existing literature.

The Scaling Book Exercise Checklist

The Google DeepMind "Scaling Book" is the most comprehensive reference on compute-optimal training. Here is a structured exercise checklist to work through it:

ExerciseWhat You BuildWhat You Learn
1. Reproduce Figure 1Train 5 model sizes (1M to 100M), fit power law to loss vs computeHow to measure scaling coefficients empirically
2. Chinchilla grid searchFor a fixed compute budget, sweep (N, D) pairs and plot the loss curveThe Chinchilla frontier is a valley, not a cliff
3. MoE scaling analysisRepeat exercise 2 with MoE models, varying expert count and top-kMoE shifts the frontier: more params at same FLOP budget
4. Downstream predictionFit scaling law on pretraining loss, predict downstream task accuracyThe relationship between loss and task performance is noisy but predictable
5. Token quality impactTrain on curated vs web data at same token count, compare scaling curvesData quality is a multiplier on effective compute

Answering "Walk Me Through Your Project"

The project discussion is often the highest-signal part of the interview. Your capstone is the perfect project to discuss. Here is how to structure your walkthrough for maximum impact.

Opening (30 seconds)
"I trained a 10-million parameter transformer to learn addition, then compared a dense architecture against a Mixture-of-Experts variant with 8 experts and top-2 routing. I wrote a custom Pallas kernel for the MoE forward pass that beats JAX's built-in ragged_dot by 1.4x on TPU v2. The whole thing runs on a free Colab TPU in 20 minutes."
Technical Depth (2-3 minutes)
Discuss ONE interesting finding in detail. Example: "The most surprising result was expert specialization. Expert 2 handled 78% of carry-propagation tokens, while expert 5 specialized in final-digit generation. I discovered this by analyzing routing decisions conditioned on whether the output digit required a carry. This emergent specialization — without any explicit supervision — is exactly the mechanism that makes DeepSeek-V3's fine-grained experts work."
What Was Hard (1 minute)
"The hardest part was not the implementation — it was getting the Pallas kernel to handle the variable-length expert batches correctly. My first version had a boundary bug where the last tile of each expert's batch read uninitialized memory, producing NaN. I fixed it by explicit padding to block-size multiples, but only after two hours of debugging by bisecting block sizes."
What You Would Change (30 seconds)
"If I were doing this again, I would add Expert Choice routing and compare it against top-k. I would also try reversed digit order, which should make carry propagation easier. And I would scale to 100M parameters to see if the scaling law fit improves with more data points."
The "what was hard" moment. This is where you separate yourself from candidates who follow tutorials. Interviewers do not just want to know you built something — they want to know you struggled with something, diagnosed the root cause, and fixed it. The NaN-from-uninitialized-memory story is exactly the kind of detail that signals real implementation experience. Prepare 2-3 such stories from your capstone work.

Complete Engineermaxxing Cross-Reference Map

Every Engineermaxxing resource relevant to frontier lab preparation, organized by topic.

TopicMicro LessonVeanor Paper
TransformersTransformer, Attention
TrainingDistributed Training
InferenceLLM InferenceSnapKV
QuantizationLLM.int8(), QuIP, QuIP#, QTIP, AQLM
AttentionAttentionFlashAttention-4
GPU KernelsGPU Kernel LandscapeThunderKittens
JAXThinking in JAX
AgentsFunSearchAlphaEvolve

Lab-Specific Preparation

Different labs emphasize different skills. Tailor your preparation to your target.

LabPrimary StackInterview EmphasisPreparation Focus
AnthropicJAX / TPU / Custom infraSystems + safety. Deep technical + values alignment.Pallas kernels, scaling laws, RLHF/Constitutional AI, interpretability basics. Read Anthropic's published research.
Google DeepMindJAX / TPU / GeminiResearch engineering. Strong ML theory + implementation.Scaling laws, MoE (Gemini uses MoE), JAX internals, distributed training. Be ready for math-heavy derivations.
OpenAIPyTorch / NVIDIA GPUs / AzureImpact + scale. Ship fast, iterate faster.CUDA/Triton, distributed training, RLHF, evaluation design. Strong coding is essential — they test implementation speed.
Meta FAIRPyTorch / NVIDIA GPUs / On-premOpen research. Publishing + artifacts.Core ML research skills, paper reproduction, open-source contributions. They value academic publication background.
xAIJAX / TPU + GPU / CustomFull-stack speed. Build fast, break things, fix them.Broad systems knowledge, comfort with ambiguity, willingness to work on whatever is needed. Startup mentality.
Read their papers. Before interviewing at any lab, read their last 5 published papers. Not for memorization — for understanding what problems they care about. When the interviewer asks "what are you excited about?", you can reference their own work and suggest extensions. This demonstrates genuine interest, not just job-seeking.

The 30-Day Countdown

If you have an interview in 30 days, here is the optimal allocation of your preparation time. This assumes 2-3 hours per day of focused preparation.

Days 1-5: Foundation Audit
Take the quiz at the end of each chapter in this lesson. Any chapter where you cannot answer confidently: re-read it. Build a list of your weak spots. Priority: chapters 1-4 (CUDA, quantization, attention) are tested most frequently.
Days 6-12: Paper Deep-Dives
Read 1 paper per day from the reading list. For each paper, write a 1-paragraph summary and implement 1 key equation in code. Ship to a public GitHub repo. By day 12, you have 7 paper summaries with code.
Days 13-20: Capstone Project
Complete the chapter 15 capstone. Train the addition transformer (dense + MoE), write the Pallas kernel, run the scaling law analysis. Document everything in a public repo with a clear README.
Days 21-25: System Design Practice
Pick 3 system design questions from the bank above. For each, write a 2-page design doc (diagrams, tradeoffs, failure modes). Time yourself: 45 minutes per design, then 15 minutes to review and improve.
Days 26-28: Mock Interviews
3 mock interviews with different partners. One focused on coding (implement MoE router live), one on system design (whiteboard a training pipeline), one on research discussion (present your capstone). Get feedback. Iterate.
Days 29-30: Final Review
Re-read your paper summaries. Re-run your capstone notebook. Review the question bank. Get a good night's sleep. You are ready.

What Separates a Hire from a No-Hire

After conducting hundreds of interviews, here is what frontier lab hiring committees consistently cite as the differentiators:

SignalNo-Hire PatternHire Pattern
DepthKnows what FlashAttention is. Cannot explain why online softmax is needed.Can derive the online softmax correction factor from scratch. Knows when it breaks (very long sequences, numerical stability).
Taste"I would try all the standard techniques and see what works.""Given the constraints, I would start with X because Y. If that fails, the fallback is Z, which trades off A for B."
Ownership"I worked on the training pipeline." (passive)"I designed the calibration system. It reduced quantization regressions from 15% to 2%. Here is the specific design decision that mattered most."
CuriosityRecites paper results.Critiques paper methodology. "They claim 2x speedup, but their baseline is suboptimal. With a tuned Triton kernel as baseline, the improvement is closer to 1.3x."
CalibrationConfidently guesses when uncertain."I believe the answer is X based on Y, but I am not 100% sure about the constant factor. Let me derive it to check."

Technical Writing for Frontier Labs

Frontier labs value written communication as much as coding. Design docs, post-mortems, experiment write-ups, and internal papers are how technical decisions get made. Here is how to write effectively in a lab context.

Document TypeLengthAudienceStructure
Design Doc3-8 pagesTech lead + peer engineersProblem statement, proposed solution, alternatives considered, risks, implementation plan, success metrics
Experiment Report1-3 pagesImmediate teamHypothesis, setup, results (tables + plots), interpretation, next steps. Keep it factual — save opinions for the discussion section.
Post-Mortem2-4 pagesFull engineering orgTimeline of the incident, root cause analysis, contributing factors, what went well, action items with owners and deadlines
Technical Blog Post5-15 pagesExternal (research community)Motivation, approach, results, ablations, limitations, future work. Must be reproducible from the description alone.
The design doc test. Write a 1-page design doc for your capstone project. It should answer: (1) What problem does this solve? (2) Why this approach and not alternatives? (3) What could go wrong? (4) How will we know it worked? If you can write this clearly, you can write any technical document a frontier lab will ask for. Practice this before your interview — some labs include a "design doc review" round.

The Mental Model Library

Staff-level engineers carry a library of mental models that they apply across different problems. Here are the ones that appear most frequently in frontier lab work:

Mental ModelCore InsightApplied To
Roofline AnalysisEvery operation is either compute-bound or memory-bound. Know which one before optimizing.Kernel optimization, hardware selection, batching strategy
Amdahl's LawSpeeding up part of a system is limited by the fraction of time spent in that part. If matmul is 40% of step time, a 2x speedup saves at most 20%.Profiling, optimization prioritization
The Bitter LessonMethods that scale with compute beat methods that encode human knowledge. (Rich Sutton, 2019)Architecture decisions, research direction
Goodhart's LawWhen a measure becomes a target, it ceases to be a good measure. Applies directly to reward model training in RLHF.Evaluation design, reward hacking prevention
The Swiss Cheese ModelSafety comes from multiple independent layers, each with holes. No single layer is perfect, but stacking them reduces risk exponentially.AI safety, deployment guardrails, testing strategy
Diminishing ReturnsThe first optimization gives 2x. The next gives 1.3x. The next gives 1.1x. Know when to stop optimizing and start shipping.Performance engineering, hyperparameter tuning

The Daily Practice Routine

Interview preparation is a skill, and skills require practice. Here is a 4-week routine that builds all five dimensions:

WeekFocusDaily Practice (2 hours)
1Concept + CodeMorning: read one paper from the list, take handwritten notes. Evening: implement the key algorithm in JAX. Ship to GitHub.
2Design + DebugMorning: pick a system design question, write a 1-page design doc (whiteboard format). Evening: reproduce a training failure and debug it (introduce NaN intentionally, fix it).
3CapstoneFull days on the capstone project from chapter 15. Train, measure, analyze, document.
4Frontier + MockMorning: read one paper from last 3 months, write a 3-paragraph summary. Evening: mock interview with a friend or self-recorded. Review and iterate.

Resources for Self-Study

Beyond papers and this lesson, here are the highest-value resources for frontier lab preparation, ordered by signal-to-noise ratio.

ResourceFormatBest ForTime Investment
Andrej Karpathy's "Neural Networks: Zero to Hero"YouTube seriesBuilding transformers from scratch. His nanoGPT implementation is a masterclass in clarity.20 hours
JAX documentation + tutorialsOnline docsUnderstanding jit, vmap, pmap, grad. The "Thinking in JAX" guide is essential.10 hours
Triton tutorials (OpenAI)Online docs + notebooksWriting GPU kernels. The matrix multiply tutorial alone is worth the time.15 hours
Stanford CS231n (Computer Vision)Lectures + assignmentsCNN architectures, training dynamics, practical optimization. Still the best ML systems course.40 hours
The Illustrated Transformer (Jay Alammar)Blog postVisual intuition for attention. Good supplement to implementation-focused learning.2 hours
vLLM source codeGitHub repoProduction LLM serving: PagedAttention, continuous batching, speculative decoding. Read the core engine.20 hours
Chinchilla paper (Hoffmann et al.)arXiv paperTHE scaling laws paper. Every frontier lab uses these results for training decisions.5 hours

Building in Public: The Portfolio Playbook

The most effective portfolio is not a collection of finished projects — it is a public learning log. Here is exactly what to build, in what order, with the expected time investment.

Week 1: "I Built a Transformer"
Implement a decoder-only transformer in JAX from scratch (no frameworks). Train on character-level Shakespeare or addition. Push to GitHub with a clear README showing the training curve and generated text.
Signal: You can go from theory to working code.
Time: 10-15 hours
Week 2: "I Wrote a Kernel"
Write a Triton or Pallas matmul kernel. Benchmark against PyTorch/JAX built-in. Create a plot showing TFLOPS vs matrix size. Blog post explaining your tiling strategy.
Signal: You understand hardware-level optimization.
Time: 15-20 hours
Week 3: "I Reproduced a Paper"
Pick ONE paper (e.g., LoRA, FlashAttention, or SmoothQuant). Implement the core algorithm. Reproduce one key figure on a small-scale model. Write up what was harder than expected.
Signal: You can go from paper to implementation and think critically.
Time: 20-30 hours
Week 4: "I Analyzed Scaling"
The capstone: train the addition transformer at multiple model sizes and token budgets. Fit the scaling law. Compare dense vs MoE. Write the Pallas kernel. Public repo with all results.
Signal: You have the full stack: theory + systems + research taste.
Time: 30-40 hours
The 4-week challenge. If you dedicate ~25 hours per week for 4 weeks, you will have a portfolio that puts you in the top 5% of frontier lab applicants. That is 100 hours of focused work. Not 100 hours of reading — 100 hours of building, measuring, and documenting. The artifacts are your proof of work. The interviewer can run your notebook and verify every claim.

The Negotiation: What to Know

Frontier lab compensation is highly competitive. Here is what to expect and how to navigate the offer stage.

LevelTypical TitleTypical TC (US, 2025)What They Want
L4 / IC3ML Engineer$300-450KStrong implementation skills. Can own a component end-to-end. 2-5 years experience.
L5 / IC4Senior ML Engineer$450-650KTechnical leadership on a project. Mentors junior engineers. 4-8 years experience.
L6 / IC5Staff ML Engineer$600-1M+Owns a system and its evolution. Influences team direction. Recognized expert in a domain. 7+ years.
L7 / IC6Principal / Distinguished$800K-2M+Shapes the technical strategy. Known externally. Papers, talks, or major systems.

Total compensation (TC) at frontier labs is heavily weighted toward equity (RSUs or options). Base salary is typically $200-300K regardless of level. The rest is equity, which can be highly variable for pre-IPO companies (Anthropic, xAI) vs public companies (Google, Meta).

Negotiation tip. The strongest negotiation lever is a competing offer from another frontier lab. If you interview at 2-3 labs simultaneously and receive multiple offers, each lab will typically match or exceed the competing offer. Do not share the exact number — say "I have a competitive offer that I am seriously considering" and let them come to you. The second strongest lever is a strong performance in the technical interview — ask your recruiter for the feedback summary, which often includes a "strong hire" or "exceptional" rating that can justify an above-band offer.

Your First 90 Days at a Frontier Lab

You got the offer. Now what? Here is what the first 90 days typically look like for a new ML research engineer at a frontier lab.

PeriodFocusWhat Success Looks Like
Days 1-14Onboarding + codebase orientationYou can submit a PR that passes CI. You understand the training pipeline: where data comes in, how it flows through the model, where checkpoints are saved. You know the team's Slack channels and meeting cadence.
Days 15-30First meaningful contributionYou ship a small optimization (fuse two ops, fix a memory leak, improve a data pipeline) that measurably improves training throughput or eval accuracy. It does not need to be groundbreaking — it needs to be correct, well-tested, and shipped.
Days 31-60Own a componentYou are the go-to person for one subsystem (quantization pipeline, evaluation framework, data preprocessing, or a specific model component). You have opinions about its design and a roadmap for improvement.
Days 61-90Drive a projectYou propose and lead a project that spans multiple weeks. Maybe it is implementing MoE routing for the next model, or building a new eval benchmark, or writing a Pallas kernel for the training bottleneck you found in your first month. This is where you transition from "new hire" to "team member."

The single most important thing in your first 90 days: ship early and often. A small PR on day 7 is worth more than a perfect PR on day 45. Frontier labs move fast. Show that you can too.

Closing: The Feynman Standard

"What I cannot create, I do not understand." — Richard Feynman's blackboard at the time of his death. This is the standard that frontier labs hold their engineers to. It is not enough to read about transformers, MoE, Pallas, and scaling laws. You must build them. From scratch. On real hardware. With measured results.

The fifteen chapters of this lesson have given you the knowledge. The capstone exercise gives you the opportunity. The paper list gives you the frontier. The question bank gives you the practice.

The only thing left is the work.

Go build something.
Final reflection: An interviewer at a frontier lab asks: "You have 2 weeks to prepare a technical presentation for our team. You can pick any topic. What do you choose, and why?" Which answer best demonstrates the qualities frontier labs value?