Introduction

A 70-billion parameter model takes thousands of GPU-hours to train. But the real engineering challenge — the one that determines whether that model can serve production traffic — is inference. Every token your chatbot generates requires a full forward pass through the entire network. At 80 layers and billions of parameters, this pass is expensive, and you need to do it hundreds of times per response, for thousands of concurrent users.

The central insight that makes modern LLM inference tractable is the key-value cache: instead of recomputing attention over the entire context at each generation step, we cache the intermediate key and value tensors from previous steps. This single optimization transforms a quadratic-cost process into a linear one. But it introduces its own challenge: the KV cache itself becomes the dominant consumer of GPU memory, often exceeding the model weights in size for long sequences.

Around this caching primitive, an entire ecosystem of optimizations has emerged — continuous batching to maximize GPU utilization, speculative decoding to trade cheap compute for reduced latency, sophisticated sampling strategies that shape output quality, and parallelism schemes that distribute inference across multiple devices. This article covers all of them.

ℹ What this article covers
We'll dissect the autoregressive decode loop, the prefill/decode split, KV cache memory math with real model calculations, continuous batching, speculative decoding, sampling strategies (temperature, top-k, top-p, beam search), tensor and pipeline parallelism for inference, and the major inference frameworks (vLLM, TGI, TensorRT-LLM, SGLang).

Autoregressive Generation

Transformer language models generate text one token at a time. Given a prompt of n tokens, the model produces a probability distribution over the entire vocabulary for position n + 1, a token is selected from that distribution, appended to the context, and the process repeats. This is the autoregressive decode loop, and understanding its structure is essential to understanding every optimization that follows.

At each step, the model computes a full forward pass: the input token passes through every layer, each layer computes self-attention over the full sequence seen so far, applies feed-forward transformations, and produces a hidden state. The final hidden state is projected to vocabulary logits via the language model head. From those logits, a sampling strategy (greedy, top-k, nucleus) selects the next token. The loop terminates when the model emits an end-of-sequence token or a maximum length is reached.

The critical observation: in a naive implementation, generating the k-th token requires attending over all k - 1 previous tokens. Generating a 1,000-token response from a 2,000-token prompt means the model processes the equivalent of 2001 + 2002 + ... + 3000 tokens of attention computation — roughly 2.5 million token-pair interactions. This quadratic growth is what the KV cache eliminates.

Prefill vs Decode Phases

Inference has two distinct computational phases, each with radically different performance characteristics:

Prefill (also called the "prompt processing" phase) processes the entire input prompt in a single forward pass. Because the prompt is known in advance, all positions can be computed in parallel — the model processes all n prompt tokens simultaneously through each layer. This phase is compute-bound: the GPU's tensor cores are fully saturated performing large matrix multiplications. Prefill produces the KV cache entries for every prompt token across every layer.

Decode (the "generation" phase) produces tokens one at a time. Each decode step processes a single new token, attending to all previous tokens via the cached keys and values. This phase is memory-bandwidth-bound: the matrices involved are thin (batch dimension of 1 per sequence), so the GPU spends most of its time loading weights and KV cache entries from HBM rather than computing. The arithmetic intensity (FLOPs per byte loaded) is extremely low.

This prefill/decode split is fundamental to inference system design. Prefill wants large batch sizes and maximum compute utilization. Decode wants to minimize memory access latency. Some systems (like Splitwise and DistServe) physically separate prefill and decode onto different hardware optimized for each workload.

TTFT vs TPS: The Two Latency Metrics

Two metrics define the user experience of LLM inference:

Time to First Token (TTFT) measures the delay from request submission to the first generated token arriving. This is dominated by the prefill phase — the time to process the entire prompt and generate the first output token. For interactive applications, TTFT determines how "responsive" the system feels. Users perceive delays above ~500ms as sluggish. A 100K-token context might take several seconds to prefill, creating noticeable lag.

Tokens Per Second (TPS), also reported as inter-token latency (ITL), measures the decode speed after the first token. Each subsequent token requires one forward pass through the model with the KV cache. For a typical 70B model on a single A100, decode speed ranges from 30-60 tokens/second per user. Since humans read at roughly 250 words/minute (about 5-6 tokens/second), decode speeds above ~15 TPS feel "instant" for streaming interfaces.

There is a fundamental tension: maximizing throughput (total tokens/second across all users) requires large batches, which increase per-user latency. Inference systems must navigate this tradeoff constantly, balancing between serving individual requests quickly and processing many requests efficiently.

