Count every FLOP, trace every byte. Training costs, inference bottlenecks, KV caching, and speculative decoding — the systems view of transformers.
You've just trained a transformer model. It works. But it takes 14 hours per epoch on 8 GPUs, and inference serves 40 tokens per second. Your boss wants 200. Where is the time going? Which part do you optimize?
Most ML engineers can write a transformer from scratch. Far fewer can answer: how many floating-point operations does a single forward pass cost? Or: why is autoregressive decoding 100x slower than what the hardware should theoretically support?
The answers live in two numbers: FLOPs (floating-point operations — the computational work) and memory bandwidth (bytes moved per second between GPU memory and compute cores). Every bottleneck in training and inference comes down to one of these being the limiting factor.
Drag the batch size slider. Small batches are memory-bound (GPU starves for data). Large batches are compute-bound (math is the bottleneck). The crossover is the arithmetic intensity threshold.
In the chapters ahead, we'll count every FLOP in training (forward + backward), trace every byte in memory, dissect inference into its two phases (prefill and decode), and learn the tricks — KV caching, speculative decoding, continuous batching — that close the gap between theoretical and actual throughput.
Let's count. Every matrix multiply has a precise FLOP cost, and a transformer is just a stack of matrix multiplies with nonlinearities between them. Once you know the rule, you can compute the cost of any model on a napkin.
Multiplying a matrix of shape [M, K] by one of shape [K, N] produces a [M, N] output. Each output element requires K multiplications and K−1 additions. We approximate this as 2MKN FLOPs (the factor of 2 counts both multiplies and adds).
Consider one MLP layer in a transformer. The input is a batch of B sequences, each of length S, with hidden dimension H. The MLP first projects up to 4H (the intermediate size), applies an activation, then projects back down:
Total MLP FLOPs per layer: 16BSH².
Attention has four projections (Q, K, V, and output) plus the attention score computation. Each of Q, K, V, O is a [H, H] weight matrix applied to [BS, H] input:
Total attention FLOPs per layer: 8BSH² + 4BS²H.
A transformer with L layers has total forward FLOPs:
Plus the final vocabulary projection: 2BSH·V (where V is vocabulary size). For most models, the 24BSH² term dominates, so the quick rule of thumb is:
| Parameter | GPT-2 Small |
|---|---|
| Hidden dim H | 768 |
| Layers L | 12 |
| Vocab V | 50,257 |
| Seq length S | 1,024 |
| Batch B | 1 |
Let's compute step by step. Tokens T = B×S = 1,024.
MLP per layer: 16 × 1,024 × 768² = 16 × 1,024 × 589,824 = 9.66 × 109 FLOPs
Attention projections per layer: 8 × 1,024 × 768² = 4.83 × 109 FLOPs
Attention scores per layer: 4 × 1,024² × 768 = 3.22 × 109 FLOPs
Per layer total: 9.66 + 4.83 + 3.22 = 17.7 × 109 FLOPs
All 12 layers: 12 × 17.7 × 109 = 212 GFLOPs
Vocab projection: 2 × 1,024 × 768 × 50,257 = 79.0 GFLOPs
Total forward: ≈ 291 GFLOPs per batch of 1,024 tokens.
Quick check via the napkin rule: GPT-2 Small has ~124M parameters. 2 × 124M × 1,024 ≈ 254 GFLOPs. Close enough — the difference is the attention score term (4BS²H) which the napkin rule ignores.
Adjust model dimensions and see FLOPs break down by component. The bar chart shows where compute goes.
Training isn't just the forward pass. After computing the loss, we backpropagate gradients through every layer to update weights. The backward pass is always roughly twice the FLOPs of the forward pass. Let's see exactly why.
Consider a linear layer: z = x · W, where x is [B, H] and W is [H, N]. The forward pass costs 2BHN FLOPs. During backprop, we receive dL/dz (same shape as z: [B, N]) and need two things:
Two matrix multiplies, each costing 2BHN. Total backward FLOPs for this layer: 4BHN. Compare with the forward pass: 2BHN. The backward pass is exactly 2× the forward pass for a linear layer.
Look at the weight gradient formula: dL/dW = xT · dL/dz. We need x — the activation from the forward pass. This means during forward, we must save every intermediate activation to reuse during backward. This is why training is so memory-hungry: we store activations for every layer.
The chain rule also means intermediate gradients like dL/dv are reused across multiple weight updates. An efficient backprop implementation computes each gradient once and caches it — never recomputing the same thing twice.
Since the backward pass is 2× forward, the total training cost per step is:
For our GPT-2 Small example: 3 × 291 GFLOPs = 873 GFLOPs per training step (batch size 1, sequence length 1024).
With the napkin rule: training FLOPs ≈ 6PT (6 × parameters × tokens), since forward ≈ 2PT and we multiply by 3.
This visualization shows a computation graph for L = 2(w1w2)w3. Watch the forward pass compute values left-to-right, then backprop compute gradients right-to-left, reusing cached values.
FLOPs tell you how much work a GPU does. Memory tells you whether the work fits. In practice, memory is almost always the binding constraint in training. Let's trace exactly where every byte goes.
A transformer with L layers and hidden dimension H has roughly:
In fp16 (2 bytes per parameter), GPT-2 Small: 12 × 12 × 768² = 84.9M weight-params ≈ 170 MB. Add embeddings and you get ~248 MB for 124M params.
Adam stores two extra copies of every parameter: the first moment (mean of gradients) and the second moment (mean of squared gradients). Both are kept in fp32 for numerical stability. Plus we store a fp32 master copy of the weights:
| Component | Bytes per param | GPT-2 Small (124M params) |
|---|---|---|
| Weights (fp16) | 2 | 248 MB |
| Gradients (fp16) | 2 | 248 MB |
| Master weights (fp32) | 4 | 496 MB |
| Adam m (fp32) | 4 | 496 MB |
| Adam v (fp32) | 4 | 496 MB |
| Total | 16 | 1.98 GB |
That's 16 bytes per parameter just for optimizer state + weights + gradients. A 7B model: 7×109 × 16 = 112 GB — already more than a single A100-80GB can hold!
During training, we cache intermediate activations for backprop. Per layer, we store the input to each linear layer, the attention scores, and the MLP intermediate. For a transformer layer with batch B, sequence S, hidden H:
For GPT-2 Small with B=1, S=1024, H=768, nheads=12: about 54 MB per layer, or 648 MB for all 12 layers. This is why activation checkpointing (recompute activations during backward instead of storing them) is essential for long sequences.
Adjust model size and see where GPU memory goes. The stacked bars show weights, optimizer state, gradients, and activations.
Training is expensive but predictable. Inference is where things get weird. An LLM serving a chat request has two distinct phases, and they have completely different performance characteristics.
The user sends a prompt of N tokens. The model processes them all at once in a single forward pass, just like training. This is compute-bound — there's enough parallel work to keep the GPU busy. The output: a KV cache and the first generated token.
Now the model generates tokens one at a time. Each step processes a single token. The Q, K, V projections are tiny: [1, H] × [H, H]. The arithmetic intensity (FLOPs per byte loaded) plummets. We must load the entire model's weights from memory to produce one token.
Arithmetic intensity is the ratio of FLOPs to bytes moved:
An A100 has a compute-to-bandwidth ratio of 312 TFLOPS / 2 TB/s = 156 FLOP/byte. If your operation has AI < 156, you're memory-bound. If AI > 156, you're compute-bound.
For a matmul [B, H] × [H, H]: FLOPs = 2BH², bytes loaded ≈ 2H² (the weight matrix in fp16). So AI = 2BH² / 2H² = B. When B = 1 (single-token decode), AI = 1 — we're catastrophically memory-bound. We need B ≥ 156 to saturate compute!
Watch a prompt get prefilled in parallel, then tokens decoded one-by-one. The orange bar shows GPU utilization — notice how it drops during decode.
| Property | Prefill | Decode |
|---|---|---|
| Tokens processed | All prompt tokens at once | One token per step |
| Arithmetic intensity | High (batch=N) | Low (batch=1) |
| Bottleneck | Compute-bound | Memory-bandwidth-bound |
| GPU utilization | High (>50%) | Low (<5% of peak FLOPS) |
| Latency | Fast (parallel) | Slow (sequential) |
Here's the core insight of KV caching: during autoregressive generation, the keys and values for previously generated tokens don't change. Token 5's K and V only depend on tokens 1–5. When we generate token 6, token 5's K and V are the same as before. So why recompute them?
Without caching, generating token N requires computing Q, K, V for all N tokens, then computing attention. That's O(N × H²) for projections alone. With caching, we only compute Q, K, V for the new token, then append its K, V to the cache and attend over the full sequence:
The savings are massive. At step 1000, we do 6d² FLOPs instead of 6000d². But there's a cost: we must store all those cached K and V vectors.
For each token in the sequence, each layer stores a K vector and a V vector, each of dimension dmodel. In fp16 (2 bytes per value):
Model: 7B parameters, L=32 layers, dmodel=4096, sequence length 1024.
Model weights: 7B × 2 bytes = 14 GB
KV cache per token: 2 × 2 × 32 × 4096 = 524,288 bytes ≈ 0.5 MB
KV cache for 1024 tokens: 0.5 MB × 1024 = 512 MB
Remaining for batching: 40 − 14 = 26 GB free. At 0.5 MB/token, that's 26 GB / 0.5 MB = ~52,000 tokens of KV cache. With S=1024, we can serve a batch of ~50 requests simultaneously.
KV caching adds memory traffic (reading the cached K,V) but saves compute. It makes sense when we're compute-bound — but during decode we're memory-bound. The resolution: KV caching saves us from becoming even more memory-bound. Without it, we'd load all N weight matrices N times; with it, we load them once and read the (smaller) cache.
More precisely, recomputing K,V for all N tokens costs 2 × 2Nd² FLOPs and loads 2d² × 2 bytes of weights. Caching costs 0 extra FLOPs for projection but reads 2 × 2 × N × d bytes of cached values. For large d, the cache read is much smaller than the weight load for recomputation.
Adjust sequence length, number of layers, and hidden dimension to see the KV cache grow. The visualization shows memory allocation on a 40GB GPU. Red means the cache exceeds available memory.
Autoregressive decode is slow because we generate one token at a time, and each token requires loading the entire model from memory. The GPU has compute to spare but nothing to do with it. Speculative decoding asks: what if we guessed the next several tokens and verified them all at once?
Use a small, fast draft model to guess the next K tokens. Then run all K tokens through the large target model in a single forward pass (like prefill). Check which guesses match. Accept the correct prefix, discard the rest, and repeat.
Remember: during decode, the GPU is memory-bound. Processing 1 token or K tokens takes almost the same wall-clock time, because the bottleneck is loading weights, not computing. If we batch K guessed tokens, the extra compute is "free" — we're just using idle compute capacity.
The speedup depends on the acceptance rate — what fraction of draft tokens the target model agrees with. In practice, a draft model ~15× smaller achieves 70–80% acceptance on natural language, giving 2–3× speedup.
| Strategy | How it works | Pros | Cons |
|---|---|---|---|
| Draft model | Smaller LLM trained on same data | Good agreement (~70-80%), easy to implement | Extra model to manage, extra memory |
| Medusa heads | Extra prediction heads on the target model itself | Near-zero cost, no extra model | Guesses degrade quickly (mean-field approx); ~2× max speedup |
| Lossy optimization | Quantized/pruned version of target model | Highly correlated with target, no extra model | Complex code, under-explored |
Watch the draft model generate guesses (fast, purple) and the target model verify them (orange). Green = accepted, red = rejected. Adjust acceptance rate and speculation depth.
We've established that single-token decode is memory-bound. The natural fix: batch multiple requests together. If we're loading the weights anyway, why not process 32 users' tokens at once?
Two very different optimization targets:
Time for ONE request to complete. Critical for interactive applications (chatbots, Copilot). A user waiting 2 seconds feels snappy; 10 seconds feels broken. Batching doesn't help individual latency — it can even hurt.
Total tokens generated per second across ALL requests. Critical for serving many users or processing offline batches (summarization, data analysis). Batching directly increases throughput by amortizing weight loads.
With batch size B, the arithmetic intensity of a matmul becomes B (since FLOPs scale with B but weight loads don't). As we increase B:
| Batch size B | Arithmetic intensity | Regime | Bottleneck |
|---|---|---|---|
| 1 | 1 | Memory-bound | Weight loading |
| 32 | 32 | Memory-bound | Weight loading (less idle compute) |
| 156+ | ≥156 | Compute-bound | Arithmetic units |
But we can't batch infinitely! Each request in the batch needs its own KV cache. From Chapter 5, a 7B model with S=1024 uses ~512 MB of KV cache per request. On a 40GB GPU with 14GB for weights, we have 26GB for KV caches: 26GB / 512MB = ~50 concurrent requests.
In naive batching, all requests in a batch must finish before new ones enter. If one request needs 500 tokens and another needs 10, the GPU sits idle for 490 tokens on the short request.
Continuous batching (also called iteration-level batching) replaces finished requests immediately. As soon as request A finishes, request D takes its slot in the next forward pass. This keeps the batch full and the GPU busy.
The KV cache for each request is a contiguous block of memory. If we pre-allocate for max sequence length, most of it is wasted. PagedAttention (from vLLM) borrows the idea of virtual memory pages from operating systems: allocate KV cache in small blocks, linked via a page table. This eliminates fragmentation and lets us serve 2–4× more concurrent requests.
Increase batch size and watch the GPU move from memory-bound to compute-bound. The teal line is memory time, the orange line is compute time. The bottleneck is whichever is higher.
You now have the systems vocabulary to analyze any transformer workload. Let's consolidate.
| Quantity | Formula | GPT-2 Small (124M) |
|---|---|---|
| Forward FLOPs | ≈ 2PT | 254 GFLOPs |
| Training FLOPs (per step) | ≈ 6PT | 762 GFLOPs |
| Params memory (fp16) | 2P bytes | 248 MB |
| Training memory (Adam) | 16P bytes | 1.98 GB |
| KV cache per token | 4Ld bytes | 73.7 KB |
| Arithmetic intensity (decode) | ≈ batch size | 1 (for B=1) |
| Backward/Forward ratio | 2× | — |
| Topic | What it is | Where to learn |
|---|---|---|
| FlashAttention | Fused kernel that avoids materializing the full attention matrix | CS 229s Lecture 03 |
| Tensor parallelism | Splitting individual layers across GPUs | CS 229s Lecture 05 |
| Pipeline parallelism | Putting different layers on different GPUs | CS 229s Lecture 05 |
| Quantization | Reducing precision (int8, int4, GPTQ, AWQ) | CS 229s Lecture 06 |
| Multi-Query Attention | Sharing K,V heads to shrink the KV cache | Shazeer 2019 |
"The purpose of computing is insight, not numbers." — Richard Hamming