Austin et al., Part 7

All About Transformer Inference

From sampling a single token to scaling inference across many accelerators — prefill vs decode, arithmetic intensity, KV cache management, batching, and speculative decoding.

Prerequisites: Roofline model (arithmetic intensity, compute/memory bound), Transformer architecture (attention, FFN, KV cache).
11
Chapters
1
Simulation
11
Quizzes

Chapter 0: The Two Phases of Inference

Training and inference are fundamentally different. Training processes large batches and is almost always compute-bound. Inference, especially generation, introduces a new constraint: latency. The user is waiting for each token.

Autoregressive generation has two distinct phases:

Prefill (Prompt Processing)
Process all input tokens in parallel. Compute-bound. Builds the KV cache.
Decode (Token Generation)
Generate one token at a time. Memory-bandwidth-bound. Each step reads the full KV cache + model weights.
The fundamental difference: Prefill does a large matmul (B × S tokens, S = sequence length) and is compute-bound like training. Decode does a tiny matmul (B × 1 token per step) and is almost always memory-bandwidth-bound — the bottleneck is loading weights from HBM, not computing with them.

This split has profound implications. Prefill throughput scales with FLOPs. Decode throughput scales with memory bandwidth. Optimizing for one often hurts the other, which is why modern serving systems treat them separately.

Here is a concrete example to build intuition. Consider LLaMA 70B on 16 TPU v5e chips:

PhaseTokens per stepFLOPs/stepBytes loaded/stepArithmetic intensityBottleneck
Prefill (S=8192)81921.15e15140 GB (weights)8192Compute
Decode (BS=1)11.4e11140 GB (weights)1Memory BW
Decode (BS=240)2403.36e13140 GB + KV240Compute (MLP only)

Notice: prefill and decode load the same weights but prefill does 8192x more useful FLOPs with them. This is why decode at small batch sizes feels so wasteful — we load 140 GB of parameters just to generate a single token.

Check: During autoregressive decode, what is the primary bottleneck?

Chapter 1: Prefill — Compute-Bound

During prefill, we process the entire prompt at once. For a prompt of S tokens, each layer performs matmuls of size (S, D) × (D, F). This is a large, compute-intensive operation — identical to a training forward pass.

The arithmetic intensity of prefill (FLOPs per byte loaded):

AIprefill = 2 × S × D × F / (2 × D × F) = S

The numerator is the FLOPs for one matmul (2SDF). The denominator is the bytes to load the weight matrix (2DF in bf16). The arithmetic intensity is simply S, the sequence length!

Prefill is compute-bound when S is large. On TPU v5e, the compute-to-bandwidth ratio is 1.97e14 / 8.2e11 = 240. So for sequences longer than ~240 tokens, prefill is compute-bound. Most prompts are well above this. For S = 8192, the arithmetic intensity is 34x higher than needed to saturate the MXU.

Prefill latency (assuming compute-bound and MFU utilization U):

Tprefill = 2 × P × S / (N × C × U)

where P = parameter count, N = chip count, C = peak FLOPs/s, U = utilization.

For LLaMA 70B with S = 8192 on 16 TPU v5e chips at 40% MFU:

Tprefill = 2 × 70e9 × 8192 / (16 × 1.97e14 × 0.4) = 0.91 seconds
That is almost a full second just for prefill. For latency-sensitive applications, reducing prefill time is important. Prefix caching (reusing KV caches for common prompt prefixes) can help significantly.

Prefill is essentially a mini training forward pass. The key difference is that there is no backward pass, so we do 2P FLOPs instead of 6P. The matmul shapes are identical to training: each FFN layer computes (S, D) × (D, F), where S is the full sequence length. This is a large, compute-rich matmul that fully saturates the MXU.

Prefill and the KV cache: During prefill, we compute the K and V projections for every input token and store them in the KV cache. This is a one-time cost. After prefill, the KV cache contains all the context the model needs to start generating.

KV cache after prefill = S × 2 × K × H × L × sizeof(dtype)

For LLaMA 70B with S=8192 in int8: 8192 × 160 kB = 1.31 GB. This must fit in HBM alongside the model weights before decode can begin.

Chunked prefill: For very long prompts (>32K tokens), the activation memory during prefill can be large. Chunked prefill processes the prompt in chunks (e.g., 2048 tokens at a time), appending to the KV cache incrementally. This trades a small amount of compute efficiency for bounded activation memory.

Check: What is the arithmetic intensity of prefill for a 4096-token prompt?

Chapter 2: Decode — Memory-Bound

During decode, we generate one token at a time. Each step performs a matmul of size (B, D) × (D, F) where B is the batch size. But unlike prefill, B is typically small (1-256), so the arithmetic intensity is very low.

AIdecode = 2 × B × D × F / (2 × D × F) = B

The arithmetic intensity equals the batch size! On TPU v5e, we need B ≥ 240 to be compute-bound. At batch size 1, we are 240x below the compute-bound threshold — the MXU is sitting mostly idle while we wait for weights to load from HBM.

The decode bottleneck is loading weights. Every decode step must load the entire model from HBM. For a 70B model in bf16 (140 GB), on a single chip this takes 140 GB / 820 GB/s = 171 ms per step. That is just the weight loading — we have not even started computing.

The general formula for decode step time at small batch sizes (memory-bandwidth-bound):

Tstep = (param_size + KV_cache_size) / (N × WHBM)

There are two terms because each step loads: (1) the full model parameters, and (2) the full KV cache for the attention computation. The KV cache grows with sequence length and batch size.

Decode throughput (tokens per second):

Throughput = B / Tstep

To maximize throughput, increase B until compute-bound. But each sequence in the batch needs its own KV cache, so larger B requires more HBM for KV storage.

Worked example — decode on a single TPU v5e:

Model: 7B params, bf16, 14 GB weight file
HBM: 16 GB. Remaining after weights: 2 GB for KV caches.
KV per token (7B model, L=32, K=32, H=128, bf16): 2 × 32 × 128 × 32 × 2 = 524 kB
At 2K context: KV per seq = 524e3 × 2048 = 1.07 GB. Max BS = 2 GB / 1.07 GB ≈ 1.

We can barely fit 1 sequence on a single v5e chip! For a usable batch size, we need multiple chips or quantization.

The "time to first token" (TTFT) and "time between tokens" (TBT) metrics:

MetricDetermined byUser impact
TTFTPrefill latencyHow long before the first word appears
TBTDecode step latencyHow fast words stream after the first

A comfortable reading speed is ~5 tokens/s. So TBT < 200 ms feels "instant" to users. Most LLM serving systems achieve TBT of 10-50 ms, well within this budget. TTFT is more variable and often the user-facing bottleneck.

The decode "efficiency gap": During training, MFU of 40-50% is considered good. During decode at BS=1, the equivalent metric is roughly 0.4% (we use 1/240th of the MXU). This is a 100x efficiency gap! This is why LLM inference is so much more expensive per FLOP than training. The entire field of inference optimization is about closing this gap.

The three ways to close the gap:

1. Increase batch size
More sequences = higher arithmetic intensity = better MXU utilization. Limited by KV cache memory.
2. Reduce weight bytes
Quantization (int8, int4, FP8) halves or quarters the bytes loaded, doubling/quadrupling effective AI.
3. Process more tokens per step
Speculative decoding generates K tokens per verification step, multiplying effective throughput.
Check: At batch size 1, how does the MXU utilization compare to peak during decode?

Chapter 3: Arithmetic Intensity Deep Dive

Let us formalize when each phase transitions from memory-bound to compute-bound. The roofline model gives us a clear threshold.

For a matmul Y = X × W with X of shape (B, D) and W of shape (D, F) in bf16:

FLOPs = 2 × B × D × F
Bytes loaded = 2(B×D + D×F + B×F)

When B is much smaller than D and F (as in decode), the weight loading (2DF) dominates and AI ≈ B. When B is large (as in prefill), the compute dominates.

The critical batch size B* is where compute time equals memory time:

B* = C / WHBM
Hardwarebf16 FLOPs/sHBM BWB* (bf16 weights)B* (int8 weights)
TPU v5e1.97e14820 GB/s240120
TPU v5p4.59e142765 GB/s16683
H1009.9e143350 GB/s296148
Why int8 halves the critical batch size. With int8 weights, we load half the bytes for the same FLOPs (since we still compute in bf16). This means AI = 2B instead of B, so we become compute-bound at B = B*/2. Quantization does not just save memory — it makes the hardware more efficient.

For attention (not just the MLP), the picture is different. Attention loads the KV cache, which grows with sequence length. The KV load per decode step is proportional to B × S (batch times sequence length), and this is always memory-bandwidth-bound for reasonable sequence lengths.

Attention arithmetic intensity: For the attention QK^T matmul during decode, the query is (B, 1, N, H) and the keys are (B, S, K, H). The FLOPs are 2BNSH. The bytes loaded are the KV cache: 2 × B × S × K × H × sizeof(dtype). The arithmetic intensity:

AIattention = 2BNSH / (2BSKH × bytes) = N/K (for bf16)

For LLaMA 70B: N=64, K=8, so AI = 8. Since the critical AI on v5e is 240, attention is 30x below the compute-bound threshold. It is always heavily bandwidth-bound during decode, regardless of batch size.

This is why KV cache loading dominates at long contexts. Even if the MLP becomes compute-bound at BS=120, the attention is still bandwidth-bound, and its cost grows linearly with sequence length.

The full roofline picture for one decode step:

MLP: max(compute, param loading)
Transitions from BW-bound to compute-bound at B*. In both cases: time ≈ max(2BP/C, P_bytes/W) per chip.
+
Attention: always BW-bound
Time = B × S × KV_per_token / (N × W). Grows linearly with B and S.
=
Total step time
Both terms contribute. At short context, MLP dominates. At long context, attention dominates.
Check: Why does int8 quantization help inference throughput beyond just saving memory?

Chapter 4: The KV Cache

The KV cache stores the key and value projections for all previous tokens, so we do not recompute them on every decode step. It is one of the most important (and most memory-hungry) data structures in LLM inference.

For a model with K KV-heads, head dimension H, and L layers, the KV cache size per token is:

KV bytes/token = 2 × K × H × L × sizeof(dtype)

For LLaMA 3-70B with int8 KV caches:

= 2 × 8 × 128 × 80 × 1 = 163,840 bytes ≈ 160 kB per token
160 kB per token is enormous. For a sequence of 32K tokens, that is 160e3 × 32768 = 5.3 GB per sequence. At batch size 240, the total KV cache is 1.3 TB — far exceeding the 70 GB of model parameters!
Sequence lengthKV cache / sequenceKV cache @ BS=32KV cache @ BS=240
2,048328 MB10.5 GB78.6 GB
8,1921.3 GB41.9 GB314 GB
32,7685.3 GB168 GB1.26 TB
131,07221 GB671 GB5.04 TB

The KV cache creates a direct trade-off between batch size and sequence length. More sequences in a batch means more KV caches, which means less room for longer contexts.

KV cache is the hidden cost of long context. Doubling the context length doubles the KV cache per sequence. This means either halving the batch size (halving throughput) or doubling the number of chips. GQA (as in LLaMA 3) reduces this by using fewer KV heads.

How GQA helps: Standard MHA (Multi-Head Attention) with 64 heads needs 64 KV heads. GQA-8 (8 groups) needs only 8 KV heads, an 8x reduction. For LLaMA 70B:

MHA KV/token = 2 × 64 × 128 × 80 = 1.31 MB (bf16)
GQA-8 KV/token = 2 × 8 × 128 × 80 = 164 kB (bf16) — 8x smaller!

At 8K context, this is the difference between 10.7 GB vs 1.3 GB per sequence. GQA is one of the most important architecture decisions for inference efficiency.

KV cache compression techniques:

TechniqueSavingsTrade-off
Quantization (int8/int4)2-4x memorySlight quality loss, especially int4
GQA (fewer KV heads)N/K reductionBuilt into architecture, no runtime cost
Paged AttentionEliminates fragmentationMore complex memory management
Sliding WindowFixed KV size regardless of contextLoses access to early context
Token evictionVariable, task-dependentMay drop important tokens

Paged Attention (vLLM): Without paged attention, each sequence pre-allocates its maximum KV cache size, even if the sequence has not reached that length yet. This wastes 50-90% of KV memory. Paged Attention allocates KV blocks on demand, like virtual memory pages, achieving near-zero waste.

Worked example — memory waste without paged attention:

