Workbook — JAX Scaling Book Companion

Scaling Book Workbook

Every exercise from the JAX Scaling Book, solvable in-browser. Pencil-and-paper math meets instant feedback.

Prerequisites: Basic linear algebra + Python familiarity. That's it.
12
Chapters
60+
Exercises
0
Hand-Waving

Chapter 0: Roofline Exercises

You are an ML engineer staring at a GPU profiler trace. A matrix multiply that should take 0.14 ms is taking 0.8 ms. Is the bottleneck memory bandwidth? Compute throughput? Launch overhead? Without a mental model, you are guessing. The roofline model gives you that mental model.

Before you can optimize anything, you need to know what limits performance. Is your operation waiting for data to arrive from memory, or waiting for the arithmetic units to finish crunching? The roofline model answers this question with a single number: arithmetic intensity.

Arithmetic intensity (AI) is the ratio of floating-point operations to bytes moved from memory. If you do many FLOPs per byte loaded, you are compute-bound. If you do few FLOPs per byte, you are memory-bound. The dividing line is the hardware's ridge point — the arithmetic intensity where the processor's compute throughput exactly matches its memory bandwidth.

The roofline in one sentence. If your operation's arithmetic intensity is below the ridge point, adding more compute won't help — you're starved for data. If it's above, adding more memory bandwidth won't help — the ALUs are the bottleneck.

Key Formulas

Arithmetic Intensity = FLOPs / Bytes Transferred
Ridge Point = Peak FLOP/s / Peak Bandwidth (bytes/s)
Attainable FLOP/s = min(Peak FLOP/s, AI × Peak Bandwidth)

For a matrix multiplication C = A × B where A is [M, K] and B is [K, N]:

FLOPs = 2 × M × K × N // multiply + add per output element
Bytes = (M×K + K×N + M×N) × bytes_per_element
AI = 2MKN / ((MK + KN + MN) × bytes_per_element)
Exercise 0.1: Matmul Arithmetic Intensity

Compute the arithmetic intensity for a square matmul: [4096, 4096] × [4096, 4096] in BF16 (2 bytes per element).

Hint: FLOPs = 2 × 40963. Bytes = 3 × 40962 × 2.

FLOPs/byte
Show derivation
FLOPs = 2 × 4096 × 4096 × 4096 = 137,438,953,472 ≈ 1.37 × 1011
Bytes = (40962 + 40962 + 40962) × 2 = 3 × 16,777,216 × 2 = 100,663,296
AI = 137,438,953,472 / 100,663,296 ≈ 1365.3 FLOPs/byte

For square matmuls, AI simplifies to 2N/6 = N/3 when element size = 2 bytes. So 4096/3 ≈ 1365. This is deeply compute-bound on any modern accelerator.

Exercise 0.2: H100 Ridge Point

The H100 SXM has ~990 TFLOP/s BF16 Tensor Core throughput and ~3.35 TB/s HBM bandwidth. What is the ridge point?

FLOPs/byte
Show derivation
Ridge Point = 990 × 1012 / (3.35 × 1012) ≈ 295.5 FLOPs/byte

Any operation with AI above ~296 is compute-bound on the H100. Our matmul from Exercise 0.1 (AI ≈ 1365) is well above this — solidly compute-bound. Element-wise ops like ReLU (AI = 1 FLOP / 4 bytes = 0.25) are deeply memory-bound.

Exercise 0.3: Matmul Throughput Prediction

Given the [4096, 4096] × [4096, 4096] BF16 matmul is compute-bound on the H100, predict the theoretical execution time assuming 990 TFLOP/s peak.

ms
Show derivation
Time = FLOPs / Peak FLOP/s = 1.374 × 1011 / (990 × 1012) = 0.0001388 s ≈ 0.139 ms

In practice, you will see ~0.15-0.2 ms due to kernel launch overhead and imperfect occupancy. The ratio of achieved to theoretical is your MFU (Model FLOPs Utilization). Getting above 50% MFU on real workloads is considered good.

Exercise 0.4: Is Self-Attention Memory-Bound?

Consider the QKT computation in self-attention: Q and K are both [batch, heads, seq, d_head]. With batch=1, heads=32, seq=2048, d_head=128, BF16. What is the arithmetic intensity of the QKT matmul (treating each head independently)?

Hint: Per head, this is a [2048, 128] × [128, 2048] matmul. Compute AI for one head.

FLOPs/byte
Show derivation

Per head: M=2048, K=128, N=2048, element_size=2 bytes.

FLOPs = 2 × 2048 × 128 × 2048 = 1,073,741,824
Bytes = (2048×128 + 128×2048 + 2048×2048) × 2
= (262,144 + 262,144 + 4,194,304) × 2 = 9,437,184
= (524,288 + 8,388,608) = 17,301,504 bytes... wait, let's recompute.
Bytes = (MK + KN + MN) × 2 = (262144 + 262144 + 4194304) × 2 = 4718592 × 2 = 9,437,184 bytes
AI = 1,073,741,824 / 9,437,184 ≈ 113.8

Wait — but we must also account for the softmax and the AV matmul. The standard attention also reads/writes the full S = QKT matrix of shape [2048, 2048] to HBM. With this materialization, total bytes for the full attention including the S matrix read/write becomes much larger, dropping effective AI to roughly 62. This is why FlashAttention exists — it avoids materializing S to HBM by computing attention in tiles, fusing the QKT, softmax, and AV multiply into a single kernel. On an H100 with ridge point ~296, standard attention at AI~62 is memory-bound. FlashAttention raises the effective AI well above the ridge point.

Exercise 0.5: Element-wise Op Arithmetic Intensity

ReLU activation on a tensor of 10M elements in FP32 (4 bytes). ReLU does one comparison per element (1 FLOP). What is the arithmetic intensity?

FLOPs/byte
Show derivation
FLOPs = 10,000,000 (one comparison per element)
Bytes = 10,000,000 × 4 (read) + 10,000,000 × 4 (write) = 80,000,000
AI = 10,000,000 / 80,000,000 = 0.125 FLOPs/byte

This is 2,364× below the H100's ridge point of 296. ReLU is absurdly memory-bound. This is why activation functions are never standalone kernels in optimized code — they are always fused into the preceding or following matmul.

Exercise 0.6: Tall-and-Skinny Matmul

During autoregressive decode, the main operation per layer is a matrix-vector product: [1, d] × [d, d] where d=8192. BF16. What is the arithmetic intensity? Is it compute-bound or memory-bound on H100?

FLOPs/byte
Show derivation
FLOPs = 2 × 1 × 8192 × 8192 = 134,217,728
Bytes = (1 × 8192 + 8192 × 8192 + 1 × 8192) × 2
= (8192 + 67,108,864 + 8192) × 2 ≈ 134,250,496
AI = 134,217,728 / 134,250,496 ≈ 1.0 FLOPs/byte

AI ≈ 1 — deeply memory-bound (ridge point is 296). The weight matrix dominates the byte count. This is why decode is memory-bandwidth-limited: each token requires reading the entire weight matrix for a trivial number of FLOPs. Batching helps by amortizing the weight read across B tokens, raising AI to ~B.

The roofline lesson. Large square matmuls are always compute-bound. Attention (without FlashAttention) is memory-bound because the intermediate S matrix forces extra HBM traffic. Element-wise ops are always memory-bound. Decode (matrix-vector) is memory-bound until batch size reaches the ridge point. The roofline tells you where to focus your optimization effort.
An operation has arithmetic intensity of 50 FLOPs/byte running on an H100 (ridge point ~296). Which statement is true?

Chapter 1: TPU Architecture

Google's Tensor Processing Units (TPUs) are custom ASICs designed for one thing: dense matrix multiplication at massive scale. Unlike GPUs, which evolved from graphics workloads and carry a lot of general-purpose flexibility, TPUs are laser-focused on the matmul-softmax-matmul pipeline that dominates transformer training and inference.

The heart of a TPU is the MXU (Matrix Multiply Unit) — a systolic array that computes a matrix multiply in a single pass. A TPU v5e has a 128×128 BF16 systolic array running at ~1.1 GHz. Each cycle, the MXU produces 128×128 = 16,384 multiply-accumulate results, and each MAC is 2 FLOPs.

Systolic array intuition. Picture a 128×128 grid of tiny calculators. Data flows in from the left (A matrix rows) and from the top (B matrix columns). Each calculator multiplies its two inputs, adds to its running sum, and passes data to its neighbor. After the array fills, results pour out from the bottom. No random access, no caches — just a river of numbers.
Exercise 1.1: MXU Peak Throughput

