Introduction

Self-attention computes pairwise interactions between every token in a sequence. For a sequence of length N, this produces an N x N attention matrix — quadratic in both compute and memory. For N = 128K tokens (a modest context window by 2025 standards), that matrix contains 16.4 billion entries. Materializing it in float16 would require 32 GB of GPU memory for a single attention head in a single layer.

The naive reaction is to assume this is a compute problem — that we need faster math. The deeper truth, discovered by Tri Dao and collaborators in 2022, is that it is primarily an IO problem. Modern GPUs can compute arithmetic operations far faster than they can shuttle data between memory levels. The attention matrix doesn't need to exist in HBM at all. FlashAttention proved this by computing exact attention 2-4x faster while using O(N) memory instead of O(N^2), changing nothing about the mathematical output — only where each intermediate result lives.

PagedAttention (Kwon et al., 2023) attacks a different but equally critical bottleneck: the KV cache. During autoregressive generation, every past token's key and value vectors must be stored for future attention computations. Naive KV cache allocation wastes 60-80% of GPU memory through internal and external fragmentation. PagedAttention borrows the operating system's virtual memory abstraction — page tables, non-contiguous allocation, copy-on-write — and applies it to KV cache management.

ℹ What this article covers
We'll start with GPU memory architecture to build physical intuition for why memory access patterns matter. Then we'll walk through FlashAttention's tiling algorithm and online softmax trick in detail, analyze its IO complexity, cover the v2 and v3 improvements, pivot to PagedAttention's virtual memory approach for KV cache, and finish with Ring Attention for ultra-long sequences, benchmarks, and working code.

GPU Memory Hierarchy

Understanding FlashAttention requires understanding where data lives on a GPU. Modern GPUs have a layered memory hierarchy, and the performance difference between levels is enormous:

Memory Level Size (A100) Bandwidth Latency
Registers~27 MB total>19 TB/s~1 cycle
SRAM (shared memory)20 MB (192 KB/SM)~19 TB/s~20 cycles
HBM (global memory)40/80 GB2.0 TB/s~400 cycles
CPU DRAM512+ GB~50 GB/s~10K cycles

The critical ratio: SRAM is roughly 10x faster than HBM, with ~20x lower latency. HBM is roughly 40x faster than CPU DRAM. A GPU kernel that fits its working set entirely in SRAM can run an order of magnitude faster than one that repeatedly reads from and writes to HBM.

But SRAM is tiny. On an A100, each of the 108 streaming multiprocessors (SMs) has 192 KB of shared memory. That is 20 MB total — enough to hold roughly 10 million fp16 values. For a sequence of length 8192 with head dimension 128, a single Q-K tile of size 128 x 128 uses just 32 KB. This is the key insight: if we can break the attention computation into tiles that fit in SRAM, we avoid most HBM traffic.

The Bandwidth Bottleneck

The A100's peak compute throughput is 312 TFLOPS for fp16 tensor core operations. Its HBM bandwidth is 2.0 TB/s. The arithmetic intensity (ratio of compute to memory access) threshold is:

Operational Intensity = FLOPs / Bytes Accessed

A100 ridge point = 312 TFLOPS / 2.0 TB/s = 156 FLOPs/byte

Any operation below 156 FLOPs per byte of HBM access is memory-bound — the GPU spends most of its time waiting for data, not computing. Standard attention has an arithmetic intensity of roughly O(1) FLOPs/byte for the softmax normalization pass and the writing of the N x N intermediate matrix back to HBM. It is profoundly memory-bound.

This is why FlashAttention targets HBM reads and writes specifically, not FLOPs. It actually performs more total floating-point operations (due to recomputation in the backward pass), yet runs faster because it dramatically reduces HBM traffic.

GPU Memory Hierarchy & Data Flow Interactive

Watch data flow between memory levels during attention computation. Click the buttons to compare standard vs FlashAttention data movement patterns. The animation shows relative bandwidth and data volume for each path.

Click to start

Standard Attention IO Analysis

Let's trace the exact HBM reads and writes in a standard attention implementation. Given queries Q, keys K, values V each of shape (N, d), where N is the sequence length and d is the head dimension:

python
# Standard attention — every intermediate lives in HBM
S = Q @ K.T           # (N, d) @ (d, N) → (N, N)  — Write N² to HBM
P = softmax(S, dim=-1) # Read N² from HBM, write N² to HBM
O = P @ V             # (N, N) @ (N, d) → (N, d)  — Read N² + Nd from HBM