Scenario: Max context = 32K tokens, BS=32, LLaMA 70B int8 on 16 v5e.
Pre-allocated KV per sequence = 32768 × 160 kB = 5.24 GB
Total pre-allocated = 32 × 5.24 = 167.7 GB
Actual average sequence length = 8K tokens (mix of short and long)
Actual KV used = 32 × 8192 × 160 kB = 41.9 GB
Waste = 167.7 - 41.9 = 125.8 GB = 75% waste!

Paged Attention eliminates this by allocating KV blocks (e.g., 256 tokens each) only when needed. The 125 GB of waste could support another 95 sequences at 8K average length — tripling effective batch size and throughput.

KV cache and context length: the cost curve

Let us quantify the cost of supporting longer contexts. For LLaMA 70B int8 on 16 v5e at BS=32:

ContextKV total% of HBM used by KVStep timeTok/s/chip
2K10.5 GB4.1%6.1 ms328
8K41.9 GB16.4%8.5 ms235
32K167.8 GB65.5%17.8 ms112
128K671 GBDoes not fit!

Going from 2K to 32K context costs a 3x reduction in throughput per chip, entirely due to KV cache loading. At 128K, we cannot even fit BS=32 on 16 chips. We would need 64+ chips or lower batch size significantly.

Check: For LLaMA 3-70B at 8K context with int8 KVs, how large is the KV cache for one sequence?

Chapter 5: Batching for Throughput

Since decode is memory-bandwidth-bound at small batch sizes, the obvious optimization is to increase the batch size. Each step already loads the full model weights — processing more tokens per step amortizes the weight-loading cost.

The general formula for decode step latency with both parameter and KV cache loading:

Tstep = underbrace(B × KV_size / (N × W))_attention + max(2 × B × P / (N × C), P_bytes / (N × W))_MLP

Breaking this down:

Attention term
Always bandwidth-bound. Loads B × seq_len × KV_per_token bytes from HBM.
+
MLP term
max(compute time, weight loading time). Transitions from bandwidth-bound to compute-bound at B = B*.
The throughput-latency trade-off. At batch size 1, latency is minimal but throughput is terrible (most FLOPs are wasted). At batch size 240+, we saturate the MXU but latency increases because we must also load 240 KV caches per step. Doubling per-token latency can yield a 100x reduction in per-token cost.

For LLaMA 70B on 16 TPU v5e with int8 weights and 8K context:

Batch sizeStep latencyThroughput (tok/s/chip)Status
1~5.5 ms11.4Bandwidth-bound
32~8 ms250Bandwidth-bound
120~14 ms536Near compute-bound
240~20 ms750Compute-bound (MLP)

The limit on batch size is HBM capacity. Total HBM usage = model parameters + batch_size × seq_len × KV_per_token. Once this fills the available HBM, we cannot increase batch size further.

The full decode step time formula: Let us be precise about what happens in one decode step. For each of the L layers:

1. Attention: Load KV cache
Load B × S × 2KH bytes from HBM. Compute QK^T and softmax. Always bandwidth-bound.
2. MLP: Load weights
Load D×F + F×D + D×F bytes (for SwiGLU: gate, up, down). Compute 3 matmuls.
3. Write new KV entry
Write B × 2KH bytes to the KV cache for this step's token.

The attention component is always bandwidth-bound because the ratio of FLOPs to bytes loaded is just B (the batch size), and B is typically much less than the critical batch size for attention. The MLP component transitions from bandwidth-bound to compute-bound at B = B*.

Worked example — LLaMA 70B on 16 TPU v5e at BS=120 and 8K context:

Attention: KV load = 120 × 8192 × 160e3 = 157 GB. Time = 157e9 / (16 × 8.2e11) = 12.0 ms
MLP: Compute = 2 × 70e9 × 120 / (16 × 1.97e14) = 5.3 ms. Loading = 70e9 / (16 × 8.2e11) = 5.3 ms.
MLP time = max(5.3, 5.3) = 5.3 ms (right at the transition!)
Total = 12.0 + 5.3 = 17.3 ms. Throughput = 120 / 0.0173 / 16 = 433 tok/s/chip
Check: Why does increasing batch size improve throughput during decode?

Chapter 6: Distributing Inference

When a model does not fit on one chip (or we want better latency), we shard it across multiple chips. During inference, the strategy is almost exclusively model parallelism (tensor parallelism) — not data parallelism.

