The complete guide: what happens when you hit "send," and every trick that makes it fast.
You type "Write me a poem about the ocean" into ChatGPT. Within a second, words start appearing. One by one, they stream onto the screen, assembling into something that sounds almost human. Behind that casual second of waiting: 70 billion numbers were multiplied together, a probability distribution over 128,000 possible next words was computed, one token was chosen from that distribution, and the whole process started over. It happened again. And again. Two hundred times before the poem was done.
Every single token — every "the," every "crashing," every comma — required the model to read its entire 140 GB weight matrix from memory. That is 200 × 140 GB = 28 terabytes of memory traffic for a single poem. And your poem was one of thousands of requests hitting that server in the same minute.
This lesson is about what happens inside the machine during that second. Not the math of transformers (that is a separate lesson), but the engineering of making transformers fast enough to be useful, cheap enough to be profitable, and scalable enough to serve billions of users.
When you hit "send," your text embarks on a six-stage pipeline. Each stage transforms the data into a different form, and each stage has different performance characteristics. Understanding this pipeline is the foundation of everything we will cover in this lesson.
That loop — transform, project, sample, append, repeat — is the beating heart of LLM inference. Everything we will learn in this lesson is about making that loop faster, cheaper, and more efficient.
The visualization below shows tokens flowing through each stage of the pipeline. Watch how each token must wait for all previous tokens to be processed before it can be generated — this is the autoregressive bottleneck that dominates inference cost.
Watch tokens flow through the inference pipeline. Each new token requires a full pass through all transformer layers.
Here is the calculation that keeps every AI company's CFO awake at night. Let us work through it step by step for a 70-billion parameter model stored in FP16 (half precision, 2 bytes per parameter).
Model size in memory:
An NVIDIA A100 GPU has 80 GB of HBM (High Bandwidth Memory). Our 70B model does not fit on one GPU. We need at least two A100s just to hold the weights, using tensor parallelism to split the model across GPUs.
Time to generate one token:
To generate a single token, the model must read every weight from GPU memory at least once (for the matrix multiplications in each transformer layer). The A100 has a memory bandwidth of approximately 2 TB/s. So the minimum time to read all weights is:
That gives us a throughput of approximately 14 tokens per second — which is roughly the speed you see words appearing in ChatGPT. This is not a coincidence. For a single user with a large model, memory bandwidth is the bottleneck, and that 70 ms per token is a hard physical limit imposed by the speed of the wires connecting the GPU chip to its memory.
Tokens per second:
An average English word is about 1.3 tokens. So 14 tokens/second is roughly 11 words per second — fast enough to feel conversational, but barely.
Let us put some real numbers on this.
| Metric | Training (one-time) | Inference (ongoing) |
|---|---|---|
| Cost per run | $50M-$100M+ for frontier models | ~$0.005-$0.06 per 1K tokens |
| Frequency | Once per model version | Millions of times per day |
| Annual cost | $100M (amortized) | $500M-$4B+ for major providers |
| Hardware utilization | ~40-60% GPU compute | ~5-15% GPU compute (memory-bound) |
| Optimization impact | 2x faster training saves months | 2x faster inference saves billions/year |
This table reveals something surprising: inference hardware is massively underutilized. An A100 can perform 312 trillion floating-point operations per second, but during autoregressive decoding, we are using only 5-15% of that compute capacity. The GPU cores are sitting idle, waiting for data to arrive from memory. This is the memory-bound regime, and it is the central challenge of LLM inference.
Here is a simple Python script that estimates inference throughput for any model and GPU combination. Try changing the parameters to build intuition for how model size and hardware bandwidth interact.
python def estimate_throughput( params_billions: float, bytes_per_param: float = 2.0, # FP16 = 2, INT8 = 1, INT4 = 0.5 bandwidth_tb_s: float = 2.0, # A100=2.0, H100=3.35, H200=4.8 batch_size: int = 1, ): """Estimate tokens/second for autoregressive decoding.""" model_gb = params_billions * bytes_per_param # model size in GB bandwidth_gb_s = bandwidth_tb_s * 1000 # convert TB/s to GB/s # Time to read all weights once (= time for one token, batch=1) time_per_token_s = model_gb / bandwidth_gb_s # With batching, we read weights once but produce batch_size tokens tokens_per_second = batch_size / time_per_token_s print(f"Model: {params_billions}B params, {bytes_per_param} bytes/param") print(f"Model memory: {model_gb:.1f} GB") print(f"GPU bandwidth: {bandwidth_tb_s} TB/s") print(f"Time per token (batch=1): {time_per_token_s*1000:.1f} ms") print(f"Tokens/sec (batch={batch_size}): {tokens_per_second:.1f}") print(f"Words/sec (approx): {tokens_per_second/1.3:.1f}") # Example: Llama 3 70B on A100 estimate_throughput(70, 2.0, 2.0, batch_size=1) # Model: 70B params, 2.0 bytes/param # Model memory: 140.0 GB # Time per token (batch=1): 70.0 ms # Tokens/sec (batch=1): 14.3 # Same model with INT4 quantization on H100 estimate_throughput(70, 0.5, 3.35, batch_size=1) # Model memory: 35.0 GB # Time per token (batch=1): 10.4 ms # Tokens/sec (batch=1): 95.7 ← 6.7x faster! # INT4 on H100 with batch size 32 estimate_throughput(70, 0.5, 3.35, batch_size=32) # Tokens/sec (batch=32): 3062.9 ← batching is the real multiplier
Notice the two levers that matter most: quantization (reducing bytes per parameter from 2.0 to 0.5 gives a 4x speedup) and batching (processing 32 users simultaneously gives a 32x speedup with the same memory read). We will explore both in depth later in this lesson.
Inference is not just chatbots. Every AI-powered feature you use runs inference continuously:
| Application | Queries per day | Latency requirement | Why inference cost matters |
|---|---|---|---|
| ChatGPT / Claude | 100M+ | <100ms time-to-first-token | Revenue per query is cents; GPU cost per query must be less |
| GitHub Copilot | Billions of completions | <200ms to feel responsive | $10/month subscription must cover GPU cost |
| Google Search (AI Overview) | 8.5B searches/day | <500ms total | Adding LLM to search triples compute cost per query |
| AI coding agents | Long-running sessions | Throughput > latency | Agents generate 10K-100K tokens per task |
| Autonomous vehicles | Continuous | <100ms hard deadline | Late response = safety hazard |
In every case, inference is the ongoing cost center. A 2x improvement in inference efficiency is worth billions of dollars per year across the industry. This is why inference optimization is the hottest area in systems ML right now.
"Inference" sounds fancy. It is not. It means: give the model some text, get the next word. But how it picks that word — the specific sequence of operations from raw text to chosen token — is where it gets interesting. And understanding this sequence is the key to making it faster.
Let us trace every step, starting from the very beginning: the text itself.
A language model does not see words. It sees tokens — integer indices into a fixed vocabulary. You might think the obvious approach is: one word = one token. But this breaks immediately. Consider the word "unbelievably." If we assign one token per word, we need a separate vocabulary entry for "unbelievably," "unbelievable," "believably," "believe," "believing," "believed," and every other variation. English has over 170,000 words in common use. Many languages have far more (German compound words, Turkish agglutination). The vocabulary would be enormous, and rare words would barely be seen during training.
Byte-Pair Encoding (BPE) is the solution used by essentially all modern LLMs. The idea is simple and elegant: start with individual characters as tokens, then repeatedly merge the most common adjacent pair into a new token. After 32,000-128,000 merges, you have a vocabulary where common words like "the" are single tokens, common subwords like "ing" and "tion" are single tokens, and rare words are broken into familiar pieces.
Let us trace a concrete example. Consider the sentence: "BentoML supports custom LLM inference."
Notice several things. "BentoML" is not in the vocabulary, so it gets split into "Bent" + "o" + "ML" — three familiar subpieces. Common words like "supports" stay whole (with a leading space baked into the token). "LLM" becomes "LL" + "M" because "LLM" as a trigram was not frequent enough in the training data to earn its own token. The period is its own token.
Here is how to verify this yourself:
python import tiktoken # Load the GPT-4 tokenizer enc = tiktoken.get_encoding("cl100k_base") text = "BentoML supports custom LLM inference." tokens = enc.encode(text) print(f"Text: {text}") print(f"Token IDs: {tokens}") print(f"Number of tokens: {len(tokens)}") print(f"Decoded tokens:") for t in tokens: print(f" {t} → '{enc.decode([t])}'") # Output: # Text: BentoML supports custom LLM inference. # Token IDs: [60471, 78, 2735, 11815, 7559, 16384, 44, 45478, 13] # Number of tokens: 9 # Decoded tokens: # 60471 → 'Bent' # 78 → 'o' # 2735 → 'ML' # 11815 → ' supports' # ...
Once we have token IDs, each one is used to look up a row from the embedding matrix. This is a simple table lookup — no multiplication, no computation. The embedding matrix has shape [vocab_size × hidden_dim]. For Llama 3 70B, that is [128,256 × 8,192], and each row is a learned 8,192-dimensional vector that represents the "meaning" of that token in the model's learned space.
The embedding table for Llama 3 70B contains 128,256 × 8,192 × 2 bytes = 2.1 GB of parameters. This is a small fraction of the total 140 GB model.
The embedded vectors now pass through the transformer layers — the core of the model. Each layer performs two main operations: self-attention and a feed-forward network (FFN). We will not derive the full attention math here (see the Transformers lesson), but we need to understand the key operations at a high level to reason about inference performance.
Each attention layer computes three projections from the input:
After attention, the result passes through a feed-forward network — typically two linear layers with a nonlinearity (SiLU/GELU) in between. In Llama 3, this FFN has an intermediate dimension of 28,672 — 3.5x the model dimension. The FFN is where the majority of each layer's parameters live.
Each of the 80 layers reads its weight matrices from GPU memory, performs the matrix multiplications, and passes the result to the next layer. The input enters as a sequence of vectors and exits as a sequence of vectors of the same shape — but now encoded with the "understanding" of the entire context.
Here is the crucial fact about LLM inference: each token depends on ALL previous tokens. The model cannot generate token 5 until it knows what tokens 1-4 are, because the attention mechanism in token 5 needs to attend to all prior tokens.
This creates a sequential bottleneck that is fundamentally different from training. During training, the model processes the entire sequence in parallel because all tokens are known upfront (from the training data). During inference, each new token must be generated one at a time, waiting for the previous token to be produced.
After the final transformer layer, the model outputs a vector of 8,192 dimensions for the last position. This vector is multiplied by the language model head (also called the "unembedding" matrix), producing a logit for every token in the vocabulary — 128,256 numbers, each indicating how likely the model thinks that token should come next.
These logits are converted to probabilities using softmax:
Where T is the temperature parameter. Temperature reshapes the probability distribution:
| Temperature | Effect | Analogy | Use case |
|---|---|---|---|
| T → 0 | All probability mass on the most likely token | A laser beam hitting one point | Factual Q&A, code generation |
| T = 1.0 | Original model distribution | Normal confidence level | General conversation |
| T > 1.0 | Flatter distribution, more randomness | A flashlight spreading light evenly | Creative writing, brainstorming |
Temperature alone is not enough for good generation quality. Two additional strategies filter the probability distribution before sampling:
Top-k sampling: Keep only the k most probable tokens, set all others to probability zero, and renormalize. If k=50, only the 50 most likely tokens are candidates. This prevents the model from ever picking extremely unlikely tokens (which tend to be incoherent).
Top-p (nucleus) sampling: Sort tokens by probability, then keep the smallest set whose cumulative probability exceeds p. If p=0.9, you keep adding tokens (from most to least probable) until their probabilities sum to at least 0.9. This is adaptive — when the model is confident (one token has probability 0.85), top-p keeps only 2-3 candidates. When the model is uncertain (many tokens at 0.01), top-p keeps 50+.
Let us trace a concrete example:
| Strategy | Deterministic? | Adaptive? | Risk |
|---|---|---|---|
| Greedy | Yes | No | Repetitive, boring outputs |
| Top-k | No | No — fixed k | k too large = incoherent; k too small = repetitive |
| Top-p | No | Yes — varies with confidence | p too high = too creative; p too low = too conservative |
| Temperature | If T→0 | No | Changes distribution shape globally |
| Top-p + T | No | Yes | Most common combination in practice |
The visualization below shows the autoregressive decode loop in action. Each step generates one token, showing the probability distribution over the top candidates. Adjust the temperature to see how it affects the distribution shape and which tokens get selected.
Each step shows the probability distribution over candidate tokens. The highlighted token is the one sampled. Adjust temperature to see how it reshapes the distribution.
How does the model know when to stop generating? Three mechanisms:
1. End-of-sequence (EOS) token. During training, every sequence ends with a special EOS token (e.g., token ID 128001 in Llama 3). When the model samples this token, generation stops. This is the "natural" stopping condition — the model decided it was done.
2. Maximum token limit. A hard cap on the number of generated tokens (e.g., max_tokens=4096 in the API). This prevents runaway generation and bounds compute cost. If the model has not produced EOS by this limit, generation is truncated.
3. Stop sequences. User-specified strings that terminate generation when they appear in the output. For example, in a chat format, you might set stop=["Human:", "User:"] so the model stops when it starts generating the next turn of dialogue.
Here is the entire inference algorithm in 15 lines of Python. No optimization, no tricks — just the raw loop that every serving system is built around:
python import torch import torch.nn.functional as F def generate(model, prompt_tokens, max_new_tokens=256, temperature=1.0, top_p=0.9): """Minimal autoregressive generation loop.""" tokens = prompt_tokens.clone() # [1, seq_len] for _ in range(max_new_tokens): # Forward pass: get logits for the LAST position logits = model(tokens)[:, -1, :] # [1, vocab_size] # Apply temperature logits = logits / temperature # Apply top-p (nucleus) filtering sorted_logits, sorted_idx = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) mask = cumulative_probs - F.softmax(sorted_logits, dim=-1) >= top_p sorted_logits[mask] = -float("inf") probs = F.softmax(sorted_logits, dim=-1) # Sample from the filtered distribution next_token_sorted = torch.multinomial(probs, num_samples=1) next_token = sorted_idx.gather(-1, next_token_sorted) # map back # Append and check for EOS tokens = torch.cat([tokens, next_token], dim=1) if next_token.item() == model.config.eos_token_id: break return tokens
There is a critical performance problem hiding in line 9. Every iteration, we pass all tokens (prompt + all generated tokens so far) through the model. If we have generated 200 tokens, the 200th iteration recomputes attention for all 200 positions — even though the first 199 positions have not changed. This is extremely wasteful. The KV cache (Chapter 4) eliminates this redundancy, and understanding it is essential to understanding modern inference systems.
Here is a paradox: processing your 500-word prompt takes 50 milliseconds. But generating a single word of the response takes 70 milliseconds. Reading 500 words is faster than writing one word. How can that be?
The answer reveals the most important distinction in LLM inference engineering: the prefill phase and the decode phase have fundamentally different computational characteristics. Prefill is compute-bound (limited by how fast the GPU can multiply). Decode is memory-bound (limited by how fast the GPU can read weights from memory). Understanding why requires diving into the shape of the matrix operations at each phase.
When your prompt arrives, all tokens are known upfront. The model can process them all in parallel. Here is what happens:
For a prompt of length S, each attention layer performs matrix multiplications of shape [S × d_model] × [d_model × d_model]. This is a matrix-matrix multiply (GEMM — General Matrix Multiply). The key property: for each weight element read from memory, we perform S multiply-accumulate operations (one for each token in the sequence). The arithmetic intensity (ratio of compute to memory access) is proportional to S.
With 500 tokens, we are well above the A100's compute/bandwidth threshold of 156 ops/byte. The GPU cores are fully busy multiplying, and memory bandwidth is not the constraint. This is why prefill is fast — the GPU is doing what it was designed to do: massive parallel matrix multiplication.
During prefill, the model also builds the KV cache: for every token at every layer, it computes and stores the Key and Value vectors. These will be reused during decoding so we do not have to recompute them. Think of it as the model "reading and remembering" your prompt.
Once the prompt is processed and the KV cache is built, generation begins. Now only ONE new token is produced at a time. Here is the critical difference:
For a single new token, each attention layer performs a matrix multiplication of shape [1 × d_model] × [d_model × d_model]. This is a matrix-vector multiply (GEMV). For each weight element read from memory, we perform only ONE multiply-accumulate operation. The arithmetic intensity is 1.
An arithmetic intensity of 1 is deep in memory-bound territory. The GPU can perform 312 trillion operations per second, but it only needs to do 2d2 ≈ 134 million operations per layer. The GPU cores execute those operations in microseconds and then sit idle, waiting for the next weight matrix to arrive from HBM. This is why decode is slow: the GPU is starving for data.
During prefill, the model computes Key (K) and Value (V) vectors for every token at every layer and every attention head. During decoding, the new token only needs to compute its own Q, K, V vectors. But the attention mechanism requires the new token's Query to attend to the Keys of all previous tokens. Rather than recomputing those Keys and Values every time, we cache them.
This is the KV cache — the most important data structure in LLM inference. Let us calculate exactly how much memory it consumes.
Let us calculate for Llama 3 70B with a 4K context:
Notice that Llama 3 70B uses Grouped-Query Attention (GQA) with only 8 KV heads instead of 64. Without GQA (standard Multi-Head Attention), the KV cache would be 8x larger: 8.6 GB at 4K context, 274 GB at 128K. GQA was specifically invented to reduce KV cache size for inference efficiency.
Let us trace exactly what happens when generating the 501st token (after a 500-token prompt):
Without the KV cache, step 4 would require recomputing K and V for all 500 previous tokens at every layer — that is 500x more work. The KV cache trades memory (storing the cached vectors) for compute (avoiding the recomputation). This is the fundamental space-time trade-off of LLM inference.
The first visualization shows the prefill and decode phases side by side. Watch how GPU utilization differs between the two phases — prefill keeps the compute units busy, while decode leaves them mostly idle.
Compare GPU utilization between the compute-bound prefill phase and the memory-bound decode phase. Adjust prompt and output length to see how they affect total latency.
This second visualization shows how the KV cache grows linearly as tokens are generated. For long sequences, the KV cache can consume more memory than the model weights themselves.
Watch GPU memory fill as the context grows. The KV cache grows linearly with sequence length, eventually dominating memory usage for long contexts.
In production, the server must handle a mix of requests simultaneously. Some requests are in the prefill phase (a new prompt just arrived), while others are in the decode phase (generating output tokens for an earlier request). These two phases compete for resources:
| Resource | Prefill wants | Decode wants | Conflict |
|---|---|---|---|
| GPU compute | All of it — large matrix multiplies | Very little — tiny matrix-vector multiplies | Prefill starves decode of scheduling slots |
| Memory bandwidth | Some — reads weights once for parallel batch | All of it — bottleneck is reading weights for each token | Prefill bursts compete with decode's steady reads |
| GPU memory | Temporary activations scale with prompt length | KV cache grows with generated length | Both need HBM; peak usage is hard to predict |
This conflict is why some advanced serving systems use prefill-decode disaggregation: they run prefill on one set of GPUs and decode on a different set, each optimized for its workload. Prefill GPUs maximize compute utilization; decode GPUs maximize memory bandwidth utilization. We will return to this in Chapter 9.
python def kv_cache_size_gb( n_layers: int, n_kv_heads: int, d_head: int, seq_len: int, bytes_per_param: float = 2.0, batch_size: int = 1, ) -> float: """Calculate KV cache memory in GB.""" # 2 for K and V, times layers, heads, dim, sequence, bytes total_bytes = 2 * n_layers * n_kv_heads * d_head * seq_len * bytes_per_param * batch_size return total_bytes / (1024 ** 3) # Llama 3 8B (GQA: 8 KV heads) print(f"Llama 3 8B, 4K ctx: {kv_cache_size_gb(32, 8, 128, 4096):.2f} GB") # → 0.25 GB # Llama 3 70B (GQA: 8 KV heads) print(f"Llama 3 70B, 4K ctx: {kv_cache_size_gb(80, 8, 128, 4096):.2f} GB") # → 1.07 GB # Llama 3 70B at 128K context print(f"Llama 3 70B, 128K ctx: {kv_cache_size_gb(80, 8, 128, 131072):.2f} GB") # → 34.36 GB — larger than many models! # GPT-4 class (estimated: 120 layers, 128 KV heads with MHA) print(f"GPT-4 class, 8K ctx: {kv_cache_size_gb(120, 128, 128, 8192):.2f} GB") # → 60.0 GB — per request! # The power of GQA: Llama 3 70B with MHA (64 KV heads) instead of GQA (8) print(f"70B with MHA, 4K ctx: {kv_cache_size_gb(80, 64, 128, 4096):.2f} GB") # → 8.59 GB — 8x more than GQA!
Your laptop CPU can multiply two numbers in 0.3 nanoseconds. An H100 GPU takes 0.5 nanoseconds for the same multiply. So why is the GPU 100x faster at inference? Because it does 16,896 multiplies at the same time. The secret is not speed — it is parallelism.
But parallelism alone is not enough. The numbers being multiplied have to arrive at the processing cores fast enough to keep them busy. This is the memory bandwidth problem, and it is the single most important hardware concept for understanding LLM inference performance. In this chapter, we will build a mental model of the hardware landscape from the ground up.
A modern CPU (say, an AMD EPYC 9654 with 96 cores) is a marvel of general-purpose engineering. Each core has deep instruction pipelines, branch predictors, speculative execution, large caches, and support for arbitrary instruction sequences. A single core can execute complex control flow, random memory accesses, and irregular computations efficiently.
But "efficiently" means "well for a single sequential task." For LLM inference, the math is thousands of identical multiply-accumulate operations across millions of weight elements. A CPU's branch predictor, out-of-order execution engine, and speculative execution hardware are all wasted — there are no branches to predict, no instructions to reorder, nothing to speculate on. Just multiply and add, millions of times.
CPUs are used for inference in three scenarios: tiny models (under 1B parameters), prototyping, and edge deployment on devices without GPUs. For anything serious, GPUs dominate.
A GPU flips the CPU's trade-off. Instead of a few powerful cores, it has thousands of simple cores, each capable of one multiply-accumulate per clock cycle. An NVIDIA H100 SXM has 16,896 CUDA cores and 528 Tensor Cores. The Tensor Cores are specialized units that perform small matrix multiplies (4x4 or 8x4 blocks) in a single operation, enabling enormous throughput on the exact kind of math that LLMs need.
The GPU's architecture maps perfectly to matrix multiplication. Consider multiplying a [4096 × 4096] weight matrix by a [4096 × 1] input vector (one decode step of one attention layer). That is 4096 × 4096 = 16.7 million multiply-adds. The GPU assigns chunks of the output to different streaming multiprocessors (SMs), each SM splits its chunk across its CUDA cores, and the entire operation completes in microseconds.
But here is the catch. Those 16.7 million operations require reading the 4096 × 4096 weight matrix from memory — that is 32 million bytes in FP16. The H100 has 3.35 TB/s of HBM bandwidth, so reading 32 MB takes 32 MB / 3350 GB/s = 9.6 microseconds. The actual computation takes approximately 16.7M / 990 TFLOPS = 0.017 microseconds. The compute is 560x faster than the memory read. The cores are idle 99.8% of the time during decode. This is the memory-bound regime we discussed in Chapter 2, now quantified at the hardware level.
Tensor Processing Units (TPUs) are custom ASICs designed by Google specifically for neural network workloads. Unlike GPUs, which evolved from graphics rendering and still carry that legacy (texture units, rasterizers), TPUs were built from scratch for matrix multiplication.
The key architectural difference: TPUs use a systolic array architecture. Imagine a 2D grid of multiply-accumulate units. Data flows through the grid in waves — weights enter from one side, activations enter from the top, and partial sums accumulate as data passes through each cell. This design maximizes data reuse: each weight is read once from memory and used across an entire row of the array, and each activation is used across an entire column.
TPU v5e (2023) has 128x128 systolic arrays operating in BF16, delivering approximately 197 TFLOPS per chip. They power Google's internal inference for Gemini, Search AI Overviews, and many other products. However, TPUs are only available through Google Cloud (you cannot buy them), and they require JAX or TensorFlow (PyTorch support exists but is less mature).
| Property | CPU | GPU (H100) | TPU (v5e) |
|---|---|---|---|
| Design purpose | General-purpose sequential tasks | Massively parallel numeric computation | Matrix multiplication (neural networks) |
| Parallelism | 8-96 cores (wide SIMD) | 16,896 CUDA + 528 Tensor Cores | 128×128 systolic array |
| Memory type | DDR5 DRAM | HBM3 | HBM2e |
| Memory capacity | 128-2048 GB | 80 GB | 16 GB |
| Memory bandwidth | ~100-300 GB/s | 3,350 GB/s | 819 GB/s |
| Peak compute (BF16) | ~5 TFLOPS | 990 TFLOPS | 197 TFLOPS |
| Power | 200-400W | 700W | ~100W |
| Software ecosystem | Everything | CUDA, PyTorch, TensorRT | JAX, TensorFlow, XLA |
| Availability | Everywhere | Cloud + purchase | Google Cloud only |
| Best for inference when | Tiny models, edge, no GPU | Almost always | Google-scale, JAX stack |
This is the single most important concept in LLM inference hardware. The GPU has multiple levels of memory, each faster and smaller than the last. Understanding this hierarchy explains why FlashAttention works, why batching helps, why quantization helps, and why the roofline model matters.
Nested boxes showing each memory level with bandwidth and capacity. The 10x jump between levels is what makes memory-aware algorithms like FlashAttention possible. Hover over each level to see detailed stats.
Let us quantify each level for an H100 SXM:
| Level | Capacity | Bandwidth | Latency | Time to read 1 GB |
|---|---|---|---|---|
| Registers | ~256 KB per SM | >100 TB/s (estimated) | <1 cycle (<1 ns) | — |
| L1 / Shared Memory (SRAM) | 256 KB per SM, ~33 MB total | ~33 TB/s aggregate | ~20-30 cycles (~20 ns) | 0.03 ms |
| L2 Cache | 50 MB | ~12 TB/s | ~200 cycles (~150 ns) | 0.08 ms |
| HBM3 (Main GPU Memory) | 80 GB | 3.35 TB/s | ~400-600 ns | 0.30 ms |
| CPU DRAM (via PCIe/NVLink) | 100s of GB | ~64 GB/s (PCIe 5.0 x16) | ~1-10 μs | 15.6 ms |
| NVMe SSD | TBs | ~7 GB/s | ~10-100 μs | 143 ms |
The critical observation: there is roughly a 10x jump in bandwidth between each adjacent level. SRAM is 10x faster than HBM. HBM is 50x faster than PCIe. Each jump represents a physical constraint — closer memory is faster but more expensive per bit, so there is less of it.
The roofline model is a visual framework for understanding whether a given computation is limited by compute or memory bandwidth. It plots achievable performance (FLOPS) as a function of arithmetic intensity (FLOPs per byte of memory accessed).
If your operation's arithmetic intensity is below 295 ops/byte, you are memory-bound — performance is limited by how fast you can read data. If it is above 295, you are compute-bound — the GPU cores are the bottleneck.
Let us plot where different inference operations land:
| Operation | Arithmetic intensity | Regime | Implication |
|---|---|---|---|
| Decode (batch=1) | ~1 op/byte | Deeply memory-bound | GPU cores idle 99%+ of the time |
| Decode (batch=32) | ~32 ops/byte | Memory-bound | Better utilization, but still below threshold |
| Decode (batch=256) | ~256 ops/byte | Near crossover | Approaching compute saturation |
| Prefill (512 tokens) | ~512 ops/byte | Compute-bound | GPU cores are the bottleneck |
| Prefill (4K tokens) | ~4096 ops/byte | Deeply compute-bound | Memory is not a factor |
| Attention (naive) | ~1 op/byte | Memory-bound | Why FlashAttention matters |
| Attention (FlashAttention) | ~100 ops/byte | Much closer to threshold | Keeps data in SRAM, avoids HBM round-trips |
Not all GPUs are created equal for inference. Here is the landscape as of 2025:
| Category | Example | VRAM | BW (TB/s) | BF16 TFLOPS | Price | Best for |
|---|---|---|---|---|---|---|
| Consumer | RTX 4090 | 24 GB | 1.01 | 165 | $1,600 | 7B-13B models, hobbyist |
| Consumer | RTX 5090 | 32 GB | 1.79 | 209 | $2,000 | 13B-30B quantized models |
| Workstation | RTX 6000 Ada | 48 GB | 0.96 | 91 | $6,800 | 34B models, small batch |
| Data center | A100 SXM | 80 GB | 2.04 | 312 | ~$15K | 70B models (2 GPU), standard |
| Data center | H100 SXM | 80 GB | 3.35 | 990 | ~$30K | 70B-400B, highest throughput |
| Data center | H200 | 141 GB | 4.80 | 990 | ~$35K | Large models, long contexts (more VRAM) |
| Data center | B200 | 192 GB | 8.00 | 2250 | ~$40K | Next-gen, FP4 support |
For inference specifically, the three numbers that matter most are:
1. VRAM (GB): Determines which models fit. A 70B FP16 model needs 140 GB — two H100s or one H200. INT4 quantization cuts this to 35 GB — one H100 or one RTX 5090.
2. Memory bandwidth (TB/s): Determines decode speed (tokens/second for batch=1). More bandwidth = faster token generation.
3. Compute (TFLOPS): Determines prefill speed and throughput at large batch sizes.
The formula is straightforward. Total GPU memory required is the sum of model weights, KV cache, and activation memory:
Let us work through several configurations:
python def estimate_gpu_memory( params_b: float, # billions of parameters quant_bits: int = 16, # 4, 8, or 16 n_layers: int = 80, # transformer layers n_kv_heads: int = 8, # KV heads (GQA) d_head: int = 128, # dim per head d_model: int = 8192, # hidden dimension seq_len: int = 4096, # context length batch_size: int = 1, kv_bits: int = 16, # KV cache precision ): """Estimate total GPU memory for inference.""" # Model weights weight_gb = params_b * (quant_bits / 8) # KV cache kv_bytes = 2 * n_layers * n_kv_heads * d_head * seq_len * (kv_bits / 8) * batch_size kv_gb = kv_bytes / (1024 ** 3) # Activations (rough estimate: ~4x one layer's activations) act_bytes = batch_size * seq_len * d_model * 2 * 4 # FP16 activations act_gb = act_bytes / (1024 ** 3) total_gb = weight_gb + kv_gb + act_gb print(f"{'='*50}") print(f"Model: {params_b}B params, {quant_bits}-bit quantization") print(f"Context: {seq_len} tokens, batch size {batch_size}") print(f"{'='*50}") print(f" Weights: {weight_gb:7.2f} GB") print(f" KV cache: {kv_gb:7.2f} GB") print(f" Activations: {act_gb:7.2f} GB") print(f" Total: {total_gb:7.2f} GB") print() # GPU fit check gpus = {"RTX 4090": 24, "RTX 5090": 32, "A100": 80, "H100": 80, "H200": 141, "B200": 192} for name, vram in gpus.items(): n_gpus = max(1, -(-int(total_gb) // vram)) # ceiling division fits = "✓" if total_gb <= vram else f"needs {n_gpus}x" print(f" {name:10s} ({vram}GB): {fits}") # Examples estimate_gpu_memory(8, 16, 32, 8, 128, 4096, 4096) # Llama 3 8B FP16 estimate_gpu_memory(70, 4, 80, 8, 128, 8192, 4096) # Llama 3 70B INT4 estimate_gpu_memory(70, 16, 80, 8, 128, 8192, 131072) # 70B FP16, 128K ctx
Let us trace exactly what happens to arithmetic intensity as we increase batch size during decode, using real H100 numbers:
At batch size 1, we read 128 MB of weights and produce 1 token. At batch size 512, we read the same 128 MB and produce 512 tokens. The memory cost is amortized over the batch, and the GPU finally gets to do what it was built for: compute.
But there is a limit. Each user in the batch needs their own KV cache. At 4K context with Llama 3 70B, that is 1.07 GB per user. A batch of 512 would need 548 GB of KV cache alone — far more than any single GPU's VRAM. In practice, batch sizes of 32-128 are common for large models, balancing GPU utilization against KV cache memory.
Your CEO asks: "Is our chatbot fast?" You check average latency: 2 seconds. Looks fine. But you dig deeper and find that the P99 latency is 45 seconds — 1 in 100 users waits nearly a minute. Which number do you report?
Both. And neither alone. Inference performance is multidimensional. A single number can hide catastrophic user experiences or mask excellent tail behavior. This chapter builds your full vocabulary for measuring LLM inference speed — the metrics that actually matter when you're on-call at 2am and the pager is screaming.
TTFT (Time to First Token) measures the elapsed time from when a request is submitted to when the first output token is generated. It captures the entire prefill phase: tokenizing the prompt, computing KV caches for every input token, and producing the first logit distribution.
Why does TTFT matter? Because it's what the user feels. When you type a question and stare at a blank screen, that silence is TTFT. Even if the rest of the response streams smoothly, a long TTFT makes the system feel broken.
TTFT scales with prompt length because the prefill phase processes all input tokens. A 100-token prompt prefills in maybe 50ms. A 100K-token document? Several seconds. This is why prompt caching (reusing KV caches from repeated system prompts) slashes TTFT for production systems.
E2EL (End-to-End Latency) is the total wall-clock time from request submission to the final token. It includes everything: network transit, queuing, prefill, and every single decode step.
Where Token Generation Time is the total time spent generating all output tokens after the first one. If your model generates 200 tokens at 40ms each, the generation time is roughly 200 × 40ms = 8000ms. Add a 300ms TTFT and you get E2EL ≈ 8.3 seconds.
E2EL matters most for non-streaming use cases — API calls that return a complete response, batch processing, or any pipeline where you need the full output before proceeding.
TPOT (Time Per Output Token) is the average time between consecutive output tokens, excluding TTFT. It captures the steady-state decode speed.
Why output_tokens - 1? Because we're measuring intervals. If you generate 5 tokens, there are 4 gaps between them. Think of fence posts: 5 posts, 4 gaps.
Let's work a real example. A request has:
At 40.8ms per token, that's about 24.5 tokens per second. Comfortable reading speed is roughly 4-5 words per second, and with an average of ~1.3 tokens per word, you need about 5-7 tokens/second to keep up with reading. So 24.5 tok/s is roughly 4× faster than a human reads — the stream will feel instant.
ITL (Inter-Token Latency) is the exact measured pause between two consecutive tokens. It's not an average — it's the raw per-gap measurement.
For a single request, average ITL equals TPOT. They're the same number. The distinction only matters when you're looking across multiple requests:
This matters in practice. Suppose you have two requests: one generates 10 tokens (ITL = 30ms each) and another generates 1000 tokens (ITL = 50ms each). The average TPOT across requests is (30+50)/2 = 40ms. But the average ITL across all intervals is (10×30 + 1000×50) / 1010 ≈ 49.8ms. Long requests dominate ITL statistics.
A mean (average) is what your dashboard shows. A median (P50) is what a typical user experiences. A P99 is what your angriest user experiences.
All three are needed. Here's why:
| Metric | Value | What It Tells You |
|---|---|---|
| Mean TTFT | 180ms | Average looks great |
| P50 TTFT | 120ms | Typical user is happy |
| P99 TTFT | 4200ms | 1 in 100 users waits 4+ seconds — they'll leave |
The mean is dragged up by those P99 outliers, which is why it's higher than the median. But neither the mean nor median warns you about the tail. Always report P50 and P99. If P99/P50 > 10×, you have a tail latency problem.
RPS (Requests Per Second) counts completed requests per second. It's coarse — a 10-token response and a 2000-token response both count as one request. But it's simple and useful for capacity planning.
TPS (Tokens Per Second) is finer-grained. Split it into:
Input TPS is typically much higher than output TPS because prefill is parallelizable — you process all prompt tokens in one forward pass. Decode generates one token per request per step.
Raw throughput without latency constraints is meaningless. A system can achieve 10,000 tok/s by batching so aggressively that every request takes 30 seconds. Is that good? Not if your SLO says TPOT < 100ms.
An SLO (Service Level Objective) is a target you set for acceptable performance. For example: "P99 TTFT < 2s AND P99 TPOT < 100ms." Any request that violates the SLO is a failure, even if it eventually completes.
Goodput is throughput that meets your SLO. It's the fraction of requests that were both completed and met all latency targets.
In practice, you graph goodput as a function of offered load. At low load, goodput = throughput (everything meets SLO). As load increases, latency rises, SLO violations start, and goodput plateaus or drops. The knee of that curve is your system's true capacity.
python import numpy as np def compute_metrics(ttft_ms, token_timestamps_ms): """Compute all inference metrics from raw timing data. Args: ttft_ms: Time to first token in milliseconds. token_timestamps_ms: List of timestamps for each output token (first entry = first token time). """ n_tokens = len(token_timestamps_ms) if n_tokens < 2: return {"ttft": ttft_ms, "e2el": ttft_ms, "tpot": 0} # Inter-token latencies: gap between consecutive tokens itls = np.diff(token_timestamps_ms) e2el = token_timestamps_ms[-1] # last token timestamp = total time gen_time = e2el - ttft_ms tpot = gen_time / (n_tokens - 1) return { "ttft_ms": ttft_ms, "e2el_ms": e2el, "tpot_ms": round(tpot, 2), "tokens_per_sec": round(1000 / tpot, 1), "mean_itl_ms": round(np.mean(itls), 2), "p50_itl_ms": round(np.percentile(itls, 50), 2), "p99_itl_ms": round(np.percentile(itls, 99), 2), } # Example: 50 tokens, TTFT=200ms, E2EL=2200ms # Simulate evenly spaced tokens after TTFT ttft = 200 n = 50 e2el = 2200 tpot_approx = (e2el - ttft) / (n - 1) # ~40.8ms timestamps = [ttft + i * tpot_approx for i in range(n)] result = compute_metrics(ttft, timestamps) print(result) # {'ttft_ms': 200, 'e2el_ms': 2200.0, 'tpot_ms': 40.82, # 'tokens_per_sec': 24.5, 'mean_itl_ms': 40.82, # 'p50_itl_ms': 40.82, 'p99_itl_ms': 40.82}
python def compute_goodput(requests, slo_ttft_ms, slo_tpot_ms): """Compute goodput: fraction of requests meeting SLO. Args: requests: List of dicts with 'ttft_ms' and 'tpot_ms' keys. slo_ttft_ms: Max acceptable TTFT. slo_tpot_ms: Max acceptable TPOT. """ total = len(requests) passing = sum( 1 for r in requests if r["ttft_ms"] <= slo_ttft_ms and r["tpot_ms"] <= slo_tpot_ms ) return { "total_rps": total, "goodput_rps": passing, "goodput_rate": round(passing / total, 3), "slo_violations": total - passing, } # Example: 100 requests, SLO = TTFT<500ms, TPOT<100ms import random requests = [] for _ in range(100): # Simulate: 90% of requests are fast, 10% are slow if random.random() < 0.9: requests.append({"ttft_ms": random.uniform(100, 300), "tpot_ms": random.uniform(30, 80)}) else: requests.append({"ttft_ms": random.uniform(800, 5000), "tpot_ms": random.uniform(100, 300)}) result = compute_goodput(requests, slo_ttft_ms=500, slo_tpot_ms=100) print(result) # ~{'total_rps': 100, 'goodput_rps': 90, 'goodput_rate': 0.9, # 'slo_violations': 10}
| Strategy | Effect on Latency | Effect on Throughput | When to Use |
|---|---|---|---|
| Increase batch size | ↑ Higher (more queuing) | ↑ Higher (GPU utilized) | Throughput-sensitive batch jobs |
| Decrease batch size | ↓ Lower (less queuing) | ↓ Lower (GPU underused) | Latency-sensitive chat apps |
| Quantize model (INT8) | ↓ Lower (faster compute) | ↑ Higher (more fits in memory) | Almost always beneficial |
| Tensor parallelism | ↓ Lower (split across GPUs) | Neutral or slight ↓ | Single-request latency critical |
| More GPUs (replicas) | Neutral per-request | ↑ Linear scaling | Scale throughput without latency cost |
| Speculative decoding | ↓ Lower (multi-token steps) | Neutral or slight ↑ | Autoregressive decode-bound cases |
Watch tokens stream in real time. The counters show TTFT, ITL, TPOT, and E2EL updating live. Increase batch load to see how metrics degrade under contention.
X-axis is batch size. Left Y-axis (teal) shows output tokens/sec. Right Y-axis (warm) shows TPOT. The shaded region below the SLO line is your goodput zone — throughput that actually meets quality targets.
A production LLaMA 70B serving system handles a request with these raw timings:
Let's compute every metric:
If the SLO requires TTFT < 500ms and TPOT < 100ms, this request passes both targets. It contributes to goodput. But under heavy batch load, that same request might see TTFT = 3200ms (prefill queued behind other requests), and suddenly it's an SLO violation despite generating tokens at the same speed.
Imagine a restaurant where the kitchen won't serve any table until every table's food is ready. Table 1 ordered a salad (2 minutes). Table 4 ordered a slow-roasted duck (45 minutes). Table 1 waits 43 minutes for a salad. That's static batching.
The restaurant analogy isn't just cute — it's structurally identical to the problem. Requests arrive at different times, require different amounts of work, and finish at different times. How you schedule them determines whether your GPU sits idle or stays maximally utilized.
A single LLM decode step for one request looks like this: load 140GB of model weights from HBM into compute units, multiply by one token's hidden state, produce one output token. The GPU performs maybe 1 floating-point operation for every byte it reads from memory — an arithmetic intensity of ~1 op/byte. That's absurdly low. An A100 can do 312 TFLOPS but only read at 2 TB/s, giving a machine balance of ~156 ops/byte. At 1 op/byte, you're using 0.6% of the GPU's compute capacity.
Batching fixes this. If you process 32 requests simultaneously, you load the weights once and multiply by 32 hidden states. Same memory traffic, 32× the useful compute. The arithmetic intensity jumps to ~32 ops/byte, and your GPU goes from 0.6% utilization to ~20%.
But batching introduces a tension: bigger batches improve throughput but increase latency (each request waits for others). The question isn't whether to batch — it's how.
Static batching is the simplest strategy: collect exactly N requests, then process them as a fixed batch until every request in the batch finishes.
The problems are severe:
python # Static batching pseudocode def static_batch_serve(queue, batch_size): while True: # Wait until we have enough requests batch = [] while len(batch) < batch_size: batch.append(queue.get()) # BLOCKS until request arrives # Pad all prompts to same length max_prompt_len = max(r.prompt_len for r in batch) max_output_len = max(r.max_tokens for r in batch) # Run prefill for all (padded) kv_caches = prefill(batch, pad_to=max_prompt_len) # Decode until ALL requests hit max_output_len or EOS for step in range(max_output_len): tokens = decode_step(kv_caches, batch) # Even finished requests still occupy a slot # Only NOW return all results for r in batch: r.respond()
Static batching is like a printer that won't start until you queue exactly 8 documents, then prints all 8 even if the first 7 finish on page 1.
Dynamic batching improves on static by adding a time limit. Instead of waiting for exactly N requests, it waits for either N requests or T milliseconds, whichever comes first.
This solves head-of-line blocking: if requests arrive slowly, the time limit kicks in and processing starts with however many requests have accumulated. But the tail-latency problem remains — all requests in the batch still wait for the slowest to finish.
python # Dynamic batching pseudocode def dynamic_batch_serve(queue, max_batch, max_wait_ms): while True: batch = [] deadline = now() + max_wait_ms # Collect until batch full OR time expires while len(batch) < max_batch and now() < deadline: remaining = deadline - now() try: batch.append(queue.get(timeout=remaining)) except Empty: break if batch: # Still pads and waits for slowest, but at least # we didn't wait forever for the batch to fill process_batch(batch)
Dynamic batching is like a bus that departs either when full or on schedule — whichever comes first. Better than waiting for every seat to fill, but still takes the slowest route once it departs.
Continuous batching (also called iteration-level scheduling) changes the fundamental unit of scheduling from "the entire request" to "a single decode step."
Here's the key insight: every decode step across all requests in a batch takes roughly the same time (one forward pass). So at each step, the scheduler can:
No more waiting for the slowest request. The moment a short request finishes, its slot is recycled. The GPU never idles — there's always a request filling every slot.
python # Continuous batching pseudocode def continuous_batch_serve(queue, max_batch): active = [] # Currently generating requests while True: # Step 1: Fill empty slots with waiting requests while len(active) < max_batch and not queue.empty(): new_req = queue.get() new_req.kv_cache = prefill(new_req) # Prefill just this one active.append(new_req) if not active: sleep(0.001) # No work, brief pause continue # Step 2: One decode step for ALL active requests tokens = decode_step([r.kv_cache for r in active]) # Step 3: Check for finished requests still_active = [] for r, tok in zip(active, tokens): r.append_token(tok) if tok == EOS or r.output_len >= r.max_tokens: r.respond() # Return result immediately # Slot is now FREE for next iteration else: still_active.append(r) active = still_active
Continuous batching is now the standard in production serving engines. vLLM, SGLang, TensorRT-LLM, LMDeploy, and Text Generation Inference (TGI) all implement it.
| Property | Static | Dynamic | Continuous |
|---|---|---|---|
| Scheduling granularity | Request-level | Request-level | Iteration-level |
| Head-of-line blocking | Severe | Reduced (timeout) | None |
| Tail latency | Slowest in batch | Slowest in batch | Per-request |
| GPU utilization | Low (wasted slots) | Medium | High (always full) |
| Throughput | 1× (baseline) | ~2-4× | ~10-23× |
| Complexity | Trivial | Low | Medium (KV cache mgmt) |
| Implementations | Naive PyTorch | Triton Server | vLLM, SGLang, TGI, TRT-LLM |
Three side-by-side timelines show how requests flow through static, dynamic, and continuous batching. Colored bars represent requests; height is output length. Watch GPU utilization and average latency update in real time. Higher output variance makes the difference more dramatic.
A 4,096-token sequence produces an attention matrix with 4,096 × 4,096 = 16.8 million entries. At FP16 (2 bytes each), that's 32 MB. Manageable. Now scale to 128K tokens: 128,000 × 128,000 = 16.4 billion entries. At 2 bytes each: 32 GB for a single attention matrix in a single layer. The model has 80 layers. You'd need 2.5 TB just for attention matrices. There has to be a better way.
There is. Two complementary innovations transformed attention from a memory bottleneck into a tractable operation: FlashAttention (2022) made attention fast by keeping computation in fast SRAM instead of slow HBM, and PagedAttention (2023) made KV caches memory-efficient by borrowing virtual memory ideas from operating systems.
Recall the attention formula:
Where Q, K, V are the query, key, and value matrices, each of shape (N, dk) for N tokens and head dimension dk (typically 64 or 128).
The naive implementation does this in three steps, each requiring a round trip to HBM (the GPU's main memory):
The total memory reads and writes are dominated by the N×N attention matrix. For every attention head, you write it once (Step 1) and read it twice (Steps 2 and 3). That's 3 × N² × 2 bytes of HBM traffic per head.
Let's do the arithmetic for a real model. LLaMA 70B has 64 attention heads with dk = 128. For a sequence of N = 4096 tokens:
An A100 has HBM bandwidth of 2 TB/s. So just the attention memory transfers take 6.4 GB / 2000 GB/s = 3.2ms per layer. With 80 layers, that's 256ms just for attention memory traffic — and we haven't even done any useful compute yet. The actual matrix multiplies are fast; the bottleneck is moving data between HBM and compute units.
This is the crux of the problem, so let's understand it precisely. A modern GPU like the A100 has two levels of memory:
| Memory | Capacity | Bandwidth | Latency |
|---|---|---|---|
| HBM (High Bandwidth Memory) | 80 GB | 2 TB/s | ~400 ns |
| SRAM (on-chip, per SM) | 20 MB total | ~33 TB/s | ~30 ns |
SRAM is 16× faster than HBM but 4000× smaller. The full N×N attention matrix doesn't fit in SRAM. So the naive algorithm writes it to HBM and reads it back — paying the slow-memory tax on every step.
The core idea of FlashAttention is deceptively simple: never write the full N×N attention matrix to HBM. Instead, compute attention in small tiles that fit in SRAM, accumulating the output incrementally.
Two techniques make this possible:
1. Tiling: Break Q, K, and V into blocks of size Br × d and Bc × d (where Br, Bc are chosen to fit in SRAM). For each block of Q, iterate over all blocks of K and V. The intermediate attention scores (a Br × Bc tile) live entirely in SRAM — never written to HBM.
2. Online softmax (the Milakov-Gimelshein trick): Normal softmax needs the maximum value across the entire row to compute stable exponentials. But we're processing K in blocks — we don't see the whole row at once. The online softmax algorithm maintains a running maximum and scaling factor, updating them as each new K block arrives. This produces the exact same result as full-row softmax.
Where m is the running row-max, ℓ is the running sum of exponentials, and O is the running output accumulator. Each time we process a new tile, we rescale the previous accumulator by the ratio of old-to-new exponentials. Exact. No approximation.
3. Kernel fusion: Instead of launching three separate GPU kernels (QKT, softmax, PV), FlashAttention fuses everything into a single kernel. No intermediate results are written to HBM between operations. Data stays in registers and shared memory throughout.
The speedup is dramatic because we eliminated the O(N²) HBM traffic:
| Metric | Standard Attention | FlashAttention |
|---|---|---|
| HBM reads/writes | O(N²) | O(N² / B) where B = SRAM block size |
| Extra memory | O(N²) for attention matrix | O(N) for running stats |
| Wall-clock time (4K seq) | ~15ms per layer | ~4ms per layer |
| Wall-clock time (16K seq) | ~240ms per layer | ~30ms per layer |
The longer the sequence, the bigger the win. At 16K tokens, FlashAttention is 8× faster because the O(N²) HBM traffic dominates at longer lengths.
| Version | Year | Key Improvement | Hardware |
|---|---|---|---|
| FlashAttention-1 | 2022 | Tiling + online softmax + kernel fusion | A100 (Ampere) |
| FlashAttention-2 | 2023 | Better work partitioning across warps, reduced non-matmul FLOPs by 4× | A100, H100 |
| FlashAttention-3 | 2024 | FP8 support, warp-specialization for Hopper TMA, asynchronous softmax | H100 (Hopper) |
FA-2 achieved 50-73% of theoretical max FLOPS on A100 (standard attention gets ~30%). FA-3 on H100 with FP8 reaches ~75% of the peak 1.98 PFLOPS, approaching hardware limits.
Left: naive attention materializes the full N×N matrix in HBM (red = written to slow memory). Right: FlashAttention processes tiles that fit in SRAM (green = stays in fast memory). Increase sequence length to see how the naive approach blows up while FlashAttention stays efficient.
FlashAttention is available through PyTorch's scaled_dot_product_attention (SDPA), which automatically dispatches to the FA kernel when available:
python import torch import torch.nn.functional as F # Q, K, V shape: (batch, num_heads, seq_len, head_dim) batch, heads, seq_len, d_k = 1, 32, 4096, 128 Q = torch.randn(batch, heads, seq_len, d_k, device="cuda", dtype=torch.float16) K = torch.randn(batch, heads, seq_len, d_k, device="cuda", dtype=torch.float16) V = torch.randn(batch, heads, seq_len, d_k, device="cuda", dtype=torch.float16) # PyTorch 2.0+ automatically uses FlashAttention when: # - Inputs are FP16 or BF16 # - Running on compatible GPU (Ampere+) # - Head dim is <= 256 # - No attention mask (or causal mask) output = F.scaled_dot_product_attention( Q, K, V, is_causal=True, # Causal mask for autoregressive decoding scale=1.0 / (d_k ** 0.5), ) # output shape: (1, 32, 4096, 128) — same as standard attention # Check which backend was used: with torch.backends.cuda.sdp_kernel( enable_flash=True, enable_math=False, enable_mem_efficient=False ): out_flash = F.scaled_dot_product_attention(Q, K, V, is_causal=True) # This ONLY uses FlashAttention — raises error if unavailable
FlashAttention solves the compute problem. But there's a memory management problem too: the KV cache.
In standard serving, each request's KV cache is stored as a single contiguous block in GPU memory. The problem: you don't know in advance how long a response will be. So you pre-allocate for the maximum possible length. If max_tokens = 2048 but the actual response is 50 tokens, 97.5% of that allocation is wasted.
Worse, contiguous allocation causes memory fragmentation. As requests start and finish, they leave differently-sized holes in memory. Even if total free memory is sufficient, you can't fit a new request because the free space is scattered.
PagedAttention, introduced by vLLM (2023), works like this:
Let's walk through the memory savings with real numbers. Consider serving LLaMA 13B (40 layers, 40 heads, dk=128) with max_tokens=2048:
Contiguous allocation wastes 1.6 GB per request! With PagedAttention, you allocate only what's used: 40 MB for a 50-token response. That means you can serve 40× more concurrent requests with the same GPU memory.
Left: contiguous allocation wastes memory with pre-allocated but unused space. Right: paged allocation uses fixed-size blocks that can be scattered in memory, with a block table for lookup. Toggle between modes to see how fragmentation affects capacity.
A subtle point: during training, the backward pass needs the attention matrix to compute gradients. But FlashAttention never stored it! The solution: recompute it from Q, K, V during the backward pass.
This sounds wasteful — you're doing the forward computation twice. But it's actually faster than saving and loading the N×N matrix from HBM. The recomputation is done tile-by-tile in SRAM, which is 16× faster than the HBM round trip the naive approach requires.
This is a textbook example of the compute-memory tradeoff: sometimes it's faster to recompute a result than to store and retrieve it, because memory access is the bottleneck.
What if you had a junior writer who drafts paragraphs really fast, and a senior editor who reviews them? If the junior's draft is mostly right, the senior just nods — you skipped the slow part. If a word is wrong, the senior rewrites just that word and the junior starts over from there.
That's speculative decoding. And it's one of the cleverest ideas in modern inference because it speeds up generation without changing the output distribution at all.
Autoregressive decoding generates one token per forward pass. Each forward pass through a 70B-parameter model takes ~35ms on an A100 (loading 140GB of weights at 2 TB/s bandwidth). To generate 100 tokens, that's 100 × 35ms = 3.5 seconds of sequential, memory-bound work.
But here's the thing: the GPU isn't actually busy during decode. As we learned in Chapter 5, the arithmetic intensity of single-request decode is ~1 op/byte — the GPU spends most of its time waiting for data from HBM. The compute units are idle 99% of the time.
Two key observations enable speculative decoding:
Speculative decoding uses two models:
The algorithm works in rounds:
This is the remarkable part: speculative decoding produces output that is statistically identical to what the target model would generate alone. No approximation, no quality loss.
The acceptance criterion uses modified rejection sampling. For each draft token x̂i with draft probability q(x̂i) and target probability p(x̂i):
This guarantees that the accepted token is distributed exactly according to p(x), the target distribution. The math works out because the adjusted distribution covers exactly the probability mass that the draft model overestimated.
Three quantities determine how much speedup you get:
α (acceptance rate): The probability that any given draft token is accepted. This depends on how well the draft model matches the target. Higher α = more tokens accepted per round = bigger speedup.
γ (speculative count / lookahead): How many tokens the draft model proposes per round. A tunable hyperparameter. Larger γ means more potential tokens per round, but also more wasted work if the draft is rejected early.
τ (expected acceptance length): The average number of tokens accepted per round. This is what determines actual speedup.
The formula for expected acceptance length, assuming independent token-level acceptance with probability α:
This is a geometric series. If the draft model is perfect (α=1), then τ = γ+1 (all tokens accepted plus one bonus from the target). If the draft model is terrible (α→0), then τ → 1 (only the target's corrective token).
Let's compute the expected speedup with realistic numbers. Suppose:
Expected acceptance length:
Each round produces ~3.69 tokens and costs 10ms (drafting) + 35ms (verification) = 45ms.
Without speculative decoding: 3.69 tokens × 35ms/token = 129ms.
Speedup: 129ms / 45ms = 2.87×
Let's check the boundary cases:
| α | γ | τ (tokens/round) | Time/round | Baseline time | Speedup |
|---|---|---|---|---|---|
| 0.5 | 5 | 1.97 | 45ms | 69ms | 1.53× |
| 0.6 | 5 | 2.37 | 45ms | 83ms | 1.84× |
| 0.7 | 5 | 2.94 | 45ms | 103ms | 2.29× |
| 0.8 | 5 | 3.69 | 45ms | 129ms | 2.87× |
| 0.9 | 5 | 4.69 | 45ms | 164ms | 3.65× |
| 0.8 | 3 | 2.95 | 41ms | 103ms | 2.52× |
| 0.8 | 8 | 4.46 | 51ms | 156ms | 3.06× |
The sweet spot is α ≥ 0.7 with γ between 4-8. Below α = 0.5, the draft model is wrong so often that you waste more time drafting than you save by parallel verification.
python import torch def speculative_decode( target_model, # Large model (70B) draft_model, # Small model (1B) prompt_tokens, # Input token IDs max_new_tokens, # Max tokens to generate gamma=5, # Speculative lookahead ): generated = [] context = prompt_tokens.clone() while len(generated) < max_new_tokens: # === DRAFT PHASE === # Draft model generates gamma candidate tokens autoregressively draft_tokens = [] draft_probs = [] draft_ctx = context.clone() for _ in range(gamma): logits = draft_model(draft_ctx) q = torch.nn.functional.softmax(logits[-1], dim=-1) tok = torch.multinomial(q, 1) draft_tokens.append(tok) draft_probs.append(q[tok]) draft_ctx = torch.cat([draft_ctx, tok.unsqueeze(0)]) # === VERIFY PHASE === # Target model scores all candidates in ONE forward pass verify_ctx = torch.cat([context] + [t.unsqueeze(0) for t in draft_tokens]) target_logits = target_model(verify_ctx) # === ACCEPT / REJECT === n_accepted = 0 for i in range(gamma): # Target probability at position where draft token was placed p = torch.nn.functional.softmax( target_logits[len(context) + i - 1], dim=-1 ) p_tok = p[draft_tokens[i]] q_tok = draft_probs[i] # Modified rejection sampling if torch.rand(1) < torch.min( torch.tensor(1.0), p_tok / q_tok ): # ACCEPT this draft token generated.append(draft_tokens[i]) n_accepted += 1 else: # REJECT: sample from adjusted distribution adjusted = torch.clamp(p - draft_probs[i], min=0) adjusted = adjusted / adjusted.sum() tok = torch.multinomial(adjusted, 1) generated.append(tok) break else: # All gamma tokens accepted! Bonus: sample one more from target p = torch.nn.functional.softmax( target_logits[len(context) + gamma - 1], dim=-1 ) bonus_tok = torch.multinomial(p, 1) generated.append(bonus_tok) # Update context with accepted + corrected tokens context = torch.cat([context] + [t.unsqueeze(0) for t in generated[-(n_accepted+1):]]) return generated
Memory overhead: You need both models in GPU memory simultaneously. A 70B target + 1B draft = 141GB + ~2GB. On 2× A100 80GB (tensor parallel), this is tight. Consider quantizing the draft model to INT4 to save space.
Choosing the draft model: The draft model must have the same tokenizer as the target (otherwise token-level acceptance is meaningless). Use the same model family at a smaller size: LLaMA-1B for LLaMA-70B, GPT-2 for GPT-4, etc. Fine-tuning the draft model on your specific data distribution raises α significantly.
Dynamic γ: Rather than fixing γ=5, adapt it based on observed acceptance rate. If recent rounds have high α, increase γ to be more aggressive. If α drops, decrease γ to avoid wasting draft compute.
Self-drafting (Medusa, EAGLE): Instead of a separate draft model, attach small prediction heads to the target model itself. These heads predict tokens 2, 3, ..., K positions ahead. No extra model needed, but requires fine-tuning the prediction heads.
| Variant | Draft Source | Memory Overhead | Typical Speedup |
|---|---|---|---|
| Standard speculative | Separate small model | +1-7B params | 2-3× |
| Medusa | Extra prediction heads on target | +~1% params | 2-3× |
| EAGLE | Autoregressive head on target features | +~1% params | 2-4× |
| Lookahead decoding | N-gram cache from Jacobi iteration | None | 1.5-2.5× |
| Prompt lookup | N-grams from the prompt itself | None | 1.5-3× (high overlap) |
Watch the draft-verify loop in action. The draft model (top) proposes γ tokens quickly. The target model (bottom) verifies in one pass. Accepted tokens glow teal, rejected tokens glow red, and the correction is shown in warm. Adjust α and γ to see how acceptance rate and lookahead affect speedup.
The theoretical maximum speedup is γ + 1 (if every draft token is accepted, plus the bonus token). In practice:
The key takeaway: speculative decoding helps most when the output is predictable. For highly creative or diverse generation, the draft model misses too often and the speedup diminishes.
A 70B model in FP32 needs 280 GB — four H100 GPUs just for weights. In INT4? 35 GB — it fits on ONE GPU. Same model. Same 70 billion parameters. How can we throw away 87.5% of the precision and still get good outputs?
The answer is that not all precision is equally important. Most weights in a neural network cluster near zero, in a tight bell curve. The outliers — the weights far from zero — carry disproportionate information. If you can protect those outliers and compress the rest, you keep most of the model's intelligence while slashing memory by 4-8x.
This chapter covers the full quantization landscape: formats, methods, tradeoffs, and the arithmetic that lets you calculate exactly how much VRAM any model needs at any precision.
Every neural network weight is just a number. How many bits you use to store that number determines precision, memory, and speed. Here is the full spectrum, with real numbers for a 7B-parameter model:
| Format | Bits | 7B Model Size | vs FP32 | Accuracy Impact | Hardware Support |
|---|---|---|---|---|---|
| FP32 | 32 | 28.0 GB | 1.00x | Baseline (full) | All GPUs |
| FP16 | 16 | 14.0 GB | 0.50x | Negligible | All modern GPUs |
| BF16 | 16 | 14.0 GB | 0.50x | Negligible | A100+, 3090+ |
| FP8 (E4M3) | 8 | 7.0 GB | 0.25x | Small (~0.1-0.5%) | H100+ (Hopper) |
| INT8 | 8 | 7.0 GB | 0.25x | Small (~0.1-0.5%) | All modern GPUs |
| INT4 | 4 | 3.5 GB | 0.125x | Moderate (~0.5-2%) | Via software (GPTQ, AWQ) |
| INT2 | 2 | 1.75 GB | 0.0625x | Large (3-10%+) | Experimental |
Let's unpack the two 16-bit formats because the difference matters:
FP16 (half precision) uses 1 sign bit, 5 exponent bits, and 10 mantissa bits. It can represent numbers as small as ~6 × 10-8 and as large as 65,504. The problem: that max value of 65,504 is easy to exceed during training (gradient accumulation, large activations), causing overflow to infinity.
BF16 (brain floating point, invented at Google Brain) uses 1 sign bit, 8 exponent bits, and only 7 mantissa bits. The larger exponent range means it can represent numbers up to ~3.4 × 1038 — the same range as FP32 — so overflow is virtually impossible. The tradeoff: less mantissa precision (7 bits vs 10), meaning each value is slightly less precise. For LLM inference, this barely matters. BF16 is now the standard training format.
The formula for weight memory is dead simple:
Or equivalently:
Where bytes_per_param = 4 for FP32, 2 for FP16/BF16, 1 for INT8, 0.5 for INT4.
Let's do worked examples for real models:
At first glance, going from 32-bit to 4-bit should destroy a model. You're mapping 232 (4.3 billion) possible values down to 24 (16) possible values. How can anything survive?
The answer comes from the weight distribution. If you plot all the weights in a typical LLM layer, you see a tight bell curve centered near zero. Most weights are small. A few are large. The small weights can be snapped to coarse grid points without changing the model's behavior much, because neighboring grid points produce nearly identical activations.
Think of it like audio compression. A 16-bit WAV file stores 65,536 volume levels per sample. An MP3 uses psychoacoustic modeling to discard inaudible details and store far less. You can't tell the difference in your headphones. Quantization does the same for neural network weights — it discards precision that doesn't meaningfully affect outputs.
The critical insight: not all weights are equally important. About 1% of weights — the outliers, the ones with unusually large magnitudes — carry a disproportionate share of the model's capability. The best quantization methods identify these salient weights and protect them.
There are two things you can quantize in a neural network:
Weight quantization is the most common approach. Weights are static — they don't change between requests. You quantize them once after training and store them in the lower precision format. Every inference uses the quantized weights. This is straightforward because the weight distribution is fixed and can be analyzed offline.
Activation quantization is harder. Activations are the intermediate values computed during inference — the outputs of each layer. They change with every input. Worse, activations often have outlier channels: specific channels whose values are 10-100x larger than the rest. These outliers make naive quantization fail because the quantization range must be stretched to accommodate them, leaving very few levels for the normal-magnitude values.
W8A8 means "weights in INT8, activations in INT8." W4A16 means "weights in INT4, activations in FP16." W4A16 is the most common for LLM inference today — you get the 4x memory savings from quantized weights, but compute each matrix multiply by dequantizing weights on the fly and using FP16 arithmetic.
GPTQ (GPT Quantization) is the workhorse method for post-training quantization to 3-4 bits. Published in 2023, it can quantize a 175B model in approximately 4 GPU-hours — no retraining needed.
The core idea: quantize weights one layer at a time, adjusting the remaining weights to compensate for quantization error. It treats quantization as an optimization problem: given a calibration dataset (typically 128 samples of text), find the INT4 weights that minimize the squared difference in layer outputs compared to FP16.
Concretely, for each column of the weight matrix, GPTQ:
This "compensate as you go" approach is why GPTQ outperforms naive round-to-nearest quantization. The error from quantizing one weight is partially absorbed by adjusting other weights.
python # GPTQ quantization with AutoGPTQ from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig # Configure: 4-bit, groups of 128 weights share one scale factor quantize_config = BaseQuantizeConfig( bits=4, # INT4 quantization group_size=128, # scale factor per 128 weights desc_act=True, # activation-order quantization ) # Load model in FP16 model = AutoGPTQForCausalLM.from_pretrained( "meta-llama/Llama-3.1-70B", quantize_config=quantize_config, ) # Quantize with 128 calibration examples # This runs forward passes to collect activation statistics, # then quantizes each layer using the Hessian-based method model.quantize(calibration_dataset) # ~4 GPU-hours for 70B # Save quantized model — weights now INT4 with FP16 scales model.save_quantized("Llama-3.1-70B-GPTQ-INT4") # Result: 140 GB → 35 GB (+ ~1 GB for scale factors) # Custom CUDA kernels achieve 3.25x speedup over FP16
The group_size parameter controls granularity. With group_size=128, every 128 weights share one scale factor and one zero point. Smaller groups (64, 32) give better accuracy but add overhead because you store more scale factors. Group size 128 is the standard tradeoff.
AWQ starts from a different insight. Instead of compensating for quantization error across columns (like GPTQ), it asks: which weights matter most, and how can we protect them?
The answer: salient weights are the ones connected to large-magnitude activations. If an activation channel typically has values around 50 while most channels are around 1, the weights feeding that channel are 50x more important — a small error in those weights gets amplified 50x in the output.
AWQ's approach:
The mathematical trick is elegant. For a weight w with corresponding activation a:
The input-output relationship doesn't change. But quantizing w · s (a larger number) to INT4 produces less relative error than quantizing w directly. The scale factor s is chosen per channel to minimize total quantization error.
AWQ is generally preferred over GPTQ for deployment because it produces slightly better accuracy at the same bit width and the quantized models are faster to load.
SmoothQuant solves a specific problem: quantizing activations. Remember, activations have outlier channels — values 100x larger than the median. If you try naive INT8 quantization on activations, those outliers force the quantization range to be enormous, and normal values get crushed to zero or one.
SmoothQuant's insight: migrate the quantization difficulty from activations to weights. Weights are easy to quantize (smooth distribution). Activations are hard (outlier channels). So multiply weights by the per-channel activation scale, and divide activations by the same scale. Now the activations are "smooth" (no outliers) and the weights have absorbed the variation.
After this transformation, both X and W have similar magnitude ranges and can be cleanly quantized to INT8. The result is W8A8: both weights AND activations in INT8. This enables integer matrix multiplication on tensor cores, which is significantly faster than FP16 matmul on supported hardware.
SmoothQuant is training-free — you just need a small calibration set to compute the per-channel activation scales. It works out of the box with most transformer models.
These two compression techniques are complementary:
| Dimension | Quantization | Pruning |
|---|---|---|
| What it reduces | Bits per weight | Number of weights |
| Example | 70B FP16 → 70B INT4 | 70B → 35B (50% sparse) |
| Memory savings | 4x (FP16→INT4) | 2x (50% sparsity, structured) |
| Speed improvement | 2-4x with custom kernels | Variable (requires sparse kernels) |
| Accuracy impact | Small at 4-bit, large at 2-bit | Depends on what's pruned |
| Combined? | Yes — prune to 50%, then quantize to INT4 = 8x compression | |
Select a model size, precision, and GPU to see memory usage and whether the model fits. The remaining memory is available for KV cache and batching.
Quantize when:
Don't quantize when:
DeepSeek-V3 has 671 billion parameters. In FP16 that's 1.34 TB. The biggest single GPU has 192 GB. You need at minimum 7 GPUs just for weights — and that leaves zero room for KV cache, activations, or framework overhead. You probably need 10-16 GPUs. How do you split a model across GPUs without breaking it?
This is the problem of model parallelism, and there isn't a single solution — there are four strategies, each with different tradeoffs. Real production deployments combine them. Understanding when and why to use each one is the difference between a system that serves 100 users and one that serves 100,000.
Data Parallelism is the simplest strategy: replicate the entire model on every GPU. Split the incoming batch of requests across GPUs. Each GPU processes its subset independently, in parallel.
Imagine a restaurant with 4 identical kitchens, each with the same chef and the same recipe book. When 40 orders come in, you send 10 to each kitchen. Each kitchen works independently. Throughput scales linearly: 4 kitchens serve 4x the customers.
Pros: Dead simple. No inter-GPU communication during inference (each replica is self-contained). Throughput scales linearly with GPU count. Each GPU can serve different requests at different stages of generation.
Cons: Every replica needs the FULL model in memory. If your model is 140 GB (70B FP16) and your GPU has 80 GB, data parallelism won't help — the model doesn't fit on a single GPU in the first place. You also can't reduce per-request latency (each request still runs on one GPU).
Best for: Models that already fit on one GPU but need higher aggregate throughput. A 7B model in INT4 (3.5 GB) on 8x H100s gives you 8x the requests per second.
Let's put numbers on it. Say one 7B INT4 model on one H100 handles 50 requests/second. With DP=8:
Tensor Parallelism is the answer when a single layer is too large for one GPU. Instead of replicating the model, you slice individual weight matrices across GPUs. Each GPU computes a portion of every matrix multiply, then the partial results are combined via an all-reduce communication operation.
Back to the restaurant analogy: instead of 4 identical kitchens, you have 4 chefs in ONE kitchen. Each chef handles a different part of every dish — one does the protein, one the sauce, one the sides, one the plating. They must coordinate on every dish (communication), but together they can make dishes too complex for any single chef.
In a transformer, the biggest weight matrices are in the attention layers (Q, K, V projections) and the MLP (up-project, down-project). With TP=4, each matrix is split into 4 column slices. GPU 0 computes the first quarter of each output, GPU 1 the second quarter, and so on. After each layer, an all-reduce gathers the partial results.
Pros: Enables serving models that don't fit on one GPU. Reduces per-GPU memory proportionally (TP=4 means ~4x less memory per GPU). Can also reduce latency because each GPU does less compute per layer.
Cons: Requires fast inter-GPU communication. Each transformer layer needs 2 all-reduce operations (one for attention, one for MLP). On NVLink (900 GB/s on H100), this takes microseconds. On PCIe (64 GB/s), it's 14x slower. Across nodes over InfiniBand (400 Gb/s = 50 GB/s), it's even worse. This is why TP is almost exclusively used within a single node.
Best for: Models too large for one GPU, deployed on a multi-GPU node with fast NVLink interconnect. TP=8 within a single 8-GPU node is the standard configuration for 70B models.
Pipeline Parallelism assigns different layers to different GPUs. Data flows through GPUs like an assembly line: GPU 0 processes layers 0-15, passes the output to GPU 1 which processes layers 16-31, and so on.
Think of an automobile assembly line. Station 1 does the frame, Station 2 adds the engine, Station 3 does the interior, Station 4 does the paint. Each car moves through all stations sequentially. But multiple cars are in the pipeline simultaneously — Station 1 starts a new car while Station 4 finishes the previous one.
The pipeline bubble problem: The naive approach has terrible utilization. When GPU 3 is processing the first request, GPUs 0-2 are idle. The solution is microbatching: split a batch into microbatches and pipeline them. While GPU 1 processes microbatch 1, GPU 0 starts microbatch 2. With enough microbatches, all stages stay busy.
Pros: Minimal communication — only send activation tensors between adjacent stages (once per stage transition, not once per layer like TP). Works well across nodes with slower networks.
Cons: Pipeline bubbles reduce utilization. The pipeline has startup latency (must fill all stages before the first output). Load balancing is tricky — early layers with embeddings and later layers with the LM head may have different compute costs.
Best for: Multi-node deployments where inter-node bandwidth is limited. PP across nodes + TP within nodes is a common hybrid pattern.
Expert Parallelism is specific to Mixture-of-Experts (MoE) models like DeepSeek-V3 (671B, 256 experts) and Mixtral (46.7B, 8 experts). In MoE models, each layer has multiple "expert" sub-networks, but only a few (typically 2) are activated for any given token. Expert Parallelism distributes different experts to different GPUs.
Think of a hospital with specialist doctors. You don't need every specialist for every patient. A patient with a broken bone goes to orthopedics. A patient with a rash goes to dermatology. The hospital distributes specialists across offices (GPUs), and a triage nurse (router) sends each patient to the right one.
Pros: Efficient for MoE models — each GPU only needs to store a fraction of the experts. Compute scales well because only activated experts run. Memory-efficient because you don't need to replicate all experts everywhere.
Cons: Load balancing is a real challenge. If the router sends most tokens to experts 0-31, GPU 0 is overloaded while GPUs 1-7 sit idle. MoE models use auxiliary loss functions during training to encourage balanced routing, but perfect balance is never achieved. The all-to-all communication pattern (tokens scatter to GPUs, results gather back) requires careful networking.
Best for: MoE models exclusively. If your model has a standard dense architecture, EP doesn't apply.
Real production deployments almost never use a single strategy. They combine them based on hardware topology. The most common pattern:
Let's walk through a concrete example:
Notice how the constraints force your hand. The weight memory alone eliminates most configurations. You have to do the arithmetic to find what actually fits.
Here's the general process for choosing a parallelism strategy:
Switch between parallelism strategies to see how the model is distributed across GPUs. Arrows show data flow and communication patterns.
| Property | Data (DP) | Tensor (TP) | Pipeline (PP) | Expert (EP) |
|---|---|---|---|---|
| What's split | Batch | Weight matrices | Layers | Expert sub-networks |
| Communication | None during inference | All-reduce per layer (heavy) | Between stages only (light) | All-to-all per MoE layer |
| Memory savings | None — full copy per GPU | ~N× with N GPUs | ~N× with N GPUs | ~N× for expert weights |
| Latency impact | No change per request | Slight increase (all-reduce) | Increase (pipeline startup) | Slight increase (routing) |
| Throughput | N× linear scaling | Slight improvement | Depends on microbatching | Good for MoE |
| Best interconnect | Any (no communication) | NVLink (intra-node) | InfiniBand (inter-node OK) | Fast intra-node preferred |
| Best for | Small models, need throughput | Large models, single node | Very large models, multi-node | MoE models only |
Your chatbot serves 10,000 users. Every request includes the same 2,000-token system prompt. That means you're re-processing 20 million tokens of identical text every day. What if you could compute it once and reuse it forever?
This chapter covers the serving-side optimizations that turn a working model into a production system: prefix caching, prefill-decode disaggregation, KV cache offloading, prefix-aware routing, and the frameworks that tie it all together. These aren't academic ideas — they're what separates a $50K/month GPU bill from a $5K/month one.
Recall from earlier chapters: prefill is the expensive first phase where the model processes the entire input prompt and builds the KV cache. For a 2,000-token system prompt, prefill computes attention over 2,000 tokens across all layers — tens of billions of floating-point operations.
Prefix caching exploits a simple observation: if two requests share the same prefix (e.g., the same system prompt), their KV cache entries for that prefix are identical. Compute the KV cache once for the shared prefix, store it, and reuse it for every subsequent request that starts with the same tokens.
Here's how it works step by step:
[system prompt (2000 tokens)] + [user message A (50 tokens)][system prompt (2000 tokens)] + [user message B (80 tokens)]The savings are dramatic. Let's put numbers on it:
Best practices for maximizing prefix cache hits:
These sound similar but are fundamentally different:
| Property | KV Cache (within a request) | Prefix Cache (across requests) |
|---|---|---|
| Scope | Single request's decode phase | Across different requests |
| What's reused | Previously-computed K,V from prefill, reused during decode so we only process 1 new token at a time | K,V from a shared prefix, reused by new requests starting with the same tokens |
| Lifetime | Dies when the request completes | Persists for minutes/hours in cache |
| Without it | Decode is quadratic (recompute all attention each step) | Every request re-prefills the shared prefix |
| Everyone uses it? | Yes — required for efficient autoregressive decode | No — optional optimization, most useful with shared prefixes |
Compare TTFT with and without prefix caching. Adjust the prefix length and toggle cache hit/miss to see the impact on time-to-first-token.
Here's a subtle but critical problem with colocating prefill and decode on the same GPU: they fight over resources.
Prefill is compute-bound. It processes hundreds or thousands of tokens in one big batch of matrix multiplications. It wants the GPU's tensor cores running at full blast. GPU utilization during prefill is often 60-80%.
Decode is memory-bandwidth-bound. It processes exactly one token per request per step. The matrix multiplies are tiny — most of the time is spent loading weights from HBM. GPU compute utilization during decode is often only 5-15%.
When you colocate them, prefill hogs the compute units and starves decode of memory bandwidth. A user in the decode phase (streaming tokens) sees their token generation stall whenever the GPU starts a prefill for a new request. This is called prefill-decode interference, and it causes unpredictable latency spikes.
Prefill-decode disaggregation (PD disaggregation) solves this by separating them onto different GPUs:
Let's quantify the KV cache transfer cost to see when disaggregation makes sense:
Benefits of disaggregation:
When NOT to disaggregate:
Toggle between colocated and disaggregated architectures. The colocated view shows interference between prefill and decode. The disaggregated view shows smooth operation.
As we discussed in the KV cache chapter, long contexts create enormous KV caches that eat into GPU memory. A single 128K-context request on Llama 3.1 70B generates ~40 GB of KV cache — half of an H100's total memory. With just two such requests, you've consumed all available VRAM for KV cache alone.
KV cache offloading moves inactive KV data to cheaper, larger storage tiers:
| Tier | Capacity | Bandwidth | Latency | Cost ($/GB) |
|---|---|---|---|---|
| GPU HBM | 80-192 GB | 2-3.35 TB/s | ~ns | Highest |
| CPU DRAM | 512 GB - 2 TB | 100-200 GB/s | ~100 ns | Medium |
| NVMe SSD | 2-32 TB | 5-14 GB/s | ~10 μs | Low |
| Remote (network) | Unlimited | 12-50 GB/s | ~100 μs | Lowest |
The strategy: keep the KV cache for the tokens being actively decoded in GPU HBM. Move the KV cache for earlier (older) tokens to CPU DRAM or SSD. When the model's attention mechanism needs those older tokens, prefetch them back to GPU just before they're needed.
This is analogous to virtual memory in operating systems. Not all data needs to be in the fastest storage at all times. A smart paging system keeps hot data in fast memory and cold data in slow-but-large memory.
LMCache is a popular extension for vLLM that implements multi-tier KV cache offloading. It manages KV cache across GPU/CPU/disk tiers with automatic paging. Results from production deployments show 3-10x latency reduction for long-context workloads, because without offloading, those requests would either be rejected (OOM) or queued waiting for GPU memory to free up.
NVIDIA reports 14x faster TTFT with KV cache offloading for sequences exceeding 100K tokens, because the alternative — recomputing from scratch each time — is far more expensive than loading cached KV from CPU DRAM.
You don't implement all of this from scratch. Production inference frameworks package these optimizations into deployable systems. Here's the landscape:
High-throughput serving (cloud/datacenter):
| Framework | Key Innovation | Strengths | Limitations |
|---|---|---|---|
| vLLM | PagedAttention (pioneered it) | Best KV cache management, continuous batching, widest model support, huge community | Python overhead, kernel efficiency behind SGLang in some benchmarks |
| SGLang | RadixAttention (tree-based prefix cache) | Fastest throughput in many benchmarks, structured output, prefix-aware routing via radix tree | Smaller community, fewer model architectures |
| TensorRT-LLM | NVIDIA-optimized kernels | Best per-GPU kernel performance (NVIDIA hardware), INT4/FP8 support, flight recorder profiling | NVIDIA-only, complex configuration, slower iteration |
| LMDeploy | Turbomind engine | Good throughput, persistent batch scheduling, KV cache quantization built-in | Smaller ecosystem |
| TGI | Hugging Face integration | Easy deployment, model hub integration | Maintenance mode — HF shifting focus to other projects |
Local/edge serving (laptops, phones, embedded):
| Framework | Language | Key Feature | Limitations |
|---|---|---|---|
| llama.cpp | C/C++ | Runs anywhere: CPU, Metal, CUDA, Vulkan. GGUF format. Highly optimized quantization kernels. | Single-request focus, less suited for serving |
| Ollama | Go + llama.cpp | Dead simple UX: ollama run llama3.1. Auto-downloads models, manages VRAM, REST API. | Single-request only. Not a production serving solution. |
| MLC-LLM | TVM-based | Cross-platform compilation: iOS, Android, browser (WebGPU), desktop. Universal deployment. | More complex setup, smaller model coverage |
In a distributed serving setup with multiple workers, each worker builds up its own prefix cache. If a request with system prompt X goes to Worker 3 the first time, Worker 3 caches that prefix. But if the next request with the same prefix goes to Worker 7, it's a cache miss — Worker 7 has to compute the prefix from scratch.
Prefix-aware routing solves this by sending requests to the worker that already has the right prefix cached. Three approaches have emerged:
1. Worker-reported (NVIDIA Dynamo): Each worker reports its cached prefixes to the router. The router maintains a mapping: "Worker 3 has prefix hash ABC123, Worker 7 has prefix hash DEF456." When a new request arrives, the router hashes its prefix and looks up which worker has it cached. Simple and accurate, but workers must continuously update the router.
2. Router-predicted (SGLang radix tree): The router maintains a radix tree (prefix tree) of all cached prefixes across all workers. It can compute the longest matching prefix for any incoming request without asking workers. More sophisticated but requires the router to track every cache insertion and eviction.
3. Hybrid (llm-d): Combines heuristic routing with worker feedback. The router makes an initial guess based on consistent hashing, and workers asynchronously report cache state to correct future decisions. Balances simplicity with accuracy.
Putting it all together, here's what a production LLM serving system looks like end-to-end:
Let's calculate the real cost of serving Llama 3.1 70B to 1,000 concurrent users, comparing naive vs optimized deployment:
You're running a production LLM server. Requests are flooding in. Your job: keep latency low, throughput high, and GPUs busy. Every optimization we've covered — continuous batching, PagedAttention, speculative decoding, prefix caching, quantization, parallelism — is a knob you can turn. Let's see how they all fit together.
The simulator below models a complete inference serving pipeline. On the left, requests arrive from users at a configurable rate. In the center, the scheduler batches them and routes them through prefill and decode stages. On the right, completed responses stream out. The dashboard at the bottom tracks every metric that matters in production.
This is your sandbox. There are no wrong answers — only configurations that reveal tradeoffs.
Requests arrive (left), get scheduled and processed (center), and stream out (right). The dashboard below tracks latency, throughput, utilization, and memory in real time.
BATCHING STRATEGY
OPTIMIZATIONS
Follow this sequence to build intuition for how each optimization affects the system. Watch the metrics dashboard — especially throughput (TPS), TTFT, and GPU utilization.
Experiment 1: Baseline suffering. Set 1 GPU, static batching, all optimizations off, request rate 10/s. Hit Start. Watch the queue build up, latency spike, and GPU utilization hover at a pitiful level. This is what happens when you serve a model naively — most of the time, the GPU is either waiting for a batch to fill or blocked on the slowest sequence.
Experiment 2: Continuous batching saves the day. Switch to continuous batching. Immediately, throughput jumps. The queue drains faster. Why? Because the scheduler no longer waits for every sequence in a batch to finish before starting new ones. As soon as one sequence completes, a waiting request takes its slot. The GPU never sits idle waiting for stragglers.
Experiment 3: Prefix caching crushes TTFT. Enable prefix caching. Watch the Time to First Token metric drop dramatically. When multiple requests share a common system prompt (which they almost always do in production), the KV cache for that prefix is computed once and reused. Instead of re-running prefill on 2000 shared tokens for every request, the server looks them up instantly.
Experiment 4: Speculative decoding smooths ITL. Enable speculative decoding. The inter-token latency improves because the system generates multiple draft tokens in parallel and verifies them in a single forward pass. On average, 2–4 tokens get accepted per step instead of 1. The output feels noticeably snappier.
Experiment 5: Scale GPUs. Increase GPU count to 4, then 8. More GPUs means more memory for KV cache (larger batches), plus tensor parallelism reduces per-token latency for large models. Watch the throughput ceiling rise. But notice: doubling GPUs doesn't double throughput — communication overhead and Amdahl's law eat some of the gains.
Experiment 6: Find the breaking point. With all optimizations on and 8 GPUs, crank the request rate up. Find the point where the queue starts growing unboundedly — that's your saturation throughput. Beyond this rate, no configuration saves you. You need either a bigger model deployment or request shedding.
Experiment 7: The model size wall. Switch from 7B to 70B. Even with 8 GPUs and all optimizations, the throughput drops dramatically. Larger models need more memory bandwidth per token. Try enabling INT8 quantization — the model fits more comfortably, freeing memory for larger batches, and compute runs faster on quantized weights.
Real production inference systems add layers of complexity that this simulation abstracts away. Here's what sits between our toy server and a real deployment:
Autoscaling. Cloud deployments don't run a fixed number of GPUs. They monitor queue depth and latency, spinning up new replicas when load spikes and scaling down during quiet periods. The challenge: GPU instances take 2–5 minutes to start, so you need predictive scaling, not just reactive.
Health checks and circuit breakers. GPU processes crash. CUDA OOM errors happen. A production system continuously health-checks each worker, drains unhealthy replicas, and reroutes traffic — all without dropping user requests.
Load balancing. Not all requests are equal. A 4000-token prompt needs far more prefill compute than a 50-token prompt. Smart load balancers estimate per-request cost and route to the least-loaded replica, rather than round-robin.
Observability. You can't optimize what you can't measure. Production systems export detailed metrics — per-request latency breakdowns (queue time, prefill time, decode time), GPU memory fragmentation, KV cache hit rates, batch utilization histograms — to monitoring systems like Prometheus/Grafana.
Cost attribution. When serving multiple customers or applications, you need to track which requests consumed which resources. This feeds into billing, capacity planning, and priority scheduling.
Multi-region routing. For global applications, requests route to the nearest datacenter with available GPU capacity. This adds a geo-routing layer, cross-region failover, and the challenge of keeping model weights synchronized across regions.
You now understand the complete LLM inference stack — from the memory wall that makes generation slow, through KV caching, attention optimizations, batching strategies, quantization, speculative decoding, parallelism, prefix caching, and disaggregated serving, all the way to production deployment. Here's your reference card.
Every technique from this lesson in one table. When you're tuning a production system, start here.
| Technique | What It Fixes | Tradeoff | When to Use |
|---|---|---|---|
| Continuous Batching | GPU idle time between batches; head-of-line blocking from slow sequences | More complex scheduler; slightly higher per-request overhead | Always. No reason not to use it in production serving |
| FlashAttention | Quadratic memory in attention; slow HBM reads for Q/K/V | Custom CUDA kernel required; limited to supported GPU architectures | Always for long contexts. Standard in all modern serving frameworks |
| PagedAttention | KV cache memory fragmentation; wasted pre-allocated memory | Page table management overhead; slightly complex memory bookkeeping | Always when serving variable-length sequences (i.e., always) |
| Speculative Decoding | High inter-token latency during autoregressive decode | Needs a fast draft model; wasted compute on rejected tokens; less benefit at high batch sizes | Latency-sensitive applications; low-batch scenarios; chatbots where perceived speed matters |
| Quantization (INT8/INT4) | Model too large for GPU memory; memory bandwidth bottleneck | Small quality degradation; calibration data needed for best results | When model doesn't fit in FP16; when you need higher throughput on fixed hardware |
| Tensor Parallelism | Model too large for one GPU; single-GPU decode latency too high | All-reduce communication overhead; requires NVLink for efficiency | Models >~13B params that don't fit on one GPU; latency-sensitive serving |
| Pipeline Parallelism | Model too large even for TP across available GPUs | Pipeline bubbles reduce utilization; higher latency than TP alone | Very large models (100B+); when you have more GPUs than TP can use efficiently |
| Data Parallelism | Throughput ceiling on a single model replica | Linear cost scaling; each replica needs full model memory | When one replica can't handle the request rate; simplest way to scale throughput |
| Expert Parallelism | MoE models have too many total parameters for one GPU | All-to-all communication for token routing; load imbalance across experts | MoE architectures (Mixtral, Switch Transformer, DeepSeek-V3) |
| Prefix Caching | Redundant prefill computation for shared system prompts | Memory used to store cached prefixes; cache management complexity | When many requests share common prefixes (system prompts, few-shot examples) |
| PD Disaggregation | Prefill and decode have conflicting resource needs on shared GPUs | Network transfer of KV cache between pools; more complex orchestration | High-scale deployments where prefill/decode interference hurts both phases |
| KV Cache Offloading | GPU memory limits the number of concurrent sequences | CPU/disk read latency when swapping back; needs fast interconnect | Long-context workloads where KV cache exceeds GPU memory; lower-priority requests |
The equations that govern LLM inference performance. Bookmark this table.
| Formula | Variables | When to Use |
|---|---|---|
Memory = 2 × P × bytes_per_param |
P = parameter count; bytes_per_param = 2 (FP16), 1 (INT8), 0.5 (INT4) | Estimating GPU memory needed for model weights |
KV = 2 × L × H × D × S × B × bytes |
L = layers, H = KV heads, D = head dim, S = seq len, B = batch size | Estimating KV cache memory consumption |
TPOT = (2P × bytes) / BW |
P = params, BW = memory bandwidth (bytes/s); assumes memory-bound regime | Estimating per-token decode latency (single sequence) |
Speedupspec = α / (1 − αk+1) |
α = acceptance rate, k = draft length | Expected tokens per verification step in speculative decoding |
AI = (2 × FLOPS) / (bytes transferred) |
AI = arithmetic intensity (ops/byte); compare to machine's compute/BW ratio | Determining if a workload is compute-bound or memory-bound (roofline model) |
TPSmem = BW / (2 × P × bytes_per_param) |
TPS = tokens per second; BW = memory bandwidth | Maximum single-stream decode throughput (memory-bandwidth ceiling) |
When deploying a model to production, walk through this decision tree. Each question narrows your configuration.
Once parallelism is sorted, layer on optimizations:
LLM inference is moving fast. Here are the frontiers worth watching:
Diffusion LLMs (dLLMs). Instead of generating tokens one at a time left-to-right, diffusion-based language models generate text by iteratively denoising a sequence of noise tokens in parallel. Models like Mercury and Gemini Diffusion can generate hundreds of tokens simultaneously, sidestepping the autoregressive bottleneck entirely. Early results show 5–10× faster generation for certain workloads, but quality is still catching up to autoregressive models for complex reasoning.
Kernel-level optimizations. Custom CUDA kernels and Triton programs squeeze the last drops of performance from GPU hardware. Fused kernels that combine attention, RoPE, and normalization into a single GPU launch eliminate memory round-trips. Libraries like FlashAttention-3, ThunderKittens, and custom Triton kernels push throughput beyond what general-purpose frameworks achieve.
Mixture of Agents (MoA). Instead of making one model faster, route queries to a mixture of specialized models. A small, fast model handles simple queries; a large, expensive model handles hard ones. The router itself can be a lightweight classifier trained on difficulty estimation. This trades model complexity for system complexity but can dramatically improve cost-efficiency.
Hardware co-design. Next-generation accelerators (Groq LPU, Cerebras WSE, custom ASICs) are designed specifically for inference workloads. They trade general-purpose compute for massive memory bandwidth and deterministic latency. The inference stack will need to adapt to radically different hardware characteristics.
This lesson connects to the broader Engineermaxxing curriculum. Continue your journey:
This lesson was built from the BentoML LLM Inference Handbook, the vLLM paper (Kwon et al., 2023), the Flash Attention papers (Dao et al., 2022, 2023), the Speculative Decoding paper (Leviathan et al., 2023), and related materials from the open-source inference community.