The total HBM access is dominated by the N x N intermediate matrices. Precisely:

HBM reads/writes = O(Nd + N2)

Step 1: Read Q, K from HBM (2Nd), write S to HBM (N2)
Step 2: Read S from HBM (N2), write P to HBM (N2)
Step 3: Read P, V from HBM (N2 + Nd), write O to HBM (Nd)

Total: O(Nd + N2) reads + writes

For typical LLM configurations — N = 8192, d = 128 — the N^2 term is 67 million entries per head, dwarfing the 2 million entries from Q, K, V, and O combined. The intermediate attention matrix consumes 33x more HBM bandwidth than the actual inputs and outputs. This matrix is computed, written to HBM, immediately read back, used once, and discarded. It is the prototypical example of wasted IO.

Concrete Numbers
N = 8192, d = 128, fp16 (2 bytes per element):
Q, K, V, O: 4 x 8192 x 128 x 2 = 8 MB
S, P matrices: 2 x 8192 x 8192 x 2 = 256 MB

The intermediate matrices are 32x larger than the inputs/outputs.

FlashAttention v1

FlashAttention (Dao et al., 2022) eliminates the N x N materialization entirely. The key idea: compute attention block by block, keeping intermediate results in SRAM and never writing the full attention matrix to HBM. The algorithm requires two crucial innovations: tiling (breaking the computation into SRAM-sized blocks) and online softmax (computing softmax incrementally without needing the full row).

The Tiling Strategy

Partition Q into blocks of size Br x d (row blocks) and K, V into blocks of size Bc x d (column blocks). The block sizes are chosen so that a Q block, a K block, a V block, and the partial output all fit in SRAM simultaneously:

Br x d + Bc x d + Br x Bc + Br x d ≤ M

where M is the SRAM size (in elements)

The outer loop iterates over K, V column blocks. The inner loop iterates over Q row blocks. For each (Q-block, K-block) pair, we compute the local attention scores, apply the running softmax correction, and accumulate the weighted V contribution — all in SRAM.

python
# FlashAttention forward pass (simplified pseudocode)
# Q, K, V in HBM: each (N, d)
# Block sizes: B_r (query block rows), B_c (key block rows)
# M = SRAM capacity

# Initialize output O = zeros(N, d), row-max m = -inf(N), row-sum l = zeros(N)
for j in range(0, N, B_c):          # Outer loop: iterate over K/V blocks
    K_j = K[j : j+B_c]              # Load K block from HBM to SRAM
    V_j = V[j : j+B_c]              # Load V block from HBM to SRAM

    for i in range(0, N, B_r):      # Inner loop: iterate over Q blocks
        Q_i = Q[i : i+B_r]          # Load Q block from HBM to SRAM
        O_i = O[i : i+B_r]          # Load partial output from HBM
        m_i = m[i : i+B_r]          # Load running max
        l_i = l[i : i+B_r]          # Load running sum

        # Compute local attention (all in SRAM)
        S_ij = Q_i @ K_j.T / sqrt(d)   # (B_r, B_c) — fits in SRAM

        # Online softmax update
        m_new = max(m_i, rowmax(S_ij))
        P_ij = exp(S_ij - m_new)        # Local softmax numerator
        l_new = exp(m_i - m_new) * l_i + rowsum(P_ij)

        # Rescale previous output and accumulate
        O_i = (exp(m_i - m_new) * l_i / l_new) * O_i + (1 / l_new) * P_ij @ V_j

        # Write updated state back to HBM
        m_i, l_i = m_new, l_new
        O[i : i+B_r] = O_i

The critical property: Q, K, V blocks are each loaded from HBM exactly once per outer iteration. The Br x Bc attention tile S_ij exists only in SRAM and is never written to HBM. The partial output O is read and written once per outer iteration, so its total HBM traffic scales with the number of K/V blocks, not N^2.

Online Softmax

Standard softmax requires two passes over the data: one to find the row maximum (for numerical stability), and one to exponentiate and normalize. This means the full N-length row must be accessible simultaneously — which normally requires materializing the entire attention row in HBM.

FlashAttention uses the online softmax trick (Milakov & Gimelshein, 2018): maintain a running maximum and running sum, and rescale previous partial results when a new block introduces a larger maximum. The algebra:

