When arrays don't fit on one chip, we split them. Understanding distributed matrix multiplication reduces to four cases and four communication primitives.
When you train an LLM on ten thousand chips, you are still doing abstractly the same computation as on one chip. The difference is that your arrays do not fit in the HBM of a single chip.
We "shard" (split) arrays across devices. Sometimes for memory — the model simply does not fit. Sometimes for speed — even if it fits on fewer chips, using more gives us more FLOPs/s. During inference, we often choose larger topologies to reduce latency rather than because we need the memory.
Consider a 2D array A[I, J] sharded across 4 TPUs in a 2×2 mesh:
| Y=0 | Y=1 | |
|---|---|---|
| X=0 | A[0:I/2, 0:J/2] | A[0:I/2, J/2:J] |
| X=1 | A[I/2:I, 0:J/2] | A[I/2:I, J/2:J] |
The global (logical) shape is still (I, J). But the local shape on each device is (I/2, J/2) — each chip holds 1/4 of the total array.
We use a clean notation: subscripts on array dimensions tell you which mesh axis they are sharded across.
A device mesh is a named grid of devices. Example: Mesh({'X': 4, 'Y': 2}) is an 8-device grid with axis names X and Y.
A sharding assigns mesh axes to array dimensions using subscripts:
| Notation | Meaning | Local Shape (for I=1024, J=4096) |
|---|---|---|
| A[I, J] | Fully replicated on every device | (1024, 4096) × 8 copies |
| A[IX, J] | I sharded across X, J replicated across Y | (256, 4096) |
| A[IX, JY] | I sharded across X, J across Y | (256, 2048) |
| A[IXY, J] | I sharded across both X and Y (flattened) | (128, 4096) |
A[IX, JX] is forbidden — you cannot shard two different tensor dimensions along the same mesh axis.When a dimension is NOT subscripted with a mesh axis, the data is replicated along that axis. For example, A[IX, J] on Mesh({'X': 4, 'Y': 2}) means I is split 4 ways across X, but J is fully present on every device — with 2 complete copies (one per Y-plane).
import jax mesh = jax.make_mesh((4, 2), ('X', 'Y')) # A[I_X, J_Y] sharding: A = jnp.zeros((1024, 4096), device=P('X', 'Y')) # A[I, J_Y] (I replicated, J sharded): B = jnp.zeros((2048, 4096), device=P(None, 'Y'))
The simplest case: neither input has a sharded contracting dimension. No communication is needed at all.
The contracting dimension is the one being summed over. In A[I, J] · B[J, K] → C[I, K], J is the contracting dimension.
All of these work with zero communication:
Think about why: each device has a complete slice of the contracting dimension. The multiplications along J are fully local. The non-contracting dimensions just come along for the ride.
The last example is particularly powerful: A[IX, J] · B[J, KY] → C[IX, KY] gives you a result sharded across both mesh axes with zero comms. This is the basis of many efficient parallelism strategies.
When one input has its contracting dimension sharded, we cannot do a local multiply directly. We need to first gather the complete contracting dimension onto every device.
A is sharded along J (the contracting dim), but B expects the full J. Solution:
An AllGather removes a sharding subscript: each device sends its shard around a ring until every device has a full copy.
Watch 8 devices exchange shards around a ring. Each starts with 1/8 of the array and ends with a full copy.
For V total bytes sharded across X devices using bidirectional ring (with wraparounds):
When do we enter a latency-bound regime? When each shard is so small that per-hop overhead (~1 μs) dominates. For TPU v5e with 4.5e10 unidirectional bandwidth, any buffer under ~45 kB will be latency-bound.
Gathering over multiple axes increases available bandwidth by a factor of Naxes:
When both inputs are sharded along the contracting dimension on the same mesh axis:
Here, each device can do a local matmul of its partial shard. But the result on each device is only a partial sum — we need to add them all up.
We write this using the "unreduced" notation {UX}:
The partial sums are then resolved with an AllReduce:
An AllReduce can be decomposed into two cheaper operations:
This decomposition is crucial because often we want the result sharded anyway. In that case, we can skip the final AllGather and just use the ReduceScatter, saving half the communication.
That is 2x the cost of an AllGather, because we do a ReduceScatter (V / Wici) then an AllGather (V / Wici).
The fourth case is when both non-contracting dimensions are sharded along the same mesh axis:
This is invalid because device i along X would hold the (i, i)-th block of C — a diagonal entry. There is not enough information to reconstruct the full matrix.
or:
In both cases, the result only mentions X once. Which you pick depends on which sharding the downstream operations need, and on the relative sizes of the arrays (gather the smaller one if comms is the bottleneck).
| Case | Condition | Communication |
|---|---|---|
| 1 | Neither input sharded on contracting dim | None |
| 2 | One input sharded on contracting dim | AllGather the sharded input |
| 3 | Both sharded on contracting dim (same axis) | Local matmul + AllReduce (or ReduceScatter) |
| 4 | Both sharded on same axis (non-contracting) | AllGather one input first |
Here is the complete cost model for all four communication primitives, assuming we are in the bandwidth-bound regime (arrays large enough that per-hop latency is negligible).
| Operation | What It Does | Syntax | Time |
|---|---|---|---|
| AllGather | Removes sharding subscript, replicates | [AX, B] → [A, B] | V / (Wici × Naxes) |
| ReduceScatter | Sums partial products, introduces sharding | [A, B] {UX} → [AX, B] | Same as AllGather |
| AllReduce | Sums partial products, keeps replicated | [A, B] {UX} → [A, B] | 2 × AllGather |
| AllToAll | Moves subscript between dims | [A, BX] → [AX, B] | AllGather / 4 |
In an AllGather, every shard must reach every device: each shard hops across the full ring. In an AllToAll, shard i only needs to reach device i. On average, that is N/4 hops (half the ring, with bidirectional sending). This gives a factor of 4 savings.
When arrays are very small (< ~45 kB per hop on TPU v5e), per-hop latency (~1 μs) dominates:
In this regime, more devices does increase communication time. This matters during autoregressive generation where buffers are tiny.
The AllToAll moves a subscript from one dimension to another:
Think of it as a distributed transpose. It arises naturally in Mixture-of-Experts models (routing tokens to experts on different devices) and when resharding between computation phases that need different layouts.
For a 1D mesh, the cost is V / (4 × Wici). For an ND mesh with axes of sizes A, B, C:
This is a deeper fact than it first appears. If the forward pass does:
Then the backward pass does:
And vice versa. This is because broadcast and reduce are transposes of each other as linear operators, and AllGather/ReduceScatter are their Kronecker products with the identity.
In practice, we overlap the AllGather/ReduceScatter with the matmul itself using a technique called collective matmul. The idea: start the matmul on available chunks while the remaining chunks are still being gathered. Each chunk's matmul overlaps with the next chunk's network transfer.
This is how we approach the max(Tmath, Tcomms) lower bound in practice.
Array A[IX, J, K, ...] on Mesh({'X': 4, 'Y': 8, 'Z': 2}). Only sharded across X. What is the ratio of total bytes across all chips to one copy of A?
AllGatherX([BX, DY]) on TPU v4p 4×4×4, B=1024, D=4096, bf16. Mesh{'X':4, 'Y':4, 'Z':4}.
We gather over X only. The array on each Y-shard: bf16[256, 1024] = 0.5 MB. Total gathered = bf16[1024, 1024] = 2 MB per Y-shard.
AllGatherX([BX]) with B=128 in bf16 on TPU v4p 4×4×4. Total = 256 bytes, 64 bytes per device. Each hop takes ~0 bandwidth time. With wraparound on X=4, just 2 hops needed: ~2 μs.
X[B, D] · Y[DX, F] → Z[B, F]. Two strategies:
Strategy 1: AllGather Y first, then matmul. Cost: max(2BDF/C, 2DF/W).
Strategy 2: Treat as Case 3 (local matmul + AllReduce). Cost: max(2BDF/(X×C), 4BF/W).
Strategy 2 does 1/X fewer FLOPs but AllReduce costs 4BF/W. When D > 2B, Strategy 2 can be better for comms-bound cases. But it requires the contracting dim to be sharded on both inputs, which is uncommon in practice (e.g., FSDP shards params and activations along the same axis).
| Primitive | Effect | Cost |
|---|---|---|
| AllGather | Removes subscript: [AX] → [A] | V / W |
| ReduceScatter | Sums + shards: [A]{UX} → [AX] | V / W |
| AllReduce | Sums: [A]{UX} → [A] | 2V / W |
| AllToAll | Moves subscript: [AX, B] → [A, BX] | V / (4W) |
Key insights:
1. Communication costs do not depend on the number of devices (in the bandwidth-bound regime).
2. ReduceScatter and AllGather are transposes of each other. Every AllGather in forward = ReduceScatter in backward.
3. AllToAll is 4x cheaper than AllGather — use it when you just need to move a sharding subscript.
4. Collective matmul overlaps comms with compute, approaching the max(Tmath, Tcomms) bound.