A TPU v5e has a 128×128 BF16 MXU running at 1.1 GHz. How many TFLOP/s of BF16 matmul throughput does one MXU produce?

Hint: Each cycle produces 128 × 128 MAC operations. Each MAC = 2 FLOPs (multiply + add).

TFLOP/s
Show derivation
FLOPs/cycle = 128 × 128 × 2 = 32,768
FLOPs/s = 32,768 × 1.1 × 109 = 36.04 × 1012
≈ 36 TFLOP/s

For comparison, an H100 achieves ~990 TFLOP/s BF16 — but that is with 4th-gen Tensor Cores operating on much smaller tiles at higher clock and much wider parallelism. The TPU v5e compensates by being cheaper and packing more chips per pod.

Exercise 1.2: TPU v5e Ridge Point

TPU v5e specs: ~197 TFLOP/s peak BF16 (with 2 MXUs + vector unit contributions) and ~819 GB/s HBM bandwidth. What is the ridge point?

FLOPs/byte
Show derivation
Ridge = 197 × 1012 / (819 × 109) ≈ 240.5 FLOPs/byte

The TPU v5e has a lower ridge point than the H100 (~296), meaning operations become compute-bound at lower arithmetic intensity. This is because the TPU's ratio of compute-to-bandwidth is more balanced — relatively more bandwidth per FLOP.

Exercise 1.3: Memory Bandwidth Utilization

You run a layer normalization on a tensor of shape [2048, 8192] in BF16 on a TPU v5e (819 GB/s bandwidth). LayerNorm reads the input once, writes the output once, plus reads/writes for mean and variance (negligible for large tensors). Approximate the minimum execution time.

Hint: LayerNorm is memory-bound. Total bytes = 2 × (read + write) of the tensor.

ms
Show derivation
Tensor bytes = 2048 × 8192 × 2 = 33,554,432 bytes = 33.55 MB
Total transfer = 2 × 33.55 MB = 67.1 MB (read + write)
Time = 67.1 × 106 / (819 × 109) = 0.0000819 s ≈ 0.082 ms

In practice it will be longer due to kernel launch overhead and the fact that LayerNorm also computes a reduction (mean, variance), but the memory transfer dominates. This is why fusing LayerNorm with the preceding or following operation is one of the most impactful optimizations — you eliminate one full round-trip to HBM.

Exercise 1.4: TPU v5e vs H100 Comparison

Fill in the comparison table from memory, then check your answers.

SpecTPU v5eH100 SXM
BF16 peak TFLOP/s~197~990
HBM bandwidth~819 GB/s~3,350 GB/s
HBM capacity16 GB80 GB
Ridge point~240~296
InterconnectICI (4800 Gbps)NVLink (900 GB/s)

At what model size (in BF16 parameters) does a single TPU v5e run out of memory?

B parameters
Show derivation
Max params in BF16 = 16 GB / 2 bytes = 8 × 109 = 8B parameters

A TPU v5e can hold at most an 8B model in weights alone (no KV cache, no activations). In practice, inference of even a 7B model requires careful memory management. This is why TPU pods aggregate many chips — a v5e pod with 256 chips has 4 TB of total HBM.

TPU vs GPU mental model. GPUs: many small cores, massive parallelism, flexible. TPUs: fewer but larger systolic arrays, less flexible, cheaper per FLOP for matmul-heavy workloads. The ridge point comparison (TPU v5e ~240 vs H100 ~296) tells you TPUs have relatively better bandwidth per compute unit, making them friendlier to slightly memory-bound workloads.
A TPU's MXU is a 128x128 systolic array. If you need to multiply a [64, 128] matrix by a [128, 64] matrix, what happens to the utilization of the MXU?

Chapter 2: Sharded Matrix Multiply

A single accelerator can only hold so much data in HBM. A 70B parameter model in BF16 needs 140 GB just for weights — far more than any single GPU's memory. The solution: shard the matrices across multiple devices and coordinate the computation.

There are three fundamental ways to shard a matmul C = A × B where A is [M, K] and B is [K, N]:

StrategyShard A alongShard B alongCommunication needed
Row-parallelM (rows)ReplicatedNone (outputs are independent)
Column-parallelReplicatedN (columns)None (outputs are independent)
Reduce-scatterK (inner dim)K (inner dim)All-reduce on partial sums
The tradeoff. Row-parallel and column-parallel avoid communication but require replicating one full matrix. Sharding along K reduces memory per device for both matrices but requires an all-reduce to sum partial products. In practice, transformer layers use a clever combination: column-parallel for the first linear, row-parallel for the second, with a single all-reduce between them.
Exercise 2.1: Sharding Strategy

You have a weight matrix W of shape [8192, 32768] and input X of shape [2048, 8192]. You want to compute Y = X × W across 8 devices with minimum communication. Which sharding minimizes communication?

Show reasoning

Column-parallel: replicate X (2048×8192×2 = 32 MB per device), shard W into 8 chunks of [8192, 4096] (67 MB each). Each device computes Y_chunk = X × W_chunk with no communication. Output Y is already distributed column-wise.

Row-parallel: shard X into rows (each device gets [256, 8192]), replicate W (8192×32768×2 = 512 MB per device!). Wastes memory replicating the huge weight matrix.

K-sharding: requires an all-reduce. Bad for minimum communication.

Column-parallel wins: zero communication, and the smaller matrix (X) is the one replicated.

Exercise 2.2: All-Reduce Volume

You shard a [8192, 8192] × [8192, 8192] BF16 matmul along K across 8 devices. Each device computes a partial [8192, 8192] result that must be summed via all-reduce. Using the ring all-reduce algorithm, how many bytes does each device send in total?

Hint: Ring all-reduce sends 2 × (P-1)/P × tensor_size bytes per device, where P = number of devices.

MB
Show derivation
Tensor size = 8192 × 8192 × 2 bytes = 134,217,728 bytes = 128 MB
All-reduce volume per device = 2 × (8-1)/8 × 128 MB
= 2 × 0.875 × 128 = 224 MB

The ring all-reduce has two phases: reduce-scatter (each device sends (P-1)/P of the data) and all-gather (same volume again). Total = 2 × (P-1)/P × data_size. As P grows, this approaches 2× the tensor size — the lower bound for any all-reduce algorithm.

Exercise 2.3: Communication-to-Computation Ratio

Using the same K-sharded matmul from 2.2 on 8 devices, each device computes a [8192, 1024] × [1024, 8192] matmul. The all-reduce transfers 224 MB. If each device runs the matmul in 0.017 ms (compute-bound at ~990 TFLOP/s) and the interconnect has 450 GB/s bidirectional bandwidth, what is the communication time? Is the overall operation communication-bound?

ms (comm time)
Show derivation
Comm time = 224 MB / (450 GB/s) = 224 × 106 / (450 × 109) ≈ 0.498 ms
Compute time = 0.017 ms
Ratio = 0.498 / 0.017 ≈ 29×

Communication takes 29× longer than computation! This is why K-sharding is avoided when possible — the all-reduce dominates. Column-parallel or row-parallel sharding avoids this entirely for individual layers, at the cost of replicating one matrix.

Exercise 2.4: FSDP All-Gather vs TP All-Reduce

Consider one transformer layer of LLaMA 3 70B. The layer has ~856M parameters in BF16 = 1.71 GB of weights. Compare two strategies on 8 devices:

(A) FSDP: all-gather the full layer weights before forward. Volume = (P-1)/P × layer_size per device.
(B) TP: two all-reduces per layer (attention + MLP) on output tensors of shape [batch×seq, d]. With batch=4, seq=4096, d=8192, BF16.

Which transfers fewer bytes per device?

Show derivation

FSDP all-gather:

Volume = (8-1)/8 × 1.71 GB = 0.875 × 1.71 = 1.50 GB per device

TP all-reduce (per operation):

Tensor size = 4 × 4096 × 8192 × 2 = 268 MB
Per all-reduce = 2 × (8-1)/8 × 268 MB = 469 MB
Two all-reduces = 2 × 469 = 938 MB = 0.94 GB

FSDP transfers 1.50 GB, TP transfers 0.94 GB. TP wins on volume for this config. But TP must happen over fast NVLink (within a node), while FSDP can use slower inter-node links since it is less latency-sensitive (one big transfer vs many small ones). This is why TP is used within nodes and FSDP across nodes.

In a standard Transformer MLP (two linear layers: up-projection then down-projection), Megatron-LM uses column-parallel for the first linear and row-parallel for the second. How many all-reduce operations are needed per MLP block?

Chapter 3: Transformer Parameter Counting

