The attention kernel that co-designs algorithm and hardware pipelining for NVIDIA Blackwell GPUs — exploiting tensor memory, 2-CTA MMA, software-emulated exponentials, and conditional softmax rescaling to reach 1613 TFLOPS (71% utilization) on B200.
You just got access to an NVIDIA B200 GPU. It has 2.25 PFLOPS of tensor core throughput — double the H100. You load your FlashAttention-3 kernel and expect a 2x speedup. Instead, you get... almost nothing. The kernel barely runs faster. What happened?
The answer is asymmetric hardware scaling. NVIDIA doubled the tensor core throughput from Hopper to Blackwell (1 PFLOPS to 2.25 PFLOPS for BF16), but they did not double everything else. Shared memory bandwidth stayed at 128 bytes/clock/SM. The exponential unit (MUFU) stayed at 16 ops/clock/SM. Integer ALUs, floating-point ALUs — all roughly the same. The GPU got faster at matrix multiplies, but everything else stayed put.
This means the bottleneck has shifted. On Hopper, the tensor cores were the limiting factor — you designed your kernel to keep them busy. On Blackwell, the tensor cores finish so fast that they spend most of their time waiting — waiting for data from shared memory, waiting for the exponential unit to finish softmax, waiting for non-matmul operations to complete.
FlashAttention has evolved across four GPU generations, each time adapting to the hardware's shifting bottlenecks:
| Version | Year | Target GPU | Key Innovation | Bottleneck Addressed |
|---|---|---|---|---|
| FA1 | 2022 | A100 | Tiled attention + online softmax | HBM bandwidth (no materializing N×N matrix) |
| FA2 | 2023 | A100/H100 | Reduced non-matmul FLOPs, parallelized over seqlen | Non-matmul operations on GPU cores |
| FA3 | 2024 | H100 (Hopper) | Warp specialization, async TMA, FP8 support | Overlap compute with memory via Hopper async features |
| FA4 | 2026 | B200 (Blackwell) | Pingpong pipeline, software exp, conditional rescaling, 2-CTA MMA | Non-matmul bottleneck from asymmetric scaling |
Each generation didn't just "go faster" — it fundamentally restructured which operations are on the critical path and which can be hidden behind tensor core compute. FA4 is the most radical restructuring yet, because the imbalance on Blackwell is the most extreme yet.
The numbers tell the story:
How does FA4 pull this off? Through four interlocking innovations that we will build up piece by piece over the next nine chapters. But first, we need to understand the hardware.
To understand FA4's design decisions, you need a mental model of the Blackwell GPU's memory hierarchy and execution units. Every optimization in FA4 targets a specific hardware resource. Let's map them out.
Data on a Blackwell GPU lives in a hierarchy of storage, each level trading capacity for speed:
Blackwell's tensor cores process 128 × N tiles (N = 128 or 256), compared to Hopper's 64 × N tiles. The throughput is 8192 FLOPs/clock/SM for BF16 — exactly double Hopper's 4096. Each MMA instruction writes its output directly to TMEM asynchronously, meaning computation and other operations can overlap without blocking on register writeback.
From finest to coarsest:
| Unit | Size | Role in FA4 |
|---|---|---|
| Thread | 1 | Smallest execution unit |
| Warp | 32 threads | SIMT execution group |
| Warpgroup | 4 warps = 128 threads | One warpgroup per accumulator tile. FA4 uses 2 warpgroups for pingpong. |
| CTA (threadblock) | 1–2 warpgroups | Co-scheduled on one SM. Shares SMEM. |
| Thread Block Cluster | Multiple CTAs | Co-scheduled on same GPC. Can share SMEM across CTAs via DSMEM in 2-CTA mode. |
| Unit | Throughput (per clock per SM) | Role |
|---|---|---|
| Tensor cores (BF16 MMA) | 8192 FLOPs | QKT and PV matrix multiplies |
| Exponential unit (MUFU) | 16 ops | exp() for softmax — same as Hopper! |
| Shared memory read | 128 bytes | Feeding operands to MMA — same as Hopper! |
| TMA (Tensor Memory Accelerator) | Async, non-blocking | HBM ↔ SMEM transfers without occupying SMs |
Blackwell introduces a new MMA mode where a CTA pair within the same thread block cluster cooperatively executes a single MMA. One CTA initiates the operation; the peer CTA must be launched and active. The paired mode supports M = 128 or 256 by partitioning the A tile in dimension M and the B tile across the two CTAs in dimension N, so each CTA stages only half of operand B in its own shared memory while the hardware consumes the combined B tile. This reduces redundant shared memory capacity and bandwidth.
For FA4's backward pass, this 2-CTA mode is critical: it lets the kernel use M = 256 tiles (vs M = 128 in single-CTA mode), roughly halving the shared memory traffic for operand B and halving the number of global atomic reductions for dQ accumulation.
Data flows from HBM through L2 and SMEM to tensor cores. TMEM (new in Blackwell) stores MMA outputs directly, freeing registers. Click the levels to see capacity and bandwidth.
Before writing a single line of kernel code, FA4's authors do something crucial: a roofline analysis. This means calculating, for each hardware resource, how many clock cycles it needs to process one tile of attention. The resource that takes the most cycles is the bottleneck — and all kernel design effort should go toward either speeding it up or hiding it behind something else.
Let the tile shape along the sequence length of Q and K be M × N, and the head dimension be d. The forward pass performs two MMAs per iteration of the inner loop:
Each MMA requires 2MNd floating-point operations (two matmuls). The total MMA compute time is:
For shared memory traffic, the QKT MMA reads both operands from SMEM (shared-shared), while PV reads P from tensor memory and V from SMEM (tensor-shared). At 128 bytes/cycle and 2 bytes per BF16 element:
The exponential unit computes exp() for M × N values in the softmax, at 16 ops/cycle:
| Resource | 1283 (M=N=d=128) | 256 × 1282 |
|---|---|---|
| MMA compute | 1024 cycles | 2048 cycles |
| Shared memory | 768 cycles | 1536 cycles |
| Exponential unit | 1024 cycles | 2048 cycles |
The backward pass performs five MMAs per iteration (recomputing S, plus computing dP, dV, dQ, dK). The shared memory traffic becomes the dominant bottleneck:
| Resource | 1-CTA (M=128) | 2-CTA (M=256) |
|---|---|---|
| MMA compute | 2560 | 2560 |
| Total shared memory | 3328 | 2688 |
| Exponential unit | 1024 | 1024 |
Shared memory traffic exceeds MMA compute by 30% in 1-CTA mode. Even in 2-CTA mode (which halves some shared memory reads), it still exceeds MMA compute by ~5%. This is why FA4 uses the 2-CTA MMA mode for the backward pass — it is the only way to bring shared memory traffic close to MMA compute time.
Cycle counts for each hardware resource on the forward pass. The tallest bar is the bottleneck. Toggle between Hopper and Blackwell to see how the bottleneck shifts.
The core scheduling idea in FA4 is the same as FA3: pingpong between two warpgroups. While one warpgroup's tile runs tensor core operations (QKT or PV), the other warpgroup computes softmax. They alternate, keeping both the tensor cores and the non-matmul units busy simultaneously.
But on Blackwell, this is harder to get right. The accumulators now live in TMEM (not registers), the tile sizes are 128×128 (not 64×128), and the imbalance between MMA speed and everything else is more extreme. FA4 redesigns the pipeline from scratch.
Each SM runs one CTA (threadblock) with two warpgroups of 128 threads each. At any given moment:
| Phase | Warpgroup A | Warpgroup B |
|---|---|---|
| Even iteration | MMA: compute SH = QHKjT, then OH += PHVj | Softmax: load SL from TMEM, compute max, rescale, exp, row sum |
| Odd iteration | Softmax: load SH from TMEM, compute max, rescale, exp, row sum | MMA: compute SL = QLKj+1T, then OL += PLVj+1 |
The superscripts H and L denote the "high" and "low" Q tiles — each warpgroup owns 128 rows of Q. Together they process 256 query rows per SM. The key constraint: the two softmax warpgroups must not overlap in their critical section (the exponential computation), because both use the same MUFU unit on the SM.
This is where the design gets intricate. TMEM must hold the intermediate results that bridge the MMA and softmax phases. FA4 allocates TMEM as follows:
Each softmax warpgroup processes one row of the attention matrix at a time. On Blackwell's larger 128×128 tiles, each thread must hold an entire row of 128 elements in registers. The softmax procedure per warpgroup:
For BF16 inputs, each thread needs: 128 registers for the S row input, 64 registers for output staging, plus miscellaneous temporaries. To reduce register pressure, FA4 stages the output P in quarters: the first three quarters are stored once (triggering the corresponding MMA operations), and the last quarter is stored separately. This careful staging keeps register usage within the 256-register-per-thread limit.
Two warpgroups alternate between MMA (tensor core) and softmax (MUFU + FMA). Watch how one warpgroup's MMA overlaps with the other's softmax. Click "Step" to advance one phase, or "Play" to animate.
The MUFU (Multi-Function Unit) computes exp() at 16 operations per clock per SM. For a 128×128 tile, that's 128×128 / 16 = 1024 cycles — matching the MMA compute time exactly. This means even with perfect pingpong overlap, the exponential is right on the edge of being a bottleneck. Any small inefficiency (pipeline bubbles, synchronization) tips it over.
FA4's solution is radical: compute some of the exponentials on the FMA (fused multiply-add) units instead of MUFU. The FMA units are the same ones used for floating-point arithmetic — they can run in parallel with MUFU, effectively doubling the exponential throughput.
The key identity is:
where ⌊x⌋ is the integer part and xfrac = x - ⌊x⌋ ∈ [0, 1) is the fractional part. The integer part 2⌊x⌋ can be computed without any floating-point arithmetic — it's just a bit manipulation on the IEEE 754 exponent field. Shift ⌊x⌋ into the exponent bits and you have your answer.
The fractional part 2xfrac is approximated by a polynomial:
with p0 = 1.0 and the remaining coefficients chosen to minimize relative approximation error. This polynomial is evaluated using Horner's method on the FMA units — each term is one fused multiply-add.
The authors benchmark polynomial approximations of different degrees against the hardware MUFU.EX2 instruction:
| Method | FP32 Max Rel Error | BF16 Max Rel Error |
|---|---|---|
| Hardware MUFU.EX2 | 1.41 × 10-7 | 3.89 × 10-3 |
| Degree 3 polynomial | 8.77 × 10-5 | 3.90 × 10-3 |
| Degree 4 polynomial | 3.05 × 10-6 | 3.89 × 10-3 |
| Degree 5 polynomial | 1.44 × 10-7 | 3.89 × 10-3 |
Computing ALL exponentials via polynomial emulation would be wasteful — the additional registers for intermediate values and coefficients would increase register pressure and could cause spills. Instead, FA4 uses partial emulation: only 10–25% of the entries in each softmax row are emulated via polynomial, with the remaining entries computed via hardware MUFU.EX2. The exact fraction is tuned empirically based on the ratio of MMA and exponential throughput for a given tile configuration.
This distributes the exponential computation across both MUFU and FMA units, alleviating the bottleneck while keeping register pressure manageable.
python # Conceptual software-emulated 2^x (simplified) import struct def fast_exp2(x): # Step 1: Clamp x = max(x, -127.0) # Step 2: Integer part via bit tricks x_floor = int(x) # In GPU: add magic number + round-down # Step 3: Fractional part x_frac = x - x_floor # x_frac in [0, 1) # Step 4: Degree-3 polynomial (Horner's method on FMA units) # Coefficients from Sollya minimizing relative error p3, p2, p1, p0 = 0.07944, 0.2274, 0.6931, 1.0 frac_part = ((p3 * x_frac + p2) * x_frac + p1) * x_frac + p0 # Step 5: Combine via IEEE 754 exponent manipulation # Shift x_floor into exponent field, multiply by frac_part int_part = 2.0 ** x_floor # In GPU: bit shift on exponent return int_part * frac_part
FlashAttention computes attention in blocks. Because we process K/V blocks one at a time, we maintain running statistics for the online softmax: a running maximum mj and a running sum of exponentials lj. When a new block has a larger maximum than the running max, we must rescale all previously accumulated output — multiplying O by emj-1 - mj.
When processing block j, FA computes:
The intermediate output is updated as:
That rescaling factor emj-1 - mj is a vector multiplication — it must be applied to every element of the accumulated O. On Blackwell, where non-matmul operations are the bottleneck, this rescaling is expensive.
FA4 makes two key observations:
Observation 1: Rescaling is only necessary when mj > mj-1, i.e., when the current block contains larger attention scores than all previous blocks. In practice, the running maximum stabilizes quickly — after the first few blocks, most iterations don't need rescaling at all.
Observation 2: We can tolerate some "slack" in the rescaling. If mj - mj-1 > τ (a threshold), we must rescale. But if the difference is small (mj - mj-1 ≤ τ), we can skip rescaling and continue using mj-1. The statistics will be slightly off, but we fix everything at the end with a final normalization.
The threshold τ is typically set to log2(256) = 8.0, corresponding to a rescaling factor of 256.0. As long as we keep track of the true statistics, the final output is exact:
At the very end:
Even when rescaling IS needed, FA4 moves it off the critical path. Since accumulators live in TMEM (not registers), rescaling can be performed by a separate "correction" warpgroup that reads the accumulator from TMEM, multiplies by the rescaling factor, and writes it back — all while the main warpgroups continue with the next iteration's MMA and softmax. This decoupling is impossible on Hopper, where accumulators live in registers that are private to each thread.
To avoid warp divergence, FA4 rescales whenever any thread in the warpgroup needs rescaling. This is a conservative choice that trades a small amount of unnecessary rescaling for simpler control flow.
The backward pass is the harder half. It performs five MMAs per iteration (vs two in forward), recomputing S on the fly, and must accumulate gradients dQ, dK, dV. On Blackwell, shared memory traffic is the dominant bottleneck — exceeding MMA compute by 30%.
Given output gradient dO ∈ RN×d, the backward pass computes:
where dsoftmax(dP) denotes the row-wise softmax gradient: dS = (diag(p) - ppT)dp for p = softmax(s). This requires recomputing S = QKT (since we didn't store it in the forward pass — that's the whole point of FlashAttention). So the five MMAs are:
| MMA | Operand A source | Operand B source | Type |
|---|---|---|---|
| ST = KQT | SMEM (K) | SMEM (Q) | Shared-Shared |
| dPT = VdOT | TMEM (V) | SMEM (dO) | Tensor-Shared |
| dV = PTdO | TMEM (P) | SMEM (dO) | Tensor-Shared |
| dQ = dSK | SMEM (dS) | SMEM (K) | Shared-Shared |
| dK = dSTQ | SMEM (dS) | SMEM (Q) | Shared-Shared |
Three of the five (ST, dQ, dK) are shared-shared — both operands read from SMEM. This is why shared memory bandwidth dominates the backward pass.
In 1-CTA mode (M = 128), the pipeline follows the same pingpong principle as the forward pass, but with 5 MMAs + 2 elementwise operations (softmax and its gradient). The pipeline stages across prologue, main loop, and tail:
Even with the 1-CTA pipeline, shared memory is still the bottleneck. FA4 uses Blackwell's 2-CTA MMA mode to address this. With an MMA tile shape of M = 256 and N = K = 128:
The two CTAs communicate via Distributed Shared Memory (DSMEM) — a Blackwell feature allowing CTAs in the same cluster to access each other's SMEM. FA4 uses DSMEM to exchange half of the dS tile between the two CTAs, since each CTA needs the full dS for its MMA but computes only half of it.
The dQ gradient accumulates via atomic additions to global memory, which are nondeterministic (the order of atomic operations varies between runs). For reproducible training in reinforcement learning, FA4 provides a deterministic mode using semaphore locks: each CTA acquires a lock before writing to a common dQ tile, performs its reduction, and releases the lock. Combined with careful CTA swizzling (processing tiles in an order that minimizes stalls), the deterministic backward pass achieves up to 75% of the nondeterministic version's speed.
Every previous FlashAttention version was written in CUDA C++ using CUTLASS templates — a maze of deeply nested C++ template metaprogramming that took 55 seconds to compile a single kernel. FA4 takes a radically different approach: it is written entirely in CuTe-DSL, a domain-specific language embedded in Python.
CuTe-DSL is the Python frontend for NVIDIA's CuTe (Collective Unrolled Tensor Expression) framework. It provides:
The CuTe-DSL programming model is isomorphic to CUTLASS C++ — every C++ CUTLASS construct has a direct CuTe-DSL equivalent. You lose nothing in expressiveness. But you gain enormously in productivity:
| Property | FA3 (C++ CUTLASS) | FA4 (CuTe-DSL Python) |
|---|---|---|
| Forward compile time | 55 seconds | 2.5 seconds |
| Backward compile time | 45 seconds | 1.4 seconds |
| Speedup | 1x | 22–32x |
| Iteration speed | Minutes per experiment | Seconds per experiment |
| Prerequisite expertise | Deep C++ template metaprogramming | Python + GPU concepts |
CuTe-DSL provides direct access to PTX (GPU assembly) as an escape hatch. FA4 uses custom PTX sequences for operations not yet fully exposed in the CuTe-DSL API, demonstrating that the framework doesn't constrain developers to a limited subset of GPU capabilities. This is critical for a performance-critical kernel that needs every last drop of throughput.
FA4's codebase is designed as a library of composable primitives rather than a monolithic kernel:
Developers have already built FlexAttention and block-sparse attention variants on top of FA4 without modifying the core framework. This modularity is a direct benefit of the CuTe-DSL approach — Python's flexibility makes it natural to compose orthogonal features.
python # Conceptual CuTe-DSL kernel structure (simplified) from cute_dsl import Tensor, Layout, TiledMMA, TMA # Define tile shapes BM, BN, BD = 128, 128, 128 # Define MMA operation (maps to WGMMA instruction) tiled_mma = TiledMMA( instruction="wgmma_bf16_128x128x128", warpgroup_size=128, # 4 warps ) # Define TMA copy from HBM to SMEM tma_load_k = TMA( src="gmem", dst="smem", shape=(BN, BD), dtype="bf16", swizzle=128, # 128-byte aligned for optimal bandwidth ) # Main loop body (one iteration) def attention_step(Q_smem, K_smem, V_smem, O_tmem, m, l): # 1. QK^T via WGMMA (writes S to TMEM) S = tiled_mma(Q_smem, K_smem.T) # → TMEM # 2. Online softmax (reads S from TMEM, writes P to TMEM) m_new = max(m, row_max(S)) if m_new - m > tau: O_tmem *= exp(m - m_new) # conditional rescale P = softmax_partial(S, m_new) # MUFU + FMA emulation # 3. PV via WGMMA (accumulates into O in TMEM) O_tmem += tiled_mma(P, V_smem) # → TMEM return O_tmem, m_new, l_new
FA4 is benchmarked on a B100 180GB SXM6 (1000W) GPU with CUDA 13.1, PyTorch 2.10.0, and CuTe-DSL 4.4.1. Baselines include cuDNN 9.13.0, cuDNN 9.19.1.2, Triton 3.6, and Gluon (a lower-level GPU language). FA3 does not run on Blackwell due to incompatible Hopper MMA instructions.
The setup: BF16 inputs, with/without causal mask, head dimensions 64, 128, and (192, 128) for the DeepSeek V3 configuration. Sequence lengths from 1K to 32K. Total tokens fixed at 32K (batch size adjusts accordingly). Hidden dimension 2048, giving either 32 heads (d=64), 16 heads (d=128), or 16 heads with GQA (dq=192, dkv=128).
FA4 vs cuDNN and Triton on B200 for causal attention with head dimension 128. FA4 consistently outperforms across all sequence lengths. Drag the slider to change sequence length.
| Seq Length | FA4 (TFLOPS) | cuDNN 9.13 (TFLOPS) | Triton (TFLOPS) | FA4 vs cuDNN | FA4 vs Triton |
|---|---|---|---|---|---|
| 1K | ~770 | ~700 | ~280 | 1.1x | 2.7x |
| 2K | ~1000 | ~900 | ~420 | 1.1x | 2.4x |
| 4K | ~1250 | ~1130 | ~540 | 1.1x | 2.3x |
| 8K | ~1420 | ~1280 | ~600 | 1.1x | 2.4x |
| 16K | ~1540 | ~1400 | ~640 | 1.1x | 2.4x |
| 32K | ~1610 | ~1400 | ~570 | 1.15x | 2.8x |
Peak throughput: 1613 TFLOPS, approximately 71% of the B200's theoretical 2.25 PFLOPS. The gains are larger for causal attention, which FA4's LPT scheduler handles particularly well.
With the GQA head dimensions used by DeepSeek V3 (192 query dims, 128 KV dims, 16 heads), FA4 achieves even larger gains over cuDNN:
| Seq Length | FA4 | cuDNN 9.13 | cuDNN 9.19.1.2 | FA4 vs cuDNN 9.13 |
|---|---|---|---|---|
| 1K | ~770 | ~770 | ~810 | 1.0x |
| 4K | ~1260 | ~1130 | ~1120 | 1.1x |
| 8K | ~1420 | ~1260 | ~1310 | 1.1x |
| 16K | ~1550 | ~1420 | ~1540 | 1.1x |
| 32K | ~1610 | ~1400 | ~1560 | 1.15x |
The backward pass shows larger gains over baselines, especially for causal attention:
| Seq Length | FA4 (TFLOPS) | cuDNN 9.13 | FA4 vs cuDNN |
|---|---|---|---|
| 1K (non-causal) | ~580 | ~520 | 1.1x |
| 8K (non-causal) | ~1350 | ~1170 | 1.15x |
| 32K (non-causal) | ~1450 | ~1310 | 1.1x |
| 8K (causal) | ~1200 | ~860 | 1.4x |
| 32K (causal) | ~1400 | ~1200 | 1.17x |
FA4's LPT (Longest-Processing-Time-first) scheduler provides a measurable 4–8% FLOPS gain for MHA and 7–14% for MQA 8 at head dimension 128, by ensuring that the longest-running worktiles are scheduled first, reducing tail-end imbalance across SMs.
For the deterministic backward pass, the SPT (Shortest-Processing-Time-first) scheduler with CTA swizzling achieves up to 950 TFLOPS for causal attention at 32K sequence length — 75% of the nondeterministic version's speed. This is fast enough for practical use in RL training where reproducibility is essential.
FlashAttention-4 sits at the intersection of hardware-aware algorithm design, GPU kernel engineering, and ML systems research. Here are the key threads to follow.
| Work | Key Idea | Relation to FA4 |
|---|---|---|
| SageAttention | INT8 quantized attention | Orthogonal precision approach; FA4 targets BF16/FP8 on datacenter GPUs |
| SageAttention2 | INT4/FP8 with outlier smoothing | Lower precision frontier; FA4's polynomial exp could benefit these too |
| SageAttention3 | FP4 on Blackwell consumer GPUs | Targets consumer Blackwell; FA4 targets datacenter B200/GB200 |
| Triton | Python GPU compiler with tiling | FA4 outperforms by 2.1–2.7x due to hand-tuned pipeline scheduling |
| Gluon | Lower-level GPU programming language | FA4 also outperforms Gluon across most configurations |