Why not data parallelism? DP replicates the model and splits the batch. But in inference, the batch is often small, and we need lower latency, not just higher throughput. Model parallelism reduces the weight-loading time per chip (each chip loads P/Y bytes) and the compute time (each chip does 1/Y of the FLOPs).

With Y-way tensor parallelism:

Tstep = (P_bytes + KV_cache) / (Y × WHBM) + ICI_comms

The ICI communication per layer is a ReduceScatter of 2BD bytes. We become ICI-bound when:

Y > F × MY / 2200 (compute-bound regime)

Or in the bandwidth-bound regime (small B):

Y > F / (8B) (when we can overlap comms with HBM loading)

For LLaMA 70B (F=28672) at batch size 64 with 2 ICI axes:

Compute-bound limit: Y < 28672 × 2 / 2200 = 26
Bandwidth-bound limit: Y < 28672 / (8 × 64) = 56
Scaling beyond the limit helps latency, not throughput. Past the ICI-bound threshold, adding more chips still reduces latency (less weight to load per chip) but does not increase throughput per chip. This is useful if latency is the primary objective.

For the KV cache, each chip stores 1/Y of the KV heads. With GQA and K=8 KV heads, we can do up to 8-way TP without splitting individual KV heads. Beyond that, we need to split heads across chips, adding communication.

Worked example — sharding LLaMA 70B for inference on 16 v5e chips:

Params per chip: 70 GB / 16 = 4.375 GB (int8)
KV cache per chip: (B × S × 160e3) / 16 (each chip stores 0.5 KV head of 8)
Weight matmul: (B, 8192) × (8192/16, 28672) per chip per FFN
ICI per layer: ReduceScatter of (B, 8192) × 2 bytes = 16384B per batch element

At BS=32: ICI volume = 32 × 16384 = 0.5 MB per layer per step. With 80 layers: 40 MB total. At 90 GB/s ICI: 0.44 ms. The total decode step is ~17 ms, so ICI is only 2.6% of the step time. Easily within budget.

The data parallelism alternative for inference: Instead of 16-way TP serving a batch of 32, you could run two 8-way TP replicas, each serving a batch of 16. Same total throughput, same total cost. But 16-way TP gives lower latency (17ms vs 34ms per step for 8-way). The choice depends on whether you optimize for latency or throughput.

When to use replicas vs larger TP:

ScenarioBetter approachWhy
Latency SLA < 15msLarger TP (more chips per replica)Reduces per-step time linearly
High throughput, latency flexibleMore replicas (fewer chips each)More total KV cache capacity, higher aggregate batch
Mixed trafficReplicas with load balancingBetter fault tolerance, easier scaling
Very large model (>100B)Minimum TP to fit in memoryThen replicas for throughput

In production, most serving deployments run multiple replicas of a minimum-TP configuration behind a load balancer. Each replica independently serves its own batch. This is simpler, more fault-tolerant, and scales horizontally by adding replicas.

Worked example — sizing a serving cluster for 1000 QPS:

Target: 1000 queries/second, LLaMA 70B int8, median 1K prefill + 512 decode tokens.
Per-replica config: 16 v5e chips, BS=64, 8K context.
Step time = (70 + 64 × 8192 × 160e3/1e9) / (16 × 820) = 11.2 ms
Tokens/s per replica = 64 / 0.0112 = 5714 tok/s
Median decode time per query = 512 × 0.0112 = 5.73 seconds
Queries completing per second per replica = 64 / 5.73 = 11.2 QPS
Replicas needed = 1000 / 11.2 = 90 replicas
Total chips = 90 × 16 = 1,440 TPU v5e
Cost = 1440 × $1.20/hr = $1,728/hour
Cost per query = $1,728 / (1000 × 3600) = $0.00048 ≈ $0.05 per 100 queries

About 5 cents per 100 queries. For a chatbot with 1M daily queries, that is $480/day or $14,400/month in inference costs. Not cheap, but far less than the $40M training cost.

Scaling the cluster: the load-shedding problem. At peak traffic (say 3x average), you need 270 replicas. But at 2 AM, you might only need 30. Autoscaling — spinning up and down replicas based on demand — is critical for cost efficiency. The challenge: loading a 70B model into HBM takes 70 GB / 820 GB/s = 85 ms per chip, but initializing the serving runtime (compiling XLA graphs, warming up caches) can take 30-60 seconds. This means autoscaling has a cold-start penalty of about 1 minute per new replica.