Before you can estimate training cost, inference memory, or sharding strategy, you need to count parameters precisely. No approximations, no hand-waving. Every weight matrix, every bias, every embedding. This is the foundation of all scaling analysis.

The Standard Transformer Block

One transformer block contains a self-attention layer and an MLP (feed-forward) layer, each with layer norms. Here is the parameter breakdown:

Self-Attention:
Q projection: d × d // = d × (n_heads × d_head)
K projection: d × d_kv // d_kv = n_kv_heads × d_head for GQA
V projection: d × d_kv
O projection: d × d

MLP (SwiGLU):
Gate projection: d × d_ff
Up projection: d × d_ff
Down projection: d_ff × d

Layer Norms:
2 × d // weight only, no bias in modern transformers

Embeddings:
Token embedding: vocab × d
(Positional: usually RoPE = 0 params)
Exercise 3.1: GPT-2 Parameter Count

GPT-2 (small): 12 layers, d=768, 12 heads, d_head=64, d_ff=3072 (4×d), vocab=50257. Standard MHA (not GQA), standard MLP (not SwiGLU, so 2 projections not 3), with biases on all linear layers, and separate LN weight+bias.

Count total parameters. Include: 4 attention projections + biases, 2 MLP projections + biases, 2 LayerNorms (weight+bias each), token embedding, position embedding (1024×768), and final LN.

parameters
Show derivation

Per transformer block:

Attention: 4 × (768 × 768 + 768) = 4 × 590,592 = 2,362,368
MLP: (768 × 3072 + 3072) + (3072 × 768 + 768) = 2,362,368 + 2,360,064 = 4,722,432
LayerNorms: 2 × (768 + 768) = 3,072
Total per block = 2,362,368 + 4,722,432 + 3,072 = 7,087,872

Global:

12 blocks = 12 × 7,087,872 = 85,054,464
Token embedding = 50,257 × 768 = 38,597,376
Position embedding = 1,024 × 768 = 786,432
Final LN = 768 + 768 = 1,536
Total = 85,054,464 + 38,597,376 + 786,432 + 1,536 = 124,439,808 ≈ 124M

OpenAI reported 124M. Note that GPT-2 ties the output head with the token embedding, so no extra parameters there.

Exercise 3.2: LLaMA 3 70B Parameter Count

LLaMA 3 70B: 80 layers, d=8192, 64 attention heads, 8 KV heads (GQA), d_head=128, d_ff=28672 (SwiGLU, so 3 projections), vocab=128256. No biases anywhere. RoPE (0 params). RMSNorm (weight only, no bias).

Count total parameters in billions.

B params
Show derivation

Per block attention:

Q: 8192 × 8192 = 67,108,864 (64 heads × 128)
K: 8192 × 1024 = 8,388,608 (8 KV heads × 128)
V: 8192 × 1024 = 8,388,608
O: 8192 × 8192 = 67,108,864
Attention total = 150,994,944

Per block MLP (SwiGLU):

Gate: 8192 × 28672 = 234,881,024
Up: 8192 × 28672 = 234,881,024
Down: 28672 × 8192 = 234,881,024
MLP total = 704,643,072

Per block norms:

2 × 8192 = 16,384 (RMSNorm, weight only)

Total per block:

150,994,944 + 704,643,072 + 16,384 = 855,654,400

Global:

80 blocks = 80 × 855,654,400 = 68,452,352,000
Token embedding = 128,256 × 8,192 = 1,050,673,152
Output head = 8,192 × 128,256 = 1,050,673,152 (untied)
Final RMSNorm = 8,192
Total = 68,452,352,000 + 1,050,673,152 + 1,050,673,152 + 8,192
= 70,553,706,496 ≈ 70.6B
Exercise 3.3: Forward Pass FLOPs Per Token

The standard approximation for forward pass FLOPs per token is 2N (where N = number of parameters), because each parameter participates in one multiply-add (2 FLOPs). For LLaMA 3 70B (N ≈ 70.6B), how many FLOPs per token in the forward pass?

×1011 FLOPs
Show derivation
Forward FLOPs/token ≈ 2N = 2 × 70.6 × 109 = 1.412 × 1011

This approximation counts matmul FLOPs only. Attention QKT and softmax add roughly 2 × n_layers × seq_len × d FLOPs per token, which for seq=8192 is ~1010 — about 7% of the 2N term. For back-of-envelope calculations, 2N is sufficient.

Exercise 3.4: Training FLOPs (6ND)

The standard training FLOP estimate is 6ND: forward (2N) + backward (4N, because gradients require ~2× the forward FLOPs), times D tokens. LLaMA 3 70B was trained on 15T tokens. How many total FLOPs?

×1024 FLOPs
Show derivation
Total FLOPs = 6 × N × D = 6 × 70.6 × 109 × 15 × 1012
= 6 × 70.6 × 15 × 1021 = 6,354 × 1021 = 6.354 × 1024

That is about 6.35 yottaFLOPs. For context, this is roughly 1024 operations — the same order of magnitude as the number of stars in the observable universe multiplied by a million.

Exercise 3.5: KV Cache Sizing

LLaMA 3 70B: 80 layers, GQA with 8 KV heads, d_head=128, batch=32, seq=8192, BF16. How many GB for the KV cache?

Formula: 2 (K and V) × n_layers × n_kv_heads × d_head × seq_len × batch × bytes_per_element

GB
Show derivation
KV cache = 2 × 80 × 8 × 128 × 8192 × 32 × 2 bytes
= 2 × 80 × 8 × 128 × 8192 × 32 × 2
= 2 × 80 × 8 × 128 × 8192 × 64
= 2 × 80 × 8 × 128 × 524,288
= 2 × 80 × 536,870,912 = 85,899,345,920 bytes
= 8.59 GB

GQA helps enormously here. Without it (64 KV heads instead of 8), the cache would be 68.7 GB — 8× larger. The model weights alone are 141 GB in BF16, so even with GQA the KV cache is a significant fraction of total memory at large batch sizes.

Exercise 3.6: MFU Calculation

You are training LLaMA 3 70B on 2048 H100s. You measure 3200 tokens/second/GPU. Each H100 has 990 TFLOP/s BF16 peak. What is your Model FLOPs Utilization (MFU)?

MFU = (actual model FLOPs/s per GPU) / (peak FLOPs/s per GPU). Model FLOPs per token = 6N (forward + backward).

%
Show derivation
Model FLOPs per token = 6N = 6 × 70.6 × 109 = 4.236 × 1011
Throughput per GPU = 3200 tokens/s
FLOPs/s per GPU = 3200 × 4.236 × 1011 = 1.355 × 1015
= 1355 TFLOP/s... wait, that exceeds peak.

Let me recalculate. The 6N counts forward + backward total. But MFU typically uses forward-only FLOPs (2N) per token, divided by wall-clock time per step:

MFU = (tokens/s/GPU × 2N) / peak_FLOPs = (3200 × 2 × 70.6 × 109) / (990 × 1012)
= (3200 × 141.2 × 109) / (990 × 1012) = 451.84 × 1012 / 990 × 1012
= 0.456 = 45.6%

Note: the exact MFU definition varies. Some use 6N (including backward), some use 2N (forward only). Meta reported ~38-43% MFU for LLaMA 3 training. The key point: anything above 40% on a large distributed run is excellent. The gap to 100% comes from communication overhead, pipeline bubbles, memory-bound operations (layer norm, softmax), and suboptimal occupancy.

GQA (Grouped-Query Attention) reduces the number of KV heads from n_heads to n_kv_heads. Which component does this NOT reduce?

Chapter 4: Training Parallelism

You cannot train a 70B model on one GPU. You need parallelism — splitting the work across many devices. There are four main flavors, and real training runs combine all of them.

StrategyWhat's splitCommunicationMemory saving
Data Parallel (DP)BatchAll-reduce gradientsNone (each GPU has full model)
FSDP (ZeRO-3)Batch + params/grads/optimizerAll-gather params, reduce-scatter grads~P× (P = # GPUs)
Tensor Parallel (TP)Individual weight matricesAll-reduce per layer~TP_degree×
Pipeline Parallel (PP)Layers across stagesPoint-to-point activations~PP_degree×
The memory equation. Per-GPU memory during training: model_params + gradients + optimizer_states + activations. With Adam in BF16 mixed precision: params (2 bytes) + grads (2 bytes) + optimizer (master weights 4B + momentum 4B + variance 4B = 12 bytes) = 16 bytes per parameter, plus activation memory. FSDP shards all of these across GPUs.
Exercise 4.1: FSDP Memory Per GPU

70B model, Adam optimizer, BF16 training. Without FSDP, per-GPU memory for params+grads+optimizer = 16 bytes × N. With FSDP across 64 GPUs, how much memory per GPU for these components?

GB
Show derivation
Without FSDP: 16 × 70 × 109 = 1,120 GB (impossible on one GPU!)
With FSDP on 64 GPUs: 1,120 / 64 = 17.5 GB per GPU

This is just model state. Activation memory is additional and depends on sequence length, batch size, and whether you use activation checkpointing. An H100 has 80 GB HBM, so 17.5 GB for model state leaves ~62 GB for activations — comfortable.

Exercise 4.2: Tensor Parallelism Communication

In tensor-parallel attention with TP=8, each device computes 8 of 64 attention heads for LLaMA 3 70B. The output projection produces a partial result that requires an all-reduce. The output tensor per token has shape [1, 8192] in BF16. For a microbatch of 4 sequences, each of length 8192, how many bytes are transferred in the all-reduce?

Hint: all-reduce on the output of shape [batch × seq, d]. Use the ring formula: 2 × (P-1)/P × tensor_bytes.

MB per device
Show derivation
Output tensor shape = [4 × 8192, 8192] = [32768, 8192] in BF16
Tensor bytes = 32768 × 8192 × 2 = 536,870,912 bytes = 512 MB
All-reduce per device = 2 × (8-1)/8 × 512 MB = 2 × 0.875 × 512 = 896 MB

Wait — the attention block actually has TWO all-reduces (one for attention output, one for MLP output). However, within each, there are also two column-parallel + row-parallel pairs. The key insight is that TP requires one all-reduce per row-parallel layer. For a full transformer block (attention + MLP), that is 2 all-reduces, each of size [batch × seq, d].

Per all-reduce: 896/2 = 448 MB per device (using single all-reduce calculation). This is why TP is typically done within a single node over NVLink (900 GB/s) rather than across nodes.

Exercise 4.3: Pipeline Bubble Fraction

Pipeline parallelism with PP=4 stages. If you use m microbatches, the bubble fraction (wasted time) is (PP-1)/(m+PP-1). How many microbatches do you need to keep the bubble fraction below 5%?

microbatches (minimum)
Show derivation
Bubble fraction = (PP-1) / (m + PP - 1) < 0.05
3 / (m + 3) < 0.05
3 < 0.05 × (m + 3)
60 < m + 3
m > 57

You need at least 57 microbatches. This means your global batch size must be at least 57 × microbatch_size. This is why PP works best with very large global batch sizes, and why modern systems prefer FSDP + TP over heavy PP when possible.

Exercise 4.4: Optimal Parallel Config for 70B on 2048 H100s

Meta trained LLaMA 3 70B on 2048 H100s. Given: TP must stay within a node (8 GPUs per node), PP adds bubbles, FSDP handles the rest. Which configuration makes sense?

Show reasoning

Option B (TP=8, PP=4, FSDP=64) is closest to what Meta actually used. Here is why:

TP=8: Uses all 8 GPUs within a node via NVLink (900 GB/s). Maximum bandwidth for TP all-reduces.

PP=4: Splits 80 layers into 4 stages of 20 layers. Reduces activation memory. With enough microbatches (Meta used ~100), bubble fraction is ~3%.

FSDP=64: 2048 / (8 × 4) = 64 FSDP groups handle the remaining parallelism with gradient communication over the inter-node network.

Option A (no PP) would work but requires more activation memory per stage. Option C (pure FSDP) has too much inter-node communication. Option D (TP=2) wastes NVLink bandwidth.

Exercise 4.5: Activation Memory

Activation memory per layer (no activation checkpointing): you must store the input to each layer for the backward pass. For LLaMA 3 70B with batch=4, seq=4096, d=8192, BF16: how much activation memory per layer?

Simplified: store the input tensor [batch × seq, d] plus the attention intermediate [batch, n_heads, seq, seq] (for the backward through attention).

GB per layer
Show derivation
Input activation: 4 × 4096 × 8192 × 2 = 268,435,456 bytes = 0.25 GB
Attention scores: 4 × 64 × 4096 × 4096 × 2 = 8,589,934,592 bytes = 8.0 GB
Total per layer ≈ 8.25 GB

The attention score matrix dominates! For 80 layers: 80 × 8.25 = 660 GB of activation memory — impossible without mitigation. Solutions: (1) Activation checkpointing: recompute activations in the backward pass instead of storing them. Trades ~33% extra compute for ~80% memory savings. (2) FlashAttention: never materializes the full attention matrix, saving the 8 GB/layer term entirely. (3) Sequence parallelism: shard the activation across the TP dimension.

FSDP (ZeRO-3) shards parameters, gradients, AND optimizer states across GPUs. Before each forward/backward computation, it must all-gather the full parameters. What is the cost of this compared to standard data parallelism?

Chapter 5: Training Cost Estimation

Now we combine everything: parameter count, FLOPs, hardware specs, and utilization to answer the question every ML lead asks: "How long will this take, and how much will it cost?"

The Core Formula

Training time = 6ND / (n_GPUs × peak_FLOPs × MFU)

Where N = parameters, D = tokens, n_GPUs = GPU count, peak_FLOPs = per-GPU peak FLOP/s, MFU = model FLOPs utilization (typically 0.35-0.50).

Exercise 5.1: Training Time for LLaMA 3 70B

LLaMA 3 70B: N=70.6B, D=15T tokens, 2048 H100s at 990 TFLOP/s peak BF16, MFU=0.40. How many days to train?

days
Show derivation
Total FLOPs = 6 × 70.6 × 109 × 15 × 1012 = 6.354 × 1024
Aggregate throughput = 2048 × 990 × 1012 × 0.40 = 8.11 × 1017 FLOP/s
Time = 6.354 × 1024 / 8.11 × 1017 = 7.83 × 106 seconds
= 7,830,000 s / 86400 s/day ≈ 90.6 days

Meta reported about 54 days for LLaMA 3 70B, which implies either higher MFU (~0.43-0.50), overlap of communication and computation, or a different token count for the main training phase. Our estimate is in the right ballpark.

Exercise 5.2: Dollar Cost

Using your answer from 5.1 (~91 days), at $2/GPU-hour for H100 cloud pricing, what is the total training cost?

USD
Show derivation
GPU-hours = 2048 GPUs × 91 days × 24 hours/day = 4,472,832 GPU-hours
Cost = 4,472,832 × $2 = $8,945,664

Roughly $9M. In practice, costs are higher due to failed runs, hyperparameter sweeps, evaluation compute, data preprocessing, and the cost of the cluster sitting idle during maintenance. A realistic total project cost is 2-3× the pure training compute cost, so $20-30M for a 70B model.

Exercise 5.3: Chinchilla-Optimal Tokens

The Chinchilla scaling law says the compute-optimal number of tokens D is approximately 20 × N. For a 70B model, what is the Chinchilla-optimal token count?

T tokens
Show derivation
Doptimal = 20 × N = 20 × 70 × 109 = 1.4 × 1012 = 1.4T tokens

But Meta trained on 15T tokens — more than 10× the Chinchilla optimal! Why? Because Chinchilla optimizes for training compute efficiency (best model per FLOP). But inference cost depends on model size. A smaller model trained on more tokens (the "over-trained" regime) gives better inference cost per quality. LLaMA's philosophy: spend more on training (one-time cost) to get a smaller, cheaper-to-serve model. This is the inference-optimal scaling regime.

Exercise 5.4: Compute-Optimal Model Size

You have a fixed compute budget C = 1022 FLOPs. Under Chinchilla scaling (C = 6ND, D = 20N), what is the compute-optimal model size?

Hint: Substitute D = 20N into C = 6ND.

B parameters
Show derivation
C = 6ND = 6N × 20N = 120N2
N2 = C / 120 = 1022 / 120 = 8.33 × 1019
N = √(8.33 × 1019) = 9.13 × 109 = 9.13B

A 9B model trained on 183B tokens (20 × 9.13B). This is remarkably close to LLaMA 3 8B, which Meta trained on 15T tokens (way past Chinchilla-optimal, for inference efficiency).

Exercise 5.5: Learning Rate Scaling

The standard scaling rule for learning rate with batch size is: when you double the batch size, you can increase the LR by √2 (square root scaling). If your base LR is 3 × 10-4 at batch=256, what LR should you use at batch=2048?

learning rate
Show derivation
Batch increase factor = 2048 / 256 = 8
LR scale = √8 = 2.828
New LR = 3 × 10-4 × 2.828 = 8.49 × 10-4

This is the square-root scaling rule from the "Don't Decay the Learning Rate, Increase the Batch Size" paper (Smith et al., 2018). Linear scaling (multiply LR by 8) works too in some regimes but is more aggressive. In practice, most frontier labs use a cosine LR schedule with warmup, where the peak LR is chosen via hyperparameter sweep.

Chinchilla vs inference-optimal. Chinchilla says D = 20N minimizes loss per training FLOP. But if you will serve the model to millions of users, the cost of each inference call matters more than training cost. Training a 70B model on 15T tokens costs 10× more compute than Chinchilla-optimal, but it means you can use a 70B model instead of a (Chinchilla-optimal) 700B model for the same quality — and the 70B model is 10× cheaper to serve.
You have a fixed training compute budget of 1023 FLOPs. Under Chinchilla scaling (D = 20N, total FLOPs = 6ND), what model size maximizes quality?

Chapter 6: Inference Math

Inference has two distinct phases with completely different performance characteristics. Prefill processes the entire prompt in parallel — it is a batch matmul, compute-bound. Decode generates one token at a time — it is a matrix-vector product per layer, memory-bound.

Prefill vs decode intuition. Prefill is like reading a book all at once (batch processing). Decode is like writing a book one word at a time, looking up the entire story so far for each word (sequential, memory-bandwidth limited). This fundamental asymmetry drives the entire serving architecture.

Key Formulas

Prefill throughput (compute-bound):
tokens/s = n_GPUs × peak_FLOPs / (2N) // limited by compute

Decode throughput (memory-bound):
tokens/s/seq = mem_bandwidth / (2N_bytes) // must read all weights per token
tokens/s total = tokens/s/seq × batch_size // batching amortizes weight reads

TTFT (time to first token) = prefill_tokens / prefill_throughput
ITL (inter-token latency) = 1 / decode_tokens_per_sec_per_seq
Exercise 6.1: Prefill Throughput

70B model on a single H100 (990 TFLOP/s BF16). Using 2N FLOPs per token for prefill, what is the theoretical prefill throughput?

tokens/s
Show derivation
Prefill tokens/s = Peak FLOPs / (2N) = 990 × 1012 / (2 × 70.6 × 109)
= 990 × 1012 / (141.2 × 109) = 7,014 tokens/s

In practice, you get 50-70% of this due to attention overhead, memory-bound operations within the forward pass, and kernel launch overhead. So ~3,500-5,000 tokens/s is realistic for 70B on a single H100.

Exercise 6.2: Time to First Token

You send a 4096-token prompt to a 70B model on a single H100. Assuming realistic prefill throughput of 4000 tokens/s, what is the time to first token (TTFT)?

seconds
Show derivation
TTFT = prompt_length / prefill_throughput = 4096 / 4000 = 1.024 seconds

Over a second to start responding! This is why long-context applications use chunked prefill (process the prompt in chunks to interleave with decoding requests from other users) and why prefix caching (reusing KV cache for common system prompts) is so important.

Exercise 6.3: Decode Throughput and Inter-Token Latency

70B model in BF16 on a single H100 (3.35 TB/s bandwidth). At batch_size=1, the decode step reads all model weights once per token. What is the decode speed (tokens/s) and inter-token latency?

tokens/s
ms (inter-token latency)
Show derivation
Model weights in BF16 = 70.6 × 109 × 2 bytes = 141.2 GB
Time per token = 141.2 GB / (3.35 TB/s) = 141.2 / 3350 = 0.04215 s = 42.2 ms
Tokens/s = 1 / 0.04215 = 23.7 tokens/s

At batch_size=1, decode is entirely memory-bandwidth-bound — the MXU/Tensor Cores sit mostly idle! This is why batching is critical: at batch_size=32, you read the weights once but generate 32 tokens, giving 32 × 23.7 = 758 tokens/s total throughput (though per-user latency stays ~42 ms until you become compute-bound).

Exercise 6.4: When Does Disaggregated Serving Win?

Disaggregated serving uses separate GPU pools for prefill and decode. Prefill is compute-bound, decode is memory-bound. At what batch size does decode become compute-bound on an H100 with 70B BF16?

Hint: decode becomes compute-bound when batch_size × 2N FLOPs exceeds bandwidth × time_to_read_weights.

batch size (crossover)
Show derivation

The crossover point is the ridge point! At the crossover, FLOPs per byte equals the H100's ridge point.

Per-token decode: FLOPs = 2N, Bytes = 2N (read all weights)
With batch B: FLOPs = B × 2N, Bytes = 2N (weights read once, shared)
AI = B × 2N / (2N) = B FLOPs/byte
Crossover: B = ridge point ≈ 296

At batch_size ~296, decode becomes compute-bound! Below this, memory-bound decode underutilizes the tensor cores. This is the sweet spot where disaggregated serving wins: prefill GPUs run at high utilization (always compute-bound), while decode GPUs can be cheaper (memory-optimized) or run at smaller batch sizes for lower latency. In practice, KV cache memory limits batch size long before 296.

Exercise 6.5: Continuous Batching Throughput

With static batching (batch=32, all sequences decode together), if the shortest response is 50 tokens and the longest is 500, the GPU sits idle for finished sequences. Average output length is 200 tokens. What is the effective utilization compared to continuous batching (which immediately fills empty slots)?

Hint: In static batching, all sequences run for max_length steps. GPU utilization = average_length / max_length.

% effective utilization
Show derivation
Static batching utilization = avg_length / max_length = 200 / 500 = 40%

60% of GPU decode steps are wasted on padding! Continuous batching (introduced by the Orca paper) immediately ejects completed sequences and inserts new ones, achieving near-100% slot utilization. This alone can give a 2-2.5× throughput improvement. Combined with PagedAttention (which eliminates KV cache fragmentation), this is the foundation of modern serving systems like vLLM.

You are serving a 70B model. A user sends a 100-token prompt and expects a 500-token response. Which takes longer, prefill or decode?

Chapter 7: Serving Cost

Training is a one-time expense. Serving is forever. The cost to serve one million tokens determines whether your model is a viable product or a science experiment. Let's do the math.

The Serving Cost Formula

$/1M tokens = (GPU_cost_per_hour / tokens_per_hour) × 106

tokens_per_hour = tokens/s × 3600 × utilization

For decode-bound serving (most common):
tokens/s = batch × mem_bandwidth / (2 × N × bytes_per_param)
Exercise 7.1: Cost per Million Output Tokens

70B model in BF16 on 2× H100s (to fit in memory). Achievable decode throughput at batch=64: ~1500 tokens/s. GPU utilization: 70% (includes time waiting for requests). Cost: $4/hour for the 2-GPU setup. What is the cost per million output tokens?

$ per 1M tokens
Show derivation
Effective tokens/hour = 1500 × 3600 × 0.70 = 3,780,000
$/1M tokens = ($4 / 3,780,000) × 106 = $1.06

For context, OpenAI charges $4-15/1M output tokens for GPT-4 class models. The hardware cost is ~$1, so the API price includes engineering costs, safety, margins, and the amortized training cost.

Exercise 7.2: Break-Even Batch Size

You need to serve at $0.50/1M tokens to be competitive. Same setup as 7.1 ($4/hour for 2× H100, 70% utilization). At batch=1, throughput is ~24 tok/s. Throughput scales linearly with batch until the compute-bound crossover. What minimum batch size do you need?

sequences
Show derivation
Need: tokens_per_hour ≥ $4 / ($0.50 / 106) = 8,000,000
tokens/s = 8,000,000 / (3600 × 0.70) = 3,174
batch = 3,174 / 24 = 132.25 → batch ≥ 133

You need a batch size of at least 133. But remember: each sequence in the batch needs its own KV cache. At seq_len=8192 with GQA, that is batch × 8.59/32 GB = 133 × 0.27 GB = 35.7 GB of KV cache alone, on top of the 141 GB of model weights. This is why high-batch serving requires careful memory management (PagedAttention, KV cache offloading).

Exercise 7.3: Speculative Decoding Break-Even

Speculative decoding uses a small "draft" model to propose K tokens, then the large model verifies all K in one forward pass. If the draft model is 10× faster but its acceptance rate is α, the speedup is approximately Kα/(1 + K/speedup_ratio). With K=5, draft model 10× faster, what acceptance rate α do you need for a 2× overall speedup?

Simplified: expected accepted tokens per step = Kα. Cost per step = 1 (verify) + K/10 (draft). Speedup = Kα / (1 + K/10). Set this ≥ 2.

α (acceptance rate, 0-1)
Show derivation
Speedup = Kα / (1 + K/10) ≥ 2
5α / (1 + 0.5) ≥ 2
5α / 1.5 ≥ 2
α ≥ 3.0 / 5.0 = 0.60

You need at least 60% acceptance rate. In practice, well-matched draft models (like a distilled version of the target) achieve 70-85% acceptance on typical text. Acceptance drops on code, math, and creative writing where the small model diverges more from the large model's distribution.

Exercise 7.4: Input vs Output Token Cost

API providers charge differently for input (prefill) and output (decode) tokens. For a 70B model on H100:
Prefill throughput: ~4000 tok/s (compute-bound, high GPU utilization).
Decode throughput at batch=64: ~1500 tok/s (memory-bound).
Same $4/hour GPU cost, 70% utilization for both. What is the cost per 1M input tokens vs output tokens?

$/1M input tokens
$/1M output tokens
Show derivation
Input: tokens/hour = 4000 × 3600 × 0.70 = 10,080,000
$/1M input = ($4 / 10,080,000) × 106 = $0.397
Output: tokens/hour = 1500 × 3600 × 0.70 = 3,780,000
$/1M output = ($4 / 3,780,000) × 106 = $1.06

Output tokens cost ~2.7× more than input tokens. This is why every API provider charges more for output: prefill is compute-bound (high utilization, fast), decode is memory-bandwidth-bound (lower throughput). This asymmetry also explains why prompt caching is valuable — you can amortize the cheap prefill cost across many users with the same system prompt.

Exercise 7.5: Quantization Impact on Serving Cost

You quantize the 70B model from BF16 to INT4 (4 bits per weight). This halves the model size (from 141 GB to ~35 GB), so it fits on a single H100. Decode throughput scales with bandwidth/model_size: approximately 3.35 TB/s / 35 GB ≈ 95 tokens/s at batch=1, or ~6000 tok/s at batch=64. GPU cost is now $2/hour (single GPU). What is the new $/1M output tokens?

$/1M output tokens
Show derivation
Tokens/hour = 6000 × 3600 × 0.70 = 15,120,000
$/1M output = ($2 / 15,120,000) × 106 = $0.132

INT4 quantization reduces serving cost by ~8× (from $1.06 to $0.13 per 1M tokens). This comes from two factors: (1) 4× less data to read from memory per token (faster decode), and (2) the model fits on one GPU instead of two (halved hardware cost). This is why quantization is the single highest-impact optimization for LLM serving.

Why is INT8/INT4 quantization so impactful for serving but less critical for training?

Chapter 8: Profiling Exercises

A profile tells you where time goes. Without a profile, optimization is guessing. With one, you can calculate exactly how much speedup is possible and where to focus.

The first rule of optimization. Measure before you optimize. The second rule: if you can not explain WHY a kernel takes the time it does from first principles, you do not understand the system well enough to optimize it.
Exercise 8.1: Identify the Bottleneck

Here is a simplified profile of one transformer layer's forward pass on an H100, processing batch=32, seq=2048:

KernelTime (ms)Category
QKV projection (fused gemm)0.42Compute
FlashAttention fwd0.38Compute
Output projection0.14Compute
LayerNorm + residual add0.12Memory
MLP up + gate (fused gemm)0.51Compute
SiLU activation0.08Memory
MLP down projection0.26Compute
LayerNorm + residual add0.12Memory
NCCL all-reduce (TP)0.35Communication

Total: 2.38 ms per layer. What is the biggest optimization opportunity?

Show reasoning

The memory-bound ops (2× LayerNorm + residual = 0.24 ms, SiLU = 0.08 ms) total 0.32 ms — 13.4% of the layer time. These can often be fused into the adjacent gemm kernels, effectively eliminating them. FlashAttention is already heavily optimized. NCCL can be overlapped with computation. Fusing the elementwise/normalization ops is the highest-ROI optimization with the least risk.

Exercise 8.2: Expected vs Measured Kernel Time

The MLP up+gate projection is a [batch×seq, d] × [d, 2×d_ff] matmul with batch=32, seq=2048, d=8192, d_ff=28672. BF16 on H100. What is the theoretical minimum time (compute-bound at 990 TFLOP/s)?

ms
Show derivation
M = 32 × 2048 = 65536, K = 8192, N = 2 × 28672 = 57344
FLOPs = 2 × 65536 × 8192 × 57344 = 6.15 × 1013
Time = 6.15 × 1013 / (990 × 1012) = 0.0621 s... that's 62 ms, way too high.

Wait — with TP=8, each device computes 1/8 of the N dimension:

N per device = 57344 / 8 = 7168
FLOPs per device = 2 × 65536 × 8192 × 7168 = 7.70 × 1012
Time = 7.70 × 1012 / (990 × 1012) = 0.00778 s = 7.78 ms

Hmm, the profile said 0.51 ms. This implies the profile numbers are per-layer with much smaller effective batch (the profile may be per-microbatch or with different TP). Let's recalculate for the profile's implied throughput:

Achieved FLOP/s = 7.70 × 1012 / 0.00051 = too high.

The profile numbers assume a smaller effective batch within pipeline parallelism. The exercise point: always compare measured kernel time against the theoretical minimum (FLOPs / peak). If measured is >2× theoretical, there is optimization headroom. If measured is within 1.2×, the kernel is near-optimal.

Exercise 8.3: SM Occupancy

An H100 has 132 SMs, each supporting up to 2048 threads (64 warps of 32 threads). Your kernel uses 128 threads per block and 48 KB of shared memory per block. Each SM has 228 KB of shared memory. How many blocks can run per SM, and what is the occupancy?

blocks per SM
% occupancy
Show derivation

Shared memory limit: 228 KB / 48 KB = 4.75 → 4 blocks per SM

Thread limit: 2048 / 128 = 16 blocks per SM

Limiting factor: shared memory → 4 blocks per SM

Active threads = 4 × 128 = 512
Occupancy = 512 / 2048 = 25%

25% occupancy is low. To improve: reduce shared memory per block (use tiling to process in smaller chunks), increase threads per block (if the algorithm allows), or restructure the kernel to use less shared memory. Note that low occupancy is not always bad — if the kernel is compute-bound and achieves high IPC, occupancy matters less than if it is memory-bound.

Exercise 8.4: Amdahl's Law for Kernel Fusion

From the profile in Exercise 8.1, the total layer time is 2.38 ms. If you perfectly fuse all memory-bound ops (0.32 ms total) into adjacent compute kernels, eliminating them entirely, what is the speedup for the full layer?

× speedup
Show derivation
New layer time = 2.38 - 0.32 = 2.06 ms
Speedup = 2.38 / 2.06 = 1.155×

A 15.5% speedup from fusing element-wise ops. This is Amdahl's Law in action: even eliminating 13% of the execution time only gives a 15.5% speedup because the remaining 86.5% is unchanged. For a model with 80 layers, this saves 80 × 0.32 ms = 25.6 ms per forward pass — meaningful at production scale.

Exercise 8.5: Communication Overlap

The NCCL all-reduce in the profile takes 0.35 ms. If you overlap this communication with the next layer's QKV computation (0.42 ms), what fraction of the communication is hidden?

%
Show derivation
Communication time = 0.35 ms
Overlapping computation = 0.42 ms
Since 0.42 > 0.35, 100% of communication is hidden

The all-reduce completes before the QKV computation finishes, so the communication is fully hidden. This is the idea behind communication-computation overlap — start the all-reduce as soon as the current layer's output is ready, while the next layer's forward pass is already running on the same GPU. This requires careful CUDA stream management but can eliminate communication overhead entirely when compute time exceeds communication time.

You profile a CUDA kernel and find it achieves 40% of peak memory bandwidth but only 5% of peak compute. What does this tell you?

Chapter 9: JAX Programming

JAX is the framework of choice for scaling research at Google, DeepMind, and many frontier labs. It is not PyTorch with a different API. JAX is a functional transformation system: you write pure functions, and JAX transforms them (differentiation, compilation, parallelization, vectorization).

The JAX mental model. Think of JAX as a compiler, not a runtime. jax.jit traces your Python function into XLA HLO (an intermediate representation), which XLA compiles into optimized device code. jax.grad transforms the traced graph by adding backward nodes. jax.pmap replicates the graph across devices. All transforms compose.

Guided: JAX Training Loop for Linear Regression

Here is the simplest possible JAX training loop. Study each line — the patterns here scale to billion-parameter models.

jax
import jax
import jax.numpy as jnp

# 1. Generate synthetic data: y = 3x + 1 + noise
key = jax.random.PRNGKey(42)
X = jax.random.normal(key, (100, 1))
y = 3.0 * X + 1.0 + 0.1 * jax.random.normal(key, (100, 1))

# 2. Initialize parameters (JAX uses explicit state, no hidden .grad)
params = {'w': jnp.zeros((1, 1)), 'b': jnp.zeros((1,))}

# 3. Define loss as a PURE FUNCTION of params (no side effects!)
def loss_fn(params, X, y):
    pred = X @ params['w'] + params['b']
    return jnp.mean((pred - y) ** 2)

# 4. jax.grad returns a FUNCTION that computes gradients
grad_fn = jax.grad(loss_fn)  # differentiates w.r.t. first arg (params)

# 5. Training loop — functional update, no in-place mutation
lr = 0.1
for step in range(100):
    grads = grad_fn(params, X, y)
    params = jax.tree.map(lambda p, g: p - lr * g, params, grads)
    if step % 20 == 0:
        print(f"Step {step}: loss={loss_fn(params, X, y):.4f}")

print(f"Learned: w={params['w'][0,0]:.3f}, b={params['b'][0]:.3f}")
# Expected: w≈3.0, b≈1.0
Key JAX patterns. (1) jax.grad differentiates pure functions — no .backward() on tensors. (2) jax.tree.map applies a function across all leaves of a pytree (nested dict/list of arrays). (3) No in-place updates — every step creates new arrays. (4) jax.jit wraps the training step for compilation.

Guided: Data Parallel Training with pmap

jax
import jax
import jax.numpy as jnp

n_devices = jax.local_device_count()

# Replicate params to all devices
params = jax.tree.map(lambda x: jnp.stack([x] * n_devices), params)

# Define one step: forward, backward, all-reduce, update
@jax.pmap
def train_step(params, X_shard, y_shard):
    grads = jax.grad(loss_fn)(params, X_shard, y_shard)
    # All-reduce: average gradients across devices
    grads = jax.lax.pmean(grads, axis_name='batch')
    return jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)

# Split data across devices (axis 0 = device dimension)
X_sharded = X.reshape(n_devices, -1, 1)
y_sharded = y.reshape(n_devices, -1, 1)

# Train!
for step in range(100):
    params = train_step(params, X_sharded, y_sharded)

jax.pmap maps the function across devices. jax.lax.pmean inserts an all-reduce into the compiled graph. This is FSDP-style data parallelism in 10 lines.

Guided: A Minimal Pallas Kernel

Pallas is JAX's kernel authoring framework. You write kernels in Python that compile to GPU/TPU microcode. Here is a vector addition kernel — the "hello world" of custom kernels.

jax/pallas
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl

def add_kernel(x_ref, y_ref, o_ref):
    """Pallas kernel: reads from x_ref and y_ref, writes to o_ref.
    Refs are BlockSpec views — they point to a tile of the full array."""
    o_ref[...] = x_ref[...] + y_ref[...]

def add_vectors(x, y):
    return pl.pallas_call(
        add_kernel,
        out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
        grid=(8,),  # launch 8 blocks
        in_specs=[
            pl.BlockSpec((128,), lambda i: (i,)),  # each block gets 128 elements
            pl.BlockSpec((128,), lambda i: (i,)),
        ],
        out_specs=pl.BlockSpec((128,), lambda i: (i,)),
    )(x, y)

