Zadouri, Hoehnerbach, Shah, Liu, Thakkar, Dao (Princeton / Meta / Colfax Research / NVIDIA / Georgia Tech / Together AI) — arXiv 2026

FlashAttention-4: Algorithm and Kernel Pipelining Co-Design for Asymmetric Hardware Scaling

The attention kernel that co-designs algorithm and hardware pipelining for NVIDIA Blackwell GPUs — exploiting tensor memory, 2-CTA MMA, software-emulated exponentials, and conditional softmax rescaling to reach 1613 TFLOPS (71% utilization) on B200.

Prerequisites: Attention mechanism + GPU memory hierarchy + FlashAttention basics (tiling, online softmax)
10
Chapters
6+
Simulations
0
Hand-Waving

Chapter 0: The Asymmetric Scaling Problem

You just got access to an NVIDIA B200 GPU. It has 2.25 PFLOPS of tensor core throughput — double the H100. You load your FlashAttention-3 kernel and expect a 2x speedup. Instead, you get... almost nothing. The kernel barely runs faster. What happened?

The answer is asymmetric hardware scaling. NVIDIA doubled the tensor core throughput from Hopper to Blackwell (1 PFLOPS to 2.25 PFLOPS for BF16), but they did not double everything else. Shared memory bandwidth stayed at 128 bytes/clock/SM. The exponential unit (MUFU) stayed at 16 ops/clock/SM. Integer ALUs, floating-point ALUs — all roughly the same. The GPU got faster at matrix multiplies, but everything else stayed put.

This means the bottleneck has shifted. On Hopper, the tensor cores were the limiting factor — you designed your kernel to keep them busy. On Blackwell, the tensor cores finish so fast that they spend most of their time waiting — waiting for data from shared memory, waiting for the exponential unit to finish softmax, waiting for non-matmul operations to complete.

The core insight of FA4: When compute scales faster than memory and special-function units, your kernel must be co-designed with the hardware to overlap matmuls with everything else. You cannot just port an H100 kernel to B200 — you must rethink what is on the critical path.

The FlashAttention lineage

FlashAttention has evolved across four GPU generations, each time adapting to the hardware's shifting bottlenecks:

VersionYearTarget GPUKey InnovationBottleneck Addressed
FA12022A100Tiled attention + online softmaxHBM bandwidth (no materializing N×N matrix)
FA22023A100/H100Reduced non-matmul FLOPs, parallelized over seqlenNon-matmul operations on GPU cores
FA32024H100 (Hopper)Warp specialization, async TMA, FP8 supportOverlap compute with memory via Hopper async features
FA42026B200 (Blackwell)Pingpong pipeline, software exp, conditional rescaling, 2-CTA MMANon-matmul bottleneck from asymmetric scaling

Each generation didn't just "go faster" — it fundamentally restructured which operations are on the critical path and which can be hidden behind tensor core compute. FA4 is the most radical restructuring yet, because the imbalance on Blackwell is the most extreme yet.

What FA4 achieves

The numbers tell the story:

1613 TFLOPS peak
71% of B200's theoretical 2.25 PFLOPS for BF16 attention
1.1–1.3x vs cuDNN 9.13
Faster than NVIDIA's own optimized closed-source library
2.1–2.7x vs Triton
Far faster than the popular open-source GPU compiler
20–30x faster compile
CuTe-DSL in Python compiles in 2.5s vs 55s for FA3's C++ templates

How does FA4 pull this off? Through four interlocking innovations that we will build up piece by piece over the next nine chapters. But first, we need to understand the hardware.

Why can't you simply run a Hopper-optimized FlashAttention-3 kernel on Blackwell and expect a proportional speedup?

Chapter 1: The Blackwell Hardware Model

To understand FA4's design decisions, you need a mental model of the Blackwell GPU's memory hierarchy and execution units. Every optimization in FA4 targets a specific hardware resource. Let's map them out.

Memory hierarchy: four levels

Data on a Blackwell GPU lives in a hierarchy of storage, each level trading capacity for speed:

HBM (Global Memory / GMEM)
180 GB capacity. Off-chip DRAM accessible to all SMs. This is where Q, K, V, O live. High capacity, lowest bandwidth relative to compute.
↓ TMA async transfer
L2 Cache
On-chip cache, transparent to the programmer. Data from HBM is cached here automatically. FA4's LPT scheduler exploits L2 locality.
↓ TMA async transfer
Shared Memory (SMEM)
Programmer-managed, per-SM. 128 bytes/clock read bandwidth. Where K and V tiles are staged for MMA consumption. Directly addressable by all threads in a CTA.
Tensor Memory (TMEM) — NEW in Blackwell
256 KB per SM. Warp-synchronous, tightly coupled to tensor cores. MMA units write results directly to TMEM without consuming registers. Allocated in 32-column (16 KB) granules.
Register File (RMEM)
Max 256 registers per thread, private. Fastest storage but most limited. On Hopper, MMA accumulators lived here. On Blackwell, they live in TMEM instead.
TMEM is the game-changer. On Hopper, MMA results went into registers, creating enormous register pressure and blocking the register file for other operations. On Blackwell, MMA writes go directly to TMEM — 256 KB of dedicated on-chip memory that doesn't compete with registers. This frees up the register file for softmax computation and other non-matmul work, enabling much deeper pipelining.

Tensor cores: 5th generation

Blackwell's tensor cores process 128 × N tiles (N = 128 or 256), compared to Hopper's 64 × N tiles. The throughput is 8192 FLOPs/clock/SM for BF16 — exactly double Hopper's 4096. Each MMA instruction writes its output directly to TMEM asynchronously, meaning computation and other operations can overlap without blocking on register writeback.

The thread hierarchy

From finest to coarsest:

UnitSizeRole in FA4
Thread1Smallest execution unit
Warp32 threadsSIMT execution group
Warpgroup4 warps = 128 threadsOne warpgroup per accumulator tile. FA4 uses 2 warpgroups for pingpong.
CTA (threadblock)1–2 warpgroupsCo-scheduled on one SM. Shares SMEM.
Thread Block ClusterMultiple CTAsCo-scheduled on same GPC. Can share SMEM across CTAs via DSMEM in 2-CTA mode.

Key functional units and their throughput

UnitThroughput (per clock per SM)Role
Tensor cores (BF16 MMA)8192 FLOPsQKT and PV matrix multiplies
Exponential unit (MUFU)16 opsexp() for softmax — same as Hopper!
Shared memory read128 bytesFeeding operands to MMA — same as Hopper!
TMA (Tensor Memory Accelerator)Async, non-blockingHBM ↔ SMEM transfers without occupying SMs
The asymmetry in numbers. Tensor cores: 2x faster. MUFU: 1x (same). SMEM bandwidth: 1x (same). This is why FA3 doesn't scale on Blackwell — the non-matmul operations that were hidden behind tensor core latency on Hopper are now exposed as bottlenecks because the matmuls finish twice as fast.

2-CTA tensor core mode

Blackwell introduces a new MMA mode where a CTA pair within the same thread block cluster cooperatively executes a single MMA. One CTA initiates the operation; the peer CTA must be launched and active. The paired mode supports M = 128 or 256 by partitioning the A tile in dimension M and the B tile across the two CTAs in dimension N, so each CTA stages only half of operand B in its own shared memory while the hardware consumes the combined B tile. This reduces redundant shared memory capacity and bandwidth.

For FA4's backward pass, this 2-CTA mode is critical: it lets the kernel use M = 256 tiles (vs M = 128 in single-CTA mode), roughly halving the shared memory traffic for operand B and halving the number of global atomic reductions for dQ accumulation.

Blackwell Memory Hierarchy

Data flows from HBM through L2 and SMEM to tensor cores. TMEM (new in Blackwell) stores MMA outputs directly, freeing registers. Click the levels to see capacity and bandwidth.

Click a level
What is Tensor Memory (TMEM) and why does it matter for attention kernels on Blackwell?

Chapter 2: Roofline Analysis — Finding the Bottleneck

Before writing a single line of kernel code, FA4's authors do something crucial: a roofline analysis. This means calculating, for each hardware resource, how many clock cycles it needs to process one tile of attention. The resource that takes the most cycles is the bottleneck — and all kernel design effort should go toward either speeding it up or hiding it behind something else.

Forward pass analysis

Let the tile shape along the sequence length of Q and K be M × N, and the head dimension be d. The forward pass performs two MMAs per iteration of the inner loop:

S = QKT
Shared-shared (SS) MMA: both Q and KT read from shared memory. Output: M × N attention scores.
↓ softmax(S)
O += PV
Tensor-shared (TS) MMA: P from TMEM, V from shared memory. Accumulates into M × d output.

Each MMA requires 2MNd floating-point operations (two matmuls). The total MMA compute time is:

TMMA = 4MNd / 8192   cycles

For shared memory traffic, the QKT MMA reads both operands from SMEM (shared-shared), while PV reads P from tensor memory and V from SMEM (tensor-shared). At 128 bytes/cycle and 2 bytes per BF16 element:

Tsmem = 3MNd / 8192 × N   cycles   (simplified)

The exponential unit computes exp() for M × N values in the softmax, at 16 ops/cycle:

Texp = MN / 16   cycles

The numbers for typical tile sizes

Resource1283 (M=N=d=128)256 × 1282
MMA compute1024 cycles2048 cycles
Shared memory768 cycles1536 cycles
Exponential unit1024 cycles2048 cycles
The surprise: For both tile sizes, MMA compute and the exponential unit are tied as bottlenecks, and shared memory is slightly below. On Hopper, MMA was the dominant bottleneck with everything else hidden behind it. On Blackwell, three resources are nearly co-bottlenecked. This motivates FA4's three-pronged approach: (1) overlap MMA with softmax via pingpong scheduling, (2) speed up exponential via software emulation on FMA units, (3) skip unnecessary rescaling operations.

Backward pass is worse

The backward pass performs five MMAs per iteration (recomputing S, plus computing dP, dV, dQ, dK). The shared memory traffic becomes the dominant bottleneck:

Resource1-CTA (M=128)2-CTA (M=256)
MMA compute25602560
Total shared memory33282688
Exponential unit10241024

Shared memory traffic exceeds MMA compute by 30% in 1-CTA mode. Even in 2-CTA mode (which halves some shared memory reads), it still exceeds MMA compute by ~5%. This is why FA4 uses the 2-CTA MMA mode for the backward pass — it is the only way to bring shared memory traffic close to MMA compute time.

Roofline Visualization

Cycle counts for each hardware resource on the forward pass. The tallest bar is the bottleneck. Toggle between Hopper and Blackwell to see how the bottleneck shifts.

Blackwell shown
On Blackwell, which resources are co-bottlenecked for the attention forward pass with tile size 1283?

Chapter 3: The Pingpong Pipeline

The core scheduling idea in FA4 is the same as FA3: pingpong between two warpgroups. While one warpgroup's tile runs tensor core operations (QKT or PV), the other warpgroup computes softmax. They alternate, keeping both the tensor cores and the non-matmul units busy simultaneously.

But on Blackwell, this is harder to get right. The accumulators now live in TMEM (not registers), the tile sizes are 128×128 (not 64×128), and the imbalance between MMA speed and everything else is more extreme. FA4 redesigns the pipeline from scratch.

The two warpgroups

Each SM runs one CTA (threadblock) with two warpgroups of 128 threads each. At any given moment:

PhaseWarpgroup AWarpgroup B
Even iterationMMA: compute SH = QHKjT, then OH += PHVjSoftmax: load SL from TMEM, compute max, rescale, exp, row sum
Odd iterationSoftmax: load SH from TMEM, compute max, rescale, exp, row sumMMA: compute SL = QLKj+1T, then OL += PLVj+1

The superscripts H and L denote the "high" and "low" Q tiles — each warpgroup owns 128 rows of Q. Together they process 256 query rows per SM. The key constraint: the two softmax warpgroups must not overlap in their critical section (the exponential computation), because both use the same MUFU unit on the SM.

TMEM partitioning

This is where the design gets intricate. TMEM must hold the intermediate results that bridge the MMA and softmax phases. FA4 allocates TMEM as follows:

S tiles (attention scores)
Two copies of S (one per warpgroup) — MMA writes S here, softmax reads from here
P tiles (softmax output)
Four copies of P — softmax writes here, PV MMA reads from here
Remaining TMEM
Rescale statistics communicated to a separate "correction" warpgroup
Why P needs four copies. Because FA4 decouples the output rescaling from the critical path. When the softmax max value changes (requiring rescaling of all previously accumulated O values), the rescaling is offloaded to a separate correction warpgroup. To avoid blocking, the P tiles from consecutive iterations must coexist in TMEM — hence four slots in a double-buffered arrangement.

The softmax warpgroup's job

Each softmax warpgroup processes one row of the attention matrix at a time. On Blackwell's larger 128×128 tiles, each thread must hold an entire row of 128 elements in registers. The softmax procedure per warpgroup:

1. Load S row from TMEM
128 elements per thread into registers
2. Compute row max
Inter-warp shuffle to find max across 128 columns
3. Subtract max, exponentiate
S[i] = exp(S[i] - max) using MUFU + software emulation
4. Convert to input precision
FP32 → BF16 for the PV matmul
5. Compute row sum
For the final normalization denominator
6. Store P to TMEM
Ready for the next PV MMA

Register pressure

For BF16 inputs, each thread needs: 128 registers for the S row input, 64 registers for output staging, plus miscellaneous temporaries. To reduce register pressure, FA4 stages the output P in quarters: the first three quarters are stored once (triggering the corresponding MMA operations), and the last quarter is stored separately. This careful staging keeps register usage within the 256-register-per-thread limit.

Pingpong Pipeline Simulation

Two warpgroups alternate between MMA (tensor core) and softmax (MUFU + FMA). Watch how one warpgroup's MMA overlaps with the other's softmax. Click "Step" to advance one phase, or "Play" to animate.

Click Step to begin
In FA4's pingpong schedule, why must the two softmax warpgroups not overlap their exponential computation?

Chapter 4: Software-Emulated Exponential

The MUFU (Multi-Function Unit) computes exp() at 16 operations per clock per SM. For a 128×128 tile, that's 128×128 / 16 = 1024 cycles — matching the MMA compute time exactly. This means even with perfect pingpong overlap, the exponential is right on the edge of being a bottleneck. Any small inefficiency (pipeline bubbles, synchronization) tips it over.

FA4's solution is radical: compute some of the exponentials on the FMA (fused multiply-add) units instead of MUFU. The FMA units are the same ones used for floating-point arithmetic — they can run in parallel with MUFU, effectively doubling the exponential throughput.

The mathematical decomposition

The key identity is:

2x = 2⌊x⌋ · 2xfrac

where ⌊x⌋ is the integer part and xfrac = x - ⌊x⌋ ∈ [0, 1) is the fractional part. The integer part 2⌊x⌋ can be computed without any floating-point arithmetic — it's just a bit manipulation on the IEEE 754 exponent field. Shift ⌊x⌋ into the exponent bits and you have your answer.

The fractional part 2xfrac is approximated by a polynomial:

2xfrac ≈ p0 + p1 xfrac + p2 xfrac2 + … + pn xfracn

with p0 = 1.0 and the remaining coefficients chosen to minimize relative approximation error. This polynomial is evaluated using Horner's method on the FMA units — each term is one fused multiply-add.

The algorithm step by step

1. Clamp x
Clamp x to be at least -127 to avoid underflow
2. Compute ⌊x⌋
Using round-down mode: add 223 + 222 to x (forcing fractional bits into mantissa), then subtract back with round-down
3. Compute xfrac = x - ⌊x⌋
Simple subtraction, now xfrac ∈ [0, 1)
4. Evaluate polynomial
Horner's method: ((pn x + pn-1) x + …) x + p0 using FMA instructions
5. Combine integer and fractional parts
Shift ⌊x⌋ into exponent field and add mantissa bits of 2xfrac

Accuracy: is it good enough?

The authors benchmark polynomial approximations of different degrees against the hardware MUFU.EX2 instruction:

MethodFP32 Max Rel ErrorBF16 Max Rel Error
Hardware MUFU.EX21.41 × 10-73.89 × 10-3
Degree 3 polynomial8.77 × 10-53.90 × 10-3
Degree 4 polynomial3.05 × 10-63.89 × 10-3
Degree 5 polynomial1.44 × 10-73.89 × 10-3
The punchline: At BF16 precision, the polynomial approximation error is indistinguishable from the hardware. The BF16 quantization error (~3.9 × 10-3) dominates the polynomial approximation error for all degrees ≥ 3. A degree-3 polynomial matches hardware to within 1 BF16 ULP on 99% of inputs. This is sufficient for attention computation where the softmax output is consumed in BF16 precision.

Partial emulation

Computing ALL exponentials via polynomial emulation would be wasteful — the additional registers for intermediate values and coefficients would increase register pressure and could cause spills. Instead, FA4 uses partial emulation: only 10–25% of the entries in each softmax row are emulated via polynomial, with the remaining entries computed via hardware MUFU.EX2. The exact fraction is tuned empirically based on the ratio of MMA and exponential throughput for a given tile configuration.

This distributes the exponential computation across both MUFU and FMA units, alleviating the bottleneck while keeping register pressure manageable.

python
# Conceptual software-emulated 2^x (simplified)
import struct

def fast_exp2(x):
    # Step 1: Clamp
    x = max(x, -127.0)

    # Step 2: Integer part via bit tricks
    x_floor = int(x)  # In GPU: add magic number + round-down

    # Step 3: Fractional part
    x_frac = x - x_floor   # x_frac in [0, 1)

    # Step 4: Degree-3 polynomial (Horner's method on FMA units)
    # Coefficients from Sollya minimizing relative error
    p3, p2, p1, p0 = 0.07944, 0.2274, 0.6931, 1.0
    frac_part = ((p3 * x_frac + p2) * x_frac + p1) * x_frac + p0

    # Step 5: Combine via IEEE 754 exponent manipulation
    # Shift x_floor into exponent field, multiply by frac_part
    int_part = 2.0 ** x_floor  # In GPU: bit shift on exponent
    return int_part * frac_part
Why does a degree-3 polynomial approximation of 2x have essentially the same accuracy as the hardware MUFU.EX2 when used for BF16 attention?

Chapter 5: Conditional Softmax Rescaling

FlashAttention computes attention in blocks. Because we process K/V blocks one at a time, we maintain running statistics for the online softmax: a running maximum mj and a running sum of exponentials lj. When a new block has a larger maximum than the running max, we must rescale all previously accumulated output — multiplying O by emj-1 - mj.

The online softmax recap

When processing block j, FA computes:

mj = max(mj-1, rowmax(Sj))
lj = emj-1 - mj lj-1 + rowsum(eSj - mj)

The intermediate output is updated as:

Oj = emj-1 - mj Oj-1 + eSj - mj Vj

That rescaling factor emj-1 - mj is a vector multiplication — it must be applied to every element of the accumulated O. On Blackwell, where non-matmul operations are the bottleneck, this rescaling is expensive.

FA4's insight: skip rescaling when it's not needed

FA4 makes two key observations:

Observation 1: Rescaling is only necessary when mj > mj-1, i.e., when the current block contains larger attention scores than all previous blocks. In practice, the running maximum stabilizes quickly — after the first few blocks, most iterations don't need rescaling at all.

Observation 2: We can tolerate some "slack" in the rescaling. If mj - mj-1 > τ (a threshold), we must rescale. But if the difference is small (mj - mj-1 ≤ τ), we can skip rescaling and continue using mj-1. The statistics will be slightly off, but we fix everything at the end with a final normalization.

The threshold τ is typically set to log2(256) = 8.0, corresponding to a rescaling factor of 256.0. As long as we keep track of the true statistics, the final output is exact:

Oj = { emj-1 - mj Oj-1 + eSj - mj Vj    if mj - mj-1 > τ
        Oj-1 + eSj - mj-1 Vj               otherwise }

At the very end:

Output = (1 / lfinal) Ofinal
Why this works: The online softmax algorithm is fundamentally about maintaining the correct numerator-denominator ratio. As long as we track the true mfinal and lfinal, we can delay rescaling indefinitely. The conditional check simply avoids doing expensive vector multiplications on iterations where the max hasn't changed significantly. The final normalization fixes any accumulated drift.

Decoupling rescaling from the critical path

Even when rescaling IS needed, FA4 moves it off the critical path. Since accumulators live in TMEM (not registers), rescaling can be performed by a separate "correction" warpgroup that reads the accumulator from TMEM, multiplies by the rescaling factor, and writes it back — all while the main warpgroups continue with the next iteration's MMA and softmax. This decoupling is impossible on Hopper, where accumulators live in registers that are private to each thread.

To avoid warp divergence, FA4 rescales whenever any thread in the warpgroup needs rescaling. This is a conservative choice that trades a small amount of unnecessary rescaling for simpler control flow.

How does FA4's conditional softmax rescaling reduce non-matmul overhead?

Chapter 6: The Backward Pass

The backward pass is the harder half. It performs five MMAs per iteration (vs two in forward), recomputing S on the fly, and must accumulate gradients dQ, dK, dV. On Blackwell, shared memory traffic is the dominant bottleneck — exceeding MMA compute by 30%.

The five MMAs

Given output gradient dO ∈ RN×d, the backward pass computes:

dV = PTdO,    dP = dO VT
dS = dsoftmax(dP)
dQ = α dS K,    dK = α dST Q

where dsoftmax(dP) denotes the row-wise softmax gradient: dS = (diag(p) - ppT)dp for p = softmax(s). This requires recomputing S = QKT (since we didn't store it in the forward pass — that's the whole point of FlashAttention). So the five MMAs are:

MMAOperand A sourceOperand B sourceType
ST = KQTSMEM (K)SMEM (Q)Shared-Shared
dPT = VdOTTMEM (V)SMEM (dO)Tensor-Shared
dV = PTdOTMEM (P)SMEM (dO)Tensor-Shared
dQ = dSKSMEM (dS)SMEM (K)Shared-Shared
dK = dSTQSMEM (dS)SMEM (Q)Shared-Shared

Three of the five (ST, dQ, dK) are shared-shared — both operands read from SMEM. This is why shared memory bandwidth dominates the backward pass.

1-CTA pipeline

In 1-CTA mode (M = 128), the pipeline follows the same pingpong principle as the forward pass, but with 5 MMAs + 2 elementwise operations (softmax and its gradient). The pipeline stages across prologue, main loop, and tail:

Prologue
Load first K, V tiles. Compute initial S = QKT, then dP = dO VT.
Main loop (j = 1..B blocks)
Interleave: load next K,V | recompute S | elementwise softmax+grad | dP, dV, dQ, dK MMAs. Pipeline dQ and dK MMAs of the current tile with loading/computing the next tile.
Tail
Final dK and dQ for the last tile. Write dQ atomically to global memory.

2-CTA mode: halving shared memory traffic

Even with the 1-CTA pipeline, shared memory is still the bottleneck. FA4 uses Blackwell's 2-CTA MMA mode to address this. With an MMA tile shape of M = 256 and N = K = 128:

The two CTAs communicate via Distributed Shared Memory (DSMEM) — a Blackwell feature allowing CTAs in the same cluster to access each other's SMEM. FA4 uses DSMEM to exchange half of the dS tile between the two CTAs, since each CTA needs the full dS for its MMA but computes only half of it.

Deterministic backward pass

The dQ gradient accumulates via atomic additions to global memory, which are nondeterministic (the order of atomic operations varies between runs). For reproducible training in reinforcement learning, FA4 provides a deterministic mode using semaphore locks: each CTA acquires a lock before writing to a common dQ tile, performs its reduction, and releases the lock. Combined with careful CTA swizzling (processing tiles in an order that minimizes stalls), the deterministic backward pass achieves up to 75% of the nondeterministic version's speed.

Why does FA4 use Blackwell's 2-CTA MMA mode for the backward pass?

Chapter 7: CuTe-DSL — A New Way to Write Kernels

Every previous FlashAttention version was written in CUDA C++ using CUTLASS templates — a maze of deeply nested C++ template metaprogramming that took 55 seconds to compile a single kernel. FA4 takes a radically different approach: it is written entirely in CuTe-DSL, a domain-specific language embedded in Python.

What is CuTe-DSL?

CuTe-DSL is the Python frontend for NVIDIA's CuTe (Collective Unrolled Tensor Expression) framework. It provides:

Layout algebra
Compact representation of how data maps from logical coordinates to physical memory. Layouts compose, transform, and tile declaratively.
Tensor abstraction
A Tensor is a pointer + Layout. Operations on tensors automatically handle strides, padding, and alignment.
TiledMMA / TiledCopy
High-level abstractions for matrix multiply-accumulate and memory copy operations. They specify the hardware instruction (WGMMA), the thread-to-data mapping, and the tile shape.
Python → PTX → SASS
JIT compilation: Python source → CuTe-DSL compiler → PTX assembly → ptxas → final GPU binary (SASS)

Why Python instead of C++?

The CuTe-DSL programming model is isomorphic to CUTLASS C++ — every C++ CUTLASS construct has a direct CuTe-DSL equivalent. You lose nothing in expressiveness. But you gain enormously in productivity:

PropertyFA3 (C++ CUTLASS)FA4 (CuTe-DSL Python)
Forward compile time55 seconds2.5 seconds
Backward compile time45 seconds1.4 seconds
Speedup1x22–32x
Iteration speedMinutes per experimentSeconds per experiment
Prerequisite expertiseDeep C++ template metaprogrammingPython + GPU concepts
20–30x faster compile is not just a convenience. Kernel development is inherently iterative — you change a tile size, recompile, benchmark, adjust. At 55 seconds per compile, you get ~60 experiments per hour. At 2.5 seconds, you get ~1400. This 20x increase in iteration speed directly translates to better-optimized kernels, because you can explore more of the design space in the same wall-clock time.

Escape hatch to PTX

CuTe-DSL provides direct access to PTX (GPU assembly) as an escape hatch. FA4 uses custom PTX sequences for operations not yet fully exposed in the CuTe-DSL API, demonstrating that the framework doesn't constrain developers to a limited subset of GPU capabilities. This is critical for a performance-critical kernel that needs every last drop of throughput.

Composable primitives

FA4's codebase is designed as a library of composable primitives rather than a monolithic kernel:

Developers have already built FlexAttention and block-sparse attention variants on top of FA4 without modifying the core framework. This modularity is a direct benefit of the CuTe-DSL approach — Python's flexibility makes it natural to compose orthogonal features.

python
# Conceptual CuTe-DSL kernel structure (simplified)
from cute_dsl import Tensor, Layout, TiledMMA, TMA

# Define tile shapes
BM, BN, BD = 128, 128, 128

# Define MMA operation (maps to WGMMA instruction)
tiled_mma = TiledMMA(
    instruction="wgmma_bf16_128x128x128",
    warpgroup_size=128,  # 4 warps
)

# Define TMA copy from HBM to SMEM
tma_load_k = TMA(
    src="gmem", dst="smem",
    shape=(BN, BD), dtype="bf16",
    swizzle=128,  # 128-byte aligned for optimal bandwidth
)

# Main loop body (one iteration)
def attention_step(Q_smem, K_smem, V_smem, O_tmem, m, l):
    # 1. QK^T via WGMMA (writes S to TMEM)
    S = tiled_mma(Q_smem, K_smem.T)   # → TMEM

    # 2. Online softmax (reads S from TMEM, writes P to TMEM)
    m_new = max(m, row_max(S))
    if m_new - m > tau:
        O_tmem *= exp(m - m_new)   # conditional rescale
    P = softmax_partial(S, m_new)  # MUFU + FMA emulation

    # 3. PV via WGMMA (accumulates into O in TMEM)
    O_tmem += tiled_mma(P, V_smem)  # → TMEM
    return O_tmem, m_new, l_new
What concrete advantage does CuTe-DSL provide over C++ CUTLASS templates for FA4's development?

Chapter 8: Benchmarks — Measuring the Payoff

FA4 is benchmarked on a B100 180GB SXM6 (1000W) GPU with CUDA 13.1, PyTorch 2.10.0, and CuTe-DSL 4.4.1. Baselines include cuDNN 9.13.0, cuDNN 9.19.1.2, Triton 3.6, and Gluon (a lower-level GPU language). FA3 does not run on Blackwell due to incompatible Hopper MMA instructions.

Benchmark settings

The setup: BF16 inputs, with/without causal mask, head dimensions 64, 128, and (192, 128) for the DeepSeek V3 configuration. Sequence lengths from 1K to 32K. Total tokens fixed at 32K (batch size adjusts accordingly). Hidden dimension 2048, giving either 32 heads (d=64), 16 heads (d=128), or 16 heads with GQA (dq=192, dkv=128).

Forward pass results

Forward Pass TFLOPS Comparison

FA4 vs cuDNN and Triton on B200 for causal attention with head dimension 128. FA4 consistently outperforms across all sequence lengths. Drag the slider to change sequence length.

Forward pass shown
Seq LengthFA4 (TFLOPS)cuDNN 9.13 (TFLOPS)Triton (TFLOPS)FA4 vs cuDNNFA4 vs Triton
1K~770~700~2801.1x2.7x
2K~1000~900~4201.1x2.4x
4K~1250~1130~5401.1x2.3x
8K~1420~1280~6001.1x2.4x
16K~1540~1400~6401.1x2.4x
32K~1610~1400~5701.15x2.8x

Peak throughput: 1613 TFLOPS, approximately 71% of the B200's theoretical 2.25 PFLOPS. The gains are larger for causal attention, which FA4's LPT scheduler handles particularly well.

For DeepSeek V3 configuration (192, 128)

With the GQA head dimensions used by DeepSeek V3 (192 query dims, 128 KV dims, 16 heads), FA4 achieves even larger gains over cuDNN:

Seq LengthFA4cuDNN 9.13cuDNN 9.19.1.2FA4 vs cuDNN 9.13
1K~770~770~8101.0x
4K~1260~1130~11201.1x
8K~1420~1260~13101.1x
16K~1550~1420~15401.1x
32K~1610~1400~15601.15x
Note: Newer cuDNN versions (9.19.1.2) have incorporated some FA4 techniques, narrowing the gap. The FA4 authors explicitly worked with the cuDNN team to upstream their innovations. This is open research at its best — the improvements benefit everyone.

Backward pass results

The backward pass shows larger gains over baselines, especially for causal attention:

Seq LengthFA4 (TFLOPS)cuDNN 9.13FA4 vs cuDNN
1K (non-causal)~580~5201.1x
8K (non-causal)~1350~11701.15x
32K (non-causal)~1450~13101.1x
8K (causal)~1200~8601.4x
32K (causal)~1400~12001.17x

Scheduling ablations

FA4's LPT (Longest-Processing-Time-first) scheduler provides a measurable 4–8% FLOPS gain for MHA and 7–14% for MQA 8 at head dimension 128, by ensuring that the longest-running worktiles are scheduled first, reducing tail-end imbalance across SMs.

For the deterministic backward pass, the SPT (Shortest-Processing-Time-first) scheduler with CTA swizzling achieves up to 950 TFLOPS for causal attention at 32K sequence length — 75% of the nondeterministic version's speed. This is fast enough for practical use in RL training where reproducibility is essential.

What is FA4's peak throughput on B200, and what fraction of theoretical maximum does this represent?

Chapter 9: Connections

FlashAttention-4 sits at the intersection of hardware-aware algorithm design, GPU kernel engineering, and ML systems research. Here are the key threads to follow.

The FlashAttention lineage

FlashAttention (Dao et al., 2022)
The original. Tiled attention with online softmax to avoid materializing the N×N attention matrix. Key insight: recompute in backward pass instead of storing. Targeted A100 with IO-awareness. arXiv
FlashAttention-2 (Dao, 2023)
Reduced non-matmul FLOPs, parallelized over sequence length dimension. Improved GPU occupancy. arXiv
FlashAttention-3 (Shah et al., 2024)
Exploited Hopper's async features: warp specialization, TMA for async data movement, FP8 support. First pingpong schedule. arXiv
FlashAttention-4 (Zadouri et al., 2026)
Co-designed for Blackwell's asymmetric scaling: TMEM-based pipeline, software exponential, conditional rescaling, 2-CTA MMA, CuTe-DSL. arXiv

Related hardware-aware attention

WorkKey IdeaRelation to FA4
SageAttentionINT8 quantized attentionOrthogonal precision approach; FA4 targets BF16/FP8 on datacenter GPUs
SageAttention2INT4/FP8 with outlier smoothingLower precision frontier; FA4's polynomial exp could benefit these too
SageAttention3FP4 on Blackwell consumer GPUsTargets consumer Blackwell; FA4 targets datacenter B200/GB200
TritonPython GPU compiler with tilingFA4 outperforms by 2.1–2.7x due to hand-tuned pipeline scheduling
GluonLower-level GPU programming languageFA4 also outperforms Gluon across most configurations

Key takeaways

1. Hardware co-design is mandatory. The same algorithm runs at vastly different efficiency on different GPUs. FA1 was designed for A100's memory hierarchy. FA3 was designed for H100's async primitives. FA4 is designed for B200's asymmetric compute-to-memory ratio. There is no "universal fast attention" — only attention that is exquisitely tuned to its target hardware.
2. The bottleneck keeps shifting. V100: compute limited. A100: memory bandwidth limited. H100: compute limited again (but with async overlap). B200: non-matmul limited (exp, shared memory). Each generation, the kernel designer must re-identify the bottleneck and restructure the entire pipeline around it.
3. Software can substitute for hardware. FA4's polynomial exponential emulation on FMA units is a beautiful example: when the hardware special-function unit is too slow, approximate it in software on the general-purpose units. The key insight is that BF16 precision masks the approximation error.
4. Open source wins. FA4 is open source, its techniques have been upstreamed to cuDNN, and the CuTe-DSL framework lowers the barrier for future kernel innovation. The code is available at github.com/Dao-AILab/flash-attention.

Open questions

What is the overarching lesson from the FA1 → FA4 evolution?