Austin et al., Part 6

Training LLaMA 3 on TPUs

Counting parameters, estimating FLOPs, choosing parallelism configs, and calculating training time — all by hand.

Prerequisites: Parallelism strategies (Ch 5: FSDP, TP, PP), Transformer architecture (attention, FFN, parameter counting).
9
Chapters
1
Calculator
9
Quizzes

Chapter 0: The LLaMA 3 Family

The LLaMA 3 family includes three main models: 8B, 70B, and 405B parameters. In this chapter we focus on the 70B variant, leaving the others as exercises.

Here is the architecture for LLaMA 3-70B, taken from the HuggingFace model config:

HyperparameterSymbolValue
LayersL80
Model dimensionD8,192
FFN dimensionF28,672
Attention headsN64
KV headsK8 (GQA)
Head dimensionH128
Vocabulary sizeV128,256
Grouped Query Attention (GQA): LLaMA 3 uses 64 query heads but only 8 KV heads. Each group of 8 query heads shares one KV head. This reduces the KV cache size by 8x during inference while barely affecting quality.

Comparing the three LLaMA 3 models:

Hyperparameter8B70B405B
Layers (L)3280126
Model dim (D)4,0968,19216,384
FFN dim (F)14,33628,67253,248
Attention heads (N)3264128
KV heads (K)888
Head dim (H)128128128
Vocab (V)128,256128,256128,256

Notice the patterns: H is always 128, V is always 128,256, K is always 8 (GQA). The models scale primarily by increasing D, F, and L. F/D ≈ 3.5 for all three, consistent with the SwiGLU scaling convention. Also note that 405B has K=8 like 70B — the same number of KV heads regardless of model size. This is a deliberate design choice: more KV heads do not help quality much, but they dramatically increase KV cache size during inference.

The configuration is all you need to derive every important number: parameter count, FLOPs, memory, training time, and cost. Let us start.

Where to find these numbers: Every model on HuggingFace has a config.json that lists these hyperparameters. The mapping from HuggingFace names to our symbols:

HuggingFace keySymbol
num_hidden_layersL
hidden_sizeD
intermediate_sizeF
num_attention_headsN
num_key_value_headsK
vocab_sizeV

The head dimension H is typically D/N = 8192/64 = 128 for LLaMA models, but it is also listed as head_dim in some configs.

It is useful to make a spreadsheet with these numbers for many open-source LLMs. You will quickly see patterns: most models use H=128, most use SwiGLU with F ≈ 8/3 × D (rounded to a multiple of 256), and GQA is now standard.

Comparison with other model families:

ModelParamsDFLF/DKV heads
LLaMA 3-70B70B8,19228,672803.58
Mistral-7B7B4,09614,336323.58
GPT-3175B12,28849,152964.096 (MHA)
DeepSeek-V3671B7,16818,432612.6MLA

All modern models converge on F/D ≈ 3-4 (SwiGLU) and GQA with few KV heads. GPT-3 used full MHA with 96 KV heads — its inference KV cache would be 12x larger than LLaMA 70B. DeepSeek-V3 uses Multi-head Latent Attention (MLA), compressing KV representations differently.

Check: How many KV head groups does LLaMA 3-70B use?

Chapter 1: Counting Parameters

Let us derive the 70B parameter count from the config table. Every parameter in the Transformer falls into one of three groups:

ComponentFormulaCount
FFN (SwiGLU)D × F × 3 × L
(gate + up + down projections)
8,192 × 28,672 × 3 × 80 = 56.3B
AttentionL × [2 × D × N × H + 2 × D × K × H]
(Q, O projections + K, V projections)
80 × (2 × 8192 × 64 × 128 + 2 × 8192 × 8 × 128) = 12.0B
Embeddings2 × V × D
(input + output embeddings)
2 × 128,256 × 8,192 = 2.1B
RMSNorm2 × D × L (pre-attn + pre-FFN)2 × 8192 × 80 = 1.3M (negligible)
Total70.4B

Notice RMSNorm adds only 1.3M parameters — completely negligible. Rotary positional embeddings (RoPE) have zero stored parameters; they are computed on the fly. The parameter count is entirely dominated by the linear projections.

The MLP dominates everything. The three FFN weight matrices account for 56.3B of 70.4B total parameters — roughly 80%. This means when reasoning about memory or FLOPs, you can almost ignore everything else. This is true for virtually all modern LLMs.

Why SwiGLU has 3 weight matrices instead of 2: A standard FFN has two matrices: up-projection (D → F) and down-projection (F → D). SwiGLU adds a gate matrix (D → F) that modulates the up-projection element-wise: output = (gate × SiLU(gate_proj)) × down_proj. Three matrices, each of size D × F. This is why the FFN count is D × F × 3 × L, not D × F × 2 × L.

Note on F for SwiGLU: Since SwiGLU has 3 matrices instead of 2, the FFN width F is typically reduced to keep the total parameter count similar to a standard FFN. The rule of thumb: F ≈ 8D/3 rounded to a multiple of 256. For LLaMA 70B: 8 × 8192 / 3 = 21845, but the actual value 28672 is larger, likely chosen for efficiency reasons (nice multiple of chip dimensions).

Let us verify the attention count step by step. Each layer has:

Q projection: D × (N × H) = 8192 × 8192 = 67.1M
K projection: D × (K × H) = 8192 × 1024 = 8.4M
V projection: D × (K × H) = 8192 × 1024 = 8.4M
O projection: (N × H) × D = 8192 × 8192 = 67.1M
Per layer: 67.1M + 8.4M + 8.4M + 67.1M = 151M
All layers: 151M × 80 = 12.1B ✓
Check: What fraction of LLaMA 3-70B's parameters are in the FFN blocks?

Chapter 2: FLOPs per Token

A standard rule of thumb: a training step (forward + backward) uses approximately 6 × parameter count FLOPs per token. The factor of 6 comes from: 2 for the forward pass (multiply-accumulate) × 3 for forward + two backward passes.

FLOPs per token = 6 × 70e9 = 4.2 × 1011

That is about half a teraFLOP per token per step. On a single TPU v5p chip (459 TFLOPS bf16):

Time per token = 4.2e11 / 4.59e14 ≈ 0.9 ms

This assumes we are compute-bound and achieving near-peak FLOPs. In practice, we target 30-50% MFU (Model FLOPs Utilization).

What is MFU? Model FLOPs Utilization measures what fraction of the hardware's peak FLOPs is actually spent on useful model computation. An MFU of 40% means 60% of the FLOPs capacity is lost to communication, memory loading, pipeline bubbles, or other overhead. MFU = 40-50% is considered good for large-scale training.

Breaking down the 6x factor: Where does the "6 × params" rule come from?

Forward pass: 2P FLOPs
Each parameter participates in one multiply and one add (MAC) per token
Backward (activation grads): 2P FLOPs
Same matmuls as forward, but transposed
Backward (weight grads): 2P FLOPs
Outer product of activations and activation gradients
Total = 6P FLOPs/token

This is an approximation. It ignores attention QK^T and softmax (which add ~4BSNH FLOPs per layer), LayerNorm, and activation functions. For long sequences, the attention FLOPs can be significant. But for the standard regime (seq < 8K, large FFN), the 6P rule is accurate to within ~5%.

Checking the 6x rule against exact computation:

MLP FLOPs per token per layer = 6 × D × F (3 matmuls, each 2DF, fwd+bwd)
Attention FLOPs per token per layer ≈ 6 × D × (N+K) × H + 4 × N × S × H
For LLaMA 70B at S=4096:
MLP: 6 × 8192 × 28672 = 1.41e9 FLOPs/token/layer
Attention projections: 6 × 8192 × 72 × 128 = 4.53e8
Attention QKV: 4 × 64 × 4096 × 128 = 1.34e8 (small!)
Total per layer: 1.41e9 + 4.53e8 + 1.34e8 = 2.0e9
All layers: 2.0e9 × 80 + embeddings ≈ 1.60e11 per token
The 6P approximation: 6 × 70e9 / 4 ≈ 1.05e11 per token (forward only)

The exact number is about 50% higher than the simple 6P/4 estimate because attention projections add significant FLOPs. The 6P rule works well for total training FLOPs but undercounts when breaking down per-layer costs. For back-of-the-envelope total training time, it is good enough.

Detailed FLOPs breakdown by operation (forward pass only, per token):

OperationFLOPs/token/layerAll 80 layers
FFN (3 projections)3 × 2 × 8192 × 28672 = 1.41e91.13e11
Attn projections (Q,K,V,O)2 × 8192 × (2×8192 + 2×1024) = 3.02e82.41e10
Attn QK^T + softmax×V4 × 64 × 4096 × 128 = 1.34e81.07e10
Total forward1.48e11
Total fwd+bwd (×3)4.44e11

The exact count of 4.44e11 vs our 6P estimate of 4.2e11 shows the rule is accurate to within 6%. Good enough for all practical purposes.

LLaMA 3 was trained for approximately 15 trillion tokens. Total training FLOPs:

Total FLOPs = 4.2e11 × 15e12 = 6.3 × 1024

That is 6.3 yottaFLOPs. On a single TPU v5p, this would take:

6.3e24 / 4.59e14 = 1.37 × 1010 seconds ≈ 435 years
Check: How many FLOPs does one training step use per token for a 70B model?

Chapter 3: Total Training Cost

Now let us estimate the training time on a full TPU v5p pod (8960 chips) at 40% MFU:

T = Total FLOPs / (N × C × MFU)
T = 6.3e24 / (8960 × 4.59e14 × 0.4)
T = 6.3e24 / 1.645e18 = 3.83 × 106 seconds
T ≈ 44 days

Let us sanity-check this. Each training step processes 4M tokens. FLOPs per step = 4.2e11 × 4e6 = 1.68e18. Time per step = 1.68e18 / (8960 × 4.59e14 × 0.4) = 1.02 seconds. Number of steps = 15e12 / 4e6 = 3.75M steps. Total time = 3.75e6 × 1.02 = 3.83e6 seconds = 44.3 days. The two approaches agree.

Steps per day: 3.75e6 steps / 44.3 days = ~84,600 steps/day, or about 1 step per second. This is important for monitoring — if your training run suddenly drops to 0.5 steps/second, you know something is wrong (communication bottleneck, hardware failure, or checkpointing overhead).

44 days on one pod. That is fairly reasonable, assuming we can actually achieve 40% MFU. The critical enabler is choosing the right parallelism configuration to keep communication overhead low.

Let us also estimate the dollar cost. TPU v5p chips cost approximately $4.20/hour on Google Cloud:

Cost = 8960 chips × $4.20/hr × (3.83e6 / 3600) hrs
Cost = 8960 × $4.20 × 1064 = $40.0 million

This is a rough estimate. Real costs include fault tolerance overhead (checkpointing, restarts), networking costs, storage, and the engineering team. The actual LLaMA 3 report quotes training on 16,384 H100 GPUs.

Hidden costs beyond compute:

Cost componentMultiplierNotes
GPU/TPU compute1x (base)The $40M we computed
Fault tolerance overhead+5-15%Time lost to restarts, checkpointing
Data preparation+5-10%Tokenization, dedup, quality filtering
Evaluation runs+10-20%Benchmark suites, ablation studies
Infrastructure (networking, storage)+10-20%High-speed interconnects, data storage
Engineering team+20-50%Salaries for 10-50 engineers over 6-12 months
Realistic total1.5-2.5x$60-100M for a 70B model
TopologyChipsTraining time (40% MFU)Approximate cost
1 TPU v5p1435 years$16M
1/4 pod2240176 days$40M
Full pod896044 days$40M
4 pods3584011 days$40M
Notice: The total cost is roughly the same regardless of how many chips you use (assuming the same MFU). More chips = faster training but same total chip-hours. The real benefit of more chips is time to completion, not cost savings.

Why is time-to-completion so important? In practice, faster training is worth paying a premium for. Reasons:

1. Iteration speed: If an experiment takes 44 days, you get feedback 4x faster than on a quarter-pod (176 days). Over a year of development, this means 4x more experiments.

2. Competitive pressure: In the frontier lab race, shipping a model 3 months earlier can be the difference between leading and lagging.

3. Reliability: Longer training runs are more likely to encounter hardware failures. Google's LLaMA 3 report mentions that their 16K GPU cluster experienced an interruption roughly every 3 hours on average.

4. MFU usually improves with scale: Larger clusters can support higher batch sizes, which improves compute utilization. So total cost may actually decrease with more chips.

Check: If you double the number of chips (at the same MFU), what happens to total training cost?

Chapter 4: Memory Budget

How many chips do we need at minimum? This is a memory question, not a compute question. During training, HBM holds three things:

ComponentFormulaSize (LLaMA 70B)
Parameters (bf16)2 × P140 GB
Optimizer state (Adam, fp32)8 × P
(fp32 copy + momentum + variance)
560 GB
Gradient checkpoints2 × D × B × nckpt × L~20.9 TB
(4 checkpoints/layer, 4M token batch)
Total~21.6 TB
Gradient checkpoints dominate. Even with a conservative 4 checkpoints per layer, activations use 20.9 TB — dwarfing the 700 GB for model state. This is because the batch is enormous (4M tokens).

Let us derive the gradient checkpoint size. Each checkpoint saves the activation tensor at that point: shape (B, D) in bf16. With 4 checkpoints per layer:

Checkpoint memory = 2 bytes × D × Btokens × 4 × L
= 2 × 8192 × 4,000,000 × 4 × 80
= 20.97 × 1012 bytes = 20.97 TB

With 96 GB HBM per TPU v5p chip, minimum chips = 21.6e12 / 96e9 = 225 chips. That is tiny compared to 8960! We are using those extra chips not because we need the memory but because we need the FLOPs to finish training in reasonable time.

On 8960 chips, memory per chip = 21.6 TB / 8960 ≈ 2.4 GB per chip. We are using only 2.5% of HBM. Even with 12 checkpoints per layer, we would still only be at ~8 GB per chip.

How many chips do we need at minimum for each LLaMA 3 model?

ModelParam memory (bf16)Optimizer (Adam fp32)Total (no checkpoints)Min TPU v5p chips
8B16 GB96 GB112 GB2
70B140 GB840 GB980 GB11
405B810 GB4860 GB5670 GB60

Even the 405B model only needs 60 chips for the model state alone. The gradient checkpoints (which depend on batch size) can add significantly more, but the point remains: memory is not the binding constraint at production scales.

Why do we use bf16 for parameters and fp32 for optimizer? The parameters and gradients use bf16 (2 bytes each) because the forward and backward passes can tolerate reduced precision. But the Adam optimizer maintains a running average of gradients (momentum m) and squared gradients (variance v), plus a master copy of the weights. These accumulators require fp32 (4 bytes each) to avoid numerical instability — small gradient updates can underflow in bf16. Total per parameter: 4 (master weight) + 4 (m) + 4 (v) = 12 bytes.

Key takeaway: At large scale, we are almost never memory-bound during training. The cluster size is determined by how fast we want to train, not by memory constraints.

What are gradient checkpoints? During the backward pass, we need the activations from the forward pass to compute gradients. Normally we store all activations, but for large models this is prohibitive. Gradient checkpointing (also called "activation recomputation") stores only a few checkpoint activations and recomputes the rest during the backward pass. The trade-off: extra compute (~33% more FLOPs) for much less memory.

With 4 checkpoints per layer, we store 4 activation tensors of shape (Bmicro, D) per layer. The "B" in the formula above is the full batch in tokens, because with FSDP each chip processes B/N tokens but we need checkpoints for the full batch across all microbatches.

What if we used only 1 checkpoint per layer?

Checkpoint memory = 2 × 8192 × 4,000,000 × 1 × 80 = 5.24 TB
Total = 140 GB + 560 GB + 5.24 TB = 5.94 TB
Per chip (8960): 5.94 TB / 8960 = 0.66 GB — still trivial!

Even with aggressive checkpointing and microbatching, memory per chip stays well under the 96 GB limit on v5p. This confirms that at these scales, we have enormous memory headroom.

Check: What dominates HBM usage during LLaMA 3-70B training with 4M token batch?

Chapter 5: Sharding LLaMA 70B

Let us work through the parallelism config for LLaMA 3-70B on a full TPU v5p pod (8960 chips) with a 4M token batch (1024 sequences of length 4096).

Attempt 1: Pure FSDP. Can we shard everything with FSDP alone?

Per-chip batch = 4M / 8960 = 447 tokens
FSDP comms-bound threshold (3 ICI axes) = 2550 / 3 = 850
447 < 850 ⇒ Communication-bound!
Wait — even with sequence sharding (FSDP over both batch and sequence axes), we only have 1024 sequences. Pure 8960-way FSDP puts 468 tokens per chip (4 × 1024 × 1024 / 8960). Still below 850. Pure FSDP fails.

Attempt 2: FSDP + TP. Use the optimal FSDP formula:

Xopt = √(2 × 4.19e6 × 8960 / 28672) = √(2.63e6) ≈ 1618

Round to 2048-way FSDP, giving TP = 8960 / 2048 ≈ 4-way TP.

Let us verify both dimensions are compute-bound:

TP check: Y = 4 < F × MY / 2200 = 28672 × 2 / 2200 = 26 ✓
FSDP check: per-FSDP-group batch = 468 × 4 = 1872. Threshold = 2550/2 = 1275. 1872 > 1275 ✓
The winning config: ~1024-way data parallelism + 2-way sequence parallelism + 4-way tensor parallelism. This keeps us compute-bound on all dimensions.

In practice, the split is: 1024-way DP (one sequence per DP rank), 2-way sequence/context parallelism (splitting 4096-length sequences in half), and 4-way TP (splitting weight matrices across 4 chips within each node).

Why not more TP? We could do 8-way TP with 1120-way FSDP. But the TP limit is Y < F × MY / 2200. With MY = 1 axis: Y < 28672/2200 = 13. So 8-way TP is fine. But with more TP, the per-FSDP-group batch would be 468 × 8 = 3744, and FSDP degree would be 1120. The FSDP threshold becomes: 3744 > 2550/1 = 2550? Yes! So 8-way TP + 1120 FSDP also works. But 4-way TP is simpler and has less TP communication, so it is preferred.

Can we train on fewer chips? Say we only had 2240 chips (a quarter-pod):

Per-chip batch = 4M / 2240 = 1786 tokens
FSDP threshold (3 axes) = 850. 1786 > 850 ⇒ Pure FSDP works!

On a quarter-pod, we do not even need TP. Pure FSDP suffices because the per-chip batch is large enough. Training time = 44 × 4 = 176 days.

What about 225 chips (the minimum for memory)?

Per-chip batch = 4M / 225 = 17,778 tokens
FSDP threshold = 850. 17,778 >> 850 ⇒ Pure FSDP, compute-bound, easy.
But training time = 44 × (8960/225) = 1752 days = 4.8 years!

Technically possible. Practically absurd. This illustrates why large clusters exist: not for memory, but for speed.

The "Goldilocks zone" for chip count:

ChipsPer-chip batchParallelism neededTraining timeStatus
22517,778FSDP only4.8 yearsToo slow
22401,786FSDP only176 daysSlow but feasible
8960447FSDP + 4-way TP44 daysProduction sweet spot
35840112PP + FSDP + TP11 daysFast but complex
Check: Why does pure FSDP fail for LLaMA 70B on 8960 chips?

Chapter 6: FSDP vs FSDP+TP — The Numbers

Let us make the communication cost concrete. For a single FFN layer of LLaMA 70B (D=8192, F=28672):

Pure FSDP (8960-way):

Comms per layer = 3 × Wlayer = 3 × 2 × D × F = 3 × 2 × 8192 × 28672 = 1.41 GB
ICI bandwidth (3 axes) = 4800 GB/s
Tcomms = 1.41 GB / 4800 GB/s = 0.29 ms
Tcompute = 6 × 2 × 8192 × 28672 × 468 / (4.59e14) = 0.14 ms
Tcomms > Tcompute ⇒ Communication-bound!

FSDP (2048-way) + TP (4-way):

FSDP comms per layer = 3 × 2 × D × F / 4 = 0.35 GB (parameter volume reduced by TP)
TP comms per layer = 4 × Bper-chip × D = 4 × 468 × 8192 × 2 = 30.7 MB
TFSDP-comms = 0.35 GB / (2 axes × 1600 GB/s) = 0.11 ms
TTP-comms = 30.7 MB / (1 axis × 1600 GB/s) = 0.019 ms
Tcompute = 6 × 2 × 8192 × 28672 × 468 / (4 × 4.59e14) = 0.57 ms per 4-chip group
Tcompute > Tcomms ⇒ Compute-bound!

Let us also verify the memory picture. With 2048-way FSDP and 4-way TP:

Params per FSDP shard = 70B / 2048 = 34.2M params = 68.4 MB (bf16)
Optimizer per shard = 34.2M × 12 = 410 MB (Adam fp32)
Per chip: 68.4 MB + 68.4 MB (grads) + 410 MB = 547 MB
Add activations (checkpointed): 2 × 8192 × 468 × 4 = 30.7 MB per layer × 80 = 2.46 GB
Total per chip: ~3.0 GB out of 96 GB. Only 3.1% utilized!

Why so much wasted memory? This is a consequence of using 8960 chips for a problem that only needs 225 chips of memory. The remaining 93 GB per chip (97% of HBM!) sits unused. In theory, we could use this spare memory for:

More aggressive checkpointing
Store more intermediate activations to reduce recomputation
Larger micro-batch sizes
Process more tokens per chip per step for better hardware utilization
Pre-fetching next layers
Cache future AllGather results for latency hiding

In practice, frameworks like Megatron-LM and MaxText do use some of this spare memory for communication buffers and prefetching, which helps overlap communication with compute.

Practical training frameworks for LLaMA-scale models:

FrameworkHardwareParallelism supportUsed by
Megatron-LMNVIDIA GPUTP, PP, DP, CP, EPNVIDIA, Meta (LLaMA)
MaxTextTPUFSDP, TP, PP (via JAX)Google DeepMind
DeepSpeedGPUZeRO-1/2/3, PP, TPMicrosoft
FSDP (PyTorch)GPUZeRO-3 styleMeta, community
NanotronGPUTP, PP, DPHugging Face

All of these implement the same fundamental parallelism strategies we derived in this chapter. The differences are in engineering: how well they overlap communication with compute, how they handle fault tolerance, and how easy they are to configure. For a new project, the choice is usually determined by your hardware (TPU ⇒ MaxText, NVIDIA GPU ⇒ Megatron-LM or DeepSpeed).

Putting it all together — a complete training recipe for LLaMA 3-70B:

1. Hardware: 8960 TPU v5p (1 full pod)
96 GB HBM each, 459 TFLOPS bf16, 4800 GB/s ICI
2. Batch: 4M tokens (1024 seqs × 4096 len)
Within Chinchilla sweet spot for gradient noise
3. Parallelism: 2048-way FSDP + 4-way TP
Compute-bound on both dimensions
4. Precision: bf16 params, fp32 optimizer
3 GB model state per chip (3.1% of HBM)
5. Checkpointing: 4/layer + async save every 1K steps
2.4 GB activations per chip
6. Result: ~44 days, ~$40M, ~40% MFU
3.75M steps at ~1 step/second

Every number in this recipe was derived from first principles in this chapter. You can repeat this analysis for any model by plugging in the architecture hyperparameters and hardware specs.

The combined approach wins. By adding 4-way TP, we reduce the FSDP parameter volume by 4x (since TP already distributes the weights) and the TP comms are tiny compared to compute. The total communication drops below compute time.

What about scaling to 4 pods (35,840 chips)? We would add 4-way pipeline parallelism across pods:

4-way PP × 2048-way FSDP × 4-way TP = 32,768 chips
PP bubble = (4-1) / (M + 3). With M = 60 microbatches: 3/63 = 4.8% ✓
Training time = 44 days / 4 × (1 + 0.048) ≈ 11.5 days
Check: What makes the FSDP+TP combination more efficient than pure FSDP?

Chapter 7: Training Cost Calculator

Use this interactive calculator to estimate training time and cost for different model sizes and hardware configurations.

LLaMA Training Estimator

Adjust parameters to estimate training time and cost.

70B
15T
8960
40%
Check: At 40% MFU on 8960 TPU v5p chips, roughly how long does it take to train LLaMA 70B on 15T tokens?

Chapter 8: Summary

Let us collect the key results for training LLaMA 3-70B:

QuantityValueHow derived
Parameter count70.4BSum of FFN (56.3B) + Attention (12.0B) + Embeddings (2.1B)
FLOPs/token/step4.2e116 × param count
Total FLOPs (15T tokens)6.3e24FLOPs/token × token count
Training time (8960 chips, 40% MFU)44 daysTotal FLOPs / (N × C × MFU)
Minimum chips (memory)22521.6 TB / 96 GB per chip
Parallelism config2048 FSDP × 4 TPXopt = √(2BN/F)
Estimated cost~$40MN × $/hr × hours
The big lessons:
1. The MLP dominates parameter count (~80%) and therefore FLOPs.
2. At large scale, memory is abundant — the cluster size is for FLOPs, not memory.
3. Pure FSDP fails at 8960 chips — combining FSDP + TP keeps us compute-bound.
4. Total cost is invariant to cluster size (at constant MFU); more chips buys time, not dollars.

Extension: LLaMA 3-405B. Using the 405B config (D=16384, F=53248, L=126, V=128256):

Params: FFN = 16384 × 53248 × 3 × 126 = 330B. Attention ≈ 62B. Embeddings = 4.2B. Total ≈ 396B.
FLOPs/token = 6 × 396e9 = 2.38e12
Total FLOPs (15T tokens) = 2.38e12 × 15e12 = 3.57e25
Training time (8 pods, 71680 chips, 35% MFU) = 3.57e25 / (71680 × 4.59e14 × 0.35) = 3.1e6 s = 36 days

This would use 8-way PP across pods, ~4096-way FSDP within each pod, and 4-way TP within each node. The PP bubble with 128 microbatches: (8-1)/(128+7) = 5.2%. Total efficiency: 0.35 × (1-0.052) ≈ 33% MFU after bubble.

Extension: LLaMA 3-8B. Much simpler (D=4096, F=14336, L=32):

Params ≈ 8B. FLOPs/token = 48e9. Total = 7.2e23 FLOPs.
On 2240 chips at 45% MFU: T = 7.2e23 / (2240 × 4.59e14 × 0.45) = 1.6e6 s = 18 days.
Pure FSDP: per-chip batch = 4M/2240 = 1786. Threshold = 850. Works!

Scaling laws and the cost of training: The Chinchilla scaling law suggests training on ~20 tokens per parameter for compute-optimal training. For 70B: 20 × 70B = 1.4T tokens. But LLaMA 3 was trained on 15T tokens — roughly 10x the Chinchilla-optimal amount. Why?

Because the Chinchilla law optimizes for training compute, not inference cost. A model trained on more data is smaller but equally capable, which means cheaper inference. Since inference cost often exceeds training cost over a model's lifetime, over-training a smaller model is economically rational. This is sometimes called the "inference-aware" or "overtrained" scaling regime.

Comparing training costs across model sizes:

ModelParamsFLOPs (15T tokens)v5p chipsMFUTimeCost
LLaMA 3-8B8B7.2e232,24045%~16 days~$4.3M
LLaMA 3-70B70B6.3e248,96040%~44 days~$40M
LLaMA 3-405B405B3.6e2571,68035%~36 days~$260M

The 405B model costs roughly 6.5x more than the 70B despite having only 5.8x more parameters, because MFU tends to decrease at larger scales (more communication overhead, pipeline bubbles, etc.).

What limits MFU at scale?

Overhead sourceTypical costHow to minimize
FSDP AllGather/ReduceScatter5-15%Overlap with compute, use TP to reduce volume
TP ReduceScatter3-8%Keep TP degree low, use fast intra-node links
PP bubble3-10%More microbatches, interleaved schedules
Activation recomputation~33%Cannot avoid — fundamental trade-off
Data loading / preprocessing1-5%Prefetch data in background
Checkpointing overhead1-3%Async checkpointing
Non-matmul ops (LayerNorm, etc.)2-5%Fused kernels

Activation recomputation is the single largest overhead. Since we recompute all activations during the backward pass, we effectively run the forward pass 1.33x times (once forward, then re-forward during backward). This alone limits theoretical max MFU to ~75%. Add communication and other overheads, and 40-50% is genuinely good.

What about training on H100 GPUs instead? Meta's actual LLaMA 3 training used 16,384 H100 GPUs. Let us estimate:

H100 bf16 FLOPs = 990 TFLOPS. Total: 16384 × 9.9e14 = 1.62e19 FLOPS/s
At 40% MFU: 6.3e24 / (1.62e19 × 0.4) = 9.72e5 s = 11.3 days
Cost: 16384 × $10.80/hr × (9.72e5 / 3600) = $47.8M

Comparable to our TPU estimate. The training time is shorter (11 vs 44 days) because we have nearly twice the total FLOPs, but the dollar cost is similar because H100s cost more per chip.

Efficiency comparison:

Metric8960 TPU v5p16384 H100
Total peak FLOPs/s4.11e181.62e19
Training time (40% MFU)44 days11 days
Total chip-hours9.44M4.44M
Estimated cost~$40M~$48M
Check: Why is the cluster size for LLaMA 70B determined by FLOPs rather than memory?