From back-of-the-envelope calculations to real measurements — the JAX profiler, TensorBoard traces, HLO reading, and worked examples.
So far the scaling book has been entirely theoretical: back-of-the-envelope calculations based on hardware rooflines. That understanding gets you remarkably far. You can predict matmul runtimes, estimate communication costs, and choose sharding strategies — all from first principles.
But a lot of real optimization comes down to practical details: how the XLA compiler actually lowers your code, what fusions it creates, where it inserts copies, and whether your theoretical predictions match reality. When they do not, you need tools to understand what went wrong.
Think of profiling as the experimental counterpart to roofline analysis. The roofline model is your hypothesis: "this matmul should take 96 ms." The profile is your experiment: "it actually took 96 ms" (great!) or "it actually took 250 ms" (something is wrong). Without profiling, you are flying blind.
This chapter introduces the JAX/TensorBoard Profiler — a multi-purpose instrument for understanding what happens on a TPU when your program runs. We will cover:
Google exposes a stack of APIs for programming TPUs, from high-level JAX code to low-level Pallas or HLO. Most programmers write JAX code exclusively, which lets you write abstract NumPy-style linear algebra that is compiled automatically to run efficiently on TPUs.
Here is a simple example — a JAX program that multiplies two matrices:
import jax import jax.numpy as jnp def multiply(x, y): return jnp.einsum('bf,fd->db', x, y) y = jax.jit(multiply)(jnp.ones((128, 256)), jnp.ones((256, 16), dtype=jnp.bfloat16))
By calling jax.jit, we tell JAX to trace this function and emit a lower-level IR called StableHLO, a platform-agnostic IR for ML computation. This is then lowered to HLO by the XLA compiler.
The compiler runs many optimization passes to determine fusions (grouping related ops together), layouts (how arrays are tiled in memory), and other factors. The resulting HLO is what you observe in a JAX profile.
Here is the HLO for the matmul above (abridged):
ENTRY %main.5 (Arg_0.1: f32[128,256], Arg_1.2: bf16[256,16]) -> f32[16,128] { %Arg_1.2 = bf16[256,16]{1,0} parameter(1) %convert.3 = f32[256,16]{1,0} convert(bf16[256,16] %Arg_1.2) %Arg_0.1 = f32[128,256]{1,0} parameter(0) ROOT %dot.4 = f32[16,128]{1,0} dot( f32[256,16] %convert.3, f32[128,256] %Arg_0.1), lhs_contracting_dims={0}, rhs_contracting_dims={1} }
Notice how closely the HLO matches the original JAX code. The ROOT %dot.4 instruction is the actual matmul, contracting along dimensions 0 and 1 of the two f32 matrices.
jax.jit(f).lower(*args).compile().as_text(). Reading HLO is one of the most valuable skills for TPU performance engineering.When something goes wrong at a lower level, you have one more escape hatch: writing custom kernels in Pallas. But most of the time, you work at the JAX level and use the profiler to understand the HLO below.
JAX provides a multi-purpose TPU profiler that records everything from the duration of each subcomponent, the HLO of each program, memory usage, and more. You use the jax.profiler module to capture a trace while your program runs.
Here is the minimal recipe for capturing a profile:
import jax with jax.profiler.trace("/tmp/tensorboard"): key = jax.random.key(0) x = jax.random.normal(key, (1024, 1024)) y = x @ x y.block_until_ready() # View in TensorBoard: # pip install tensorboard tensorboard-plugin-profile # tensorboard --logdir=/tmp/tensorboard
block_until_ready() call is essential. JAX uses asynchronous dispatch — without blocking, the profiler context may close before computation actually runs on the TPU.Once in TensorBoard, the profiler exposes several key views:
| View | What It Shows | Use Case |
|---|---|---|
| Trace Viewer | Chronological timeline of all TPU operations | Understanding what each part of your program does and how long it takes |
| Graph Viewer | HLO dataflow graph with shape/sharding info | Understanding how operations feed into each other |
| Memory Profile | Memory usage over time | Debugging OOMs, understanding peak memory |
| Memory Viewer | Detailed buffer-level memory breakdown | Finding which tensors consume the most memory |
For annotating your traces, use jax.named_scope to label regions of code. These names show up in the Trace Viewer as nested scopes, making it much easier to identify which part of a Transformer corresponds to which operations:
with jax.named_scope("attention"): q = x @ w_q k = x @ w_k v = x @ w_v attn = softmax(q @ k.T / sqrt_dk) @ v with jax.named_scope("ffn"): h = gelu(attn @ w_up) out = h @ w_down
block_until_ready() needed when profiling?The Trace Viewer is probably the most useful part of the profiler. It shows a chronological timeline of all actions on each TPU core, letting you see exactly what the hardware is doing and for how long.
A few key things to understand about the Trace Viewer:
In a typical Transformer profile, you will see repeated blocks — one per layer. Within each block, you can identify the attention portion and the MLP/FFN portion. The XLA ops row tells you what is actually executing, while the scope rows help you map back to your code.
When you click on an XLA op in the trace, you get detailed information:
Since all TPUs typically execute the same instructions in SPMD programs, you usually only need to look at TPU:0. The timeline for other cores will be nearly identical.
| Pattern | What It Looks Like | What It Means |
|---|---|---|
| Dense compute blocks | Long, solid fusion bars | Good — the TPU is doing useful matmul work |
| Many small ops | Rapid alternation of tiny bars | Kernel launch overhead may dominate; consider fusing |
| Large AllGather | Wide bar labeled all-gather or all-reduce | Expensive communication; check if sharding is optimal |
| Copy ops | Bars labeled "copy" or "dynamic-update-slice" | Retiling between operations; may indicate layout mismatches |
| Gaps / idle time | Empty space between ops | Pipeline bubbles, host-device sync stalls, or data starvation |
Good annotations are the difference between a readable profile and a wall of inscrutable HLO names. The key tools:
jax.named_scope("name") — creates a visual scope in the Trace Viewer that groups all ops inside itjax.named_call(fn, name="name") — similar but wraps a function callwith named_scope("layer_0"): with named_scope("attention"):For production Transformer code, annotate at minimum: each layer, attention vs MLP, each projection, and each collective operation. This makes the profile immediately navigable.
HLO is not actually very hard to read, and it is immensely helpful for understanding what a given part of the trace corresponds to. Let us break down the anatomy of an HLO op with a real example.
Consider this op called fusion.3:
%fusion.3 = bf16[32,32,4096]{2,1,0:T(8,128)(2,1)S(1)} fusion(bf16[32,32,8192]{2,1,0:T(8,128)(2,1)S(1)} %fusion.32), kind=kCustom, calls=%all-reduce-scatter.3
Let us unpack each component:
| Component | Example | Meaning |
|---|---|---|
| Op Name | fusion.3 | A fusion or dot op: a set of operations containing at most 1 matmul plus related pointwise VPU-ops |
| Output Shape | bf16[32,32,4096] | Output dtype is bf16 (2 bytes/element), shape is [32, 32, 4096] |
| Layout | {2,1,0:T(8,128)(2,1)} | Physical memory ordering and tiling of the array |
| Memory Location | S(1) | S(0) = HBM, S(1) = VMEM, S(2)/S(3) = other memory spaces |
| Arguments | bf16[32,32,8192] %fusion.32 | This op takes one input: a bf16 array called fusion.32 |
The layout notation tells us how an N-dimensional array is laid out sequentially in memory. Consider a simpler example:
f32[3,5]{1,0:T(2,2)}
This means:
{1,0} — read right-to-left: dimension 0 is most major, dimension 1 is most minor. The physical layout is [3, 5], identical to the logical shape in this case.T(2,2) — the array is tiled in 2x2 chunks. Within each chunk, elements are stored row-major.For the more complex tiling bf16[32,32,8192]{2,1,0:T(8,128)(2,1)S(1)}, we have two levels of tiling: an outer (8, 128) tiling and an inner (2, 1) tiling within that. The inner (2, 1) tiling is used for bf16 so that loads are always multiples of 4 bytes.
Let us walk through a concrete padding calculation. Consider bf16[3, 5]{1,0:T(8,128)(2,1)}:
| Op Type | HLO Name Pattern | What It Does |
|---|---|---|
| Matrix multiply | dot, fusion with a dot inside | The workhorse — runs on the MXU/Tensor Core |
| Pointwise ops | add, multiply, tanh, fused into a fusion | Runs on the VPU/CUDA cores |
| AllReduce | all-reduce | Sum/mean across devices |
| AllGather | all-gather | Collect shards from all devices |
| ReduceScatter | reduce-scatter | Reduce then distribute shards |
| Copy / retile | copy, bitcast | Layout conversion between operations |
| Dynamic slice | dynamic-slice | Indexing into an array (e.g., embedding lookup) |
AUTO. When you jax.jit a function, you can let XLA compute its preferred layout for inputs, avoiding retiling copies within the program.S(1) mean in an HLO op annotation?While some fusions in the Trace Viewer can look complicated, the XLA Graph Viewer makes them much easier to parse. It shows the HLO as a visual dataflow graph — boxes for operations, arrows for data dependencies.
The Graph Viewer is especially helpful for:
A common pattern you will see in the Graph Viewer is a fusion containing a dot (matmul) followed by some pointwise operations. For example, an MLP up-projection might show:
XLA's fusion pass is quite aggressive. It will fuse the matmul with subsequent pointwise operations (ReLU, GeLU, bias adds, etc.) into a single operation, eliminating intermediate memory traffic. This is one of XLA's main advantages: you write clean, modular code and the compiler fuses it for free.
However, fusions are not always beneficial. Sometimes XLA fuses operations that would be better left separate, or it fails to fuse operations you expect. The Graph Viewer helps you verify that the compiler is doing what you want.
Let us walk through a real profile of a Transformer and verify our theoretical predictions. We will look at a model running on 8 TPU v2 cores with 4-way data parallelism and 2-way model parallelism.
Zooming into the FFN (feed-forward network) block, we see the up-projection matmul. The HLO tells us the per-shard shapes:
// Per-shard computation: bf16[8, 1024, 8192] * bf16[8192, 16384] -> bf16[8, 1024, 16384] // True (unsharded) shapes (4-way DP, 2-way MP): X: bf16[32, 1024, 8192] * Win: bf16[8192, 32768] -> Tmp: bf16[32, 1024, 32768]
How long should this take? Our batch size per DP shard is 8 * 1024 = 8192 tokens, so we are solidly compute-bound.
The profile reports 96 ms. That is essentially perfect FLOPs utilization — we are hitting the roofline.
At the end of the second matmul, there is a small fusion that is actually a ReduceScatter:
%fusion.1 = bf16[8,1024,4096]{2,1,0:T(8,128)(2,1)S(1)} fusion(bf16[8,1024,8192]{...} %fusion.31), kind=kCustom, calls=%all-reduce-scatter.1
Expected time: the array has size 2 * 32 * 1024 * 8192 = 128 MB with the batch axis sharded 4 ways on a TPU v2 4x2 (1.2 × 1011 bidirectional bandwidth). One hop required:
Profile reports: 1.13 ms. Again, essentially at the roofline.
The attention component shows the Q, K, V projections as separate matmuls. The Q projection uses weight matrix WQ of shape [dmodel=8192, nheads=32, dqkv=256], with Megatron sharding along the head dimension.
Let us do the exercise. The Q projection has true shape:
With Megatron sharding along the 32 heads (2-way MP), each shard computes 16 heads. The FLOPs per shard:
At 23 TFLOPS per core: T ≈ 68.7 × 1012 / (23 × 1012 × 8) ≈ 373 ms. If the profile shows a similar number, we are at the roofline. If much longer, check for retiling or communication overhead.
The workflow for any operation in a profile:
The Memory Profile view shows program memory as a function of time. This is invaluable for debugging out-of-memory (OOM) errors — the most common failure mode when scaling models.
In a typical profile, you will see:
| Memory Region | Typical Contents | Behavior Over Time |
|---|---|---|
| Parameters | Model weights | Constant — allocated at initialization |
| Optimizer states | Adam moments, gradient buffers | Constant after first step |
| Activations | Intermediate tensors from forward pass | Grows during forward, freed during backward |
| Temporaries | Buffers for collectives, retiling | Short-lived spikes |
The Memory Profile also helps answer practical questions:
The Memory Viewer (a separate tab) provides a more granular view, showing individual buffer allocations. This is useful for identifying which specific tensors are consuming the most memory.
copy operations that convert between different tilings.When you hit an OOM, follow this systematic approach:
| Component | Formula (bf16 weights, Adam) | Example: 7B model |
|---|---|---|
| Weights | 2 × params | 14 GB |
| Gradients | 2 × params | 14 GB |
| Adam m (first moment) | 4 × params (fp32) | 28 GB |
| Adam v (second moment) | 4 × params (fp32) | 28 GB |
| Master weights (fp32) | 4 × params | 28 GB |
| Total model state | 16 × params | 112 GB |
This 112 GB already exceeds a single H100's 80 GB HBM — before any activations. This is why sharding (ZeRO, TP, PP) is essential for models ≥ 7B.
Let us put everything together with some worked problems that test your ability to read profiles and reason about what is happening.
You see a profile with two big fusions, a reduce, and an AllReduce. One fusion has this HLO:
%fusion.1 = bf16[4096]{0:T(1024)(128)(2,1)} fusion( bf16[4096,8192]{1,0:T(8,128)(2,1)} %param.1, bf16[8192]{0:T(1024)(128)(2,1)} %reduce.6 ), kind=kLoop, calls=%fused_computation.1
The final AllReduce has replica_groups={{0,16,32,48,64,80,96,112}, ...}.
bf16[8192] * bf16[4096, 8192] → bf16[4096] (contracting over the 8192 dimension). The replica groups show 8-way model parallelism (stride of 16 in a 128-chip setup), so the true shapes are:This is two chained matrix multiplications with 8-way MP sharding:
def matmul(w1, w2, x): return jnp.einsum('wf,bf->bw', w2, jnp.einsum('fw,bw->bf', w1, x))
You profile a Transformer and find an unexpectedly large AllGather taking 80% of the step time. This is a sign that XLA's automatic partitioner (GSPMD/Shardy) chose a poor sharding for some intermediate tensor.
The fix: use jax.lax.with_sharding_constraint to guide the compiler:
def ffn(x, w_up, w_down): h = jnp.einsum('bd,df->bf', x, w_up) # Force h to be sharded along F, not gathered h = jax.lax.with_sharding_constraint(h, P('X', 'Y')) h = jax.nn.gelu(h) return jnp.einsum('bf,df->bd', h, w_down)
with_sharding_constraint or shard_map to fix itHere are real-world patterns you will encounter and how to diagnose them from a profile:
| Symptom | Profile Appearance | Root Cause | Fix |
|---|---|---|---|
| Step time 3x expected | Giant AllGather taking 70% of step | XLA gathered full weight matrix instead of using TP | Add with_sharding_constraint to intermediate activations |
| Matmul 2x slower than expected | Copy op before the dot | Input tiling mismatch causing retile | Use AUTO input layouts or restructure computation |
| OOM despite memory estimate says it should fit | Memory Profile shows 1.5x expected peak | Tiling padding + XLA temporaries | Reduce batch size, enable activation recomputation, or reduce TP degree |
| GPU idle 30% of the time | Gaps between ops in timeline | Python-side data loading too slow | Prefetch data, use async data loaders, profile host side |
| What to Check | What to Expect | Red Flag |
|---|---|---|
| Matmul duration | Within 5-10% of roofline | More than 2x roofline → wrong sharding or layout |
| Collective duration | Close to B/W prediction | Much longer → check for unexpected AllGathers |
| Gaps between ops | Minimal idle time | Large gaps → pipeline bubbles or stalls |
| Copy ops | Rare and small | Many copies → tiling mismatches, use AUTO layout |
| Peak memory | Close to theoretical estimate | Much higher → check for retained activations |