Given blocks S1, S2, ..., ST of the attention row:

After block t: m(t) = max(m(t-1), max(St))
l(t) = em(t-1) - m(t) · l(t-1) + ∑ eSt - m(t)

O(t) = (em(t-1) - m(t) · l(t-1) / l(t)) · O(t-1) + (1/l(t)) · Pt · Vt

The rescaling factor exp(m_old - m_new) corrects previous partial sums whenever a new block has a larger maximum. When the maximum doesn't change (which becomes increasingly likely as more blocks are processed), the rescaling factor is exactly 1 — no correction needed. The final output is numerically equivalent to standard softmax attention, bit for bit (in exact arithmetic).

💡 Why recomputation beats storage

FlashAttention does not store the attention matrix for the backward pass. Instead, it recomputes S from Q and K during backpropagation. This costs extra FLOPs — approximately 25% more total compute. But because the recomputation happens in SRAM with no HBM writes for the N x N matrix, it is faster in wall-clock time. This is a textbook example of the compute-memory tradeoff: spending cheap FLOPs to save expensive memory bandwidth.

IO Complexity Analysis

FlashAttention's HBM access pattern is fundamentally different from standard attention:

Algorithm HBM Reads/Writes Extra Memory
Standard attentionO(Nd + N2)O(N2)
FlashAttentionO(N2d2 / M)O(N)

Where M is the SRAM size. The derivation: each of the N/Bc outer iterations loads all of Q (Nd elements) and one K, V block (Bcd elements). Total reads:

Reads = (N/Bc) · Nd + N · d = N2d / Bc + Nd

With Bc = O(M/d), this becomes O(N2d2 / M)

For the A100 (M ~ 100K fp16 elements in practice), d = 128, N = 8192: standard attention performs ~134M HBM accesses while FlashAttention performs ~17M — roughly an 8x reduction. Dao et al. proved this is optimal: any exact attention algorithm must perform at least this many HBM accesses (an IO-complexity lower bound).

FlashAttention Tiling Strategy Interactive

Step through the FlashAttention algorithm tile by tile. Each step loads a Q block and K/V block into SRAM, computes the local attention, and accumulates into the output. The full N x N matrix never exists in HBM.

Step 0 / 16
Current Q block (in SRAM) Current K/V block (in SRAM) Computed tile (result in SRAM) Not yet computed

FlashAttention v2 & v3

FlashAttention v1 proved the concept. Subsequent versions focused on extracting maximum throughput from the hardware through better parallelism and work partitioning.

FlashAttention v2 (Dao, 2023)

Three major improvements over v1:

  • Swapped loop order. v1 iterates over K/V blocks in the outer loop and Q blocks in the inner loop. v2 reverses this: Q blocks in the outer loop, K/V blocks in the inner loop. This eliminates the need to write partial output O back to HBM after each inner iteration — each Q block's output is fully accumulated before being written once.
  • Better work partitioning across warps. v1 splits work across warps within a thread block by partitioning the Q block (each warp processes a subset of Q rows). v2 instead splits across K/V blocks within a thread block, keeping all warps working on the same Q rows but different K/V columns. This reduces the need for communication (shared memory reads/writes) between warps during the reduction step.
  • Sequence-length parallelism. v1 uses batch size and number of heads as its two parallelism dimensions. v2 adds parallelism across sequence-length blocks, which matters for long sequences with small batch sizes — exactly the regime of modern LLM serving.

Result: FlashAttention-2 achieves 50-73% of theoretical maximum FLOPS throughput on A100, up from 25-40% in v1. This translates to roughly 2x end-to-end speedup over v1.

FlashAttention v3 (Shah et al., 2024)

FlashAttention-3 targets Hopper GPUs (H100) and exploits three Hopper-specific features:

  • Asynchronous execution via warp-specialization. Different warps are specialized for different roles: some warps issue global memory loads (producer warps), while others perform tensor core math (consumer warps). The two can overlap, hiding memory latency behind computation.
  • FP8 support with incoherent processing. H100 tensor cores support FP8 at 2x the throughput of FP16. FlashAttention-3 uses FP8 for the QK^T matmul while keeping the softmax and accumulation in FP16/FP32 for numerical precision. "Incoherent processing" applies random orthogonal transformations to Q and K before quantization, reducing quantization error.
  • Block quantization and low-precision accumulation. Rather than converting the entire Q, K matrices to FP8, FlashAttention-3 quantizes per-block (per-tile), which limits the dynamic range each FP8 tile must represent.

FlashAttention-3 reaches 740 TFLOPS on H100 (75% of peak), which is 1.5-2x faster than FlashAttention-2 on the same hardware.

ℹ Adoption
FlashAttention is now the default attention implementation in PyTorch (via torch.nn.functional.scaled_dot_product_attention), HuggingFace Transformers, and virtually every major training and inference framework. If you're running a transformer model on an NVIDIA GPU today, you are almost certainly using FlashAttention whether you know it or not.

PagedAttention

FlashAttention optimizes the computation of attention. PagedAttention optimizes the storage of attention's most expensive artifact: the KV cache.

During autoregressive generation, every token's key and value vectors must be stored so that future tokens can attend to them. For a model with L layers and H heads, each with head dimension d, storing the KV cache for a single sequence of length N requires:

KV cache size = 2 · L · H · N · d · sizeof(dtype)

LLaMA-2 70B (L=80, H=8 KV heads with GQA, d=128, fp16):
= 2 · 80 · 8 · N · 128 · 2 bytes = 320 KB per token

For N = 4096: 1.25 GB per sequence
For N = 128K: 39 GB per sequence

In a serving system handling dozens of concurrent requests with different sequence lengths, the KV cache dominates GPU memory. And here's the problem: standard implementations pre-allocate contiguous memory for the maximum possible sequence length of each request. A request that might generate up to 2048 tokens allocates 2048 tokens worth of KV cache at the start, even if it ultimately only generates 50 tokens. This leads to massive fragmentation.

The Virtual Memory Analogy

PagedAttention's core insight is that the KV cache fragmentation problem is structurally identical to the memory fragmentation problem that operating systems solved decades ago with virtual memory.

In an OS, each process sees a contiguous virtual address space, but the underlying physical memory is divided into fixed-size pages (typically 4 KB). A page table maps virtual pages to physical frames. Pages can be scattered anywhere in physical memory — the process doesn't know or care. This eliminates external fragmentation entirely.

PagedAttention applies the same abstraction: each sequence's KV cache has a contiguous logical address space (token 0 through token N), but the underlying GPU memory is divided into fixed-size KV blocks. A block table maps each sequence's logical blocks to physical blocks in GPU memory. Blocks can be non-contiguous.

KV Cache Blocks

A KV block stores the key and value vectors for a fixed number of tokens (the block size, typically 16 tokens). For each attention layer and each KV head, the block contains:

Block content = K[block_start : block_start + block_size] and V[block_start : block_start + block_size]

Block memory = 2 · block_size · d · sizeof(dtype)
= 2 · 16 · 128 · 2 = 8 KB per layer per KV head

Blocks are allocated on demand as new tokens are generated. When a sequence finishes, its blocks are returned to a free pool. Internal fragmentation is limited to the last block of each sequence — at most block_size - 1 wasted slots, compared to potentially thousands of wasted slots with contiguous pre-allocation.

The attention kernel is modified to work with non-contiguous memory: instead of reading K, V from a single contiguous buffer, it looks up each block's physical address in the block table and gathers the keys and values from scattered locations. This adds indirection overhead, but it is negligible compared to the memory savings.

Memory Sharing & Copy-on-Write

The block-level indirection enables powerful memory sharing. Consider beam search with beam width B: all B beams share the same prompt prefix. With contiguous allocation, each beam requires a full copy of the prefix's KV cache. With PagedAttention, all beams simply share the same physical blocks for the prefix — their block tables point to the same physical addresses.

When beams diverge (a beam generates a different token than another), PagedAttention uses copy-on-write: only the block containing the diverging token is copied. All other shared blocks remain shared. This reduces beam search memory from O(B x N) to O(N + B x diverged_tokens).

The same mechanism enables parallel sampling (generating multiple completions from the same prompt) and prefix caching (reusing KV blocks across requests that share a common system prompt). In practice, PagedAttention achieves near-zero waste in memory utilization, compared to 60-80% waste with contiguous allocation.

💡 Impact on throughput

By eliminating fragmentation, PagedAttention allows serving systems to pack 2-4x more concurrent sequences into the same GPU memory. Since LLM serving is typically memory-capacity-bound (not compute-bound), this translates almost directly to 2-4x higher throughput. vLLM, the reference implementation of PagedAttention, demonstrated 2-4x throughput improvement over HuggingFace Text Generation Inference and 14-24x over the naive implementation at launch.

