Austin et al., Chapter 10

Programming TPUs in JAX

Three modes of multi-device programming: auto sharding, explicit sharding, and shard_map — from "compiler, take the wheel" to total manual control.

Prerequisites: JAX basics (jax.jit, jnp, einsum), sharding concepts (data/model parallelism from Chapters 3-5), and collectives (AllGather, AllReduce, ReduceScatter).
9
Chapters
2
Simulations
9
Quizzes

Chapter 0: Three Modes of Parallelism

JAX supports three fundamentally different philosophies for multi-device programming. Understanding when to use each one is the key to productive TPU development.

ModeViewExplicit Sharding?Explicit Collectives?Best For
AutoGlobalNoNoPrototyping, simple models
ExplicitGlobalYesNoProduction models, catching mistakes
Manual (shard_map)Per-deviceYesYesCustom collectives, max performance

Think of it as a spectrum of control:

Auto: "Compiler, take the wheel!"
Write single-device code. XLA (Shardy) automatically partitions arrays, decides communication. Magical when it works.
↓ more control
Explicit: "JAX, take the wheel!"
Write single-device code but sharding is part of the type system. JAX propagates shardings and errors when ambiguous.
↓ maximum control
Manual: "Let me write what I mean!"
Per-device view. You write all communication explicitly — AllGather, psum, ppermute. No compiler surprises.
In practice: Most JAX parallel programming (about 60% of the work) is done in Auto mode with strategic jax.lax.with_sharding_constraint calls to guide the compiler. When you need guaranteed performance, you drop into shard_map.

This chapter covers all three modes with real code examples you can run on free TPU v2-8 hardware in Google Colab. By the end, you will know when to use each mode and how to implement the collective matmul pattern — the single most important optimization for overlapping communication with computation in production Transformer training.

The chapter is structured as a progression: we start with the easiest approach (auto), build understanding of explicit sharding, and then implement increasingly sophisticated patterns in shard_map.

Check: Which JAX parallelism mode gives you a per-device local view and requires explicit communication calls?

Chapter 1: Auto Sharding

jax.jit plays two roles: it compiles Python to fast XLA code, and — when inputs are sharded — it distributes computation across devices. The XLA compiler's Shardy partitioner automatically decides how to split work and where to insert communication.

Here is a complete sharded matmul in auto mode:

import jax
import jax.numpy as jnp

# Create a 4x2 mesh (e.g., TPU v5e 4x2)
mesh = jax.make_mesh(axis_shapes=(4, 2), axis_names=('X', 'Y'))
jax.set_mesh(mesh)

# Create sharded arrays
In = jnp.zeros((8, 2048), dtype=jnp.bfloat16,
               device=jax.NamedSharding(mesh, jax.P('X', 'Y')))
W = jnp.zeros((2048, 8192), dtype=jnp.bfloat16,
              device=jax.NamedSharding(mesh, jax.P('Y', None)))

def matmul_square(In, W):
  return jnp.einsum('bd,df->bf', jnp.square(In), W)

# Compile with output sharding constraint
jit_matmul = jax.jit(matmul_square,
  out_shardings=jax.P('X', None)).lower(In, W).compile()

out = jit_matmul(In, W)

What is actually happening at the hardware level?

Using our sharding notation, XLA will emit:

Out[BX, F] { UY } = In[BX, DY] *D W[DY, F]
Local matmul on each device — produces partial sum
Out[BX, F] = AllReduceY(Out[BX, F] { UY })
Sum partial products across Y axis

We can verify this by examining the HLO with jit_matmul.as_text():

# The matmul (local on each device)
%fusion = bf16[2,8192]{1,0} fusion(
  bf16[2,1024] %param,
  bf16[8192,1024] %copy-done)

# The AllReduce across devices
ROOT %AllReduce = bf16[2,8192]{1,0} AllReduce(
  bf16[2,8192] %fusion)

Note the per-device shapes: bf16[2, 1024] for activations (batch=8 split 4 ways, d_model=2048 split 2 ways).

"Compiler tickling": When Shardy makes a bad choice, you fix it with jax.lax.with_sharding_constraint. But this is famously frustrating — you can annotate every intermediate variable and still not know if you will get the right outcome. This motivates the next mode.
Check: In auto sharding mode, who decides what communication (AllReduce, AllGather, etc.) to insert?

Chapter 2: Explicit Sharding

Explicit sharding (or "sharding in types") looks a lot like auto sharding, but with one crucial difference: sharding propagation happens at the JAX level, not the XLA level. This means JAX can catch ambiguous decisions and ask you for clarification instead of silently guessing.

The key concept: each JAX operation has a sharding rule that takes input shardings and produces an output sharding. For most ops, there is only one reasonable choice. For some (like einsum), it is ambiguous.

import jax
import jax.sharding as shd

# Create mesh with Explicit axis types
mesh = jax.make_mesh(
  axis_shapes=(2, 2), axis_names=('X', 'Y'),
  axis_types=(shd.AxisType.Explicit, shd.AxisType.Explicit))
jax.set_mesh(mesh)

x = jax.device_put(np.arange(16).reshape(8, 2), jax.P('X', 'Y'))

@jax.jit
def f(x):
  print(jax.typeof(x))   # float32[8@X, 2@Y]
  out = x * 2
  print(jax.typeof(out))  # float32[8@X, 2@Y]
  return out

Notice the @X and @Y annotations in the type — the sharding is now part of the type system. For elementwise ops like x * 2, the output sharding is obvious: same as the input.

But for a matmul, things get ambiguous:

In = jnp.zeros((8, 2048), out_sharding=jax.P('X', 'Y'))
W = jnp.zeros((2048, 8192), out_sharding=jax.P('Y', None))

@jax.jit
def matmul_square(In, W):
  return jnp.einsum('bd,df->bf', jnp.square(In), W)

matmul_square(In, W)  # ERROR!

This produces a clear error:

Contracting dimensions are sharded and it is ambiguous
how the output should be sharded.
Please specify the output sharding via the `out_sharding` parameter.
Got lhs_contracting_spec=('Y',) and rhs_contracting_spec=('Y',)
This is the key advantage of explicit mode. In auto mode, the compiler would silently choose one option. In explicit mode, JAX tells you: "I don't know if you want an AllReduce (output replicated) or a ReduceScatter (output sharded). You decide."

The fix is to specify the output sharding:

@jax.jit
def matmul_square(In, W):
  return jnp.einsum('bd,df->bf', jnp.square(In), W,
                     out_sharding=jax.P('X', 'Y'))
# This induces a ReduceScatter (output sharded along both axes)

Auto and Explicit modes can be composed using jax.sharding.auto_axes and jax.sharding.explicit_axes APIs, so you can use explicit control on some mesh axes and let the compiler handle others.

Check: Why does jnp.einsum('bd,df->bf', In, W) error in explicit mode when the contracting dimension is sharded?

Chapter 3: Manual Mode via shard_map

jax.shard_map is the "just let me write what I mean" mode. You get a per-device local view of the program and write all communication explicitly. No compiler surprises.

Here is a simple example — slicing and averaging across devices:

import jax
import jax.numpy as jnp
import jax.sharding as shd

mesh = jax.make_mesh((2, 4), ('x', 'y'),
  (shd.AxisType.Explicit, shd.AxisType.Explicit))
jax.set_mesh(mesh)

x = jnp.arange(0, 512, dtype=jnp.int32,
               out_sharding=jax.P(('x', 'y')))

# This function operates on 1/8th of the array
@jax.shard_map(in_specs=jax.P(('x', 'y')),
               out_specs=jax.P())
