Replicate the model, split the data. All-reduce, gradient bucketing, and ZeRO — the workhorse of distributed training.
In the last chapter we learned gradient accumulation: run multiple micro-batches sequentially, accumulate their gradients, then update the model. It works, but it is slow — each micro-batch waits for the previous one to finish.
Here is the key observation: the forward and backward passes for each micro-batch are completely independent. They use different input data but the same model weights. If we have 8 GPUs, why not run 8 micro-batches at the same time?
This is the simplest and most widely used form of distributed training. It is also the conceptual foundation for everything else in this book. But there are subtleties: how do we average gradients across GPUs efficiently? What if the model state does not fit on one GPU? That is where all-reduce and ZeRO come in.
After each GPU finishes its backward pass, it has a set of local gradients. To keep all model replicas in sync, we need to average these gradients across all GPUs before the optimizer step.
The operation that does this is called all-reduce. It takes a tensor from each GPU, sums (or averages) them, and distributes the result back to every GPU. After all-reduce, every GPU has the identical averaged gradient.
A naive implementation would wait for the entire backward pass to finish, then do one big all-reduce. But that means GPUs sit idle during communication. In distributed training, sequential compute then communicate is a BIG NO-NO.
The trick to efficient data parallelism is to overlap communication with computation. As soon as gradients for a layer are ready during the backward pass, start communicating them — do not wait for the entire backward pass to finish.
There are three key optimizations that make this work:
| Optimization | How it works |
|---|---|
| Backward hooks | Register a hook on each parameter that fires as soon as its gradient is computed. The hook starts the all-reduce immediately. |
| Gradient bucketing | Group small gradients into larger buckets before all-reducing. This amortizes the fixed cost of launching a communication operation. |
| Overlap | While the all-reduce for layer N's gradients is in progress, the backward pass for layer N-1 continues computing on the GPU. |
In a profiler trace, you would see the backward computation on one CUDA stream and the all-reduce communication on another, running in parallel. The bucket size is tunable — too small and you pay overhead per communication call, too large and you delay the overlap.
With data parallelism and gradient accumulation combined, the global batch size formula becomes:
where DP is the number of data-parallel replicas (GPUs used for data parallelism), BSmicro is the micro-batch size per GPU, and grad_acc_steps is the number of gradient accumulation steps.
To get the global batch size in tokens, multiply by the sequence length:
Example: To reach 4M tokens with seq_len=4096, you need 1024 samples. With DP=8 and micro-batch=4, that is 32 samples per step, so you need grad_acc_steps = 1024 / 32 = 32.
Basic data parallelism replicates everything on every GPU: weights, gradients, and optimizer states. For a 7B model, each GPU needs ~112 GB just for the model state. That is extremely wasteful — we have N copies of the same thing.
The Zero Redundancy Optimizer (ZeRO) eliminates this redundancy by sharding the model state across data-parallel GPUs. There are three stages, each more aggressive than the last.
ZeRO Stage 1 shards the optimizer states across DP replicas.
Since Adam's momentum and variance (8 bytes/param) are the largest per-parameter cost, sharding them across N GPUs saves significant memory. After each optimizer step, each GPU updates only its shard of parameters, then broadcasts the updates.
ZeRO Stage 2 shards both optimizer states and gradients. Instead of all-reducing all gradients to every GPU, we use reduce-scatter: each GPU ends up with only 1/N of the averaged gradients — exactly the shard it needs for its optimizer step.
ZeRO Stage 3 goes all the way: it shards optimizer states, gradients, and model weights across GPUs. No GPU holds a complete copy of anything. Before each forward or backward pass through a layer, the GPU gathers that layer's weights from the other GPUs, uses them, then discards them.
| Stage | Shards | Memory per GPU | Communication |
|---|---|---|---|
| ZeRO-0 (none) | Nothing | 16P bytes | All-reduce gradients |
| ZeRO-1 | Optimizer states | ~8P + 4P/N bytes | All-reduce grads + broadcast params |
| ZeRO-2 | Optimizer + grads | ~4P + 8P/N bytes | Reduce-scatter grads + broadcast params |
| ZeRO-3 | Everything | ~16P/N bytes | All-gather weights before each layer |
Let us see how much memory ZeRO actually saves. The interactive diagram below shows per-GPU memory for a model across different ZeRO stages and DP degrees.
Adjust model size and DP degree. Click ZeRO stages to compare.
Notice how ZeRO-1 already cuts optimizer memory dramatically, while ZeRO-3 divides everything by the DP degree. At DP=64, even a 70B model's state fits on a single GPU.
Data parallelism scaling is not free. As we add more GPUs, the all-reduce communication grows. Let us simulate how throughput scales with DP degree.
Watch how total throughput grows with DP, and how per-GPU efficiency drops due to communication overhead.
Data parallelism is the workhorse of distributed training. Combined with ZeRO, it handles most training scenarios up to moderate scale.
| Technique | What it does | When to use |
|---|---|---|
| Basic DP | Replicate model, split data, all-reduce gradients | Model fits on one GPU |
| DP + ZeRO-1 | Shard optimizer states | Optimizer states are the memory bottleneck |
| DP + ZeRO-2 | Shard optimizer + gradients | Need more memory savings |
| DP + ZeRO-3 (FSDP) | Shard everything | Model too large for one GPU |