Austin et al., Part 5

How to Parallelize a Transformer for Training

Data parallelism, FSDP, tensor parallelism, pipeline parallelism — when each strategy wins, when it breaks, and how to combine them.

Prerequisites: Roofline model basics (arithmetic intensity, compute vs memory bound), sharded matmuls (AllGather, ReduceScatter costs).
10
Chapters
1
Simulation
10
Quizzes

Chapter 0: What Do We Mean By Scaling?

You have a model that trains perfectly on one chip. Now you want to make it ten times bigger and train on a thousand chips. Will you get a thousand-fold speedup? Almost certainly not — but the gap between "linear speedup" and "actual speedup" is exactly what this chapter is about.

The goal of "model scaling" is deceptively simple: double the chips, halve the training time. In other words, we want strong scaling — a proportional, linear increase in throughput as we add more accelerators.

On a single chip, performance depends on the trade-off between memory bandwidth and FLOPs. At the cluster level, a new bottleneck emerges: inter-chip communication. Every time we add more chips, we increase the communication load while reducing the per-device computation available to hide it.

Think of it like a restaurant kitchen. One chef can cook a meal in an hour. Ten chefs should cook ten meals in an hour, but only if they do not spend half their time coordinating. The "parallelism strategies" in this chapter are the recipes for organizing the kitchen so the chefs stay busy cooking instead of talking.

The core question: Sharded matmuls require expensive AllGathers or ReduceScatters that can block chips from doing useful work. The goal of this chapter is to figure out when these become too expensive for each parallelism strategy.

We will study four common parallelism schemes:

Data Parallelism (DP)
Replicate model on every chip, split data across chips
Fully-Sharded DP (FSDP / ZeRO)
Shard parameters + optimizer across chips, gather on demand
Tensor Parallelism (TP)
Split weight matrices within each layer across chips
Pipeline Parallelism (PP)
Assign different layers to different chips, pipeline microbatches

For each scheme, we will derive: (1) the memory savings, (2) the communication volume, and (3) the critical batch size below which we become communication-bound.

Before we dive in, let us establish some notation. For a Transformer model:

SymbolMeaningExample (LLaMA 70B)
DModel dimension (hidden size)8,192
FFFN intermediate dimension28,672
LNumber of layers80
BGlobal batch size (tokens)4,000,000
NNumber of chips8,960
CPeak FLOPs/s per chip4.59e14 (v5p)
WInter-chip bandwidth4800 GB/s (ICI, 3 axes)

The key insight of this chapter is that every parallelism strategy can be analyzed through the lens of a single ratio: computation time vs communication time. When the compute per chip exceeds the communication per chip, we are in the happy "compute-bound" regime. When communication dominates, we are wasting expensive hardware on waiting.

Let us build intuition with a concrete example. Imagine training a 70B model with a 4M token batch on increasing numbers of TPU v5p chips:

ChipsPer-chip batchCompute time/layer (ms)Comms time/layer (ms)Status
25615,625 tok5.90.29Compute-bound (20x margin)
10243,906 tok1.50.29Compute-bound (5x margin)
4096977 tok0.370.29Borderline!
8960447 tok0.170.29Comms-bound (1.7x over)

See how the compute time per layer shrinks as we add chips (each chip does less work), but the communication stays constant (the AllGather always moves the same volume of data)? The crossover point is where we need to get smarter about our parallelism strategy.

Check: What does "strong scaling" mean in distributed training?

Chapter 1: Data Parallelism

Data parallelism is the simplest distributed strategy. Every chip holds a full copy of the model. The global batch is split across N chips so each chip processes B/N samples. After the backward pass, gradients are summed via an AllReduce.

PropertyValue
Model memoryFull copy on every chip (no savings)
Data splitEach chip sees B/N of the global batch
CommunicationAllReduce of gradients once per step
Communication volume2P bytes per chip (ring AllReduce)

The AllReduce has two phases: ReduceScatter (each chip sends (N-1)/N of its gradients) and AllGather (each chip receives the summed result). In a ring topology, each chip sends and receives exactly 2P · (N-1)/N ≈ 2P bytes total.

Let us be concrete about what these operations do:

OperationInputOutputBytes per chip
AllGatherEach chip has 1/N of the dataEvery chip has all the data(N-1)/N × total_size
ReduceScatterEach chip has full data (different values)Each chip has 1/N of the sum(N-1)/N × total_size
AllReduceEach chip has full dataEvery chip has the sum= ReduceScatter + AllGather

In a ring AllReduce with N chips, the data is split into N chunks. Each chip sends one chunk to its neighbor and receives one chunk, repeating N-1 times for the ReduceScatter phase and N-1 times for the AllGather phase. Total per-chip transfer: 2 × (N-1)/N × data_size ≈ 2 × data_size for large N.

Why the ring AllReduce is bandwidth-optimal: Any algorithm that sums N copies of a vector of size S must transfer at least 2S(N-1)/N bytes per chip (information theory lower bound). The ring AllReduce achieves this exactly, making it the most bandwidth-efficient collective for large messages. For small messages, latency (not bandwidth) dominates, and tree-based reductions are preferred.

Gradient accumulation — the DP memory trick: If the per-chip batch B/N is too large to fit in memory (due to activations), we can split it further into micro-steps. Compute gradients for a smaller micro-batch, accumulate them in memory, and only run the AllReduce once all micro-batches are done. This uses the same communication volume but allows us to process large effective batch sizes with limited memory. The trade-off: more sequential compute per step, but the same communication cost.

When does DP break? DP is communication-bound when the AllReduce time exceeds the compute time. The compute per chip is proportional to B/N (per-device batch size). As N grows, B/N shrinks while the AllReduce volume stays at 2P. Eventually the AllReduce dominates.

Let us derive the critical point. For a single Transformer layer with FFN width F and model dimension D, the forward + backward FLOPs per chip are roughly:

Tcompute = 6 · 2DF · (B/N) / C

where C is the peak FLOPs/s per chip. The AllReduce transfers 2 · 2DF bytes (both ReduceScatter and AllGather of one layer's parameters):

Tcomms = 2 · 2DF / W

where W is the inter-chip bandwidth. We become communication-bound when Tcomms > Tcompute:

2 · 2DF / W > 12DF · B / (N · C)
B/N < C / (3W)

On TPU v5p with ICI bandwidth W = 4800 GB/s across 3 axes: C/3W ≈ 4.59e14 / (3 × 4800e9) ≈ 32 tokens per chip. So if your per-device batch drops below ~32 tokens, pure DP becomes communication-bound.

Practical limit: On a TPU v5p pod (8960 chips) with a 4M token global batch, the per-device batch is 4M/8960 ≈ 447 tokens — well above the threshold. But on smaller batches or very large clusters, DP alone fails.

Worked example: You have 256 TPU v5p chips and a 1M token batch. Can you use pure DP?

Per-device batch = 1,000,000 / 256 = 3906 tokens
Threshold = C / (3W) = 4.59e14 / (3 × 4800e9) = 32 tokens
3906 >> 32 ⇒ Compute-bound. Pure DP works!

Now try 8960 chips with the same 1M batch:

Per-device batch = 1,000,000 / 8960 = 112 tokens
112 > 32 ⇒ Still compute-bound, but getting closer to the edge.

The real limitation of DP is not the communication threshold — it is memory. Every chip must hold the entire model, optimizer state, and gradient buffers. For a 70B model in bf16 with Adam optimizer, that is 16 × 70B = 1.12 TB per chip. No single chip has that much memory. This is why we need FSDP.

Memory breakdown for pure DP (per chip):

ComponentBytes per param7B model70B model
Parameters (bf16)214 GB140 GB
Gradients (bf16)214 GB140 GB
Adam master weights (fp32)428 GB280 GB
Adam momentum m (fp32)428 GB280 GB
Adam variance v (fp32)428 GB280 GB
Total16112 GB1,120 GB

A 7B model barely fits on a single H100 (80 GB) — and that is before activations. A 70B model needs 1.12 TB per chip, roughly 12 H100s or 12 TPU v5p chips. This memory wall is the #1 motivation for moving beyond pure DP.

Check: In pure data parallelism, what is the communication volume per chip during AllReduce?

Chapter 2: Fully-Sharded Data Parallelism (FSDP)

Pure DP wastes memory: every chip holds the full model, full optimizer state, and full gradients. FSDP (also known as ZeRO) fixes this by sharding all three across N chips.

What is shardedMemory per chip (before)Memory per chip (after FSDP-N)
Parameters2P bytes (bf16)2P/N bytes
Gradients2P bytes2P/N bytes
Optimizer state (Adam)12P bytes (fp32 copy + m + v)12P/N bytes
Total16P bytes16P/N bytes
Memory savings are huge. A 70B model needs ~1.12 TB for parameters + optimizer in DP. With 256-way FSDP, that drops to ~4.4 GB per chip — easily fitting on a single TPU or GPU.

The price is more communication. In the forward pass, each layer must AllGather its parameters before computing (since each chip only stores 1/N of the weights). In the backward pass, we do a ReduceScatter on the gradients.

Let us count the bytes moved. For one layer with weight size Wlayer:

Forward: AllGather
Each chip gathers full weights: Wlayer bytes sent per chip
Backward (gradients): AllGather
Need full weights again for gradient computation: Wlayer bytes
Backward (reduce): ReduceScatter
Sum and shard the gradients: Wlayer bytes

Total per layer: 3 × Wlayer bytes per chip. Compare this to pure DP which does 2 × Wlayer per chip. FSDP costs 50% more communication but saves (N-1)/N of the memory.

Key insight: FSDP communication is 3/2 of DP communication. The extra cost comes from the forward-pass AllGather that DP does not need (since DP already has all parameters locally).

When does FSDP become communication-bound? Following the same analysis as DP but with 50% more comms:

B/N < C / (2W) ≈ 48 tokens per chip (TPU v5p, 3 ICI axes)

Since FSDP can overlap communication with computation (gather layer i+1 while computing layer i), the practical threshold is lower. The book derives that with MX ICI axes dedicated to FSDP:

Bper-chip < 2550 / MX ⇒ communication-bound

where 2550 is a hardware-specific constant for TPU v5p reflecting the ratio of FLOPs to bandwidth. With 3 axes: threshold ≈ 850 tokens per chip.

How communication overlap works: While chip i computes the forward pass for layer k, the AllGather for layer k+1 runs in the background on the ICI links. If the AllGather finishes before the computation, we pay zero extra latency — the communication is fully hidden. The threshold tells us when the AllGather takes longer than the computation and starts to block.

Worked example — ZeRO memory savings: Consider a 7B model with Adam optimizer:

Component1 GPU (DP)8-way FSDP64-way FSDP
Parameters (bf16)14 GB1.75 GB0.22 GB
Gradients (bf16)14 GB1.75 GB0.22 GB
Optimizer (fp32, Adam)84 GB10.5 GB1.31 GB
Total model state112 GB14 GB1.75 GB

With 8-way FSDP, a 7B model's state fits comfortably in a single H100 (80 GB). With 64-way, even a 70B model is manageable at ~17.5 GB per chip. The key trade-off: more FSDP degree means more communication, but memory savings scale linearly.

ZeRO stages mapped to FSDP: ZeRO Stage 1 shards only the optimizer state. ZeRO Stage 2 adds gradient sharding. ZeRO Stage 3 (FSDP) shards everything including parameters. Each stage adds more communication but saves more memory. In practice, FSDP (Stage 3) is almost always used because the memory savings are so dramatic.
Check: How much more communication does FSDP require compared to pure DP?

Chapter 3: Tensor Parallelism

Tensor parallelism (TP) splits individual weight matrices across chips. Instead of each chip holding the full weight and processing a fraction of the batch, each chip holds a fraction of the weight and processes the full batch.

Consider a simple matrix multiply Y = XW where X is (B, D) and W is (D, F). We can split W along either axis:

StrategySplitPer-chip computeCommunication
Column parallelW split along F (columns)X × Wshard = (B, F/Y)AllGather output to get (B, F)
Row parallelW split along D (rows)Xshard × Wshard = (B, F)ReduceScatter to sum partial results

where Y is the TP degree (number of chips). In practice, Transformer FFN blocks use column-then-row: the first matmul (up-projection) uses column parallelism, and the second (down-projection) uses row parallelism. This way the AllGather output of the first feeds directly into the row-split input of the second, with only one AllReduce per FFN block.

Communication per layer (forward pass): One ReduceScatter of size 2BD bytes. That is the partial sum from the row-parallel matmul. In the backward pass, we need the transpose operations, adding another ReduceScatter. Total: 2 × 2BD bytes per layer per chip.

When does TP become communication-bound? The per-chip FLOPs for one FFN layer are 2BDF/Y (since we have F/Y columns). The communication is 2 × 2BD bytes. We are compute-bound when:

2BDF / (Y · C) > 4BD / W
F / (2Y) > C / W
Y < F · W / (2C)

For TPU v5p with a single ICI axis (W = 1600 GB/s) and F = 28672 (LLaMA 70B):

Y < 28672 × 1.6e12 / (2 × 4.59e14) ≈ 50

So we could do up to ~50-way TP before becoming communication-bound. But this uses one axis. With dedicated ICI axes MY:

Ymax ≈ F · MY / 2200
Key difference from FSDP: TP communication is proportional to B×D (activations), while FSDP communication is proportional to D×F (parameters). TP comms scale with batch size; FSDP comms do not. This is why TP is preferred within a node and FSDP across nodes.

Let us make this concrete with numbers for LLaMA 70B (D=8192, F=28672) with B=4096 tokens per chip:

TP comms per layer: 4 × 4096 × 8192 × 2 = 268 MB (proportional to B)
FSDP comms per layer: 3 × 2 × 8192 × 28672 = 1.41 GB (fixed, independent of B)

TP comms are 5x smaller than FSDP comms at this batch size. But if we doubled the batch to 8192 tokens/chip, TP comms would double to 536 MB while FSDP stays at 1.41 GB. At very large per-chip batches (>~16K tokens), TP comms would exceed FSDP comms. In practice, per-chip batches at scale are usually 100-4000 tokens, so TP comms are almost always the smaller term.

Attention parallelism: TP also applies to the attention mechanism. Multi-head attention is embarrassingly parallel — each attention head is independent. With Y-way TP, each chip computes N/Y attention heads. The Q, K, V projections are column-parallel; the output projection is row-parallel. Same pattern as the FFN.

Worked example: LLaMA 70B with 64 attention heads and 4-way TP. Each chip computes 16 heads. The per-chip Q projection: (B, D) × (D, 16×128) = (B, 2048). No communication needed for Q/K/V computation; the ReduceScatter happens only at the output projection. Attention TP is essentially free communication-wise.

Sequence parallelism (within TP): In standard TP, the LayerNorm and dropout operations are duplicated across all TP ranks (since they operate on the full (B, D) activation). Sequence parallelism extends TP by sharding these operations too: each chip holds B/Y of the sequence for LayerNorm/dropout, doing an AllGather before and ReduceScatter after the TP region. This saves activation memory at the cost of two extra collectives per layer.

Check: In tensor parallelism, what determines the communication volume per layer?

Chapter 4: Combining FSDP + TP

Neither FSDP nor TP is ideal alone at scale. FSDP becomes communication-bound when per-device batch is too small. TP becomes communication-bound when the TP degree exceeds F/(2200/MY). The solution is to combine them.

The standard recipe: use TP within a node (where bandwidth is highest) and FSDP across nodes. With Y-way TP and X-way FSDP on N = X × Y total chips:

DimensionDegreeWhat it splitsCommunication type
TPYWeight matrices (columns/rows)ReduceScatter activations — intra-node
FSDPXFull parameters + optimizer stateAllGather params, ReduceScatter grads — inter-node

The combined system is compute-bound when both TP and FSDP are individually compute-bound. For TP: Y < F · MY / 2200. For FSDP: per-chip batch B/(X·Y) > 2550/MX.

From the FSDP constraint, we can derive the optimal FSDP degree:

Xopt = √(2BN / F)

where B is the global batch size in tokens and N is the total chip count. Rounding to the nearest power of 2 gives the practical value.

Worked example — LLaMA 70B on 8960 TPU v5p chips:
B = 4M tokens, N = 8960, F = 28672.
Xopt = √(2 × 4.19e6 × 8960 / 28672) = √(2.63e6) ≈ 1618.
Round to 2048-way FSDP, giving 8960/2048 ≈ 4-way TP.
Per-chip batch = 4M / 8960 ≈ 468 tokens. FSDP threshold with 2 axes = 2550/2 = 1275. Since 468 < 1275 for pure FSDP alone, we need TP. With 4-way TP: per-FSDP-group batch = 468 × 4 = 1872 > 1275. Compute-bound!

What if we tried pure FSDP on all 8960 chips? Per-chip batch = 4M/8960 = 468. Threshold with 3 axes = 850. Since 468 < 850, we are communication-bound. The combination of FSDP + TP rescues us.

Why the formula Xopt = √(2BN/F) works: The FSDP communication per layer is 3 × 2DF/Y (where Y is the TP degree, reducing the parameter volume each FSDP rank handles). The TP communication per layer is 4BD/X (where X is the FSDP degree). Setting them equal and solving for the optimal balance gives the square-root formula. The derivation:

FSDP comms: 3 × 2DF / Y = 6DF/Y
TP comms: 4BD
Total comms = 6DF/Y + 4BD, with constraint X × Y = N
Minimize: substitute Y = N/X, take derivative, set to 0
d/dX [6DFX/N + 4BD] = 6DF/N = 0 (no X dependence!)

The total comms is actually independent of the FSDP/TP split when overlapped — the real constraint is ensuring both are individually within their compute-bound thresholds. The formula comes from the FSDP threshold condition: B/(XY) > threshold, solved for X.

Practical consideration — ICI axis assignment: On a 3D torus topology (like TPU v5p), the three physical axes have different bandwidths. The convention:

ICI AxisTypical useWhy
Fastest axisTensor ParallelismTP runs every layer, needs lowest latency
Medium axesFSDPAllGather can be pipelined and overlapped
Slowest axisData ParallelismOnly communicates once per step (AllReduce)
Check: Why do we typically place TP within a node and FSDP across nodes?

Chapter 5: Pipeline Parallelism

Pipeline parallelism (PP) assigns different layers to different chips (or groups of chips). Chip 0 runs layers 1-10, chip 1 runs layers 11-20, and so on. Activations are sent forward between stages; gradients flow backward.

The naive approach has a devastating problem: the pipeline bubble. While chip 0 is computing the forward pass, all other chips sit idle. Then while chip 1 computes, chip 0 is idle. The utilization is only 1/S where S is the number of pipeline stages.

Microbatching to the rescue. We split the global batch into M microbatches and pipeline them through the stages. While microbatch 1 is on stage 2, microbatch 2 can start on stage 1. This fills the pipeline.

The 1F1B (one forward, one backward) schedule is the standard approach:

Warmup phase
Stage 0 processes microbatches 1, 2, ..., S forward passes to fill the pipeline
Steady state
Each stage alternates: 1 forward pass, 1 backward pass. Pipeline stays full.
Cooldown phase
Final S backward passes drain through the pipeline

The bubble fraction measures wasted time:

Bubble fraction = (S - 1) / (M + S - 1)

where S = number of stages, M = number of microbatches. To keep the bubble below 5%, we need M ≥ 19(S-1). For S = 8 stages, that requires M ≥ 133 microbatches.

Stages (S)Microbatches for <5% bubbleMicrobatches for <10% bubble
45727
813363
16285135
32589279
Interleaved schedules assign non-contiguous layers to each stage (e.g., stage 0 gets layers 1-5 and 21-25). This effectively multiplies M by the number of "virtual stages" per device, reducing the bubble at the cost of more point-to-point communication.

PP communication is point-to-point: each stage sends activation tensors of size B × D to the next stage. This is much lighter than the AllReduce/AllGather of DP/FSDP/TP, making PP ideal for cross-node communication where bandwidth is limited.

Memory advantage of 1F1B: In the naive "all forward then all backward" schedule, stage 0 must store activations for all M microbatches simultaneously (since the backward pass has not started yet). In 1F1B, stage 0 only needs to buffer at most S microbatches' activations, since backward passes start as soon as the pipeline fills. This reduces peak activation memory from O(M) to O(S).

Worked example — LLaMA 70B with 8-way PP:

Layers per stage = 80 / 8 = 10 layers
Activation transfer per microbatch = 2 bytes × D × microbatch_tokens
= 2 × 8192 × 4096 = 67 MB (for 4K-token microbatch)
At 50 GB/s DCN: transfer time = 67 MB / 50 GB/s = 1.3 ms
Compute per stage per microbatch ≈ 6 × 2 × 8192 × 28672 × 4096 × 10 / 4.59e14 ≈ 25 ms
Compute >> transfer ⇒ Communication is fully hidden!

The bubble fraction of 1F1B is (S-1)/(M+S-1). But newer schedules like zero-bubble PP (Qi et al., 2023) interleave forward and backward passes more carefully to reduce the bubble to near zero. The key idea: run the backward pass for the weight gradient separately from the backward pass for the activation gradient, allowing more flexible scheduling.

Comparing PP schedules:

ScheduleBubble fractionMemory (peak activations)Communication
Naive (all-F then all-B)(S-1)/(M+S-1)O(M) microbatchesMinimal
1F1B(S-1)/(M+S-1)O(S) microbatchesSame as naive
Interleaved 1F1B(S-1)/(v×M+S-1) (v = virtual stages)O(S) per virtual stagev× more point-to-point
Zero-bubble≈ 0O(S)Same as 1F1B

For most practical cases, 1F1B with enough microbatches (M ≥ 4S) gives an acceptable bubble of <20%. The interleaved schedule is popular in Megatron-LM, where it halves the bubble at the cost of 2x more inter-stage communication (since activations cross the stage boundary twice as often with 2 virtual stages per device).

Choosing the number of microbatches: The global batch B is split into M microbatches of size B/M tokens each. Larger M reduces the bubble but also reduces per-microbatch compute (smaller matmuls, potentially less efficient). There is also a memory trade-off: each pipeline stage must buffer up to S microbatches' worth of activations. The sweet spot is M = 4S to 8S for most configurations.

Check: With 8 pipeline stages and 63 microbatches, what is the bubble fraction?

Chapter 6: Scaling Across Pods

A single TPU v5p pod has up to 8960 chips connected by high-speed ICI. But what if we need more? Connecting multiple pods requires DCN (Data Center Network), which is roughly 10-100x slower than ICI.

InterconnectBandwidth (per chip)Latency
ICI (intra-pod)~4800 GB/s (3 axes)~1 μs
DCN (inter-pod)~50 GB/s~10 μs

The strategy for multi-pod training follows a strict hierarchy:

Within a node (4-8 chips)
Tensor Parallelism — needs highest bandwidth
Within a pod (up to 8960 chips)
FSDP — AllGather/ReduceScatter over ICI
Across pods
Pipeline Parallelism — only point-to-point, lowest bandwidth need
Why PP across pods? PP only requires point-to-point transfers of activation tensors (size B × D). An AllReduce over DCN would be catastrophically slow. PP's communication pattern matches the low-bandwidth, higher-latency inter-pod link perfectly.

For a 4-pod setup training LLaMA 70B:

4-way PP across pods × ~2240-way FSDP within each pod × 4-way TP within each node

The PP bubble with M microbatches and S = 4 stages: bubble = 3/(M+3). With M = 60, that is 3/63 ≈ 4.8%. Acceptable!

What happens at the pod boundary? Each pod computes its assigned pipeline stages using FSDP+TP internally. At the boundary between pods, stage Sk sends its output activation tensor (shape: microbatch_size × D, in bf16) to the first stage of the next pod via DCN.

Bandwidth requirement: Each microbatch transfers 2 × D × microbatch_tokens bytes. For D=8192 and microbatch = 4K tokens: 2 × 8192 × 4096 = 67 MB. At 50 GB/s DCN, this takes ~1.3 ms. Compared to the ~25 ms compute per stage per microbatch, the transfer is comfortably overlapped.

However, if microbatches are very large or D is very large (as in 405B with D=16384), the transfer time grows. For 405B with 16K microbatch: 2 × 16384 × 16384 = 537 MB, taking ~11 ms over DCN. Still overlapped with the ~200 ms compute time, but not as comfortably.

Fault tolerance at scale: Multi-pod training requires robust checkpointing. If any chip in any pod fails, the entire training run must restart from the last checkpoint. Google reports that at 16K chip scale, failures occur roughly every few hours. Asynchronous checkpointing (writing checkpoints in the background) is essential to avoid losing hours of compute per failure.

Practical tip — measuring communication overhead: On TPUs, you can use the XLA profiler to see exactly how much time is spent in communication collectives vs computation. Look for "all-gather" and "reduce-scatter" operations in the trace. If communication takes more than 15-20% of total step time, you likely need to adjust your parallelism configuration (add TP, reduce FSDP degree, or increase batch size).

On NVIDIA GPUs, NCCL profiling with NCCL_DEBUG=INFO shows the time per collective. The PyTorch profiler also breaks down time between compute and communication. The key metric: communication fraction = time in collectives / total step time. Target: <15%.

Context / Sequence Parallelism (CP): A fifth parallelism dimension not covered in detail here, but worth mentioning. CP splits long sequences across chips: each chip processes a different chunk of the sequence. The key challenge is attention — each token needs to attend to tokens on other chips. Ring Attention solves this by passing KV chunks in a ring, overlapping communication with computation. CP is essential for very long contexts (>32K tokens) where activation memory per sequence exceeds a single chip's HBM.

Expert Parallelism (EP): For Mixture-of-Experts (MoE) models like Mixtral or DeepSeek-V3, a sixth dimension distributes different experts to different chips. The communication pattern is all-to-all: each chip routes its tokens to the expert chips, and results are routed back. This is unique because the communication is data-dependent (which tokens go to which experts depends on the router).

The full "5D parallelism" picture: Modern frontier models combine all five dimensions:

DimensionTypical degreeWhat it splitsCommunication
TP4-8Weight matrices (intra-node)ReduceScatter per layer
FSDP/DP64-4096Batch + params/optimizerAllGather + ReduceScatter per layer
PP4-16Layer groups (inter-pod)Point-to-point activations
CP2-8Sequence lengthRing Attention for KV
EP8-64MoE expertsAll-to-all token routing

The total parallelism is the product: TP × FSDP × PP × CP × EP = total chips. For DeepSeek-V3 on a hypothetical 8-pod setup: 8 TP × 256 FSDP × 8 PP × 2 CP × 1 EP (experts are within FSDP groups) = 32,768 chips. Each dimension is carefully chosen to match the hardware topology and communication characteristics.

Check: Why is pipeline parallelism preferred over FSDP for cross-pod communication?

Chapter 7: The Parallelism Decision Tree

With four strategies available, how do you choose? Here is a practical decision tree derived from the analysis above:

Step 1: Does the model fit on one chip?
YES → Use pure DP. It has the least communication overhead.
↓ NO
Step 2: Does it fit with FSDP on one node?
YES → Use FSDP within the node. Check per-chip batch > 2550/MX.
↓ NO, or per-chip batch too small
Step 3: Add Tensor Parallelism within the node.
TP degree Y ≤ F · MY / 2200. Typical: Y = 4 or 8.
↓ Need more chips than one node?
Step 4: FSDP across nodes within the pod.
Use Xopt = √(2BN/F) for optimal FSDP degree.
↓ Need more than one pod?
Step 5: Pipeline parallelism across pods.
S = number of pods. Need M ≥ 19(S-1) microbatches for <5% bubble.
The golden rule: Use the parallelism strategy with the lowest communication cost at each level of the hardware hierarchy. TP within node (highest BW), FSDP within pod (medium BW), PP across pods (lowest BW needed).

Worked example — 30B model on 512 GPUs across 2 nodes of 256:

Step 1: 30B model, D=6144, F=16384. Batch = 2M tokens.
Step 2: Try pure FSDP on 512 chips. Per-chip = 2M/512 = 3906 tokens.
FSDP threshold (3 axes) = 850. 3906 > 850 ⇒ Pure FSDP works within a pod!
Step 3: But we span 2 pods. FSDP across pods would use slow DCN.
Step 4: Use 2-way PP across pods + 256-way FSDP within each pod.
PP bubble (M=64, S=2) = 1/65 = 1.5%. Negligible!
Per-pod: 30B/2 = 15B params per stage. FSDP = 256-way. Per-chip = 3906 tok. Works!
StrategyComms per layerComms typeBest placement
DP2 × WlayerAllReduce (grads)Anywhere (if batch large enough)
FSDP3 × WlayerAllGather + ReduceScatterWithin pod (ICI)
TP4BD bytesReduceScatter (activations)Within node (highest BW)
PPB × D bytes (point-to-point)Send/RecvAcross pods (DCN)
Check: You have a 70B model on 4 pods. What parallelism goes across pods?

Chapter 8: Parallelism Visualizer

Use this interactive tool to explore how different parallelism configurations affect communication overhead and memory usage. Adjust the model size, batch size, chip count, and parallelism degrees to see what is compute-bound vs communication-bound.

Parallelism Strategy Explorer

Adjust sliders to see compute vs comms balance for each strategy.

2048
4M
4
Check: For LLaMA 70B on a full v5p pod (8960 chips, 4M token batch), what combination keeps us compute-bound?

Chapter 9: Takeaways

Let us crystallize the key results from this chapter.

StrategyMemory savingsComms costWhen to use
DPNone2P (AllReduce grads)Model fits on one chip, large batch
FSDP~N× (params + optimizer)3P (AllGather fwd + AllGather bwd + ReduceScatter)Model too big for one chip, moderate batch
TPY× (per-layer weights)4BD per layer (activation-proportional)Large FFN, high intra-node BW
PPS× (layer groups)BD point-to-point per stage boundaryCross-pod, many microbatches
The hierarchy: TP within the node, FSDP within the pod, PP across pods. This mapping follows the hardware bandwidth hierarchy: intra-node > intra-pod ICI > inter-pod DCN.

Why this ordering makes physical sense: TP communicates every layer, both forward and backward — it is the most latency-sensitive and needs the highest bandwidth. FSDP communicates once per layer (AllGather/ReduceScatter of parameters) but can be pipelined — it needs good bandwidth but tolerates more latency. PP communicates once per stage boundary (activations only) with tiny volumes — it is the least communication-intensive.

The communication volume hierarchy (per training step for one layer):

StrategyVolume per chipExample (D=8192, F=28672, B=4096)
TP (4-way)4BD = 4 × 4096 × 8192 × 2 = 268 MB268 MB — every layer
FSDP (1024-way)3 × 2DF/TP = 3 × 2 × 8192 × 28672 / 4 = 353 MB353 MB — every layer
PP (4-way)2BD = 2 × 4096 × 8192 × 2 = 134 MB134 MB — once per 20 layers

TP and FSDP volumes are comparable per layer, but TP runs on faster links (intra-node). PP's per-stage volume is similar in magnitude but occurs far less frequently, making it the best fit for slow DCN links.

Critical formulas to remember:

FSDP bound: Bper-chip > 2550 / MX
TP bound: Y < F · MY / 2200
Optimal FSDP degree: Xopt = √(2BN / F)
PP bubble: (S-1) / (M + S - 1)

Worked problem: You have a 13B model (D=5120, F=13824, L=40) and want to train on 512 TPU v5p chips with a 2M token batch. What parallelism config should you use?

Step 1: Per-chip batch = 2M / 512 = 3906 tokens
Step 2: FSDP threshold (3 axes) = 850 tokens. 3906 > 850 ⇒ FSDP alone works!
Step 3: Verify TP not needed. Xopt = √(2 × 2M × 512 / 13824) = 385.
Since Xopt (385) < N (512), pure FSDP is efficient. Use 512-way FSDP, no TP needed.
Memory per chip: 16 × 13B / 512 = 0.41 GB. Easily fits.

Same model, 4096 chips:

Per-chip batch = 2M / 4096 = 488. Below 850 ⇒ Need TP!
Xopt = √(2 × 2M × 4096 / 13824) = 1089. So ~1024-way FSDP, 4-way TP.

Worked problem 2: You want to train a 405B model (D=16384, F=53248, L=126) on 8 pods (71680 chips) with 8M token batch.

Step 1: Per-chip batch = 8M / 71680 = 112 tokens. Below 850 ⇒ Need TP.
Step 2: Xopt = √(2 × 8M × 71680 / 53248) = 4637. Round to 4096-way FSDP.
Step 3: Within each pod: 71680/8 = 8960 chips. 8960/(4096/8) = 8960/512 ≈ 17-way TP. Too much.
Better: 8-way PP across pods × 1120-way FSDP within pod × 8-way TP within node.
Check: 8 × 1120 × 8 = 71680 ✓. Per FSDP group batch = 112 × 8 = 896 > 850 ✓
TP check: Y=8 < F×2/2200 = 53248×2/2200 = 48 ✓
PP bubble: (8-1)/(M+7). Need M≥133 for <5%. With 8M batch / 60K microbatch = 133. Works!

Worked problem 3: A 3B model on 64 H100 GPUs (80GB each) with 500K token batch.

Per-chip batch = 500K / 64 = 7812 tokens.
H100 threshold ≈ C/(3W) = 9.9e14 / (3 × 900e9) = 367.
7812 >> 367 ⇒ Pure DP or FSDP works easily.
Memory: 16 × 3B = 48 GB per chip in DP. Fits in 80GB with room for activations.
Can even use pure DP (no sharding) since model state fits per chip. Simplest possible setup.

Summary of the decision procedure:

1. Compute per-chip batch = B/N
If > threshold (2550/MX), FSDP alone may work
2. Check memory: 16P/N < HBM?
If model state fits per chip, proceed. Otherwise need more chips or TP.
3. Compute Xopt = √(2BN/F)
Y = N/Xopt. If Y ≤ F×MY/2200, valid config found.
4. Multi-pod? Add PP
One PP stage per pod. Need M ≥ 4(S-1) for <20% bubble.

In the next chapter, we apply all of these to a concrete model: training LLaMA 3 on TPUs.

Check: Which parallelism strategy's communication cost scales with batch size (not parameter count)?