The KV Cache

Why Cache Keys and Values?

Recall how self-attention works: for each token at position i, we compute a query vector q_i, then dot it against key vectors k_j for all positions j ≤ i to get attention weights, then use those weights to combine value vectors v_j. The crucial property is that the key and value vectors for position j depend only on the input at position j (and the positions before it in causal attention), not on any future tokens.

This means that once we've computed k_j and v_j during the processing of position j, those vectors are immutable — they will never change regardless of what tokens come later. So instead of recomputing them at every subsequent generation step, we store them in GPU memory. At step t, the model only needs to compute q_t, k_t, and v_t for the new token, then attend against the cached K = [k_1, ..., k_{t-1}] and V = [v_1, ..., v_{t-1}] matrices.

Without the KV cache, generating a sequence of length T from a prompt of length P requires O((P + T)^2) total attention computation. With the KV cache, prefill costs O(P^2) and each decode step costs O(P + t) where t is the current generation position, giving a total of O(P^2 + P*T + T^2/2). For typical use cases where P >> T or T is moderate, this is a massive saving.

Σ Complexity comparison
Without KV cache: ∑_{t=1}^{T} (P + t) = P*T + T(T+1)/2 attention operations per step, but each requires recomputing all K,V from scratch
With KV cache: Prefill: O(P^2) once. Then each decode step: O(P + t) for one query against cached K,V.
For P=4096, T=512: naive recomputation does ~44x more work than KV-cached decoding.
KV Cache Growth During Generation Interactive
Token 0 / 32 — Cache: 0 entries

Memory Mathematics

The KV cache stores, for each token in the sequence, one key vector and one value vector per attention head per layer. The memory formula is:

KV Cache Memory = 2 × n_layers × n_kv_heads × d_head × seq_len × bytes_per_param

The factor of 2 accounts for both keys and values. n_kv_heads may differ from n_heads when using Grouped Query Attention (GQA) or Multi-Query Attention (MQA) — techniques that share key/value heads across multiple query heads to reduce cache size. In standard Multi-Head Attention (MHA), n_kv_heads = n_heads. In MQA, n_kv_heads = 1. In GQA, n_kv_heads = n_heads / group_size.

The precision term depends on the serving format. FP16/BF16 uses 2 bytes per parameter. FP8 and INT8 quantization reduce this to 1 byte. INT4 (GPTQ, AWQ) gets it to 0.5 bytes. The choice of quantization for the KV cache is independent of the model weight quantization — you can run INT4 weights with FP16 KV cache, or FP16 weights with INT8 KV cache.

Real Model Calculations

Let's work through the KV cache sizes for models commonly served in production:

LLaMA 3 8B: 32 layers, 8 KV heads (GQA with group size 4), head dimension 128. At FP16 with 4K context: 2 × 32 × 8 × 128 × 4096 × 2 bytes = 512 MB per sequence. At 128K context: 16 GB per sequence — already twice the model weights in INT4.

LLaMA 3 70B: 80 layers, 8 KV heads (GQA), head dimension 128. At FP16 with 4K context: 2 × 80 × 8 × 128 × 4096 × 2 = 1.28 GB per sequence. At 128K: 40 GB per sequence. On a single 80GB A100, the model weights in FP16 take ~140GB (needs multi-GPU), and even in INT4 (~35GB), a single 128K sequence's KV cache cannot fit alongside the model.

Mixtral 8x7B (MoE): 32 layers, 8 KV heads, head dimension 128. Despite having 47B total parameters, its KV cache matches the 7B dense model because the MoE routing only affects the FFN experts, not the attention heads. At 4K context in FP16: 512 MB per sequence.

💡 KV cache is the bottleneck
For a batch of 64 concurrent users with 4K context on LLaMA 70B at FP16, the total KV cache alone requires 64 × 1.28 GB = 82 GB — more than an entire A100's memory. This is why KV cache quantization, GQA, paged attention, and aggressive memory management are not optional optimizations but hard requirements for production serving.
KV Cache Memory Calculator Calculator
Per-sequence KV cache --
Total batch KV cache --
KV entries per token --
Bytes per token --

Batching Strategies

A single decode step for one sequence barely utilizes a modern GPU. The weight matrices must be loaded from HBM regardless — whether they multiply against one vector or sixty-four. Batching multiple sequences together amortizes this memory bandwidth cost, transforming the memory-bound decode phase into something closer to compute-bound. The question is how to batch.