Solutions: maintain a "warm pool" of pre-loaded replicas, use spot/preemptible instances for non-critical traffic, or run a mix of model sizes (route simple queries to a smaller model).

Model routing — the "small model first" pattern: Many production systems route easy queries to a small model (e.g., 8B) and only use the large model (70B) for complex queries. If 70% of queries can be handled by the 8B model:

Blended cost = 0.7 × $0.15/1M tok (8B) + 0.3 × $0.77/1M tok (70B) = $0.34/1M tok

That is a 56% cost reduction compared to sending everything to the 70B model. The routing decision can be made by a lightweight classifier or by the small model itself (if it is "uncertain", escalate to the large model).

Check: During inference, why do we prefer tensor parallelism over data parallelism?

Chapter 7: Continuous Batching

In a real serving system, requests arrive at different times, have different prompt lengths, and generate different numbers of tokens. Continuous batching (also called "inflight batching") keeps the hardware busy by dynamically adding and removing sequences from the batch.

Static batching
Wait for a full batch, process all sequences until the longest one finishes. Short sequences waste compute while padding.
vs
Continuous batching
When a sequence finishes, immediately replace it with a new request. The batch stays full at all times.
Continuous batching can double throughput compared to static batching, because we never have empty slots in the batch. It is now standard in production serving systems (vLLM, TensorRT-LLM, JetStream).

Prefix caching is another important optimization. Many requests share common prefixes (system prompts, few-shot examples). By caching the KV states for these shared prefixes, we can skip the prefill computation entirely for that portion.

Consider a system with median prefill length P = 8192 and median decode length G = 4096, batch size B = 32. How often does a slot open up?

Sequences finishing per step = B / G = 32 / 4096 ≈ 1 every 128 steps
Tokens evicted per step = B × (P + G) / G = 32 × 12288 / 4096 = 96 tokens

The scheduling challenge: When a sequence finishes decoding, its slot opens up. A new request must be prefilled before it can join the decode batch. But prefill is compute-bound and slow (potentially ~1 second for long prompts). How do we avoid stalling the decode batch while waiting for the new request to prefill?

Solution 1: Chunked prefill. Interleave small chunks of prefill with decode steps. Each decode step also processes a few hundred prefill tokens for the incoming request. This amortizes the prefill cost across many decode steps but slightly increases per-step latency.

Solution 2: Disaggregated prefill. Run prefill on separate servers. The KV cache is shipped to the decode server when ready. The decode batch never stalls.

Solution 3: Speculative insertion. Start the new request decoding with a partial KV cache (e.g., only the system prompt, which was cached). Complete the prefill in the background and swap in the full KV cache when ready.

Real serving systems — JetStream (Google): JetStream is Google's open-source inference engine for TPUs. It implements continuous batching, prefix caching, and quantized KV caches. The architecture separates the "tokenizer" (CPU), "prefill engine" (TPU), and "decode engine" (TPU) into separate async workers that communicate via queues. This mirrors the disaggregated serving pattern we described.

Other major serving frameworks:

FrameworkHardwareKey innovation
vLLMGPUPagedAttention, continuous batching
TensorRT-LLMNVIDIA GPUFP8 kernels, inflight batching
JetStreamTPUJAX-based, int8 KV, disaggregated
SGLangGPURadixAttention (prefix tree caching)
DeepSpeed-FastGenGPUSplitFuse (chunked prefill + decode)

All of these implement the core ideas from this chapter: continuous batching, prefix caching, quantized KV, and either disaggregated or chunked prefill. The engineering challenge is making them work together efficiently on real hardware with real workloads.

Check: What is the main benefit of continuous batching over static batching?

Chapter 8: Speculative Decoding

Speculative decoding attacks the decode bottleneck from a different angle. Instead of generating one token at a time with the large model, we use a small draft model to speculate multiple tokens ahead, then verify them all at once with the large model.

1. Draft
Small model generates K candidate tokens quickly (fast because small model is also bandwidth-bound but loads fewer bytes)
2. Verify
Large model scores all K candidates in one forward pass (same cost as prefill — compute-bound, efficient)
3. Accept / Reject
Accept the longest prefix where draft and target agree. On average, accept α·K tokens per iteration.
Why this works: The verification step processes all K tokens in parallel (like prefill), so it costs roughly the same as generating a single token in decode mode. If we accept on average 3-4 tokens per speculation round, we get a 3-4x speedup in wall-clock latency.

The acceptance rate α depends on how well the draft model matches the target model. Common approaches:

Draft model typeTypical acceptance rateOverhead
Smaller model (same family)60-80%Separate model weights in HBM
Early exit from same model70-85%No extra weights
N-gram / lookup table40-60%Minimal
Medusa (extra heads)65-80%Small extra heads on target

The expected tokens per iteration: if we speculate K tokens with acceptance probability α per token, the expected accepted tokens is (1 - αK+1) / (1 - α). For α = 0.7, K = 5: expected ≈ 2.8 tokens/iteration.

Worked example: Draft model is a 1B model, target is 70B. Both on the same 16 v5e chips.

Draft step time: 1 GB params / (16 × 820 GB/s) = 0.076 ms per token
K=5 draft tokens: 5 × 0.076 = 0.38 ms
Verify step (like prefill with K tokens): 2 × 70e9 × 5 / (16 × 1.97e14) = 0.22 ms
Total per speculation round: 0.38 + 0.22 = 0.60 ms
Normal decode: 70 GB / (16 × 820 GB/s) = 5.34 ms per token
At α=0.7, expected 2.8 tokens/round: effective = 2.8 / 0.60 = 4667 tok/s
Normal decode: 1 / 0.00534 = 187 tok/s at BS=1
Speedup: 4667 / 187 = 25x at batch size 1!
Caveat: This is an idealized calculation. In practice, the draft model has its own KV cache overhead, the acceptance rate depends on the input, and the verification step has some overhead beyond pure FLOPs. Real-world speculative decoding typically achieves 2-3x speedup at BS=1, not 25x. But at large batch sizes, the benefit diminishes because decode is already closer to compute-bound.
Disaggregated serving takes the prefill/decode split further: run prefill on compute-optimized chips and decode on bandwidth-optimized chips. The KV cache is transferred between them. This lets each phase use hardware tuned for its bottleneck.

How disaggregated serving works in practice:

1. Request arrives
Load balancer routes to a prefill server with capacity
2. Prefill runs
Compute-bound forward pass. KV cache generated for all prompt tokens.
3. KV cache transfer
Ship KV cache over network to a decode server. Size = S × KV/token.
4. Decode loop
Token-by-token generation on bandwidth-optimized hardware. Results stream back.
Check: Why is the verification step in speculative decoding efficient?

Chapter 9: Prefill/Decode Timeline

This visualization shows the interplay between prefill and decode phases for a batch of requests, and the latency breakdown between parameter loading, KV cache loading, and compute.

Inference Latency Breakdown

Adjust batch size and sequence length to see how latency components change.

32
8192
Check: At long context (32K), what dominates the decode step latency?

Chapter 10: Takeaways

Inference is fundamentally different from training. Here are the key results:

PropertyPrefillDecode
Tokens processedS (full sequence at once)1 per step
Arithmetic intensityS (high)B (low)
BottleneckCompute (FLOPs)Memory bandwidth (HBM)
Scales withMore FLOPs per chipMore HBM bandwidth per chip
ParallelismTP or DP (large matmuls)TP only (model parallelism)
Key formulas:
Prefill latency: T = 2 × P × S / (N × C × U)
Decode step: T = (param_bytes + KV_bytes) / (N × WHBM) [bandwidth-bound]
Critical batch size: B* = C / WHBM (e.g., 240 on v5e bf16)
KV cache/token: 2 × K × H × L × sizeof(dtype)
Throughput = B / Tstep

Design principles for inference systems:

Maximize batch size
Amortize weight loading across more tokens. Limited by KV cache memory.
Minimize KV cache size
Use GQA, quantize KV to int8, compress with paged attention.
Separate prefill and decode
Different hardware needs. Disaggregate or at least schedule separately.
Continuous batching
Never let batch slots go empty. Replace finished sequences immediately.
Check: A serving system can either decrease per-token latency 2x OR decrease per-token cost 100x. Which is the larger improvement?