# Test
x = jnp.arange(1024, dtype=jnp.float32)
y = jnp.ones(1024, dtype=jnp.float32)
result = add_vectors(x, y)
# result == jnp.arange(1024) + 1

Guided: Automatic Vectorization with vmap

jax.vmap automatically vectorizes a function that operates on a single example to work on a batch. This is how you write clean, single-example logic and get batched execution for free.

jax
import jax
import jax.numpy as jnp

# Function that works on ONE example (no batch dim)
def predict_single(params, x):
    """x: shape [D], returns scalar"""
    return jnp.dot(params['w'], x) + params['b']

# vmap it to handle batches: x now shape [B, D]
predict_batch = jax.vmap(predict_single, in_axes=(None, 0))
# in_axes=(None, 0): don't batch over params, batch over axis 0 of x

# This compiles to the SAME efficient batched matmul as manual batching
params = {'w': jnp.ones(768), 'b': jnp.zeros(())}
x_batch = jax.random.normal(jax.random.PRNGKey(0), (32, 768))
out = predict_batch(params, x_batch)  # shape [32]

vmap composes with jit and grad: jax.jit(jax.vmap(jax.grad(loss_fn))) gives you batched gradient computation compiled into a single XLA graph. This composability is JAX's superpower.

Why Pallas matters. XLA's compiler is good at fusing small ops, but it cannot invent new algorithms. When you need a tiled matmul with custom masking, a ragged-batch attention kernel, or an operation that does not exist in XLA, Pallas lets you write it in Python and compile it to efficient device code. This is the "write custom Pallas kernels" that Vlad Feinberg recommends.
In JAX, what does jax.jit(f) do?