def slice_and_average(x):
  assert x.shape == (512 // 8,)  # Local view!
  return jax.lax.pmean(x[:4], axis_name=('x', 'y'))

out = slice_and_average(x)
assert out.shape == (4,)

What is happening? Each device sees only 64 elements (512 / 8). We slice the first 4 elements from each shard, then average them across all 8 devices using pmean (an AllReduce mean). The result is mean(x[:4], x[64:68], x[128:132], ...).

Why not jax.jit? With jax.jit, you would see the global 512-element array. Expressing "take elements at indices [0:4, 64:68, 128:132, ...]" in JAX is awkward and XLA might interpret it incorrectly. With shard_map, the intent is crystal clear: each device works on its own chunk.

When to Use shard_map

Use shard_map when you need any of the following:

The key primitives available inside shard_map:

PrimitiveWhat It DoesCommunication
jax.lax.psumAllReduce (sum) across named axisAllReduce
jax.lax.pmeanAllReduce (mean) across named axisAllReduce
jax.lax.all_gatherGather all shards along axisAllGather
jax.lax.ppermutePermute data between devicesPoint-to-point
jax.lax.all_to_allTranspose a sharded dimensionAllToAll
jax.lax.axis_indexGet this device's index on axisNone
jax.lax.axis_sizeGet the number of devices on axisNone
Check: Inside a shard_map function operating on a [512,] array sharded across 8 devices, what shape does each device see?

Chapter 4: The Collective Matmul

This is where shard_map really shines. Consider model parallelism where activations are sharded along the model dimension: A[BX, DY] * W[D, FY] → Out[BX, FY].

The naive approach first gathers all of A, then does a local matmul:

A[BX, D] = AllGatherY(A[BX, DY])
Gather full activations — blocking communication
Out[BX, FY] = A[BX, D] * W[D, FY]
Local matmul after communication completes

The problem: communication and computation happen sequentially. The TPU sits idle during the AllGather.

The collective matmul (Wang et al., 2023) overlaps them. The algorithm:

  1. For each Y shard: perform a matmul of the local chunk of A with the corresponding chunk of W, producing a partial result
  2. Simultaneously, use ppermute to circularly shift A to get the next chunk
  3. Accumulate partial results
def collective_matmul_allgather_lhs(lhs, rhs):
  axis_size = jax.lax.axis_size('Y')
  idx = jax.lax.axis_index('Y')
  chunk_size = lhs.shape[1]

  def f(i, carrys):
    accum, lhs = carrys
    # Matmul with the right chunk of rhs
    rhs_chunk = jax.lax.dynamic_slice_in_dim(
      rhs, (idx + i) % axis_size * chunk_size, chunk_size)
    update = lhs @ rhs_chunk
    # Circular shift lhs to get next chunk
    lhs = jax.lax.ppermute(lhs, axis_name='Y',
      perm=[(j, (j - 1) % axis_size)
            for j in range(axis_size)])
    return accum + update, lhs

  accum = jnp.zeros((lhs.shape[0], rhs.shape[1]))
  accum = jax.lax.pvary(accum, ('X', 'Y'))
  accum, lhs = jax.lax.fori_loop(
    0, axis_size - 1, f, (accum, lhs), unroll=True)

  # Last chunk
  i = axis_size - 1
  rhs_chunk = jax.lax.dynamic_slice_in_dim(
    rhs, (idx + i) % axis_size * chunk_size, chunk_size)
  return accum + lhs @ rhs_chunk
Performance win: Benchmarking on a TPU v5e 4x2, the naive approach takes 311 μs (with a blocking AllGather visible in the profile). The collective matmul takes 244 μs — a 22% improvement. The unsharded matmul takes 224 μs, so the collective matmul adds only 9% overhead despite distributing across devices.

The profile for the collective matmul shows no AllGather op — it is all useful compute with communication overlapped underneath. FLOPs utilization is much higher.

Check: What is the key advantage of a collective matmul over a naive AllGather-then-matmul approach?

Chapter 5: Meshes and PartitionSpecs

Before we go further, let us solidify the two core abstractions that all three modes share: meshes and PartitionSpecs.

Meshes

A mesh is a named, shaped arrangement of your TPU devices. It assigns human-readable names to the physical axes of your hardware:

# A 4x2 mesh: 4 devices along 'X', 2 along 'Y'
mesh = jax.make_mesh(
  axis_shapes=(4, 2),
  axis_names=('X', 'Y')
)

Common mesh configurations and their use:

ConfigMeshUse Case
Pure DP(8,), ('dp',)Replicate model, shard data across 8 devices
DP + TP(4, 2), ('dp', 'tp')4-way data parallel, 2-way tensor parallel
DP + TP + PP(2, 2, 2), ('dp', 'tp', 'pp')Add pipeline parallelism as a third axis

PartitionSpecs

A PartitionSpec (jax.P) maps array dimensions to mesh axes. Each position corresponds to an array dimension:

# For a 2D array [batch, features]:
jax.P('X', 'Y')     # Shard dim 0 along X, dim 1 along Y
jax.P('X', None)   # Shard dim 0 along X, replicate dim 1
jax.P(None, 'Y')   # Replicate dim 0, shard dim 1 along Y
jax.P()            # Fully replicated
Key mental model: A PartitionSpec answers the question: "Which mesh axis is this array dimension split across?" None means "not split — replicated across that axis." You can also combine axes: P(('X', 'Y')) shards a single dimension across both X and Y.

Ways to Create Sharded Arrays

  1. At creation time: jnp.zeros(..., device=jax.NamedSharding(mesh, P(...)))
  2. Via device_put: jax.device_put(array, P(...))
  3. Via out_sharding: jnp.zeros(..., out_sharding=P(...))
  4. Via jit output: jax.jit(f, out_shardings=P(...))
Check: What does jax.P('X', None) mean for a 2D array on a mesh with axes X and Y?

Chapter 6: MoE with shard_map

Mixture of Experts (MoE) models are one of the most compelling use cases for shard_map. The routing is inherently irregular — each token goes to a different expert — which makes it hard for automatic partitioners to do the right thing.

The Problem Setup

We have:

Goal: compute Out[i] = A[i] @ W[B[i]] — each token is processed by its assigned expert.

Local Implementation (No Sharding)

def moe_local(W, A, B):
  S, _ = A.shape
  E, _, F = W.shape

  def expert_forward(carry, e):
    output = carry
    mask = (B == e)[:, None]     # [S, 1]
    expert_result = A @ W[e]    # [S, F]
    output = output + expert_result * mask
    return output, None

  output = jnp.zeros((S, F))
  output, _ = jax.lax.scan(expert_forward, output, jnp.arange(E))
  return output

This iterates over experts with masking. Each expert computes its transform of all tokens, then masks out the irrelevant ones. Simple but wasteful — we compute E times more FLOPs than needed.

The Sharded Challenge

If you just jax.jit the local implementation with sharded inputs, XLA will likely AllGather the full activations A locally. This is expensive in both communication and memory.

The better approach uses shard_map with an AllToAll to route tokens to the correct devices:

1. Sort tokens by expert assignment
Group tokens headed for the same device together
2. AllToAll to route tokens
Send each group to the device hosting that expert
3. Local matmul with local expert
Each device multiplies incoming tokens by its expert weights
4. AllToAll to return results
Send results back to the originating devices
Ragged AllToAll: In practice, not all tokens go to every expert. If only k of N experts are needed for a given device's tokens, the AllToAll cost is reduced by k/N. This is called a sparse or ragged AllToAll — a major advantage for top-k routing (e.g., DeepSeek uses top-6 out of 256 experts).

You can also use jax.lax.ragged_dot for a more efficient implementation that avoids the masking overhead entirely.

Check: Why is an AllToAll better than an AllGather for MoE routing?

Chapter 7: End-to-End Transformer

Let us put everything together and build a complete Transformer MLP block with overlapped communication — the kind of code that runs in production systems.

The Two Matmuls

A Transformer FFN with Megatron-style model parallelism has two matmuls with different communication patterns:

OperationNotationNaive CommCollective Comm
Up-projectionIn[BX, DY] * Win[D, FY] → H[BX, FY]AllGather A, then matmulAG-matmul (overlap)
Down-projectionH[BX, FY] * Wout[FY, D] → Out[BX, DY]Matmul, then ReduceScatterRS-matmul (overlap)

AllGather Collective Matmul (Up-projection)

This is what we implemented in Chapter 4: circular shift + chunk matmul to overlap the AllGather with computation.

ReduceScatter Collective Matmul (Down-projection)

The complement: instead of gathering inputs, we scatter results. The idea:

  1. Tile over the input (F) dimension
  2. Compute partial products for each tile
  3. Permute the accumulator as you go, so each device accumulates only its output shard
Key insight for ReduceScatter matmul: You are permuting the result as you accumulate, not the input. This means each device only ever holds a 1/Y-sized result buffer. Be careful to only send the minimal amount of data needed.

Combined Pipeline

Putting both together for a full FFN block:

# Full overlapped FFN: In[BX, DY] -> Out[BX, DY]
@jax.shard_map(
  in_specs=(jax.P('X', 'Y'), jax.P(None, 'Y'), jax.P('Y', None)),
  out_specs=jax.P('X', 'Y'))
def ffn_collective(x, w_up, w_down):
  # Up-projection: AG-matmul overlaps AllGather with compute
  h = collective_matmul_allgather(x, w_up)
  h = jax.nn.gelu(h)
  # Down-projection: RS-matmul overlaps compute with ReduceScatter
  out = collective_matmul_reducescatter(h, w_down)
  return out
Performance comparison on TPU v5e 4x2:
• Naive jax.jit FFN: 622 μs (blocking AllGather + blocking AllReduce visible in profile)
• Collective matmul FFN: ~490 μs (communication hidden behind compute)
• Theoretical minimum: 448 μs (pure compute, zero communication overhead)

The collective matmul approach gets within 10% of the theoretical minimum.

Bidirectional Communication

The collective matmuls above use unidirectional circular shifts. An extension uses bidirectional communication — sending data both left and right simultaneously. This can further reduce latency by overlapping both directions of the ring, though the implementation is more complex.

Choosing the Right Approach

ScenarioRecommended ApproachWhy
Prototyping a new modeljax.jit with auto shardingFastest to write, good enough for initial experiments
Production training loopExplicit sharding + shard_map for hot pathCatches mistakes at trace time; overlapped comms for performance
MoE routingshard_map with explicit AllToAllIrregular routing patterns confuse auto partitioners
Custom attention (e.g., Ring Attention)shard_mapNeed precise control of which chunks go where
Debugging why auto mode is slowRewrite in shard_map, compare profilesEliminates compiler as a variable
Check: In an end-to-end Transformer FFN with Megatron sharding, which collective does the up-projection overlap, and which does the down-projection overlap?

Chapter 8: Worked Problems

Here are several problems to test your understanding. These are meant to be implemented on a real TPU (free TPU v2-8 available on Google Colab).

Problem 1: Shard-local Operations

Let A: float32[SX, DY] with X * Y = N.

Part a: Write a function that computes the average within each (X, Y) shard, returning an array of shape [X, Y]. Implement in both jax.jit and shard_map. Profile each. Was communication added?

Hint: With shard_map, this is trivial — just x.mean(keepdims=True) on each device. With jax.jit, you need careful reshaping: reshape to [X, S//X, Y, D//Y] and mean over axes (1, 3).

Part b: Write a function that returns roll(x, shift, axis=0) - x for some shift, applied within each shard along X. Use shard_map.

Problem 2: Mixture of Experts

Build a complete MoE forward pass:

  1. Start with a local (single-device) implementation using scan + masking
  2. Profile the jax.jit version with sharded inputs — what communication does it add?
  3. Implement with shard_map and explicit AllToAll to avoid materializing a full [E, S, D] buffer
  4. Extend to top-k routing where each token goes to k experts

Problem 3: Collective Matmul Variations

  1. AllReduce collective matmul: Implement A[BX, DY] * W[DY, F] → Out[BX, F] with overlapped communication. Tile over the output dimension and use jax.lax.psum.
  2. ReduceScatter collective matmul: Implement Tmp[BX, FY] * W2[FY, D] → Out[BX, DY]. Be careful about passing only the minimal data needed. Hint: permute the result as you accumulate.
  3. End-to-end: Combine both into a full Transformer block: In[BX, DY] * Win * Wout → Out[BX, DY]

Problem 4: Bidirectional Collective Matmul

Rewrite the collective AllReduce and ReduceScatter matmuls to use bidirectional communication (sending data both left and right simultaneously). How much faster are these?

Summary of Chapter 10: JAX provides three modes for multi-device programming, each with different trade-offs between ease of use and control. Auto mode is great for prototyping. Explicit mode catches mistakes. shard_map gives you total control for maximum performance. The collective matmul pattern — overlapping communication with computation — is the single most important optimization for production Transformer training.
Check: What is the most commonly used approach in practice for JAX parallel programming?