PagedAttention — Virtual Memory for KV Cache Interactive

Watch how PagedAttention maps logical KV cache blocks to physical GPU memory blocks. Click "Add Sequence" to allocate a new request, "Generate Token" to extend a sequence, and "Free Sequence" to release blocks. Notice how blocks are scattered in physical memory but appear contiguous to each sequence.

0 sequences, 0/32 blocks used

Ring Attention

FlashAttention fits the computation into a single GPU's SRAM. But what about sequences so long that even the Q, K, V matrices don't fit in a single GPU's HBM? Ring Attention (Liu et al., 2023) extends the FlashAttention tiling pattern across multiple devices.

The setup: distribute the sequence across D devices, where device i holds tokens [iN/D, (i+1)N/D). Each device has its local Q, K, V blocks. To compute full attention, every Q block must interact with every K, V block — including those on other devices.

Ring Attention arranges devices in a logical ring. Each device holds its local K, V blocks and passes them to the next device in the ring while simultaneously computing attention against the K, V blocks it currently holds. After D-1 communication steps, every device has seen every K, V block.

Round r (on device i):
1. Compute: FlashAttention(Qlocal, Kfrom device (i-r) mod D, Vfrom device (i-r) mod D)
2. Send: K, V blocks to device (i+1) mod D
3. Receive: K, V blocks from device (i-1) mod D

After D rounds, each device has computed full attention for its Q rows.

The key insight: the FlashAttention online softmax mechanism means each round's computation can be seamlessly merged with the running output — no additional correction pass is needed. Communication (sending K, V to the next device) is overlapped with computation (attention against the current K, V), so the communication cost is largely hidden.

Ring Attention enables context lengths that scale linearly with the number of devices. With D = 8 GPUs, each with 80 GB HBM, a model can handle sequences 8x longer than a single GPU permits. This is how models like Gemini 1.5 (1M+ token context) and Claude (200K token context) handle ultra-long inputs.

Benchmarks

The practical impact of these techniques is substantial. Here are representative numbers across sequence lengths:

Sequence Length Standard Attn Memory FlashAttention Memory Wall-clock Speedup
1K4 MB0.5 MB1.3x
4K64 MB2 MB2.1x
16K1 GB8 MB3.5x
64K16 GB32 MB4.2x
128K64 GB (OOM)64 MB∞ (enables)

Memory values are for the attention matrix per head per layer. FlashAttention's memory scales linearly with N (for the running statistics m and l) while standard attention scales quadratically.

For PagedAttention, the throughput gains come from memory capacity, not speed:

System Concurrent Sequences (LLaMA-13B, A100 80GB) Throughput (tokens/s)
Naive (contiguous)~8~1,200
HuggingFace TGI~16~3,400
vLLM (PagedAttention)~40~8,500
Memory Usage Comparison Interactive

Compare attention memory usage across sequence lengths. Standard attention grows quadratically (N^2) while FlashAttention grows linearly. Hover over bars to see exact values.

Hover over bars for details

Code Examples

Using FlashAttention in practice is straightforward — it's integrated into PyTorch's scaled_dot_product_attention and can be enabled explicitly in HuggingFace models.

PyTorch — FlashAttention via SDPA

python
import torch
import torch.nn.functional as F

# PyTorch >= 2.0 automatically selects FlashAttention when available
batch, heads, seq_len, d_head = 4, 32, 8192, 128

Q = torch.randn(batch, heads, seq_len, d_head, device='cuda', dtype=torch.float16)
K = torch.randn(batch, heads, seq_len, d_head, device='cuda', dtype=torch.float16)
V = torch.randn(batch, heads, seq_len, d_head, device='cuda', dtype=torch.float16)

# This automatically uses FlashAttention-2 on supported hardware
with torch.backends.cuda.sdp_kernel(
    enable_flash=True,       # Enable FlashAttention
    enable_math=False,       # Disable fallback math kernel
    enable_mem_efficient=False  # Disable xformers memory-efficient kernel
):
    output = F.scaled_dot_product_attention(Q, K, V)
    # output shape: (batch, heads, seq_len, d_head)

# Check which backend was used
print(torch.backends.cuda.flash_sdp_enabled())  # True

