Austin et al., Part 4

All the Transformer Math You Need to Know

Parameter counting, FLOPs per token, memory footprints, KV cache sizing, and MFU — everything you need to predict how a Transformer will behave on real hardware.

Prerequisites: Chapters 1–3. Basic familiarity with the Transformer architecture (attention, MLP, residuals).
11
Chapters
1
Simulation
11
Quizzes

Chapter 0: Counting FLOPs

Before we can do any roofline analysis on Transformers, we need to count FLOPs precisely. Let's build from the ground up.

OperationShapesFLOPsData (bytes, bf16)
Dot productx[P] · y[P]2P4P + 2
Matrix-vectorA[N,P] · x[P]2NP2NP + 2P + 2N
Matrix-matrixA[N,P] · B[P,M]2NPM2NP + 2PM + 2NM

The dot product does P multiplies and P−1 adds ≈ 2P FLOPs. Matrix-vector is N dot products: 2NP. Matrix-matrix is M such products: 2NPM.

The general rule for einsums

For two higher-dimensional arrays C[...] and D[...], some dimensions are contracting (summed over) and some are batching (shared between inputs and output). The FLOPs are:

FLOPs = 2 × (product of all unique dimensions)

where batch and contraction dims are counted only once (not double-counted from both inputs).

The O(N3) vs O(N2) insight: For a square N×N matmul, FLOPs scale as 2N3 but data scales as 3×2N2. Bigger matmuls have higher arithmetic intensity. This is why Transformers scale so well on hardware.
A[I,J,K,L] · B[I,J,M,N,O] → C[K,L,M,N,O] where I,J are contracting. Total FLOPs?

Chapter 1: Forward & Backward FLOPs

During training, we care about the backward pass too. For a single matmul C = A · B where A[N,P] are activations and B[P,M] are weights:

Forward: C = A B — costs 2NPM FLOPs.

Backward: We need two gradients:

dL/dB = AT · (dL/dC) — 2NPM FLOPs (contracts over N)
dL/dA = (dL/dC) · BT — 2NPM FLOPs (contracts over M)

Each gradient computation is itself a matmul of the same cost as the forward pass.

The 1:2 ratio: Forward = 2NPM. Backward = 4NPM (two matmuls). Total training = 6NPM per weight matrix. Since PM is the number of parameters, this gives us the famous approximation: 6 × params × tokens total training FLOPs.

This 1:2 forward:backward ratio holds for every matmul in the network. It means the backward pass costs exactly 2x the forward pass in FLOPs. This is a good sanity check for any training FLOPs estimate.

A model does 1e15 FLOPs in the forward pass. How many total FLOPs per training step?

Chapter 2: MLP Accounting

Modern Transformers use a gated MLP with three weight matrices per layer. Let B = batch, T = sequence length, D = model dimension, F = FFN dimension (typically 4D).

OperationTrain FLOPs/layerParams/layer
A[B,T,D] · Win1[D,F]6BTDFDF
A[B,T,D] · Win2[D,F]6BTDFDF
σ(Ain1) × Ain2O(BTF) — negligible0
A[B,T,F] · Wout[F,D]6BTDFDF
Total MLP18BTDF3DF

The gating einsum (used by LLaMA, DeepSeek, and many others) splits the up-projection into two matrices whose outputs are element-wise multiplied. Some models use a single Win (2DF params) but scale D and F to compensate.

Note: For a standard model with F=4D, the MLP has 3×D×4D = 12D2 parameters per layer. With D=4096, that is ~200M per layer.
A Transformer with D=8192, F=4D=32768, L=64 layers. MLP params per layer?

Chapter 3: Attention Accounting

Attention has two parts: the QKVO projections (matmuls with weight matrices) and the dot-product attention (QKT softmax, then times V).

QKVO projections

Let N = number of Q heads, K = number of KV heads (for GQA), H = head dimension, with D = NH typically.

OperationTrain FLOPs/layerParams/layer
Q: A[B,T,D] · WQ[D,N,H]6BTDNHDNH
K: A[B,T,D] · WK[D,K,H]6BTDKHDKH
V: A[B,T,D] · WV[D,K,H]6BTDKHDKH
O: A[B,T,N,H] · WO[N,H,D]6BTDNHDNH
Total QKVO12BTD(N+K)H2D(N+K)H

Dot-product attention

The QKT and softmax·V operations are batched matmuls:

OperationTrain FLOPs
Q[B,T,K,G,H] · K[B,S,K,H] → [B,T,S,K,G]6BTSNH
softmaxO(BTSN) — negligible
S[B,T,S,K,G] · V[B,S,K,H] → [B,T,K,G,H]6BTSNH
Total dot-product12BTSNH ≈ 12BT2NH

(For self-attention, S = T. With causal masking, effective FLOPs are halved.)

QKVO projections vs dot-product attention: QKVO costs 24BTDNH (assuming N=K, MHA). Dot-product costs 12BT2NH. They are equal when 2D = T. For D=8192, that is T=16384. For most practical training lengths, the QKVO projections dominate.
For GQA with N=32 Q heads and K=8 KV heads (4x fewer), what fraction of attention params are KV projections?

Chapter 4: The 6P Rule

Let's put it all together. Per layer, the total parameters and training FLOPs are:

ComponentParams/layerTrain FLOPs/layer
MLP (gated)3DF18BTDF
Attention (QKVO)2D(N+K)H12BTD(N+K)H
Dot-product attn012BT2NH
LayerNormD (negligible)O(BTD) (negligible)
Unembedding (total)DV6BTDV

Ignoring dot-product attention FLOPs (valid when T < 8D), the total across all L layers is:

Total FLOPs ≈ 6 × BT × [3DF + 2D(N+K)H] × L
= 6 × (num tokens) × (param count)
The 6P rule: Total training FLOPs ≈ 6 × P × T, where P is the parameter count and T is total tokens processed. Forward-only inference is ~2P per token. This approximation holds when context length is not extreme (T < 8D).
Transformer FLOPs Calculator

Adjust model dimensions to see parameter counts and FLOPs breakdowns.

A 7B parameter model trained on 1T tokens. Total training FLOPs?

Chapter 5: Attention vs. MLP FLOPs

When does dot-product attention become the dominant cost? Let's derive the crossover.

Assuming F=4D, D=NH, and N=K (standard MHA):

Attention dot-product FLOPs / Matmul FLOPs
= 12BT2NH / (18BTDF + 24BTDNH)
= 12BT2D / (72BTD2 + 24BTD2)
= 12T / 96D = T / 8D
Attention dominates when T > 8D. For D=8192 (large model), this is T > 65,536 tokens. For D=4608 (Gemma-27B), it is T > 36,864. For smaller models, attention costs bite earlier.

This has practical consequences:

1. For most training runs (T ≤ 8K), the MLP dominates and the 6P rule is accurate.

2. For long-context training (T > 64K), attention FLOPs can double the total cost.

3. Flash Attention helps with memory but not FLOPs — the quadratic cost remains.

Arithmetic intensity of self-attention

Self-attention loads Q, K, V from HBM and writes the output back. The total bytes are O(BTNH) while FLOPs are O(BT2NH). So:

Intensity(self-attention) ≈ T (during prefill/training)

At T=240, attention is just barely compute-bound. At T=4096, it is solidly compute-bound. During generation (T=1), attention is always bandwidth-bound with intensity ≈ G (number of Q heads per KV head).

For a model with D=4096, at what sequence length do attention FLOPs equal QKVO projection FLOPs?

Chapter 6: KV Cache

During inference, LLMs have two phases:

Prefill
Process the prompt, save K and V projections into the KV cache
Generation
Sample tokens one at a time, using the cached K/V to avoid recomputation

The KV cache stores the key and value projections for every layer, every KV head, and every position in the sequence:

KV cache size = 2 × S × L × K × H (per sequence)

where S = sequence length, L = layers, K = KV heads, H = head dimension, and the factor 2 is for keys + values.

Worked example

Model with S=8192, L=64, K=N (MHA), D=NH=8192, in int8:

KV cache = 2 × 8192 × 64 × 8192 = 8 GiB per sequence
This is enormous. With just 10 concurrent sequences, the KV cache alone consumes 80 GiB — nearly an entire H100. This is why GQA/MQA (reducing K relative to N) is so popular. With K=N/4, the cache drops to 2 GiB per sequence.

Per-token KV cache

For each new token generated, we add 2×L×K×H bytes (in int8) to the cache:

Per-token KV = 2 × L × K × H

For D=4096, L=64, MHA, int8: 2 × 64 × 4096 = 512 KB/token.

A model with D=8192, L=80, GQA with K=8 KV heads, H=128, int8. KV cache per token?

Chapter 7: Gradient Checkpointing

Backpropagation trades memory for compute: it saves all intermediate activations from the forward pass so gradients can be computed without redoing forward calculations. But this is extremely memory-hungry.

For a model with BT=4M tokens, L=64, D=8192, saving ~20 intermediate nodes per layer in bf16:

Activation memory = 2 × 20 × BT × D × L = 84 TB

That is far more than any practical amount of HBM. We must be selective about what we save.

Two common strategies

StrategyCheckpoints savedFLOPs overhead
Block remat1 per layer (just the input)+33% (from 6P to ~8P)
Big matmuls only7 per layer (Q,K,V,O + 3 FFN outputs)Lower (recompute only attention dot-product + small ops)

Block remat is the most aggressive: save only the layer input, recompute everything else during backward. This basically repeats the entire forward pass in the backward, increasing total FLOPs from ~6P to ~8P per token.

Big matmuls only saves the 7 matmul outputs. During backward, only the dot-product attention and small ops need to be recomputed — extra FLOPs of about 4BT2NH per layer.

In JAX: jax.remat / jax.checkpoint control which activations to save. The choice is a memory-compute tradeoff that depends on your specific model size and hardware.
Block remat saves only 1 checkpoint per layer. If the forward pass costs 2P FLOPs/token, total training cost with block remat is approximately:

Chapter 8: MoE Math

A Mixture-of-Experts (MoE) model replaces each dense MLP with E independent expert MLPs, of which only k are activated per token.

QuantityDense modelMoE model
Total params (MLP)3DF × L3DF × E × L
Activated params/token3DF3DF × k
Sparsity ratio1E / k (typically 8–64)

The key insight: MoE increases capacity (total params) without proportionally increasing compute (activated FLOPs per token). DeepSeek v3 has k=8, E=256 for a sparsity of 32x.

Roofline for MoE

With E expert copies, we load E×D×F weight bytes but only do 2k×B×D×F FLOPs. For int8 weights and bf16 compute, to be compute-bound:

2k × BDF / (E × DF) > 240
kB / E > 120
B > 120 × E / k

For DeepSeek (E=256, k=8): B > 120 × 256 / 8 = 3,840 tokens.

MoE needs large batches. At 3,840 tokens per replica just to be compute-bound, MoE models require substantial batching during inference. This is a major operational challenge — you need enough concurrent requests to fill the batch.

Communication overhead

MoE introduces two AllToAll operations per layer (routing tokens to experts and back). But AllToAll is only 1/4 the cost of an AllGather, making this relatively cheap.

DeepSeek v3 was trained for 2.79M H800 hours on 14.8T tokens with 37B activated params. What FLOPs utilization (MFU) did they achieve? (FP8 no sparsity: 1.513e15 FLOPs/s)

Chapter 9: Exercises

Exercise 1: Full model parameter count

D=4096, F=4D=16384, V=32000, L=64, N·H=D, MHA.

Per layer: 3DF + 4DNH + D = 3×4096×16384 + 4×40962 + 4096
= 201M + 67M + 4K ≈ 268M per layer
Total: 268M × 64 + 2×4096×32000 = 17.2B + 0.26B ≈ 17.4B

Attention fraction: 4D2 / (4D2 + 3×4D2) = 4/(4+12) = 1/4 of params are in attention.

KV cache per token (int8): 2 × 64 × 4096 = 512 KB/token.

Exercise 2: Sharded matmul FLOPs

A[BX, DY] · W[DY, F] on Mesh{'X':4, 'Y':8, 'Z':4}. Total theoretical FLOPs = 2BDF. But the computation is replicated over Z (not sharded), so total actual FLOPs = 2BDF × Z. Per device: 2BDF / (X×Y).

Exercise 3: Flash Attention savings

Flash Attention avoids materializing the full T×T attention matrix. It processes attention in chunks, keeping running statistics (max, sum, output) in VMEM. This reduces memory from O(BT2NH) to O(BTNH) — the full attention matrix is never stored. FLOPs are unchanged (still quadratic), but the arithmetic intensity increases dramatically because data stays in fast on-chip memory.

Exercise 4: Remat FLOPs overhead

Saving only the 7 matmul outputs per layer means we recompute only the attention dot-product: 4BT2NH extra FLOPs per layer. Plus O(BTD) for gating/norm recomputation and O(BTF) for the output nonlinearity.

A 70B model in bf16 with Adam optimizer (2 states per param, fp32). Total memory for weights + optimizer?

Chapter 10: Summary

QuantityFormulaTypical Value (7B model)
MLP params/layer3DF~200M
Attn params/layer2D(N+K)H ≈ 4D2 (MHA)~67M
Total params(3DF + 4D2)L + DV~7B
Training FLOPs/token6P (ignoring attn)~42 GFLOPs
Attn crossoverT > 8DT > 32K
KV cache/token2LKH512 KB (int8, MHA)
Block remat overhead+33% FLOPs (6P → 8P)Saves ~20x activation memory
MoE critical batch120 × E/k3840 (DeepSeek)
The punchline: Everything about Transformer scaling — from training cost to inference latency to memory requirements — can be derived from a handful of equations. The 6P rule for training FLOPs, the KV cache formula, and the roofline model from Chapter 1 together let you predict performance on any hardware configuration before writing a single line of code.

Armed with these equations and the roofline/sharding tools from Chapters 1–3, you can now analyze any distributed training or inference setup: estimate step time, identify bottlenecks, and choose optimal parallelism strategies.

The single most important takeaway from this chapter is: