Data parallelism, FSDP, tensor parallelism, pipeline parallelism — when each strategy wins, when it breaks, and how to combine them.
You have a model that trains perfectly on one chip. Now you want to make it ten times bigger and train on a thousand chips. Will you get a thousand-fold speedup? Almost certainly not — but the gap between "linear speedup" and "actual speedup" is exactly what this chapter is about.
The goal of "model scaling" is deceptively simple: double the chips, halve the training time. In other words, we want strong scaling — a proportional, linear increase in throughput as we add more accelerators.
On a single chip, performance depends on the trade-off between memory bandwidth and FLOPs. At the cluster level, a new bottleneck emerges: inter-chip communication. Every time we add more chips, we increase the communication load while reducing the per-device computation available to hide it.
Think of it like a restaurant kitchen. One chef can cook a meal in an hour. Ten chefs should cook ten meals in an hour, but only if they do not spend half their time coordinating. The "parallelism strategies" in this chapter are the recipes for organizing the kitchen so the chefs stay busy cooking instead of talking.
We will study four common parallelism schemes:
For each scheme, we will derive: (1) the memory savings, (2) the communication volume, and (3) the critical batch size below which we become communication-bound.
Before we dive in, let us establish some notation. For a Transformer model:
| Symbol | Meaning | Example (LLaMA 70B) |
|---|---|---|
| D | Model dimension (hidden size) | 8,192 |
| F | FFN intermediate dimension | 28,672 |
| L | Number of layers | 80 |
| B | Global batch size (tokens) | 4,000,000 |
| N | Number of chips | 8,960 |
| C | Peak FLOPs/s per chip | 4.59e14 (v5p) |
| W | Inter-chip bandwidth | 4800 GB/s (ICI, 3 axes) |
The key insight of this chapter is that every parallelism strategy can be analyzed through the lens of a single ratio: computation time vs communication time. When the compute per chip exceeds the communication per chip, we are in the happy "compute-bound" regime. When communication dominates, we are wasting expensive hardware on waiting.
Let us build intuition with a concrete example. Imagine training a 70B model with a 4M token batch on increasing numbers of TPU v5p chips:
| Chips | Per-chip batch | Compute time/layer (ms) | Comms time/layer (ms) | Status |
|---|---|---|---|---|
| 256 | 15,625 tok | 5.9 | 0.29 | Compute-bound (20x margin) |
| 1024 | 3,906 tok | 1.5 | 0.29 | Compute-bound (5x margin) |
| 4096 | 977 tok | 0.37 | 0.29 | Borderline! |
| 8960 | 447 tok | 0.17 | 0.29 | Comms-bound (1.7x over) |
See how the compute time per layer shrinks as we add chips (each chip does less work), but the communication stays constant (the AllGather always moves the same volume of data)? The crossover point is where we need to get smarter about our parallelism strategy.
Data parallelism is the simplest distributed strategy. Every chip holds a full copy of the model. The global batch is split across N chips so each chip processes B/N samples. After the backward pass, gradients are summed via an AllReduce.
| Property | Value |
|---|---|
| Model memory | Full copy on every chip (no savings) |
| Data split | Each chip sees B/N of the global batch |
| Communication | AllReduce of gradients once per step |
| Communication volume | 2P bytes per chip (ring AllReduce) |
The AllReduce has two phases: ReduceScatter (each chip sends (N-1)/N of its gradients) and AllGather (each chip receives the summed result). In a ring topology, each chip sends and receives exactly 2P · (N-1)/N ≈ 2P bytes total.
Let us be concrete about what these operations do:
| Operation | Input | Output | Bytes per chip |
|---|---|---|---|
| AllGather | Each chip has 1/N of the data | Every chip has all the data | (N-1)/N × total_size |
| ReduceScatter | Each chip has full data (different values) | Each chip has 1/N of the sum | (N-1)/N × total_size |
| AllReduce | Each chip has full data | Every chip has the sum | = ReduceScatter + AllGather |
In a ring AllReduce with N chips, the data is split into N chunks. Each chip sends one chunk to its neighbor and receives one chunk, repeating N-1 times for the ReduceScatter phase and N-1 times for the AllGather phase. Total per-chip transfer: 2 × (N-1)/N × data_size ≈ 2 × data_size for large N.
Why the ring AllReduce is bandwidth-optimal: Any algorithm that sums N copies of a vector of size S must transfer at least 2S(N-1)/N bytes per chip (information theory lower bound). The ring AllReduce achieves this exactly, making it the most bandwidth-efficient collective for large messages. For small messages, latency (not bandwidth) dominates, and tree-based reductions are preferred.
Gradient accumulation — the DP memory trick: If the per-chip batch B/N is too large to fit in memory (due to activations), we can split it further into micro-steps. Compute gradients for a smaller micro-batch, accumulate them in memory, and only run the AllReduce once all micro-batches are done. This uses the same communication volume but allows us to process large effective batch sizes with limited memory. The trade-off: more sequential compute per step, but the same communication cost.
Let us derive the critical point. For a single Transformer layer with FFN width F and model dimension D, the forward + backward FLOPs per chip are roughly:
where C is the peak FLOPs/s per chip. The AllReduce transfers 2 · 2DF bytes (both ReduceScatter and AllGather of one layer's parameters):
where W is the inter-chip bandwidth. We become communication-bound when Tcomms > Tcompute:
On TPU v5p with ICI bandwidth W = 4800 GB/s across 3 axes: C/3W ≈ 4.59e14 / (3 × 4800e9) ≈ 32 tokens per chip. So if your per-device batch drops below ~32 tokens, pure DP becomes communication-bound.
Worked example: You have 256 TPU v5p chips and a 1M token batch. Can you use pure DP?
Now try 8960 chips with the same 1M batch:
The real limitation of DP is not the communication threshold — it is memory. Every chip must hold the entire model, optimizer state, and gradient buffers. For a 70B model in bf16 with Adam optimizer, that is 16 × 70B = 1.12 TB per chip. No single chip has that much memory. This is why we need FSDP.
Memory breakdown for pure DP (per chip):
| Component | Bytes per param | 7B model | 70B model |
|---|---|---|---|
| Parameters (bf16) | 2 | 14 GB | 140 GB |
| Gradients (bf16) | 2 | 14 GB | 140 GB |
| Adam master weights (fp32) | 4 | 28 GB | 280 GB |
| Adam momentum m (fp32) | 4 | 28 GB | 280 GB |
| Adam variance v (fp32) | 4 | 28 GB | 280 GB |
| Total | 16 | 112 GB | 1,120 GB |
A 7B model barely fits on a single H100 (80 GB) — and that is before activations. A 70B model needs 1.12 TB per chip, roughly 12 H100s or 12 TPU v5p chips. This memory wall is the #1 motivation for moving beyond pure DP.
Pure DP wastes memory: every chip holds the full model, full optimizer state, and full gradients. FSDP (also known as ZeRO) fixes this by sharding all three across N chips.
| What is sharded | Memory per chip (before) | Memory per chip (after FSDP-N) |
|---|---|---|
| Parameters | 2P bytes (bf16) | 2P/N bytes |
| Gradients | 2P bytes | 2P/N bytes |
| Optimizer state (Adam) | 12P bytes (fp32 copy + m + v) | 12P/N bytes |
| Total | 16P bytes | 16P/N bytes |
The price is more communication. In the forward pass, each layer must AllGather its parameters before computing (since each chip only stores 1/N of the weights). In the backward pass, we do a ReduceScatter on the gradients.
Let us count the bytes moved. For one layer with weight size Wlayer:
Total per layer: 3 × Wlayer bytes per chip. Compare this to pure DP which does 2 × Wlayer per chip. FSDP costs 50% more communication but saves (N-1)/N of the memory.
When does FSDP become communication-bound? Following the same analysis as DP but with 50% more comms:
Since FSDP can overlap communication with computation (gather layer i+1 while computing layer i), the practical threshold is lower. The book derives that with MX ICI axes dedicated to FSDP:
where 2550 is a hardware-specific constant for TPU v5p reflecting the ratio of FLOPs to bandwidth. With 3 axes: threshold ≈ 850 tokens per chip.
How communication overlap works: While chip i computes the forward pass for layer k, the AllGather for layer k+1 runs in the background on the ICI links. If the AllGather finishes before the computation, we pay zero extra latency — the communication is fully hidden. The threshold tells us when the AllGather takes longer than the computation and starts to block.
Worked example — ZeRO memory savings: Consider a 7B model with Adam optimizer:
| Component | 1 GPU (DP) | 8-way FSDP | 64-way FSDP |
|---|---|---|---|
| Parameters (bf16) | 14 GB | 1.75 GB | 0.22 GB |
| Gradients (bf16) | 14 GB | 1.75 GB | 0.22 GB |
| Optimizer (fp32, Adam) | 84 GB | 10.5 GB | 1.31 GB |
| Total model state | 112 GB | 14 GB | 1.75 GB |
With 8-way FSDP, a 7B model's state fits comfortably in a single H100 (80 GB). With 64-way, even a 70B model is manageable at ~17.5 GB per chip. The key trade-off: more FSDP degree means more communication, but memory savings scale linearly.
Tensor parallelism (TP) splits individual weight matrices across chips. Instead of each chip holding the full weight and processing a fraction of the batch, each chip holds a fraction of the weight and processes the full batch.
Consider a simple matrix multiply Y = XW where X is (B, D) and W is (D, F). We can split W along either axis:
| Strategy | Split | Per-chip compute | Communication |
|---|---|---|---|
| Column parallel | W split along F (columns) | X × Wshard = (B, F/Y) | AllGather output to get (B, F) |
| Row parallel | W split along D (rows) | Xshard × Wshard = (B, F) | ReduceScatter to sum partial results |
where Y is the TP degree (number of chips). In practice, Transformer FFN blocks use column-then-row: the first matmul (up-projection) uses column parallelism, and the second (down-projection) uses row parallelism. This way the AllGather output of the first feeds directly into the row-split input of the second, with only one AllReduce per FFN block.
When does TP become communication-bound? The per-chip FLOPs for one FFN layer are 2BDF/Y (since we have F/Y columns). The communication is 2 × 2BD bytes. We are compute-bound when:
For TPU v5p with a single ICI axis (W = 1600 GB/s) and F = 28672 (LLaMA 70B):
So we could do up to ~50-way TP before becoming communication-bound. But this uses one axis. With dedicated ICI axes MY:
Let us make this concrete with numbers for LLaMA 70B (D=8192, F=28672) with B=4096 tokens per chip:
TP comms are 5x smaller than FSDP comms at this batch size. But if we doubled the batch to 8192 tokens/chip, TP comms would double to 536 MB while FSDP stays at 1.41 GB. At very large per-chip batches (>~16K tokens), TP comms would exceed FSDP comms. In practice, per-chip batches at scale are usually 100-4000 tokens, so TP comms are almost always the smaller term.
Attention parallelism: TP also applies to the attention mechanism. Multi-head attention is embarrassingly parallel — each attention head is independent. With Y-way TP, each chip computes N/Y attention heads. The Q, K, V projections are column-parallel; the output projection is row-parallel. Same pattern as the FFN.
Worked example: LLaMA 70B with 64 attention heads and 4-way TP. Each chip computes 16 heads. The per-chip Q projection: (B, D) × (D, 16×128) = (B, 2048). No communication needed for Q/K/V computation; the ReduceScatter happens only at the output projection. Attention TP is essentially free communication-wise.
Sequence parallelism (within TP): In standard TP, the LayerNorm and dropout operations are duplicated across all TP ranks (since they operate on the full (B, D) activation). Sequence parallelism extends TP by sharding these operations too: each chip holds B/Y of the sequence for LayerNorm/dropout, doing an AllGather before and ReduceScatter after the TP region. This saves activation memory at the cost of two extra collectives per layer.
Neither FSDP nor TP is ideal alone at scale. FSDP becomes communication-bound when per-device batch is too small. TP becomes communication-bound when the TP degree exceeds F/(2200/MY). The solution is to combine them.
The standard recipe: use TP within a node (where bandwidth is highest) and FSDP across nodes. With Y-way TP and X-way FSDP on N = X × Y total chips:
| Dimension | Degree | What it splits | Communication type |
|---|---|---|---|
| TP | Y | Weight matrices (columns/rows) | ReduceScatter activations — intra-node |
| FSDP | X | Full parameters + optimizer state | AllGather params, ReduceScatter grads — inter-node |
The combined system is compute-bound when both TP and FSDP are individually compute-bound. For TP: Y < F · MY / 2200. For FSDP: per-chip batch B/(X·Y) > 2550/MX.
From the FSDP constraint, we can derive the optimal FSDP degree:
where B is the global batch size in tokens and N is the total chip count. Rounding to the nearest power of 2 gives the practical value.
What if we tried pure FSDP on all 8960 chips? Per-chip batch = 4M/8960 = 468. Threshold with 3 axes = 850. Since 468 < 850, we are communication-bound. The combination of FSDP + TP rescues us.
Why the formula Xopt = √(2BN/F) works: The FSDP communication per layer is 3 × 2DF/Y (where Y is the TP degree, reducing the parameter volume each FSDP rank handles). The TP communication per layer is 4BD/X (where X is the FSDP degree). Setting them equal and solving for the optimal balance gives the square-root formula. The derivation:
The total comms is actually independent of the FSDP/TP split when overlapped — the real constraint is ensuring both are individually within their compute-bound thresholds. The formula comes from the FSDP threshold condition: B/(XY) > threshold, solved for X.
Practical consideration — ICI axis assignment: On a 3D torus topology (like TPU v5p), the three physical axes have different bandwidths. The convention:
| ICI Axis | Typical use | Why |
|---|---|---|
| Fastest axis | Tensor Parallelism | TP runs every layer, needs lowest latency |
| Medium axes | FSDP | AllGather can be pipelined and overlapped |
| Slowest axis | Data Parallelism | Only communicates once per step (AllReduce) |
Pipeline parallelism (PP) assigns different layers to different chips (or groups of chips). Chip 0 runs layers 1-10, chip 1 runs layers 11-20, and so on. Activations are sent forward between stages; gradients flow backward.
The naive approach has a devastating problem: the pipeline bubble. While chip 0 is computing the forward pass, all other chips sit idle. Then while chip 1 computes, chip 0 is idle. The utilization is only 1/S where S is the number of pipeline stages.
The 1F1B (one forward, one backward) schedule is the standard approach:
The bubble fraction measures wasted time:
where S = number of stages, M = number of microbatches. To keep the bubble below 5%, we need M ≥ 19(S-1). For S = 8 stages, that requires M ≥ 133 microbatches.
| Stages (S) | Microbatches for <5% bubble | Microbatches for <10% bubble |
|---|---|---|
| 4 | 57 | 27 |
| 8 | 133 | 63 |
| 16 | 285 | 135 |
| 32 | 589 | 279 |
PP communication is point-to-point: each stage sends activation tensors of size B × D to the next stage. This is much lighter than the AllReduce/AllGather of DP/FSDP/TP, making PP ideal for cross-node communication where bandwidth is limited.
Memory advantage of 1F1B: In the naive "all forward then all backward" schedule, stage 0 must store activations for all M microbatches simultaneously (since the backward pass has not started yet). In 1F1B, stage 0 only needs to buffer at most S microbatches' activations, since backward passes start as soon as the pipeline fills. This reduces peak activation memory from O(M) to O(S).
Worked example — LLaMA 70B with 8-way PP:
The bubble fraction of 1F1B is (S-1)/(M+S-1). But newer schedules like zero-bubble PP (Qi et al., 2023) interleave forward and backward passes more carefully to reduce the bubble to near zero. The key idea: run the backward pass for the weight gradient separately from the backward pass for the activation gradient, allowing more flexible scheduling.
Comparing PP schedules:
| Schedule | Bubble fraction | Memory (peak activations) | Communication |
|---|---|---|---|
| Naive (all-F then all-B) | (S-1)/(M+S-1) | O(M) microbatches | Minimal |
| 1F1B | (S-1)/(M+S-1) | O(S) microbatches | Same as naive |
| Interleaved 1F1B | (S-1)/(v×M+S-1) (v = virtual stages) | O(S) per virtual stage | v× more point-to-point |
| Zero-bubble | ≈ 0 | O(S) | Same as 1F1B |
For most practical cases, 1F1B with enough microbatches (M ≥ 4S) gives an acceptable bubble of <20%. The interleaved schedule is popular in Megatron-LM, where it halves the bubble at the cost of 2x more inter-stage communication (since activations cross the stage boundary twice as often with 2 virtual stages per device).
Choosing the number of microbatches: The global batch B is split into M microbatches of size B/M tokens each. Larger M reduces the bubble but also reduces per-microbatch compute (smaller matmuls, potentially less efficient). There is also a memory trade-off: each pipeline stage must buffer up to S microbatches' worth of activations. The sweet spot is M = 4S to 8S for most configurations.
A single TPU v5p pod has up to 8960 chips connected by high-speed ICI. But what if we need more? Connecting multiple pods requires DCN (Data Center Network), which is roughly 10-100x slower than ICI.
| Interconnect | Bandwidth (per chip) | Latency |
|---|---|---|
| ICI (intra-pod) | ~4800 GB/s (3 axes) | ~1 μs |
| DCN (inter-pod) | ~50 GB/s | ~10 μs |
The strategy for multi-pod training follows a strict hierarchy:
For a 4-pod setup training LLaMA 70B:
The PP bubble with M microbatches and S = 4 stages: bubble = 3/(M+3). With M = 60, that is 3/63 ≈ 4.8%. Acceptable!
What happens at the pod boundary? Each pod computes its assigned pipeline stages using FSDP+TP internally. At the boundary between pods, stage Sk sends its output activation tensor (shape: microbatch_size × D, in bf16) to the first stage of the next pod via DCN.
Bandwidth requirement: Each microbatch transfers 2 × D × microbatch_tokens bytes. For D=8192 and microbatch = 4K tokens: 2 × 8192 × 4096 = 67 MB. At 50 GB/s DCN, this takes ~1.3 ms. Compared to the ~25 ms compute per stage per microbatch, the transfer is comfortably overlapped.
However, if microbatches are very large or D is very large (as in 405B with D=16384), the transfer time grows. For 405B with 16K microbatch: 2 × 16384 × 16384 = 537 MB, taking ~11 ms over DCN. Still overlapped with the ~200 ms compute time, but not as comfortably.
Fault tolerance at scale: Multi-pod training requires robust checkpointing. If any chip in any pod fails, the entire training run must restart from the last checkpoint. Google reports that at 16K chip scale, failures occur roughly every few hours. Asynchronous checkpointing (writing checkpoints in the background) is essential to avoid losing hours of compute per failure.
Practical tip — measuring communication overhead: On TPUs, you can use the XLA profiler to see exactly how much time is spent in communication collectives vs computation. Look for "all-gather" and "reduce-scatter" operations in the trace. If communication takes more than 15-20% of total step time, you likely need to adjust your parallelism configuration (add TP, reduce FSDP degree, or increase batch size).
On NVIDIA GPUs, NCCL profiling with NCCL_DEBUG=INFO shows the time per collective. The PyTorch profiler also breaks down time between compute and communication. The key metric: communication fraction = time in collectives / total step time. Target: <15%.
Context / Sequence Parallelism (CP): A fifth parallelism dimension not covered in detail here, but worth mentioning. CP splits long sequences across chips: each chip processes a different chunk of the sequence. The key challenge is attention — each token needs to attend to tokens on other chips. Ring Attention solves this by passing KV chunks in a ring, overlapping communication with computation. CP is essential for very long contexts (>32K tokens) where activation memory per sequence exceeds a single chip's HBM.
Expert Parallelism (EP): For Mixture-of-Experts (MoE) models like Mixtral or DeepSeek-V3, a sixth dimension distributes different experts to different chips. The communication pattern is all-to-all: each chip routes its tokens to the expert chips, and results are routed back. This is unique because the communication is data-dependent (which tokens go to which experts depends on the router).
The full "5D parallelism" picture: Modern frontier models combine all five dimensions:
| Dimension | Typical degree | What it splits | Communication |
|---|---|---|---|
| TP | 4-8 | Weight matrices (intra-node) | ReduceScatter per layer |
| FSDP/DP | 64-4096 | Batch + params/optimizer | AllGather + ReduceScatter per layer |
| PP | 4-16 | Layer groups (inter-pod) | Point-to-point activations |
| CP | 2-8 | Sequence length | Ring Attention for KV |
| EP | 8-64 | MoE experts | All-to-all token routing |
The total parallelism is the product: TP × FSDP × PP × CP × EP = total chips. For DeepSeek-V3 on a hypothetical 8-pod setup: 8 TP × 256 FSDP × 8 PP × 2 CP × 1 EP (experts are within FSDP groups) = 32,768 chips. Each dimension is carefully chosen to match the hardware topology and communication characteristics.
With four strategies available, how do you choose? Here is a practical decision tree derived from the analysis above:
Worked example — 30B model on 512 GPUs across 2 nodes of 256:
| Strategy | Comms per layer | Comms type | Best placement |
|---|---|---|---|
| DP | 2 × Wlayer | AllReduce (grads) | Anywhere (if batch large enough) |
| FSDP | 3 × Wlayer | AllGather + ReduceScatter | Within pod (ICI) |
| TP | 4BD bytes | ReduceScatter (activations) | Within node (highest BW) |
| PP | B × D bytes (point-to-point) | Send/Recv | Across pods (DCN) |
Use this interactive tool to explore how different parallelism configurations affect communication overhead and memory usage. Adjust the model size, batch size, chip count, and parallelism degrees to see what is compute-bound vs communication-bound.
Adjust sliders to see compute vs comms balance for each strategy.
Let us crystallize the key results from this chapter.
| Strategy | Memory savings | Comms cost | When to use |
|---|---|---|---|
| DP | None | 2P (AllReduce grads) | Model fits on one chip, large batch |
| FSDP | ~N× (params + optimizer) | 3P (AllGather fwd + AllGather bwd + ReduceScatter) | Model too big for one chip, moderate batch |
| TP | Y× (per-layer weights) | 4BD per layer (activation-proportional) | Large FFN, high intra-node BW |
| PP | S× (layer groups) | BD point-to-point per stage boundary | Cross-pod, many microbatches |
Why this ordering makes physical sense: TP communicates every layer, both forward and backward — it is the most latency-sensitive and needs the highest bandwidth. FSDP communicates once per layer (AllGather/ReduceScatter of parameters) but can be pipelined — it needs good bandwidth but tolerates more latency. PP communicates once per stage boundary (activations only) with tiny volumes — it is the least communication-intensive.
The communication volume hierarchy (per training step for one layer):
| Strategy | Volume per chip | Example (D=8192, F=28672, B=4096) |
|---|---|---|
| TP (4-way) | 4BD = 4 × 4096 × 8192 × 2 = 268 MB | 268 MB — every layer |
| FSDP (1024-way) | 3 × 2DF/TP = 3 × 2 × 8192 × 28672 / 4 = 353 MB | 353 MB — every layer |
| PP (4-way) | 2BD = 2 × 4096 × 8192 × 2 = 134 MB | 134 MB — once per 20 layers |
TP and FSDP volumes are comparable per layer, but TP runs on faster links (intra-node). PP's per-stage volume is similar in magnitude but occurs far less frequently, making it the best fit for slow DCN links.
Critical formulas to remember:
Worked problem: You have a 13B model (D=5120, F=13824, L=40) and want to train on 512 TPU v5p chips with a 2M token batch. What parallelism config should you use?
Same model, 4096 chips:
Worked problem 2: You want to train a 405B model (D=16384, F=53248, L=126) on 8 pods (71680 chips) with 8M token batch.
Worked problem 3: A 3B model on 64 H100 GPUs (80GB each) with 500K token batch.
Summary of the decision procedure:
In the next chapter, we apply all of these to a concrete model: training LLaMA 3 on TPUs.