# With causal mask (for autoregressive models)
output_causal = F.scaled_dot_product_attention(Q, K, V, is_causal=True)

HuggingFace — Enabling FlashAttention-2

python
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load model with FlashAttention-2 backend
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    torch_dtype=torch.float16,
    attn_implementation="flash_attention_2",  # Explicit FA2
    device_map="auto",
)

# For models that support it, FlashAttention is auto-selected
# based on hardware compatibility. Check with:
print(model.config._attn_implementation)  # "flash_attention_2"

# Long-context inference that would OOM with standard attention
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
long_input = tokenizer("..." * 4096, return_tensors="pt").to("cuda")
with torch.no_grad():
    output = model(**long_input)  # Works with FA2, OOM without

vLLM — PagedAttention in Practice

python
from vllm import LLM, SamplingParams

# vLLM uses PagedAttention by default — no configuration needed
llm = LLM(
    model="meta-llama/Llama-2-13b-hf",
    tensor_parallel_size=1,        # Number of GPUs
    gpu_memory_utilization=0.90,   # Use 90% of GPU memory
    max_model_len=4096,            # Maximum sequence length
    block_size=16,                 # KV cache block size (tokens per block)
    swap_space=4,                  # GiB of CPU swap space for preempted seqs
)

# Generate with beam search — PagedAttention shares prefix KV blocks
sampling_params = SamplingParams(
    temperature=0.0,
    max_tokens=256,
    use_beam_search=True,
    best_of=4,  # 4 beams — prefix KV cache shared via copy-on-write
)

prompts = [
    "Explain the difference between FlashAttention and standard attention:",
    "What is the purpose of the KV cache in transformer inference?",
]

outputs = llm.generate(prompts, sampling_params)
for output in outputs:
    print(f"Prompt: {output.prompt[:60]}...")
    print(f"Output: {output.outputs[0].text[:200]}")
    print(f"Tokens generated: {len(output.outputs[0].token_ids)}")
    print()

FlashAttention Direct — Using the flash-attn Package

python
from flash_attn import flash_attn_func, flash_attn_varlen_func

# Basic FlashAttention call
# Note: flash_attn expects (batch, seqlen, nheads, headdim) layout
batch, seq_len, n_heads, d_head = 4, 8192, 32, 128

q = torch.randn(batch, seq_len, n_heads, d_head, device='cuda', dtype=torch.float16)
k = torch.randn(batch, seq_len, n_heads, d_head, device='cuda', dtype=torch.float16)
v = torch.randn(batch, seq_len, n_heads, d_head, device='cuda', dtype=torch.float16)

# Forward pass — causal mask for autoregressive models
output = flash_attn_func(q, k, v, causal=True)
# output: (batch, seq_len, n_heads, d_head)

# Variable-length sequences (packed into single batch dimension)
# Useful for serving batches with different sequence lengths
cu_seqlens_q = torch.tensor([0, 1024, 3072, 8192], dtype=torch.int32, device='cuda')
cu_seqlens_k = torch.tensor([0, 1024, 3072, 8192], dtype=torch.int32, device='cuda')
max_seqlen_q = 5120
max_seqlen_k = 5120

q_packed = torch.randn(8192, n_heads, d_head, device='cuda', dtype=torch.float16)
k_packed = torch.randn(8192, n_heads, d_head, device='cuda', dtype=torch.float16)
v_packed = torch.randn(8192, n_heads, d_head, device='cuda', dtype=torch.float16)

output_varlen = flash_attn_varlen_func(
    q_packed, k_packed, v_packed,
    cu_seqlens_q, cu_seqlens_k,
    max_seqlen_q, max_seqlen_k,
    causal=True,
)

🧭 What comes next

FlashAttention and PagedAttention attack the memory wall from the compute and storage sides respectively. The next frontier combines these with speculative decoding (Article 09) to break the latency wall — using draft models to generate candidate tokens in parallel, then verifying them in a single forward pass. Together, these three techniques form the systems backbone of modern LLM serving.

References

Seminal papers and key works referenced in this article.

  1. Dao et al. "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness." NeurIPS, 2022. arXiv
  2. Dao. "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning." ICLR, 2024. arXiv
  3. Kwon et al. "Efficient Memory Management for Large Language Model Serving with PagedAttention." SOSP, 2023. arXiv
  4. Liu et al. "Ring Attention with Blockwise Transformers for Near-Infinite Context." ICLR, 2024. arXiv