Every exercise from the JAX Scaling Book, solvable in-browser. Pencil-and-paper math meets instant feedback.
You are an ML engineer staring at a GPU profiler trace. A matrix multiply that should take 0.14 ms is taking 0.8 ms. Is the bottleneck memory bandwidth? Compute throughput? Launch overhead? Without a mental model, you are guessing. The roofline model gives you that mental model.
Before you can optimize anything, you need to know what limits performance. Is your operation waiting for data to arrive from memory, or waiting for the arithmetic units to finish crunching? The roofline model answers this question with a single number: arithmetic intensity.
Arithmetic intensity (AI) is the ratio of floating-point operations to bytes moved from memory. If you do many FLOPs per byte loaded, you are compute-bound. If you do few FLOPs per byte, you are memory-bound. The dividing line is the hardware's ridge point — the arithmetic intensity where the processor's compute throughput exactly matches its memory bandwidth.
For a matrix multiplication C = A × B where A is [M, K] and B is [K, N]:
Compute the arithmetic intensity for a square matmul: [4096, 4096] × [4096, 4096] in BF16 (2 bytes per element).
Hint: FLOPs = 2 × 40963. Bytes = 3 × 40962 × 2.
For square matmuls, AI simplifies to 2N/6 = N/3 when element size = 2 bytes. So 4096/3 ≈ 1365. This is deeply compute-bound on any modern accelerator.
The H100 SXM has ~990 TFLOP/s BF16 Tensor Core throughput and ~3.35 TB/s HBM bandwidth. What is the ridge point?
Any operation with AI above ~296 is compute-bound on the H100. Our matmul from Exercise 0.1 (AI ≈ 1365) is well above this — solidly compute-bound. Element-wise ops like ReLU (AI = 1 FLOP / 4 bytes = 0.25) are deeply memory-bound.
Given the [4096, 4096] × [4096, 4096] BF16 matmul is compute-bound on the H100, predict the theoretical execution time assuming 990 TFLOP/s peak.
In practice, you will see ~0.15-0.2 ms due to kernel launch overhead and imperfect occupancy. The ratio of achieved to theoretical is your MFU (Model FLOPs Utilization). Getting above 50% MFU on real workloads is considered good.
Consider the QKT computation in self-attention: Q and K are both [batch, heads, seq, d_head]. With batch=1, heads=32, seq=2048, d_head=128, BF16. What is the arithmetic intensity of the QKT matmul (treating each head independently)?
Hint: Per head, this is a [2048, 128] × [128, 2048] matmul. Compute AI for one head.
Per head: M=2048, K=128, N=2048, element_size=2 bytes.
Wait — but we must also account for the softmax and the AV matmul. The standard attention also reads/writes the full S = QKT matrix of shape [2048, 2048] to HBM. With this materialization, total bytes for the full attention including the S matrix read/write becomes much larger, dropping effective AI to roughly 62. This is why FlashAttention exists — it avoids materializing S to HBM by computing attention in tiles, fusing the QKT, softmax, and AV multiply into a single kernel. On an H100 with ridge point ~296, standard attention at AI~62 is memory-bound. FlashAttention raises the effective AI well above the ridge point.
ReLU activation on a tensor of 10M elements in FP32 (4 bytes). ReLU does one comparison per element (1 FLOP). What is the arithmetic intensity?
This is 2,364× below the H100's ridge point of 296. ReLU is absurdly memory-bound. This is why activation functions are never standalone kernels in optimized code — they are always fused into the preceding or following matmul.
During autoregressive decode, the main operation per layer is a matrix-vector product: [1, d] × [d, d] where d=8192. BF16. What is the arithmetic intensity? Is it compute-bound or memory-bound on H100?
AI ≈ 1 — deeply memory-bound (ridge point is 296). The weight matrix dominates the byte count. This is why decode is memory-bandwidth-limited: each token requires reading the entire weight matrix for a trivial number of FLOPs. Batching helps by amortizing the weight read across B tokens, raising AI to ~B.
Google's Tensor Processing Units (TPUs) are custom ASICs designed for one thing: dense matrix multiplication at massive scale. Unlike GPUs, which evolved from graphics workloads and carry a lot of general-purpose flexibility, TPUs are laser-focused on the matmul-softmax-matmul pipeline that dominates transformer training and inference.
The heart of a TPU is the MXU (Matrix Multiply Unit) — a systolic array that computes a matrix multiply in a single pass. A TPU v5e has a 128×128 BF16 systolic array running at ~1.1 GHz. Each cycle, the MXU produces 128×128 = 16,384 multiply-accumulate results, and each MAC is 2 FLOPs.
A TPU v5e has a 128×128 BF16 MXU running at 1.1 GHz. How many TFLOP/s of BF16 matmul throughput does one MXU produce?
Hint: Each cycle produces 128 × 128 MAC operations. Each MAC = 2 FLOPs (multiply + add).
For comparison, an H100 achieves ~990 TFLOP/s BF16 — but that is with 4th-gen Tensor Cores operating on much smaller tiles at higher clock and much wider parallelism. The TPU v5e compensates by being cheaper and packing more chips per pod.
TPU v5e specs: ~197 TFLOP/s peak BF16 (with 2 MXUs + vector unit contributions) and ~819 GB/s HBM bandwidth. What is the ridge point?
The TPU v5e has a lower ridge point than the H100 (~296), meaning operations become compute-bound at lower arithmetic intensity. This is because the TPU's ratio of compute-to-bandwidth is more balanced — relatively more bandwidth per FLOP.
You run a layer normalization on a tensor of shape [2048, 8192] in BF16 on a TPU v5e (819 GB/s bandwidth). LayerNorm reads the input once, writes the output once, plus reads/writes for mean and variance (negligible for large tensors). Approximate the minimum execution time.
Hint: LayerNorm is memory-bound. Total bytes = 2 × (read + write) of the tensor.
In practice it will be longer due to kernel launch overhead and the fact that LayerNorm also computes a reduction (mean, variance), but the memory transfer dominates. This is why fusing LayerNorm with the preceding or following operation is one of the most impactful optimizations — you eliminate one full round-trip to HBM.
Fill in the comparison table from memory, then check your answers.
| Spec | TPU v5e | H100 SXM |
|---|---|---|
| BF16 peak TFLOP/s | ~197 | ~990 |
| HBM bandwidth | ~819 GB/s | ~3,350 GB/s |
| HBM capacity | 16 GB | 80 GB |
| Ridge point | ~240 | ~296 |
| Interconnect | ICI (4800 Gbps) | NVLink (900 GB/s) |
At what model size (in BF16 parameters) does a single TPU v5e run out of memory?
A TPU v5e can hold at most an 8B model in weights alone (no KV cache, no activations). In practice, inference of even a 7B model requires careful memory management. This is why TPU pods aggregate many chips — a v5e pod with 256 chips has 4 TB of total HBM.
A single accelerator can only hold so much data in HBM. A 70B parameter model in BF16 needs 140 GB just for weights — far more than any single GPU's memory. The solution: shard the matrices across multiple devices and coordinate the computation.
There are three fundamental ways to shard a matmul C = A × B where A is [M, K] and B is [K, N]:
| Strategy | Shard A along | Shard B along | Communication needed |
|---|---|---|---|
| Row-parallel | M (rows) | Replicated | None (outputs are independent) |
| Column-parallel | Replicated | N (columns) | None (outputs are independent) |
| Reduce-scatter | K (inner dim) | K (inner dim) | All-reduce on partial sums |
You have a weight matrix W of shape [8192, 32768] and input X of shape [2048, 8192]. You want to compute Y = X × W across 8 devices with minimum communication. Which sharding minimizes communication?
Column-parallel: replicate X (2048×8192×2 = 32 MB per device), shard W into 8 chunks of [8192, 4096] (67 MB each). Each device computes Y_chunk = X × W_chunk with no communication. Output Y is already distributed column-wise.
Row-parallel: shard X into rows (each device gets [256, 8192]), replicate W (8192×32768×2 = 512 MB per device!). Wastes memory replicating the huge weight matrix.
K-sharding: requires an all-reduce. Bad for minimum communication.
Column-parallel wins: zero communication, and the smaller matrix (X) is the one replicated.
You shard a [8192, 8192] × [8192, 8192] BF16 matmul along K across 8 devices. Each device computes a partial [8192, 8192] result that must be summed via all-reduce. Using the ring all-reduce algorithm, how many bytes does each device send in total?
Hint: Ring all-reduce sends 2 × (P-1)/P × tensor_size bytes per device, where P = number of devices.
The ring all-reduce has two phases: reduce-scatter (each device sends (P-1)/P of the data) and all-gather (same volume again). Total = 2 × (P-1)/P × data_size. As P grows, this approaches 2× the tensor size — the lower bound for any all-reduce algorithm.
Using the same K-sharded matmul from 2.2 on 8 devices, each device computes a [8192, 1024] × [1024, 8192] matmul. The all-reduce transfers 224 MB. If each device runs the matmul in 0.017 ms (compute-bound at ~990 TFLOP/s) and the interconnect has 450 GB/s bidirectional bandwidth, what is the communication time? Is the overall operation communication-bound?
Communication takes 29× longer than computation! This is why K-sharding is avoided when possible — the all-reduce dominates. Column-parallel or row-parallel sharding avoids this entirely for individual layers, at the cost of replicating one matrix.
Consider one transformer layer of LLaMA 3 70B. The layer has ~856M parameters in BF16 = 1.71 GB of weights. Compare two strategies on 8 devices:
(A) FSDP: all-gather the full layer weights before forward. Volume = (P-1)/P × layer_size per device.
(B) TP: two all-reduces per layer (attention + MLP) on output tensors of shape [batch×seq, d]. With batch=4, seq=4096, d=8192, BF16.
Which transfers fewer bytes per device?
FSDP all-gather:
TP all-reduce (per operation):
FSDP transfers 1.50 GB, TP transfers 0.94 GB. TP wins on volume for this config. But TP must happen over fast NVLink (within a node), while FSDP can use slower inter-node links since it is less latency-sensitive (one big transfer vs many small ones). This is why TP is used within nodes and FSDP across nodes.
Before you can estimate training cost, inference memory, or sharding strategy, you need to count parameters precisely. No approximations, no hand-waving. Every weight matrix, every bias, every embedding. This is the foundation of all scaling analysis.
One transformer block contains a self-attention layer and an MLP (feed-forward) layer, each with layer norms. Here is the parameter breakdown:
GPT-2 (small): 12 layers, d=768, 12 heads, d_head=64, d_ff=3072 (4×d), vocab=50257. Standard MHA (not GQA), standard MLP (not SwiGLU, so 2 projections not 3), with biases on all linear layers, and separate LN weight+bias.
Count total parameters. Include: 4 attention projections + biases, 2 MLP projections + biases, 2 LayerNorms (weight+bias each), token embedding, position embedding (1024×768), and final LN.
Per transformer block:
Global:
OpenAI reported 124M. Note that GPT-2 ties the output head with the token embedding, so no extra parameters there.
LLaMA 3 70B: 80 layers, d=8192, 64 attention heads, 8 KV heads (GQA), d_head=128, d_ff=28672 (SwiGLU, so 3 projections), vocab=128256. No biases anywhere. RoPE (0 params). RMSNorm (weight only, no bias).
Count total parameters in billions.
Per block attention:
Per block MLP (SwiGLU):
Per block norms:
Total per block:
Global:
The standard approximation for forward pass FLOPs per token is 2N (where N = number of parameters), because each parameter participates in one multiply-add (2 FLOPs). For LLaMA 3 70B (N ≈ 70.6B), how many FLOPs per token in the forward pass?
This approximation counts matmul FLOPs only. Attention QKT and softmax add roughly 2 × n_layers × seq_len × d FLOPs per token, which for seq=8192 is ~1010 — about 7% of the 2N term. For back-of-envelope calculations, 2N is sufficient.
The standard training FLOP estimate is 6ND: forward (2N) + backward (4N, because gradients require ~2× the forward FLOPs), times D tokens. LLaMA 3 70B was trained on 15T tokens. How many total FLOPs?
That is about 6.35 yottaFLOPs. For context, this is roughly 1024 operations — the same order of magnitude as the number of stars in the observable universe multiplied by a million.
LLaMA 3 70B: 80 layers, GQA with 8 KV heads, d_head=128, batch=32, seq=8192, BF16. How many GB for the KV cache?
Formula: 2 (K and V) × n_layers × n_kv_heads × d_head × seq_len × batch × bytes_per_element
GQA helps enormously here. Without it (64 KV heads instead of 8), the cache would be 68.7 GB — 8× larger. The model weights alone are 141 GB in BF16, so even with GQA the KV cache is a significant fraction of total memory at large batch sizes.
You are training LLaMA 3 70B on 2048 H100s. You measure 3200 tokens/second/GPU. Each H100 has 990 TFLOP/s BF16 peak. What is your Model FLOPs Utilization (MFU)?
MFU = (actual model FLOPs/s per GPU) / (peak FLOPs/s per GPU). Model FLOPs per token = 6N (forward + backward).
Let me recalculate. The 6N counts forward + backward total. But MFU typically uses forward-only FLOPs (2N) per token, divided by wall-clock time per step:
Note: the exact MFU definition varies. Some use 6N (including backward), some use 2N (forward only). Meta reported ~38-43% MFU for LLaMA 3 training. The key point: anything above 40% on a large distributed run is excellent. The gap to 100% comes from communication overhead, pipeline bubbles, memory-bound operations (layer norm, softmax), and suboptimal occupancy.
You cannot train a 70B model on one GPU. You need parallelism — splitting the work across many devices. There are four main flavors, and real training runs combine all of them.
| Strategy | What's split | Communication | Memory saving |
|---|---|---|---|
| Data Parallel (DP) | Batch | All-reduce gradients | None (each GPU has full model) |
| FSDP (ZeRO-3) | Batch + params/grads/optimizer | All-gather params, reduce-scatter grads | ~P× (P = # GPUs) |
| Tensor Parallel (TP) | Individual weight matrices | All-reduce per layer | ~TP_degree× |
| Pipeline Parallel (PP) | Layers across stages | Point-to-point activations | ~PP_degree× |
model_params + gradients + optimizer_states + activations. With Adam in BF16 mixed precision: params (2 bytes) + grads (2 bytes) + optimizer (master weights 4B + momentum 4B + variance 4B = 12 bytes) = 16 bytes per parameter, plus activation memory. FSDP shards all of these across GPUs.70B model, Adam optimizer, BF16 training. Without FSDP, per-GPU memory for params+grads+optimizer = 16 bytes × N. With FSDP across 64 GPUs, how much memory per GPU for these components?
This is just model state. Activation memory is additional and depends on sequence length, batch size, and whether you use activation checkpointing. An H100 has 80 GB HBM, so 17.5 GB for model state leaves ~62 GB for activations — comfortable.
In tensor-parallel attention with TP=8, each device computes 8 of 64 attention heads for LLaMA 3 70B. The output projection produces a partial result that requires an all-reduce. The output tensor per token has shape [1, 8192] in BF16. For a microbatch of 4 sequences, each of length 8192, how many bytes are transferred in the all-reduce?
Hint: all-reduce on the output of shape [batch × seq, d]. Use the ring formula: 2 × (P-1)/P × tensor_bytes.
Wait — the attention block actually has TWO all-reduces (one for attention output, one for MLP output). However, within each, there are also two column-parallel + row-parallel pairs. The key insight is that TP requires one all-reduce per row-parallel layer. For a full transformer block (attention + MLP), that is 2 all-reduces, each of size [batch × seq, d].
Per all-reduce: 896/2 = 448 MB per device (using single all-reduce calculation). This is why TP is typically done within a single node over NVLink (900 GB/s) rather than across nodes.
Pipeline parallelism with PP=4 stages. If you use m microbatches, the bubble fraction (wasted time) is (PP-1)/(m+PP-1). How many microbatches do you need to keep the bubble fraction below 5%?
You need at least 57 microbatches. This means your global batch size must be at least 57 × microbatch_size. This is why PP works best with very large global batch sizes, and why modern systems prefer FSDP + TP over heavy PP when possible.
Meta trained LLaMA 3 70B on 2048 H100s. Given: TP must stay within a node (8 GPUs per node), PP adds bubbles, FSDP handles the rest. Which configuration makes sense?
Option B (TP=8, PP=4, FSDP=64) is closest to what Meta actually used. Here is why:
TP=8: Uses all 8 GPUs within a node via NVLink (900 GB/s). Maximum bandwidth for TP all-reduces.
PP=4: Splits 80 layers into 4 stages of 20 layers. Reduces activation memory. With enough microbatches (Meta used ~100), bubble fraction is ~3%.
FSDP=64: 2048 / (8 × 4) = 64 FSDP groups handle the remaining parallelism with gradient communication over the inter-node network.
Option A (no PP) would work but requires more activation memory per stage. Option C (pure FSDP) has too much inter-node communication. Option D (TP=2) wastes NVLink bandwidth.
Activation memory per layer (no activation checkpointing): you must store the input to each layer for the backward pass. For LLaMA 3 70B with batch=4, seq=4096, d=8192, BF16: how much activation memory per layer?
Simplified: store the input tensor [batch × seq, d] plus the attention intermediate [batch, n_heads, seq, seq] (for the backward through attention).
The attention score matrix dominates! For 80 layers: 80 × 8.25 = 660 GB of activation memory — impossible without mitigation. Solutions: (1) Activation checkpointing: recompute activations in the backward pass instead of storing them. Trades ~33% extra compute for ~80% memory savings. (2) FlashAttention: never materializes the full attention matrix, saving the 8 GB/layer term entirely. (3) Sequence parallelism: shard the activation across the TP dimension.
Now we combine everything: parameter count, FLOPs, hardware specs, and utilization to answer the question every ML lead asks: "How long will this take, and how much will it cost?"
Where N = parameters, D = tokens, n_GPUs = GPU count, peak_FLOPs = per-GPU peak FLOP/s, MFU = model FLOPs utilization (typically 0.35-0.50).
LLaMA 3 70B: N=70.6B, D=15T tokens, 2048 H100s at 990 TFLOP/s peak BF16, MFU=0.40. How many days to train?
Meta reported about 54 days for LLaMA 3 70B, which implies either higher MFU (~0.43-0.50), overlap of communication and computation, or a different token count for the main training phase. Our estimate is in the right ballpark.
Using your answer from 5.1 (~91 days), at $2/GPU-hour for H100 cloud pricing, what is the total training cost?
Roughly $9M. In practice, costs are higher due to failed runs, hyperparameter sweeps, evaluation compute, data preprocessing, and the cost of the cluster sitting idle during maintenance. A realistic total project cost is 2-3× the pure training compute cost, so $20-30M for a 70B model.
The Chinchilla scaling law says the compute-optimal number of tokens D is approximately 20 × N. For a 70B model, what is the Chinchilla-optimal token count?
But Meta trained on 15T tokens — more than 10× the Chinchilla optimal! Why? Because Chinchilla optimizes for training compute efficiency (best model per FLOP). But inference cost depends on model size. A smaller model trained on more tokens (the "over-trained" regime) gives better inference cost per quality. LLaMA's philosophy: spend more on training (one-time cost) to get a smaller, cheaper-to-serve model. This is the inference-optimal scaling regime.
You have a fixed compute budget C = 1022 FLOPs. Under Chinchilla scaling (C = 6ND, D = 20N), what is the compute-optimal model size?
Hint: Substitute D = 20N into C = 6ND.
A 9B model trained on 183B tokens (20 × 9.13B). This is remarkably close to LLaMA 3 8B, which Meta trained on 15T tokens (way past Chinchilla-optimal, for inference efficiency).
The standard scaling rule for learning rate with batch size is: when you double the batch size, you can increase the LR by √2 (square root scaling). If your base LR is 3 × 10-4 at batch=256, what LR should you use at batch=2048?
This is the square-root scaling rule from the "Don't Decay the Learning Rate, Increase the Batch Size" paper (Smith et al., 2018). Linear scaling (multiply LR by 8) works too in some regimes but is more aggressive. In practice, most frontier labs use a cosine LR schedule with warmup, where the peak LR is chosen via hyperparameter sweep.
Inference has two distinct phases with completely different performance characteristics. Prefill processes the entire prompt in parallel — it is a batch matmul, compute-bound. Decode generates one token at a time — it is a matrix-vector product per layer, memory-bound.
70B model on a single H100 (990 TFLOP/s BF16). Using 2N FLOPs per token for prefill, what is the theoretical prefill throughput?
In practice, you get 50-70% of this due to attention overhead, memory-bound operations within the forward pass, and kernel launch overhead. So ~3,500-5,000 tokens/s is realistic for 70B on a single H100.
You send a 4096-token prompt to a 70B model on a single H100. Assuming realistic prefill throughput of 4000 tokens/s, what is the time to first token (TTFT)?
Over a second to start responding! This is why long-context applications use chunked prefill (process the prompt in chunks to interleave with decoding requests from other users) and why prefix caching (reusing KV cache for common system prompts) is so important.
70B model in BF16 on a single H100 (3.35 TB/s bandwidth). At batch_size=1, the decode step reads all model weights once per token. What is the decode speed (tokens/s) and inter-token latency?
At batch_size=1, decode is entirely memory-bandwidth-bound — the MXU/Tensor Cores sit mostly idle! This is why batching is critical: at batch_size=32, you read the weights once but generate 32 tokens, giving 32 × 23.7 = 758 tokens/s total throughput (though per-user latency stays ~42 ms until you become compute-bound).
Disaggregated serving uses separate GPU pools for prefill and decode. Prefill is compute-bound, decode is memory-bound. At what batch size does decode become compute-bound on an H100 with 70B BF16?
Hint: decode becomes compute-bound when batch_size × 2N FLOPs exceeds bandwidth × time_to_read_weights.
The crossover point is the ridge point! At the crossover, FLOPs per byte equals the H100's ridge point.
At batch_size ~296, decode becomes compute-bound! Below this, memory-bound decode underutilizes the tensor cores. This is the sweet spot where disaggregated serving wins: prefill GPUs run at high utilization (always compute-bound), while decode GPUs can be cheaper (memory-optimized) or run at smaller batch sizes for lower latency. In practice, KV cache memory limits batch size long before 296.
With static batching (batch=32, all sequences decode together), if the shortest response is 50 tokens and the longest is 500, the GPU sits idle for finished sequences. Average output length is 200 tokens. What is the effective utilization compared to continuous batching (which immediately fills empty slots)?
Hint: In static batching, all sequences run for max_length steps. GPU utilization = average_length / max_length.
60% of GPU decode steps are wasted on padding! Continuous batching (introduced by the Orca paper) immediately ejects completed sequences and inserts new ones, achieving near-100% slot utilization. This alone can give a 2-2.5× throughput improvement. Combined with PagedAttention (which eliminates KV cache fragmentation), this is the foundation of modern serving systems like vLLM.
Training is a one-time expense. Serving is forever. The cost to serve one million tokens determines whether your model is a viable product or a science experiment. Let's do the math.
70B model in BF16 on 2× H100s (to fit in memory). Achievable decode throughput at batch=64: ~1500 tokens/s. GPU utilization: 70% (includes time waiting for requests). Cost: $4/hour for the 2-GPU setup. What is the cost per million output tokens?
For context, OpenAI charges $4-15/1M output tokens for GPT-4 class models. The hardware cost is ~$1, so the API price includes engineering costs, safety, margins, and the amortized training cost.
You need to serve at $0.50/1M tokens to be competitive. Same setup as 7.1 ($4/hour for 2× H100, 70% utilization). At batch=1, throughput is ~24 tok/s. Throughput scales linearly with batch until the compute-bound crossover. What minimum batch size do you need?
You need a batch size of at least 133. But remember: each sequence in the batch needs its own KV cache. At seq_len=8192 with GQA, that is batch × 8.59/32 GB = 133 × 0.27 GB = 35.7 GB of KV cache alone, on top of the 141 GB of model weights. This is why high-batch serving requires careful memory management (PagedAttention, KV cache offloading).
Speculative decoding uses a small "draft" model to propose K tokens, then the large model verifies all K in one forward pass. If the draft model is 10× faster but its acceptance rate is α, the speedup is approximately Kα/(1 + K/speedup_ratio). With K=5, draft model 10× faster, what acceptance rate α do you need for a 2× overall speedup?
Simplified: expected accepted tokens per step = Kα. Cost per step = 1 (verify) + K/10 (draft). Speedup = Kα / (1 + K/10). Set this ≥ 2.
You need at least 60% acceptance rate. In practice, well-matched draft models (like a distilled version of the target) achieve 70-85% acceptance on typical text. Acceptance drops on code, math, and creative writing where the small model diverges more from the large model's distribution.
API providers charge differently for input (prefill) and output (decode) tokens. For a 70B model on H100:
Prefill throughput: ~4000 tok/s (compute-bound, high GPU utilization).
Decode throughput at batch=64: ~1500 tok/s (memory-bound).
Same $4/hour GPU cost, 70% utilization for both. What is the cost per 1M input tokens vs output tokens?
Output tokens cost ~2.7× more than input tokens. This is why every API provider charges more for output: prefill is compute-bound (high utilization, fast), decode is memory-bandwidth-bound (lower throughput). This asymmetry also explains why prompt caching is valuable — you can amortize the cheap prefill cost across many users with the same system prompt.
You quantize the 70B model from BF16 to INT4 (4 bits per weight). This halves the model size (from 141 GB to ~35 GB), so it fits on a single H100. Decode throughput scales with bandwidth/model_size: approximately 3.35 TB/s / 35 GB ≈ 95 tokens/s at batch=1, or ~6000 tok/s at batch=64. GPU cost is now $2/hour (single GPU). What is the new $/1M output tokens?
INT4 quantization reduces serving cost by ~8× (from $1.06 to $0.13 per 1M tokens). This comes from two factors: (1) 4× less data to read from memory per token (faster decode), and (2) the model fits on one GPU instead of two (halved hardware cost). This is why quantization is the single highest-impact optimization for LLM serving.
A profile tells you where time goes. Without a profile, optimization is guessing. With one, you can calculate exactly how much speedup is possible and where to focus.
Here is a simplified profile of one transformer layer's forward pass on an H100, processing batch=32, seq=2048:
| Kernel | Time (ms) | Category |
|---|---|---|
| QKV projection (fused gemm) | 0.42 | Compute |
| FlashAttention fwd | 0.38 | Compute |
| Output projection | 0.14 | Compute |
| LayerNorm + residual add | 0.12 | Memory |
| MLP up + gate (fused gemm) | 0.51 | Compute |
| SiLU activation | 0.08 | Memory |
| MLP down projection | 0.26 | Compute |
| LayerNorm + residual add | 0.12 | Memory |
| NCCL all-reduce (TP) | 0.35 | Communication |
Total: 2.38 ms per layer. What is the biggest optimization opportunity?
The memory-bound ops (2× LayerNorm + residual = 0.24 ms, SiLU = 0.08 ms) total 0.32 ms — 13.4% of the layer time. These can often be fused into the adjacent gemm kernels, effectively eliminating them. FlashAttention is already heavily optimized. NCCL can be overlapped with computation. Fusing the elementwise/normalization ops is the highest-ROI optimization with the least risk.
The MLP up+gate projection is a [batch×seq, d] × [d, 2×d_ff] matmul with batch=32, seq=2048, d=8192, d_ff=28672. BF16 on H100. What is the theoretical minimum time (compute-bound at 990 TFLOP/s)?
Wait — with TP=8, each device computes 1/8 of the N dimension:
Hmm, the profile said 0.51 ms. This implies the profile numbers are per-layer with much smaller effective batch (the profile may be per-microbatch or with different TP). Let's recalculate for the profile's implied throughput:
The profile numbers assume a smaller effective batch within pipeline parallelism. The exercise point: always compare measured kernel time against the theoretical minimum (FLOPs / peak). If measured is >2× theoretical, there is optimization headroom. If measured is within 1.2×, the kernel is near-optimal.
An H100 has 132 SMs, each supporting up to 2048 threads (64 warps of 32 threads). Your kernel uses 128 threads per block and 48 KB of shared memory per block. Each SM has 228 KB of shared memory. How many blocks can run per SM, and what is the occupancy?
Shared memory limit: 228 KB / 48 KB = 4.75 → 4 blocks per SM
Thread limit: 2048 / 128 = 16 blocks per SM
Limiting factor: shared memory → 4 blocks per SM
25% occupancy is low. To improve: reduce shared memory per block (use tiling to process in smaller chunks), increase threads per block (if the algorithm allows), or restructure the kernel to use less shared memory. Note that low occupancy is not always bad — if the kernel is compute-bound and achieves high IPC, occupancy matters less than if it is memory-bound.
From the profile in Exercise 8.1, the total layer time is 2.38 ms. If you perfectly fuse all memory-bound ops (0.32 ms total) into adjacent compute kernels, eliminating them entirely, what is the speedup for the full layer?
A 15.5% speedup from fusing element-wise ops. This is Amdahl's Law in action: even eliminating 13% of the execution time only gives a 15.5% speedup because the remaining 86.5% is unchanged. For a model with 80 layers, this saves 80 × 0.32 ms = 25.6 ms per forward pass — meaningful at production scale.
The NCCL all-reduce in the profile takes 0.35 ms. If you overlap this communication with the next layer's QKV computation (0.42 ms), what fraction of the communication is hidden?
The all-reduce completes before the QKV computation finishes, so the communication is fully hidden. This is the idea behind communication-computation overlap — start the all-reduce as soon as the current layer's output is ready, while the next layer's forward pass is already running on the same GPU. This requires careful CUDA stream management but can eliminate communication overhead entirely when compute time exceeds communication time.
JAX is the framework of choice for scaling research at Google, DeepMind, and many frontier labs. It is not PyTorch with a different API. JAX is a functional transformation system: you write pure functions, and JAX transforms them (differentiation, compilation, parallelization, vectorization).
jax.jit traces your Python function into XLA HLO (an intermediate representation), which XLA compiles into optimized device code. jax.grad transforms the traced graph by adding backward nodes. jax.pmap replicates the graph across devices. All transforms compose.Here is the simplest possible JAX training loop. Study each line — the patterns here scale to billion-parameter models.
jax import jax import jax.numpy as jnp # 1. Generate synthetic data: y = 3x + 1 + noise key = jax.random.PRNGKey(42) X = jax.random.normal(key, (100, 1)) y = 3.0 * X + 1.0 + 0.1 * jax.random.normal(key, (100, 1)) # 2. Initialize parameters (JAX uses explicit state, no hidden .grad) params = {'w': jnp.zeros((1, 1)), 'b': jnp.zeros((1,))} # 3. Define loss as a PURE FUNCTION of params (no side effects!) def loss_fn(params, X, y): pred = X @ params['w'] + params['b'] return jnp.mean((pred - y) ** 2) # 4. jax.grad returns a FUNCTION that computes gradients grad_fn = jax.grad(loss_fn) # differentiates w.r.t. first arg (params) # 5. Training loop — functional update, no in-place mutation lr = 0.1 for step in range(100): grads = grad_fn(params, X, y) params = jax.tree.map(lambda p, g: p - lr * g, params, grads) if step % 20 == 0: print(f"Step {step}: loss={loss_fn(params, X, y):.4f}") print(f"Learned: w={params['w'][0,0]:.3f}, b={params['b'][0]:.3f}") # Expected: w≈3.0, b≈1.0
jax.grad differentiates pure functions — no .backward() on tensors. (2) jax.tree.map applies a function across all leaves of a pytree (nested dict/list of arrays). (3) No in-place updates — every step creates new arrays. (4) jax.jit wraps the training step for compilation.jax import jax import jax.numpy as jnp n_devices = jax.local_device_count() # Replicate params to all devices params = jax.tree.map(lambda x: jnp.stack([x] * n_devices), params) # Define one step: forward, backward, all-reduce, update @jax.pmap def train_step(params, X_shard, y_shard): grads = jax.grad(loss_fn)(params, X_shard, y_shard) # All-reduce: average gradients across devices grads = jax.lax.pmean(grads, axis_name='batch') return jax.tree.map(lambda p, g: p - 0.1 * g, params, grads) # Split data across devices (axis 0 = device dimension) X_sharded = X.reshape(n_devices, -1, 1) y_sharded = y.reshape(n_devices, -1, 1) # Train! for step in range(100): params = train_step(params, X_sharded, y_sharded)
jax.pmap maps the function across devices. jax.lax.pmean inserts an all-reduce into the compiled graph. This is FSDP-style data parallelism in 10 lines.
Pallas is JAX's kernel authoring framework. You write kernels in Python that compile to GPU/TPU microcode. Here is a vector addition kernel — the "hello world" of custom kernels.
jax/pallas import jax import jax.numpy as jnp from jax.experimental import pallas as pl def add_kernel(x_ref, y_ref, o_ref): """Pallas kernel: reads from x_ref and y_ref, writes to o_ref. Refs are BlockSpec views — they point to a tile of the full array.""" o_ref[...] = x_ref[...] + y_ref[...] def add_vectors(x, y): return pl.pallas_call( add_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), grid=(8,), # launch 8 blocks in_specs=[ pl.BlockSpec((128,), lambda i: (i,)), # each block gets 128 elements pl.BlockSpec((128,), lambda i: (i,)), ], out_specs=pl.BlockSpec((128,), lambda i: (i,)), )(x, y) # Test x = jnp.arange(1024, dtype=jnp.float32) y = jnp.ones(1024, dtype=jnp.float32) result = add_vectors(x, y) # result == jnp.arange(1024) + 1
jax.vmap automatically vectorizes a function that operates on a single example to work on a batch. This is how you write clean, single-example logic and get batched execution for free.
jax import jax import jax.numpy as jnp # Function that works on ONE example (no batch dim) def predict_single(params, x): """x: shape [D], returns scalar""" return jnp.dot(params['w'], x) + params['b'] # vmap it to handle batches: x now shape [B, D] predict_batch = jax.vmap(predict_single, in_axes=(None, 0)) # in_axes=(None, 0): don't batch over params, batch over axis 0 of x # This compiles to the SAME efficient batched matmul as manual batching params = {'w': jnp.ones(768), 'b': jnp.zeros(())} x_batch = jax.random.normal(jax.random.PRNGKey(0), (32, 768)) out = predict_batch(params, x_batch) # shape [32]
vmap composes with jit and grad: jax.jit(jax.vmap(jax.grad(loss_fn))) gives you batched gradient computation compiled into a single XLA graph. This composability is JAX's superpower.
jax.jit(f) do?This is the exercise Vlad Feinberg specifically recommends: implement a ~10M parameter transformer in JAX/Flax/Optax that learns integer addition. It sounds trivial. It is not. This single exercise tests your understanding of tokenization, architecture design, scaling laws, training dynamics, and JAX fluency.
The vocabulary is tiny: digits 0-9, plus sign, equals sign, and a padding token. That is 13 tokens total.
The model sees the entire string but only predicts the answer portion (after the = sign). During training, you mask the loss on the prompt tokens.
For addition of up to 3-digit numbers, what is the maximum input sequence length? Consider: "999+999=1998". Count each character as one token.
The maximum result of 3-digit addition is 999+999=1998, which is 4 digits. So max_len = 3 + 1 + 3 + 1 + 4 = 12. Pad all sequences to length 12. Shorter examples (like "1+2=3") get padded: "1+2=3<pad><pad><pad><pad><pad><pad><pad>".
You want ~10M parameters. For a decoder-only transformer with:
Target 10M parameters. Try L=6 layers. What value of d (model dimension) gets you closest to 10M? Use the simplified formula: Params ≈ L × 12d2.
Round to d=384 (divisible by common head counts). With 6 heads, d_head = 384/6 = 64. Actual params: 6 × 12 × 3842 = 10,616,832 ≈ 10.6M. Close enough!
A practical config: L=6, d=384, n_heads=6, d_head=64, d_ff=1536, vocab=13, max_seq=12.
Here is the core transformer in Flax. This is what you would code in a Colab notebook.
python/flax import jax import jax.numpy as jnp import flax.linen as nn import optax class CausalSelfAttention(nn.Module): n_heads: int d_model: int @nn.compact def __call__(self, x): B, T, C = x.shape d_head = C // self.n_heads # Fused QKV projection qkv = nn.Dense(3 * C)(x) # [B, T, 3C] q, k, v = jnp.split(qkv, 3, axis=-1) # Reshape for multi-head: [B, T, C] -> [B, n_heads, T, d_head] q = q.reshape(B, T, self.n_heads, d_head).transpose(0,2,1,3) k = k.reshape(B, T, self.n_heads, d_head).transpose(0,2,1,3) v = v.reshape(B, T, self.n_heads, d_head).transpose(0,2,1,3) # Scaled dot-product attention with causal mask attn = (q @ k.transpose(0,1,3,2)) / jnp.sqrt(d_head) mask = jnp.tril(jnp.ones((T, T)))[None, None, :, :] attn = jnp.where(mask == 0, -1e9, attn) attn = jax.nn.softmax(attn, axis=-1) out = (attn @ v).transpose(0,2,1,3).reshape(B, T, C) return nn.Dense(C)(out) # output projection class TransformerBlock(nn.Module): n_heads: int d_model: int d_ff: int @nn.compact def __call__(self, x): # Pre-norm architecture (like GPT-2, LLaMA) x = x + CausalSelfAttention(self.n_heads, self.d_model)(nn.LayerNorm()(x)) x = x + nn.Dense(self.d_model)(nn.gelu(nn.Dense(self.d_ff)(nn.LayerNorm()(x)))) return x class AdditionTransformer(nn.Module): vocab: int = 13 d_model: int = 384 n_heads: int = 6 n_layers: int = 6 d_ff: int = 1536 max_len: int = 12 @nn.compact def __call__(self, tokens): B, T = tokens.shape # Token + positional embeddings tok_emb = nn.Embed(self.vocab, self.d_model)(tokens) pos_emb = nn.Embed(self.max_len, self.d_model)(jnp.arange(T)) x = tok_emb + pos_emb # [B, T, d_model] # Transformer blocks for _ in range(self.n_layers): x = TransformerBlock(self.n_heads, self.d_model, self.d_ff)(x) x = nn.LayerNorm()(x) logits = nn.Dense(self.vocab)(x) # [B, T, vocab] return logits
Using D = 20N, how many tokens should you train your 10M model on? How many 3-digit addition examples is that (each example ~12 tokens)?
There are only 999 × 999 = 998,001 unique 3-digit addition problems. At 16.7M examples, you will see each problem ~17 times on average. This is fine — language models see each training example multiple times too. With 3-digit numbers and random sampling, the model should converge to near-perfect accuracy well before exhausting 200M tokens.
Your 10.6M parameter model will train on 200M tokens. What is the total training compute in GFLOPs? How long will this take on a single T4 GPU (65 TFLOP/s FP16, assuming 30% MFU)?
Under 11 minutes on a free Colab T4! This is exactly the kind of experiment Vlad recommends: small enough to iterate quickly, but large enough to teach real scaling concepts. If your training takes longer, you are probably not batching efficiently or your data pipeline is the bottleneck.
Vlad's blog also asks you to derive Chinchilla scaling for MoE (Mixture of Experts). The key insight: an MoE model has more total parameters (N_total) but only activates a fraction per token (N_active). Chinchilla's compute law uses the active parameter count.
You have a compute budget of 6 × 1017 FLOPs. Compare two options:
(A) Dense 1B model, trained on D_A tokens.
(B) MoE with 8 experts, top-2 routing, N_active = 1B (but total params = ~3.5B). Trained on D_B tokens.
Under Chinchilla scaling, how many tokens does each get trained on?
Both get 100B tokens! Because both have N_active = 1B, and the training compute formula uses N_active, the token count is identical. The MoE advantage is that it has 3.5B total parameters (more capacity to store knowledge) while using the same training compute as the 1B dense model. At inference, however, the MoE model needs more memory for the extra expert weights but the same compute per token.
MoE routing creates "ragged" batches: each expert gets a different number of tokens. JAX's jax.lax.ragged_dot handles this, but a custom Pallas kernel can beat it when the number of experts (or the "fan-out" F) exceeds the model dimension D.
In MoE, after routing, you combine up-projection and down-projection for each expert. With E=64 experts, top-2 routing, d=1024: the combined projection is essentially a grouped matmul with F (fan-out) = E × top_k = 128 independent matmuls. When F > D, each individual matmul is tiny ([tokens_per_expert, D] × [D, D_ff]). Why does Pallas win here?
With F=128 experts and a batch of, say, 2048 tokens routed top-2, each expert gets ~32 tokens on average. Each expert's matmul is [32, 1024] × [1024, 4096] — far too small to saturate the H100's tensor cores. ragged_dot launches these as padded grouped GEMMs with wasted compute on padding. A Pallas kernel can:
1. Tile across experts AND tokens simultaneously
2. Fuse the up and down projections (compute up, apply activation, compute down in one pass without writing intermediate to HBM)
3. Handle ragged batch sizes natively without padding
4. Amortize kernel launch overhead (1 launch instead of 128)
This is exactly the kind of optimization that differentiates a frontier lab engineer from someone who just calls library functions.
Here is every skill this workbook covers. Be honest with yourself: can you do each of these from scratch, with paper and pencil, without looking up the formula? Check off each one as you master it.
Click checklist items above to track your progress.
0 / 37 mastered
| Topic | Lesson |
|---|---|
| Transformer internals | Transformer — From Absolute Zero |
| GPT architecture | GPT — From Absolute Zero |
| Diffusion models | Diffusion — From Absolute Zero |
| Reward & alignment | RLHF — From Absolute Zero |
| Inference engineering | ML Inference Engineer — Day In The Life |
Here is the protocol that maximizes retention from this workbook. It is based on spaced retrieval practice — the most evidence-backed learning technique in cognitive science.