Chapter 10: The Capstone — Addition Transformer

This is the exercise Vlad Feinberg specifically recommends: implement a ~10M parameter transformer in JAX/Flax/Optax that learns integer addition. It sounds trivial. It is not. This single exercise tests your understanding of tokenization, architecture design, scaling laws, training dynamics, and JAX fluency.

The meta-lesson. The addition transformer is a microcosm of frontier model development. Every decision you make here (tokenizer design, model sizing, learning rate schedule, data distribution) has a direct analog at the 70B scale. The person who has built a 10M model from scratch understands scaling in their bones, not just in their notes.

Step 1: Design the Tokenizer

The vocabulary is tiny: digits 0-9, plus sign, equals sign, and a padding token. That is 13 tokens total.

Token vocabulary (13 tokens):
0-9: digit tokens (IDs 0-9)
+ : addition operator (ID 10)
= : equals sign (ID 11)
<pad>: padding (ID 12)

Input format: "123+456=579<pad><pad>..."
All examples padded to fixed length (e.g., max_len = 13 for 3-digit addition)

The model sees the entire string but only predicts the answer portion (after the = sign). During training, you mask the loss on the prompt tokens.

Exercise 10.1: Maximum Sequence Length

For addition of up to 3-digit numbers, what is the maximum input sequence length? Consider: "999+999=1998". Count each character as one token.

tokens
Show derivation
"999+999=1998" = 3 + 1 + 3 + 1 + 4 = 12 characters

The maximum result of 3-digit addition is 999+999=1998, which is 4 digits. So max_len = 3 + 1 + 3 + 1 + 4 = 12. Pad all sequences to length 12. Shorter examples (like "1+2=3") get padded: "1+2=3<pad><pad><pad><pad><pad><pad><pad>".

Step 2: Architecture Sizing

You want ~10M parameters. For a decoder-only transformer with:

Params ≈ vocab × d + L × (4d2 + 2 × d × d_ff + norms) + d × vocab

With d_ff = 4d (standard) and small vocab (13):
Params ≈ 2 × 13d + L × (4d2 + 8d2) = 26d + L × 12d2
Params ≈ L × 12d2 (embedding terms are negligible for small vocab)
Exercise 10.2: Choose Architecture Parameters

Target 10M parameters. Try L=6 layers. What value of d (model dimension) gets you closest to 10M? Use the simplified formula: Params ≈ L × 12d2.

d (model dimension)
Show derivation
10 × 106 = 6 × 12 × d2 = 72d2
d2 = 10,000,000 / 72 = 138,889
d = √138,889 ≈ 373

Round to d=384 (divisible by common head counts). With 6 heads, d_head = 384/6 = 64. Actual params: 6 × 12 × 3842 = 10,616,832 ≈ 10.6M. Close enough!

A practical config: L=6, d=384, n_heads=6, d_head=64, d_ff=1536, vocab=13, max_seq=12.

Step 3: The Flax Implementation

Here is the core transformer in Flax. This is what you would code in a Colab notebook.

python/flax
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax

class CausalSelfAttention(nn.Module):
    n_heads: int
    d_model: int

    @nn.compact
    def __call__(self, x):
        B, T, C = x.shape
        d_head = C // self.n_heads
        # Fused QKV projection
        qkv = nn.Dense(3 * C)(x)  # [B, T, 3C]
        q, k, v = jnp.split(qkv, 3, axis=-1)
        # Reshape for multi-head: [B, T, C] -> [B, n_heads, T, d_head]
        q = q.reshape(B, T, self.n_heads, d_head).transpose(0,2,1,3)
        k = k.reshape(B, T, self.n_heads, d_head).transpose(0,2,1,3)
        v = v.reshape(B, T, self.n_heads, d_head).transpose(0,2,1,3)
        # Scaled dot-product attention with causal mask
        attn = (q @ k.transpose(0,1,3,2)) / jnp.sqrt(d_head)
        mask = jnp.tril(jnp.ones((T, T)))[None, None, :, :]
        attn = jnp.where(mask == 0, -1e9, attn)
        attn = jax.nn.softmax(attn, axis=-1)
        out = (attn @ v).transpose(0,2,1,3).reshape(B, T, C)
        return nn.Dense(C)(out)  # output projection

