From sampling a single token to scaling inference across many accelerators — prefill vs decode, arithmetic intensity, KV cache management, batching, and speculative decoding.
Training and inference are fundamentally different. Training processes large batches and is almost always compute-bound. Inference, especially generation, introduces a new constraint: latency. The user is waiting for each token.
Autoregressive generation has two distinct phases:
This split has profound implications. Prefill throughput scales with FLOPs. Decode throughput scales with memory bandwidth. Optimizing for one often hurts the other, which is why modern serving systems treat them separately.
Here is a concrete example to build intuition. Consider LLaMA 70B on 16 TPU v5e chips:
| Phase | Tokens per step | FLOPs/step | Bytes loaded/step | Arithmetic intensity | Bottleneck |
|---|---|---|---|---|---|
| Prefill (S=8192) | 8192 | 1.15e15 | 140 GB (weights) | 8192 | Compute |
| Decode (BS=1) | 1 | 1.4e11 | 140 GB (weights) | 1 | Memory BW |
| Decode (BS=240) | 240 | 3.36e13 | 140 GB + KV | 240 | Compute (MLP only) |
Notice: prefill and decode load the same weights but prefill does 8192x more useful FLOPs with them. This is why decode at small batch sizes feels so wasteful — we load 140 GB of parameters just to generate a single token.
During prefill, we process the entire prompt at once. For a prompt of S tokens, each layer performs matmuls of size (S, D) × (D, F). This is a large, compute-intensive operation — identical to a training forward pass.
The arithmetic intensity of prefill (FLOPs per byte loaded):
The numerator is the FLOPs for one matmul (2SDF). The denominator is the bytes to load the weight matrix (2DF in bf16). The arithmetic intensity is simply S, the sequence length!
Prefill latency (assuming compute-bound and MFU utilization U):
where P = parameter count, N = chip count, C = peak FLOPs/s, U = utilization.
For LLaMA 70B with S = 8192 on 16 TPU v5e chips at 40% MFU:
Prefill is essentially a mini training forward pass. The key difference is that there is no backward pass, so we do 2P FLOPs instead of 6P. The matmul shapes are identical to training: each FFN layer computes (S, D) × (D, F), where S is the full sequence length. This is a large, compute-rich matmul that fully saturates the MXU.
Prefill and the KV cache: During prefill, we compute the K and V projections for every input token and store them in the KV cache. This is a one-time cost. After prefill, the KV cache contains all the context the model needs to start generating.
For LLaMA 70B with S=8192 in int8: 8192 × 160 kB = 1.31 GB. This must fit in HBM alongside the model weights before decode can begin.
Chunked prefill: For very long prompts (>32K tokens), the activation memory during prefill can be large. Chunked prefill processes the prompt in chunks (e.g., 2048 tokens at a time), appending to the KV cache incrementally. This trades a small amount of compute efficiency for bounded activation memory.
During decode, we generate one token at a time. Each step performs a matmul of size (B, D) × (D, F) where B is the batch size. But unlike prefill, B is typically small (1-256), so the arithmetic intensity is very low.
The arithmetic intensity equals the batch size! On TPU v5e, we need B ≥ 240 to be compute-bound. At batch size 1, we are 240x below the compute-bound threshold — the MXU is sitting mostly idle while we wait for weights to load from HBM.
The general formula for decode step time at small batch sizes (memory-bandwidth-bound):
There are two terms because each step loads: (1) the full model parameters, and (2) the full KV cache for the attention computation. The KV cache grows with sequence length and batch size.
Decode throughput (tokens per second):
To maximize throughput, increase B until compute-bound. But each sequence in the batch needs its own KV cache, so larger B requires more HBM for KV storage.
Worked example — decode on a single TPU v5e:
We can barely fit 1 sequence on a single v5e chip! For a usable batch size, we need multiple chips or quantization.
The "time to first token" (TTFT) and "time between tokens" (TBT) metrics:
| Metric | Determined by | User impact |
|---|---|---|
| TTFT | Prefill latency | How long before the first word appears |
| TBT | Decode step latency | How fast words stream after the first |
A comfortable reading speed is ~5 tokens/s. So TBT < 200 ms feels "instant" to users. Most LLM serving systems achieve TBT of 10-50 ms, well within this budget. TTFT is more variable and often the user-facing bottleneck.
The decode "efficiency gap": During training, MFU of 40-50% is considered good. During decode at BS=1, the equivalent metric is roughly 0.4% (we use 1/240th of the MXU). This is a 100x efficiency gap! This is why LLM inference is so much more expensive per FLOP than training. The entire field of inference optimization is about closing this gap.
The three ways to close the gap:
Let us formalize when each phase transitions from memory-bound to compute-bound. The roofline model gives us a clear threshold.
For a matmul Y = X × W with X of shape (B, D) and W of shape (D, F) in bf16:
When B is much smaller than D and F (as in decode), the weight loading (2DF) dominates and AI ≈ B. When B is large (as in prefill), the compute dominates.
The critical batch size B* is where compute time equals memory time:
| Hardware | bf16 FLOPs/s | HBM BW | B* (bf16 weights) | B* (int8 weights) |
|---|---|---|---|---|
| TPU v5e | 1.97e14 | 820 GB/s | 240 | 120 |
| TPU v5p | 4.59e14 | 2765 GB/s | 166 | 83 |
| H100 | 9.9e14 | 3350 GB/s | 296 | 148 |
For attention (not just the MLP), the picture is different. Attention loads the KV cache, which grows with sequence length. The KV load per decode step is proportional to B × S (batch times sequence length), and this is always memory-bandwidth-bound for reasonable sequence lengths.
Attention arithmetic intensity: For the attention QK^T matmul during decode, the query is (B, 1, N, H) and the keys are (B, S, K, H). The FLOPs are 2BNSH. The bytes loaded are the KV cache: 2 × B × S × K × H × sizeof(dtype). The arithmetic intensity:
For LLaMA 70B: N=64, K=8, so AI = 8. Since the critical AI on v5e is 240, attention is 30x below the compute-bound threshold. It is always heavily bandwidth-bound during decode, regardless of batch size.
This is why KV cache loading dominates at long contexts. Even if the MLP becomes compute-bound at BS=120, the attention is still bandwidth-bound, and its cost grows linearly with sequence length.
The full roofline picture for one decode step:
The KV cache stores the key and value projections for all previous tokens, so we do not recompute them on every decode step. It is one of the most important (and most memory-hungry) data structures in LLM inference.
For a model with K KV-heads, head dimension H, and L layers, the KV cache size per token is:
For LLaMA 3-70B with int8 KV caches:
| Sequence length | KV cache / sequence | KV cache @ BS=32 | KV cache @ BS=240 |
|---|---|---|---|
| 2,048 | 328 MB | 10.5 GB | 78.6 GB |
| 8,192 | 1.3 GB | 41.9 GB | 314 GB |
| 32,768 | 5.3 GB | 168 GB | 1.26 TB |
| 131,072 | 21 GB | 671 GB | 5.04 TB |
The KV cache creates a direct trade-off between batch size and sequence length. More sequences in a batch means more KV caches, which means less room for longer contexts.
How GQA helps: Standard MHA (Multi-Head Attention) with 64 heads needs 64 KV heads. GQA-8 (8 groups) needs only 8 KV heads, an 8x reduction. For LLaMA 70B:
At 8K context, this is the difference between 10.7 GB vs 1.3 GB per sequence. GQA is one of the most important architecture decisions for inference efficiency.
KV cache compression techniques:
| Technique | Savings | Trade-off |
|---|---|---|
| Quantization (int8/int4) | 2-4x memory | Slight quality loss, especially int4 |
| GQA (fewer KV heads) | N/K reduction | Built into architecture, no runtime cost |
| Paged Attention | Eliminates fragmentation | More complex memory management |
| Sliding Window | Fixed KV size regardless of context | Loses access to early context |
| Token eviction | Variable, task-dependent | May drop important tokens |
Paged Attention (vLLM): Without paged attention, each sequence pre-allocates its maximum KV cache size, even if the sequence has not reached that length yet. This wastes 50-90% of KV memory. Paged Attention allocates KV blocks on demand, like virtual memory pages, achieving near-zero waste.
Worked example — memory waste without paged attention:
Paged Attention eliminates this by allocating KV blocks (e.g., 256 tokens each) only when needed. The 125 GB of waste could support another 95 sequences at 8K average length — tripling effective batch size and throughput.
KV cache and context length: the cost curve
Let us quantify the cost of supporting longer contexts. For LLaMA 70B int8 on 16 v5e at BS=32:
| Context | KV total | % of HBM used by KV | Step time | Tok/s/chip |
|---|---|---|---|---|
| 2K | 10.5 GB | 4.1% | 6.1 ms | 328 |
| 8K | 41.9 GB | 16.4% | 8.5 ms | 235 |
| 32K | 167.8 GB | 65.5% | 17.8 ms | 112 |
| 128K | 671 GB | Does not fit! | — | — |
Going from 2K to 32K context costs a 3x reduction in throughput per chip, entirely due to KV cache loading. At 128K, we cannot even fit BS=32 on 16 chips. We would need 64+ chips or lower batch size significantly.
Since decode is memory-bandwidth-bound at small batch sizes, the obvious optimization is to increase the batch size. Each step already loads the full model weights — processing more tokens per step amortizes the weight-loading cost.
The general formula for decode step latency with both parameter and KV cache loading:
Breaking this down:
For LLaMA 70B on 16 TPU v5e with int8 weights and 8K context:
| Batch size | Step latency | Throughput (tok/s/chip) | Status |
|---|---|---|---|
| 1 | ~5.5 ms | 11.4 | Bandwidth-bound |
| 32 | ~8 ms | 250 | Bandwidth-bound |
| 120 | ~14 ms | 536 | Near compute-bound |
| 240 | ~20 ms | 750 | Compute-bound (MLP) |
The limit on batch size is HBM capacity. Total HBM usage = model parameters + batch_size × seq_len × KV_per_token. Once this fills the available HBM, we cannot increase batch size further.
The full decode step time formula: Let us be precise about what happens in one decode step. For each of the L layers:
The attention component is always bandwidth-bound because the ratio of FLOPs to bytes loaded is just B (the batch size), and B is typically much less than the critical batch size for attention. The MLP component transitions from bandwidth-bound to compute-bound at B = B*.
Worked example — LLaMA 70B on 16 TPU v5e at BS=120 and 8K context:
When a model does not fit on one chip (or we want better latency), we shard it across multiple chips. During inference, the strategy is almost exclusively model parallelism (tensor parallelism) — not data parallelism.
With Y-way tensor parallelism:
The ICI communication per layer is a ReduceScatter of 2BD bytes. We become ICI-bound when:
Or in the bandwidth-bound regime (small B):
For LLaMA 70B (F=28672) at batch size 64 with 2 ICI axes:
For the KV cache, each chip stores 1/Y of the KV heads. With GQA and K=8 KV heads, we can do up to 8-way TP without splitting individual KV heads. Beyond that, we need to split heads across chips, adding communication.
Worked example — sharding LLaMA 70B for inference on 16 v5e chips:
At BS=32: ICI volume = 32 × 16384 = 0.5 MB per layer per step. With 80 layers: 40 MB total. At 90 GB/s ICI: 0.44 ms. The total decode step is ~17 ms, so ICI is only 2.6% of the step time. Easily within budget.
The data parallelism alternative for inference: Instead of 16-way TP serving a batch of 32, you could run two 8-way TP replicas, each serving a batch of 16. Same total throughput, same total cost. But 16-way TP gives lower latency (17ms vs 34ms per step for 8-way). The choice depends on whether you optimize for latency or throughput.
When to use replicas vs larger TP:
| Scenario | Better approach | Why |
|---|---|---|
| Latency SLA < 15ms | Larger TP (more chips per replica) | Reduces per-step time linearly |
| High throughput, latency flexible | More replicas (fewer chips each) | More total KV cache capacity, higher aggregate batch |
| Mixed traffic | Replicas with load balancing | Better fault tolerance, easier scaling |
| Very large model (>100B) | Minimum TP to fit in memory | Then replicas for throughput |
In production, most serving deployments run multiple replicas of a minimum-TP configuration behind a load balancer. Each replica independently serves its own batch. This is simpler, more fault-tolerant, and scales horizontally by adding replicas.
Worked example — sizing a serving cluster for 1000 QPS:
About 5 cents per 100 queries. For a chatbot with 1M daily queries, that is $480/day or $14,400/month in inference costs. Not cheap, but far less than the $40M training cost.
Scaling the cluster: the load-shedding problem. At peak traffic (say 3x average), you need 270 replicas. But at 2 AM, you might only need 30. Autoscaling — spinning up and down replicas based on demand — is critical for cost efficiency. The challenge: loading a 70B model into HBM takes 70 GB / 820 GB/s = 85 ms per chip, but initializing the serving runtime (compiling XLA graphs, warming up caches) can take 30-60 seconds. This means autoscaling has a cold-start penalty of about 1 minute per new replica.
Solutions: maintain a "warm pool" of pre-loaded replicas, use spot/preemptible instances for non-critical traffic, or run a mix of model sizes (route simple queries to a smaller model).
Model routing — the "small model first" pattern: Many production systems route easy queries to a small model (e.g., 8B) and only use the large model (70B) for complex queries. If 70% of queries can be handled by the 8B model:
That is a 56% cost reduction compared to sending everything to the 70B model. The routing decision can be made by a lightweight classifier or by the small model itself (if it is "uncertain", escalate to the large model).
In a real serving system, requests arrive at different times, have different prompt lengths, and generate different numbers of tokens. Continuous batching (also called "inflight batching") keeps the hardware busy by dynamically adding and removing sequences from the batch.
Prefix caching is another important optimization. Many requests share common prefixes (system prompts, few-shot examples). By caching the KV states for these shared prefixes, we can skip the prefill computation entirely for that portion.
Consider a system with median prefill length P = 8192 and median decode length G = 4096, batch size B = 32. How often does a slot open up?
The scheduling challenge: When a sequence finishes decoding, its slot opens up. A new request must be prefilled before it can join the decode batch. But prefill is compute-bound and slow (potentially ~1 second for long prompts). How do we avoid stalling the decode batch while waiting for the new request to prefill?
Solution 1: Chunked prefill. Interleave small chunks of prefill with decode steps. Each decode step also processes a few hundred prefill tokens for the incoming request. This amortizes the prefill cost across many decode steps but slightly increases per-step latency.
Solution 2: Disaggregated prefill. Run prefill on separate servers. The KV cache is shipped to the decode server when ready. The decode batch never stalls.
Solution 3: Speculative insertion. Start the new request decoding with a partial KV cache (e.g., only the system prompt, which was cached). Complete the prefill in the background and swap in the full KV cache when ready.
Real serving systems — JetStream (Google): JetStream is Google's open-source inference engine for TPUs. It implements continuous batching, prefix caching, and quantized KV caches. The architecture separates the "tokenizer" (CPU), "prefill engine" (TPU), and "decode engine" (TPU) into separate async workers that communicate via queues. This mirrors the disaggregated serving pattern we described.
Other major serving frameworks:
| Framework | Hardware | Key innovation |
|---|---|---|
| vLLM | GPU | PagedAttention, continuous batching |
| TensorRT-LLM | NVIDIA GPU | FP8 kernels, inflight batching |
| JetStream | TPU | JAX-based, int8 KV, disaggregated |
| SGLang | GPU | RadixAttention (prefix tree caching) |
| DeepSpeed-FastGen | GPU | SplitFuse (chunked prefill + decode) |
All of these implement the core ideas from this chapter: continuous batching, prefix caching, quantized KV, and either disaggregated or chunked prefill. The engineering challenge is making them work together efficiently on real hardware with real workloads.
Speculative decoding attacks the decode bottleneck from a different angle. Instead of generating one token at a time with the large model, we use a small draft model to speculate multiple tokens ahead, then verify them all at once with the large model.
The acceptance rate α depends on how well the draft model matches the target model. Common approaches:
| Draft model type | Typical acceptance rate | Overhead |
|---|---|---|
| Smaller model (same family) | 60-80% | Separate model weights in HBM |
| Early exit from same model | 70-85% | No extra weights |
| N-gram / lookup table | 40-60% | Minimal |
| Medusa (extra heads) | 65-80% | Small extra heads on target |
The expected tokens per iteration: if we speculate K tokens with acceptance probability α per token, the expected accepted tokens is (1 - αK+1) / (1 - α). For α = 0.7, K = 5: expected ≈ 2.8 tokens/iteration.
Worked example: Draft model is a 1B model, target is 70B. Both on the same 16 v5e chips.
How disaggregated serving works in practice:
This visualization shows the interplay between prefill and decode phases for a batch of requests, and the latency breakdown between parameter loading, KV cache loading, and compute.
Adjust batch size and sequence length to see how latency components change.
Inference is fundamentally different from training. Here are the key results:
| Property | Prefill | Decode |
|---|---|---|
| Tokens processed | S (full sequence at once) | 1 per step |
| Arithmetic intensity | S (high) | B (low) |
| Bottleneck | Compute (FLOPs) | Memory bandwidth (HBM) |
| Scales with | More FLOPs per chip | More HBM bandwidth per chip |
| Parallelism | TP or DP (large matmuls) | TP only (model parallelism) |
Design principles for inference systems: