CS 229s — Systems for Machine Learning

FlashAttention & Efficient Attention

The N×N attention matrix never touches slow memory. Tiled computation, online softmax, and the IO-aware algorithm that changed how we run transformers.

Prerequisites: Attention mechanism + Memory hierarchy. That's it.
9
Chapters
6+
Simulations
0
Assumed Knowledge

Chapter 0: The Long Sequence Problem

You want a language model that can read an entire textbook. Not a paragraph — an entire 300-page mathematics textbook, all at once. Or you want it to process a full second of raw audio (thousands of timesteps), or analyze long-range interactions in the human genome across 100,000+ nucleotide pairs.

The obstacle is attention. The core operation in a transformer computes pairwise similarities between every token and every other token. If your sequence has N tokens, the attention matrix is N × N. Double the sequence length and you quadruple the memory and compute.

Memory(attention) = O(N2)  Compute(attention) = O(N2d)

At N = 1,024 tokens, that matrix has about 1 million entries. Manageable. At N = 16,384, it has 268 million entries. At N = 131,072 (a modest book), it has 17 billion entries. The naive approach collapses.

The core tension: Attention is what makes transformers powerful — it lets any token attend to any other, capturing long-range dependencies. But the N×N cost is what makes them expensive. Can we keep the power without the quadratic penalty?
Attention Memory Scaling

Drag the slider to increase sequence length N. Watch memory explode quadratically. The red line is the N×N attention matrix. The teal line is the Q, K, V storage (linear in N).

Sequence length N1024

This chapter's question: is the quadratic cost fundamental? Do we really need every token to talk to every other token? And even if we do, must we store that enormous matrix in memory?

The answer, as we'll discover, is no. FlashAttention computes exact attention — identical results to the naive algorithm — without ever materializing the full N×N matrix. But to understand how, we first need to survey the landscape of approaches people tried before it.

Check: Why does doubling the sequence length quadruple the attention cost?

Chapter 1: The Efficient Attention Taxonomy

Before FlashAttention, researchers attacked the quadratic bottleneck by changing the attention algorithm itself — computing something cheaper that approximates attention. These approaches fall into three families: sparse, low-rank, and kernel.

Sparse Attention

Instead of letting every token attend to every other token, sparse attention restricts which pairs of tokens interact. Think of it as poking holes in the N×N matrix — only computing certain entries.

Sparse Transformers (Child et al., 2019) use a sliding window: each token attends only to nearby tokens within a fixed window size w. Complexity drops from O(N²) to O(N√N). BigBird (Zaheer et al., 2021) combines three patterns: local windows (nearby tokens), global tokens (a few special tokens attend to everything), and random connections (to reduce graph diameter). Complexity: O(N).

Analogy: Full attention is like a conference where every person talks to every other person. Sparse attention is like organizing small breakout rooms with a few ambassadors who visit all rooms.

Low-Rank Attention

Low-rank methods observe that the N×N attention matrix often has low effective rank — it can be approximated by projecting K and V to a much smaller dimension k ≪ N. Linformer (Wang et al., 2020) projects the N×d key and value matrices down to k×d using learned projection matrices, reducing the attention computation to k×k instead of N×N.

The downside: hard to maintain causality (the model shouldn't peek at future tokens), and quality degrades when the true rank is higher than k.

Kernel Attention

Kernel methods replace the exp(Q KT) similarity with a decomposable feature map: sim(qi, kj) ≈ φ(qi)T φ(kj). The Linear Transformer (Katharopoulos et al., 2020) uses this trick to avoid computing the full N×N matrix. Without the softmax, you can exploit the associative property of matrix multiplication:

Standard: (softmax(QKT)) V  —  O(N2d)
Linear: φ(Q) (φ(K)T V)  —  O(Nd2)

When d ≪ N (which it usually is), this is a win. But the approximation quality suffers, especially on language tasks where the sharp attention patterns that softmax produces matter.

Sparsity Patterns

Click the buttons to see different attention patterns. Orange cells are computed; dark cells are skipped. Full attention fills the whole grid.

The big tradeoff: All three families trade quality for speed. They compute something different from standard attention. On the Pile language modeling benchmark, these efficient variants consistently underperform exact attention by measurable perplexity gaps. Is there an algorithm that's both fast AND exact?
MethodTypeComplexityExact?Tradeoff
StandardO(N²d)YesSlow for long N
Sparse TransformerSparseO(N√N)NoMisses global patterns
BigBirdSparseO(N)NoFixed pattern, not learned
LinformerLow-rankO(Nk)NoRank assumption, no causality
Linear TransformerKernelO(Nd²)NoNo softmax → weaker attention
FlashAttentionIO-awareO(N²d)YesSame FLOPs, fewer memory accesses

Notice that FlashAttention doesn't reduce the FLOP count. It has the same O(N²d) compute. Its insight is entirely different: the problem isn't that we're doing too many operations — it's that we're moving data too slowly. The bottleneck is memory bandwidth, not compute.

Check: What do sparse, low-rank, and kernel attention methods have in common?

Chapter 2: Why Standard Attention is Slow

Here's the surprise. Standard attention on a modern GPU like the A100 isn't slow because of arithmetic. The A100 can perform 312 TFLOPS of BF16 computation per second. For a sequence of N = 1024 with head dimension d = 64, the attention requires roughly 134 million FLOPs. At 312 TFLOPS, that's done in microseconds. So why does attention take milliseconds?

The answer: memory traffic. Let's trace what happens step by step in a standard PyTorch attention implementation.

Step 1: Matmul
Read Q, K from HBM → Compute S = QKT → Write S to HBM
Step 2: Masking
Read S from HBM → Apply mask → Write masked S to HBM
Step 3: Softmax
Read S from HBM → Compute softmax(S) = A → Write A to HBM
Step 4: Dropout
Read A from HBM → Apply dropout → Write masked A to HBM
Step 5: Matmul
Read A, V from HBM → Compute O = AV → Write O to HBM

Count the memory trips. The N×N matrix S gets written to HBM after Step 1, read back for Step 2, written again, read back for Step 3, and so on. Each of those lightweight element-wise operations (masking, softmax, dropout) is sandwiched between expensive HBM reads and writes. The compute in each step is trivial; the data movement is not.

The key realization: Standard attention is memory-bound, not compute-bound. The GPU spends most of its time waiting for data to arrive from slow memory, not doing arithmetic. This is the insight that FlashAttention exploits.

Compute-bound vs. Memory-bound

An operation is compute-bound when the time is dominated by arithmetic: matmuls with large inner dimensions fall here. An operation is memory-bound when the time is dominated by moving bytes between memory levels: element-wise ops (masking, softmax, dropout) with large tensors fall here.

The way to tell: compute the arithmetic intensity — the ratio of FLOPs to bytes transferred.

Arithmetic Intensity = FLOPs / Bytes accessed

The A100 has a compute bandwidth of 312 TFLOPS and a memory bandwidth of 1,935 GB/s. The crossover ratio is:

312 TFLOPS / 1.935 TB/s = 161 FLOPs/Byte

If your operation's arithmetic intensity is below 161, it's memory-bound. If above, it's compute-bound. Let's check attention.

Worked Example: Is Attention Memory-Bound?

For the softmax step: N² values must be read and N² values written. Each BF16 value is 2 bytes. Total memory: 2 × N² × 2 = 4N² bytes. FLOPs: roughly 5N² (exp, subtract max, divide). Arithmetic intensity: 5N² / 4N² ≈ 1.25 FLOPs/Byte. That's 128× below the A100's crossover. Massively memory-bound.

For comparison, a large matrix multiplication like (8192 × 8192) × (8192 × 8192) has arithmetic intensity around 2,730 FLOPs/Byte — solidly compute-bound.

Arithmetic Intensity Calculator

Adjust sequence length and head dimension. The chart shows arithmetic intensity for each step of attention. The red dashed line is the A100's compute/memory ratio (161). Everything below it is memory-bound.

Seq length N1024
Head dim d64
Key insight: Lower FLOPs does NOT mean faster wall-clock time. An algorithm with 2× the FLOPs but 10× fewer memory accesses will run 5× faster on a memory-bound workload. This is the entire thesis of FlashAttention.
Check: Why is the softmax step of attention memory-bound on an A100?

Chapter 3: GPU Memory Hierarchy

To understand FlashAttention, you need a mental model of where data lives inside a GPU. A modern GPU has (at least) three levels of memory, and the speed difference between them is staggering.

LevelSize (A100)BandwidthAnalogy
Registers256 KB / SM>19 TB/sYour hands
SRAM (shared memory)192 KB / SM (~20 MB total)~19 TB/sYour desk
HBM (DRAM)40–80 GB1.5–2 TB/sThe warehouse

HBM (High Bandwidth Memory) is the big, slow memory. 40–80 GB is plenty of room, but you can only read from it at about 1.5–2 TB/s. SRAM (on-chip shared memory) is tiny but fast — about 20 MB total across all streaming multiprocessors (SMs), but accessible at roughly 19 TB/s. That's a 10× bandwidth gap.

Analogy: HBM is a warehouse full of supplies. SRAM is the small workbench right next to you. Fetching a tool from the warehouse takes ten trips down the hall. Grabbing something from your workbench is instant. Standard attention keeps walking to the warehouse for every tiny operation. FlashAttention organizes all the tools on the workbench first.

Here's the problem for standard attention. The N×N attention matrix is large. For N = 4096 in BF16, it's 4096² × 2 bytes = 32 MB. That exceeds SRAM capacity. So standard implementations have no choice: they write S to HBM after the QKT matmul, read it back for softmax, write the result back, read it again for dropout, and so on.

But the Q, K, V tiles? A single tile of, say, 128 × 64 in BF16 is only 128 × 64 × 2 = 16 KB. A pair of tiles (one from Q, one from K) fits comfortably in SRAM. This is the opening that FlashAttention exploits.

Memory Hierarchy Visualizer

The bars represent how much of each memory level the attention matrix occupies at different sequence lengths. Watch how quickly HBM fills up, while SRAM can only hold tiny tiles.

Sequence length N2048

How GPU Execution Works

A GPU runs kernels — functions that execute in parallel across thousands of threads. Threads are grouped into warps (32 threads), warps into thread blocks, and thread blocks are assigned to streaming multiprocessors (SMs). Each thread block has access to its SM's shared memory (SRAM). Data that multiple threads need can be loaded once into shared memory, avoiding redundant HBM reads. This is called tiling.

Kernel fusion is the other key technique: instead of launching five separate GPU kernels (matmul, mask, softmax, dropout, matmul) with HBM writes between each one, we fuse them into a single kernel that does all the work while data is still in SRAM.

FlashAttention's strategy in one sentence: Tile the attention computation so that each tile fits in SRAM, fuse all the operations (matmul + mask + softmax + dropout) into one kernel, and never write the N×N matrix to HBM.
Check: Why can't we just hold the full N×N attention matrix in SRAM?

Chapter 4: FlashAttention — The Key Idea

Here's the plan. Instead of computing the full N×N matrix and storing it, we break Q, K, V into small blocks that fit in SRAM. For each pair of blocks, we compute a partial attention result, then combine the partial results to get the exact final answer.

Concretely: split Q into blocks of Br rows and K, V into blocks of Bc rows. The outer loop iterates over K/V blocks (j = 1, ..., Tc). The inner loop iterates over Q blocks (i = 1, ..., Tr). For each (i, j) pair, we compute the Br × Bc tile of the attention matrix — and immediately use it to update the running output, then discard it.

1. Load tiles
Load Kj, Vj from HBM to SRAM (one outer-loop block)
2. Inner loop
For each Qi: load Qi, compute Sij = QiKjT in SRAM
3. Local softmax
Compute local max mij and local sum lij for this tile
4. Update output
Rescale running Oi with correction factors, add new tile's contribution
5. Discard tile
The Br × Bc tile of S never leaves SRAM. Write final O to HBM.
The magic: The Br × Bc tile of the attention matrix is computed in SRAM, used immediately, and thrown away. It is never written to HBM. The full N×N matrix never exists anywhere at any time.

But Wait — What About Softmax?

There's one serious obstacle. Softmax is a global operation. For each row i, softmax divides by the sum of exp(sij) across ALL columns j = 1, ..., N. If we only have one tile at a time, we don't have the global sum. We can't do softmax.

Or can we?

The answer is online softmax — an algorithm that computes softmax incrementally, one tile at a time, maintaining a running maximum and running sum. When a new tile arrives with larger values, it rescales all the previous partial results. At the end, the answer is identical to computing softmax over the entire row at once.

This is the algorithmic heart of FlashAttention, and we'll derive it step by step in the next chapter.

Memory Analysis

Standard attention reads/writes the N×N matrix multiple times. Total HBM access: O(Nd + N²). FlashAttention only reads Q, K, V once and writes O once, but does it multiple times in tiles. Total HBM access: O(N²d²M−1), where M is SRAM size. Since M is roughly 20 MB on an A100:

Standard HBM access: O(Nd + N²)
FlashAttention HBM access: O(N²d² / M)

In practice, this is a 9× reduction in HBM reads and writes. Despite doing 13% more FLOPs (due to the rescaling arithmetic), FlashAttention is 6× faster in wall-clock time.

MetricStandardFlashAttention
GFLOPs66.675.2 (↑13%)
HBM reads/writes (GB)40.34.4 (↓9×)
Runtime (ms)41.77.3 (↓6×)
The lesson: Fewer FLOPs ≠ faster wall-clock. FlashAttention does MORE computation but moves LESS data. On a memory-bound workload, that's the correct trade.
Check: What is the key obstacle to tiling the attention computation?

Chapter 5: Online Softmax — The Algorithmic Heart

This is the core derivation. If you understand this chapter, you understand FlashAttention. Let's build online softmax from scratch.

Standard Softmax

Given a vector x = [x1, x2, ..., xN], the softmax is:

softmax(xi) = exp(xi) / ∑j=1..N exp(xj)

This requires two passes over the data: first compute all exp(xj) and sum them, then divide each by the sum. Both passes need all N values simultaneously.

Step 1: Numerical Stability

In practice, we first subtract the maximum value for numerical stability. If xmax = 42, then exp(42) overflows float32. But exp(42 − 42) = exp(0) = 1 is fine. The result is identical because:

exp(xi − m) / ∑j exp(xj − m) = exp(xi) / ∑j exp(xj)

where m = max(x). The exp(m) cancels in numerator and denominator. So stable softmax needs three passes: (1) find max, (2) compute exp and sum, (3) divide.

Step 2: The Tiling Problem

Now suppose we split x into two blocks: x = [x(1), x(2)]. We process block 1 first, then block 2. After block 1, we have:

m(1) = max(x(1))
l(1) = ∑j ∈ block 1 exp(xj − m(1))
o(1) = (1/l(1)) ∑j ∈ block 1 exp(xj − m(1)) · vj

This is a valid local softmax weighted by block 1's values. But the denominator l(1) only reflects block 1. It's wrong as a global softmax.

Step 3: Rescaling After Block 2

When we process block 2, we discover a new maximum and new sum. We need to fix our earlier result. Here's the algebra, step by step.

After processing block 2, we have:

m(2) = max(x(2))
mnew = max(m(1), m(2))

We need to rescale block 1's exponentials. Originally we computed exp(xj − m(1)). We need exp(xj − mnew). The correction factor is:

exp(xj − mnew) = exp(xj − m(1)) · exp(m(1) − mnew)

So we multiply all of block 1's contributions by exp(m(1) − mnew). Since mnew ≥ m(1), this factor is ≤ 1 (it shrinks old values that were computed with a smaller max). Similarly for block 2's terms: multiply by exp(m(2) − mnew).

The corrected sums:

lnew = l(1) · exp(m(1) − mnew) + l(2) · exp(m(2) − mnew)

And the corrected output:

onew = (1/lnew) [ l(1) · exp(m(1) − mnew) · o(1) + l(2) · exp(m(2) − mnew) · o(2) ]
This is exact. No approximation whatsoever. The rescaling undoes the effect of the old max and applies the new global max. At the end of all blocks, mnew is the true global max, lnew is the true global sum, and the output is identical to standard attention.

Step 4: The General Recurrence

For B blocks processed sequentially, after processing block j, we update:

mj = max(mj−1, max(Sj))
lj = lj−1 · exp(mj−1 − mj) + ∑ exp(Sj − mj)
Oj = (lj−1 / lj) · exp(mj−1 − mj) · Oj−1 + (1 / lj) · exp(Sj − mj) · Vj

Three running scalars per row: the maximum m, the sum l, and the output vector O. These are tiny — they live in registers. The N×N matrix tiles are computed and discarded in SRAM.

Worked Example: 4×4 Matrix in 2×2 Tiles

Let's compute attention for a concrete 4×4 score matrix, tiled into 2×2 blocks, to see the online softmax in action.

Score matrix S (after QKT):

j=1j=2j=3j=4
i=11.02.00.50.1
i=20.31.52.00.8

(Showing rows 1–2 for clarity. The process is identical for rows 3–4.)

Block 1 (columns 1–2 for row 1): scores = [1.0, 2.0]

m(1) = max(1.0, 2.0) = 2.0
exp(1.0 − 2.0) = 0.368,   exp(2.0 − 2.0) = 1.000
l(1) = 0.368 + 1.000 = 1.368
o(1) = (0.368 · v1 + 1.000 · v2) / 1.368

Block 2 (columns 3–4 for row 1): scores = [0.5, 0.1]

m(2) = max(0.5, 0.1) = 0.5
mnew = max(2.0, 0.5) = 2.0  (no change)
Correction: exp(m(1) − mnew) = exp(0) = 1.0  (old values unchanged)
exp(0.5 − 2.0) = 0.223,   exp(0.1 − 2.0) = 0.150
lnew = 1.368 · 1.0 + (0.223 + 0.150) = 1.741
onew = (1.368/1.741) · o(1) + (0.223 · v3 + 0.150 · v4) / 1.741

That final onew is exactly the same as computing softmax([1.0, 2.0, 0.5, 0.1]) and multiplying by [v1, v2, v3, v4]. You can verify: 0.368/1.741 = 0.211, 1.000/1.741 = 0.574, 0.223/1.741 = 0.128, 0.150/1.741 = 0.086. These sum to 1.0 (within rounding).

Online Softmax — Tile-by-Tile Computation

Click Next Tile to process one K/V block at a time. Watch the running max (m), running sum (l), and output weights update with each tile. The final weights match global softmax exactly.

Ready — 0/4 tiles processed
Verification: The online softmax produces IDENTICAL results to standard softmax. It's not an approximation. The rescaling factors are exact correction terms. This is why FlashAttention can be a drop-in replacement for standard attention.
Check: When processing a new tile whose max is smaller than the running max, what happens to the correction factor?

Chapter 6: The Full FlashAttention Algorithm

Now we put it all together. We've established three principles: (1) tile to fit in SRAM, (2) online softmax to combine tiles exactly, (3) never write the N×N matrix to HBM. Here's the complete forward pass.

Algorithm: FlashAttention Forward Pass

pseudocode
# Inputs: Q, K, V in HBM (each N x d)
# Block sizes: B_r, B_c chosen so tiles fit in SRAM
# Output: O in HBM (N x d)

# Initialize output O = 0, running max m = -inf, running sum l = 0
O = zeros(N, d)         # in HBM
m = full(N, -inf)       # in HBM (one per row)
l = zeros(N)             # in HBM (one per row)

# Divide K, V into T_c blocks of size B_c
# Divide Q into T_r blocks of size B_r

for j = 1 to T_c:                # outer loop over K/V blocks
    K_j, V_j = load_from_HBM(K[j], V[j])  # to SRAM

    for i = 1 to T_r:            # inner loop over Q blocks
        Q_i = load_from_HBM(Q[i])         # to SRAM
        O_i, m_i, l_i = load_from_HBM(O[i], m[i], l[i])

        # Compute tile of attention scores (in SRAM!)
        S_ij = Q_i @ K_j.T               # B_r x B_c, never leaves SRAM

        # Online softmax update
        m_new  = max(m_i, rowmax(S_ij))
        l_new  = l_i * exp(m_i - m_new) + rowsum(exp(S_ij - m_new))
        O_new  = (l_i * exp(m_i - m_new) * O_i
                 + exp(S_ij - m_new) @ V_j) / l_new

        # Write updated O, m, l back to HBM
        write_to_HBM(O_new, m_new, l_new)

return O
Loop order matters. The outer loop is over K/V blocks, the inner loop over Q blocks. This means each K/V block is loaded from HBM once and reused across all Q blocks. Since K and V are typically the larger reads, this minimizes total HBM traffic.

What Gets Stored Where?

DataLives inSizeLifetime
Q, K, V, OHBMN × d eachEntire computation
m, lHBM (+ registers)N each (one per row)Entire computation
Qi, Kj, Vj tilesSRAMBr × d, Bc × dOne iteration
Sij tileSRAMBr × BcOne iteration, then discarded

Total SRAM needed: Br × d + 2 × Bc × d + Br × Bc. On an A100 with ~192 KB of shared memory per SM, typical block sizes are Br = Bc = 128 with d = 64. That's 128 × 64 × 3 × 2 bytes = 48 KB for tiles, plus 128 × 128 × 2 = 32 KB for Sij. Total: 80 KB. Fits comfortably.

Backward Pass: Recomputation

Training requires a backward pass. Normally, the forward pass saves the N×N attention matrix for use in backpropagation. FlashAttention doesn't have it — we threw it away. The solution: recompute the attention tiles during the backward pass.

This sounds wasteful. We do more FLOPs (about 13% more). But remember: the bottleneck is memory bandwidth, not compute. By recomputing tiles in SRAM instead of reading a huge cached matrix from HBM, the backward pass is still faster.

The only things saved from the forward pass: the output O, and the per-row m and l vectors (total: 2N scalars plus the N × d output). From these, plus the original Q, K, V, we can reconstruct any tile of the attention matrix on the fly.

Key insight: Recomputation beats caching when you're memory-bound. It trades cheap FLOPs for expensive memory accesses. This is the "caching vs. recomputation" principle from systems design — for compute-bound workloads, cache; for memory-bound workloads, recompute.
FlashAttention Tiling Simulator

Watch the algorithm process tiles one at a time. Orange = current tile in SRAM. Teal = completed tiles (discarded from SRAM). Purple = current Q and K/V blocks loaded. The full N×N matrix never exists at once.

Ready
Grid size8
Tile size2
Check: In FlashAttention's backward pass, why is recomputing the attention tiles faster than reading a cached N×N matrix from HBM?

Chapter 7: FlashAttention-2 & Beyond

FlashAttention-1 was a breakthrough. FlashAttention-2 (Dao, 2023) made it even faster by fixing inefficiencies in how work is distributed across GPU thread blocks and warps.

FlashAttention-2 Improvements

1. Better work partitioning. FA-1's inner loop assigns each thread block to a (Q-block, K/V-block) pair. But this means different thread blocks working on the same Q-block need to communicate through slow HBM to share their partial outputs. FA-2 restructures the loops: the outer loop iterates over Q blocks, and the inner loop over K/V blocks, so each thread block fully owns one Q-block and accumulates the output in SRAM without cross-block communication.

2. Reduced non-matmul FLOPs. Modern GPUs have tensor cores that accelerate matrix multiplications dramatically (up to 16× on H100). Non-matmul operations (the rescaling, max-finding, and bookkeeping) run on regular CUDA cores, which are much slower. FA-2 reduces these non-matmul operations, keeping tensor cores busier.

3. More parallelism. FA-1 parallelizes over batch size and number of attention heads. FA-2 also parallelizes over the sequence length dimension. For long sequences with small batch sizes (common in inference), this better saturates the GPU.

The result: FlashAttention-2 achieves roughly 2× the speed of FlashAttention-1, reaching 50-73% of the theoretical maximum FLOPS on A100 GPUs. This is remarkably close to the peak for an attention kernel.

FlashAttention-3

FlashAttention-3 (2024) targets the H100 GPU and exploits its new features: asynchronous memory copies (overlap compute and data movement), low-precision matmuls (FP8 tensor cores), and the TMA (Tensor Memory Accelerator) unit for efficient tiling. It also introduces software pipelining — while one tile is being computed, the next tile is being loaded from HBM to SRAM simultaneously.

Combining FlashAttention with Sparse Attention

Here's an elegant insight: FlashAttention and sparse attention are complementary. FlashAttention makes exact attention faster. Sparse attention reduces the number of tiles that need processing. Combine them: use FlashAttention for the tiles you DO compute, and skip the tiles you don't. Mistral 7B (2023) uses this approach — sliding window attention with FlashAttention — and it became one of the most efficient open-source models.

VersionYearKey InnovationSpeedup vs Previous
FlashAttention-12022Tiling + online softmax + recomputation6× vs standard
FlashAttention-22023Better parallelism + work partitioning~2× vs FA-1
FlashAttention-32024H100 features: FP8, TMA, pipelining~1.5× vs FA-2
FA + Sparse2023+Skip tiles via sparsity masksAdditional gains on top

Impact on Training and Inference

FlashAttention didn't just save memory — it changed what's possible:

The meta-lesson: FlashAttention is not a new model architecture. It computes the exact same thing as standard attention. The entire speedup comes from understanding and respecting the hardware memory hierarchy. Algorithms and systems are not separate disciplines — the best algorithms are designed with hardware in mind.
Check: What is the main improvement in FlashAttention-2 over FlashAttention-1?

Chapter 8: Connections & Cheat Sheet

FlashAttention Cheat Sheet

ConceptOne-Liner
The bottleneckStandard attention is memory-bound: the N×N matrix keeps bouncing between SRAM and HBM
Arithmetic intensityFLOPs / Bytes accessed. If below the GPU's ratio, you're memory-bound
TilingBreak Q, K, V into blocks that fit in SRAM. Compute tile of S, use it, discard it
Online softmaxRunning max + running sum let you compute softmax incrementally across tiles
Rescalingexp(mold − mnew) corrects old partial results when a new tile has larger values
Kernel fusionOne GPU kernel does matmul + mask + softmax + dropout instead of five separate ones
RecomputationRecompute attention tiles in backward pass instead of caching N×N matrix. Trades cheap FLOPs for expensive memory
Exact outputFlashAttention computes the same result as standard attention. Zero approximation
HBM accessStandard: O(Nd + N²). FlashAttention: O(N²d²/M). About 9× reduction in practice

Key Numbers (A100)

ResourceValue
HBM capacity40–80 GB
HBM bandwidth~2 TB/s
SRAM capacity (total)~20 MB
SRAM bandwidth~19 TB/s
BF16 compute312 TFLOPS
Compute/Memory ratio161 FLOPs/Byte

Related Topics

Upstream:

Downstream:

  • Quantization — reducing model size while FlashAttention reduces memory access
  • State Space Models — an alternative to attention for long sequences
  • SSM / Mamba — linear-time sequence modeling
Closing thought: "The purpose of abstraction is not to be vague, but to create a new semantic level in which one can be absolutely precise." — Edsger Dijkstra. FlashAttention succeeds because it breaks the abstraction between algorithm and hardware, then builds a better one.
Final check: FlashAttention computes exact attention 6× faster. Where does the speedup come from?