Static Batching: The Naive Approach

Static batching (also called "synchronous batching") groups incoming requests into a fixed batch, processes the entire batch until every sequence has finished generating, then returns all results and accepts a new batch. This is simple to implement but deeply wasteful.

The problem is obvious: sequences finish at different times. If one request generates 20 tokens and another generates 500, the 20-token response is complete in seconds but must wait for the 500-token response to finish before the slot can be reused. The GPU processes "padding" — empty computation for completed sequences — wasting both compute and memory. Worse, new requests must queue until the entire batch completes, inflating tail latency.

In practice, static batching achieves only 30-50% GPU utilization for typical workloads with high variance in output length. For the long-context era, where some requests have 100K-token prefills and others have 500-token prompts, static batching is essentially unusable.

Continuous Batching: Iteration-Level Scheduling

Continuous batching (introduced in the Orca paper, 2022) makes scheduling decisions at the granularity of individual decode iterations rather than entire request lifecycles. After each decode step, the scheduler can:

  • Evict completed sequences, immediately freeing their KV cache memory and batch slot
  • Insert new sequences that were waiting in the queue, running their prefill and joining the decode batch
  • Preempt lower-priority sequences if memory is tight, swapping their KV cache to CPU and restoring later

This iteration-level scheduling keeps batch slots occupied nearly 100% of the time. When a short response finishes at iteration 20, a new request immediately takes its slot — no waiting for the rest of the batch. GPU utilization jumps from 30-50% to 85-95%, and throughput increases 2-3x compared to static batching for the same hardware.

The engineering challenge is memory management. Each new sequence needs contiguous KV cache space, but sequences have unpredictable lengths. This is where PagedAttention (introduced by vLLM) becomes critical: it manages KV cache memory in fixed-size "pages" (like OS virtual memory), eliminating fragmentation and enabling near-zero waste of GPU memory. Without paged attention, continuous batching must conservatively pre-allocate maximum-length KV cache per sequence, limiting batch sizes.

Static vs Continuous Batching Interactive
Iteration 0 / 40
Prefill
Decode
Idle / Padding
Complete

Speculative Decoding

Standard autoregressive decoding has an inherent inefficiency: each token requires a full forward pass through the large model, but the GPU is barely utilized during decode (low arithmetic intensity). Speculative decoding exploits this gap by using a small, fast "draft" model to predict multiple tokens ahead, then verifying all of them in a single forward pass of the large "target" model.

The Draft-and-Verify Protocol

The algorithm works as follows:

  1. Draft: Run a small model (e.g., a 1B parameter model for a 70B target) autoregressively for K steps (typically K = 4-8), generating candidate tokens [d_1, d_2, ..., d_K].
  2. Verify: Feed the original context plus all K draft tokens into the target model in a single forward pass. This gives us the target model's probability distributions at all K positions simultaneously.
  3. Accept/Reject: Compare each draft token against the target model's distribution. Accept tokens sequentially as long as they match (or are acceptable under a stochastic acceptance criterion). On the first rejection, resample from an adjusted distribution and discard remaining draft tokens.
  4. Repeat: Continue from the last accepted position.

The key insight is that the verification step processes K tokens in parallel — the same cost as processing 1 token in decode mode (roughly, since prefill-style parallel processing is compute-bound, not memory-bound, and K is small). If the draft model matches the target well and most tokens are accepted, we generate K + 1 tokens for the cost of one target model forward pass plus K cheap draft model passes.

Acceptance Criteria

The stochastic acceptance scheme preserves the exact output distribution of the target model. For each draft token d_i with draft probability q(d_i) and target probability p(d_i):

Accept with probability: min(1, p(d_i) / q(d_i))

If the target model assigns higher probability to the draft token than the draft model did, it's always accepted. If lower, it's accepted proportionally. On rejection, we sample from a corrected distribution norm(max(0, p(x) - q(x))) to ensure the final distribution is exactly p. This mathematical guarantee means speculative decoding is lossless — the output distribution is identical to standard sampling from the target model alone.

In practice, well-matched draft/target pairs (e.g., LLaMA 1B drafting for LLaMA 70B) achieve acceptance rates of 70-85% on typical text, yielding effective speedups of 2-3x. Code generation tends to have higher acceptance rates (more predictable tokens) while creative writing has lower rates. The optimal draft length K depends on the acceptance rate — higher acceptance rates justify longer speculative sequences.

Speculative Decoding: Draft & Verify Interactive
Ready — click Run Round to start

Sampling Strategies

The model's forward pass produces a vector of logits over the vocabulary — raw, unnormalized scores for each possible next token. Converting these logits into a chosen token is the job of the sampling strategy, and the choice profoundly affects output quality, creativity, and coherence.

Temperature Scaling

Temperature divides the logits by a scalar T before applying softmax: p_i = exp(z_i / T) / ∑ exp(z_j / T). Temperature controls the "sharpness" of the distribution:

  • T = 1.0: The model's natural distribution. This is the default calibration from training.
  • T < 1.0: Sharpens the distribution — high-probability tokens become even more likely, low-probability tokens become negligible. At T → 0, this converges to greedy decoding (always pick the argmax).
  • T > 1.0: Flattens the distribution — lower-probability tokens get a meaningful chance. The output becomes more "creative" but also more likely to produce incoherent text.

A common misconception: temperature does not change which tokens the model considers possible, only their relative probabilities. A token with logit -100 remains negligible at any reasonable temperature. Temperature is most effective when the model is genuinely uncertain between several plausible continuations.

Top-k and Top-p (Nucleus) Sampling

Top-k sampling restricts the candidate pool to the k most probable tokens, then renormalizes and samples. With k = 50, the model can only choose among its top 50 predictions regardless of how the probability mass is distributed. The problem: for some contexts, the model is very confident (top 3 tokens hold 99% mass) and k = 50 includes too many garbage options. For other contexts, the model is genuinely uncertain (top 200 tokens each hold ~0.5%) and k = 50 cuts off reasonable options.

Top-p (nucleus) sampling addresses this by dynamically sizing the candidate pool. It sorts tokens by probability and includes the smallest set whose cumulative probability exceeds p. With p = 0.95, the model samples from however many tokens are needed to cover 95% of the probability mass — sometimes 3 tokens, sometimes 300. This adapts to the model's confidence at each step, a property top-k lacks.

In practice, most production systems combine temperature with top-p: first scale logits by temperature, then apply nucleus sampling. Typical production settings are T = 0.7, top-p = 0.9 for a balance of quality and diversity. Some systems add a repetition penalty that reduces the logits of tokens that have appeared recently in the output, combating the tendency of language models to enter repetitive loops.

Min-p sampling is a newer approach that sets a minimum probability threshold relative to the top token: any token with probability less than min_p * max_prob is excluded. This shares the adaptive property of nucleus sampling while being simpler to reason about and tune.

Beam search maintains B candidate sequences (beams) at each step, expanding each by all vocabulary tokens, scoring the results, and keeping the top B candidates. It finds high-probability sequences but has notable failure modes for open-ended generation: outputs tend to be repetitive, generic, and short. Beam search implicitly optimizes for the most probable sequence, which often means bland, high-frequency text.

For open-ended generation (chatbots, creative writing), sampling with temperature and top-p consistently produces better results. Beam search remains useful for tasks with a "correct" answer — machine translation, summarization, and code generation where you want the most likely output. Many code-generation systems use beam search or a hybrid approach where multiple samples are generated and ranked.

Parallelism for Inference

Models exceeding ~15B parameters cannot fit on a single GPU in FP16. Even with INT4 quantization, 70B+ models require multiple devices. Two primary parallelism strategies serve inference, each with distinct tradeoffs.

Tensor Parallelism (TP)

Tensor parallelism splits individual weight matrices across GPUs. For a linear layer Y = XW, the weight matrix W is sharded column-wise across N GPUs. Each GPU computes its partition of the output, then an all-reduce operation combines results. For attention layers, different heads are assigned to different GPUs — a natural split since attention heads are independent computations.

Tensor parallelism is the standard approach for latency-sensitive inference because it reduces per-layer computation time by a factor of N. The cost is communication: each layer requires an all-reduce across all participating GPUs, demanding high-bandwidth interconnects (NVLink at 900 GB/s, not PCIe at 64 GB/s). TP works well within a single node (8 GPUs on one server with NVLink) but poorly across nodes where network bandwidth is 100-400 Gb/s.

For the KV cache, tensor parallelism means each GPU stores 1/N of the KV heads. With LLaMA 70B using 8 KV heads across TP=8, each GPU holds exactly one KV head per layer — the minimum granularity. This is one reason GQA uses powers-of-two head counts.

