Split weight matrices across GPUs. Column & row sharding, sequence parallelism, and the critical-path communication trade-off.
ZeRO-3 can shard model parameters, gradients, and optimizer states across GPUs. But it has a limitation: before computing each layer, it must all-gather the full layer weights onto every GPU. The activations during the matrix multiplication are still unsharded.
When activations become the memory bottleneck — for large models with long sequences — we need something different. We need to split the actual computation of a matrix multiplication across GPUs, so that no single GPU ever holds the full activation tensor.
This works because of a fundamental mathematical property: matrix multiplication can be decomposed along either the column or row dimension. If Y = XW, we can either split W by columns (and concatenate results) or split W by rows (and sum results). Let us see how.
Start with a weight matrix W of shape (in_features, out_features). Column-parallel sharding splits W along the output dimension into N shards: W = [W1, W2, ..., WN].
Each GPU receives the full input X (via broadcast) and one column shard Wi. It computes Yi = X · Wi. The result Yi is a partial output. To reconstruct the full result, we all-gather the partial outputs: Y = [Y1, Y2, ..., YN].
Now split W along the input dimension (rows): W = [W1; W2; ...; WN]. This also requires splitting the input X = [X1, X2, ..., XN] via scatter.
Each GPU computes Yi = Xi · Wi. The partial results are the correct shape but need to be summed (not concatenated) for the final result. This requires an all-reduce.
A transformer layer has two main blocks: the MLP (feedforward) and the multi-head attention (MHA). We can apply TP to both by cleverly pairing column and row sharding.
MLP block: The first linear layer (FC1) uses column sharding; the second (FC2) uses row sharding. This means we only need one all-reduce per MLP block in the forward pass (at the output of FC2). The input broadcast for FC1 is free since inputs are already synced.
Attention block: The Q, K, V projections are column-parallel (each GPU handles a subset of attention heads). The output projection is row-parallel. Again, one all-reduce per attention block.
Shows how column-linear and row-linear combine inside MLP and Attention blocks. Click blocks to highlight data flow.
Here is the critical difference between TP and data parallelism: in DP, communication (all-reduce gradients) happens between layers and can be overlapped with backward computation. In TP, the all-reduce happens within each layer, on the critical path of the forward pass.
Benchmarks show significant throughput drops when scaling TP beyond 8 GPUs (the typical node size):
| TP Degree | Relative Throughput | Communication |
|---|---|---|
| TP=1 | 100% (baseline) | None |
| TP=4 | ~90% | Intra-node NVLink |
| TP=8 | ~80% | Intra-node NVLink |
| TP=16 | ~45% | Crosses node boundary (InfiniBand) |
| TP=32 | ~25% | Multi-node InfiniBand |
TP shards the activations along the hidden dimension for MLP and attention computations. But operations like LayerNorm and dropout need the full hidden dimension to compute correctly (LayerNorm computes mean and variance across all hidden features).
This means after the TP region, we must gather the full activations for LayerNorm, partially negating the memory savings.
Sequence parallelism (SP) solves this by splitting activations along the sequence dimension for operations outside the TP region. Since LayerNorm operates independently on each token, splitting tokens across GPUs works perfectly.
Let us track exactly what happens to the activation shape as data flows through a transformer layer with TP+SP:
| Location | TP Only | TP + SP |
|---|---|---|
| Enter column-linear | h: sharded, s: full | h: sharded, s: all-gather to full |
| TP region | h: sharded, s: full | h: sharded, s: full |
| Exit row-linear | h: full (all-reduce) | h: full (reduce-scatter to s: sharded) |
| SP region (LN, dropout) | h: full, s: full | h: full, s: sharded |
The maximum activation tensor on any GPU is now [seq/N, hidden] rather than [seq, hidden]. For TP+SP=16, this allows fitting sequence lengths of 16K tokens that would be impossible with TP alone.
Benchmarks confirm: TP+SP enables significantly larger batch sizes per GPU through activation memory savings, with the same communication cost as vanilla TP. The performance drop beyond TP=8 (crossing node boundaries) remains the same limiting factor.
Explore the trade-off between TP degree, throughput, and memory. Higher TP reduces per-GPU memory but adds communication overhead.
Adjust TP degree and model size. Observe the throughput/memory trade-off.
| Technique | What it shards | Communication | Best for |
|---|---|---|---|
| Column TP | Weights along output dim | Broadcast + all-gather | First linear in MLP, Q/K/V projections |
| Row TP | Weights along input dim | Scatter + all-reduce | Second linear in MLP, output projection |
| TP+SP | Weights + activations (hidden & seq) | All-gather + reduce-scatter | Maximum activation memory savings |