Three modes of multi-device programming: auto sharding, explicit sharding, and shard_map — from "compiler, take the wheel" to total manual control.
JAX supports three fundamentally different philosophies for multi-device programming. Understanding when to use each one is the key to productive TPU development.
| Mode | View | Explicit Sharding? | Explicit Collectives? | Best For |
|---|---|---|---|---|
| Auto | Global | No | No | Prototyping, simple models |
| Explicit | Global | Yes | No | Production models, catching mistakes |
| Manual (shard_map) | Per-device | Yes | Yes | Custom collectives, max performance |
Think of it as a spectrum of control:
jax.lax.with_sharding_constraint calls to guide the compiler. When you need guaranteed performance, you drop into shard_map.This chapter covers all three modes with real code examples you can run on free TPU v2-8 hardware in Google Colab. By the end, you will know when to use each mode and how to implement the collective matmul pattern — the single most important optimization for overlapping communication with computation in production Transformer training.
The chapter is structured as a progression: we start with the easiest approach (auto), build understanding of explicit sharding, and then implement increasingly sophisticated patterns in shard_map.
jax.jit plays two roles: it compiles Python to fast XLA code, and — when inputs are sharded — it distributes computation across devices. The XLA compiler's Shardy partitioner automatically decides how to split work and where to insert communication.
Here is a complete sharded matmul in auto mode:
import jax import jax.numpy as jnp # Create a 4x2 mesh (e.g., TPU v5e 4x2) mesh = jax.make_mesh(axis_shapes=(4, 2), axis_names=('X', 'Y')) jax.set_mesh(mesh) # Create sharded arrays In = jnp.zeros((8, 2048), dtype=jnp.bfloat16, device=jax.NamedSharding(mesh, jax.P('X', 'Y'))) W = jnp.zeros((2048, 8192), dtype=jnp.bfloat16, device=jax.NamedSharding(mesh, jax.P('Y', None))) def matmul_square(In, W): return jnp.einsum('bd,df->bf', jnp.square(In), W) # Compile with output sharding constraint jit_matmul = jax.jit(matmul_square, out_shardings=jax.P('X', None)).lower(In, W).compile() out = jit_matmul(In, W)
What is actually happening at the hardware level?
W is sharded 2-way along the contracting dimension: W[DY, F]In is sharded 8-way: 4-way batch, 2-way contracting: In[BX, DY]P('X', None) — sharded along batch, replicated across modelUsing our sharding notation, XLA will emit:
We can verify this by examining the HLO with jit_matmul.as_text():
# The matmul (local on each device) %fusion = bf16[2,8192]{1,0} fusion( bf16[2,1024] %param, bf16[8192,1024] %copy-done) # The AllReduce across devices ROOT %AllReduce = bf16[2,8192]{1,0} AllReduce( bf16[2,8192] %fusion)
Note the per-device shapes: bf16[2, 1024] for activations (batch=8 split 4 ways, d_model=2048 split 2 ways).
jax.lax.with_sharding_constraint. But this is famously frustrating — you can annotate every intermediate variable and still not know if you will get the right outcome. This motivates the next mode.Explicit sharding (or "sharding in types") looks a lot like auto sharding, but with one crucial difference: sharding propagation happens at the JAX level, not the XLA level. This means JAX can catch ambiguous decisions and ask you for clarification instead of silently guessing.
The key concept: each JAX operation has a sharding rule that takes input shardings and produces an output sharding. For most ops, there is only one reasonable choice. For some (like einsum), it is ambiguous.
import jax import jax.sharding as shd # Create mesh with Explicit axis types mesh = jax.make_mesh( axis_shapes=(2, 2), axis_names=('X', 'Y'), axis_types=(shd.AxisType.Explicit, shd.AxisType.Explicit)) jax.set_mesh(mesh) x = jax.device_put(np.arange(16).reshape(8, 2), jax.P('X', 'Y')) @jax.jit def f(x): print(jax.typeof(x)) # float32[8@X, 2@Y] out = x * 2 print(jax.typeof(out)) # float32[8@X, 2@Y] return out
Notice the @X and @Y annotations in the type — the sharding is now part of the type system. For elementwise ops like x * 2, the output sharding is obvious: same as the input.
But for a matmul, things get ambiguous:
In = jnp.zeros((8, 2048), out_sharding=jax.P('X', 'Y')) W = jnp.zeros((2048, 8192), out_sharding=jax.P('Y', None)) @jax.jit def matmul_square(In, W): return jnp.einsum('bd,df->bf', jnp.square(In), W) matmul_square(In, W) # ERROR!
This produces a clear error:
Contracting dimensions are sharded and it is ambiguous
how the output should be sharded.
Please specify the output sharding via the `out_sharding` parameter.
Got lhs_contracting_spec=('Y',) and rhs_contracting_spec=('Y',)
The fix is to specify the output sharding:
@jax.jit def matmul_square(In, W): return jnp.einsum('bd,df->bf', jnp.square(In), W, out_sharding=jax.P('X', 'Y')) # This induces a ReduceScatter (output sharded along both axes)
Auto and Explicit modes can be composed using jax.sharding.auto_axes and jax.sharding.explicit_axes APIs, so you can use explicit control on some mesh axes and let the compiler handle others.
jnp.einsum('bd,df->bf', In, W) error in explicit mode when the contracting dimension is sharded?jax.shard_map is the "just let me write what I mean" mode. You get a per-device local view of the program and write all communication explicitly. No compiler surprises.
Here is a simple example — slicing and averaging across devices:
import jax import jax.numpy as jnp import jax.sharding as shd mesh = jax.make_mesh((2, 4), ('x', 'y'), (shd.AxisType.Explicit, shd.AxisType.Explicit)) jax.set_mesh(mesh) x = jnp.arange(0, 512, dtype=jnp.int32, out_sharding=jax.P(('x', 'y'))) # This function operates on 1/8th of the array @jax.shard_map(in_specs=jax.P(('x', 'y')), out_specs=jax.P()) def slice_and_average(x): assert x.shape == (512 // 8,) # Local view! return jax.lax.pmean(x[:4], axis_name=('x', 'y')) out = slice_and_average(x) assert out.shape == (4,)
What is happening? Each device sees only 64 elements (512 / 8). We slice the first 4 elements from each shard, then average them across all 8 devices using pmean (an AllReduce mean). The result is mean(x[:4], x[64:68], x[128:132], ...).
jax.jit, you would see the global 512-element array. Expressing "take elements at indices [0:4, 64:68, 128:132, ...]" in JAX is awkward and XLA might interpret it incorrectly. With shard_map, the intent is crystal clear: each device works on its own chunk.Use shard_map when you need any of the following:
The key primitives available inside shard_map:
| Primitive | What It Does | Communication |
|---|---|---|
jax.lax.psum | AllReduce (sum) across named axis | AllReduce |
jax.lax.pmean | AllReduce (mean) across named axis | AllReduce |
jax.lax.all_gather | Gather all shards along axis | AllGather |
jax.lax.ppermute | Permute data between devices | Point-to-point |
jax.lax.all_to_all | Transpose a sharded dimension | AllToAll |
jax.lax.axis_index | Get this device's index on axis | None |
jax.lax.axis_size | Get the number of devices on axis | None |
shard_map function operating on a [512,] array sharded across 8 devices, what shape does each device see?This is where shard_map really shines. Consider model parallelism where activations are sharded along the model dimension: A[BX, DY] * W[D, FY] → Out[BX, FY].
The naive approach first gathers all of A, then does a local matmul:
The problem: communication and computation happen sequentially. The TPU sits idle during the AllGather.
The collective matmul (Wang et al., 2023) overlaps them. The algorithm:
ppermute to circularly shift A to get the next chunkdef collective_matmul_allgather_lhs(lhs, rhs): axis_size = jax.lax.axis_size('Y') idx = jax.lax.axis_index('Y') chunk_size = lhs.shape[1] def f(i, carrys): accum, lhs = carrys # Matmul with the right chunk of rhs rhs_chunk = jax.lax.dynamic_slice_in_dim( rhs, (idx + i) % axis_size * chunk_size, chunk_size) update = lhs @ rhs_chunk # Circular shift lhs to get next chunk lhs = jax.lax.ppermute(lhs, axis_name='Y', perm=[(j, (j - 1) % axis_size) for j in range(axis_size)]) return accum + update, lhs accum = jnp.zeros((lhs.shape[0], rhs.shape[1])) accum = jax.lax.pvary(accum, ('X', 'Y')) accum, lhs = jax.lax.fori_loop( 0, axis_size - 1, f, (accum, lhs), unroll=True) # Last chunk i = axis_size - 1 rhs_chunk = jax.lax.dynamic_slice_in_dim( rhs, (idx + i) % axis_size * chunk_size, chunk_size) return accum + lhs @ rhs_chunk
The profile for the collective matmul shows no AllGather op — it is all useful compute with communication overlapped underneath. FLOPs utilization is much higher.
Before we go further, let us solidify the two core abstractions that all three modes share: meshes and PartitionSpecs.
A mesh is a named, shaped arrangement of your TPU devices. It assigns human-readable names to the physical axes of your hardware:
# A 4x2 mesh: 4 devices along 'X', 2 along 'Y' mesh = jax.make_mesh( axis_shapes=(4, 2), axis_names=('X', 'Y') )
Common mesh configurations and their use:
| Config | Mesh | Use Case |
|---|---|---|
| Pure DP | (8,), ('dp',) | Replicate model, shard data across 8 devices |
| DP + TP | (4, 2), ('dp', 'tp') | 4-way data parallel, 2-way tensor parallel |
| DP + TP + PP | (2, 2, 2), ('dp', 'tp', 'pp') | Add pipeline parallelism as a third axis |
A PartitionSpec (jax.P) maps array dimensions to mesh axes. Each position corresponds to an array dimension:
# For a 2D array [batch, features]: jax.P('X', 'Y') # Shard dim 0 along X, dim 1 along Y jax.P('X', None) # Shard dim 0 along X, replicate dim 1 jax.P(None, 'Y') # Replicate dim 0, shard dim 1 along Y jax.P() # Fully replicated
None means "not split — replicated across that axis." You can also combine axes: P(('X', 'Y')) shards a single dimension across both X and Y.jnp.zeros(..., device=jax.NamedSharding(mesh, P(...)))jax.device_put(array, P(...))jnp.zeros(..., out_sharding=P(...))jax.jit(f, out_shardings=P(...))jax.P('X', None) mean for a 2D array on a mesh with axes X and Y?Mixture of Experts (MoE) models are one of the most compelling use cases for shard_map. The routing is inherently irregular — each token goes to a different expert — which makes it hard for automatic partitioners to do the right thing.
We have:
W: float32[EX, D, F] — E expert weight matrices, sharded across XA: float32[SX, D] — activations, sharded across XB: int32[SX] — routing assignments (which expert for each token)Goal: compute Out[i] = A[i] @ W[B[i]] — each token is processed by its assigned expert.
def moe_local(W, A, B): S, _ = A.shape E, _, F = W.shape def expert_forward(carry, e): output = carry mask = (B == e)[:, None] # [S, 1] expert_result = A @ W[e] # [S, F] output = output + expert_result * mask return output, None output = jnp.zeros((S, F)) output, _ = jax.lax.scan(expert_forward, output, jnp.arange(E)) return output
This iterates over experts with masking. Each expert computes its transform of all tokens, then masks out the irrelevant ones. Simple but wasteful — we compute E times more FLOPs than needed.
If you just jax.jit the local implementation with sharded inputs, XLA will likely AllGather the full activations A locally. This is expensive in both communication and memory.
The better approach uses shard_map with an AllToAll to route tokens to the correct devices:
You can also use jax.lax.ragged_dot for a more efficient implementation that avoids the masking overhead entirely.
Let us put everything together and build a complete Transformer MLP block with overlapped communication — the kind of code that runs in production systems.
A Transformer FFN with Megatron-style model parallelism has two matmuls with different communication patterns:
| Operation | Notation | Naive Comm | Collective Comm |
|---|---|---|---|
| Up-projection | In[BX, DY] * Win[D, FY] → H[BX, FY] | AllGather A, then matmul | AG-matmul (overlap) |
| Down-projection | H[BX, FY] * Wout[FY, D] → Out[BX, DY] | Matmul, then ReduceScatter | RS-matmul (overlap) |
This is what we implemented in Chapter 4: circular shift + chunk matmul to overlap the AllGather with computation.
The complement: instead of gathering inputs, we scatter results. The idea:
Putting both together for a full FFN block:
# Full overlapped FFN: In[BX, DY] -> Out[BX, DY] @jax.shard_map( in_specs=(jax.P('X', 'Y'), jax.P(None, 'Y'), jax.P('Y', None)), out_specs=jax.P('X', 'Y')) def ffn_collective(x, w_up, w_down): # Up-projection: AG-matmul overlaps AllGather with compute h = collective_matmul_allgather(x, w_up) h = jax.nn.gelu(h) # Down-projection: RS-matmul overlaps compute with ReduceScatter out = collective_matmul_reducescatter(h, w_down) return out
The collective matmuls above use unidirectional circular shifts. An extension uses bidirectional communication — sending data both left and right simultaneously. This can further reduce latency by overlapping both directions of the ring, though the implementation is more complex.
| Scenario | Recommended Approach | Why |
|---|---|---|
| Prototyping a new model | jax.jit with auto sharding | Fastest to write, good enough for initial experiments |
| Production training loop | Explicit sharding + shard_map for hot path | Catches mistakes at trace time; overlapped comms for performance |
| MoE routing | shard_map with explicit AllToAll | Irregular routing patterns confuse auto partitioners |
| Custom attention (e.g., Ring Attention) | shard_map | Need precise control of which chunks go where |
| Debugging why auto mode is slow | Rewrite in shard_map, compare profiles | Eliminates compiler as a variable |
Here are several problems to test your understanding. These are meant to be implemented on a real TPU (free TPU v2-8 available on Google Colab).
Let A: float32[SX, DY] with X * Y = N.
Part a: Write a function that computes the average within each (X, Y) shard, returning an array of shape [X, Y]. Implement in both jax.jit and shard_map. Profile each. Was communication added?
shard_map, this is trivial — just x.mean(keepdims=True) on each device. With jax.jit, you need careful reshaping: reshape to [X, S//X, Y, D//Y] and mean over axes (1, 3).Part b: Write a function that returns roll(x, shift, axis=0) - x for some shift, applied within each shard along X. Use shard_map.
Build a complete MoE forward pass:
jax.jit version with sharded inputs — what communication does it add?shard_map and explicit AllToAll to avoid materializing a full [E, S, D] bufferA[BX, DY] * W[DY, F] → Out[BX, F] with overlapped communication. Tile over the output dimension and use jax.lax.psum.Tmp[BX, FY] * W2[FY, D] → Out[BX, DY]. Be careful about passing only the minimal data needed. Hint: permute the result as you accumulate.In[BX, DY] * Win * Wout → Out[BX, DY]Rewrite the collective AllReduce and ReduceScatter matmuls to use bidirectional communication (sending data both left and right simultaneously). How much faster are these?