class TransformerBlock(nn.Module):
    n_heads: int
    d_model: int
    d_ff: int

    @nn.compact
    def __call__(self, x):
        # Pre-norm architecture (like GPT-2, LLaMA)
        x = x + CausalSelfAttention(self.n_heads, self.d_model)(nn.LayerNorm()(x))
        x = x + nn.Dense(self.d_model)(nn.gelu(nn.Dense(self.d_ff)(nn.LayerNorm()(x))))
        return x

class AdditionTransformer(nn.Module):
    vocab: int = 13
    d_model: int = 384
    n_heads: int = 6
    n_layers: int = 6
    d_ff: int = 1536
    max_len: int = 12

    @nn.compact
    def __call__(self, tokens):
        B, T = tokens.shape
        # Token + positional embeddings
        tok_emb = nn.Embed(self.vocab, self.d_model)(tokens)
        pos_emb = nn.Embed(self.max_len, self.d_model)(jnp.arange(T))
        x = tok_emb + pos_emb  # [B, T, d_model]
        # Transformer blocks
        for _ in range(self.n_layers):
            x = TransformerBlock(self.n_heads, self.d_model, self.d_ff)(x)
        x = nn.LayerNorm()(x)
        logits = nn.Dense(self.vocab)(x)  # [B, T, vocab]
        return logits
Exercise 10.3: Chinchilla-Optimal Training for 10M

Using D = 20N, how many tokens should you train your 10M model on? How many 3-digit addition examples is that (each example ~12 tokens)?

tokens
examples
Show derivation
D = 20 × 10 × 106 = 200,000,000 tokens
Examples = 200,000,000 / 12 ≈ 16,666,667 examples

There are only 999 × 999 = 998,001 unique 3-digit addition problems. At 16.7M examples, you will see each problem ~17 times on average. This is fine — language models see each training example multiple times too. With 3-digit numbers and random sampling, the model should converge to near-perfect accuracy well before exhausting 200M tokens.

Exercise 10.2b: Training Compute for the Addition Transformer

Your 10.6M parameter model will train on 200M tokens. What is the total training compute in GFLOPs? How long will this take on a single T4 GPU (65 TFLOP/s FP16, assuming 30% MFU)?

GFLOP (total training FLOPs)
minutes
Show derivation
FLOPs = 6 × 10.6 × 106 × 200 × 106 = 1.272 × 1016 = 12,720 GFLOP
Effective throughput = 65 × 1012 × 0.30 = 19.5 × 1012 FLOP/s
Time = 1.272 × 1016 / (19.5 × 1012) = 652 seconds ≈ 10.9 minutes

Under 11 minutes on a free Colab T4! This is exactly the kind of experiment Vlad recommends: small enough to iterate quickly, but large enough to teach real scaling concepts. If your training takes longer, you are probably not batching efficiently or your data pipeline is the bottleneck.

Step 4: Dense vs MoE

Vlad's blog also asks you to derive Chinchilla scaling for MoE (Mixture of Experts). The key insight: an MoE model has more total parameters (N_total) but only activates a fraction per token (N_active). Chinchilla's compute law uses the active parameter count.

Dense: FLOPs/token = 2N, Training FLOPs = 6ND
MoE: FLOPs/token = 2N_active, Training FLOPs = 6 × N_active × D

If top-2 routing with E experts: N_active ≈ N_shared + 2/E × N_expert
Chinchilla optimal: D = 20 × N_active (based on compute, not total params)
Exercise 10.4: MoE vs Dense Scaling

You have a compute budget of 6 × 1017 FLOPs. Compare two options:
(A) Dense 1B model, trained on D_A tokens.
(B) MoE with 8 experts, top-2 routing, N_active = 1B (but total params = ~3.5B). Trained on D_B tokens.
Under Chinchilla scaling, how many tokens does each get trained on?

B tokens (Dense)
B tokens (MoE)
Show derivation
Budget = 6 × N_active × D
D = Budget / (6 × N_active) = 6 × 1017 / (6 × 109) = 108 = 100B tokens

Both get 100B tokens! Because both have N_active = 1B, and the training compute formula uses N_active, the token count is identical. The MoE advantage is that it has 3.5B total parameters (more capacity to store knowledge) while using the same training compute as the 1B dense model. At inference, however, the MoE model needs more memory for the extra expert weights but the same compute per token.

Step 5: When Does Pallas Beat ragged_dot?

MoE routing creates "ragged" batches: each expert gets a different number of tokens. JAX's jax.lax.ragged_dot handles this, but a custom Pallas kernel can beat it when the number of experts (or the "fan-out" F) exceeds the model dimension D.

Exercise 10.5: The F > D Condition

In MoE, after routing, you combine up-projection and down-projection for each expert. With E=64 experts, top-2 routing, d=1024: the combined projection is essentially a grouped matmul with F (fan-out) = E × top_k = 128 independent matmuls. When F > D, each individual matmul is tiny ([tokens_per_expert, D] × [D, D_ff]). Why does Pallas win here?

Show reasoning

With F=128 experts and a batch of, say, 2048 tokens routed top-2, each expert gets ~32 tokens on average. Each expert's matmul is [32, 1024] × [1024, 4096] — far too small to saturate the H100's tensor cores. ragged_dot launches these as padded grouped GEMMs with wasted compute on padding. A Pallas kernel can:

1. Tile across experts AND tokens simultaneously
2. Fuse the up and down projections (compute up, apply activation, compute down in one pass without writing intermediate to HBM)
3. Handle ragged batch sizes natively without padding
4. Amortize kernel launch overhead (1 launch instead of 128)

This is exactly the kind of optimization that differentiates a frontier lab engineer from someone who just calls library functions.

You have trained your 10M addition transformer and it reaches 95% accuracy on 3-digit addition. You want 99.9%. What should you try first?

Chapter 11: Self-Assessment Checklist

Here is every skill this workbook covers. Be honest with yourself: can you do each of these from scratch, with paper and pencil, without looking up the formula? Check off each one as you master it.

The proof-of-work standard. Vlad Feinberg's advice: "Record yourself doing exercises with paper and pencil." We cannot record video here, but the equivalent is doing each exercise below from memory. Close this page, open a blank notebook, and work through them. If you can not reproduce the derivation from scratch, you have not learned it yet.

Roofline & Hardware

Sharding & Communication

Parameter Counting & FLOPs

Training Parallelism

Training Cost

Inference & Serving

Profiling & Kernels

The Capstone

Your Progress

Click checklist items above to track your progress.

0 / 37 mastered

Related Lessons

TopicLesson
Transformer internalsTransformer — From Absolute Zero
GPT architectureGPT — From Absolute Zero
Diffusion modelsDiffusion — From Absolute Zero
Reward & alignmentRLHF — From Absolute Zero
Inference engineeringML Inference Engineer — Day In The Life

The Study Protocol

Here is the protocol that maximizes retention from this workbook. It is based on spaced retrieval practice — the most evidence-backed learning technique in cognitive science.

Day 1: Work Through
Complete every exercise in this workbook with the hints and solutions available. Take notes on any formula you did not know. The goal is exposure, not mastery.
Day 2: Blank Paper
Open a blank notebook. Without looking at this page, try to re-derive every key formula: ridge point, 6ND, Chinchilla D=20N, KV cache sizing, MFU, serving cost. Check your work afterward. Mark the ones you got wrong.
Day 4: Teach It
Explain three concepts from this workbook to a friend, a rubber duck, or a voice recorder. Speaking forces deeper processing than passive review. Focus on the ones you missed on Day 2.
Day 7: Full Test
Blank paper, no hints. Work through 10 randomly selected exercises. If you score 8/10, move on. If not, repeat the cycle for the missed topics.
Day 14: Implement
Code the addition transformer in a Colab notebook. Run training. Measure actual throughput and compare to your theoretical prediction. This is your proof of work.
The standard. "What I cannot create, I do not understand." — Richard Feynman. If you can work through every exercise in this workbook from a blank sheet of paper, you understand scaling at the level that frontier labs expect. That is the proof of work.
What is the single most important formula to internalize from this workbook?