Pipeline Parallelism (PP)

Pipeline parallelism assigns different layers to different GPUs. With a 80-layer model across 4 GPUs, each GPU handles 20 contiguous layers. Activation tensors flow from GPU to GPU as data moves through the network. The KV cache for each layer lives on whichever GPU owns that layer.

Pipeline parallelism requires much less inter-GPU communication than tensor parallelism — only point-to-point transfers between adjacent stages rather than all-to-all communication at every layer. This makes it viable across nodes connected by standard networking.

The downside is pipeline bubbles: during decode, only one pipeline stage is active at a time per sequence (since each stage depends on the previous stage's output). With PP=4 and a single sequence, each GPU is idle 75% of the time. Large batch sizes mitigate this by filling the pipeline, but this conflicts with low-latency requirements. Pipeline parallelism is therefore best suited for throughput-oriented workloads rather than real-time interactive use.

Most production deployments combine both: TP within a node (8-way across NVLink-connected GPUs) and PP across nodes. A 4-node cluster with 8 GPUs each might use TP=8, PP=4 for a total of 32 GPUs serving a single large model.

The Inference Stack

Several open-source and commercial frameworks implement the optimizations described above. Each makes different design choices reflecting different priorities.

vLLM introduced PagedAttention and remains the most widely deployed open-source inference engine. Its core innovation is managing KV cache memory like OS virtual memory — non-contiguous physical pages mapped to logical sequence positions. This eliminates memory fragmentation and enables near-optimal batch sizes. vLLM supports continuous batching, tensor parallelism, and a wide range of models. Its throughput for online serving workloads is typically the benchmark that other systems measure against.

TGI (Text Generation Inference) by Hugging Face provides a production-ready serving solution with built-in support for flash attention, continuous batching, quantization (GPTQ, AWQ, EETQ), tensor parallelism, and a gRPC/HTTP API with streaming. TGI integrates tightly with the Hugging Face ecosystem, making it the simplest path from model training to deployment for teams already using Transformers.

TensorRT-LLM by NVIDIA compiles model graphs into highly optimized CUDA kernels using the TensorRT compiler. It achieves the highest single-stream performance on NVIDIA hardware by fusing operations, using custom attention kernels, and leveraging hardware-specific features (FP8 on Hopper, INT4 on Ada). The tradeoff is complexity: model conversion, engine building, and configuration require significant NVIDIA-specific expertise. TensorRT-LLM supports in-flight batching (NVIDIA's term for continuous batching), speculative decoding, and multi-GPU serving.

SGLang takes a compiler-centric approach, introducing a domain-specific language for structured generation. Its RadixAttention mechanism caches KV states at a prefix tree level, enabling efficient reuse of shared prefixes across requests — critical for systems that prepend the same system prompt to every request. SGLang achieves up to 5x speedup on workloads with shared prefixes and structured output constraints (JSON mode, regex-constrained generation).

Framework Key Innovation Best For KV Management
vLLM PagedAttention General-purpose high throughput Paged (virtual memory)
TGI HF ecosystem integration Quick deployment, HF models Flash Attention + contiguous
TensorRT-LLM Compiled CUDA kernels Maximum single-stream perf In-flight batching
SGLang RadixAttention, prefix caching Shared prefixes, structured output Radix tree prefix cache

Code Examples

Let's implement the core concepts from this article in code. We'll build a minimal KV cache, implement speculative decoding logic, and show practical inference configurations.

Implementing a KV Cache

python
import torch

class KVCache:
    """Minimal KV cache for a single attention layer."""
    def __init__(self, max_seq_len, n_kv_heads, d_head, dtype=torch.float16):
        self.max_seq_len = max_seq_len
        # Pre-allocate buffers: [batch=1, n_kv_heads, max_seq, d_head]
        self.k = torch.zeros(1, n_kv_heads, max_seq_len, d_head, dtype=dtype)
        self.v = torch.zeros(1, n_kv_heads, max_seq_len, d_head, dtype=dtype)
        self.seq_len = 0   # current number of cached tokens

    def update(self, k_new, v_new):
        """Append new key/value vectors. k_new: [1, n_kv_heads, n_new, d_head]"""
        n_new = k_new.shape[2]
        end = self.seq_len + n_new
        assert end <= self.max_seq_len, "KV cache overflow"
        self.k[:, :, self.seq_len:end] = k_new
        self.v[:, :, self.seq_len:end] = v_new
        self.seq_len = end

    def get(self):
        """Return cached K, V up to current sequence length."""
        return self.k[:, :, :self.seq_len], self.v[:, :, :self.seq_len]

    def memory_bytes(self):
        """Report current memory usage."""
        element_size = self.k.element_size()  # 2 for fp16
        return 2 * self.seq_len * self.k.shape[1] * self.k.shape[3] * element_size

# ── Usage: simulate a 32-layer model ──────────────────────
n_layers, n_kv_heads, d_head = 32, 8, 128
caches = [KVCache(4096, n_kv_heads, d_head) for _ in range(n_layers)]

# Prefill: process 512 prompt tokens
for cache in caches:
    k = torch.randn(1, n_kv_heads, 512, d_head, dtype=torch.float16)
    v = torch.randn(1, n_kv_heads, 512, d_head, dtype=torch.float16)
    cache.update(k, v)

print(f"After prefill (512 tokens):")
print(f"  Per-layer cache: {caches[0].memory_bytes() / 1024:.0f} KB")
print(f"  Total cache:     {sum(c.memory_bytes() for c in caches) / 1024**2:.1f} MB")

# Decode: generate tokens one at a time
for step in range(100):
    for cache in caches:
        k = torch.randn(1, n_kv_heads, 1, d_head, dtype=torch.float16)
        v = torch.randn(1, n_kv_heads, 1, d_head, dtype=torch.float16)
        cache.update(k, v)

print(f"\nAfter 100 decode steps (612 total tokens):")
print(f"  Per-layer cache: {caches[0].memory_bytes() / 1024:.0f} KB")
print(f"  Total cache:     {sum(c.memory_bytes() for c in caches) / 1024**2:.1f} MB")

Speculative Decoding Logic

python
import torch
import torch.nn.functional as F

def speculative_decode(target_model, draft_model, input_ids, K=5):
    """
    Speculative decoding: use draft_model to propose K tokens,
    verify with target_model in one pass.
    Returns accepted tokens (at least 1, at most K+1).
    """
    draft_tokens = []
    draft_probs = []
    current = input_ids.clone()

    # Step 1: Draft K tokens autoregressively with the small model
    for _ in range(K):
        logits = draft_model(current).logits[:, -1, :]    # [1, vocab]
        probs = F.softmax(logits, dim=-1)
        token = torch.multinomial(probs, 1)                # [1, 1]
        draft_tokens.append(token.item())
        draft_probs.append(probs[0, token.item()].item())
        current = torch.cat([current, token], dim=1)

    # Step 2: Verify all K draft tokens with target model in one pass
    # Feed [prompt ... d1 d2 ... dK] through target model
    verify_input = torch.cat([
        input_ids,
        torch.tensor([draft_tokens], device=input_ids.device)
    ], dim=1)
    target_logits = target_model(verify_input).logits      # [1, seq+K, vocab]

    # Step 3: Accept/reject each draft token
    accepted = []
    n_prompt = input_ids.shape[1]
    for i, (d_tok, q_prob) in enumerate(zip(draft_tokens, draft_probs)):
        # Target probability at position where this token was generated
        p_logits = target_logits[0, n_prompt + i - 1]      # logits before d_tok
        p_probs = F.softmax(p_logits, dim=-1)
        p_prob = p_probs[d_tok].item()

        # Stochastic acceptance: accept with prob min(1, p/q)
        accept_prob = min(1.0, p_prob / max(q_prob, 1e-10))
        if torch.rand(1).item() < accept_prob:
            accepted.append(d_tok)
        else:
            # Rejection: sample from adjusted distribution max(0, p - q)
            adjusted = torch.clamp(p_probs - F.softmax(
                draft_model(verify_input[:, :n_prompt+i]).logits[:, -1], dim=-1
            )[0], min=0)
            adjusted = adjusted / adjusted.sum()
            bonus = torch.multinomial(adjusted.unsqueeze(0), 1).item()
            accepted.append(bonus)
            break   # stop at first rejection

    # If all K accepted, sample one bonus token from target model
    if len(accepted) == K:
        bonus_logits = target_logits[0, n_prompt + K - 1]
        bonus_probs = F.softmax(bonus_logits, dim=-1)
        bonus = torch.multinomial(bonus_probs.unsqueeze(0), 1).item()
        accepted.append(bonus)

    return accepted   # 1 to K+1 tokens per round

Practical vLLM Configuration

python
from vllm import LLM, SamplingParams

# ── Launch LLaMA 70B with tensor parallelism ──────────────
llm = LLM(
    model="meta-llama/Meta-Llama-3-70B-Instruct",
    tensor_parallel_size=4,        # Split across 4 GPUs
    dtype="bfloat16",
    max_model_len=8192,            # Max context length
    gpu_memory_utilization=0.92,   # Reserve 8% for overhead
    enable_prefix_caching=True,    # Cache shared system prompts
    # KV cache quantization (optional — saves ~50% KV memory)
    # kv_cache_dtype="fp8_e5m2",
)

# ── Sampling parameters ──────────────────────────────────
params = SamplingParams(
    temperature=0.7,
    top_p=0.9,
    max_tokens=1024,
    repetition_penalty=1.1,
    stop=["<|eot_id|>"],
)

# ── Generate with continuous batching (automatic) ─────────
prompts = [
    "Explain the KV cache in transformer inference.",
    "Write a Python function to compute Fibonacci numbers.",
    "What causes GPU memory fragmentation during LLM serving?",
]
outputs = llm.generate(prompts, params)

for out in outputs:
    print(f"Prompt: {out.prompt[:60]}...")
    print(f"Output: {out.outputs[0].text[:200]}...")
    print(f"Tokens generated: {len(out.outputs[0].token_ids)}")
    print()

KV Cache Memory Estimation Utility

python
def kv_cache_memory(
    n_layers: int,
    n_kv_heads: int,
    d_head: int,
    seq_len: int,
    batch_size: int = 1,
    precision: str = "fp16",
) -> dict:
    """Calculate KV cache memory requirements.

    Returns dict with per-sequence and total memory in bytes.
    """
    bytes_per_param = {"fp32": 4, "fp16": 2, "bf16": 2,
                       "fp8": 1, "int8": 1, "int4": 0.5}[precision]

    # 2 for K and V, per layer, per head, per position, per dimension
    per_seq = 2 * n_layers * n_kv_heads * d_head * seq_len * bytes_per_param
    total = per_seq * batch_size

    return {
        "per_sequence_bytes": per_seq,
        "per_sequence_gb": per_seq / (1024**3),
        "total_bytes": total,
        "total_gb": total / (1024**3),
        "bytes_per_token": 2 * n_layers * n_kv_heads * d_head * bytes_per_param,
    }

# ── Common model configurations ──────────────────────────
models = {
    "LLaMA-3-8B":  dict(n_layers=32,  n_kv_heads=8,   d_head=128),
    "LLaMA-3-70B": dict(n_layers=80,  n_kv_heads=8,   d_head=128),
    "Mistral-7B":  dict(n_layers=32,  n_kv_heads=8,   d_head=128),
    "Qwen2-72B":   dict(n_layers=80,  n_kv_heads=8,   d_head=128),
    "GPT-4 (est)": dict(n_layers=120, n_kv_heads=16,  d_head=128),
}

for name, cfg in models.items():
    for seq_len in [4096, 32768, 131072]:
        info = kv_cache_memory(**cfg, seq_len=seq_len, batch_size=1)
        print(f"{name:16s} @ {seq_len:>7,} ctx: "
              f"{info['per_sequence_gb']:.2f} GB/seq "
              f"({info['bytes_per_token']:,} bytes/tok)")
    print()

📌 What comes next

With inference fundamentals covered, the next frontier is making these systems cheaper. In Article 08: Quantization & Compression, we'll explore how INT8, INT4, and sub-4-bit quantization reduce both model weights and KV cache memory, the tradeoffs between PTQ and QAT, and techniques like GPTQ, AWQ, and SqueezeLLM that make 70B models run on a single consumer GPU.

References

Seminal papers and key works referenced in this article.

  1. Pope et al. "Efficiently Scaling Transformer Inference." MLSys, 2023. arXiv
  2. Kwon et al. "Efficient Memory Management for Large Language Model Serving with PagedAttention." SOSP, 2023. arXiv
  3. Leviathan et al. "Fast Inference from Transformers via Speculative Decoding." ICML, 2023. arXiv
  4. Shazeer et al. "Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer." ICLR, 2017. arXiv
  5. Fedus et al. "Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity." JMLR, 2022. arXiv