Staff-level interview prep: GPU kernels, quantization, JAX, scaling laws, LLM-guided search, and the full stack from roofline to deployment.
It is 7:30 AM. You badge into the open-plan research floor at a frontier AI lab — think Anthropic, DeepMind, or OpenAI scale. On one monitor, a Pallas kernel you wrote for a custom grouped-query attention variant is 15% slower than the CUTLASS baseline on H100. The kernel profiles fine on a single chip, but when you run it through the full TP=8 sharding, an unexpected AllReduce is eating 2 ms per layer. You need to figure out why before the training run launches at noon — 2048 H100s are reserved, and idle time costs the lab roughly $50 per GPU per hour.
On your second monitor, a colleague on the post-training team has pinged you. They are running a QuIP#-style 2-bit quantization experiment on the latest 70B checkpoint. Perplexity on the eval set is 0.4 nats higher than the FP16 baseline — tolerable for a 16x compression ratio. But on a specific subset of math reasoning benchmarks, accuracy has cratered from 68% to 41%. They need your diagnosis. Is the quantization destroying the precision in the MLP up-projection layers that are critical for chain-of-thought arithmetic? Or is the calibration set too narrow?
On Slack, the scaling laws team has posted a sprawling analysis. They are debating the next model generation: should the lab invest the $80M compute budget in a dense 400B model, or a Mixture-of-Experts with 16 experts and 1.2T total parameters? The MoE activates only 70B parameters per token, giving it effectively the same inference cost as a dense 70B model, but training requires careful load balancing and routing. They want your compute analysis — FLOPs, memory, communication overhead, and predicted loss at the target budget — by end of day.
Before lunch, you will fix the kernel (the AllReduce was happening on the wrong axis due to a sharding annotation typo in the Pallas mesh spec), diagnose the quantization issue (the calibration set had zero math-heavy prompts, so the activation ranges for arithmetic-critical layers were miscalibrated), and draft the first half of the scaling comparison (using the C ≈ 6ND relationship to compute Chinchilla-optimal points for both architectures).
This is the daily reality of a Frontier Lab Engineer. You sit at the intersection of three disciplines that rarely overlap in a single person:
| Discipline | What you need | How it shows up daily |
|---|---|---|
| Systems & Hardware | GPU memory hierarchy, kernel writing (Pallas/Triton/CUDA), roofline analysis, communication primitives | You profile a Pallas kernel in the XLA profiler, discover a bank conflict in shared memory, fix it, and cut the custom attention from 4.2 ms to 2.9 ms per layer |
| ML Research | Transformer architecture, scaling laws, training dynamics, loss landscapes, alignment methods | You read the DeepSeek-V3 paper to understand their MoE load-balancing loss, then estimate whether the technique saves enough FLOPs to justify the engineering investment |
| Tooling & Infrastructure | JAX/XLA ecosystem, distributed training frameworks, TPU/GPU cluster management, experiment tracking | You write a JAX training loop with FSDP sharding, add gradient checkpointing to fit within HBM, and instrument it with TensorBoard profiling hooks |
The diagram below traces the lifecycle of a frontier model from pre-training cluster to production serving. Every box is a system you own or co-own. Think of it as the "full-stack map" you would draw on a whiteboard in a system-design interview.
A senior engineer can train a model. They know the APIs, can write a training loop, can debug a loss spike. Give them an architecture and a compute budget, and they will hit a target loss.
A staff engineer decides which architecture to train. They run the scaling law experiments that tell you whether 400B dense or 1.2T MoE is optimal for your budget. They design the sharding strategy that determines whether your training run achieves 42% or 48% MFU — a difference that represents $6M in compute savings at scale. They build the evaluation framework that catches capability regressions before the model ships. And when the next generation of hardware arrives (B200, TPU v6), they redesign the parallelism strategy and kernel library rather than patching the old system.
The distinction is leverage. A senior engineer makes one training run succeed. A staff engineer builds the infrastructure that makes every training run at the lab more efficient, more reliable, and more reproducible.
Every strong interview loop at a frontier lab tests five orthogonal skills. Each chapter in this lesson hits multiple dimensions, but the table below shows the kind of question each dimension produces.
| Dimension | What they test | Example question | What a staff answer adds |
|---|---|---|---|
| CONCEPT | First-principles math and theory | "Derive the roofline bound for multi-head attention on H100" | Connects the bound to practical kernel design decisions |
| DESIGN | System architecture and trade-offs | "Design the sharding strategy for a 400B MoE model across 2048 GPUs" | Discusses expert parallelism, communication topology, failure recovery |
| CODE | Implementation in JAX/Python/Triton | "Write a JAX training loop with FSDP and gradient checkpointing" | Adds sharding annotations, proper PRNG handling, profiling hooks |
| DEBUG | Failure diagnosis under production pressure | "Your 2048-GPU training run's loss spikes every 500 steps. What do you check?" | Systematic bisection: data pipeline, gradient norms, learning rate schedule, hardware faults |
| FRONTIER | Research awareness and taste | "What architecture changes since GPT-4 do you think matter most?" | Discusses GQA, RoPE, SwiGLU, MoE routing, and why each matters at scale with specific numbers |
| Time | Task | Skill used |
|---|---|---|
| 7:30 | Check overnight training run: loss curve, gradient norms, MFU dashboard | Debug |
| 8:00 | Profile custom Pallas attention kernel, find unnecessary AllReduce | Code + Debug |
| 9:00 | Fix sharding annotation, re-benchmark, verify MFU improvement | Code + Concept |
| 10:00 | Standup: present scaling law projections for next model generation | Design + Concept |
| 10:30 | Diagnose quantization anomaly on 70B checkpoint math benchmarks | Debug + Concept |
| 11:30 | Write calibration-aware quantization script, re-run with math-heavy prompts | Code |
| 13:00 | Review colleague's DPO training changes: reward model architecture, data mix | Frontier + Design |
| 14:00 | Design doc: MoE vs dense cost-benefit analysis with FLOPs projections | Design + Concept |
| 15:30 | Prototype ring attention kernel for 128K context sequence parallelism | Code + Frontier |
| 17:00 | Run ablation: GQA head count vs memory savings vs eval degradation | Concept + Design |
This lesson is structured around the advice from Vlad Mnih's influential blog post on how to prepare for frontier lab roles. The core framework: master both "below the stack" (GPU kernels, memory hierarchy, roofline analysis, XLA compilation) and "above the abstraction" (scaling laws, RLHF, agents, evaluation). The specific preparation he recommends:
| Preparation area | Why it matters | Chapters in this lesson |
|---|---|---|
| Roofline analysis | Every kernel, every architecture choice comes down to "are we compute-bound or memory-bound?" | Ch 1 |
| GPU memory hierarchy | You cannot write fast kernels without understanding registers → SRAM → L2 → HBM | Ch 2 |
| Transformer math | Parameter counts, FLOPs, memory budgets — you must be able to derive these from scratch in an interview | Ch 3 |
| JAX fundamentals | Frontier labs use JAX/XLA for training. jit, grad, vmap, pjit are non-negotiable | Ch 4 |
| Custom kernels (Pallas) | When the default XLA lowering is too slow, you write Pallas kernels | Ch 5 |
| Quantization | Serving a 400B model at production scale requires INT4/FP8 without quality loss | Ch 6 |
| FlashAttention | The single most important kernel optimization in modern LLMs | Ch 7 |
| Scaling laws | $50-200M compute decisions depend on predicting loss at target scale | Ch 8 |
| MoE architectures | Every frontier lab is shipping MoE models — the routing/balancing tradeoffs are interview staples | Ch 9 |
| RLHF / DPO | Post-training alignment is where models become useful — you need to understand the full pipeline | Ch 10 |
| Distributed training | DP, TP, PP, FSDP, expert parallelism — 2048+ GPUs require all of them simultaneously | Ch 11 |
| Infra & profiling | XLA profiler, Nsight, MFU dashboards, checkpoint management | Ch 12 |
| LLM agents & search | "Above the abstraction" — tree search, tool use, the LLM as a heuristic function | Ch 13 |
| Eval & safety | How do you know if the model is better? How do you know it is safe? | Ch 14 |
Your colleague submits a Pallas kernel for grouped-query attention. It runs at 180 TFLOPS on an H100. Is that good? Is there room for improvement? Without a framework for answering this question, you are guessing. The roofline model gives you a principled upper bound on performance for any operation, in under 60 seconds of mental math.
Reiner Pope (DeepMind) popularized a style of back-of-the-envelope analysis where every architecture decision — attention variant, parallelism strategy, batch size — is evaluated through the roofline lens. At a frontier lab, this is the first tool you reach for. Master it, and you will never be surprised by a kernel's performance again.
A GPU has two fundamental resources that limit performance. Every operation you run is bottlenecked by one or the other — never both simultaneously (at the ridge point, both bind equally, but this is rare in practice).
Compute is measured in FLOPS (floating-point operations per second). An H100 SXM5 delivers:
| Precision | Peak TFLOPS | Hardware unit |
|---|---|---|
| FP32 | 67 | CUDA cores |
| FP16 / BF16 | 989 | Tensor Cores (4th gen) |
| FP8 (E4M3) | 1,979 | Tensor Cores |
| INT8 | 1,979 | Tensor Cores |
Memory bandwidth is measured in bytes per second. The H100 has HBM3 at 3.35 TB/s (that is 3,350 GB/s). This is the rate at which data can flow between the GPU's main memory (80 GB of HBM) and its compute units.
Think of it this way: the GPU is a factory. Compute is the number of workers on the factory floor. Bandwidth is the width of the loading dock. If you have 1,000 workers but a narrow loading dock, the workers stand idle waiting for materials. If you have a massive loading dock but only 10 workers, materials pile up unprocessed.
Arithmetic intensity (AI) is the ratio of computation to data movement for a given operation:
The units are FLOPs/byte. This single number tells you whether your operation is compute-bound (limited by how fast the GPU can crunch numbers) or memory-bound (limited by how fast data can flow to the compute units).
Let us derive the roofline model step by step. We want to find the maximum achievable FLOPS for an operation with arithmetic intensity AI.
This is the entire model. Below the ridge point (AI < 295), you are memory-bound and performance scales linearly with arithmetic intensity. Above it (AI > 295), you are compute-bound and performance is flat at the peak compute rate. Plot this on a log-log graph and you get the characteristic "roofline" shape: a diagonal line (memory-bound region) that hits a flat ceiling (compute-bound region).
Matrix multiplication is the most important operation in deep learning. Let us analyze it with the roofline model.
The takeaway: large matrix multiplications are compute-bound. This is why Tensor Cores matter — they are the fast lane for matrix multiplication. Small matrix multiplications (batch size 1, short sequences) can become memory-bound, which is exactly why LLM inference at batch=1 is often bandwidth-limited.
Self-attention in a transformer is more nuanced because it involves multiple operations with very different arithmetic intensities. Let us dissect it.
For multi-GPU training and inference, there is a third bound that the single-GPU roofline ignores: communication bandwidth. When your model is sharded across 8 GPUs with tensor parallelism, every attention layer requires an AllReduce of the output. The AllReduce transfers data proportional to the tensor size over NVLink or InfiniBand.
The three-way roofline analysis is what separates a staff engineer from a senior. A senior knows that attention is memory-bound. A staff engineer can tell you whether the AllReduce or the attention kernel itself is the bottleneck for a specific sharding configuration, and can redesign the parallelism strategy accordingly.
python import dataclasses # ── GPU specifications ── @dataclasses.dataclass class GPUSpec: name: str peak_tflops_fp16: float # Tensor Core peak hbm_bandwidth_tbs: float # HBM bandwidth in TB/s sram_per_sm_kb: float # Shared memory per SM in KB num_sms: int nvlink_bw_gbs: float # NVLink bandwidth per direction in GB/s @property def ridge_point(self) -> float: """Arithmetic intensity where compute = bandwidth bound.""" return self.peak_tflops_fp16 * 1e12 / (self.hbm_bandwidth_tbs * 1e12) def achieved_tflops(self, arithmetic_intensity: float) -> float: """Roofline model: min(peak, AI * bandwidth).""" bandwidth_limited = arithmetic_intensity * self.hbm_bandwidth_tbs return min(self.peak_tflops_fp16, bandwidth_limited) def efficiency(self, arithmetic_intensity: float) -> float: """What fraction of peak compute we achieve.""" return self.achieved_tflops(arithmetic_intensity) / self.peak_tflops_fp16 H100 = GPUSpec( name="H100 SXM5", peak_tflops_fp16=989, hbm_bandwidth_tbs=3.35, sram_per_sm_kb=228, num_sms=132, nvlink_bw_gbs=450, ) A100 = GPUSpec( name="A100 SXM4", peak_tflops_fp16=312, hbm_bandwidth_tbs=2.0, sram_per_sm_kb=164, num_sms=108, nvlink_bw_gbs=300, ) # ── Roofline analysis for common operations ── def matmul_roofline(M, K, N, bytes_per_elem=2, gpu=H100): """Roofline analysis for C[M,N] = A[M,K] @ B[K,N].""" flops = 2 * M * K * N bytes_moved = bytes_per_elem * (M*K + K*N + M*N) ai = flops / bytes_moved tflops = gpu.achieved_tflops(ai) time_ms = flops / (tflops * 1e12) * 1e3 bound = "COMPUTE" if ai > gpu.ridge_point else "MEMORY" print(f"Matmul [{M},{K}]x[{K},{N}] on {gpu.name}:") print(f" FLOPs: {flops:.2e}") print(f" Bytes: {bytes_moved:.2e}") print(f" AI: {ai:.1f} FLOPs/byte") print(f" Ridge: {gpu.ridge_point:.1f} FLOPs/byte") print(f" Bound: {bound}") print(f" Achieved: {tflops:.1f} TFLOPS ({gpu.efficiency(ai)*100:.1f}%)") print(f" Time: {time_ms:.3f} ms") return ai, tflops, time_ms def attention_roofline(seq_len, d_head, n_heads, gpu=H100): """Roofline for standard (non-Flash) multi-head attention.""" S, d, h = seq_len, d_head, n_heads # Per head: 2 matmuls (QK^T, PV) + softmax flops_per_head = 2 * (2 * S * S * d) + 5 * S * S total_flops = h * flops_per_head # Without FlashAttention: must materialize S×S attn matrix per head bytes_per_head = 2 * (2*S*d + S*S + S*S + S*d + S*d) # Q,K,P,P_softmax,V,O total_bytes = h * bytes_per_head ai = total_flops / total_bytes tflops = gpu.achieved_tflops(ai) bound = "COMPUTE" if ai > gpu.ridge_point else "MEMORY" print(f"\nAttention S={S}, d={d}, h={h} on {gpu.name}:") print(f" AI: {ai:.1f} FLOPs/byte ({bound})") print(f" Achieved: {tflops:.1f} TFLOPS ({gpu.efficiency(ai)*100:.1f}%)") return ai # ── Run examples ── matmul_roofline(4096, 4096, 4096) # Square matmul matmul_roofline(1, 4096, 4096) # Batch=1 inference matmul_roofline(4096, 4096, 11008) # LLaMA FFN up-projection attention_roofline(2048, 128, 32) # Standard LLM attention attention_roofline(32768, 128, 32) # Long-context attention
matmul_roofline(1, 4096, 4096) above. The result: AI = 1.0 FLOPs/byte, meaning we are at 0.3% of the ridge point. The GPU achieves only 3.35 TFLOPS out of a possible 989 — barely 0.3% efficiency. This is why single-request LLM inference is so slow: each token generation is a batch=1 matrix-vector multiply, and the GPU is almost entirely idle, waiting for data to arrive from HBM. This is the fundamental reason batching, speculative decoding, and quantization (which reduces bytes transferred) matter so much for LLM serving.Adjust the operation parameters and watch the dot move on the roofline plot. Below the ridge: memory-bound. Above: compute-bound.
The roofline gives you an upper bound. Real kernels fall below the roofline for several reasons:
| Inefficiency | What it means | Typical penalty |
|---|---|---|
| Occupancy | Not enough warps in flight to hide memory latency | 10-50% below roofline if occupancy < 50% |
| Bank conflicts | Multiple threads accessing the same shared memory bank | 2-32x slowdown on shared memory loads |
| Non-coalesced access | Threads in a warp reading non-contiguous memory | 2-32x slowdown on global memory loads |
| Instruction overhead | Control flow, address computation, loop overhead | 5-15% below roofline for simple kernels |
| Tail effects | Last block of a grid is partially filled | Depends on grid dimensions vs SM count |
| L2 thrashing | Working set too large for L2 cache | Effective bandwidth drops below peak HBM rate |
A well-optimized kernel achieves 70-85% of the roofline prediction. If you are below 50%, look for one of the above issues. If you are above 85%, you are doing excellent work. Above 90% is exceptional and usually requires deep hardware-specific tuning.
Here is a table of common deep learning operations and their typical arithmetic intensity. Memorize these — they come up constantly in interviews.
| Operation | FLOPs | Bytes (FP16) | AI (FLOPs/byte) | Bound on H100 |
|---|---|---|---|---|
| Large matmul [4096×4096]×[4096×4096] | 1.37×1011 | 1.01×108 | 1365 | Compute |
| Batch=1 matmul [1×4096]×[4096×4096] | 3.36×107 | 3.36×107 | 1.0 | Memory (0.3%) |
| LayerNorm (d=4096) | ~5 per elem | 4 per elem | 1.25 | Memory (0.4%) |
| Softmax (S=2048) | ~5 per elem | 4 per elem | 1.25 | Memory (0.4%) |
| GELU / SiLU | ~10 per elem | 4 per elem | 2.5 | Memory (0.8%) |
| Embedding lookup | 0 | 2 per elem | 0 | Memory (pure) |
| Attention (standard) | 4S2d | ~8S2 | d/2 ≈ 64 | Memory (21%) |
| Attention (Flash) | 4S2d | ~8Sd | S/2 | Compute for S>590 |
A matrix multiply computes C = A×B where A is [2048, 4096] and B is [4096, 2048]. Assume BF16 (2 bytes). What is the arithmetic intensity in FLOPs/byte? (Hint: FLOPs = 2MNK, bytes = 2(MK + KN + MN))
This is well above the H100 ridge point (~295), so this matmul is compute-bound. Good — tensor cores will be fully utilized.
Standard attention (no FlashAttention) with S=2048, d=128, batch=1, single head. FLOPs ≈ 4S²d. Bytes ≈ 2(2Sd + S²) × 2 for BF16. What is the arithmetic intensity? Is this compute-bound or memory-bound on H100?
The H100 ridge point is ~295 FLOPs/byte. AI=114 < 295, so standard attention is memory-bound. This is exactly why FlashAttention was invented — by tiling the computation and avoiding materializing the S×S matrix to HBM, it increases effective arithmetic intensity.
The roofline model told us that most operations in a transformer are memory-bound. The natural follow-up question: which memory? A GPU does not have one flat memory space. It has a hierarchy of four levels, each with vastly different bandwidth, latency, and capacity. Understanding this hierarchy is the foundation of all kernel optimization.
Think of the GPU memory hierarchy as a city's transportation system. Registers are the desk you are sitting at — instant access, but tiny (a few papers). Shared memory (SRAM) is the filing cabinet in your office — a few steps away, fast to access, holds a decent amount. L2 cache is the building's mailroom — takes a walk, much larger. HBM is the warehouse across town — takes a truck to deliver, but stores everything.
| Level | Technology | Bandwidth | Latency | Capacity per SM | Total | Scope |
|---|---|---|---|---|---|---|
| Registers | Flip-flops | ~78 TB/s | 0 cycles (same clock) | 256 KB | ~33 MB | Per thread |
| Shared Memory (SRAM) | SRAM | ~19 TB/s | ~23 cycles | Up to 228 KB | ~30 MB | Per block (CTA) |
| L2 Cache | SRAM | ~12 TB/s | ~200 cycles | — | 50 MB | All SMs |
| HBM3 | 3D-stacked DRAM | 3.35 TB/s | ~400-600 cycles | — | 80 GB | All SMs |
Look at the bandwidth column. Registers are 23x faster than HBM. Shared memory is 5.7x faster than HBM. These are not small differences — they are order-of-magnitude gaps that completely determine kernel performance.
The hierarchy is not an arbitrary engineering choice. It is dictated by physics.
Speed vs. capacity trade-off. Flip-flops (registers) toggle in a single clock cycle but require 6 transistors each. One 32-bit register = 192 transistors. To store 80 GB at register density would require trillions of transistors and consume absurd power. SRAM (shared memory) uses 6 transistors per bit but with a different layout — it is denser than registers but slower because the read circuitry adds latency. DRAM (the technology behind HBM) uses just 1 transistor + 1 capacitor per bit, making it ~6x denser than SRAM, but the capacitor must be refreshed periodically and the read is destructive (requiring precharge), adding ~400 cycles of latency.
Speed of light limits. An electrical signal travels about 15 cm per nanosecond on a chip. At 2 GHz clock speed, each cycle is 0.5 ns, so a signal can travel at most 7.5 cm in one cycle. The HBM stacks sit on the edge of the GPU die, about 2-3 cm from the SM array. Even at the speed of light, a round trip takes ~0.4 ns = ~1 cycle just for the physical propagation. Add the DRAM access protocol (row activate, column select, data transfer, precharge) and you get 400+ cycles.
The wire delay problem. On a modern chip, transistor switching is no longer the bottleneck — wire delay is. Moving a bit across a 1 mm wire takes longer than a transistor can switch. This is why local memory (registers, shared memory on the same SM) is fast and global memory (HBM on the edge of the die) is slow. Every level of the hierarchy represents a different physical distance from the compute units.
Each thread on an H100 SM has access to up to 255 32-bit registers (the maximum register file allocation per thread). The total register file per SM is 256 KB. Registers are the only memory that can be accessed at full speed — zero additional latency, and the read happens in the same cycle as the instruction that uses the value.
--ptxas-options=-v when compiling CUDA kernels, or check the XLA compilation log for Pallas kernels. If register usage per thread exceeds 128, you are likely spilling and losing significant performance.Shared memory is the most important memory level for kernel developers because it is the one you explicitly control. Unlike L1/L2 caches (which are hardware-managed), you decide exactly what data goes into shared memory and when.
On H100, each SM has 228 KB of combined shared memory and L1 cache. You can configure the split: up to 228 KB all shared memory with 0 L1, or various splits in between. For most kernels, maximizing shared memory is the right choice because you get to control the access pattern perfectly.
Shared memory's 32 banks are the source of its high bandwidth — but also its most common performance pitfall. When two threads in the same warp access different addresses that map to the same bank, the accesses serialize. This is called a bank conflict.
__shared__ float tile[BLOCK_M][BLOCK_K + 1] in virtually every optimized matmul kernel. In Pallas, the equivalent is choosing tile dimensions that are not powers of 2, or using explicit swizzling patterns.HBM (High Bandwidth Memory) is a stack of DRAM dies connected to the GPU die via silicon interposers. The H100 has six 16-GB HBM3 stacks, totaling 80 GB at 3.35 TB/s aggregate bandwidth.
Despite its name ("High Bandwidth Memory"), HBM is the slowest memory level. Its bandwidth is "high" relative to traditional DDR5 (which peaks at ~0.1 TB/s), but low relative to the GPU's compute capability.
Tensor Cores are specialized matrix multiply units on each SM. They are the reason H100 achieves 989 TFLOPS for FP16 while standard CUDA cores top out at 67 TFLOPS. A single Tensor Core operation on H100 computes a 16×8×16 matrix multiply-accumulate in one cycle.
H100 introduced a dedicated hardware unit called the Tensor Memory Accelerator (TMA). TMA offloads data movement from the SM's compute pipeline to a separate engine. Instead of threads computing addresses and issuing loads, a single thread programs the TMA with a tensor descriptor (base pointer, shape, strides, data type), and the TMA fetches the entire tile asynchronously.
BlockSpec for your input and output tensors. The Pallas compiler on Hopper translates these into TMA descriptor loads. This means a well-written Pallas kernel can achieve near-CUTLASS performance without any manual memory management — the compiler handles TMA for you. But you still need to choose tile sizes wisely: tiles too small waste TMA setup overhead; tiles too large spill out of shared memory.Watch data flow through the memory hierarchy. Click "Run Matmul Tile" to see a single tile loaded from HBM into shared memory and then consumed by Tensor Cores.
ReadyLet us trace a single matmul through the memory hierarchy to understand where time is spent.
| Metric | A100 | H100 | B200 | Trend |
|---|---|---|---|---|
| FP16 Tensor TFLOPS | 312 | 989 | ~2,250 | ~3x per generation |
| HBM Bandwidth (TB/s) | 2.0 | 3.35 | ~8.0 | ~2x per generation |
| Ridge Point (FLOPs/byte) | 156 | 295 | ~281 | Increasing: more ops become memory-bound |
| SRAM per SM (KB) | 164 | 228 | ~256 | Slow growth: physics limits |
| HBM Capacity (GB) | 80 | 80 | 192 | 2.4x: finally fits 70B FP16 |
In a scaling law meeting, someone asks: "How many FLOPs to train a 70B model on 2T tokens?" You need to answer in under 10 seconds. In a system-design interview, the question is: "How much memory does a 400B MoE model need at inference time with batch=64, seq_len=8192?" You need to derive it on the whiteboard in 2 minutes. This chapter gives you the formulas and worked examples to do both.
All of these are derived from first principles — no memorization required. Once you understand where each parameter lives and how each FLOP is spent, you can re-derive any formula on the spot.
A standard transformer decoder (GPT-style) has L identical layers, each containing a multi-head attention block and a feed-forward network (FFN). Let us count every parameter.
For every parameter, the forward pass performs approximately 2 FLOPs (one multiply and one add in a matrix multiplication). This gives us a beautifully simple rule:
Training a model requires storing four categories of data in GPU memory. Let us account for each one.
During inference, the model generates one token at a time. For each new token, it needs to attend to all previous tokens. Re-computing K and V for all previous tokens at every step would be wasteful, so we cache them. This KV cache is often the largest memory consumer at inference time.
Model FLOPs Utilization (MFU) is the fraction of the GPU's peak compute that is actually used for model computation:
| System | Reported MFU | Notes |
|---|---|---|
| PaLM (Google, 2022) | 46.2% | 540B on 6144 TPUv4, pipeline + data parallelism |
| LLaMA (Meta, 2023) | ~38% | 65B on 2048 A100, FSDP |
| GPT-4 (estimated) | ~45-50% | Reported to use ~25K A100s, exact MFU unknown |
| DeepSeek-V3 (2024) | ~52% | 671B MoE on H800, FP8 training, expert parallelism |
Adjust model dimensions and see parameter count, training FLOPs, memory footprint, and KV cache size update in real time.
In 2022, the Chinchilla paper (Hoffmann et al.) established that for a given compute budget C, there is an optimal split between model size N and training data D. Training a model that is too large on too little data wastes compute; training a small model on excessive data also wastes compute.
| Quantity | Formula | Value |
|---|---|---|
| Parameters | 80 × (2.25d2 + 3d·dff) + 2Vd | ~69B |
| Weights (BF16) | 69B × 2 | 138 GB |
| Optimizer (AdamW) | 69B × 12 | 828 GB |
| Gradients (BF16) | 69B × 2 | 138 GB |
| Training memory total | Sum + activations | ~1.1-1.3 TB |
| Min GPUs (training) | ~1200 GB / 80 GB | 16 H100s (ZeRO-3) |
| Forward FLOPs/token | 2P | 138 GFLOPs |
| Training FLOPs (15T tok) | 6 × 69B × 15T | 6.2 × 1024 |
| Training time (2048 H100) | C / (GPUs × FLOPS × MFU) | ~80 days at MFU=0.45 |
| KV cache (B=32, S=4096) | 2×80×8×128×4096×32×2 | 42.9 GB |
| Inference memory (TP=4) | (138 + 42.9) / 4 | ~45 GB/GPU |
GPT-2 Small: 12 layers, d_model=768, 12 attention heads, d_head=64, FFN inner dim=3072. Ignore embeddings and LayerNorm. How many parameters (in millions)?
(Adding embeddings: vocab_size × d = 50257 × 768 = 38.6M, shared with output head. Full GPT-2 Small ≈ 124M.)
LLaMA 3 70B: 80 layers, GQA with 8 KV heads, d_head=128, batch=32, sequence length=8192, BF16 (2 bytes). How many GB for the KV cache?
The factor of 2 at the start is for K and V tensors. With GQA (8 KV heads instead of 64 query heads), the cache is 8x smaller than it would be with MHA.
You are training a 7B parameter model on 2T tokens. Using the 6ND approximation, how many total FLOPs (in units of 1022)?
At 50% MFU on 1024 H100s (each 989 TFLOPS BF16): effective = 0.5 × 1024 × 989e12 = 5.06e17 FLOPS. Time = 8.4e22 / 5.06e17 = 166,008 seconds ≈ 1.92 days.
At a frontier lab, the training infrastructure is almost certainly built on JAX. Not PyTorch — JAX. Why? Three reasons: (1) JAX compiles your Python code through XLA, which generates fused, optimized kernels for TPUs and GPUs automatically. (2) JAX's functional design (no hidden state, no in-place mutation) makes distributed training with model parallelism radically simpler — the compiler can shard pure functions across devices. (3) JAX's composable transforms (jit, grad, vmap) let you write research code that is simultaneously readable AND production-fast.
If you are coming from PyTorch, JAX will feel strange at first. PyTorch is imperative: you build a computation graph by executing Python statements, and the graph is traced eagerly. JAX is functional: you write pure functions, and JAX transforms (traces, compiles, differentiates) them. The payoff for this discipline is enormous — but the learning curve is real.
The first surprise: jax.numpy operations do not execute immediately (unless you are in eager mode). When you call a function decorated with @jax.jit, JAX traces the function: it runs your Python code with abstract "tracer" values instead of real data. The tracers record the operations, building an XLA computation graph. Then XLA compiles this graph into optimized machine code.
python import jax import jax.numpy as jnp # ── Eager mode (no jit) ── # Operations execute immediately, like NumPy. # Useful for debugging, but slow (no fusion, no optimization). x = jnp.ones((3, 3)) y = x + 1 # Dispatches immediately to GPU/TPU z = y * 2 # Another dispatch — TWO kernel launches print(z) # [[4. 4. 4.], [4. 4. 4.], [4. 4. 4.]] # ── JIT mode ── # JAX traces the function, compiles it, then runs the compiled version. @jax.jit def f(x): y = x + 1 # NOT executed — recorded as a trace operation z = y * 2 # NOT executed — recorded return z # The trace is: z = (x + 1) * 2 # First call: trace + compile + execute (slow) result = f(x) # XLA fuses add+mul into ONE kernel launch! # Second call with same shape/dtype: just execute (fast, no recompile) result2 = f(jnp.zeros((3, 3))) # Third call with DIFFERENT shape: retrace + recompile! result3 = f(jnp.zeros((4, 4))) # New compilation
jax.lax.cond for value-dependent branching. (2) Python side effects (print statements, list appends, global variable mutations) happen at trace time only, not at execution time. (3) Functions are recompiled when input shapes change. Design your data pipeline to produce fixed-shape batches to avoid constant recompilation.JAX requires your functions to be pure: given the same inputs, they must always return the same outputs, and they must not modify any external state. This is not just a style preference — it is a fundamental requirement for correctness when JAX transforms your code.
python # ── BAD: impure function ── counter = [0] # Mutable external state @jax.jit def bad_f(x): counter[0] += 1 # Side effect! Modifies external state print("traced") # Side effect! Only prints during tracing return x + counter[0] # Call 1: traces, prints "traced", counter becomes 1, returns x+1 # Call 2: DOES NOT retrace (same shape), DOES NOT print, # counter is STILL 1 (the value was captured at trace time), # returns x+1 (not x+2!) # This is a silent correctness bug. # ── GOOD: pure function ── @jax.jit def good_f(x, counter): return x + counter, counter + 1 # Return new counter, don't mutate # State is threaded through as an explicit argument and return value. # No hidden mutation, no trace-time capture bugs. result, new_counter = good_f(x, jnp.array(0)) result2, new_counter2 = good_f(x, new_counter)
Why does purity matter beyond avoiding bugs? Because the XLA compiler assumes pure functions to perform optimizations. It can reorder operations, fuse kernels, and even eliminate dead code — all of which are only correct if the function has no side effects. The compiler also needs purity to shard the function across multiple devices: if a function modifies global state, you cannot safely run it on 2048 TPUs simultaneously.
jax.grad computes the gradient of a scalar-valued function with respect to its first argument (by default). It uses reverse-mode automatic differentiation (backpropagation), which computes all parameter gradients in a single backward pass.
python import jax import jax.numpy as jnp # ── Basic gradient ── def loss_fn(w, x, y): """MSE loss: L = mean((w @ x - y)^2)""" pred = w @ x return jnp.mean((pred - y) ** 2) # grad takes the gradient w.r.t. the FIRST argument (w) grad_fn = jax.grad(loss_fn) w = jnp.ones((3, 4)) x = jnp.ones((4, 2)) y = jnp.zeros((3, 2)) dL_dw = grad_fn(w, x, y) # Shape: (3, 4) — same as w print(dL_dw.shape) # (3, 4) # ── Gradient w.r.t. multiple arguments ── grad_fn_multi = jax.grad(loss_fn, argnums=(0, 1)) # grad w.r.t. w AND x dL_dw, dL_dx = grad_fn_multi(w, x, y) # ── Value and gradient together ── # In training, you need both the loss value and its gradient. # Computing them separately wastes work (shared forward pass). loss_val, dL_dw = jax.value_and_grad(loss_fn)(w, x, y) print(f"Loss: {loss_val:.4f}, Grad norm: {jnp.linalg.norm(dL_dw):.4f}")
python # ── Forward mode: Jacobian-Vector Product (JVP) ── # Computes f(x) and J @ v simultaneously (J = Jacobian, v = tangent vector) def f(x): return jnp.array([x[0]**2 + x[1], x[0] * x[1]]) x = jnp.array([3.0, 4.0]) v = jnp.array([1.0, 0.0]) # Tangent vector: derivative in x[0] direction primals, tangents = jax.jvp(f, (x,), (v,)) # primals = f(x) = [9+4, 12] = [13, 12] # tangents = J @ v = [2*x[0]*1 + 0, x[1]*1 + 0] = [6, 4] # ── Reverse mode: Vector-Jacobian Product (VJP) ── # Computes f(x) and a function that maps cotangent to gradient primals, vjp_fn = jax.vjp(f, x) # primals = f(x) = [13, 12] # vjp_fn takes a cotangent vector (same shape as output) # and returns the gradient w.r.t. the input cotangent = jnp.array([1.0, 0.0]) # grad of output[0] only (grad_x,) = vjp_fn(cotangent) # grad_x = [2*x[0], 1] = [6, 1] (df[0]/dx[0], df[0]/dx[1])
jax.vmap automatically vectorizes a function that operates on a single example to work on a batch of examples. Instead of writing explicit batch loops or reshaping tensors, you write your function for one example and vmap handles batching.
python # ── Without vmap: manual batching ── def predict_single(params, x): """Forward pass for a single input x of shape (d,).""" for W, b in params: x = jax.nn.relu(W @ x + b) return x # To handle a batch, you'd either: # (a) Loop: [predict_single(params, x_i) for x_i in batch] ← SLOW # (b) Rewrite: change W @ x to x_batch @ W.T + b ← manual work # ── With vmap: automatic batching ── predict_batch = jax.vmap(predict_single, in_axes=(None, 0)) # in_axes=(None, 0) means: # params: NOT batched (same params for all examples) # x: batched along axis 0 (first dim is batch) # Now predict_batch(params, x_batch) works on a batch # where x_batch has shape (batch_size, d) # The output has shape (batch_size, output_d) # ── vmap + grad: per-example gradients ── # This is one of JAX's superpowers: get the gradient for EACH # example in the batch, not just the average gradient. # Useful for: differential privacy, per-example clipping, debugging. def loss_single(params, x, y): pred = predict_single(params, x) return jnp.sum((pred - y) ** 2) # Per-example gradient function: per_example_grad = jax.vmap(jax.grad(loss_single), in_axes=(None, 0, 0)) # per_example_grad(params, x_batch, y_batch) returns a pytree # with the same structure as params, but each leaf has an extra # batch dimension: shape (batch_size, *param_shape)
JAX's transforms (jit, grad, vmap) compose freely. You can stack them in any order, and JAX will trace through all of them to produce a single optimized program.
python # ── The classic training step ── # This is the pattern you'll see in every JAX training loop. @jax.jit # 3. Compile the whole thing def train_step(params, x_batch, y_batch): def loss_fn(params): # vmap handles the batch dimension automatically preds = jax.vmap(predict_single, in_axes=(None, 0))(params, x_batch) # 1. vmap for batching return jnp.mean((preds - y_batch) ** 2) loss, grads = jax.value_and_grad(loss_fn)(params) # 2. grad for differentiation # Simple SGD update new_params = jax.tree.map(lambda p, g: p - 0.01 * g, params, grads) return new_params, loss # What happens under the hood: # 1. jit traces train_step # 2. Inside, it encounters vmap → adds batch tracing # 3. Inside, it encounters grad → adds differentiation tracing # 4. The entire thing (forward + backward + update) compiles # to a SINGLE XLA program with fused kernels # 5. On subsequent calls: no Python overhead, just the compiled kernel
Neural network parameters are not a single array — they are a nested structure of arrays (a dict of dicts of arrays, a list of tuples, etc.). JAX calls these nested structures PyTrees and provides utilities to work with them uniformly.
python # ── Model parameters as a PyTree ── params = { 'layer1': {'W': jnp.ones((128, 64)), 'b': jnp.zeros(64)}, 'layer2': {'W': jnp.ones((64, 32)), 'b': jnp.zeros(32)}, 'head': {'W': jnp.ones((32, 10)), 'b': jnp.zeros(10)}, } # This is a PyTree: a nested dict of jax arrays. # ── jax.tree.map: apply a function to every leaf ── # SGD update: params = params - lr * grads # grads has the SAME structure as params (JAX guarantees this) new_params = jax.tree.map(lambda p, g: p - 0.01 * g, params, grads) # This walks both trees in parallel, applying the lambda to each pair. # No manual iteration over layers! # ── Counting parameters ── n_params = sum(p.size for p in jax.tree.leaves(params)) print(f"Total params: {n_params:,}") # 10,474 # ── Flatten and unflatten ── flat_params, tree_def = jax.tree.flatten(params) # flat_params is a list of arrays: [W1, b1, W2, b2, W3, b3] # tree_def remembers the structure for reconstruction restored = jax.tree.unflatten(tree_def, flat_params) # restored == params (same structure and values) # ── Custom PyTree nodes ── # You can register your own classes as PyTree nodes. # This is how Flax/Equinox/Haiku make their module objects # work seamlessly with jax.grad, jit, vmap, etc. import dataclasses @dataclasses.dataclass class Linear: W: jax.Array b: jax.Array def __call__(self, x): return self.W @ x + self.b # Register with JAX so jit/grad/vmap know how to traverse it: jax.tree_util.register_dataclass(Linear, data_fields=['W', 'b'], meta_fields=[])
jax.experimental.pjit or jax.sharding, you specify a sharding for each leaf in the parameter PyTree. JAX walks the tree, assigns each array to the right devices, and inserts the necessary communication (AllGather, AllReduce) between devices. Without the PyTree abstraction, you would have to manually manage the sharding of every parameter — an error-prone nightmare for a model with thousands of parameters.JAX's random number generation is different from NumPy's, and for good reason. NumPy uses a global, stateful PRNG: np.random.seed(42) followed by np.random.randn() advances a global counter. This is fundamentally incompatible with JAX's functional paradigm (no hidden state) and with parallel execution (what happens when two TPUs both try to advance the same counter?).
python # ── JAX PRNG: explicit, functional, splittable ── # Create an initial key (this is the ONLY global state) key = jax.random.PRNGKey(42) # key is a pair of uint32 values, not a global state print(key) # [ 0 42] # WRONG: reusing the same key gives the SAME random numbers! a = jax.random.normal(key, (3,)) b = jax.random.normal(key, (3,)) # a == b! This is intentional: same key = same result (purity) # RIGHT: split the key to get new, independent subkeys key, subkey1, subkey2 = jax.random.split(key, 3) a = jax.random.normal(subkey1, (3,)) b = jax.random.normal(subkey2, (3,)) # a != b, and both are deterministic given the original key # ── Pattern for training loops ── def train_step(params, batch, key): # Split key for this step's randomness key, dropout_key, noise_key = jax.random.split(key, 3) def loss_fn(params): # Use dropout_key for dropout mask logits = model_forward(params, batch, dropout_key) return cross_entropy_loss(logits, batch['labels']) loss, grads = jax.value_and_grad(loss_fn)(params) params = update_params(params, grads) return params, loss, key # Return the updated key! # The key is threaded through the loop as explicit state. # Every step gets a unique, deterministic key. # The entire training run is perfectly reproducible. # ── Why the split model matters for parallelism ── # On 2048 TPUs, each device needs independent random numbers. # With the split model, you split the key into 2048 subkeys # at the start, one per device. No coordination needed. keys = jax.random.split(jax.random.PRNGKey(42), 2048) # Each device gets keys[device_id] — deterministic and independent.
See how JAX transforms a simple function. Click each transform to see the computation graph at that stage.
Let us put it all together with a complete, production-style training loop for a small MLP. This is the pattern you will see (with more complexity) in frontier lab training codebases.
python import jax import jax.numpy as jnp from typing import NamedTuple # ── Model definition (pure functions + parameter PyTree) ── class MLPParams(NamedTuple): """Parameters for a 2-layer MLP.""" W1: jax.Array # [d_in, d_hidden] b1: jax.Array # [d_hidden] W2: jax.Array # [d_hidden, d_out] b2: jax.Array # [d_out] def init_params(key, d_in, d_hidden, d_out): """Xavier initialization. Key is split for each layer.""" k1, k2 = jax.random.split(key) scale1 = jnp.sqrt(2.0 / (d_in + d_hidden)) scale2 = jnp.sqrt(2.0 / (d_hidden + d_out)) return MLPParams( W1=scale1 * jax.random.normal(k1, (d_in, d_hidden)), b1=jnp.zeros(d_hidden), W2=scale2 * jax.random.normal(k2, (d_hidden, d_out)), b2=jnp.zeros(d_out), ) def forward(params: MLPParams, x): """Forward pass for a single example x of shape (d_in,).""" h = jax.nn.silu(x @ params.W1 + params.b1) # SiLU activation (like LLaMA) return h @ params.W2 + params.b2 # Batched forward: vmap over the data dimension, not over params forward_batch = jax.vmap(forward, in_axes=(None, 0)) # ── Loss function ── def loss_fn(params, x_batch, y_batch): """Cross-entropy loss for classification.""" logits = forward_batch(params, x_batch) # [B, n_classes] # Numerically stable log-softmax log_probs = jax.nn.log_softmax(logits, axis=-1) # Select the log-prob of the correct class loss = -jnp.mean(log_probs[jnp.arange(y_batch.shape[0]), y_batch]) return loss # ── Optimizer state (AdamW, simplified) ── class AdamState(NamedTuple): m: MLPParams # First moment (same tree structure as params) v: MLPParams # Second moment step: jax.Array # Step counter for bias correction def init_adam(params, lr=1e-3, beta1=0.9, beta2=0.999, wd=0.01): m = jax.tree.map(jnp.zeros_like, params) v = jax.tree.map(jnp.zeros_like, params) return AdamState(m=m, v=v, step=jnp.array(0)) def adam_update(params, grads, state, lr=1e-3, beta1=0.9, beta2=0.999, wd=0.01, eps=1e-8): step = state.step + 1 # Update moments new_m = jax.tree.map(lambda m, g: beta1 * m + (1 - beta1) * g, state.m, grads) new_v = jax.tree.map(lambda v, g: beta2 * v + (1 - beta2) * g**2, state.v, grads) # Bias correction bc1 = 1 - beta1 ** step bc2 = 1 - beta2 ** step m_hat = jax.tree.map(lambda m: m / bc1, new_m) v_hat = jax.tree.map(lambda v: v / bc2, new_v) # AdamW: weight decay applied to params directly, not through gradient new_params = jax.tree.map( lambda p, m, v: p * (1 - lr * wd) - lr * m / (jnp.sqrt(v) + eps), params, m_hat, v_hat ) return new_params, AdamState(m=new_m, v=new_v, step=step) # ── Training step (compiled) ── @jax.jit def train_step(params, opt_state, x_batch, y_batch): loss, grads = jax.value_and_grad(loss_fn)(params, x_batch, y_batch) # Gradient clipping (max norm = 1.0) grad_norm = jnp.sqrt(sum( jnp.sum(g**2) for g in jax.tree.leaves(grads) )) clip_coeff = jnp.minimum(1.0, 1.0 / (grad_norm + 1e-6)) grads = jax.tree.map(lambda g: g * clip_coeff, grads) # Update params new_params, new_opt_state = adam_update(params, grads, opt_state) return new_params, new_opt_state, loss, grad_norm # ── Training loop ── key = jax.random.PRNGKey(42) key, init_key = jax.random.split(key) params = init_params(init_key, d_in=784, d_hidden=256, d_out=10) opt_state = init_adam(params) for step in range(1000): key, data_key = jax.random.split(key) # (In real code, load from data pipeline) x_batch = jax.random.normal(data_key, (64, 784)) y_batch = jax.random.randint(data_key, (64,), 0, 10) params, opt_state, loss, grad_norm = train_step( params, opt_state, x_batch, y_batch ) if step % 100 == 0: print(f"Step {step}: loss={loss:.4f}, grad_norm={grad_norm:.4f}")
@jax.jit compiled — including the optimizer update and gradient clipping. This means the GPU runs the full step without returning to Python. (2) Gradient clipping is inside the jitted function, not outside — this avoids a device-to-host transfer of the gradient norm. (3) The key is split at each step and threaded through explicitly. (4) value_and_grad avoids computing the forward pass twice. These patterns matter at scale: on 2048 GPUs, Python overhead between steps can dominate training time if the step is fast enough.| Gotcha | What happens | Fix |
|---|---|---|
| Value-dependent shapes | x[x > 0] produces different-sized outputs for different inputs. jit cannot handle dynamic shapes. | Use jnp.where(x > 0, x, 0) (fixed shape) or jax.lax.dynamic_slice |
| In-place mutation | x[0] = 5 raises an error. JAX arrays are immutable. | Use x = x.at[0].set(5) (returns a new array) |
| Print in jit | print(x) prints a tracer object, not a value. | Use jax.debug.print("{x}", x=x) for runtime printing |
| Recompilation storm | Passing different-shaped inputs to jit causes recompilation every call. | Pad inputs to fixed shapes, or use jax.ensure_compile_time_eval() |
| NaN in grad | jnp.where(cond, safe_val, jnp.log(x)) — gradient of log(x) is computed even when x ≤ 0. | Use jnp.where(cond, safe_val, jnp.log(jnp.maximum(x, 1e-8))) |
| PRNG reuse | Using the same key twice gives identical "random" numbers. | Always jax.random.split before each use |
jnp.where(cond, a, b) computes gradients through BOTH branches, even the one not selected. If b = jnp.log(x) and x = 0, the gradient of log(0) is infinite — even though you "selected" branch a. This is because the gradient of where is grad_a * cond + grad_b * (1-cond), and inf * 0 = NaN. The fix: make sure both branches have safe gradients everywhere, not just safe values.@jax.jit def train_step(params, batch, step_num): where step_num is a Python integer that increases each step. They use it for learning rate scheduling: lr = 0.001 * min(1.0, step_num / 1000). The training is mysteriously slow (each step takes 5 seconds instead of the expected 0.5 seconds). What is the bug, and how do you fix it?It is 10:15 AM. The inference team Slack lights up: the new 70B model fits on a single 80 GB A100 in FP16 — barely. At 140 GB of weights, it overflows. Two options: buy a second GPU (double the cost, double the latency from tensor parallelism), or quantize. Shrink the weights from 16 bits to 4 bits, and the model fits with room to spare for the KV cache. But will the quality hold?
This is not a theoretical exercise. Every frontier lab ships quantized models. The question is how many bits and which method. Get it wrong and your chatbot starts hallucinating. Get it right and you serve 4x more users on the same hardware.
Here is the single most important insight in LLM inference: autoregressive decoding is memory-bandwidth-bound, not compute-bound. During generation, each token requires reading the entire weight matrix once (to compute one output vector), but the arithmetic intensity is tiny — you multiply a matrix by a single vector, not a matrix by a matrix.
Let us make this concrete. An A100 GPU has:
| Resource | A100 80GB SXM | H100 80GB SXM |
|---|---|---|
| FP16 Compute | 312 TFLOPS | 989 TFLOPS |
| Memory Bandwidth | 2.0 TB/s | 3.35 TB/s |
| Arithmetic Intensity Crossover | 156 FLOPs/byte | 295 FLOPs/byte |
For a matrix-vector multiply W · x where W is (d, d), x is (d, 1): you read d² · 2 bytes of weights (FP16) and perform 2d² FLOPs. The arithmetic intensity is:
One. One FLOP per byte. The A100 needs 156 to be compute-bound. You are 156x below the threshold. The GPU spends virtually all its time waiting for weights to arrive from HBM. This means: if you halve the weight size, you nearly double generation speed. Not because you do less math — because you move less data.
Quantization maps continuous (or high-precision) values to a smaller set of discrete levels. A FP16 weight can represent ~65,536 distinct values. An INT4 weight can represent only 16. The art is choosing which 16 values best represent the original distribution.
Symmetric quantization maps the range [-α, α] to integer range [-2b-1, 2b-1 - 1], using a single scale factor:
Dequantization is simply x̂ = xq · s. This is fast because there is no zero-point offset, but wasteful when the distribution is asymmetric (e.g., mostly positive activations after ReLU).
Asymmetric quantization adds a zero-point z to handle shifted distributions:
Dequantization: x̂ = (xq - z) · s. More accurate for skewed distributions, but the extra subtraction costs a few cycles per element.
The choice of how many weights share a single (s, z) pair dramatically affects accuracy:
| Granularity | What shares (s, z) | Overhead | Accuracy | When to use |
|---|---|---|---|---|
| Per-tensor | All weights in the entire matrix | 2 values total | Worst — outliers ruin scale for everyone | Almost never for weights |
| Per-channel | One row (output channel) | 2 values per row | Good — each neuron gets its own range | Standard for INT8 |
| Per-group | g consecutive weights (g=32, 64, 128) | 2 values per group | Best — fine-grained adaptation | Standard for INT4 and below |
Per-group quantization with group size g = 128 means that for a weight matrix of shape (4096, 4096), you store 4096 · (4096/128) = 131,072 scale/zero-point pairs. At FP16 each, that is 0.5 MB of overhead versus the 32 MB of quantized weights. Tiny price for dramatically better accuracy.
In 2022, Tim Dettmers made a pivotal observation: transformer weights have emergent outlier features. A tiny fraction of hidden dimensions (as few as 6 out of 4096) have magnitudes 10-100x larger than the rest. These outliers appear in every layer, in the same feature dimensions, and they are critical for model quality. Quantize them to INT8 and the model breaks. Leave the rest in INT8 and the model is fine.
LLM.int8() implements mixed-precision decomposition:
python import torch def llm_int8_matmul(X, W, threshold=6.0): """Mixed-precision decomposition from Dettmers et al. 2022.""" # Step 1: find outlier feature dimensions outlier_mask = X.abs().max(dim=0).values > threshold # shape: (d_in,) normal_mask = ~outlier_mask # Step 2: decompose into outlier and normal parts X_outlier = X[:, outlier_mask].half() # (batch, d_outlier) in FP16 W_outlier = W[outlier_mask, :].half() # (d_outlier, d_out) in FP16 X_normal = X[:, normal_mask] # (batch, d_normal) W_normal = W[normal_mask, :] # (d_normal, d_out) # Step 3: quantize the normal part to INT8 sx = X_normal.abs().max(dim=1, keepdim=True).values / 127 sw = W_normal.abs().max(dim=0, keepdim=True).values / 127 X_int8 = (X_normal / sx).round().to(torch.int8) W_int8 = (W_normal / sw).round().to(torch.int8) # Step 4: compute both parts out_normal = (X_int8.float() @ W_int8.float()) * (sx * sw) # dequantize out_outlier = X_outlier @ W_outlier # FP16 matmul return (out_normal + out_outlier).half()
While LLM.int8() proved mixed-precision works at 8 bits, a group at Cornell led by Chris De Sa asked a bolder question: can we go to 2 bits per weight without destroying the model? The answer required ideas from information theory and lattice coding.
QuiP (2023) introduced two key ideas:
1. Incoherence processing. The problem with naive quantization is that some weight columns have much larger norms than others. If you apply the same quantization grid to all columns, the high-norm columns suffer. QuiP multiplies the weight matrix by random orthogonal matrices U and V: W̃ = U · W · VT. This "spreads" the information evenly across all entries (makes the matrix incoherent), so uniform quantization works better. You store U and V as random seeds (cheap) and undo the rotation during inference.
2. LDLQ (Lattice-based Data-aware Quantization). Instead of rounding each weight independently (which ignores correlations), LDLQ uses the inverse Hessian H-1 = (XTX)-1 to quantize weights in an order that accounts for their mutual information. The Hessian tells you which weight errors are cheap (dimensions the model rarely uses) and which are expensive (dimensions on which loss is sensitive). LDLQ performs LDLT decomposition of the Hessian and uses it to guide a greedy per-column rounding scheme similar to GPTQ.
This is a weighted quantization error where H = 2XTX is the Hessian of the layer's squared error. LDLQ solves it approximately via the LDLT factorization of H, processing one row at a time with adaptive rounding.
QuiP# (2024) made two improvements that turned QuiP from a proof of concept into a practical method:
1. Hadamard rotations instead of random orthogonal matrices. A random orthogonal matrix multiply costs O(d²). A Hadamard transform costs O(d log d) and achieves the same incoherence property. This is because Hadamard matrices have the special property that all entries have the same magnitude (1/√d), which is exactly the definition of maximal incoherence. QuiP# replaces the expensive random rotations with fast Walsh-Hadamard transforms — a huge speedup at inference time.
2. E8 lattice codebook. Instead of rounding to a uniform integer grid, QuiP# rounds to the nearest point in the E8 lattice — the densest known sphere packing in 8 dimensions. Why does this matter? The quantization error of a codebook depends on how efficiently it tiles space. A uniform grid in 8D wastes space in the corners of each hypercube. The E8 lattice packs more representable points into the same volume, reducing the average distance from any real weight vector to its nearest codebook entry. At 2 bits per weight, QuiP# with E8 lattice achieves perplexity close to what older methods achieve at 3 bits.
python # Conceptual: E8 lattice quantization in 8D # The E8 lattice consists of all vectors in R^8 whose coordinates # are either all integers or all half-integers, and whose sum is even. import numpy as np def nearest_e8(x): """Find the nearest E8 lattice point to vector x in R^8.""" # Two candidates: round to integers, or round to half-integers c1 = np.round(x) # nearest integers if np.sum(c1) % 2 != 0: # enforce even sum idx = np.argmin(np.abs(x - c1)) # flip the closest-to-boundary entry c1[idx] += 1 if x[idx] > c1[idx] else -1 c2 = np.round(x - 0.5) + 0.5 # nearest half-integers if np.sum(c2) % 2 != 0: idx = np.argmin(np.abs(x - c2)) c2[idx] += 1 if x[idx] > c2[idx] else -1 # Pick the closer candidate d1 = np.sum((x - c1)**2) d2 = np.sum((x - c2)**2) return c1 if d1 <= d2 else c2
QTIP (2024) pushed even further. Instead of the E8 lattice (which is optimal in 8D but doesn't generalize easily), QTIP uses trellis codes — a technique borrowed from telecommunications. A trellis code defines a state machine where each transition corresponds to a quantized symbol. The Viterbi algorithm finds the sequence of quantized weights that minimizes total error while respecting the code constraints. This is analogous to how your cell phone decodes noisy signals, but applied to weight quantization.
The key advantage of trellis codes: they naturally capture sequential dependencies between adjacent weights. If weight[i] was rounded up, the trellis can compensate by rounding weight[i+1] down. This error diffusion is automatic and optimal (in the Viterbi sense).
AQLM (2024) takes a completely different approach: multi-codebook additive quantization. Instead of quantizing each weight independently, AQLM groups weights into vectors of size g (typically 8) and represents each vector as the sum of entries from M learned codebooks:
Each codebook has 2B entries (e.g., B=8 gives 256 entries per codebook). With M=2 codebooks, the total number of representable vectors is 256 × 256 = 65,536 — far more than the 16 levels of straight INT4. The indices (i1, i2) are stored as integers, and the codebooks are learned via the beam search variant of product quantization with calibration data.
The effective bit rate is (M · B) / g bits per weight. With M=2, B=8, g=8: (2 · 8) / 8 = 2 bits per weight. But the representational power is vastly greater than naive 2-bit quantization because the codebook entries are learned from the actual weight distribution.
python # AQLM dequantization (simplified) def aqlm_dequant(indices, codebooks, group_size=8): """ indices: (n_groups, M) tensor of codebook indices codebooks: list of M tensors, each (2^B, group_size) Returns: (n_groups * group_size,) dequantized weights """ M = len(codebooks) n_groups = indices.shape[0] result = torch.zeros(n_groups, group_size) for m in range(M): result += codebooks[m][indices[:, m]] # additive: sum across codebooks return result.reshape(-1)
Each paper in the quantization lineage solved a specific limitation of its predecessor:
| Method | Year | Bits | Key Idea | Llama-2 70B PPL (WikiText) | Speed vs FP16 |
|---|---|---|---|---|---|
| FP16 baseline | — | 16 | — | 3.32 | 1.0x |
| LLM.int8() | 2022 | 8 (mixed) | Outlier decomposition | 3.33 | ~1.0x (overhead from decomposition) |
| GPTQ | 2022 | 4 | OBS-based layer-wise quant | 3.85 | ~3.2x |
| AWQ | 2023 | 4 | Activation-aware scaling | 3.62 | ~3.4x |
| QuiP | 2023 | 2-4 | Incoherence + LDLQ | 5.40 (2-bit) | ~1.5x (slow rotations) |
| QuiP# | 2024 | 2 | Hadamard + E8 lattice | 4.16 | ~3.0x |
| AQLM | 2024 | 2 | Multi-codebook additive | 4.21 | ~2.5x (lookup overhead) |
| QTIP | 2024 | 2 | Trellis codes + Viterbi | 3.94 | ~2.8x |
Visualizes a transformer weight distribution with outlier features (red spikes). Adjust bit width and see how the quantization grid resolves the distribution. Notice how outliers distort the grid at low bit widths.
When a new model lands on your desk and you need to deploy it, here is the decision process a frontier lab engineer follows:
python import torch import torch.nn.functional as F def quantize_per_group(W, bits=4, group_size=128): """ Per-group asymmetric quantization. W: (out_features, in_features) weight tensor Returns: quantized weights, scales, zero_points """ out_f, in_f = W.shape assert in_f % group_size == 0, "in_features must divide group_size" n_groups = in_f // group_size # Reshape to (out_features * n_groups, group_size) W_grouped = W.reshape(out_f * n_groups, group_size) # Compute per-group min/max w_min = W_grouped.min(dim=1).values # (out_f * n_groups,) w_max = W_grouped.max(dim=1).values # Compute scale and zero-point q_max = 2 ** bits - 1 scale = (w_max - w_min) / q_max scale = torch.clamp(scale, min=1e-8) # avoid division by zero zero_point = torch.round(-w_min / scale).clamp(0, q_max).to(torch.int8) # Quantize W_q = torch.round(W_grouped / scale.unsqueeze(1)) + zero_point.unsqueeze(1) W_q = W_q.clamp(0, q_max).to(torch.uint8) # Verify: dequantize and check error W_deq = (W_q.float() - zero_point.unsqueeze(1).float()) * scale.unsqueeze(1) error = (W_grouped - W_deq).pow(2).mean().sqrt() print(f"RMSE: {error:.6f}, relative: {error / W_grouped.abs().mean():.4f}") return W_q.reshape(out_f, in_f), scale.reshape(out_f, n_groups), zero_point.reshape(out_f, n_groups) # Example: quantize a 4096x4096 matrix to 4-bit with group_size=128 W = torch.randn(4096, 4096) * 0.02 # typical init scale # Inject outlier features (simulating the LLM.int8() discovery) W[:, 42] *= 15 # feature 42 is an outlier W[:, 1337] *= 12 # feature 1337 is an outlier W_q, scales, zps = quantize_per_group(W, bits=4, group_size=128) # RMSE: ~0.0029, relative: ~0.18 # Memory: 4096*4096*0.5 bytes (4-bit) + scales + zps ≈ 8.4 MB vs 33.6 MB (FP16)
It is 11:30 AM. You are reviewing a profiling trace from the training cluster. Attention layers consume 42% of wall-clock time and, more importantly, 68% of peak HBM usage. The model processes sequences of length 8192, which means the attention matrix is 8192 × 8192 = 67 million entries per head, per layer. With 32 heads across 80 layers, that is 171 billion intermediate values materialized in memory. Your hardware budget says you need 32K context. Quadrupling the sequence length quadruples memory but sixteen-tuples the attention matrix. Something has to give.
This is the problem FlashAttention solves. Not by changing the math — the output is bit-for-bit identical to standard attention — but by changing how the math is executed on the hardware.
The textbook attention computation proceeds in three steps:
Where Q, K, V have shape (N, d) — N tokens, d head dimension (typically 64 or 128). The problem is the intermediate matrix S (and P), which has shape (N, N). Let us trace the memory access pattern:
python # Standard attention — what actually happens on the GPU # Step 1: S = Q @ K.T / sqrt(d) # Read Q (N×d) from HBM, read K (N×d) from HBM # Compute matmul (fast on tensor cores) # Write S (N×N) to HBM ← THE BOTTLENECK S = Q @ K.T / math.sqrt(d) # O(N²d) FLOPs, O(N²) HBM writes # Step 2: P = softmax(S, dim=-1) # Read S (N×N) from HBM # Compute softmax (element-wise, cheap) # Write P (N×N) to HBM P = torch.softmax(S, dim=-1) # O(N²) FLOPs, O(N²) HBM reads + writes # Step 3: O = P @ V # Read P (N×N) from HBM, read V (N×d) from HBM # Compute matmul # Write O (N×d) to HBM O = P @ V # O(N²d) FLOPs, O(N²) HBM reads
Total HBM accesses: O(N²) reads and writes for S and P. For N = 8192, d = 128: the compute is 8192² × 128 × 2 ≈ 17.2 billion FLOPs (trivial for a GPU). But reading and writing S and P requires 8192² × 2 bytes × 2 (read + write) × 2 (for S and P) ≈ 1 GB of memory traffic. Per head. Per layer. This is the problem.
Tri Dao's key insight: never materialize the N × N matrix. Instead, compute attention in tiles that fit in SRAM (the GPU's fast on-chip memory, ~20 MB on A100 vs 80 GB HBM). Each tile computes a partial result and accumulates it, without ever writing the full S or P to HBM.
The challenge: softmax is a non-decomposable operation. To compute softmax(S[i, :]), you need max(S[i, :]) and sum(exp(S[i, :] - max)), which requires the entire row. If you are processing the row in blocks, you do not have the full row available. You need the online softmax trick.
Standard softmax of row s = [s1, ..., sN]:
Now suppose you have processed the first block of B columns and computed:
Then you process the second block and find a new maximum m(2) = max(m(1), max(sB+1, ..., s2B)). The old partial sum ℓ(1) was computed with respect to m(1), not m(2). We need to rescale:
And the partial output O(1) (which was already multiplied by the old softmax weights) must also be rescaled:
This is the online softmax algorithm: maintain a running (m, ℓ, O) triple, and after each new block, rescale the accumulated result by exp(mold - mnew). The final result is mathematically identical to computing softmax on the full row.
python def flash_attention_forward(Q, K, V, block_size=64): """ Simplified FlashAttention-1 forward pass. Q, K, V: (N, d) tensors Returns O: (N, d) — same as standard attention """ N, d = Q.shape O = torch.zeros(N, d) m = torch.full((N,), float('-inf')) # running max per query l = torch.zeros(N) # running sum per query # Iterate over KV blocks for j in range(0, N, block_size): j_end = min(j + block_size, N) K_block = K[j:j_end] # (B, d) V_block = V[j:j_end] # (B, d) # Iterate over Q blocks (in practice, both loops are tiled) for i in range(0, N, block_size): i_end = min(i + block_size, N) Q_block = Q[i:i_end] # (B, d) # Compute tile of S = Q_block @ K_block.T / sqrt(d) S_tile = Q_block @ K_block.T / (d ** 0.5) # (B, B) — fits in SRAM! # Online softmax: update running stats m_new = torch.max(m[i:i_end], S_tile.max(dim=-1).values) correction = torch.exp(m[i:i_end] - m_new) P_tile = torch.exp(S_tile - m_new.unsqueeze(-1)) l_new = l[i:i_end] * correction + P_tile.sum(dim=-1) # Rescale accumulated output and add new contribution O[i:i_end] = O[i:i_end] * correction.unsqueeze(-1) + P_tile @ V_block # Update running stats m[i:i_end] = m_new l[i:i_end] = l_new # Final normalization O = O / l.unsqueeze(-1) return O
Let us be precise about the IO complexity. Define M = SRAM size in elements.
| Algorithm | FLOPs | HBM Reads/Writes | Extra HBM for Intermediates |
|---|---|---|---|
| Standard Attention | O(N²d) | O(N² + Nd) | O(N²) — store S and P |
| FlashAttention-1 | O(N²d) | O(N²d² / M) | O(N) — only m and ℓ vectors |
Same FLOPs. Dramatically less memory traffic. And the intermediate memory drops from O(N²) to O(N) — this is what enables 32K, 64K, 128K context lengths.
FlashAttention-1 was IO-optimal but left GPU SMs underutilized. FA-2 made two key changes:
1. Swap the loop order. FA-1 iterates over KV blocks in the outer loop and Q blocks in the inner loop. FA-2 reverses this: iterate over Q blocks in the outer loop. Why? Each Q block now runs on a separate thread block (SM), and the inner KV loop is serial within that block. This means the number of thread blocks = N / BQ, which is much larger than the number of SMs, giving the GPU scheduler plenty to work with.
2. Reduce non-matmul FLOPs. On modern GPUs (A100, H100), matmul operations run on specialized tensor cores at 10-16x the throughput of general-purpose FP32 ops. The rescaling in online softmax uses exp() and division, which run on the slower CUDA cores. FA-2 carefully restructures the computation to minimize these non-matmul operations, achieving ~70% of theoretical peak matmul throughput (vs ~35% for FA-1).
3. Work partitioning within warps. FA-2 splits Q among warps but shares K/V across warps. This minimizes communication between warps (no cross-warp reduction needed for the softmax statistics) while maximizing the reuse of K/V from shared memory.
The H100 Hopper GPU introduced three hardware features that FA-3 exploits:
| H100 Feature | What It Does | How FA-3 Uses It |
|---|---|---|
| WGMMA | Warp Group Matrix Multiply — issues matmul instructions across 4 warps (128 threads) | Larger tiles, fewer instructions. A 128×128×64 matmul in one WGMMA call vs multiple HMMA calls on A100. |
| TMA | Tensor Memory Accelerator — hardware unit for async memory copies | Overlaps memory loads with computation. While WGMMA computes on tile N, TMA prefetches tile N+1 from HBM to shared memory. |
| FP8 Tensor Cores | Native 8-bit floating-point matmul at 2x FP16 throughput | FA-3 supports FP8 attention: Q and K in FP8, with careful handling of the softmax numerics (still in FP32 for stability). |
Warp specialization. FA-3 assigns different roles to different warp groups: "producer" warps issue TMA loads, "consumer" warps issue WGMMA instructions. They communicate through shared memory barriers. This is a software pipeline: while consumers compute on data that producers loaded in the previous step, producers are simultaneously loading the next batch. The result: near-zero memory latency hiding.
pseudocode // FA-3 warp-specialized pipeline (conceptual) // Two warp groups: PRODUCER and CONSUMER // PRODUCER warp group: for tile in kv_tiles: TMA_async_load(K_tile[next], smem_K[next_buf]) // hardware-accelerated TMA_async_load(V_tile[next], smem_V[next_buf]) signal(barrier[next_buf]) // tell consumer data is ready wait(barrier[curr_buf]) // wait for consumer to finish swap(curr_buf, next_buf) // CONSUMER warp group: for tile in kv_tiles: wait(barrier[curr_buf]) // wait for producer to load S = WGMMA(Q_reg, smem_K[curr_buf]) // 128×128×64 matmul P = online_softmax(S, &m, &ℓ) // on CUDA cores O += WGMMA(P, smem_V[curr_buf]) // accumulate output signal(barrier[curr_buf]) // tell producer buffer is free swap(curr_buf, next_buf)
The B200 Blackwell GPU doubles down on the trends from Hopper, and FA-4 exploits new capabilities:
CuTe DSL. FA-4 is written using NVIDIA's CuTe (Cute Templates) abstraction, which represents tiles as first-class objects with layout algebra. Instead of manually indexing into shared memory with pointer arithmetic, you declare a tile's layout (e.g., "128 rows × 64 cols, row-major, swizzled by 128 bytes") and CuTe generates the correct indexing code. This eliminates an entire class of shared memory bank conflict bugs.
Pingpong scheduling. Blackwell has even larger shared memory (228 KB per SM vs 228 KB on Hopper, but with a new partitioned structure). FA-4 uses a "pingpong" pattern where two warp groups alternate between computing and loading, with triple-buffering to keep both the compute pipeline and the memory pipeline fully saturated at all times.
FP4 support. Blackwell's tensor cores natively support FP4 matmul. FA-4 can run the Q×KT multiply in FP4 (with FP32 accumulation), achieving 4x the throughput of FP16 at the cost of some precision. For inference with quantized models, this is a game-changer.
| Version | Year | GPU Target | Key Innovation | % of Peak FLOPs | Speedup vs Standard |
|---|---|---|---|---|---|
| Standard | — | Any | Baseline (materializes N×N) | ~15% | 1.0x |
| FA-1 | 2022 | A100 | Tiling + online softmax | ~35% | 2-4x |
| FA-2 | 2023 | A100/H100 | Q-outer loop, warp partitioning | ~70% | 3-5x |
| FA-3 | 2024 | H100 Hopper | WGMMA + TMA + warp specialization + FP8 | ~80% | 1.5-2x over FA-2 on H100 |
| FA-4 | 2025 | B200 Blackwell | CuTe DSL + pingpong scheduling + FP4 | ~85% | ~2x over FA-3 on B200 |
Watch how FlashAttention processes the Q×KT matrix in tiles without materializing the full N×N matrix. The highlighted tile is the current computation in SRAM. Standard attention would need the entire grid in HBM.
The forward pass avoids materializing S and P, but the backward pass needs them for the gradient computation. FA-1 and FA-2 solve this by recomputing S and P from Q, K, V during the backward pass, which trades extra FLOPs for HBM savings. Since the forward pass is memory-bound (extra FLOPs are free), this recomputation is essentially costless. FA-3 further optimizes the backward by keeping a small auxiliary tensor (the log-sum-exp per row, size N) from the forward pass to avoid some redundant work.
python # Using FlashAttention in practice (PyTorch 2.0+) import torch import torch.nn.functional as F # Option 1: PyTorch native SDPA (auto-selects FlashAttention when possible) with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION): out = F.scaled_dot_product_attention(Q, K, V, is_causal=True) # Option 2: Direct flash-attn library from flash_attn import flash_attn_func # Q, K, V shape: (batch, seqlen, nheads, headdim) out = flash_attn_func(Q, K, V, causal=True) # Option 3: FA-3 with FP8 on H100 out = flash_attn_func( Q.to(torch.float8_e4m3fn), # FP8 E4M3 format K.to(torch.float8_e4m3fn), V.to(torch.float8_e4m3fn), causal=True, softmax_scale=1.0 / (d ** 0.5), ) # Memory comparison for sequence length 32768, 32 heads, dim 128: # Standard: 32768² × 32 × 2 bytes = 64 GB (!) — doesn't fit on any GPU # FlashAttn: 32768 × 32 × 4 bytes = 4 MB (just the L vector)
It is 1:15 PM, back from lunch. A Slack thread is debating which framework to use for a custom fused kernel: someone advocates Triton, another insists on CUTLASS, a third says ThunderKittens is the only way to get 80%+ utilization on Hopper. The team lead posts a screenshot of a timeline from an AI infrastructure newsletter showing eleven new kernel DSLs released in the past twelve months. "We need a decision by Friday," she writes.
This explosion is not arbitrary. It reflects a fundamental tension in GPU programming that every frontier lab grapples with: productivity vs performance. Writing CUDA by hand gives you total control over shared memory layout, warp scheduling, and instruction selection. But a single fused attention kernel can take 2,000+ lines of CUDA and weeks to debug. Triton gives you 100 lines of Python for the same kernel, but the compiler might leave 30% of the hardware idle because it made suboptimal scheduling choices. Every DSL in the 2025 explosion stakes out a different position on this tradeoff.
Let us map the landscape. At one end: raw CUDA/PTX, where you control every register and every instruction. At the other: PyTorch's autograd, where you write math and hope the compiler figures out the rest.
| Level | Tool | Abstraction | Lines for Fused GEMM+ReLU | % Peak FLOPS (A100) |
|---|---|---|---|---|
| 0: Hardware | PTX / SASS | Individual instructions | ~3000 | 95%+ |
| 1: Low-level | CUDA | Threads, warps, shared memory | ~800 | 85-90% |
| 2: Template | CUTLASS / CuTe | Tiles, layouts, MMA descriptors | ~200 | 80-90% |
| 3: Register-tile | ThunderKittens | 16×16 register tiles as objects | ~100 | 75-85% |
| 4: Compiler | Triton | Block-level programs | ~50 | 65-80% |
| 5: Auto-schedule | TileLang, Helion | Declarative tiles + auto-tuning | ~30 | 60-75% |
| 6: Graph | PyTorch compile | Math expressions | ~5 | 40-60% |
Every kernel ultimately compiles to CUDA (or rather, to PTX and then SASS). Understanding the CUDA programming model is non-negotiable for a frontier lab engineer, even if you rarely write raw CUDA yourself.
The key abstractions: threads execute scalar operations. Warps (32 threads) execute in lockstep (SIMT). Thread blocks (up to 1024 threads) share SRAM. Grids of thread blocks fill the GPU. The programmer must manually manage:
cuda // CUDA GEMM kernel — simplified, no tiling optimizations // This is what you're trying to AVOID writing by using a DSL __global__ void matmul_naive(float* A, float* B, float* C, int M, int N, int K) { int row = blockIdx.y * blockDim.y + threadIdx.y; int col = blockIdx.x * blockDim.x + threadIdx.x; if (row < M && col < N) { float sum = 0.0f; for (int k = 0; k < K; k++) { sum += A[row * K + k] * B[k * N + col]; } C[row * N + col] = sum; } } // Performance: ~5% of peak on A100. Why? No shared memory, // no tiling, no tensor core usage, terrible memory access pattern. // A production GEMM kernel is 50-100x more complex.
Triton (by OpenAI, first released 2021, mature by 2023) lets you write GPU kernels in Python. The key abstraction is the block program: your code operates on blocks (tiles) of data, and the Triton compiler handles the mapping to warps, shared memory, and tensor cores.
python import triton import triton.language as tl @triton.jit def matmul_kernel( A_ptr, B_ptr, C_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ): # Which tile of C does this program compute? pid_m = tl.program_id(0) pid_n = tl.program_id(1) # Compute offsets for this tile offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_k = tl.arange(0, BLOCK_K) # Pointers to first tile of A and B a_ptrs = A_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak b_ptrs = B_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn # Accumulate in FP32 for numerical stability acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) # Loop over K dimension in blocks for k in range(0, K, BLOCK_K): a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k) acc += tl.dot(a, b) # Triton automatically uses tensor cores! a_ptrs += BLOCK_K * stride_ak b_ptrs += BLOCK_K * stride_bk # Write result c_ptrs = C_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn tl.store(c_ptrs, acc, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
What the compiler does for you: shared memory allocation, double-buffering (loading the next tile while computing the current one), warp scheduling, tensor core instruction selection, predicated loads for boundary tiles. What it does NOT do well: cross-SM communication, complex warp-level synchronization, fine-grained instruction scheduling (which is why hand-tuned CUDA can still beat Triton by 10-20% on specific kernels).
CUTLASS (NVIDIA) is a C++ template library for writing high-performance GPU kernels. Its modern incarnation, CuTe, introduces layout algebra — a type system for describing how data is arranged in memory. A layout is a function from logical coordinates to physical memory addresses.
c++ // CuTe layout example: a 128×64 tile in shared memory // with 128-byte swizzle to avoid bank conflicts using SmemLayout = Layout, Stride<_64, _1>>; // A swizzled version: XOR the row index with bytes-per-row // This eliminates bank conflicts for column-major access using SmemLayoutSwizzled = composition( Swizzle<3, 3, 3>{}, // 8-byte swizzle pattern SmemLayout{} ); // The layout algebra handles indexing automatically: // smem(row, col) returns the correct physical address // even with the swizzle — no manual index math needed.
CUTLASS is what NVIDIA uses internally for cuBLAS. If you need absolute peak performance on NVIDIA hardware and are willing to invest in C++ template wizardry, it is the tool.
ThunderKittens (Stanford, 2024) takes a radical position: the right abstraction level is the register tile — a 16×16 matrix that lives in a warp's registers and maps directly to a single tensor core MMA instruction.
c++ // ThunderKittens: the entire abstraction in a few types using namespace kittens; // Register tiles: 16×16 matrices that live in warp registers rt_fl<16, 16> a_tile, b_tile, c_tile; // float register tiles rt_bf<16, 16> a_bf, b_bf; // bfloat16 register tiles // Shared memory tiles st_bf<64, 64> smem_a, smem_b; // shared memory tiles // Operations map directly to hardware mma_ABt(c_tile, a_bf, b_bf, c_tile); // tensor core MMA: C += A @ B^T load(a_bf, smem_a, row_offset, col_offset);// load from shared to register store(smem_a, a_bf, row_offset, col_offset);// store from register to shared
The philosophy: "If it's not a 16×16 tile, you're not thinking at the right level." ThunderKittens forces you to decompose your algorithm into register-tile operations, which guarantees that the tensor cores are always engaged. The downside: you still manage shared memory layout and warp scheduling manually.
Starting in mid-2024, a wave of new DSLs appeared, each targeting a specific niche in the productivity-performance space:
| DSL | Origin | Key Idea | Target Use Case |
|---|---|---|---|
| TileLang | HKUST / TVM | Python DSL compiled via TVM, multi-backend (CUDA, ROCm) | Cross-platform kernel development |
| Helion | Meta (PyTorch team) | Python DSL with autotuning built in, integrated with torch.compile | PyTorch ecosystem, research prototyping |
| DeepGEMM | DeepSeek | JIT-compiled GEMM library, no installation, pure header | Fast FP8 GEMMs for MoE inference |
| Mirage / MPK | CMU | Superoptimizer that searches over kernel implementations | Automatically finding optimal kernels |
| ThunderKittens 2.0 | Stanford | Blackwell support, TMA integration, multi-GPU tiles | Cutting-edge hardware utilization |
| CUDA Tile / TileIR | NVIDIA | Official tile-level IR, successor to CuTe | NVIDIA's answer to the DSL explosion |
| Tilus | Modular (Mojo team) | Tile-level DSL compiled to MLIR | Multi-hardware (GPU + TPU + custom) |
| Gluon / TLX | Various startups | Higher-level fusion DSLs that auto-tile | Non-expert kernel writers |
To make the tradeoffs concrete, here is the conceptual core of a tiled GEMM in three DSLs. All three compute C += A × B in tiles.
python — triton # Triton: ~50 lines, compiler handles shared memory + scheduling @triton.jit def gemm(A, B, C, M, N, K, BM: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr): pid_m, pid_n = tl.program_id(0), tl.program_id(1) rm = pid_m * BM + tl.arange(0, BM) rn = pid_n * BN + tl.arange(0, BN) acc = tl.zeros((BM, BN), dtype=tl.float32) for k in range(0, K, BK): rk = k + tl.arange(0, BK) a = tl.load(A + rm[:, None] * K + rk[None, :]) b = tl.load(B + rk[:, None] * N + rn[None, :]) acc += tl.dot(a, b) tl.store(C + rm[:, None] * N + rn[None, :], acc)
c++ — cutlass cute // CuTe: ~200 lines, explicit layout algebra, explicit pipeline // (heavily abbreviated — real code has copy engines, pipeline barriers) auto tiled_mma = make_tiled_mma(SM80_16x8x16_F32BF16BF16F32{}, Layout>{}); auto smem_a = make_tensor(smem_ptr, SmemLayoutA{}); auto smem_b = make_tensor(smem_ptr + smem_a_size, SmemLayoutB{}); for (int k = 0; k < K; k += BK) { copy(gmem_a_slice, smem_a); // global → shared (via TMA on Hopper) copy(gmem_b_slice, smem_b); __syncthreads(); gemm(tiled_mma, smem_a, smem_b, acc); // tensor core MMA __syncthreads(); }
c++ — thunderkittens // ThunderKittens: ~100 lines, register-tile-centric using namespace kittens; st_bfsmem_a; st_bf smem_b; rt_fl<16, 16> acc[BM/16][BN/16] = {}; // register tile grid for (int k = 0; k < K; k += BK) { load(smem_a, A + k, K); // HBM → shared load(smem_b, B + k * N, N); __syncthreads(); for (int i = 0; i < BM/16; i++) for (int j = 0; j < BN/16; j++) mma_ABt(acc[i][j], smem_a.get_tile(i), smem_b.get_tile(j), acc[i][j]); __syncthreads(); }
Same algorithm. Three levels of abstraction. Triton is shortest but gives you least control over the schedule. CuTe is longest but lets you specify exactly how data flows through the memory hierarchy. ThunderKittens is in between — you think in register tiles, but the library handles tile loads and MMA instruction selection.
DSLs positioned on the productivity-performance axis. Hover/tap a node to see details. The vertical axis shows relative performance (% of peak FLOPs). The horizontal axis shows developer productivity (inverse of lines of code).
Despite the explosion, the DSLs are converging on a shared set of abstractions: tiles as first-class objects, pipeline stages for memory latency hiding, layout descriptors for swizzled shared memory, and auto-tuning for block-size selection. The eventual outcome is likely one or two dominant frameworks (Triton + CUTLASS/CuTe) with specialized tools for niche use cases.
It is 2:30 PM. The product team wants to launch a feature that requires 128K context: users upload a 200-page document and ask questions about it. Your current serving stack handles 8K context comfortably. At 128K, the KV cache — the stored key and value tensors from previous tokens — consumes more memory than the model weights themselves. Your 70B model at FP16 uses 140 GB for weights. The KV cache at 128K tokens adds another 160 GB. You need a plan.
This chapter covers every major technique for taming the KV cache: architectural changes (MQA, GQA), system-level tricks (PagedAttention), intelligent eviction (SnapKV, H2O, StreamingLLM), and quantization. By the end, you will know exactly which technique to apply for which use case.
Let us derive the exact memory cost. For a transformer with:
| Symbol | Meaning | Typical (Llama-2 70B) |
|---|---|---|
| L | Number of layers | 80 |
| nkv | Number of KV heads per layer | 8 (GQA) or 64 (MHA) |
| dh | Head dimension | 128 |
| S | Sequence length (current + generated) | 8,192 to 128,000+ |
| b | Bytes per element | 2 (FP16) or 1 (FP8) or 0.5 (INT4) |
Each layer stores K and V tensors of shape (nkv, S, dh). The total KV cache memory:
The factor of 2 is for K and V. Let us plug in numbers:
python def kv_cache_memory_gb(L, n_kv, d_h, S, bytes_per_elem=2): """Compute KV cache memory in GB.""" total_bytes = 2 * L * n_kv * S * d_h * bytes_per_elem return total_bytes / (1024**3) # Llama-2 70B with MHA (64 KV heads) print(kv_cache_memory_gb(80, 64, 128, 8192)) # → 20 GB at 8K context print(kv_cache_memory_gb(80, 64, 128, 32768)) # → 80 GB at 32K context print(kv_cache_memory_gb(80, 64, 128, 131072)) # → 320 GB at 128K context! # Llama-2 70B with GQA (8 KV heads) — 8x reduction print(kv_cache_memory_gb(80, 8, 128, 8192)) # → 2.5 GB print(kv_cache_memory_gb(80, 8, 128, 131072)) # → 40 GB — manageable! # With INT4 KV quantization on top of GQA: print(kv_cache_memory_gb(80, 8, 128, 131072, 0.5)) # → 10 GB — fits easily
The most effective KV cache reduction is architectural — baked into the model during training:
Multi-Head Attention (MHA): standard transformer. Each of the nh query heads has its own K and V head. KV cache stores nh × S × dh per layer. This is what GPT-3, Llama-1, and Llama-2 7B/13B use.
Multi-Query Attention (MQA): all query heads share ONE K head and ONE V head. KV cache drops to 1 × S × dh per layer — an nh-fold reduction (typically 32-64x). Introduced by Noam Shazeer in 2019. Used by PaLM, Falcon, StarCoder. The downside: some quality degradation because all heads are forced to attend to the same keys/values.
Grouped-Query Attention (GQA): the compromise. Group every g query heads to share one KV head. With nh = 64 query heads and g = 8 groups: nkv = 8 KV heads. An 8x reduction over MHA with minimal quality loss. Used by Llama-2 70B, Llama-3, Mistral, Gemma. GQA is now the de facto standard for large models.
| Scheme | KV Heads | KV Cache Size (relative) | Quality Impact | Models Using It |
|---|---|---|---|---|
| MHA | nh (e.g., 64) | 1.0x | Baseline | GPT-3, Llama-1, Llama-2 7B/13B |
| GQA | nh/g (e.g., 8) | 1/g (e.g., 0.125x) | Negligible (<0.5% PPL) | Llama-2 70B, Llama-3, Mistral |
| MQA | 1 | 1/nh (e.g., 0.016x) | Small (1-2% PPL) | PaLM, Falcon, StarCoder |
Even with GQA, serving multiple concurrent requests requires careful memory management. The naive approach: pre-allocate a contiguous buffer of max_seq_len × n_kv × d_h for each request. If max_seq_len = 128K but the average request uses 4K tokens, you waste 97% of the allocated memory. With 32 concurrent requests, the waste is hundreds of GB.
PagedAttention (vLLM, 2023) borrows the virtual memory concept from operating systems. Instead of contiguous buffers, KV cache entries are stored in fixed-size pages (blocks) of, say, 16 tokens. Pages are allocated on demand as the sequence grows, and reclaimed when the sequence ends. Pages need not be contiguous in physical GPU memory — a page table maps logical token positions to physical page addresses.
python # PagedAttention: conceptual implementation class PagedKVCache: def __init__(self, page_size=16, n_layers=80, n_kv_heads=8, d_head=128): self.page_size = page_size self.n_layers = n_layers self.n_kv_heads = n_kv_heads self.d_head = d_head # Physical page pool: pre-allocate all GPU memory as pages self.total_pages = 65536 # e.g., 65K pages # Each page: (n_layers, 2, n_kv_heads, page_size, d_head) — K and V self.page_pool = torch.zeros( self.total_pages, n_layers, 2, n_kv_heads, page_size, d_head, dtype=torch.float16, device="cuda" ) self.free_pages = list(range(self.total_pages)) # free list # Per-sequence page tables: seq_id → list of page indices self.page_tables = {} # Dict[int, List[int]] def allocate_page(self, seq_id): """Allocate a new page for a sequence (called when it grows).""" if not self.free_pages: raise MemoryError("No free KV cache pages") page_idx = self.free_pages.pop() if seq_id not in self.page_tables: self.page_tables[seq_id] = [] self.page_tables[seq_id].append(page_idx) return page_idx def free_sequence(self, seq_id): """Release all pages for a completed sequence.""" for page_idx in self.page_tables.pop(seq_id, []): self.free_pages.append(page_idx) def get_kv(self, seq_id, layer, token_pos): """Read K/V for a specific token position.""" page_num = token_pos // self.page_size offset = token_pos % self.page_size page_idx = self.page_tables[seq_id][page_num] k = self.page_pool[page_idx, layer, 0, :, offset, :] # (n_kv_heads, d_head) v = self.page_pool[page_idx, layer, 1, :, offset, :] return k, v def utilization(self): used = self.total_pages - len(self.free_pages) return used / self.total_pages # Without paging: 32 requests × 128K tokens × ~0.31 GB each = ~10 GB # With paging: only allocate pages for actual tokens used # If average length is 4K: 32 × 4K × 0.31/128K × GB ≈ 0.31 GB — 32x saving!
SnapKV (2024) made a key observation: in long-context attention, most tokens receive negligible attention weight. A small set of "important" tokens dominate the attention pattern, and this set can be identified cheaply by looking at the attention pattern in a few observation layers near the end of the prefix.
The algorithm:
python def snapkv_compress(K, V, attn_scores, k=1024, window=64): """ SnapKV: attention-guided KV cache compression. K, V: (n_heads, seq_len, d_head) attn_scores: (n_heads, window, seq_len) — attention from last `window` tokens """ n_heads, seq_len, d_head = K.shape prefix_len = seq_len - window # Step 1: attention scores from observation window to prefix tokens prefix_scores = attn_scores[:, :, :prefix_len] # (n_heads, window, prefix_len) # Step 2: vote — sum attention across observation window positions votes = prefix_scores.sum(dim=1) # (n_heads, prefix_len) # Step 3: top-k selection per head _, top_indices = votes.topk(k, dim=-1) # (n_heads, k) top_indices, _ = top_indices.sort(dim=-1) # maintain order # Step 4: gather selected K/V + keep observation window K_prefix = K.gather(1, top_indices.unsqueeze(-1).expand(-1, -1, d_head)) V_prefix = V.gather(1, top_indices.unsqueeze(-1).expand(-1, -1, d_head)) K_window = K[:, prefix_len:, :] # last `window` tokens — always kept V_window = V[:, prefix_len:, :] K_compressed = torch.cat([K_prefix, K_window], dim=1) # (n_heads, k+window, d_head) V_compressed = torch.cat([V_prefix, V_window], dim=1) return K_compressed, V_compressed # seq_len reduced from S to k+window
H2O (2023) takes a different approach to KV eviction. Instead of a one-shot compression at prompt time, H2O maintains a running eviction policy during generation. It tracks the cumulative attention score for each token across all generation steps. Tokens that consistently receive high attention ("heavy hitters") are kept. Tokens with low cumulative scores are evicted.
python class H2OCache: """Heavy Hitter Oracle: cumulative attention-based eviction.""" def __init__(self, max_size=2048, window_size=256): self.max_size = max_size self.window_size = window_size # recent tokens always kept self.cumulative_scores = {} # token_pos → cumulative attention def update_and_evict(self, attn_weights, K, V): """ attn_weights: (n_heads, 1, current_cache_size) — attention from new token K, V: current KV cache tensors """ # Update cumulative scores scores = attn_weights.mean(dim=0).squeeze() # average across heads for pos in range(len(scores)): self.cumulative_scores[pos] = self.cumulative_scores.get(pos, 0) + scores[pos].item() # Check if eviction needed current_size = K.shape[1] if current_size <= self.max_size: return K, V # Protect recent tokens (sliding window) n_evictable = current_size - self.window_size evictable_scores = {pos: self.cumulative_scores[pos] for pos in range(n_evictable)} # Evict the token with lowest cumulative attention victim = min(evictable_scores, key=evictable_scores.get) keep_mask = torch.ones(current_size, dtype=torch.bool) keep_mask[victim] = False return K[:, keep_mask], V[:, keep_mask]
The key insight: in practice, attention patterns follow a power law. A small set of tokens (typically the first few tokens, punctuation, and key content words) receive the lion's share of attention. H2O exploits this by keeping the heavy hitters and evicting the long tail.
StreamingLLM (2023) discovered a remarkable phenomenon: the first few tokens in any sequence always receive disproportionately high attention, regardless of their content. These are called attention sinks. The hypothesis is that softmax needs somewhere to dump attention mass when no token is particularly relevant, and position 0 becomes the default "sink."
StreamingLLM's algorithm is dead simple:
StreamingLLM does NOT help with tasks that require accessing information from the evicted middle (e.g., "what was mentioned in paragraph 5 of a 100-paragraph document?"). It is designed for streaming use cases: chat, real-time transcription, monitoring dashboards.
Orthogonal to eviction strategies, you can quantize the KV cache itself from FP16 to INT8 or INT4. This is different from weight quantization because:
1. KV values change every token. You cannot pre-compute quantization parameters. You need an online quantization scheme that is fast enough to run per-token.
2. The error budget is different. Weight quantization errors are fixed — the same error on every inference. KV quantization errors accumulate: an error in the cached K at position 100 affects every subsequent attention computation forever. This makes KV quantization more sensitive.
3. The distribution is friendlier. KV cache values tend to be more uniformly distributed (post-RoPE, K values are roughly Gaussian) compared to weights (which have those pesky outlier features). INT8 KV quantization with per-head scales typically causes <0.1% perplexity increase. INT4 requires per-token-group scales and causes 0.3-1% increase.
python def quantize_kv_cache(K, V, bits=8): """ Per-head quantization of KV cache. K, V: (n_heads, seq_len, d_head) in FP16 Returns quantized tensors + per-head scales """ q_max = 2 ** (bits - 1) - 1 # symmetric: -127..127 for INT8 # Per-head scale (one scale per head) K_absmax = K.abs().amax(dim=(1, 2), keepdim=True) # (n_heads, 1, 1) V_absmax = V.abs().amax(dim=(1, 2), keepdim=True) K_scale = K_absmax / q_max V_scale = V_absmax / q_max K_q = (K / K_scale).round().clamp(-q_max, q_max).to(torch.int8) V_q = (V / V_scale).round().clamp(-q_max, q_max).to(torch.int8) # Memory: INT8 = 1 byte vs FP16 = 2 bytes → 2x compression # INT4 (packed) = 0.5 bytes → 4x compression return K_q, V_q, K_scale, V_scale def dequantize_for_attention(K_q, K_scale, Q): """Dequantize K on-the-fly during attention computation.""" K_deq = K_q.float() * K_scale # dequant before Q @ K^T attn = Q @ K_deq.transpose(-1, -2) return attn
In practice, frontier labs combine multiple techniques. Here is the typical stack for long-context serving:
| Layer | Technique | Compression Ratio | Quality Impact |
|---|---|---|---|
| Architecture | GQA (8 groups) | 8x | Negligible (trained in) |
| Memory management | PagedAttention | 2-4x (reduces waste) | Zero (exact) |
| Precision | INT8 KV cache | 2x | <0.1% PPL increase |
| Eviction | SnapKV (top-k=2048) | Up to 64x on 128K | ~1-3% quality loss on long-range tasks |
| Total | All combined | ~128-512x | Depends on task type |
Compare KV cache memory usage across different optimization strategies as sequence length grows. Toggle techniques on/off to see their individual and combined impact.
The phrase "Barbarians at the Gate" in the context of KV caches refers to the growing realization that the attention mechanism itself may need to change. All the techniques above are engineering patches on a fundamentally O(N) memory mechanism (linear in sequence length, per layer). State-space models (Mamba, RWKV), linear attention variants, and hybrid architectures are the "barbarians" — they offer O(1) memory per token by replacing the KV cache with a fixed-size recurrent state.
The counterargument: the KV cache is not just overhead. It IS the model's explicit memory. Eviction techniques lose information. Recurrent states compress information lossy. The KV cache is a perfect, lossless memory of every token the model has seen. The frontier lab engineer's job is to make this memory affordable, not to throw it away.
| Scenario | Best KV Cache Strategy | Why |
|---|---|---|
| Chatbot, 4K context | GQA + PagedAttention | Small cache, no compression needed |
| Document QA, 32K context | GQA + Paged + INT8 KV | Moderate compression, zero quality loss |
| Book summarization, 128K | GQA + Paged + INT8 + SnapKV (k=4096) | Heavy compression needed; SnapKV preserves retrieval-relevant tokens |
| Infinite streaming (chat) | StreamingLLM (sinks + window) | Fixed memory, infinite context, but no recall of old turns |
| Serving 1000+ concurrent users | GQA + Paged + INT4 KV + prefix caching | Maximize throughput; shared system prompt saves 30%+ memory |
| Beam search (width 4) | PagedAttention with copy-on-write | Beams share prefix pages, 75% savings |
You have a 70-billion parameter model. Each parameter is a BF16 number: 2 bytes. The model weights alone occupy 70B × 2 = 140 GB. A single H100 has 80 GB of HBM. The model does not even fit on one GPU — and we have not yet accounted for optimizer states, gradients, or activations.
This is the fundamental problem of large-scale training: the model, its optimizer state, its gradients, and its activations are collectively far too large for any single device. The solution is to split the work across many GPUs. But how you split matters enormously. A naive approach can leave 90% of your GPUs idle, waiting for communication. A well-engineered approach can achieve 40-55% of peak FLOPS utilization — and the difference between 20% and 50% MFU on a 2048-GPU cluster is millions of dollars in wasted compute.
This chapter covers the four pillars of distributed training: data parallelism, fully sharded data parallelism (FSDP), tensor parallelism, and pipeline parallelism. By the end, you will be able to derive the optimal parallelism configuration for any model-cluster pair.
Before choosing a parallelism strategy, you must know what consumes memory. Let us derive the full breakdown for a 70B parameter model trained in BF16 with AdamW optimizer:
Notice the optimizer dominates: 840 GB out of 1,120 GB. This is why ZeRO and FSDP focus primarily on sharding optimizer states.
Data Parallelism is the simplest distributed strategy. Every GPU holds a complete copy of the model. Each GPU processes a different mini-batch. After the forward and backward passes, gradients are synchronized across all GPUs via an all-reduce operation, and every GPU applies the same optimizer step.
DDP works beautifully for models that fit on one GPU (up to ~10B parameters in BF16 on an 80GB H100). The communication cost is a single all-reduce of the gradient tensor per step.
The all-reduce operation takes a tensor that exists on every GPU, sums (or averages) them, and distributes the result back to every GPU. The efficient implementation is the ring all-reduce, which works in two phases:
Phase 1: Reduce-Scatter. Each GPU sends 1/N of its gradient tensor to the next GPU in the ring. After N-1 steps, each GPU holds the fully reduced sum for its 1/N shard.
Phase 2: All-Gather. Each GPU sends its reduced shard around the ring. After N-1 steps, every GPU holds the complete reduced result.
The total data transferred per GPU in a ring all-reduce is:
DDP requires every GPU to hold the full model, gradients, and optimizer states. For a 70B model, that is 1,120 GB per GPU — impossible. FSDP (PyTorch's implementation of DeepSpeed ZeRO) fixes this by sharding everything across GPUs. Each GPU stores only 1/N of the weights, 1/N of the gradients, and 1/N of the optimizer states.
But if each GPU only has 1/N of the weights, how does it compute the forward pass? It all-gathers the full weights just before each layer's forward pass, uses them, then discards them. During the backward pass, it all-gathers again, computes gradients, then reduce-scatters to keep only its 1/N shard of the gradients.
The communication cost is higher than DDP. Let us derive it:
FSDP shards weights but requires an all-gather before every layer. Tensor parallelism takes a different approach: it splits each individual layer's computation across GPUs so that no GPU ever needs the full weight matrix.
Consider a linear layer Y = XW, where X is the input [B, dmodel] and W is the weight matrix [dmodel, dhidden]. Tensor parallelism splits W column-wise across T GPUs:
In a transformer, Megatron-LM chains column-parallel and row-parallel layers to minimize communication. Each transformer block requires exactly 2 all-reduces in the forward pass (one after the attention block, one after the FFN) and 2 all-reduces in the backward pass.
Pipeline parallelism splits the model vertically by layers. If you have 80 transformer layers and 8 pipeline stages, each GPU holds 10 consecutive layers. GPU 0 runs layers 0-9, GPU 1 runs layers 10-19, and so on.
The problem: naive pipelining has catastrophic bubble overhead. GPU 7 sits idle while GPUs 0-6 process the forward pass through their layers. Then GPU 0 sits idle while GPUs 1-7 finish the forward pass and start the backward pass.
The solution is microbatching. Instead of sending one large batch through the pipeline, split it into M microbatches. As soon as GPU 0 finishes the forward pass for microbatch 1, it starts microbatch 2 while GPU 1 works on microbatch 1. This fills the pipeline.
Given N GPUs and a model of size M parameters, which strategy do you use? Here is the decision framework that frontier labs actually follow:
| Condition | Strategy | Why |
|---|---|---|
| Model fits on 1 GPU with optimizer | DDP (pure data parallel) | Simplest, lowest communication overhead (2D per step) |
| Model fits on 1 GPU but optimizer does not | FSDP / ZeRO-1 or ZeRO-2 | Shard optimizer states across GPUs, keep model replicated |
| Model does not fit on 1 GPU | FSDP (ZeRO-3) + TP within node | Shard everything. Use TP=8 within each 8-GPU node (fast NVLink), FSDP across nodes |
| Model is very large (>100B) and cluster is huge | 3D parallelism: TP + PP + DP | TP within node, PP across a few nodes, DP across the rest. Minimizes cross-node communication. |
Let us derive the optimal parallelism configuration for training LLaMA 3 70B on a cluster of 2048 NVIDIA H100 GPUs, each with 80 GB HBM, connected by NVLink within 8-GPU nodes and InfiniBand across nodes.
The ultimate metric for training efficiency is MFU — what fraction of the GPU's theoretical peak FLOPS you are actually using for useful computation (excluding communication, bubbles, recomputation, and overhead).
Sometimes the global batch size you need is larger than what fits in GPU memory across all data-parallel replicas. Gradient accumulation solves this: run K forward-backward passes, accumulating gradients, then do one optimizer step.
Gradient accumulation is "free" in terms of communication — you only all-reduce once per K steps. But it increases latency per optimizer step by K×.
Training in BF16 saves memory and doubles throughput on tensor cores. But gradients can underflow to zero in BF16 (the smallest representable normal number is ~1.2×10-38 for the exponent, but mantissa precision is only ~0.01). Small gradients from early layers can vanish.
Loss scaling solves this: multiply the loss by a large constant S before backpropagation. This scales all gradients by S, pushing them away from the underflow zone. After the backward pass, divide gradients by S before the optimizer step.
| Strategy | Per-GPU Memory | Communication per Step | Communication Type | Best Interconnect |
|---|---|---|---|---|
| DDP | Full model + optimizer | 2 × (N-1)/N × D | All-reduce (gradients) | Any (least communication) |
| FSDP/ZeRO-3 | 1/N of everything | 3 × (N-1)/N × D | All-gather + reduce-scatter | Cross-node IB OK |
| Tensor Parallel | 1/T of each layer | 4 × all-reduce per block | All-reduce (activations) | NVLink required (intra-node) |
| Pipeline Parallel | L/P layers | Point-to-point per microbatch | Send/recv (activations) | IB OK, low bandwidth needs |
Adjust model size and GPU count to see how parallelism strategies affect memory usage and communication. The visualizer shows per-GPU memory and communication volume.
During the forward pass, every layer's output is stored in memory so it can be used during the backward pass. For a 70B model with 80 layers, batch size 1, and sequence length 8192, the activation memory is enormous:
Activation checkpointing (also called gradient checkpointing or rematerialization) discards intermediate activations during the forward pass and recomputes them during the backward pass. The trade-off: 33% more compute (one extra forward pass per layer), but activation memory drops from O(L) to O(sqrt(L)) or O(1) depending on the strategy.
| Strategy | Memory Saved | Extra Compute | When to Use |
|---|---|---|---|
| No checkpointing | None | None | Model and activations fit in memory |
| Selective (every N layers) | ~(N-1)/N of activations | ~(N-1)/N extra forward | Most common: checkpoint every 2-4 layers |
| Full (every layer) | ~L× activations | ~33% extra compute | Very large models or long sequences |
| FlashAttention | Eliminates S×S attention matrix | Recomputes in backward | Always use — saves the biggest chunk |
When training with very long sequences (32K, 128K, or even 1M tokens), the attention memory grows quadratically with sequence length. Even with FlashAttention, the KV cache for one layer is B × 2 × H_kv × S × d_head × 2 bytes. At S=128K with GQA (8 KV heads, d_head=128):
Context parallelism (also called sequence parallelism for the attention component) splits the sequence across GPUs. Each GPU processes a shard of the sequence and communicates partial attention results via all-to-all or ring exchanges. This allows training on sequences far longer than what fits on a single GPU.
Ring attention is the most elegant implementation: GPUs are arranged in a ring. Each GPU holds a chunk of Q and a chunk of K,V. The K,V chunks are passed around the ring, and each GPU accumulates the attention for its Q chunk against all K,V chunks as they arrive. After one full ring rotation, every GPU has its complete attention output.
When training on 2048+ GPUs for weeks, hardware failures are guaranteed. A single GPU failure can halt a training run that costs $50K per hour of idle time. Frontier labs engineer for this:
| Failure Type | Frequency (2048 GPUs) | Detection | Recovery |
|---|---|---|---|
| GPU memory error (ECC) | ~1 per day | CUDA runtime error | Replace GPU, resume from checkpoint |
| NIC failure | ~1 per 3 days | NCCL timeout | Reroute traffic, resume |
| Node crash | ~1 per week | Heartbeat timeout | Exclude node, redistribute ranks |
| Software hang | ~2 per day | Training step time 3x normal | Kill, restart from checkpoint |
| Checkpoint corruption | Rare | Checksum validation | Fall back to previous checkpoint |
python # Compute parallelism configuration for any model-cluster pair def compute_parallel_config( num_params_B: float, # billions of parameters num_gpus: int, # total GPU count hbm_per_gpu_GB: float = 80, num_heads: int = 64, num_layers: int = 80, ): P = num_params_B * 1e9 # Memory breakdown (BF16 training with AdamW) weights_GB = P * 2 / 1e9 # BF16 weights grads_GB = P * 2 / 1e9 # BF16 gradients opt_GB = P * 12 / 1e9 # FP32 master + m + v total_static_GB = weights_GB + grads_GB + opt_GB # Strategy 1: Can we fit on 1 GPU? → DDP if total_static_GB * 1.3 < hbm_per_gpu_GB: return {"strategy": "DDP", "DP": num_gpus, "mem_per_gpu": total_static_GB} # Strategy 2: Choose TP (must divide heads, max 8 for NVLink) tp = 1 for t in [8, 4, 2]: if num_heads % t == 0 and weights_GB / t < hbm_per_gpu_GB * 0.5: tp = t; break # Strategy 3: Choose PP (enough stages to fit layers) pp = 1 while total_static_GB / (tp * pp) > hbm_per_gpu_GB * 0.6: pp *= 2 if pp > 16: break # Remaining GPUs are data parallel dp = num_gpus // (tp * pp) # Per-GPU memory with FSDP across DP dimension layers_per_stage = num_layers // pp frac = layers_per_stage / num_layers mem_weights = weights_GB * frac / tp # TP shards weights mem_grads = grads_GB * frac / (dp * tp) # FSDP shards grads mem_opt = opt_GB * frac / (dp * tp) # FSDP shards opt mem_per_gpu = mem_weights + mem_grads + mem_opt return { "strategy": f"TP={tp} x PP={pp} x DP={dp}", "tp": tp, "pp": pp, "dp": dp, "mem_per_gpu_GB": round(mem_per_gpu, 2), "total_static_GB": round(total_static_GB, 1), "bubble_frac": (pp - 1) / (pp - 1 + 4 * pp) if pp > 1 else 0, } # Example: LLaMA 3 70B on 2048 H100s config = compute_parallel_config(70, 2048) print(config) # {'strategy': 'TP=8 x PP=4 x DP=64', 'mem_per_gpu_GB': 5.14, ...}
Let us examine the actual parallelism configurations used by frontier labs. These numbers come from published papers and technical reports:
| Model | Params | GPUs | TP | PP | DP | CP | MFU | Global Batch |
|---|---|---|---|---|---|---|---|---|
| LLaMA 3 8B | 8B | 512 H100 | 1 | 1 | 512 | 1 | 43% | 4M tokens |
| LLaMA 3 70B | 70B | 2048 H100 | 8 | 4 | 64 | 1 | 41% | 4M tokens |
| LLaMA 3 405B | 405B | 16384 H100 | 8 | 16 | 128 | 1 | 38% | 8M tokens |
| GPT-4 (estimated) | ~1.7T MoE | ~25000 A100 | 8 | ~64 | ~48 | ? | ~35% | ~60M tokens |
| Gemini 1.5 (est) | >1T | ~4096 TPUv5p | 16 | ~8 | ~32 | 4 | ~45% | ? |
| DeepSeek V3 | 671B MoE | 2048 H800 | 4 | 16 | 32 | 1 | 52% | 15M tokens |
Three major frameworks support distributed training. Knowing their trade-offs matters for system design interviews:
| Feature | PyTorch FSDP | DeepSpeed ZeRO | Megatron-LM |
|---|---|---|---|
| Maintained by | Meta (PyTorch team) | Microsoft | NVIDIA |
| Data parallelism | ZeRO-3 (native) | ZeRO-1/2/3 | DDP |
| Tensor parallelism | Via DTensor (limited) | Via Megatron-DeepSpeed | Native (best support) |
| Pipeline parallelism | Limited | Yes (1F1B, interleaved) | Yes (1F1B, interleaved) |
| Sequence parallelism | No | Ulysses SP | Yes |
| Mixed precision | BF16 native | BF16/FP16 | BF16/FP16 + FP8 (TransformerEngine) |
| Best for | Simple FSDP jobs, <100B params | Flexible configs, research | Max performance, 100B+ models |
| Learning curve | Low (PyTorch native) | Medium (config-driven) | High (custom model code) |
In practice, Meta uses Megatron-LM + FSDP for LLaMA training. Google uses its own internal framework (Pathways) on TPUs. DeepSpeed is popular for research labs and smaller companies. The choice often depends on what your cluster runs (NVIDIA GPUs → Megatron-LM, TPUs → JAX/Pax, budget-constrained → DeepSpeed).
The learning rate schedule interacts with parallelism in non-obvious ways. When you change the global batch size (by changing DP or gradient accumulation), the optimal learning rate changes too.
Training a model costs millions of dollars and takes weeks. Serving that model costs millions of dollars per month and runs indefinitely. At frontier labs, the serving infrastructure team often has a larger headcount than the training team. And the engineering challenges are fundamentally different: training optimizes for throughput (tokens per second across the cluster), serving optimizes for latency under load (time to first token for the user who just typed a query while 10,000 other users are also waiting).
This chapter covers the core systems that make LLM serving efficient: the two-phase serving model, continuous batching, speculative decoding, disaggregated serving, and KV cache management. By the end, you will be able to derive the cost of serving a 70B model and identify exactly where the bottlenecks are.
When a user sends a prompt to an LLM, the server does two fundamentally different things:
Phase 1: Prefill. The entire prompt (say, 500 tokens) is processed in a single forward pass. Every token can attend to all previous tokens in parallel. This is a large matrix multiply — compute-bound. The GPU tensor cores are busy. Memory bandwidth is not the bottleneck.
Phase 2: Decode. Output tokens are generated one at a time. Each forward pass processes exactly 1 new token (attending to all previous tokens via the KV cache). This is a tiny matrix multiply — memory-bandwidth-bound. The tensor cores are mostly idle, waiting for weights to be loaded from HBM.
Batching is the single most important optimization in LLM serving. Let us derive exactly why, starting from first principles.
Consider a 70B model on 4x H100 with TP=4. The model weights occupy 140/4 = 35 GB per GPU. During decode, each token requires loading all weights from HBM. The key insight: loading weights once for B tokens costs nearly the same as loading them for 1 token (the extra KV cache and activation memory is tiny in comparison).
In static batching, you collect B requests, process them together, and wait for ALL B requests to finish generating before accepting new ones. If request 1 generates 20 tokens and request 2 generates 500 tokens, request 1's GPU slot sits idle for 480 decode steps. For a batch of diverse-length requests, average GPU utilization can be as low as 30-40%.
Continuous batching (also called "in-flight batching" or "iteration-level batching") fixes this: as soon as any request finishes, its slot is immediately filled with a new request from the queue. The batch composition changes at every decode step.
Let us quantify the waste. Consider 4 requests with output lengths [30, 80, 150, 500]:
The implementation requires a scheduler that tracks the state of every request (prefilling, decoding, waiting, done) and makes admission decisions at every iteration:
python class ContinuousBatchScheduler: def __init__(self, max_batch: int, max_tokens: int): self.max_batch = max_batch # max concurrent requests self.max_tokens = max_tokens # max total tokens in KV cache self.running = [] # requests currently generating self.waiting = [] # queue of pending requests def schedule_step(self): # Step 1: Remove finished requests (hit EOS or max_len) finished = [r for r in self.running if r.is_done()] self.running = [r for r in self.running if not r.is_done()] # Step 2: Free KV cache memory from finished requests for r in finished: self.kv_cache.free(r.request_id) # Step 3: Admit new requests from the queue while (self.waiting and len(self.running) < self.max_batch and self.kv_cache.available() > self.waiting[0].est_tokens): req = self.waiting.pop(0) req.state = "prefill" self.running.append(req) # Step 4: Separate prefill and decode requests prefill_reqs = [r for r in self.running if r.state == "prefill"] decode_reqs = [r for r in self.running if r.state == "decode"] # Step 5: Run prefill for new requests (compute-bound) if prefill_reqs: self.run_prefill(prefill_reqs) for r in prefill_reqs: r.state = "decode" # Step 6: Run one decode step for all decoding requests if decode_reqs: self.run_decode(decode_reqs) return finished
A subtle problem with continuous batching: when a new request with a long prompt (say, 10,000 tokens) enters the batch, its prefill takes hundreds of milliseconds. During this time, the decode requests in the same batch are blocked — they cannot generate their next token until the prefill completes. This is head-of-line blocking.
Chunked prefill solves this by breaking long prefills into smaller chunks (e.g., 512 tokens at a time). Between chunks, decode requests get a chance to run. The result: prefill spreads over multiple iterations, and decode latency stays bounded.
Many applications need the LLM to output valid JSON, SQL, or code that follows a grammar. Guided decoding constrains the token-by-token generation to follow a formal grammar or regex pattern by masking out invalid tokens at each step.
python # Guided JSON decoding: mask invalid tokens at each step class JSONGuide: def __init__(self, schema: dict): self.schema = schema self.state = "start" # FSM state: start, key, value, etc. self.stack = [] # nesting stack: [object, array, ...] def get_valid_tokens(self, generated_so_far: str) -> set: """Return the set of token IDs that are valid at this position.""" if self.state == "start": return tokens_starting_with("{") elif self.state == "key": # Only allow keys from schema valid_keys = self.schema["properties"].keys() return tokens_for_any_of([f'"{k}"' for k in valid_keys]) elif self.state == "value": expected_type = self.current_field_type() if expected_type == "integer": return digit_tokens | tokens_for("-") elif expected_type == "string": return tokens_starting_with('"') # ... more types return all_tokens # fallback: no constraint def mask_logits(self, logits, generated_so_far): """Zero out logits for invalid tokens.""" valid = self.get_valid_tokens(generated_so_far) mask = torch.full_like(logits, -1e9) mask[list(valid)] = 0 return logits + mask
Autoregressive decoding is slow because you generate one token at a time. Speculative decoding accelerates this with a clever insight: use a small, fast draft model to generate K candidate tokens, then have the large target model verify all K tokens in a single forward pass (which is as fast as verifying 1 token, since it is compute-bound like prefill).
Prefill and decode have opposite hardware requirements. Prefill is compute-bound — it benefits from high FLOPS (many tensor cores). Decode is memory-bandwidth-bound — it benefits from high HBM bandwidth and large batch sizes. Running both on the same GPU creates contention: a long prefill blocks decode for all batched requests, causing latency spikes.
Disaggregated serving runs prefill and decode on separate GPU pools:
| Property | Prefill GPU Pool | Decode GPU Pool |
|---|---|---|
| Workload | Process input prompts | Generate output tokens |
| Bottleneck | Compute (tensor cores) | Memory bandwidth (HBM) |
| Optimal hardware | High FLOPS (H100 SXM) | High bandwidth, less compute (H100 NVL, or even A100) |
| Batch strategy | Large batches of prompts | Large batches of single-token decodes |
| KV cache | Generate KV cache, transfer to decode pool | Store and extend KV cache |
| Utilization | Near 100% compute | Near 100% bandwidth |
The critical challenge is transferring the KV cache from prefill to decode. For a 70B model with 80 layers, 8 KV heads, d_head=128, and a 4096-token prompt in BF16:
The KV cache is the largest per-request memory consumer. For a serving system handling hundreds of concurrent requests, KV cache management becomes a critical systems problem.
PagedAttention (from vLLM) is the breakthrough that made efficient KV cache management possible. Instead of allocating contiguous memory for each request's KV cache (which leads to fragmentation when requests have variable lengths), PagedAttention divides KV cache memory into fixed-size pages (typically 16 tokens per page):
python # PagedAttention KV Cache Manager (simplified) class PagedKVCacheManager: def __init__(self, num_pages: int, page_size: int = 16, num_layers: int = 80, num_heads: int = 8, head_dim: int = 128): self.page_size = page_size # Physical pages: [num_pages, 2, num_layers, num_heads, page_size, head_dim] # 2 = key and value self.physical_pages = torch.zeros( num_pages, 2, num_layers, num_heads, page_size, head_dim, dtype=torch.bfloat16, device="cuda" ) self.free_pages = list(range(num_pages)) self.page_tables = {} # request_id → [page_indices] def allocate(self, request_id: str, num_tokens: int): num_pages_needed = (num_tokens + self.page_size - 1) // self.page_size if len(self.free_pages) < num_pages_needed: return False # out of memory — preempt a request pages = [self.free_pages.pop() for _ in range(num_pages_needed)] self.page_tables[request_id] = pages return True def append_token(self, request_id: str): # Check if current last page has space pages = self.page_tables[request_id] current_len = self.get_seq_len(request_id) if current_len % self.page_size == 0: # Need a new page if not self.free_pages: return False pages.append(self.free_pages.pop()) return True def free(self, request_id: str): pages = self.page_tables.pop(request_id, []) self.free_pages.extend(pages)
| Metric | Definition | What It Measures | Target (chatbot) |
|---|---|---|---|
| TTFT | Time from request arrival to first output token | Prefill speed + queueing delay | < 500 ms (p95) |
| ITL | Inter-token latency: time between consecutive output tokens | Decode speed per token | < 50 ms (p95) for real-time feel |
| E2E Latency | Time from request to last token | TTFT + (output_len × ITL) | Depends on output length |
| Throughput | Total output tokens per second across all requests | System capacity | Maximize (cost efficiency) |
| TPS/user | Tokens per second experienced by one user | Perceived speed | > 20 TPS for "fast" feeling |
Cloud providers and API companies price LLM inference in dollars per million tokens. Let us derive this cost from first principles for serving LLaMA 3 70B:
Watch requests flow through a serving system. Green = prefilling, blue = decoding, gray = waiting. Compare static vs continuous batching utilization. Click "Add Request" to queue new work.
Let us derive the full serving profile for LLaMA 3 70B on a TPU v5e pod slice of 8 chips, using tensor parallelism.
When KV cache memory runs out, the scheduler must decide: reject new requests (increasing queue wait time) or preempt an existing request (evict its KV cache, restart it later). Sophisticated schedulers use priority-based preemption:
python # Priority preemption: evict lowest-priority request when memory is full def preempt_if_needed(self, new_request): """Evict a running request to make room for a higher-priority one.""" if self.kv_cache.available() >= new_request.est_tokens: return True # enough memory, no preemption needed # Find lowest-priority running request candidates = sorted( self.running, key=lambda r: (r.priority, -r.tokens_generated) ) # Evict requests with LOWER priority than the new one for victim in candidates: if victim.priority >= new_request.priority: break # don't preempt equal or higher priority # Swap out: save KV cache to CPU memory (or discard + recompute later) self.kv_cache.swap_out(victim.request_id) self.running.remove(victim) victim.state = "preempted" victim.preempt_count += 1 self.waiting.insert(0, victim) # re-queue at front if self.kv_cache.available() >= new_request.est_tokens: return True return False # cannot make room even after preemption
The KV cache is a major memory consumer during serving. Quantizing it from BF16 to FP8 or even INT4 can double or quadruple the number of concurrent requests:
In a chatbot, users send multiple messages in a conversation. Each new message requires attending to the full conversation history. Without optimization, every user message triggers a complete prefill of the entire conversation — wasting compute on tokens already processed.
The solution is KV cache persistence: keep the KV cache from previous turns alive between requests. When the user sends a new message, only the new tokens need prefill. This requires:
| Challenge | Solution | Trade-off |
|---|---|---|
| KV cache occupies memory between turns | TTL-based eviction (expire after 5-30 min idle) | Memory vs latency for returning users |
| User may never return | LRU eviction when memory pressure is high | Cold-start penalty for evicted users |
| Long conversations grow unboundedly | Sliding window attention or summary compression | Context accuracy vs memory |
| Scaling across servers | Session affinity (route user to same server) | Load balancing flexibility |
Every system we have studied so far uses LLMs as standalone generators: prompt in, text out. But the most exciting frontier in 2025 is using LLMs as components inside larger optimization loops. The LLM does not solve the problem — it proposes candidate solutions that an automated evaluator scores, and the best candidates are fed back to the LLM to produce better proposals. This is LLM-guided search.
This pattern has already produced genuine scientific discoveries: new algorithms for matrix multiplication, state-of-the-art solutions to open combinatorics problems, and improvements to real-world infrastructure systems. It is the closest thing we have to "AI doing research."
Every LLM-guided search system follows the same four-step loop:
Google DeepMind's FunSearch (2023) was the first system to demonstrate that LLMs can make genuine mathematical discoveries. The name stands for "searching in the space of functions" — the LLM generates Python functions, and an evaluator scores them.
FunSearch tackled the cap set problem — a notorious open problem in combinatorics. A cap set is a subset of points in F3n (the n-dimensional space over the field with 3 elements) such that no three points are collinear. The question: how large can a cap set be?
Mathematicians had found cap sets of size 512 in F38 through clever constructions. FunSearch found cap sets of size 512 as well — matching the best known — but using a completely different construction method that humans had not considered. For larger dimensions, it exceeded the previous best.
To understand what FunSearch actually discovered, we need to understand the problem it solved. A cap set in F3n is a subset S of {0, 1, 2}n such that no three elements a, b, c in S satisfy a + b + c = 0 (mod 3). In other words, no three points form an arithmetic progression.
FunSearch does not search directly for cap sets. Instead, it evolves a construction function — a Python function that takes the dimension n and returns a set of points. The evaluator checks (a) that the returned set is actually a cap set (no collinear triples), and (b) how large it is. Larger = better score.
FunSearch does not use a single population of solutions. It maintains multiple islands — independent populations that evolve separately. Periodically, the best solution from one island migrates to another. This prevents premature convergence (all candidates becoming too similar).
Beyond pure mathematics, FunSearch discovered new heuristics for online bin packing — a classic NP-hard optimization problem with real-world applications (packing containers, scheduling jobs, memory allocation).
The problem: items of various sizes arrive one at a time. You must place each item into a bin (capacity 1.0) immediately, without seeing future items. The goal: minimize the total number of bins used.
The best known online heuristic was First Fit Decreasing. FunSearch discovered a new priority function that outperformed all previously known heuristics on standard benchmarks. The discovered function was non-obvious — a piecewise function with specific breakpoints that no human had designed.
python # FunSearch-discovered bin packing heuristic (simplified) # The LLM evolved this priority function over 10,000+ generations def priority(item_size: float, bin_remaining: float) -> float: """Score for placing item into a bin. Higher = preferred.""" waste = bin_remaining - item_size if waste < 0: return -1e9 # does not fit if waste < 0.01: return 1e6 # nearly perfect fit — always prefer if waste < 0.1: return 500 - waste * 1000 if waste < 0.33: return 100 - waste * 200 # Key insight: prefer bins that leave specific "useful" gaps # (gaps that commonly-sized future items could fill) if 0.49 < waste < 0.51: return 80 # half-empty bins are useful if 0.32 < waste < 0.35: return 60 # one-third gaps are useful return 10 - waste * 5
Google DeepMind's AlphaEvolve (2025) extends FunSearch in three critical ways:
| Feature | FunSearch | AlphaEvolve |
|---|---|---|
| Code scope | Evolves a single function | Evolves entire multi-file programs |
| Objectives | Single scalar score | Multi-objective (Pareto frontier) |
| LLM | One model (Codey/PaLM) | Ensemble: Gemini Flash (fast exploration) + Gemini Pro (deep reasoning) |
| Evaluation | Sandbox execution | Sandbox + formal verification where possible |
| Context | Best programs + spec | Best programs + spec + diff history + failure analysis |
AlphaEvolve's most famous result: discovering a new algorithm for 4×4 complex matrix multiplication that uses fewer scalar multiplications than Strassen's algorithm. Specifically, it found a decomposition that uses 48 multiplications for 4×4 matrices, improving on the 49 previously known. This is a genuine mathematical discovery — the first improvement in this specific problem in decades.
AlphaEvolve's architecture is worth studying in detail because it represents the state of the art in LLM-guided search as of 2025. Let us trace the lifecycle of a single evolution step:
FunSearch and AlphaEvolve are the most famous, but the LLM-guided search paradigm has spawned many systems:
| System | Year | Lab | Domain | Key Innovation |
|---|---|---|---|---|
| FunSearch | 2023 | DeepMind | Math, combinatorics | Island model + LLM mutation |
| AlphaEvolve | 2025 | DeepMind | Math, infrastructure | Multi-file, multi-objective, LLM ensemble |
| EvoPrompting | 2023 | Various | Prompt optimization | Evolve prompts instead of code |
| ReEvo | 2024 | Academic | Combinatorial optimization | LLM reflects on failed attempts |
| OpenELM | 2024 | Academic | Robot control | LLM evolves reward functions |
| AIDE | 2024 | Weco AI | ML research | LLM writes + runs experiments in a tree search |
The common thread: an LLM generates candidates in a domain where automated evaluation is possible, and selection pressure drives improvement over hundreds or thousands of generations.
Andrej Karpathy's concept of Autoresearch (2024-2025) extends LLM-guided search to the full research workflow. Instead of searching for a single function, the system:
Here is a complete, runnable implementation of the core FunSearch loop for evolving a sorting key function:
python import random import subprocess import json from openai import OpenAI client = OpenAI() # ── Problem: evolve a priority function for bin packing ── SPEC = """You are evolving a Python function called `priority(item_size, bin_remaining)` that returns a float score. Higher score = preferred bin for the item. The function is used to solve online bin packing: items arrive one at a time, each must be placed in a bin (capacity 1.0) immediately. Goal: minimize total bins used across the test suite.""" EVAL_HARNESS = """ import random random.seed(42) # Generate test instances def make_instance(n=100): return [random.uniform(0.1, 0.9) for _ in range(n)] def evaluate(priority_fn, instances): total_bins = 0 for items in instances: bins = [] # list of remaining capacities for item in items: scores = [(priority_fn(item, b), i) for i, b in enumerate(bins) if b >= item] if scores: _, best_i = max(scores) bins[best_i] -= item else: bins.append(1.0 - item) total_bins += len(bins) return -total_bins # negative because lower bins = better, but we maximize score instances = [make_instance(200) for _ in range(20)] {candidate_code} score = evaluate(priority, instances) print(json.dumps({"score": score})) """ class Island: def __init__(self, capacity=50): self.programs = [] # (score, code) sorted by score descending self.capacity = capacity def insert(self, score, code): self.programs.append((score, code)) self.programs.sort(key=lambda x: x[0], reverse=True) if len(self.programs) > self.capacity: self.programs.pop() def sample_best(self, k=2): # Tournament selection: pick from top 30% top = self.programs[:max(1, len(self.programs) // 3)] return random.sample(top, min(k, len(top))) def evaluate_candidate(code: str, timeout: int = 30) -> float: """Run candidate in sandbox, return score.""" full_code = EVAL_HARNESS.replace("{candidate_code}", code) try: result = subprocess.run( ["python3", "-c", full_code], capture_output=True, text=True, timeout=timeout ) return json.loads(result.stdout)["score"] except: return -1e9 # crashed or timed out def generate_candidate(parents: list) -> str: """Ask LLM to evolve a new candidate from parent programs.""" parent_text = "\n\n".join( f"# Score: {s}\n{c}" for s, c in parents ) response = client.chat.completions.create( model="gpt-4o", temperature=1.0, messages=[{ "role": "user", "content": f"""{SPEC} Here are the best solutions found so far: {parent_text} Write an improved `priority` function. Be creative — try a different mathematical form. Output ONLY the Python function.""" }] ) return response.choices[0].message.content # ── Main evolution loop ── NUM_ISLANDS = 5 islands = [Island() for _ in range(NUM_ISLANDS)] # Seed with a simple baseline seed = """def priority(item_size, bin_remaining): if bin_remaining < item_size: return -1e9 return -(bin_remaining - item_size) # first-fit-decreasing""" seed_score = evaluate_candidate(seed) for island in islands: island.insert(seed_score, seed) best_ever = seed_score for gen in range(500): # Pick a random island island = random.choice(islands) parents = island.sample_best(2) # Generate and evaluate candidate = generate_candidate(parents) score = evaluate_candidate(candidate) island.insert(score, candidate) if score > best_ever: best_ever = score print(f"Gen {gen}: NEW BEST = {score}") # Migrate every 50 generations if gen % 50 == 0 and gen > 0: src, dst = random.sample(islands, 2) if src.programs: dst.insert(*src.programs[0])
Not every problem benefits from LLM-guided search. Understanding the failure modes is as important as understanding the successes:
| Antipattern | Why It Fails | Example |
|---|---|---|
| Non-automatable evaluation | If a human must judge quality, the loop is too slow (hours per generation vs seconds) | Generating creative writing, designing UIs, composing music |
| Tiny search space | If there are only a few valid programs, random search works as well as LLM-guided | Choosing between 3 sorting algorithms |
| Deceptive fitness landscape | If small changes cause large score swings, the LLM cannot learn meaningful gradients | Cryptographic hash optimization |
| Insufficient LLM knowledge | If the domain has no representation in training data, proposals are random | Novel chemistry synthesis (pre-2024 data cutoff) |
| Evaluation is too expensive | If each evaluation takes hours (e.g., training a model), you get too few generations | Neural architecture search with full training |
The quality of the LLM's proposals depends critically on how you prompt it. FunSearch and AlphaEvolve both discovered that specific prompt structures dramatically improve the mutation quality:
If you want to apply LLM-guided search to your own problem, here is the engineering checklist:
Watch islands of candidate solutions evolve over generations. Each dot is a candidate program. Color = island. Height = fitness score. Click "Evolve" to run one generation. Watch how migration spreads breakthroughs across islands.
You have written the training loop. You have configured the parallelism. The job is running. But the MFU reads 22%. Nearly 80% of the cluster's theoretical capability is wasted. Where is the performance going?
To answer this question, you need to understand how the hardware actually works — not at the CUDA-kernel level (Chapter 2 covered that), but at the architectural level: how data flows through the chip, where bottlenecks form, and how to predict performance before writing a single line of code. This is the roofline model, and it is the single most useful tool in a performance engineer's toolkit.
An NVIDIA H100 SXM is organized as a hierarchy of compute and memory:
| Component | Specification | Role |
|---|---|---|
| Streaming Multiprocessors (SMs) | 132 SMs | Each SM is an independent processor with its own registers, shared memory, and warp schedulers |
| Tensor Cores (per SM) | 4 fourth-gen Tensor Cores | Matrix multiply accelerators. Each performs 16×16 matrix ops per cycle in BF16. |
| Peak BF16 FLOPS | 989 TFLOP/s | Theoretical maximum with all tensor cores active |
| HBM3 Memory | 80 GB, 3.35 TB/s bandwidth | Main memory. All model weights, KV cache, activations live here. |
| L2 Cache | 50 MB | Shared across all SMs. Caches frequently accessed HBM data. |
| Shared Memory (per SM) | 228 KB | Programmer-managed scratchpad. Faster than L2, used for tiling. |
| Register File (per SM) | 256 KB | Fastest storage. Each thread has up to 255 registers. |
| NVLink | 900 GB/s (bidirectional) | GPU-to-GPU interconnect within a node (8 GPUs) |
Google's TPU v5e is a fundamentally different design philosophy. Where GPUs evolved from graphics (many small cores, complex scheduling), TPUs were built from scratch for matrix multiplication.
| Component | Specification | Role |
|---|---|---|
| MXU (Matrix Multiply Unit) | 2 MXUs per chip | Systolic array: 128×128 BF16 multiply-accumulate units. Performs one 128×128 matmul per cycle. |
| Peak BF16 FLOPS | 197 TFLOP/s | Lower than H100, but higher utilization in practice |
| HBM Memory | 16 GB, 819 GB/s bandwidth | Smaller and slower than H100. Constrains model size per chip. |
| VMEM (Vector Memory) | 16-32 MB | On-chip SRAM. Equivalent to GPU shared memory, but larger. |
| VPU (Vector Processing Unit) | 1 per chip | Handles non-matmul operations: activations, normalization, softmax |
| ICI (Inter-Chip Interconnect) | ~400 GB/s per link, 6 links | Direct chip-to-chip connection in a 3D torus topology. No switch needed. |
A systolic array is a grid of processing elements (PEs) where data flows rhythmically through the grid, like a heartbeat (hence "systolic"). Each PE performs one multiply-accumulate per cycle and passes the result to its neighbor.
For a 128×128 systolic array computing C = A × B:
The roofline model predicts the maximum achievable performance of any operation based on a single metric: arithmetic intensity (operations per byte of memory traffic). The model has two regimes:
Let us compute the arithmetic intensity for the key operations in a transformer:
| Operation | FLOPS | Bytes | Intensity (I) | Bound |
|---|---|---|---|---|
| Linear layer (batch B, in d, out d) | 2Bd2 | 2d2 + 2Bd + 2Bd | ≈ B (for d>>B) | Compute if B > 295 |
| LayerNorm (batch B, dim d) | 5Bd | 4Bd (read + write) | 1.25 | Memory-bound (always) |
| Softmax (batch B, seq S) | 3BS | 4BS | 0.75 | Memory-bound (always) |
| GeLU (batch B, dim d) | ~8Bd | 4Bd | 2 | Memory-bound (always) |
| Attention QKT (batch B, heads H, seq S, d_h) | 2BHS2dh | 2BHSdh + 2BHS2 | ≈ S (for large S) | Compute if S > 295 |
| Dimension | H100 (GPU) | TPU v5e | Winner for... |
|---|---|---|---|
| Peak BF16 FLOPS | 989 TFLOP/s | 197 TFLOP/s | GPU: raw compute power |
| HBM Bandwidth | 3.35 TB/s | 0.819 TB/s | GPU: memory-bound ops |
| HBM Capacity | 80 GB | 16 GB | GPU: large models per chip |
| On-chip SRAM | 50 MB (L2) + 228KB/SM shared | 16-32 MB VMEM | Comparable |
| Interconnect BW | 900 GB/s NVLink (8 GPUs) | ~2.4 TB/s ICI (3D torus) | TPU: multi-chip communication |
| Interconnect topology | All-to-all via NVSwitch (8 GPUs) | 3D torus (thousands of chips) | TPU: massive scale |
| Programming model | CUDA (explicit), Triton | XLA (compiler-driven, implicit) | GPU: flexibility; TPU: ease |
| Matmul efficiency | High but requires careful tuning | Very high out of the box (systolic array) | TPU: less tuning needed |
| Cost per TFLOP | Higher ($) | Lower ($/TFLOP) | TPU: cost efficiency |
| Non-matmul ops | Fast (many SMs, flexible) | Slower (single VPU) | GPU: diverse workloads |
For distributed training, the interconnect between devices is often the bottleneck, not the compute. The two architectures take fundamentally different approaches:
Let us predict the performance of a [4096, 4096] × [4096, 4096] BF16 matmul on both chips, then verify against roofline predictions.
On TPUs, you do not write kernels. Instead, JAX traces your Python code, XLA compiles it to optimized HLO (High-Level Operations), and the TPU runtime schedules the execution. Profiling tells you where time is spent in this compiled execution.
python import jax import jax.numpy as jnp from jax import profiler # ── Step 1: Define the computation ── def transformer_block(x, w_q, w_k, w_v, w_o, w_ff1, w_ff2, w_ln): # Layer norm (memory-bound) x_norm = jax.nn.standardize(x, axis=-1) # Attention projections (compute-bound for large batch) q = x_norm @ w_q # [B, S, D] @ [D, D] → [B, S, D] k = x_norm @ w_k v = x_norm @ w_v # Attention scores (compute-bound for long sequences) d_k = q.shape[-1] scores = (q @ k.transpose(0, 1, 3, 2)) / jnp.sqrt(d_k) attn = jax.nn.softmax(scores, axis=-1) # Attention output out = (attn @ v) @ w_o x = x + out # residual (memory-bound) # FFN (compute-bound) x_norm = jax.nn.standardize(x, axis=-1) h = jax.nn.gelu(x_norm @ w_ff1) # [B,S,D] @ [D,4D] → [B,S,4D] x = x + h @ w_ff2 # [B,S,4D] @ [4D,D] → [B,S,D] return x # ── Step 2: JIT compile ── transformer_jit = jax.jit(transformer_block) # ── Step 3: Profile ── B, S, D = 8, 2048, 4096 key = jax.random.PRNGKey(0) x = jax.random.normal(key, (B, S, D), dtype=jnp.bfloat16) # ... initialize weight matrices ... # Warm up (trigger compilation) _ = transformer_jit(x, w_q, w_k, w_v, w_o, w_ff1, w_ff2, w_ln) # Profile 10 iterations with profiler.trace("/tmp/jax-trace"): for _ in range(10): x = transformer_jit(x, w_q, w_k, w_v, w_o, w_ff1, w_ff2, w_ln) x.block_until_ready() # force synchronization # View in TensorBoard: tensorboard --logdir=/tmp/jax-trace # Look for: # - HLO operation breakdown (which ops take the most time) # - Memory usage timeline (peak vs sustained) # - ICI communication time (for multi-chip runs) # - MXU utilization % (target: >80%)
When you write `y = x @ w` in JAX, the path to execution is:
When training on multiple TPU chips (or GPUs with tensor parallelism), the large weight matrices are sharded across devices. The matmul C = A × B becomes a distributed operation.
Visualizes the roofline model for H100 and TPU v5e. Drag the "operation" dot to see where different workloads fall. Operations below the roofline are underperforming — the gap shows wasted potential.
On GPUs, the profiling workflow uses NVIDIA's Nsight Systems (system-level timeline) and Nsight Compute (kernel-level analysis):
bash # Nsight Systems: capture a full training step timeline nsys profile --trace=cuda,nvtx --output=train_step \ python train.py --steps 10 # What to look for in the timeline: # 1. GPU idle gaps → CPU-bound or waiting for data # 2. Small kernels with gaps → launch overhead (use CUDA Graphs) # 3. Long NCCL kernels → communication bottleneck # 4. Memory copy (H2D/D2H) overlapping with compute → good! # 5. Memory copy NOT overlapping → pipeline stall # Nsight Compute: deep-dive on one kernel ncu --set full --target-processes all \ python -c "import torch; a=torch.randn(4096,4096,device='cuda',dtype=torch.bfloat16); b=a@a" # Key metrics from ncu: # - SM utilization (% of SMs active) # - Tensor Core utilization (% of TC active) # - Memory throughput (% of peak BW achieved) # - Achieved occupancy (warps per SM) # - Arithmetic intensity (actual, not theoretical)
Our roofline analysis showed that most transformer operations are memory-bound. Let us derive exactly how much time is wasted on non-compute operations for a single transformer block on an H100:
Here is the systematic debugging checklist that production teams at frontier labs follow when MFU is below expectations:
| Step | Check | Tool | Expected Value | If Failing |
|---|---|---|---|---|
| 1 | Batch size large enough? | Config check | B ≥ 8 per GPU (ideally 16+) | Increase microbatch size, use gradient accumulation |
| 2 | Communication overlapped? | Nsight Systems / JAX profiler | NCCL/ICI ops overlap with compute | Enable async communication, check FSDP prefetch |
| 3 | Pipeline bubble acceptable? | Calculate (P-1)/(P-1+M) | < 15% | Increase microbatches M or reduce pipeline stages P |
| 4 | FlashAttention enabled? | Code check | Using FA2 or FA3 | Switch to FlashAttention, massive memory + speed win |
| 5 | Matmul dimensions aligned? | Shape analysis | Multiples of 64 (GPU) or 128 (TPU) | Pad hidden dims to nearest multiple |
| 6 | Activation checkpointing balanced? | Memory profiler | HBM usage 70-90% peak | If <70%: reduce checkpointing. If OOM: increase it. |
| 7 | Data loading bottleneck? | CPU utilization | GPU should never wait for data | More dataloader workers, prefetch, faster storage |
Here is the mental model you should carry. Every computation you write passes through this stack:
| Layer | GPU Path | TPU Path |
|---|---|---|
| User code | PyTorch / Triton | JAX / Flax |
| Graph capture | torch.compile / CUDA Graphs | jax.jit (tracing) |
| Compiler | TorchInductor / Triton compiler | XLA |
| IR | Triton IR → PTX → SASS | HLO → LLO → TPU microcode |
| Runtime | CUDA driver, NCCL | TPU runtime, ICI stack |
| Hardware | SM → Tensor Core → Register → SMEM → L2 → HBM | VPU → MXU → VMEM → HBM |
| Interconnect | NVLink (intra) → IB (inter) | ICI (uniform 3D torus) |
At every layer, performance can be lost. The profiling tools let you pinpoint where. The roofline model tells you the theoretical maximum. The gap between theoretical and achieved is your optimization opportunity.
You have just been handed an uncomfortable reality. The dense transformer you have been training — every parameter active on every token — is hitting diminishing returns. You doubled the parameter count from 7B to 14B, doubled the compute budget, and gained... 0.3 points on your eval suite. The Chinchilla-optimal frontier says you need 4x the tokens to justify 4x the parameters, but your data pipeline cannot produce tokens that fast. There has to be a way to get more parameters without a proportional compute increase.
There is. It is called Mixture of Experts (MoE), and it is the architectural pattern behind every frontier lab's largest models: GPT-4, Gemini, DeepSeek-V3, Mixtral, and Grok. The idea is deceptively simple: instead of one big feedforward network (FFN) that processes every token, you have many smaller FFNs (the "experts"), and a lightweight router decides which experts handle which token. Most of the parameters are dormant for any given token. You pay the storage cost of a huge model but only the compute cost of a small one.
A standard transformer block has two main components: a multi-head attention layer and a feedforward network (FFN). MoE replaces the FFN. The attention layer stays dense — every head sees every token. Only the FFN becomes sparse.
Why only the FFN? Two reasons. First, the FFN typically accounts for 2/3 of the parameters in a transformer block (it expands the hidden dimension by 4x, then contracts back). Replacing it with experts gives you the biggest parameter multiplier for the smallest architectural change. Second, attention is inherently a cross-token operation — all tokens need to interact. The FFN operates token-independently, so routing tokens to different experts does not break any dependency.
The router is the brain of MoE. It takes a token's hidden state x and produces a routing decision. Here is the math, step by step.
Think of it as a committee vote where only the top-k committee members get to speak, and their influence is proportional to how confident the router is in their relevance.
Let's trace through a concrete example. We have 8 experts, top-2 routing, and d_model = 4 (tiny for illustration).
Here is the failure mode that kills naive MoE training. Without intervention, the router learns to send almost all tokens to 1-2 "favorite" experts. Those experts get better because they see more data. Because they are better, the router sends even more tokens to them. The other experts never get trained. Within a few thousand steps, you have a "Mixture of 2 Experts plus 6 dead weights." This is called expert collapse.
The standard fix is an auxiliary load-balancing loss that penalizes uneven expert usage. The loss has two terms:
Why multiply fi by pi? Because fi depends on the argmax (not differentiable), but pi is the smooth softmax probability (differentiable). The product lets gradients flow through pi to encourage the router to spread probability mass across experts. The total loss becomes Ltotal = LLM + Laux.
Let's be precise about why MoE is efficient. A dense transformer FFN has two weight matrices: W1 ∈ ℝd × 4d and W2 ∈ ℝ4d × d. The FLOPs for one token through this FFN are approximately 2 × d × 4d + 2 × 4d × d = 16d2 (counting multiply-accumulate as 2 FLOPs).
MoE introduces a unique parallelism challenge. In a dense model, every GPU processes the same layers — you can split by data parallelism (different batches) or tensor parallelism (different parts of the same matrix). In MoE, different GPUs hold different experts, and tokens must physically move to the GPU hosting their assigned expert.
This is called expert parallelism, and the communication pattern is all-to-all: every GPU may need to send tokens to every other GPU, and every GPU may receive tokens from every other GPU. It is the MoE equivalent of traffic at an intersection where every car wants to go in a different direction.
| Parallelism Type | What splits | Communication | Used for |
|---|---|---|---|
| Data Parallel | Batch across GPUs | AllReduce gradients | Dense layers, attention |
| Tensor Parallel | Weight matrices across GPUs | AllReduce activations | Large single layers |
| Expert Parallel | Experts across GPUs | All-to-All tokens | MoE FFN layers only |
| Pipeline Parallel | Layers across GPUs | Point-to-point activations | Very deep models |
The all-to-all operation has two phases. In the dispatch phase, each GPU sends its tokens to the GPUs hosting the selected experts. In the combine phase, each GPU sends the expert outputs back to the originating GPUs. Both phases are all-to-all collectives, and they are the dominant communication overhead in MoE training.
Let's calculate the actual communication overhead for a realistic MoE setup, because this is the number that determines whether MoE is practical at your cluster scale.
Load balancing is imperfect. Even with the auxiliary loss, some experts receive more tokens than others. If expert 3 is assigned 2x the average number of tokens, processing it takes 2x as long, and all other GPUs sit idle waiting — destroying parallelism.
The capacity factor C controls the maximum number of tokens each expert can process. If an expert is assigned more than C × (total_tokens / num_experts) tokens, the excess tokens are dropped — they pass through the layer with only the residual connection (no FFN computation). Typical values: C = 1.0 (strict) to C = 1.5 (generous).
DeepSeek-V3 (2024) introduced several MoE innovations that frontier labs are rapidly adopting. The most significant: no auxiliary balancing loss. Instead, they use a simple bias term.
| Feature | Standard MoE | DeepSeek-V3 |
|---|---|---|
| Total params | ~47B (Mixtral 8x7B) | 671B |
| Active params | ~13B | 37B |
| Expert count | 8 coarse experts | 256 fine-grained + 1 shared |
| Top-k | 2 | 8 of 256 (3.1% activation ratio) |
| Shared expert | None | 1 expert that processes ALL tokens |
| Balancing | Auxiliary loss (α = 0.01) | Per-expert bias term, no aux loss |
| Expert granularity | Coarse (each expert = full FFN) | Fine-grained (each expert = 1/16 of FFN) |
The key insight behind DeepSeek's auxiliary-free approach: instead of adding a loss term that fights with the main training objective, they add a learnable bias bi to each expert's routing logit. When an expert is underutilized, its bias increases slightly; when overutilized, it decreases. The bias is updated by a simple heuristic (not gradient descent), decoupling load balancing from the language modeling loss entirely.
Why 256 fine-grained experts instead of 8 coarse ones? Finer granularity means each expert can specialize more precisely. An expert that handles "Python variable assignment" is more useful than one that handles "all of programming." The cost is more complex routing, but with 8-of-256 selection, the activation ratio (3.1%) is even sparser than 2-of-8 (25%).
The shared expert is another innovation. One expert processes every single token — it captures common patterns that all tokens need (basic grammar, formatting, common words). The routed experts then only need to capture specialized knowledge. Think of it as a "base layer" that everyone gets, plus "specialist consultants" routed by topic.
Here is a practical problem that matters for your capstone implementation. When you route tokens to experts, each expert receives a different number of tokens. Expert 3 might get 50 tokens while expert 7 gets 120. You need to do batched matrix multiplication where each batch element has a different length. In dense models, every batch element has the same shape, so standard matmul works. In MoE, you need a ragged (variable-length) batched dot product.
JAX provides jax.lax.ragged_dot for exactly this. It takes:
python import jax import jax.numpy as jnp # x: [total_tokens, d_model] — all tokens concatenated # w: [num_experts, d_model, expert_dim] — stacked expert weights # group_sizes: [num_experts] — how many tokens assigned to each expert # e.g., [50, 80, 120, 45, 90, 110, 65, 75] — sums to total_tokens # ragged_dot computes expert_i's matmul on its assigned chunk of x output = jax.lax.ragged_dot( x, # [total_tokens, d_model] w, # [num_experts, d_model, expert_dim] group_sizes, # [num_experts] — token counts per expert ) # output: [total_tokens, expert_dim] # Equivalent to: for each expert i, slice x[start:start+group_sizes[i]] # and compute matmul(slice, w[i]). Concatenate results.
The alternative without ragged_dot is to pad every expert's batch to the maximum length, waste compute on padding tokens, and then mask out the results. For highly imbalanced routing, this can waste 2-3x FLOPs. ragged_dot avoids the waste by handling variable-length segments natively on the accelerator.
Before Mixtral and DeepSeek-V3, there was the Switch Transformer (Fedus et al., 2022). It simplified the original MoE design (Shazeer et al., 2017) from top-2 routing to top-1 — each token goes to exactly one expert. This sounds wasteful, but it has a crucial advantage: simpler communication patterns and higher training stability.
The Switch Transformer also introduced the capacity factor and the auxiliary load-balancing loss formulation that became standard. Its key finding: MoE scaling is efficient. A Switch Transformer with 1.6 trillion parameters (but only ~100B active) trained 4x faster than a compute-equivalent dense T5 model, reaching the same quality in 1/4 the time.
The router's gating function is a design choice with significant consequences. Different papers use different gating mechanisms, each with tradeoffs.
| Gating Mechanism | Formula | Pros | Cons | Used In |
|---|---|---|---|---|
| Softmax top-k | softmax(h)top-k | Standard, well-understood | Top-k is not differentiable — requires STE or aux loss | Mixtral, Switch, GShard |
| Sigmoid top-k | sigmoid(h)top-k / sum | Independent expert probabilities (not sum-to-1) | Need explicit normalization | DeepSeek-V3 |
| Expert Choice | Top-k tokens per expert (inverse routing) | Perfect load balance by construction | Some tokens may get zero experts | Zhou et al. (2022) |
| Noisy top-k | softmax(h + noise)top-k | Encourages exploration during training | Added noise hurts convergence | Shazeer et al. (2017) |
| Hash routing | expert_id = hash(token_id) % N | Zero overhead, deterministic | No learned specialization | Roller et al. (2021) |
The Expert Choice mechanism deserves special attention. Instead of each token choosing its experts, each expert chooses its top-k tokens. This guarantees perfect load balance (each expert gets exactly the same number of tokens) but creates a different problem: some tokens may be chosen by zero experts and others by many. In practice, Expert Choice often works better than auxiliary-loss-based balancing because it eliminates the hyperparameter sensitivity of the balancing coefficient.
Chinchilla scaling laws for dense models say: for a compute budget C, the optimal model has Nopt ∝ C0.5 parameters trained on Dopt ∝ C0.5 tokens. But MoE breaks this equation because active parameters ≠ total parameters.
The practical implication: you cannot just add infinite experts. At some point, the communication overhead and the diminishing scaling returns of inactive parameters make it better to increase expert size or add more active experts. DeepSeek-V3's choice of 256 experts with 8 active represents a point on this frontier that was determined empirically.
jax import jax import jax.numpy as jnp import flax.linen as nn class TopKRouter(nn.Module): """Top-k expert routing with load-balancing auxiliary loss.""" num_experts: int = 8 top_k: int = 2 @nn.compact def __call__(self, x): # x: [batch, seq_len, d_model] B, S, D = x.shape # Router: single linear projection to expert logits logits = nn.Dense(self.num_experts, name="gate")(x) # [B, S, N] # Top-k selection top_k_logits, top_k_indices = jax.lax.top_k(logits, self.top_k) # Routing weights: softmax over selected experts only gates = jax.nn.softmax(top_k_logits, axis=-1) # [B, S, k] # ── Auxiliary load-balancing loss ── # f_i = fraction of tokens routed to expert i router_probs = jax.nn.softmax(logits, axis=-1) # [B, S, N] expert_mask = jax.nn.one_hot(top_k_indices, self.num_experts) # [B,S,k,N] expert_mask = expert_mask.sum(axis=-2) # [B, S, N] — 1 if expert selected f_i = expert_mask.mean(axis=(0, 1)) # [N] — fraction of tokens per expert p_i = router_probs.mean(axis=(0, 1)) # [N] — mean router prob per expert aux_loss = self.num_experts * (f_i * p_i).sum() return gates, top_k_indices, aux_loss class MoELayer(nn.Module): """Mixture of Experts FFN layer.""" num_experts: int = 8 top_k: int = 2 expert_dim: int = 2048 # hidden dim of each expert FFN @nn.compact def __call__(self, x): B, S, D = x.shape # Route tokens to experts gates, indices, aux_loss = TopKRouter( self.num_experts, self.top_k )(x) # gates: [B,S,k], indices: [B,S,k] # Compute ALL expert outputs (simple but wasteful) # Production code uses scatter-gather for true sparsity expert_outputs = [] for i in range(self.num_experts): # Each expert is a 2-layer FFN with GELU h = nn.Dense(self.expert_dim, name=f"expert_{i}_up")(x) h = jax.nn.gelu(h) h = nn.Dense(D, name=f"expert_{i}_down")(h) expert_outputs.append(h) # Stack: [num_experts, B, S, D] all_experts = jnp.stack(expert_outputs, axis=0) # Gather selected expert outputs using indices # indices: [B, S, k] — which experts for each token selected = jnp.take_along_axis( all_experts, indices.transpose(2, 0, 1)[..., None].broadcast_to( self.top_k, B, S, D), axis=0 ).transpose(1, 2, 0, 3) # [B, S, k, D] # Weighted sum by gate values output = (selected * gates[..., None]).sum(axis=-2) # [B, S, D] return output, aux_loss
ragged_dot or padding), then unsort. The above is clearer for learning; the capstone in chapter 15 will implement the efficient version.A frontier topic that appears in research discussion interviews: how do you route tokens from different modalities (text, images, audio) through the same MoE? This is relevant because frontier models like Gemini and GPT-4o are multi-modal.
The challenge: text tokens and image patch tokens have very different statistical properties. Text tokens have a discrete, high-entropy distribution. Image patch tokens are continuous, locally correlated, and lower-entropy. If you use a single router for both, the router may route all image tokens to one set of experts and all text tokens to another — creating accidental modality-specific experts rather than concept-specific experts.
This is an active research area with no consensus answer. In an interview, the best response is to present the tradeoffs (shared routing gives more flexibility but risks modality collapse; reserved experts are safer but limit cross-modal transfer) and propose an experiment to decide (train both configurations, measure loss on text-only, image-only, and multi-modal benchmarks).
An important practical consideration: MoE models behave differently under RLHF than dense models. The router's decisions can shift dramatically during the reward model training phase, because RLHF changes the distribution of "good" outputs.
A natural question: do MoE experts specialize in meaningful ways, or is the routing arbitrary? Empirical studies on trained MoE models reveal striking patterns.
In language models, experts often specialize by domain. One expert handles code, another handles scientific text, another handles conversational language. This happens without any explicit supervision — the router learns to route tokens to the expert that produces the best loss for that domain. Think of it as emergent department formation in a company: people naturally gravitate toward the work they are best at.
In the addition task from our capstone, we expect even finer specialization. Some experts should learn to handle carry operations (the hard part), while others handle simple single-digit addition (the easy part). Some might specialize by digit position — handling the ones place, tens place, etc. Analyzing this specialization after training is one of the most interesting parts of the capstone.
When MoE training goes wrong, the symptoms are distinctive. Here is a diagnostic guide that will save you hours of debugging, both in the capstone and in production.
| Symptom | Likely Cause | Fix |
|---|---|---|
| Loss plateaus after 1K steps, never improves | Expert collapse — all tokens route to 1-2 experts. The "dead" experts have random weights that never receive gradients. | Increase α (balancing coefficient) from 0.01 to 0.1. If already high, check that the balancing loss is actually being added to the total loss (common bug: forgetting to add aux_loss). |
| Loss is good but eval accuracy is bad | Token dropping. High-capacity experts are hitting their capacity limit and dropping tokens during evaluation. Dropped tokens get degraded processing. | Increase capacity factor C from 1.0 to 1.5. Or switch to auxiliary-free balancing (DeepSeek-V3 style) which avoids the capacity factor entirely. |
| Training is 3x slower than expected for top-k routing | All-to-all communication is bottlenecked. Tokens are being sent across slow inter-node links instead of fast intra-node NVLink/ICI. | Co-locate frequently-paired experts on the same node. Use expert-slicing: split each expert across fewer GPUs rather than distributing experts across all GPUs. |
| Router probabilities are near-uniform (entropy too high) | Router has not learned to differentiate tokens. Often happens when the balancing loss is too strong — it overwhelms the language modeling signal. | Reduce α by 10x. Warm up the balancing loss: start with α=0 for the first 1000 steps, then ramp up linearly. |
| One expert has 10x the gradient norm of others | That expert is receiving all the "hard" tokens (high loss, high gradients). Common with unbalanced routing or when one domain is much harder than others. | Per-expert gradient clipping. Clip each expert's gradients independently rather than globally. This prevents the "hard" expert from destabilizing the others. |
Training MoE is hard. Serving MoE is harder. During training, you have a large batch of tokens that statistically covers all experts — every expert gets enough work to justify its GPU. During inference, a single request might only activate 2 experts, leaving the other 6 GPUs idle.
The simulation below shows tokens flowing through an MoE layer. Each token is routed to its top-2 experts. Watch how load balancing distributes tokens, and see what happens when you disable it — experts collapse.
Tokens enter from the left and route to their top-2 experts. Bar chart shows expert utilization. Toggle load balancing to see expert collapse.
Three landmark MoE designs represent different points on the design space. Understanding their differences teaches you the tradeoffs that matter.
| Feature | GShard (2020) | Mixtral 8x7B (2023) | DeepSeek-V3 (2024) |
|---|---|---|---|
| Origin | Mistral AI | DeepSeek | |
| Total params | 600B | 46.7B | 671B |
| Active params | ~100B | 12.9B | 37B |
| Num experts | 2048 | 8 | 256 + 1 shared |
| Top-k | 2 | 2 | 8 |
| Activation ratio | 0.1% | 25% | 3.1% |
| Expert size | Very small (fine-grained) | Large (full 7B FFN each) | Small (fine-grained) |
| Shared expert | No | No | Yes (1) |
| Balancing | Capacity factor + aux loss | Aux loss only | Bias term (no aux loss) |
| Key innovation | Scaled MoE to 600B+ params | Open-weight MoE that beats dense Llama 2 70B | Auxiliary-free balancing + shared experts |
The trend is clear: newer designs use more experts, finer granularity, and more sophisticated balancing. GShard proved MoE works at scale. Mixtral proved it works in the open-weights community. DeepSeek-V3 pushed the frontier with novel routing and training techniques.
Let's make the savings concrete with a worked example comparing training costs.
The efficient MoE implementation avoids computing all experts by physically sorting tokens by expert assignment. This is the permute-compute-unpermute pattern used in production systems.
jax def efficient_moe_forward(x, expert_weights_up, expert_weights_down, gates, indices): """Efficient MoE: sort tokens by expert, compute, unsort.""" B, S, D = x.shape x_flat = x.reshape(-1, D) # [B*S, D] T = B * S # Step 1: PERMUTE — sort tokens by expert assignment # For top-2: each token appears twice (once per selected expert) expert_ids = indices.reshape(-1) # [T*k] — flat list of expert assignments token_ids = jnp.repeat(jnp.arange(T), 2) # [T*k] — which token each entry came from # Sort by expert id to group tokens for each expert together sort_order = jnp.argsort(expert_ids) sorted_expert_ids = expert_ids[sort_order] sorted_token_ids = token_ids[sort_order] # Gather sorted tokens sorted_x = x_flat[sorted_token_ids] # [T*k, D] # Count tokens per expert (group_sizes for ragged_dot) group_sizes = jnp.bincount(sorted_expert_ids, length=NUM_EXPERTS) # Step 2: COMPUTE — run each expert on its tokens # Using ragged_dot or sequential per-expert matmul h = jax.lax.ragged_dot(sorted_x, expert_weights_up, group_sizes) h = jax.nn.gelu(h) y = jax.lax.ragged_dot(h, expert_weights_down, group_sizes) # Step 3: UNPERMUTE — scatter results back to original token order # Weight by gate values and sum the k=2 contributions per token gate_flat = gates.reshape(-1) # [T*k] sorted_gates = gate_flat[sort_order] y_weighted = y * sorted_gates[:, None] # Scatter-add back to original positions output = jnp.zeros_like(x_flat) output = output.at[sorted_token_ids].add(y_weighted) return output.reshape(B, S, D)
Your MoE model is training on a TPU v4 pod. You profile the training step and discover that 35% of the step time is spent inside the MoE layer — specifically, inside jax.lax.ragged_dot. The operation handles variable-length expert batches correctly, but its generic implementation is not exploiting the specific structure of your problem. Your experts all have the same hidden dimension. Your batch sizes per expert, while variable, fall into a predictable range. You know that if you could write a custom kernel that tiles the computation exactly right for your access pattern, you could cut that 35% down to 20%.
This is where Pallas comes in. Pallas is JAX's framework for writing custom accelerator kernels — for both TPUs and GPUs — directly in Python. No CUDA C++. No XLA HLO surgery. You write Python functions decorated with a few Pallas-specific annotations, and the framework compiles them to efficient machine code for your target hardware.
Pallas is a kernel authoring API that sits between JAX's high-level NumPy-like interface and the low-level hardware instruction set. Think of it as JAX's answer to Triton (for NVIDIA GPUs) — except Pallas targets both TPUs and GPUs with a unified programming model.
| Abstraction Level | Technology | You write | You control |
|---|---|---|---|
| Highest | JAX / NumPy | jnp.matmul(A, B) | Nothing — XLA decides everything |
| Mid | Pallas | Grid + BlockSpec + kernel body | Tiling, memory placement, accumulation |
| Lowest | Mosaic (TPU) / PTX (GPU) | Raw hardware instructions | Everything — registers, pipelines, barriers |
The key insight of Pallas is declarative data movement. Instead of manually loading tiles from global memory into shared memory (as you would in CUDA), you declare what data each block needs via a BlockSpec, and the compiler figures out how to move it. This is a fundamental difference from CUDA, where you write the loads yourself.
A BlockSpec describes the shape and indexing pattern of one input or output. It tells Pallas: "For grid cell (i, j), give me this slice of the array." The compiler then generates the appropriate memory load/store instructions.
Why is this better than writing loads manually? Three reasons. First, the compiler can prefetch: it knows what data the next grid cell will need before you ask. Second, it can double-buffer: while the kernel computes on one tile, the hardware asynchronously loads the next tile. Third, it can optimize memory placement: on TPU, it decides whether to put your tile in VMEM (fast, small) or HBM (slow, large). You do not have to think about any of this.
A Pallas kernel executes on a grid — a multi-dimensional array of blocks. Each block processes one tile of the computation. The grid dimensions determine the parallelism and tiling strategy.
The k-dimension loop is the reduction axis. On each iteration, we load a tile of A and a tile of B, multiply them, and accumulate into the output tile. The accumulator stays in fast on-chip memory (VMEM on TPU, registers on GPU) across all k iterations.
Understanding the memory hierarchy is essential for writing fast kernels.
| Memory | Hardware | Size | Bandwidth | Latency | You manage? |
|---|---|---|---|---|---|
| HBM | TPU & GPU | 16-80 GB | ~2 TB/s (GPU), ~1.2 TB/s (TPU) | ~400 cycles | Pallas handles via BlockSpec |
| VMEM | TPU only | 16-32 MB per core | ~100 TB/s (on-chip) | ~5 cycles | Pallas auto-manages tiles here |
| SMEM | GPU only | 48-228 KB per SM | ~20 TB/s (on-chip) | ~20 cycles | Pallas uses for tile staging |
| Registers | Both | ~256 KB per core/SM | Infinite (local) | 0 cycles | Compiler assigns |
The key ratio is compute-to-memory. A TPU v4 core does ~275 TFLOPS but can only load ~1.2 TB/s from HBM. That means for every byte you load, you need to do ~230 FLOPs to keep the compute units busy. If your kernel does fewer FLOPs per byte loaded, it is memory-bound and the compute units are starving. If it does more, it is compute-bound and the memory system can keep up.
Let's start with the simplest possible Pallas kernel — adding two vectors. This is memory-bound and will not be faster than jnp.add, but it teaches the API structure.
jax/pallas import jax import jax.numpy as jnp from jax.experimental import pallas as pl def vector_add_kernel(x_ref, y_ref, o_ref): """Pallas kernel body. Operates on one block at a time. x_ref, y_ref, o_ref are Ref objects — views into the block. """ o_ref[...] = x_ref[...] + y_ref[...] def vector_add(x, y): """Launch the Pallas kernel over a 1D grid.""" n = x.shape[0] BLOCK = 512 # process 512 elements per block grid = (n // BLOCK,) # 1D grid return pl.pallas_call( vector_add_kernel, grid=grid, in_specs=[ pl.BlockSpec(block_shape=(BLOCK,), index_map=lambda i: (i,)), pl.BlockSpec(block_shape=(BLOCK,), index_map=lambda i: (i,)), ], out_specs=pl.BlockSpec(block_shape=(BLOCK,), index_map=lambda i: (i,)), out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), )(x, y) # Test x = jnp.ones(4096) y = jnp.ones(4096) * 2 result = vector_add(x, y) # [3.0, 3.0, 3.0, ...] print(result[:5]) # [3. 3. 3. 3. 3.]
Let's dissect every piece:
| Component | What it does | Analogy |
|---|---|---|
| kernel function | The computation for one block. Receives Ref objects (views into tiles), writes output. | The body of a CUDA __global__ function |
| grid | How many blocks to launch. (8,) means 8 blocks, each processing one tile. | CUDA grid dimensions |
| BlockSpec | Declares the shape and index mapping for each input/output tile. | The shared memory load pattern in CUDA |
| index_map | Lambda that maps grid indices to tile indices. lambda i: (i,) means block i reads tile i. | blockIdx.x * blockDim.x in CUDA |
| out_shape | The full shape and dtype of the output. Pallas allocates this. | The output buffer you'd cudaMalloc |
| pallas_call | Wraps everything into a JAX-compatible function. JIT-compiles the kernel. | cudaLaunchKernel |
Now let's write a tiled matrix multiplication. This is the kernel that actually matters for MoE — every expert FFN is a matmul, and how you tile it determines your performance.
jax/pallas def matmul_kernel(a_ref, b_ref, o_ref, *, BLOCK_K): """Tiled matmul kernel. Accumulates over K dimension.""" # o_ref is initialized to zero by Pallas (we'll specify this) # a_ref: [BLOCK_M, BLOCK_K], b_ref: [BLOCK_K, BLOCK_N] o_ref[...] += a_ref[...] @ b_ref[...] def matmul(a, b, BLOCK_M=128, BLOCK_N=128, BLOCK_K=64): """Pallas tiled matmul: C[M,N] = A[M,K] @ B[K,N]""" M, K = a.shape K2, N = b.shape assert K == K2 grid = (M // BLOCK_M, N // BLOCK_N, K // BLOCK_K) # 3D grid return pl.pallas_call( functools.partial(matmul_kernel, BLOCK_K=BLOCK_K), grid=grid, in_specs=[ # A: for grid cell (i, j, k), read block at row i, col k pl.BlockSpec( block_shape=(BLOCK_M, BLOCK_K), index_map=lambda i, j, k: (i, k), ), # B: for grid cell (i, j, k), read block at row k, col j pl.BlockSpec( block_shape=(BLOCK_K, BLOCK_N), index_map=lambda i, j, k: (k, j), ), ], out_specs=pl.BlockSpec( block_shape=(BLOCK_M, BLOCK_N), index_map=lambda i, j, k: (i, j), # output tile at (i, j) ), out_shape=jax.ShapeDtypeStruct((M, N), a.dtype), # CRITICAL: the k dimension is a reduction axis. # We accumulate across k, so the output is initialized to 0 # and += is used in the kernel body. grid_spec=pl.GridSpec( grid=(M // BLOCK_M, N // BLOCK_N, K // BLOCK_K), in_specs=[ pl.BlockSpec((BLOCK_M, BLOCK_K), lambda i, j, k: (i, k)), pl.BlockSpec((BLOCK_K, BLOCK_N), lambda i, j, k: (k, j)), ], out_specs=pl.BlockSpec((BLOCK_M, BLOCK_N), lambda i, j, k: (i, j)), ), )(a, b)
+=) into the output tile across all k iterations. The accumulator lives in fast on-chip memory (VMEM/registers) and is only written back to HBM once, after all k blocks are processed. This is the same tiling trick that makes CUDA matmul fast.Before reaching for Pallas, you should understand what JAX's compiler (XLA) does automatically. XLA already performs aggressive optimizations, and sometimes your "custom kernel" is actually slower than what XLA generates.
| XLA Optimization | What It Does | Example |
|---|---|---|
| Operator fusion | Combines adjacent elementwise ops into a single kernel | jax.nn.gelu(x + bias) becomes one kernel, not two |
| Layout optimization | Rearranges memory layout for optimal access patterns | Transposes a matrix before matmul if that reduces HBM reads |
| Common subexpression elimination | Computes shared subexpressions once | a*b + a*b computes a*b once |
| Constant folding | Pre-computes operations on known constants at compile time | jnp.ones((3,3)) * 2 becomes a constant tensor |
| Buffer aliasing | Reuses memory buffers when a tensor is consumed and no longer needed | After h = gelu(x), x's memory can be reused for the next layer |
XLA's fusion is particularly good for elementwise chains. If your "slow operation" is a chain of x + bias, gelu(x), dropout(x), XLA has already fused them into one kernel. Writing a Pallas version will not help. Where Pallas wins is on operations that XLA cannot fuse: matmul followed by elementwise followed by another matmul. XLA cannot fuse across matmul boundaries because the tiling strategy changes. Pallas can.
How do you know if XLA is already optimizing your code well? Inspect the HLO (High-Level Optimizer intermediate representation). This shows you exactly what kernels XLA generates.
python import jax # Method 1: Print the HLO for a function def my_function(x, w): h = x @ w h = jax.nn.gelu(h) return h # Lower to HLO (shows the computation graph) lowered = jax.jit(my_function).lower( jnp.ones((128, 4096)), jnp.ones((4096, 11008)) ) print(lowered.as_text()) # Shows fusion decisions # Method 2: Full XLA dump (for deep debugging) # Run with: XLA_FLAGS="--xla_dump_to=/tmp/xla_dump" # Generates .hlo files showing every optimization pass # Method 3: Profile with JAX profiler with jax.profiler.trace("/tmp/jax-trace"): result = my_function(x, w) result.block_until_ready() # Open the trace in perfetto or TensorBoard to see kernel timings
Now the real problem. In your MoE layer, you need to do a batched matmul where each batch element (expert) has a different number of tokens. The straightforward approach is padding:
jax.lax.ragged_dot avoids this waste by natively handling variable-length segments. But it is a general-purpose implementation. When you know your specific problem structure (e.g., expert dimensions are always the same, group sizes fall in a known range), you can write a custom Pallas kernel that is even faster.
Writing a Pallas kernel is an investment. Before you start, apply this decision tree:
Let's derive the condition where a custom Pallas kernel outperforms ragged_dot. The key insight: ragged_dot uses a conservative tiling strategy because it must handle arbitrary group sizes. If you know your group sizes at kernel-write time (or they cluster around known values), you can tile more aggressively.
Here is a simplified but complete Pallas kernel for the MoE forward pass. It processes one expert at a time with optimal tiling.
jax/pallas import functools from jax.experimental import pallas as pl def expert_matmul_kernel( x_ref, # [BLOCK_M, D] — input token tile w_ref, # [D, BLOCK_N] — expert weight tile (reduction axis tiled) o_ref, # [BLOCK_M, BLOCK_N] — output accumulator ): """Single-expert tiled matmul kernel.""" # Accumulate: load tile of input and weight, matmul, add to output o_ref[...] += x_ref[...] @ w_ref[...] def moe_forward_pallas( x, # [total_tokens, D] — pre-sorted by expert assignment expert_weights, # [num_experts, D, expert_dim] — stacked weights group_sizes, # [num_experts] — tokens per expert BLOCK_M=64, BLOCK_K=128, BLOCK_N=64, ): """Process each expert's tokens with a tuned Pallas matmul.""" num_experts, D, expert_dim = expert_weights.shape outputs = [] offset = 0 for e in range(num_experts): n_tokens = group_sizes[e] if n_tokens == 0: continue # Slice this expert's tokens and weight x_e = jax.lax.dynamic_slice( x, (offset, 0), (n_tokens, D) ) # [n_tokens, D] w_e = expert_weights[e] # [D, expert_dim] # Pad n_tokens to multiple of BLOCK_M n_padded = ((n_tokens + BLOCK_M - 1) // BLOCK_M) * BLOCK_M x_padded = jnp.pad(x_e, ((0, n_padded - n_tokens), (0, 0))) grid = (n_padded // BLOCK_M, expert_dim // BLOCK_N, D // BLOCK_K) result = pl.pallas_call( expert_matmul_kernel, grid=grid, in_specs=[ pl.BlockSpec((BLOCK_M, BLOCK_K), lambda i, j, k: (i, k)), pl.BlockSpec((BLOCK_K, BLOCK_N), lambda i, j, k: (k, j)), ], out_specs=pl.BlockSpec((BLOCK_M, BLOCK_N), lambda i, j, k: (i, j)), out_shape=jax.ShapeDtypeStruct((n_padded, expert_dim), x.dtype), )(x_padded, w_e) # Unpad and collect outputs.append(result[:n_tokens]) offset += n_tokens return jnp.concatenate(outputs, axis=0)
The canvas below shows how a matrix multiply is tiled into a Pallas grid. Each colored block represents one grid cell. Watch the accumulation pattern as blocks iterate along the K dimension.
A matrix C = A × B is tiled into blocks. Click "Step" to advance one grid cell at a time and watch the accumulator build up. The highlighted tiles show which blocks of A and B are being multiplied.
Pallas runs on both TPU and GPU, but the programming model has hardware-specific differences that affect how you write kernels.
| Aspect | Pallas on TPU (Mosaic) | Pallas on GPU (Triton backend) |
|---|---|---|
| Compute unit | MXU (Matrix Multiply Unit) — 128x128 systolic array | CUDA Cores + Tensor Cores |
| Natural tile size | 128x128 (matches MXU) | Varies (16x16 to 128x128 depending on Tensor Core generation) |
| On-chip memory | VMEM: 16-32 MB per core. Huge! | SMEM: 48-228 KB per SM. Much smaller. |
| Data movement | Compiler manages via BlockSpec. DMA handles async loads. | You hint via BlockSpec. Compiler maps to shared memory loads. |
| Parallelism model | Grid cells map to TPU cores. Each core has one MXU. | Grid cells map to thread blocks. Each block has multiple warps. |
| Matmul primitive | jax.lax.dot inside kernel maps directly to MXU | Maps to wmma or mma instructions via Tensor Cores |
| Best block sizes | 128-512 (large VMEM allows big tiles) | 32-128 (limited SMEM constrains tile size) |
| Reduction axis | Built into grid_spec — compiler unrolls the k-loop | Manual loop inside kernel body |
The biggest practical difference: TPU's VMEM is 100-500x larger than GPU's SMEM. This means you can use much larger block sizes on TPU, which improves compute utilization (fewer loop iterations, less overhead per tile). On GPU, you are forced to use smaller tiles and more iterations.
In practice, engineers at frontier labs do not write Pallas kernels for standard matmuls — XLA handles those well. They write kernels for operations that fall outside XLA's optimization reach. Here are the most common use cases:
| Use Case | Why XLA is Not Enough | The Pallas Solution |
|---|---|---|
| MoE expert dispatch | Variable-length batches per expert. XLA pads to max length, wasting compute. | Custom kernel that processes each expert's tokens without padding, fusing the permute-compute-unpermute. |
| Flash-style attention | Standard attention materializes the S×S matrix. XLA cannot discover the tiled-softmax trick. | Pallas kernel implementing online softmax with tile-by-tile processing, keeping the running max and sum in registers. |
| Custom activations | Novel activation functions (e.g., SwiGLU with gating) create fusion barriers in XLA. | Fuse the gated activation into the preceding matmul. The activation happens on tiles already in VMEM. |
| Quantized matmul | INT8/INT4 matmul with per-channel dequantization. XLA does not always find the optimal fusion pattern. | Custom kernel that loads quantized weights, dequantizes in-register, and accumulates in BF16/FP32. Avoids writing dequantized weights to HBM. |
| Sparse attention patterns | Block-sparse or sliding-window attention has irregular access patterns that XLA handles conservatively. | Custom kernel with specialized BlockSpec index_maps that only load the non-zero blocks, skipping empty regions entirely. |
The block sizes (BLOCK_M, BLOCK_K, BLOCK_N) are the most important tuning knobs. Too small and you have too many blocks (overhead dominates). Too large and tiles do not fit in fast memory (spill to HBM, killing performance).
| Block Size | TPU v4 Sweet Spot | GPU (A100) Sweet Spot | Why |
|---|---|---|---|
| BLOCK_M | 128-256 | 64-128 | Batch dimension. Larger = better compute utilization, but needs more VMEM/SMEM. |
| BLOCK_K | 128-256 | 32-64 | Reduction dimension. Larger = fewer loop iterations, but wider loads from HBM. |
| BLOCK_N | 128-256 | 64-128 | Output dimension. Larger = more parallelism per block. |
The constraint: BLOCK_M × BLOCK_K + BLOCK_K × BLOCK_N + BLOCK_M × BLOCK_N must fit in fast memory (VMEM on TPU, SMEM on GPU). For TPU v4 with 16 MB VMEM per core at BF16 (2 bytes), you can fit ~8M elements across all three tiles. For a A100 with 192 KB SMEM at FP16, you can fit ~96K elements.
Sometimes your kernel needs temporary storage that does not correspond to any input or output. For example, in a fused attention kernel, you need to store the running softmax denominator and maximum. Pallas provides scratch memory for this: on-chip buffers that exist only for the lifetime of a single grid cell's computation.
jax/pallas from jax.experimental.pallas import MemorySpace def attention_kernel(q_ref, k_ref, v_ref, o_ref, m_scratch, l_scratch): """Simplified flash-attention-style kernel with scratch memory. m_scratch: [BLOCK_M] — running max of attention scores l_scratch: [BLOCK_M] — running sum of exp(scores - max) """ # Compute attention scores for this tile scores = q_ref[...] @ k_ref[...].T # [BLOCK_M, BLOCK_N] # Online softmax: update running max and sum tile_max = scores.max(axis=-1) # [BLOCK_M] new_max = jnp.maximum(m_scratch[...], tile_max) # Correction factor for previous tiles correction = jnp.exp(m_scratch[...] - new_max) # Update running sum tile_exp = jnp.exp(scores - new_max[:, None]) new_sum = l_scratch[...] * correction + tile_exp.sum(axis=-1) # Update output accumulator with correction o_ref[...] = o_ref[...] * correction[:, None] + tile_exp @ v_ref[...] # Update scratch buffers m_scratch[...] = new_max l_scratch[...] = new_sum # Scratch memory is specified as additional BlockSpec with MemorySpace.VMEM # It is allocated on-chip and never touches HBM. # This is how FlashAttention avoids materializing the full S×S matrix.
The scratch memory pattern is what makes Pallas powerful for attention and other reduction operations. Without it, you would need to write intermediate results to HBM and read them back — exactly the inefficiency that FlashAttention eliminates.
The biggest wins from Pallas come not from writing faster matmuls (XLA already does that well) but from fusing multiple operations into a single kernel. Every time data travels to HBM and back, you pay a latency penalty. Fusion eliminates these round-trips.
Consider the MoE forward pass for a single expert:
For an expert FFN with d_model=4096 and d_ff=11008, the intermediate activation h is 11008 floats per token. At BF16, that is 22 KB per token. For a batch of 128 tokens, the intermediate is 2.8 MB — well within VMEM. This data is loaded from HBM 3 times in the unfused version (once per kernel) but stays on-chip in the fused version. At 1.2 TB/s HBM bandwidth, those extra loads cost ~7 microseconds per round-trip. Across 256 experts and 60+ layers, this adds up to milliseconds per training step.
Pallas kernels are compiled, not interpreted. When something goes wrong, you do not get a Python traceback pointing to the offending line. Instead you get cryptic XLA errors or, worse, silently wrong results. Here is a systematic debugging approach.
| Problem | Symptom | Debug Strategy |
|---|---|---|
| Shape mismatch | XLA compilation error mentioning "incompatible shapes" | Print the expected shapes from your BlockSpec. Verify: grid[i] * block_shape[dim] == array_shape[dim] for every dimension. Off-by-one in grid calculation is the #1 bug. |
| Wrong results | Output does not match reference (jnp.matmul) | Test with BLOCK sizes equal to the full matrix (grid=(1,1,1)). If this works, the bug is in tiling/accumulation. Halve block sizes until you find the boundary that breaks. |
| NaN on large inputs | Works on small, NaN on large | The last tile has fewer valid elements than BLOCK_SIZE. Uninitialized memory in the tile causes NaN. Add explicit padding or boundary masking. |
| Performance worse than jnp | Kernel is slower than built-in op | Check if your kernel is memory-bound. Profile with jax.profiler. If FLOP utilization is <30%, try larger block sizes. If HBM utilization is >80%, you are memory-bound — fusion is the answer, not better tiling. |
| Compilation takes minutes | pallas_call hangs during JIT | Reduce grid dimensions. Very large grids (>10000 cells) can cause the compiler to spend excessive time optimizing. Try fusing grid dimensions: (M, N) grid instead of (M, N, K). |
jnp.matmul, jnp.einsum). Run both on the same random input. Compare with jnp.allclose(pallas_out, ref_out, atol=1e-5). If they match on random data, test on structured data (all zeros, all ones, identity matrix). If they match everywhere, your kernel is correct — profile it for performance.On TPU, Pallas supports prefetching hints that tell the hardware to start loading the next tile before the current tile's computation finishes. This hides memory latency behind compute.
jax/pallas (advanced) # Prefetch hint: tell the compiler to load the next k-block # while the current k-block is being computed. # This is TPU-specific and requires Pallas's prefetch API. def matmul_prefetch_kernel(a_ref, b_ref, o_ref, a_scratch, b_scratch): """Matmul with double-buffered prefetching.""" # a_scratch and b_scratch are allocated in VMEM by Pallas # They hold the prefetched next tile while we compute on the current # Compute current tile (data already in VMEM from previous prefetch) o_ref[...] += a_ref[...] @ b_ref[...] # The compiler automatically pipelines: next tile loads overlap # with current tile's compute. On TPU v4, this hides ~80% of # memory latency for compute-bound kernels. # Scratch memory specification: # pl.BlockSpec with memory_space=pl.MemorySpace.VMEM # tells Pallas to allocate these in fast on-chip memory
Double-buffering works because TPU has separate data movement units (DMAs) and compute units (MXUs). While the MXU multiplies the current tile, the DMA loads the next tile. When the MXU finishes, the next tile is already in VMEM — zero wait time. This is the same principle as CPU prefetching, but at a much larger scale (megabytes instead of cache lines).
If you are targeting NVIDIA GPUs, you have a choice: Triton (Python DSL, mature ecosystem, broad adoption) or Pallas (JAX-native, works on TPU and GPU, newer). Here is an honest comparison.
| Dimension | Pallas | Triton |
|---|---|---|
| Hardware | TPU + GPU | GPU only (NVIDIA, AMD) |
| Ecosystem | Niche — mostly Google/DeepMind/Anthropic | Broad — PyTorch, vLLM, HuggingFace |
| Maturity | Young (2023-). API still changing. | Established (2019-). Stable API. |
| Memory model | Declarative (BlockSpec — compiler manages) | Explicit (you write tl.load, tl.store) |
| Debugging | Harder — compiled through XLA, cryptic errors | Easier — interpret mode, torch integration |
| Best for | TPU workloads, JAX-native projects, Anthropic/Google shops | NVIDIA GPU kernels, PyTorch projects, broad deployment |
To show that the concepts transfer, here is the same vector add kernel written in both Pallas and Triton side by side.
pallas (JAX) # Pallas version: declarative data movement def vadd_kernel(x_ref, y_ref, o_ref): o_ref[...] = x_ref[...] + y_ref[...] def vadd(x, y): return pl.pallas_call( vadd_kernel, grid=(x.shape[0] // 512,), in_specs=[ pl.BlockSpec((512,), lambda i: (i,)), pl.BlockSpec((512,), lambda i: (i,)), ], out_specs=pl.BlockSpec((512,), lambda i: (i,)), out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), )(x, y)
triton (PyTorch) # Triton version: explicit loads and stores import triton import triton.language as tl @triton.jit def vadd_kernel(x_ptr, y_ptr, o_ptr, n, BLOCK: tl.constexpr): pid = tl.program_id(0) offsets = pid * BLOCK + tl.arange(0, BLOCK) mask = offsets < n # boundary check (Pallas does this via BlockSpec) x = tl.load(x_ptr + offsets, mask=mask) # explicit load y = tl.load(y_ptr + offsets, mask=mask) tl.store(o_ptr + offsets, x + y, mask=mask) # explicit store def vadd(x, y): o = torch.empty_like(x) n = x.numel() grid = (triton.cdiv(n, 512),) vadd_kernel[grid](x, y, o, n, BLOCK=512) return o
Notice the key differences: Pallas uses BlockSpec to declare data movement (the compiler generates loads/stores). Triton uses explicit tl.load/tl.store calls (you control the loads). Pallas requires out_shape upfront (it pre-allocates). Triton takes a pre-allocated output pointer. The kernel body logic (the actual addition) is identical in both — because the math does not change between hardware platforms.
A critical question for training: can you differentiate through a Pallas kernel? Yes — but with a caveat. JAX's autodiff system (jax.grad, jax.vjp) works on Pallas kernels, but only if you provide a custom VJP rule (vector-Jacobian product). The compiler cannot automatically differentiate through the low-level kernel code.
jax/pallas import jax from jax import custom_vjp @custom_vjp def my_pallas_matmul(a, b): """Forward pass: use Pallas kernel.""" return pallas_matmul(a, b) # our custom Pallas implementation def my_pallas_matmul_fwd(a, b): """Forward pass + save residuals for backward.""" result = pallas_matmul(a, b) return result, (a, b) # save inputs for backward def my_pallas_matmul_bwd(res, g): """Backward pass: compute gradients. For C = A @ B: dA = dC @ B^T dB = A^T @ dC """ a, b = res # Can use Pallas kernels for backward matmuls too! da = pallas_matmul(g, b.T) db = pallas_matmul(a.T, g) return da, db my_pallas_matmul.defvjp(my_pallas_matmul_fwd, my_pallas_matmul_bwd) # Now jax.grad works through my_pallas_matmul! loss_fn = lambda a, b: my_pallas_matmul(a, b).sum() grads = jax.grad(loss_fn)(a, b)
Everything in this lesson has been building to this chapter. Vlad Mikulik's blog post for Anthropic explicitly names this exercise: train a ~10 million parameter transformer from scratch that learns addition, compare dense vs MoE architectures, derive the Chinchilla-optimal training configuration, and write a custom Pallas kernel that beats ragged_dot for the MoE variant. This is the proof of work. Not a toy exercise — a miniature version of what you will do daily at a frontier lab.
We will design every component from first principles: the tokenizer, the architecture, the dataset, the training loop, the scaling law analysis, and the Pallas kernel. By the end of this chapter, you will have a complete implementation plan that you can execute on a free Colab TPU in under an hour.
Addition is the perfect capstone task for three reasons. First, it is learnable by a small model — you do not need billions of parameters or weeks of training. Second, it has a crisp success metric — the model either outputs the correct sum or it does not. Third, it exercises genuine algorithmic reasoning — the model must learn to propagate carries across digit positions, which requires attending to non-local dependencies.
We keep the tokenizer dead simple. Thirteen tokens are sufficient:
| Token | ID | Purpose |
|---|---|---|
| 0-9 | 0-9 | Digit tokens |
| + | 10 | Addition operator |
| = | 11 | Equals sign (transition from input to output) |
| <PAD> | 12 | Padding token for fixed-length sequences |
Why separate digit tokens instead of multi-digit numbers? Because multi-digit tokenization introduces a combinatorial explosion. "123" as a single token means your vocabulary grows to 1000+ for 3-digit numbers. With single-digit tokens, your vocabulary stays at 13 regardless of number length. The model must learn positional alignment, but that is the point — we want it to learn this.
python # Tokenizer: dead simple character-level encoding VOCAB = {'0':0, '1':1, '2':2, '3':3, '4':4, '5':5, '6':6, '7':7, '8':8, '9':9, '+':10, '=':11, '<PAD>':12} VOCAB_SIZE = 13 def encode(s): """Convert "1 2 3 + 4 5 6 = 5 7 9" to token IDs.""" tokens = [] for ch in s.split(): if ch in VOCAB: tokens.append(VOCAB[ch]) else: raise ValueError(f"Unknown token: {ch}") return tokens def decode(ids): """Convert token IDs back to string.""" inv = {v: k for k, v in VOCAB.items()} return ' '.join(inv[i] for i in ids if i != 12)
We need to choose d_model, n_layers, n_heads, and d_ff to hit approximately 10 million parameters. Let's derive this.
| Hyperparameter | Value | Rationale |
|---|---|---|
| d_model | 384 | Divisible by 6 heads. Large enough to represent digit relationships. |
| n_heads | 6 | d_head = 384/6 = 64. Standard head dimension. |
| n_layers | 6 | Deep enough for carry propagation (needs multi-hop reasoning). |
| d_ff | 1536 | 4 × d_model. Standard FFN expansion ratio. |
| max_seq_len | 32 | Enough for "9 9 9 9 9 + 9 9 9 9 9 = 1 0 0 0 0 0" (21 tokens). |
| vocab_size | 13 | 10 digits + 3 special tokens. |
We generate addition problems randomly, with a curriculum that starts with small numbers and gradually increases.
python import random import jax.numpy as jnp def generate_addition_example(max_digits=5): """Generate one addition problem with answer.""" # Random number of digits for each operand d1 = random.randint(1, max_digits) d2 = random.randint(1, max_digits) a = random.randint(0, 10**d1 - 1) b = random.randint(0, 10**d2 - 1) c = a + b # Convert to space-separated digits a_str = ' '.join(str(a)) b_str = ' '.join(str(b)) c_str = ' '.join(str(c)) # Full sequence: "1 2 3 + 4 5 6 = 5 7 9" seq = f"{a_str} + {b_str} = {c_str}" return seq def make_batch(batch_size, max_digits, max_len=32): """Generate a padded batch of addition problems.""" inputs = [] targets = [] for _ in range(batch_size): seq = generate_addition_example(max_digits) tokens = encode(seq) # Input: everything up to and including "=" # Target: everything after "=" eq_pos = tokens.index(11) # position of "=" # For autoregressive training: input is full sequence shifted right # Target is full sequence shifted left # We mask the loss on the input portion (before "=") inp = tokens[:-1] # all tokens except last tgt = tokens[1:] # all tokens except first # Pad to max_len pad_len = max_len - len(inp) inp = inp + [12] * pad_len tgt = tgt + [12] * pad_len # Loss mask: only supervise tokens after "=" mask = [0] * eq_pos + [1] * (len(tokens) - 1 - eq_pos) + [0] * pad_len inputs.append(inp) targets.append(tgt) return { 'input_ids': jnp.array(inputs), # [B, max_len] 'targets': jnp.array(targets), # [B, max_len] 'loss_mask': jnp.array(mask), # [B, max_len] }
jax/flax import jax import jax.numpy as jnp import flax.linen as nn import optax class TransformerBlock(nn.Module): d_model: int = 384 n_heads: int = 6 d_ff: int = 1536 dropout_rate: float = 0.1 @nn.compact def __call__(self, x, mask, deterministic=True): # Pre-norm attention h = nn.LayerNorm()(x) h = nn.MultiHeadDotProductAttention( num_heads=self.n_heads, qkv_features=self.d_model, deterministic=deterministic, )(h, h, mask=mask) h = nn.Dropout(self.dropout_rate)(h, deterministic=deterministic) x = x + h # residual # Pre-norm FFN h = nn.LayerNorm()(x) h = nn.Dense(self.d_ff)(h) h = jax.nn.gelu(h) h = nn.Dense(self.d_model)(h) h = nn.Dropout(self.dropout_rate)(h, deterministic=deterministic) x = x + h # residual return x class AdditionTransformer(nn.Module): vocab_size: int = 13 d_model: int = 384 n_heads: int = 6 n_layers: int = 6 d_ff: int = 1536 max_len: int = 32 @nn.compact def __call__(self, x, deterministic=True): B, S = x.shape # Token + positional embeddings tok_emb = nn.Embed(self.vocab_size, self.d_model)(x) pos_emb = nn.Embed(self.max_len, self.d_model)(jnp.arange(S)) h = tok_emb + pos_emb # [B, S, d_model] # Causal mask: each position can only attend to earlier positions causal_mask = jnp.tril(jnp.ones((S, S)))[None, None, :, :] # Transformer layers for _ in range(self.n_layers): h = TransformerBlock( self.d_model, self.n_heads, self.d_ff )(h, causal_mask, deterministic) # Final LayerNorm + output projection h = nn.LayerNorm()(h) logits = nn.Dense(self.vocab_size)(h) # [B, S, vocab_size] return logits # ── Training setup ── model = AdditionTransformer() rng = jax.random.PRNGKey(42) dummy = jnp.zeros((1, 32), dtype=jnp.int32) params = model.init(rng, dummy) # Count parameters n_params = sum(p.size for p in jax.tree.leaves(params)) print(f"Total parameters: {n_params:,}") # ~10.6M # Optimizer: AdamW with cosine decay schedule = optax.warmup_cosine_decay_schedule( init_value=0.0, peak_value=3e-4, warmup_steps=500, decay_steps=50000, end_value=1e-5, ) optimizer = optax.adamw(schedule, weight_decay=0.01) opt_state = optimizer.init(params) # ── Loss function ── def loss_fn(params, batch): logits = model.apply(params, batch['input_ids'], deterministic=False, rngs={'dropout': jax.random.PRNGKey(0)}) # Cross-entropy loss, masked to answer tokens only labels = jax.nn.one_hot(batch['targets'], 13) loss = -jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1) loss = (loss * batch['loss_mask']).sum() / batch['loss_mask'].sum() return loss # ── Training step ── @jax.jit def train_step(params, opt_state, batch): loss, grads = jax.value_and_grad(loss_fn)(params, batch) updates, opt_state = optimizer.update(grads, opt_state, params) params = optax.apply_updates(params, updates) return params, opt_state, loss # ── Training loop ── for step in range(50000): batch = make_batch(batch_size=128, max_digits=5) params, opt_state, loss = train_step(params, opt_state, batch) if step % 500 == 0: print(f"Step {step:5d} | Loss: {loss:.4f}")
The Chinchilla paper found that for a compute budget C, the optimal model size and token count scale as:
For our 10.6M parameter model, the Chinchilla-optimal token budget is approximately 10.6M × 20 = 212M tokens. Each training example is about 15 tokens on average, so we need about 212M / 15 ≈ 14 million training examples. At batch size 128, that is about 110,000 training steps.
For the capstone comparison, we implement the same model in both dense and MoE variants. The MoE variant replaces the FFN in every other layer (layers 1, 3, 5) with an MoE layer, keeping total active parameters the same.
| Config | Dense | MoE (8 experts, top-2) |
|---|---|---|
| d_model | 384 | 384 |
| n_layers | 6 | 6 (3 dense FFN + 3 MoE) |
| d_ff | 1536 | 1536 per expert |
| Active params per token | 10.6M | ~10.6M (same) |
| Total params | 10.6M | ~31M (3 layers × 8 experts each) |
| FLOPs per token | ~127M | ~127M (same — top-2 of 8) |
| Memory | ~42 MB (FP32) | ~124 MB (FP32) |
jax/flax class MoETransformerBlock(nn.Module): """Transformer block with MoE FFN instead of dense FFN.""" d_model: int = 384 n_heads: int = 6 d_ff: int = 1536 num_experts: int = 8 top_k: int = 2 @nn.compact def __call__(self, x, mask, deterministic=True): # Pre-norm attention (identical to dense) h = nn.LayerNorm()(x) h = nn.MultiHeadDotProductAttention( num_heads=self.n_heads, qkv_features=self.d_model, deterministic=deterministic, )(h, h, mask=mask) x = x + h # Pre-norm MoE FFN (replaces dense FFN) h = nn.LayerNorm()(x) h_flat = h.reshape(-1, self.d_model) # [B*S, D] moe_out, aux_loss = MoELayer( num_experts=self.num_experts, top_k=self.top_k, expert_dim=self.d_ff, )(h_flat[jnp.newaxis]) # MoELayer expects [B, S, D] moe_out = moe_out.reshape(x.shape) x = x + moe_out self.sow('intermediates', 'aux_loss', aux_loss) return x class MoEAdditionTransformer(nn.Module): """Addition transformer with MoE in alternating layers.""" vocab_size: int = 13 d_model: int = 384 n_layers: int = 6 @nn.compact def __call__(self, x, deterministic=True): B, S = x.shape tok_emb = nn.Embed(self.vocab_size, self.d_model)(x) pos_emb = nn.Embed(32, self.d_model)(jnp.arange(S)) h = tok_emb + pos_emb causal_mask = jnp.tril(jnp.ones((S, S)))[None, None, :, :] for i in range(self.n_layers): if i % 2 == 1: # Odd layers get MoE h = MoETransformerBlock()(h, causal_mask, deterministic) else: # Even layers stay dense h = TransformerBlock()(h, causal_mask, deterministic) h = nn.LayerNorm()(h) return nn.Dense(self.vocab_size)(h)
What does the model actually learn? After training, we can visualize the attention patterns to understand the algorithm the model has discovered.
For a problem like "4 8 + 5 3 = 1 0 1", the model needs to:
Empirically, trained addition transformers develop a characteristic attention pattern: early layers attend strongly to positionally-aligned digits (learning the alignment), middle layers attend to adjacent output positions (propagating carries), and the final layers attend globally (handling edge cases like leading digits).
Every hyperparameter choice has a reason. Here is the complete training configuration with justifications.
| Hyperparameter | Value | Why |
|---|---|---|
| Optimizer | AdamW | Standard for transformers. Adam handles the varying gradient scales across attention and FFN layers. Weight decay prevents overfitting on the finite addition patterns. |
| Peak LR | 3e-4 | Standard for small transformers. Too high (>1e-3) causes training instability. Too low (<1e-5) slows convergence. |
| Warmup steps | 500 | Prevents early loss spikes from large random gradients. 1% of total training steps is a good rule of thumb. |
| Schedule | Cosine decay | Smooth decay from peak to near-zero. Better than step decay for transformers because it avoids discontinuous learning rate changes. |
| Weight decay | 0.01 | Mild regularization. The model should not memorize specific addition pairs but learn the algorithm. Weight decay encourages simpler solutions. |
| Batch size | 128 | Large enough for stable gradient estimates. Small enough to fit in TPU memory. With 128 problems per batch, the model sees ~1920 tokens per step. |
| Dropout | 0.1 | Regularization during training. Prevents co-adaptation of attention heads. Applied to attention weights and FFN output. |
| Max seq length | 32 | The longest possible sequence is "9 9 9 9 9 + 9 9 9 9 9 = 1 9 9 9 9 8" (21 tokens). 32 gives headroom for padding. |
Based on the scaling laws and empirical results from similar toy experiments, here is what you should expect:
| Metric | Dense 10.6M | MoE 31M (10.6M active) |
|---|---|---|
| Training loss at 50K steps | ~0.05 | ~0.03 |
| Exact-match accuracy (5-digit) | ~92% | ~96% |
| Exact-match accuracy (6-digit, OOD) | ~60% | ~75% |
| Training time (Colab TPU) | ~15 min | ~18 min |
| Step time | ~18 ms | ~22 ms |
The MoE model achieves lower loss at the same compute budget because it has 3x the parameters, giving it more capacity to memorize digit patterns and carry rules. The training time is only slightly longer because the active compute is the same — the overhead comes from routing and the auxiliary loss computation.
For the MoE variant, the ragged_dot inside the MoE layer is the bottleneck. Your Pallas kernel fuses the expert dispatching and matmul into a single kernel, avoiding the intermediate permutation.
jax/pallas def moe_fused_kernel( x_ref, # [BLOCK_M, D] — sorted input tokens w_up_ref, # [D, D_FF_BLOCK] — expert up-projection tile w_down_ref, # [D_FF_BLOCK, D] — expert down-projection tile o_ref, # [BLOCK_M, D] — output accumulator ): """Fused expert forward: up-project, GELU, down-project.""" # Up-projection h = x_ref[...] @ w_up_ref[...] # [BLOCK_M, D_FF_BLOCK] # GELU activation (fused — no round-trip to HBM!) h = h * jax.lax.erfc(-h / jnp.sqrt(2.0)) * 0.5 # Down-projection and accumulate o_ref[...] += h @ w_down_ref[...] # [BLOCK_M, D]
Before writing a kernel, you should know its theoretical maximum performance. The roofline model gives you this: it plots achievable performance (FLOPS) against arithmetic intensity (FLOPS per byte of memory traffic).
When you profile a Pallas kernel and find it achieves 50 TFLOPS instead of the theoretical 275 TFLOPS, the roofline model tells you whether the problem is (a) your kernel is memory-bound and 50 TFLOPS is actually close to the ceiling, or (b) your kernel has poor compute utilization and there is room to improve.
Correct benchmarking of accelerator kernels is tricky. Here are the common pitfalls and how to avoid them.
jax import jax import jax.numpy as jnp import time def benchmark_kernel(fn, *args, warmup=5, repeats=20): """Correct benchmarking for JAX/Pallas kernels.""" # Step 1: Compile the kernel (JIT warmup) # First call triggers compilation — don't time this! out = fn(*args) out.block_until_ready() # CRITICAL: wait for async compute # Step 2: Warmup runs (saturate caches, warm up the device) for _ in range(warmup): out = fn(*args) out.block_until_ready() # Step 3: Timed runs times = [] for _ in range(repeats): start = time.perf_counter() out = fn(*args) out.block_until_ready() # CRITICAL: JAX is async! end = time.perf_counter() times.append(end - start) # Report median (not mean — avoids outlier sensitivity) times.sort() median_ms = times[len(times) // 2] * 1000 print(f"Median: {median_ms:.3f} ms") print(f"Min: {times[0]*1000:.3f} ms, Max: {times[-1]*1000:.3f} ms") # Calculate FLOPS utilization # flops = ... (compute for your specific operation) # print(f"Achieved: {flops / median_s / 1e12:.1f} TFLOPS") return median_ms # COMMON MISTAKES: # 1. Forgetting block_until_ready() — times only CPU dispatch, not compute # 2. Including JIT compilation in the timing # 3. Using mean instead of median (one outlier ruins your numbers) # 4. Not warming up (first few runs are slower due to cache effects) # 5. Benchmarking tiny inputs (overhead dominates, not representative)
jnp.matmul(a, b), JAX returns a future immediately while the actual computation happens on the accelerator. If you time the function call without block_until_ready(), you are timing how fast Python can enqueue work (microseconds), not how fast the hardware executes it (milliseconds). This is the #1 benchmarking mistake in JAX. Always call .block_until_ready() on the output before reading the timer.Putting it all together, here is the workflow for writing a production Pallas kernel from scratch:
Once you have completed the capstone, you need to document it in a way that a frontier lab interviewer can evaluate. Here is the structure that maximizes signal:
Training loss is a proxy. What we actually care about is: does the model get the right answer? For addition, we have a crisp metric: exact-match accuracy. The model's output must be exactly correct — "101" is wrong if the answer is "100." There is no partial credit.
python def evaluate(params, model, num_samples=1000, max_digits=5): """Evaluate exact-match accuracy on random addition problems.""" correct = 0 for _ in range(num_samples): # Generate a random problem d1 = random.randint(1, max_digits) d2 = random.randint(1, max_digits) a = random.randint(0, 10**d1 - 1) b = random.randint(0, 10**d2 - 1) expected = a + b # Encode input: "1 2 3 + 4 5 = " prompt = ' '.join(str(a)) + ' + ' + ' '.join(str(b)) + ' =' tokens = encode(prompt) input_ids = jnp.array([tokens + [12] * (32 - len(tokens))]) # Autoregressive generation generated = [] for _ in range(8): # max answer length logits = model.apply(params, input_ids) next_token = jnp.argmax(logits[0, len(tokens) + len(generated) - 1]) if next_token == 12: # PAD = stop break generated.append(int(next_token)) # Append to input for next step tokens.append(int(next_token)) input_ids = jnp.array([tokens + [12] * (32 - len(tokens))]) # Check if generated digits form the correct number predicted = int(''.join(str(d) for d in generated)) if generated else -1 if predicted == expected: correct += 1 return correct / num_samples
Raw accuracy hides the most interesting patterns. Break down errors by type to understand how the model fails, not just that it fails.
| Error Type | Example | Frequency | What It Tells You |
|---|---|---|---|
| Carry propagation | 999 + 1 = 990 (missed carry chain) | ~60% of errors | Model has not learned multi-hop carry. Need more training or deeper model. |
| Off-by-one digit | 123 + 456 = 579 (correct!) but 123 + 457 = 579 (wrong by 1) | ~20% of errors | Model is "almost right" — the logit margin for the correct vs incorrect digit is tiny. More training data at this difficulty level helps. |
| Wrong length | 99 + 1 = 10 (should be 100, missing a digit) | ~10% of errors | Model has not learned that addition can increase the number of digits. This is fundamentally about predicting the output length before generating. |
| Positional confusion | 12 + 34 = 55 (treated as 21 + 34) | ~5% of errors | Positional encoding is not strong enough to distinguish digit positions. Consider explicit positional markers or reversed digit order. |
| Random/garbage | 123 + 456 = 333 | ~5% of errors | Model has not converged for this input range. Insufficient training. |
python def error_analysis(params, model, num_samples=500): """Categorize errors by type.""" errors = {'carry': 0, 'off_by_one': 0, 'wrong_length': 0, 'positional': 0, 'random': 0} total_errors = 0 for _ in range(num_samples): a, b = random.randint(0, 99999), random.randint(0, 99999) expected = a + b predicted = generate_answer(params, model, a, b) if predicted != expected: total_errors += 1 # Classify the error exp_str, pred_str = str(expected), str(predicted) if len(exp_str) != len(pred_str): errors['wrong_length'] += 1 elif abs(expected - predicted) == 1: errors['off_by_one'] += 1 elif needs_carry(a, b) and carry_error(expected, predicted): errors['carry'] += 1 else: errors['random'] += 1 print(f"Total errors: {total_errors}/{num_samples}") for k, v in errors.items(): print(f" {k}: {v} ({v/max(total_errors,1)*100:.1f}%)")
Throwing 5-digit addition at a randomly initialized model is like teaching calculus before arithmetic. Curriculum learning starts with easy problems and gradually increases difficulty. For addition:
Why does curriculum learning help? The model first masters the basic pattern (single-digit addition, no carry), then learns single carries, then carry chains. Each skill builds on the previous one. Without curriculum, the model sees hard problems before it has learned the prerequisites, wasting gradient signal on problems it has no hope of solving yet.
To derive Chinchilla-style scaling laws for your toy model, you need to train multiple model sizes at multiple compute budgets. Here is the experiment grid:
| Model Size | d_model | n_layers | Params | Token Budgets to Test |
|---|---|---|---|---|
| Tiny | 128 | 4 | 0.8M | 5M, 10M, 20M, 50M tokens |
| Small | 256 | 4 | 3.1M | 10M, 30M, 60M, 120M tokens |
| Medium | 384 | 6 | 10.6M | 50M, 100M, 200M, 400M tokens |
| Large | 512 | 8 | 25M | 100M, 250M, 500M, 1B tokens |
| XL | 768 | 8 | 56M | 250M, 500M, 1B, 2B tokens |
For each (model_size, token_budget) pair, train to completion and record the final loss. Then plot loss vs compute (= 6 × params × tokens) and fit the power law L(C) = A × C-α + L0. The fitted parameters A, α, and L0 are your toy-model scaling law. Compare them to the Chinchilla values (A=406.4, α=0.34 for their setup) — they will differ because addition is not natural language, but the power-law structure should hold.
The canvas below simulates the training loss curves for both the dense and MoE models. Watch how the MoE model converges faster and reaches a lower final loss, despite using the same compute per step.
Simulated training loss curves for the 10.6M addition transformer. Dense (orange) vs MoE 8-expert top-2 (teal). Click "Train" to animate.
You have spent fifteen chapters absorbing the technical foundation of what frontier AI labs expect from their research engineers. This final chapter is your war chest: the reading list, the interview question bank, the portfolio strategy, and the cross-references to every relevant Engineermaxxing resource. Bookmark this chapter. Return to it before every interview.
Frontier lab engineering roles split into two broad categories. Your path determines which skills to emphasize, but the strongest candidates have depth in both.
| Dimension | Below the Stack | Above the Abstraction |
|---|---|---|
| Focus | Making models run fast: kernels, compilers, distributed training, hardware | Making models do useful things: agents, evals, RLHF, data, reasoning |
| Typical titles | ML Systems Engineer, Kernel Engineer, Training Infrastructure | Research Engineer, Applied ML, Agent Infrastructure |
| Core skills | CUDA/Triton/Pallas, XLA, distributed systems, profiling, quantization | Prompt engineering, RLHF/DPO, evaluation design, tool use, planning |
| Interview emphasis | 70% CODE + DEBUG, 30% CONCEPT + DESIGN | 40% DESIGN + FRONTIER, 40% CODE, 20% CONCEPT |
| Portfolio artifact | Custom kernel that beats a baseline, scaling law reproduction | Agent system with real-world evals, RLHF pipeline, benchmark contribution |
| Labs hiring heavily | Anthropic, Google DeepMind, xAI, Meta FAIR (Infra) | Anthropic, OpenAI, Google DeepMind, Cohere |
Frontier labs filter hundreds of applications. Your portfolio needs to signal three things in under 60 seconds of scanning:
Concrete portfolio items, ranked by signal strength:
| Rank | Artifact | Time Investment | Signal |
|---|---|---|---|
| 1 | Capstone (Ch 15): addition transformer + MoE + Pallas kernel + scaling laws | 20-40 hours | Maximum — covers all dimensions |
| 2 | Reproduce a recent paper's key result (e.g., FlashAttention-4 on a subset) | 40-80 hours | Very high — proves deep understanding |
| 3 | Open-source contribution to JAX, vLLM, or similar project | 20-60 hours | High — proves you work with production code |
| 4 | Blog post explaining a paper with original code and experiments | 10-20 hours | Medium-high — good communication signal |
| 5 | Technical Twitter/blog presence discussing current research | Ongoing | Medium — shows engagement, not depth |
These are the papers you need to have read — and be prepared to discuss in detail — for a frontier lab interview. They are organized into five tracks. Within each track, papers are listed in recommended reading order (build on each other).
The compression frontier. Each paper introduces a technique that the next paper builds on. By the end, you should be able to explain why QTIP achieves near-lossless INT2 quantization.
| # | Paper | Key Contribution | Read After |
|---|---|---|---|
| 1 | LLM.int8() (Dettmers et al., 2022) | Mixed-precision decomposition: outlier features in FP16, rest in INT8 | — |
| 2 | GPTQ (Frantar et al., 2023) | One-shot weight quantization using inverse Hessian for error compensation | LLM.int8() |
| 3 | QuIP (Chee et al., 2023) | Random orthogonal transform before quantization — provably optimal incoherence | GPTQ |
| 4 | QuIP# (Tseng et al., 2024) | Lattice codebooks + incoherence processing for near-lossless 2-bit weights | QuIP |
| 5 | QTIP (Tseng et al., 2024) | Trellis-coded quantization — information-theoretic lower bound on compression | QuIP# |
| 6 | AQLM (Egiazarian et al., 2024) | Additive vector quantization for extreme 1-2 bit compression | QuIP# |
The computational engine of transformers. FlashAttention is not one paper — it is a lineage. Each version solves the previous version's limitation.
| # | Paper | Key Contribution | Read After |
|---|---|---|---|
| 1 | FlashAttention (Dao et al., 2022) | IO-aware exact attention: tiled computation that never materializes the full attention matrix in HBM | — |
| 2 | FlashAttention-2 (Dao, 2023) | Better work partitioning across warps: 2x faster than FA1 by reducing shared memory reads | FA1 |
| 3 | FlashAttention-3 (Shah et al., 2024) | Hopper-specific: warp specialization, FP8 attention, asynchronous softmax | FA2 |
| 4 | FlashAttention-4 (2025) | KV-cache-aware decode attention, variable-length prefill batching | FA3 |
| 5 | SnapKV (Li et al., 2024) | KV cache compression: identify and retain only informative KV pairs | FA2 |
The low-level tools that let you write custom operations. Understanding the progression from Triton to ThunderKittens to CuTe is essential for systems interviews.
| # | Paper/System | Key Contribution | Read After |
|---|---|---|---|
| 1 | Triton (Tillet et al., 2019) | Python-based GPU kernel DSL with automatic tiling and memory management | — |
| 2 | ThunderKittens (Spector et al., 2024) | Embedded C++ DSL for writing CUDA kernels at the warpgroup level | Triton |
| 3 | CuTe (NVIDIA, 2024) | Layout algebra for describing multi-dimensional data layouts on GPU memory hierarchies | ThunderKittens |
| 4 | Pallas (Google, 2023-2024) | JAX kernel authoring for TPU and GPU with BlockSpec-based data movement | Triton |
The "above the abstraction" reading list. These papers show how LLMs are used to build systems, not just generate text.
| # | Paper | Key Contribution | Read After |
|---|---|---|---|
| 1 | Toolformer (Schick et al., 2023) | Self-supervised tool use: model learns when and how to call APIs | — |
| 2 | FunSearch (Romera-Paredes et al., 2024) | LLM-guided program search that discovers new mathematical results | Toolformer |
| 3 | AlphaEvolve (DeepMind, 2025) | Evolution-guided code generation with automated verification | FunSearch |
| 4 | SWE-bench (Jimenez et al., 2024) | Real-world software engineering benchmark for LLM agents | FunSearch |
The theoretical foundation that drives every training decision at frontier labs. Without understanding scaling laws, you cannot reason about resource allocation.
| # | Paper | Key Contribution | Read After |
|---|---|---|---|
| 1 | Scaling Laws for Neural LMs (Kaplan et al., 2020) | Power-law relationships between compute, data, model size, and loss | — |
| 2 | Chinchilla (Hoffmann et al., 2022) | Corrected scaling: tokens matter as much as parameters. Optimal ratio ~20:1. | Kaplan |
| 3 | The Scaling Book (Pope et al., Google, 2024) | Comprehensive roofline analysis, transformer math, parallelism, and inference on TPU/GPU | Chinchilla |
| 4 | Distillation Scaling Laws (Busbridge et al., Apple ICML 2025) | Scaling laws for distillation — the capacity gap phenomenon where stronger teachers hurt | Chinchilla |
Five questions per dimension, calibrated to staff-level expectations. For each question, we note the key insight the interviewer is looking for.
| # | Question | Key Insight They Want |
|---|---|---|
| 1 | "Derive the memory bandwidth needed to serve a 70B parameter model at 100 tokens/sec in FP16." | 70B × 2 bytes = 140 GB. At 100 tok/s, you load all weights 100 times/sec = 14 TB/s. A single H100 has 3.35 TB/s HBM bandwidth, so you need at least ceil(14/3.35) = 5 GPUs minimum. But tensor parallelism overhead adds ~10%, so 6 GPUs in practice. |
| 2 | "Why does KV cache grow linearly with sequence length and batch size? Derive the formula." | KV cache per layer = 2 × batch × seq × n_heads × d_head × dtype_bytes. For 32 layers, 70B model (n_heads=64, d_head=128), batch=32, seq=4096, FP16: 2 × 32 × 4096 × 64 × 128 × 2 × 32 = 68.7 GB per token-batch. This is why KV cache, not weights, is often the memory bottleneck. |
| 3 | "Explain why the auxiliary load-balancing loss in MoE training uses f_i × p_i instead of just minimizing the variance of f_i." | f_i (token fraction) is not differentiable — it depends on argmax. p_i (mean router probability) is differentiable via softmax. Their product lets gradients flow while still penalizing imbalance. |
| 4 | "What is the theoretical minimum memory for training a B-parameter model with Adam?" | Model: 2B bytes (FP16). Gradients: 2B. Adam states (m and v): 4B each = 8B. Total: 12B bytes minimum. For 7B params: 84 GB. This is why FSDP/ZeRO-3 shards everything across GPUs. |
| 5 | "Derive the arithmetic intensity of self-attention for sequence length S and head dimension d." | Q@K^T: 2Sd FLOPs, loads 2Sd bytes. Softmax: O(S^2). Attn@V: 2S^2d FLOPs. Total FLOPs ≈ 4S^2d. Total bytes ≈ 4Sd + 2S^2. Intensity = 4S^2d / (4Sd + 2S^2) ≈ 2d for large S. Since d=128 typically, intensity ≈ 256 FLOPs/byte — attention is compute-bound for large S. |
| # | Question | Key Insight They Want |
|---|---|---|
| 1 | "Design a system that trains a 100B MoE model on 1000 GPUs." | Layered parallelism: expert parallel for MoE layers, tensor parallel within each expert group, pipeline parallel across layer chunks, data parallel for throughput. The all-to-all communication for expert routing is the bottleneck — co-locate experts that frequently share tokens. |
| 2 | "Your serving system handles 10K requests/sec with a 7B model. The team wants to switch to a 70B MoE. What changes?" | Memory: 10x more weight storage but only 2-3x more active compute. Need multi-GPU inference with expert parallelism. Routing adds latency variance — some requests hit "hot" experts. Solution: disaggregate prefill (compute-bound) from decode (memory-bound) on different hardware pools. |
| 3 | "Design the evaluation pipeline for a new foundation model release." | Multi-level: unit tests (basic capabilities), benchmark suites (MMLU, HumanEval, etc.), red-teaming (adversarial safety), A/B testing (human preference), long-context specific tests, multilingual parity checks. Key: evals must run in CI, not manually. |
| 4 | "Your training run has been running for 3 days. Loss suddenly spikes. Walk me through your debug process." | Systematic: (1) check for NaN/Inf in gradients, (2) check learning rate schedule, (3) inspect the exact batch that caused the spike, (4) check for hardware failures (GPU ECC errors), (5) check data pipeline for corrupted examples, (6) if gradient norm exploded, was the spike in a specific layer? (7) rollback to last good checkpoint, reduce lr, resume. |
| 5 | "Design a system for continuous model improvement using user feedback." | RLHF/DPO pipeline: collect preference pairs from user thumbs up/down, train reward model (or use DPO directly), fine-tune with PPO/DPO, evaluate on held-out preferences + safety benchmarks. Key: prevent reward hacking, maintain base capability, detect distribution shift in user queries. |
| # | Question | Key Insight They Want |
|---|---|---|
| 1 | "Write the top-k router for an MoE layer in JAX. Include the load-balancing loss." | See chapter 13 implementation. Key details: softmax only over selected experts, aux loss uses f_i × p_i product, handle the non-differentiability of top-k selection. |
| 2 | "Implement a Pallas kernel for batched vector add on TPU." | See chapter 14 vector_add example. Show BlockSpec, grid, index_map. Explain the compilation pipeline. |
| 3 | "Write a KV cache implementation that supports dynamic batch sizes." | Pre-allocate max capacity, track per-sequence length, use scatter/gather for non-contiguous sequences. PagedAttention: page table + block allocator. |
| 4 | "Implement gradient checkpointing for a transformer layer." | jax.checkpoint (or torch.utils.checkpoint): recompute activations during backward instead of storing. Trade 33% more compute for ~60% less memory. Choose which layers to checkpoint based on activation size. |
| 5 | "Write the data loading pipeline for a multi-host TPU training run." | tf.data or grain: shard dataset across hosts, prefetch to device, handle host-device async, deterministic shuffling (seeded per epoch for reproducibility), handle restarts from checkpoint. |
| # | Question | Key Insight They Want |
|---|---|---|
| 1 | "Your MoE model's loss is 10% worse than expected. All experts have equal utilization. What's wrong?" | Equal utilization is suspicious — it might mean the router has collapsed to uniform random routing (all logits near zero). Check router entropy: it should be high but not maximum. If the router is not specializing, the MoE is just an ensemble of random experts, losing the specialization benefit. |
| 2 | "A Pallas kernel produces correct results on small inputs but NaN on large inputs." | Likely a tile boundary issue: when input size is not divisible by block size, the last tile has uninitialized memory or goes out of bounds. Check padding, check boundary masking, verify grid dimension calculation. |
| 3 | "Training throughput dropped 30% after a JAX version upgrade." | Check XLA compilation: new version may choose different fusion patterns. Run with XLA_FLAGS=--xla_dump_to to compare HLO before/after. Check for new op lowerings that bypass custom patterns. Profile with Tensorboard to find the slow op. |
| 4 | "Your distributed training run hangs intermittently after 2-3 hours." | Classic symptoms of a slow or failing GPU. Check: (1) NCCL/TPU runtime logs for timeout errors, (2) individual GPU compute times (one outlier = hardware issue), (3) memory pressure causing OOM on one host, (4) network congestion in all-reduce (check for competing jobs). |
| 5 | "Your quantized INT8 model is 2% worse on English benchmarks but 15% worse on code generation." | Code tokens have sharper distributions (exact syntax matters). Quantization smooths the distribution, rounding "near-correct" logits to incorrect tokens. Fix: keep the code-heavy layers (likely the last few layers and the LM head) in FP16. Run per-layer sensitivity analysis specifically on code benchmarks. |
| # | Question | Key Insight They Want |
|---|---|---|
| 1 | "What are the open problems in MoE scaling?" | Expert collapse at scale (>1000 experts), optimal granularity (few large vs many small), routing for multi-modal inputs, expert merging for inference efficiency, MoE-specific distillation. |
| 2 | "Where is quantization headed after INT4/INT2?" | Microscaling (MX) formats from the OCP consortium: block-structured formats with per-block exponents. FP4 E2M1 on Blackwell GPUs. The frontier is 1-2 bit quantization with lookup-table decompression. Information theory says the limit is ~1.5 bits for weights with current model architectures. |
| 3 | "What is the biggest bottleneck in LLM inference today?" | Memory bandwidth for autoregressive decode. Each generated token requires loading all model weights. For a 70B FP16 model, that is 140 GB per token. At 3.35 TB/s (H100), theoretical max is ~24 tokens/sec. Speculative decoding, compression, and batch scheduling are the active research frontiers. |
| 4 | "How would you design an LLM that can reliably use tools?" | The key challenge is not calling tools (any prompted model can emit a function call) but knowing when to call them and what to do with failures. Research frontiers: ReAct-style reasoning traces, function-calling fine-tuning with execution feedback, tool-use RLHF where the reward includes task completion not just format correctness. |
| 5 | "What changed in the last 6 months that you're most excited about?" | This is a culture-fit question. They want genuine enthusiasm about a specific development — not a list. Pick one thing you actually care about and explain why it matters technically. Good answers name a specific paper or system and articulate its implications. |
The papers and concepts in this chapter are not independent. They form a dependency graph — you cannot understand QTIP without understanding QuIP, and you cannot understand QuIP without understanding quantization fundamentals. Here is the critical path through the material.
Common preparation mistakes that waste time or actively hurt your candidacy:
| Anti-Pattern | Why It Fails | What to Do Instead |
|---|---|---|
| Grinding LeetCode | Frontier lab interviews are not about algorithms and data structures. They test ML systems knowledge. A perfect score on LeetCode does not help you derive the Chinchilla scaling law. | Spend that time on the capstone project or paper reading. If you must practice coding, implement ML algorithms (attention, quantization, MoE routing) not two-sum. |
| Reading 50 papers shallowly | Surface-level knowledge of many papers is less useful than deep knowledge of 5. Interviewers probe depth, not breadth. | Read 5-7 papers deeply. For each: implement the key algorithm, reproduce one figure, write a 3-paragraph critique. |
| Memorizing formulas | If you memorize the attention formula but cannot derive it from the query-key-value interpretation, you will fail the follow-up question. | Derive every formula from first principles. If you cannot derive it, you do not understand it. |
| Only studying theory | "I understand the concept of quantization" means nothing if you have never quantized a model and measured the accuracy drop. | Every concept must have a corresponding implementation. Theory + code + measurement = understanding. |
| Ignoring the lab's specific work | Discussing FlashAttention at a lab that uses TPUs shows you did not research the company. | Read the lab's recent publications. Know their stack (JAX vs PyTorch, TPU vs GPU). Reference their work in your answers. |
Frontier lab interview loops are intense. They typically span 4-6 hours across 1-2 days, with each session testing a different dimension. Here is the typical structure:
| Round | Duration | Format | What They Test |
|---|---|---|---|
| Phone Screen | 45-60 min | Video call with one interviewer | Basic competence: can you code, do you know ML fundamentals, are you articulate? |
| Technical Deep-Dive | 60-90 min | Whiteboard or shared doc with 1-2 senior engineers | One topic in depth: derivations, edge cases, failure modes. They want to see how deep your understanding goes. |
| System Design | 60 min | Whiteboard with a senior/staff engineer | Design a training pipeline, inference system, or evaluation framework. They want to see you make tradeoffs and justify decisions. |
| Coding | 60-90 min | Live coding in an IDE or Colab | Implement a non-trivial algorithm: attention mechanism, MoE router, quantization routine, or Pallas kernel. Clean code matters as much as correctness. |
| Research Discussion | 45-60 min | Conversation with a research lead | Discuss recent papers, open problems, your own research interests. They want intellectual curiosity and the ability to identify important problems. |
| Culture/Values | 30-45 min | Conversation with a team lead or manager | Collaboration style, conflict resolution, what motivates you. They want to know if you will thrive on the team. |
| Mistake | Example | Better Approach |
|---|---|---|
| Answering too broadly | "How would you speed up inference?" — "Well, there's quantization, pruning, distillation, caching, batching, compilation..." | Pick the most impactful one for the specific scenario. "For a 70B model serving chat at 100 req/s, the bottleneck is memory bandwidth during decode. I'd start with INT8 quantization for a 2x speedup, then implement continuous batching for utilization." |
| Not asking clarifying questions | Jumping into a system design without understanding constraints | Spend the first 5 minutes asking: What hardware? What latency budget? What accuracy tolerance? What traffic pattern? These determine the entire design. |
| Memorized definitions without understanding | "FlashAttention uses tiling to reduce memory" — correct but shallow | "FlashAttention exploits the fact that attention is a streaming reduction: you never need the full S×S matrix in memory. By tiling along the sequence dimension and keeping running softmax statistics, each tile only needs O(BLOCK_SIZE) memory instead of O(S^2). The key insight is that softmax can be decomposed into local computations with a correction factor." |
| Ignoring failure modes | Presenting a system design without discussing what breaks | For every component, mention: "This fails when X. The mitigation is Y. If Y also fails, we gracefully degrade to Z." |
| Not writing tests | Submitting code without verification | After implementing, immediately write a test case. "Let me verify with a simple example: 12 + 34 should be 46..." Running the test live shows engineering discipline. |
This question tests research taste. The interviewer wants to know: can you identify problems that matter? Here are five strong answers with the reasoning behind each.
| Open Problem | Why It Matters | Current Best Approach | Why It Is Not Solved |
|---|---|---|---|
| Reliable reasoning | LLMs can write poetry but struggle with multi-step logic. This limits their use in math, code verification, and scientific discovery. | Chain-of-thought prompting, verification-guided search (AlphaProof), process reward models | We do not understand why in-context reasoning works or when it fails. No theory predicts which problems a model can solve by reasoning vs which require more training data. |
| Efficient long-context | Real-world tasks (codebase understanding, document analysis, multi-turn conversations) require contexts of 100K-1M+ tokens. Current models degrade badly beyond 32K. | RoPE extension, ring attention, sliding window + global tokens | KV cache memory scales linearly with context. Attention quality degrades with distance. No architecture efficiently captures both local and global dependencies at million-token scale. |
| Data efficiency | Chinchilla says we need 20 tokens per parameter. A 1T parameter model needs 20T tokens. We are running out of high-quality text data. | Synthetic data generation, data deduplication, curriculum learning | Synthetic data risks "model collapse" (self-reinforcing errors). There is no theoretical framework for predicting data quality impact on downstream capabilities. |
| Interpretability at scale | We deploy models we do not understand. This is unacceptable for safety-critical applications (medicine, law, autonomous vehicles). | Sparse autoencoders, probing classifiers, circuit analysis | Current methods work on small models but do not scale to 100B+ parameters. We have no way to verify that an interpretation is complete (covers all model behavior). |
| Compute efficiency cliff | Transformer compute scales quadratically with sequence length and linearly with parameters. We are approaching the limits of what current hardware can sustain. | MoE (this chapter!), linear attention, state-space models, speculative decoding | Every efficiency technique has tradeoffs (MoE = memory, SSM = reduced in-context learning). No single architecture is efficient across all workloads. |
The Google DeepMind "Scaling Book" is the most comprehensive reference on compute-optimal training. Here is a structured exercise checklist to work through it:
| Exercise | What You Build | What You Learn |
|---|---|---|
| 1. Reproduce Figure 1 | Train 5 model sizes (1M to 100M), fit power law to loss vs compute | How to measure scaling coefficients empirically |
| 2. Chinchilla grid search | For a fixed compute budget, sweep (N, D) pairs and plot the loss curve | The Chinchilla frontier is a valley, not a cliff |
| 3. MoE scaling analysis | Repeat exercise 2 with MoE models, varying expert count and top-k | MoE shifts the frontier: more params at same FLOP budget |
| 4. Downstream prediction | Fit scaling law on pretraining loss, predict downstream task accuracy | The relationship between loss and task performance is noisy but predictable |
| 5. Token quality impact | Train on curated vs web data at same token count, compare scaling curves | Data quality is a multiplier on effective compute |
The project discussion is often the highest-signal part of the interview. Your capstone is the perfect project to discuss. Here is how to structure your walkthrough for maximum impact.
Every Engineermaxxing resource relevant to frontier lab preparation, organized by topic.
| Topic | Micro Lesson | Veanor Paper |
|---|---|---|
| Transformers | Transformer, Attention | — |
| Training | Distributed Training | — |
| Inference | LLM Inference | SnapKV |
| Quantization | — | LLM.int8(), QuIP, QuIP#, QTIP, AQLM |
| Attention | Attention | FlashAttention-4 |
| GPU Kernels | GPU Kernel Landscape | ThunderKittens |
| JAX | Thinking in JAX | — |
| Agents | FunSearch | AlphaEvolve |
Different labs emphasize different skills. Tailor your preparation to your target.
| Lab | Primary Stack | Interview Emphasis | Preparation Focus |
|---|---|---|---|
| Anthropic | JAX / TPU / Custom infra | Systems + safety. Deep technical + values alignment. | Pallas kernels, scaling laws, RLHF/Constitutional AI, interpretability basics. Read Anthropic's published research. |
| Google DeepMind | JAX / TPU / Gemini | Research engineering. Strong ML theory + implementation. | Scaling laws, MoE (Gemini uses MoE), JAX internals, distributed training. Be ready for math-heavy derivations. |
| OpenAI | PyTorch / NVIDIA GPUs / Azure | Impact + scale. Ship fast, iterate faster. | CUDA/Triton, distributed training, RLHF, evaluation design. Strong coding is essential — they test implementation speed. |
| Meta FAIR | PyTorch / NVIDIA GPUs / On-prem | Open research. Publishing + artifacts. | Core ML research skills, paper reproduction, open-source contributions. They value academic publication background. |
| xAI | JAX / TPU + GPU / Custom | Full-stack speed. Build fast, break things, fix them. | Broad systems knowledge, comfort with ambiguity, willingness to work on whatever is needed. Startup mentality. |
If you have an interview in 30 days, here is the optimal allocation of your preparation time. This assumes 2-3 hours per day of focused preparation.
After conducting hundreds of interviews, here is what frontier lab hiring committees consistently cite as the differentiators:
| Signal | No-Hire Pattern | Hire Pattern |
|---|---|---|
| Depth | Knows what FlashAttention is. Cannot explain why online softmax is needed. | Can derive the online softmax correction factor from scratch. Knows when it breaks (very long sequences, numerical stability). |
| Taste | "I would try all the standard techniques and see what works." | "Given the constraints, I would start with X because Y. If that fails, the fallback is Z, which trades off A for B." |
| Ownership | "I worked on the training pipeline." (passive) | "I designed the calibration system. It reduced quantization regressions from 15% to 2%. Here is the specific design decision that mattered most." |
| Curiosity | Recites paper results. | Critiques paper methodology. "They claim 2x speedup, but their baseline is suboptimal. With a tuned Triton kernel as baseline, the improvement is closer to 1.3x." |
| Calibration | Confidently guesses when uncertain. | "I believe the answer is X based on Y, but I am not 100% sure about the constant factor. Let me derive it to check." |
Frontier labs value written communication as much as coding. Design docs, post-mortems, experiment write-ups, and internal papers are how technical decisions get made. Here is how to write effectively in a lab context.
| Document Type | Length | Audience | Structure |
|---|---|---|---|
| Design Doc | 3-8 pages | Tech lead + peer engineers | Problem statement, proposed solution, alternatives considered, risks, implementation plan, success metrics |
| Experiment Report | 1-3 pages | Immediate team | Hypothesis, setup, results (tables + plots), interpretation, next steps. Keep it factual — save opinions for the discussion section. |
| Post-Mortem | 2-4 pages | Full engineering org | Timeline of the incident, root cause analysis, contributing factors, what went well, action items with owners and deadlines |
| Technical Blog Post | 5-15 pages | External (research community) | Motivation, approach, results, ablations, limitations, future work. Must be reproducible from the description alone. |
Staff-level engineers carry a library of mental models that they apply across different problems. Here are the ones that appear most frequently in frontier lab work:
| Mental Model | Core Insight | Applied To |
|---|---|---|
| Roofline Analysis | Every operation is either compute-bound or memory-bound. Know which one before optimizing. | Kernel optimization, hardware selection, batching strategy |
| Amdahl's Law | Speeding up part of a system is limited by the fraction of time spent in that part. If matmul is 40% of step time, a 2x speedup saves at most 20%. | Profiling, optimization prioritization |
| The Bitter Lesson | Methods that scale with compute beat methods that encode human knowledge. (Rich Sutton, 2019) | Architecture decisions, research direction |
| Goodhart's Law | When a measure becomes a target, it ceases to be a good measure. Applies directly to reward model training in RLHF. | Evaluation design, reward hacking prevention |
| The Swiss Cheese Model | Safety comes from multiple independent layers, each with holes. No single layer is perfect, but stacking them reduces risk exponentially. | AI safety, deployment guardrails, testing strategy |
| Diminishing Returns | The first optimization gives 2x. The next gives 1.3x. The next gives 1.1x. Know when to stop optimizing and start shipping. | Performance engineering, hyperparameter tuning |
Interview preparation is a skill, and skills require practice. Here is a 4-week routine that builds all five dimensions:
| Week | Focus | Daily Practice (2 hours) |
|---|---|---|
| 1 | Concept + Code | Morning: read one paper from the list, take handwritten notes. Evening: implement the key algorithm in JAX. Ship to GitHub. |
| 2 | Design + Debug | Morning: pick a system design question, write a 1-page design doc (whiteboard format). Evening: reproduce a training failure and debug it (introduce NaN intentionally, fix it). |
| 3 | Capstone | Full days on the capstone project from chapter 15. Train, measure, analyze, document. |
| 4 | Frontier + Mock | Morning: read one paper from last 3 months, write a 3-paragraph summary. Evening: mock interview with a friend or self-recorded. Review and iterate. |
Beyond papers and this lesson, here are the highest-value resources for frontier lab preparation, ordered by signal-to-noise ratio.
| Resource | Format | Best For | Time Investment |
|---|---|---|---|
| Andrej Karpathy's "Neural Networks: Zero to Hero" | YouTube series | Building transformers from scratch. His nanoGPT implementation is a masterclass in clarity. | 20 hours |
| JAX documentation + tutorials | Online docs | Understanding jit, vmap, pmap, grad. The "Thinking in JAX" guide is essential. | 10 hours |
| Triton tutorials (OpenAI) | Online docs + notebooks | Writing GPU kernels. The matrix multiply tutorial alone is worth the time. | 15 hours |
| Stanford CS231n (Computer Vision) | Lectures + assignments | CNN architectures, training dynamics, practical optimization. Still the best ML systems course. | 40 hours |
| The Illustrated Transformer (Jay Alammar) | Blog post | Visual intuition for attention. Good supplement to implementation-focused learning. | 2 hours |
| vLLM source code | GitHub repo | Production LLM serving: PagedAttention, continuous batching, speculative decoding. Read the core engine. | 20 hours |
| Chinchilla paper (Hoffmann et al.) | arXiv paper | THE scaling laws paper. Every frontier lab uses these results for training decisions. | 5 hours |
The most effective portfolio is not a collection of finished projects — it is a public learning log. Here is exactly what to build, in what order, with the expected time investment.
Frontier lab compensation is highly competitive. Here is what to expect and how to navigate the offer stage.
| Level | Typical Title | Typical TC (US, 2025) | What They Want |
|---|---|---|---|
| L4 / IC3 | ML Engineer | $300-450K | Strong implementation skills. Can own a component end-to-end. 2-5 years experience. |
| L5 / IC4 | Senior ML Engineer | $450-650K | Technical leadership on a project. Mentors junior engineers. 4-8 years experience. |
| L6 / IC5 | Staff ML Engineer | $600-1M+ | Owns a system and its evolution. Influences team direction. Recognized expert in a domain. 7+ years. |
| L7 / IC6 | Principal / Distinguished | $800K-2M+ | Shapes the technical strategy. Known externally. Papers, talks, or major systems. |
Total compensation (TC) at frontier labs is heavily weighted toward equity (RSUs or options). Base salary is typically $200-300K regardless of level. The rest is equity, which can be highly variable for pre-IPO companies (Anthropic, xAI) vs public companies (Google, Meta).
You got the offer. Now what? Here is what the first 90 days typically look like for a new ML research engineer at a frontier lab.
| Period | Focus | What Success Looks Like |
|---|---|---|
| Days 1-14 | Onboarding + codebase orientation | You can submit a PR that passes CI. You understand the training pipeline: where data comes in, how it flows through the model, where checkpoints are saved. You know the team's Slack channels and meeting cadence. |
| Days 15-30 | First meaningful contribution | You ship a small optimization (fuse two ops, fix a memory leak, improve a data pipeline) that measurably improves training throughput or eval accuracy. It does not need to be groundbreaking — it needs to be correct, well-tested, and shipped. |
| Days 31-60 | Own a component | You are the go-to person for one subsystem (quantization pipeline, evaluation framework, data preprocessing, or a specific model component). You have opinions about its design and a roadmap for improvement. |
| Days 61-90 | Drive a project | You propose and lead a project that spans multiple weeks. Maybe it is implementing MoE routing for the next model, or building a new eval benchmark, or writing a Pallas kernel for the training bottleneck you found in your first month. This is where you transition from "new hire" to "team member." |
The single most important thing in your first 90 days: ship early and often. A small PR on day 7 is worth more than a perfect PR on day 45. Frontier labs move fast. Show that you can too.