The N×N attention matrix never touches slow memory. Tiled computation, online softmax, and the IO-aware algorithm that changed how we run transformers.
You want a language model that can read an entire textbook. Not a paragraph — an entire 300-page mathematics textbook, all at once. Or you want it to process a full second of raw audio (thousands of timesteps), or analyze long-range interactions in the human genome across 100,000+ nucleotide pairs.
The obstacle is attention. The core operation in a transformer computes pairwise similarities between every token and every other token. If your sequence has N tokens, the attention matrix is N × N. Double the sequence length and you quadruple the memory and compute.
At N = 1,024 tokens, that matrix has about 1 million entries. Manageable. At N = 16,384, it has 268 million entries. At N = 131,072 (a modest book), it has 17 billion entries. The naive approach collapses.
Drag the slider to increase sequence length N. Watch memory explode quadratically. The red line is the N×N attention matrix. The teal line is the Q, K, V storage (linear in N).
This chapter's question: is the quadratic cost fundamental? Do we really need every token to talk to every other token? And even if we do, must we store that enormous matrix in memory?
The answer, as we'll discover, is no. FlashAttention computes exact attention — identical results to the naive algorithm — without ever materializing the full N×N matrix. But to understand how, we first need to survey the landscape of approaches people tried before it.
Before FlashAttention, researchers attacked the quadratic bottleneck by changing the attention algorithm itself — computing something cheaper that approximates attention. These approaches fall into three families: sparse, low-rank, and kernel.
Instead of letting every token attend to every other token, sparse attention restricts which pairs of tokens interact. Think of it as poking holes in the N×N matrix — only computing certain entries.
Sparse Transformers (Child et al., 2019) use a sliding window: each token attends only to nearby tokens within a fixed window size w. Complexity drops from O(N²) to O(N√N). BigBird (Zaheer et al., 2021) combines three patterns: local windows (nearby tokens), global tokens (a few special tokens attend to everything), and random connections (to reduce graph diameter). Complexity: O(N).
Low-rank methods observe that the N×N attention matrix often has low effective rank — it can be approximated by projecting K and V to a much smaller dimension k ≪ N. Linformer (Wang et al., 2020) projects the N×d key and value matrices down to k×d using learned projection matrices, reducing the attention computation to k×k instead of N×N.
The downside: hard to maintain causality (the model shouldn't peek at future tokens), and quality degrades when the true rank is higher than k.
Kernel methods replace the exp(Q KT) similarity with a decomposable feature map: sim(qi, kj) ≈ φ(qi)T φ(kj). The Linear Transformer (Katharopoulos et al., 2020) uses this trick to avoid computing the full N×N matrix. Without the softmax, you can exploit the associative property of matrix multiplication:
When d ≪ N (which it usually is), this is a win. But the approximation quality suffers, especially on language tasks where the sharp attention patterns that softmax produces matter.
Click the buttons to see different attention patterns. Orange cells are computed; dark cells are skipped. Full attention fills the whole grid.
| Method | Type | Complexity | Exact? | Tradeoff |
|---|---|---|---|---|
| Standard | — | O(N²d) | Yes | Slow for long N |
| Sparse Transformer | Sparse | O(N√N) | No | Misses global patterns |
| BigBird | Sparse | O(N) | No | Fixed pattern, not learned |
| Linformer | Low-rank | O(Nk) | No | Rank assumption, no causality |
| Linear Transformer | Kernel | O(Nd²) | No | No softmax → weaker attention |
| FlashAttention | IO-aware | O(N²d) | Yes | Same FLOPs, fewer memory accesses |
Notice that FlashAttention doesn't reduce the FLOP count. It has the same O(N²d) compute. Its insight is entirely different: the problem isn't that we're doing too many operations — it's that we're moving data too slowly. The bottleneck is memory bandwidth, not compute.
Here's the surprise. Standard attention on a modern GPU like the A100 isn't slow because of arithmetic. The A100 can perform 312 TFLOPS of BF16 computation per second. For a sequence of N = 1024 with head dimension d = 64, the attention requires roughly 134 million FLOPs. At 312 TFLOPS, that's done in microseconds. So why does attention take milliseconds?
The answer: memory traffic. Let's trace what happens step by step in a standard PyTorch attention implementation.
Count the memory trips. The N×N matrix S gets written to HBM after Step 1, read back for Step 2, written again, read back for Step 3, and so on. Each of those lightweight element-wise operations (masking, softmax, dropout) is sandwiched between expensive HBM reads and writes. The compute in each step is trivial; the data movement is not.
An operation is compute-bound when the time is dominated by arithmetic: matmuls with large inner dimensions fall here. An operation is memory-bound when the time is dominated by moving bytes between memory levels: element-wise ops (masking, softmax, dropout) with large tensors fall here.
The way to tell: compute the arithmetic intensity — the ratio of FLOPs to bytes transferred.
The A100 has a compute bandwidth of 312 TFLOPS and a memory bandwidth of 1,935 GB/s. The crossover ratio is:
If your operation's arithmetic intensity is below 161, it's memory-bound. If above, it's compute-bound. Let's check attention.
For the softmax step: N² values must be read and N² values written. Each BF16 value is 2 bytes. Total memory: 2 × N² × 2 = 4N² bytes. FLOPs: roughly 5N² (exp, subtract max, divide). Arithmetic intensity: 5N² / 4N² ≈ 1.25 FLOPs/Byte. That's 128× below the A100's crossover. Massively memory-bound.
For comparison, a large matrix multiplication like (8192 × 8192) × (8192 × 8192) has arithmetic intensity around 2,730 FLOPs/Byte — solidly compute-bound.
Adjust sequence length and head dimension. The chart shows arithmetic intensity for each step of attention. The red dashed line is the A100's compute/memory ratio (161). Everything below it is memory-bound.
To understand FlashAttention, you need a mental model of where data lives inside a GPU. A modern GPU has (at least) three levels of memory, and the speed difference between them is staggering.
| Level | Size (A100) | Bandwidth | Analogy |
|---|---|---|---|
| Registers | 256 KB / SM | >19 TB/s | Your hands |
| SRAM (shared memory) | 192 KB / SM (~20 MB total) | ~19 TB/s | Your desk |
| HBM (DRAM) | 40–80 GB | 1.5–2 TB/s | The warehouse |
HBM (High Bandwidth Memory) is the big, slow memory. 40–80 GB is plenty of room, but you can only read from it at about 1.5–2 TB/s. SRAM (on-chip shared memory) is tiny but fast — about 20 MB total across all streaming multiprocessors (SMs), but accessible at roughly 19 TB/s. That's a 10× bandwidth gap.
Here's the problem for standard attention. The N×N attention matrix is large. For N = 4096 in BF16, it's 4096² × 2 bytes = 32 MB. That exceeds SRAM capacity. So standard implementations have no choice: they write S to HBM after the QKT matmul, read it back for softmax, write the result back, read it again for dropout, and so on.
But the Q, K, V tiles? A single tile of, say, 128 × 64 in BF16 is only 128 × 64 × 2 = 16 KB. A pair of tiles (one from Q, one from K) fits comfortably in SRAM. This is the opening that FlashAttention exploits.
The bars represent how much of each memory level the attention matrix occupies at different sequence lengths. Watch how quickly HBM fills up, while SRAM can only hold tiny tiles.
A GPU runs kernels — functions that execute in parallel across thousands of threads. Threads are grouped into warps (32 threads), warps into thread blocks, and thread blocks are assigned to streaming multiprocessors (SMs). Each thread block has access to its SM's shared memory (SRAM). Data that multiple threads need can be loaded once into shared memory, avoiding redundant HBM reads. This is called tiling.
Kernel fusion is the other key technique: instead of launching five separate GPU kernels (matmul, mask, softmax, dropout, matmul) with HBM writes between each one, we fuse them into a single kernel that does all the work while data is still in SRAM.
Here's the plan. Instead of computing the full N×N matrix and storing it, we break Q, K, V into small blocks that fit in SRAM. For each pair of blocks, we compute a partial attention result, then combine the partial results to get the exact final answer.
Concretely: split Q into blocks of Br rows and K, V into blocks of Bc rows. The outer loop iterates over K/V blocks (j = 1, ..., Tc). The inner loop iterates over Q blocks (i = 1, ..., Tr). For each (i, j) pair, we compute the Br × Bc tile of the attention matrix — and immediately use it to update the running output, then discard it.
There's one serious obstacle. Softmax is a global operation. For each row i, softmax divides by the sum of exp(sij) across ALL columns j = 1, ..., N. If we only have one tile at a time, we don't have the global sum. We can't do softmax.
Or can we?
The answer is online softmax — an algorithm that computes softmax incrementally, one tile at a time, maintaining a running maximum and running sum. When a new tile arrives with larger values, it rescales all the previous partial results. At the end, the answer is identical to computing softmax over the entire row at once.
This is the algorithmic heart of FlashAttention, and we'll derive it step by step in the next chapter.
Standard attention reads/writes the N×N matrix multiple times. Total HBM access: O(Nd + N²). FlashAttention only reads Q, K, V once and writes O once, but does it multiple times in tiles. Total HBM access: O(N²d²M−1), where M is SRAM size. Since M is roughly 20 MB on an A100:
In practice, this is a 9× reduction in HBM reads and writes. Despite doing 13% more FLOPs (due to the rescaling arithmetic), FlashAttention is 6× faster in wall-clock time.
| Metric | Standard | FlashAttention |
|---|---|---|
| GFLOPs | 66.6 | 75.2 (↑13%) |
| HBM reads/writes (GB) | 40.3 | 4.4 (↓9×) |
| Runtime (ms) | 41.7 | 7.3 (↓6×) |
This is the core derivation. If you understand this chapter, you understand FlashAttention. Let's build online softmax from scratch.
Given a vector x = [x1, x2, ..., xN], the softmax is:
This requires two passes over the data: first compute all exp(xj) and sum them, then divide each by the sum. Both passes need all N values simultaneously.
In practice, we first subtract the maximum value for numerical stability. If xmax = 42, then exp(42) overflows float32. But exp(42 − 42) = exp(0) = 1 is fine. The result is identical because:
where m = max(x). The exp(m) cancels in numerator and denominator. So stable softmax needs three passes: (1) find max, (2) compute exp and sum, (3) divide.
Now suppose we split x into two blocks: x = [x(1), x(2)]. We process block 1 first, then block 2. After block 1, we have:
This is a valid local softmax weighted by block 1's values. But the denominator l(1) only reflects block 1. It's wrong as a global softmax.
When we process block 2, we discover a new maximum and new sum. We need to fix our earlier result. Here's the algebra, step by step.
After processing block 2, we have:
We need to rescale block 1's exponentials. Originally we computed exp(xj − m(1)). We need exp(xj − mnew). The correction factor is:
So we multiply all of block 1's contributions by exp(m(1) − mnew). Since mnew ≥ m(1), this factor is ≤ 1 (it shrinks old values that were computed with a smaller max). Similarly for block 2's terms: multiply by exp(m(2) − mnew).
The corrected sums:
And the corrected output:
For B blocks processed sequentially, after processing block j, we update:
Three running scalars per row: the maximum m, the sum l, and the output vector O. These are tiny — they live in registers. The N×N matrix tiles are computed and discarded in SRAM.
Let's compute attention for a concrete 4×4 score matrix, tiled into 2×2 blocks, to see the online softmax in action.
Score matrix S (after QKT):
| j=1 | j=2 | j=3 | j=4 | |
|---|---|---|---|---|
| i=1 | 1.0 | 2.0 | 0.5 | 0.1 |
| i=2 | 0.3 | 1.5 | 2.0 | 0.8 |
(Showing rows 1–2 for clarity. The process is identical for rows 3–4.)
Block 1 (columns 1–2 for row 1): scores = [1.0, 2.0]
Block 2 (columns 3–4 for row 1): scores = [0.5, 0.1]
That final onew is exactly the same as computing softmax([1.0, 2.0, 0.5, 0.1]) and multiplying by [v1, v2, v3, v4]. You can verify: 0.368/1.741 = 0.211, 1.000/1.741 = 0.574, 0.223/1.741 = 0.128, 0.150/1.741 = 0.086. These sum to 1.0 (within rounding).
Click Next Tile to process one K/V block at a time. Watch the running max (m), running sum (l), and output weights update with each tile. The final weights match global softmax exactly.
Now we put it all together. We've established three principles: (1) tile to fit in SRAM, (2) online softmax to combine tiles exactly, (3) never write the N×N matrix to HBM. Here's the complete forward pass.
pseudocode # Inputs: Q, K, V in HBM (each N x d) # Block sizes: B_r, B_c chosen so tiles fit in SRAM # Output: O in HBM (N x d) # Initialize output O = 0, running max m = -inf, running sum l = 0 O = zeros(N, d) # in HBM m = full(N, -inf) # in HBM (one per row) l = zeros(N) # in HBM (one per row) # Divide K, V into T_c blocks of size B_c # Divide Q into T_r blocks of size B_r for j = 1 to T_c: # outer loop over K/V blocks K_j, V_j = load_from_HBM(K[j], V[j]) # to SRAM for i = 1 to T_r: # inner loop over Q blocks Q_i = load_from_HBM(Q[i]) # to SRAM O_i, m_i, l_i = load_from_HBM(O[i], m[i], l[i]) # Compute tile of attention scores (in SRAM!) S_ij = Q_i @ K_j.T # B_r x B_c, never leaves SRAM # Online softmax update m_new = max(m_i, rowmax(S_ij)) l_new = l_i * exp(m_i - m_new) + rowsum(exp(S_ij - m_new)) O_new = (l_i * exp(m_i - m_new) * O_i + exp(S_ij - m_new) @ V_j) / l_new # Write updated O, m, l back to HBM write_to_HBM(O_new, m_new, l_new) return O
| Data | Lives in | Size | Lifetime |
|---|---|---|---|
| Q, K, V, O | HBM | N × d each | Entire computation |
| m, l | HBM (+ registers) | N each (one per row) | Entire computation |
| Qi, Kj, Vj tiles | SRAM | Br × d, Bc × d | One iteration |
| Sij tile | SRAM | Br × Bc | One iteration, then discarded |
Total SRAM needed: Br × d + 2 × Bc × d + Br × Bc. On an A100 with ~192 KB of shared memory per SM, typical block sizes are Br = Bc = 128 with d = 64. That's 128 × 64 × 3 × 2 bytes = 48 KB for tiles, plus 128 × 128 × 2 = 32 KB for Sij. Total: 80 KB. Fits comfortably.
Training requires a backward pass. Normally, the forward pass saves the N×N attention matrix for use in backpropagation. FlashAttention doesn't have it — we threw it away. The solution: recompute the attention tiles during the backward pass.
This sounds wasteful. We do more FLOPs (about 13% more). But remember: the bottleneck is memory bandwidth, not compute. By recomputing tiles in SRAM instead of reading a huge cached matrix from HBM, the backward pass is still faster.
The only things saved from the forward pass: the output O, and the per-row m and l vectors (total: 2N scalars plus the N × d output). From these, plus the original Q, K, V, we can reconstruct any tile of the attention matrix on the fly.
Watch the algorithm process tiles one at a time. Orange = current tile in SRAM. Teal = completed tiles (discarded from SRAM). Purple = current Q and K/V blocks loaded. The full N×N matrix never exists at once.
FlashAttention-1 was a breakthrough. FlashAttention-2 (Dao, 2023) made it even faster by fixing inefficiencies in how work is distributed across GPU thread blocks and warps.
1. Better work partitioning. FA-1's inner loop assigns each thread block to a (Q-block, K/V-block) pair. But this means different thread blocks working on the same Q-block need to communicate through slow HBM to share their partial outputs. FA-2 restructures the loops: the outer loop iterates over Q blocks, and the inner loop over K/V blocks, so each thread block fully owns one Q-block and accumulates the output in SRAM without cross-block communication.
2. Reduced non-matmul FLOPs. Modern GPUs have tensor cores that accelerate matrix multiplications dramatically (up to 16× on H100). Non-matmul operations (the rescaling, max-finding, and bookkeeping) run on regular CUDA cores, which are much slower. FA-2 reduces these non-matmul operations, keeping tensor cores busier.
3. More parallelism. FA-1 parallelizes over batch size and number of attention heads. FA-2 also parallelizes over the sequence length dimension. For long sequences with small batch sizes (common in inference), this better saturates the GPU.
FlashAttention-3 (2024) targets the H100 GPU and exploits its new features: asynchronous memory copies (overlap compute and data movement), low-precision matmuls (FP8 tensor cores), and the TMA (Tensor Memory Accelerator) unit for efficient tiling. It also introduces software pipelining — while one tile is being computed, the next tile is being loaded from HBM to SRAM simultaneously.
Here's an elegant insight: FlashAttention and sparse attention are complementary. FlashAttention makes exact attention faster. Sparse attention reduces the number of tiles that need processing. Combine them: use FlashAttention for the tiles you DO compute, and skip the tiles you don't. Mistral 7B (2023) uses this approach — sliding window attention with FlashAttention — and it became one of the most efficient open-source models.
| Version | Year | Key Innovation | Speedup vs Previous |
|---|---|---|---|
| FlashAttention-1 | 2022 | Tiling + online softmax + recomputation | 6× vs standard |
| FlashAttention-2 | 2023 | Better parallelism + work partitioning | ~2× vs FA-1 |
| FlashAttention-3 | 2024 | H100 features: FP8, TMA, pipelining | ~1.5× vs FA-2 |
| FA + Sparse | 2023+ | Skip tiles via sparsity masks | Additional gains on top |
FlashAttention didn't just save memory — it changed what's possible:
| Concept | One-Liner |
|---|---|
| The bottleneck | Standard attention is memory-bound: the N×N matrix keeps bouncing between SRAM and HBM |
| Arithmetic intensity | FLOPs / Bytes accessed. If below the GPU's ratio, you're memory-bound |
| Tiling | Break Q, K, V into blocks that fit in SRAM. Compute tile of S, use it, discard it |
| Online softmax | Running max + running sum let you compute softmax incrementally across tiles |
| Rescaling | exp(mold − mnew) corrects old partial results when a new tile has larger values |
| Kernel fusion | One GPU kernel does matmul + mask + softmax + dropout instead of five separate ones |
| Recomputation | Recompute attention tiles in backward pass instead of caching N×N matrix. Trades cheap FLOPs for expensive memory |
| Exact output | FlashAttention computes the same result as standard attention. Zero approximation |
| HBM access | Standard: O(Nd + N²). FlashAttention: O(N²d²/M). About 9× reduction in practice |
| Resource | Value |
|---|---|
| HBM capacity | 40–80 GB |
| HBM bandwidth | ~2 TB/s |
| SRAM capacity (total) | ~20 MB |
| SRAM bandwidth | ~19 TB/s |
| BF16 compute | 312 TFLOPS |
| Compute/Memory ratio | 161 FLOPs/Byte |
Upstream:
Downstream: