Parameter counting, FLOPs per token, memory footprints, KV cache sizing, and MFU — everything you need to predict how a Transformer will behave on real hardware.
Before we can do any roofline analysis on Transformers, we need to count FLOPs precisely. Let's build from the ground up.
| Operation | Shapes | FLOPs | Data (bytes, bf16) |
|---|---|---|---|
| Dot product | x[P] · y[P] | 2P | 4P + 2 |
| Matrix-vector | A[N,P] · x[P] | 2NP | 2NP + 2P + 2N |
| Matrix-matrix | A[N,P] · B[P,M] | 2NPM | 2NP + 2PM + 2NM |
The dot product does P multiplies and P−1 adds ≈ 2P FLOPs. Matrix-vector is N dot products: 2NP. Matrix-matrix is M such products: 2NPM.
For two higher-dimensional arrays C[...] and D[...], some dimensions are contracting (summed over) and some are batching (shared between inputs and output). The FLOPs are:
where batch and contraction dims are counted only once (not double-counted from both inputs).
During training, we care about the backward pass too. For a single matmul C = A · B where A[N,P] are activations and B[P,M] are weights:
Forward: C = A B — costs 2NPM FLOPs.
Backward: We need two gradients:
Each gradient computation is itself a matmul of the same cost as the forward pass.
6 × params × tokens total training FLOPs.This 1:2 forward:backward ratio holds for every matmul in the network. It means the backward pass costs exactly 2x the forward pass in FLOPs. This is a good sanity check for any training FLOPs estimate.
Modern Transformers use a gated MLP with three weight matrices per layer. Let B = batch, T = sequence length, D = model dimension, F = FFN dimension (typically 4D).
| Operation | Train FLOPs/layer | Params/layer |
|---|---|---|
| A[B,T,D] · Win1[D,F] | 6BTDF | DF |
| A[B,T,D] · Win2[D,F] | 6BTDF | DF |
| σ(Ain1) × Ain2 | O(BTF) — negligible | 0 |
| A[B,T,F] · Wout[F,D] | 6BTDF | DF |
| Total MLP | 18BTDF | 3DF |
The gating einsum (used by LLaMA, DeepSeek, and many others) splits the up-projection into two matrices whose outputs are element-wise multiplied. Some models use a single Win (2DF params) but scale D and F to compensate.
Attention has two parts: the QKVO projections (matmuls with weight matrices) and the dot-product attention (QKT softmax, then times V).
Let N = number of Q heads, K = number of KV heads (for GQA), H = head dimension, with D = NH typically.
| Operation | Train FLOPs/layer | Params/layer |
|---|---|---|
| Q: A[B,T,D] · WQ[D,N,H] | 6BTDNH | DNH |
| K: A[B,T,D] · WK[D,K,H] | 6BTDKH | DKH |
| V: A[B,T,D] · WV[D,K,H] | 6BTDKH | DKH |
| O: A[B,T,N,H] · WO[N,H,D] | 6BTDNH | DNH |
| Total QKVO | 12BTD(N+K)H | 2D(N+K)H |
The QKT and softmax·V operations are batched matmuls:
| Operation | Train FLOPs |
|---|---|
| Q[B,T,K,G,H] · K[B,S,K,H] → [B,T,S,K,G] | 6BTSNH |
| softmax | O(BTSN) — negligible |
| S[B,T,S,K,G] · V[B,S,K,H] → [B,T,K,G,H] | 6BTSNH |
| Total dot-product | 12BTSNH ≈ 12BT2NH |
(For self-attention, S = T. With causal masking, effective FLOPs are halved.)
Let's put it all together. Per layer, the total parameters and training FLOPs are:
| Component | Params/layer | Train FLOPs/layer |
|---|---|---|
| MLP (gated) | 3DF | 18BTDF |
| Attention (QKVO) | 2D(N+K)H | 12BTD(N+K)H |
| Dot-product attn | 0 | 12BT2NH |
| LayerNorm | D (negligible) | O(BTD) (negligible) |
| Unembedding (total) | DV | 6BTDV |
Ignoring dot-product attention FLOPs (valid when T < 8D), the total across all L layers is:
Adjust model dimensions to see parameter counts and FLOPs breakdowns.
When does dot-product attention become the dominant cost? Let's derive the crossover.
Assuming F=4D, D=NH, and N=K (standard MHA):
This has practical consequences:
1. For most training runs (T ≤ 8K), the MLP dominates and the 6P rule is accurate.
2. For long-context training (T > 64K), attention FLOPs can double the total cost.
3. Flash Attention helps with memory but not FLOPs — the quadratic cost remains.
Self-attention loads Q, K, V from HBM and writes the output back. The total bytes are O(BTNH) while FLOPs are O(BT2NH). So:
At T=240, attention is just barely compute-bound. At T=4096, it is solidly compute-bound. During generation (T=1), attention is always bandwidth-bound with intensity ≈ G (number of Q heads per KV head).
During inference, LLMs have two phases:
The KV cache stores the key and value projections for every layer, every KV head, and every position in the sequence:
where S = sequence length, L = layers, K = KV heads, H = head dimension, and the factor 2 is for keys + values.
Model with S=8192, L=64, K=N (MHA), D=NH=8192, in int8:
For each new token generated, we add 2×L×K×H bytes (in int8) to the cache:
For D=4096, L=64, MHA, int8: 2 × 64 × 4096 = 512 KB/token.
Backpropagation trades memory for compute: it saves all intermediate activations from the forward pass so gradients can be computed without redoing forward calculations. But this is extremely memory-hungry.
For a model with BT=4M tokens, L=64, D=8192, saving ~20 intermediate nodes per layer in bf16:
That is far more than any practical amount of HBM. We must be selective about what we save.
| Strategy | Checkpoints saved | FLOPs overhead |
|---|---|---|
| Block remat | 1 per layer (just the input) | +33% (from 6P to ~8P) |
| Big matmuls only | 7 per layer (Q,K,V,O + 3 FFN outputs) | Lower (recompute only attention dot-product + small ops) |
Block remat is the most aggressive: save only the layer input, recompute everything else during backward. This basically repeats the entire forward pass in the backward, increasing total FLOPs from ~6P to ~8P per token.
Big matmuls only saves the 7 matmul outputs. During backward, only the dot-product attention and small ops need to be recomputed — extra FLOPs of about 4BT2NH per layer.
jax.remat / jax.checkpoint control which activations to save. The choice is a memory-compute tradeoff that depends on your specific model size and hardware.A Mixture-of-Experts (MoE) model replaces each dense MLP with E independent expert MLPs, of which only k are activated per token.
| Quantity | Dense model | MoE model |
|---|---|---|
| Total params (MLP) | 3DF × L | 3DF × E × L |
| Activated params/token | 3DF | 3DF × k |
| Sparsity ratio | 1 | E / k (typically 8–64) |
The key insight: MoE increases capacity (total params) without proportionally increasing compute (activated FLOPs per token). DeepSeek v3 has k=8, E=256 for a sparsity of 32x.
With E expert copies, we load E×D×F weight bytes but only do 2k×B×D×F FLOPs. For int8 weights and bf16 compute, to be compute-bound:
For DeepSeek (E=256, k=8): B > 120 × 256 / 8 = 3,840 tokens.
MoE introduces two AllToAll operations per layer (routing tokens to experts and back). But AllToAll is only 1/4 the cost of an AllGather, making this relatively cheap.
D=4096, F=4D=16384, V=32000, L=64, N·H=D, MHA.
Attention fraction: 4D2 / (4D2 + 3×4D2) = 4/(4+12) = 1/4 of params are in attention.
KV cache per token (int8): 2 × 64 × 4096 = 512 KB/token.
A[BX, DY] · W[DY, F] on Mesh{'X':4, 'Y':8, 'Z':4}. Total theoretical FLOPs = 2BDF. But the computation is replicated over Z (not sharded), so total actual FLOPs = 2BDF × Z. Per device: 2BDF / (X×Y).
Flash Attention avoids materializing the full T×T attention matrix. It processes attention in chunks, keeping running statistics (max, sum, output) in VMEM. This reduces memory from O(BT2NH) to O(BTNH) — the full attention matrix is never stored. FLOPs are unchanged (still quadratic), but the arithmetic intensity increases dramatically because data stays in fast on-chip memory.
Saving only the 7 matmul outputs per layer means we recompute only the attention dot-product: 4BT2NH extra FLOPs per layer. Plus O(BTD) for gating/norm recomputation and O(BTF) for the output nonlinearity.
| Quantity | Formula | Typical Value (7B model) |
|---|---|---|
| MLP params/layer | 3DF | ~200M |
| Attn params/layer | 2D(N+K)H ≈ 4D2 (MHA) | ~67M |
| Total params | (3DF + 4D2)L + DV | ~7B |
| Training FLOPs/token | 6P (ignoring attn) | ~42 GFLOPs |
| Attn crossover | T > 8D | T > 32K |
| KV cache/token | 2LKH | 512 KB (int8, MHA) |
| Block remat overhead | +33% FLOPs (6P → 8P) | Saves ~20x activation memory |
| MoE critical batch | 120 × E/k | 3840 (DeepSeek) |
Armed with these equations and the roofline/sharding tools from Chapters 1–3, you can now analyze any distributed training or inference setup: estimate step time, identify bottlenecks, and choose optimal parallelism strategies.