Austin et al., Chapter 9

How to Profile TPU Code

From back-of-the-envelope calculations to real measurements — the JAX profiler, TensorBoard traces, HLO reading, and worked examples.

Prerequisites: Basic JAX familiarity (jax.jit, jnp operations) and TPU architecture (MXU, VMEM, HBM from Chapter 2).
9
Chapters
1
Simulation
9
Quizzes

Chapter 0: Why Profile?

So far the scaling book has been entirely theoretical: back-of-the-envelope calculations based on hardware rooflines. That understanding gets you remarkably far. You can predict matmul runtimes, estimate communication costs, and choose sharding strategies — all from first principles.

But a lot of real optimization comes down to practical details: how the XLA compiler actually lowers your code, what fusions it creates, where it inserts copies, and whether your theoretical predictions match reality. When they do not, you need tools to understand what went wrong.

The profiler closes the loop. Theory tells you what should happen. Profiling tells you what did happen. The gap between the two is where all the interesting optimization lives.

Think of profiling as the experimental counterpart to roofline analysis. The roofline model is your hypothesis: "this matmul should take 96 ms." The profile is your experiment: "it actually took 96 ms" (great!) or "it actually took 250 ms" (something is wrong). Without profiling, you are flying blind.

This chapter introduces the JAX/TensorBoard Profiler — a multi-purpose instrument for understanding what happens on a TPU when your program runs. We will cover:

Check: What is the primary purpose of profiling TPU code?

Chapter 1: The TPU Software Stack

Google exposes a stack of APIs for programming TPUs, from high-level JAX code to low-level Pallas or HLO. Most programmers write JAX code exclusively, which lets you write abstract NumPy-style linear algebra that is compiled automatically to run efficiently on TPUs.

Here is a simple example — a JAX program that multiplies two matrices:

import jax
import jax.numpy as jnp

def multiply(x, y):
  return jnp.einsum('bf,fd->db', x, y)

y = jax.jit(multiply)(jnp.ones((128, 256)),
                       jnp.ones((256, 16), dtype=jnp.bfloat16))

By calling jax.jit, we tell JAX to trace this function and emit a lower-level IR called StableHLO, a platform-agnostic IR for ML computation. This is then lowered to HLO by the XLA compiler.

JAX (Python)
NumPy-style linear algebra. What you write.
↓ jax.jit traces
StableHLO
Platform-agnostic ML IR. Portable across backends.
↓ XLA compiler lowers
HLO (High Level Optimizer)
XLA's graph IR. Fusions, layouts, sharding. What you see in profiles.
↓ further lowering
LLO (Low Level Optimizer)
Programs TPU directly: DMA scheduling, systolic array push/pull.
↓ compile
Machine Code
Loaded into TPU IMEM and executed.

The compiler runs many optimization passes to determine fusions (grouping related ops together), layouts (how arrays are tiled in memory), and other factors. The resulting HLO is what you observe in a JAX profile.

Here is the HLO for the matmul above (abridged):

ENTRY %main.5 (Arg_0.1: f32[128,256], Arg_1.2: bf16[256,16]) -> f32[16,128] {
  %Arg_1.2 = bf16[256,16]{1,0} parameter(1)
  %convert.3 = f32[256,16]{1,0} convert(bf16[256,16] %Arg_1.2)
  %Arg_0.1 = f32[128,256]{1,0} parameter(0)
  ROOT %dot.4 = f32[16,128]{1,0} dot(
    f32[256,16] %convert.3,
    f32[128,256] %Arg_0.1),
    lhs_contracting_dims={0}, rhs_contracting_dims={1}
}

Notice how closely the HLO matches the original JAX code. The ROOT %dot.4 instruction is the actual matmul, contracting along dimensions 0 and 1 of the two f32 matrices.

Key insight: To get the HLO of any jitted function, use jax.jit(f).lower(*args).compile().as_text(). Reading HLO is one of the most valuable skills for TPU performance engineering.

When something goes wrong at a lower level, you have one more escape hatch: writing custom kernels in Pallas. But most of the time, you work at the JAX level and use the profiler to understand the HLO below.

Check: At which level of the stack do you primarily see operations when looking at a JAX profile?

Chapter 2: The JAX Profiler

JAX provides a multi-purpose TPU profiler that records everything from the duration of each subcomponent, the HLO of each program, memory usage, and more. You use the jax.profiler module to capture a trace while your program runs.

Here is the minimal recipe for capturing a profile:

import jax

with jax.profiler.trace("/tmp/tensorboard"):
  key = jax.random.key(0)
  x = jax.random.normal(key, (1024, 1024))
  y = x @ x
  y.block_until_ready()

# View in TensorBoard:
# pip install tensorboard tensorboard-plugin-profile
# tensorboard --logdir=/tmp/tensorboard
Important: The block_until_ready() call is essential. JAX uses asynchronous dispatch — without blocking, the profiler context may close before computation actually runs on the TPU.

Once in TensorBoard, the profiler exposes several key views:

ViewWhat It ShowsUse Case
Trace ViewerChronological timeline of all TPU operationsUnderstanding what each part of your program does and how long it takes
Graph ViewerHLO dataflow graph with shape/sharding infoUnderstanding how operations feed into each other
Memory ProfileMemory usage over timeDebugging OOMs, understanding peak memory
Memory ViewerDetailed buffer-level memory breakdownFinding which tensors consume the most memory
Pro tip: You can also share profiles using Perfetto. While it only supports the Trace Viewer component (not Graph Viewer or Memory Profile), it is excellent for sharing and collaborative debugging.

For annotating your traces, use jax.named_scope to label regions of code. These names show up in the Trace Viewer as nested scopes, making it much easier to identify which part of a Transformer corresponds to which operations:

with jax.named_scope("attention"):
  q = x @ w_q
  k = x @ w_k
  v = x @ w_v
  attn = softmax(q @ k.T / sqrt_dk) @ v

with jax.named_scope("ffn"):
  h = gelu(attn @ w_up)
  out = h @ w_down
Check: Why is block_until_ready() needed when profiling?

Chapter 3: The Trace Viewer

The Trace Viewer is probably the most useful part of the profiler. It shows a chronological timeline of all actions on each TPU core, letting you see exactly what the hardware is doing and for how long.

A few key things to understand about the Trace Viewer:

Top Row: XLA Ops
The actual TPU operations with HLO names. This is ground truth.
Lower Rows: Python Scopes
Approximate traces from jax.named_scope, jax.named_call, and Python stack traces.

In a typical Transformer profile, you will see repeated blocks — one per layer. Within each block, you can identify the attention portion and the MLP/FFN portion. The XLA ops row tells you what is actually executing, while the scope rows help you map back to your code.

Navigation tip: Use "video game" style controls in the Trace Viewer: A/D to pan left/right, W/S to zoom in/out. These make navigating much easier than dragging.

When you click on an XLA op in the trace, you get detailed information:

Since all TPUs typically execute the same instructions in SPMD programs, you usually only need to look at TPU:0. The timeline for other cores will be nearly identical.

What to look for: Unexpected gaps between ops (stalls), unexpectedly long operations, large collectives (AllGather, AllReduce) that dominate the timeline, and copy/retiling ops that should not be there.

Common Trace Patterns

PatternWhat It Looks LikeWhat It Means
Dense compute blocksLong, solid fusion barsGood — the TPU is doing useful matmul work
Many small opsRapid alternation of tiny barsKernel launch overhead may dominate; consider fusing
Large AllGatherWide bar labeled all-gather or all-reduceExpensive communication; check if sharding is optimal
Copy opsBars labeled "copy" or "dynamic-update-slice"Retiling between operations; may indicate layout mismatches
Gaps / idle timeEmpty space between opsPipeline bubbles, host-device sync stalls, or data starvation

Annotating Your Code

Good annotations are the difference between a readable profile and a wall of inscrutable HLO names. The key tools:

For production Transformer code, annotate at minimum: each layer, attention vs MLP, each projection, and each collective operation. This makes the profile immediately navigable.

Check: In the Trace Viewer, which row shows ground truth about what the TPU is actually doing?

Chapter 4: Reading HLO Ops

HLO is not actually very hard to read, and it is immensely helpful for understanding what a given part of the trace corresponds to. Let us break down the anatomy of an HLO op with a real example.

Consider this op called fusion.3:

%fusion.3 = bf16[32,32,4096]{2,1,0:T(8,128)(2,1)S(1)}
  fusion(bf16[32,32,8192]{2,1,0:T(8,128)(2,1)S(1)} %fusion.32),
  kind=kCustom, calls=%all-reduce-scatter.3

Let us unpack each component:

ComponentExampleMeaning
Op Namefusion.3A fusion or dot op: a set of operations containing at most 1 matmul plus related pointwise VPU-ops
Output Shapebf16[32,32,4096]Output dtype is bf16 (2 bytes/element), shape is [32, 32, 4096]
Layout{2,1,0:T(8,128)(2,1)}Physical memory ordering and tiling of the array
Memory LocationS(1)S(0) = HBM, S(1) = VMEM, S(2)/S(3) = other memory spaces
Argumentsbf16[32,32,8192] %fusion.32This op takes one input: a bf16 array called fusion.32

Understanding Tilings

The layout notation tells us how an N-dimensional array is laid out sequentially in memory. Consider a simpler example:

f32[3,5]{1,0:T(2,2)}

This means:

Padding from tiling: The T(2,2) tiling requires the array to be padded to [4, 6], expanding memory usage by about 1.6x. Tiling can significantly affect memory efficiency and may cause XLA to insert retiling copies at non-trivial overhead.

For the more complex tiling bf16[32,32,8192]{2,1,0:T(8,128)(2,1)S(1)}, we have two levels of tiling: an outer (8, 128) tiling and an inner (2, 1) tiling within that. The inner (2, 1) tiling is used for bf16 so that loads are always multiples of 4 bytes.

Worked Tiling Example

Let us walk through a concrete padding calculation. Consider bf16[3, 5]{1,0:T(8,128)(2,1)}:

  1. The outer tiling T(8, 128) means each tile is 8 rows × 128 columns
  2. To fit our [3, 5] array, we need ceil(3/8)=1 row tile and ceil(5/128)=1 col tile
  3. Padded shape: [8, 128] — much larger than [3, 5]
  4. Memory overhead: (8×128) / (3×5) = 68x wasted memory for this tiny array
Why this matters for real models: Tiling overhead is negligible for large tensors (a [4096, 8192] array padded to [4096, 8192] has zero waste), but for small tensors like bias vectors, embedding lookups, or MoE routing tables, tiling can waste significant memory. Watch for unexpectedly large memory usage in the Memory Profile — tiling overhead is often the culprit.

Common HLO Op Types

Op TypeHLO Name PatternWhat It Does
Matrix multiplydot, fusion with a dot insideThe workhorse — runs on the MXU/Tensor Core
Pointwise opsadd, multiply, tanh, fused into a fusionRuns on the VPU/CUDA cores
AllReduceall-reduceSum/mean across devices
AllGatherall-gatherCollect shards from all devices
ReduceScatterreduce-scatterReduce then distribute shards
Copy / retilecopy, bitcastLayout conversion between operations
Dynamic slicedynamic-sliceIndexing into an array (e.g., embedding lookup)
Layout tip: JAX provides an experimental feature to specify input layouts as AUTO. When you jax.jit a function, you can let XLA compute its preferred layout for inputs, avoiding retiling copies within the program.
Check: What does S(1) mean in an HLO op annotation?

Chapter 5: The Graph Viewer

While some fusions in the Trace Viewer can look complicated, the XLA Graph Viewer makes them much easier to parse. It shows the HLO as a visual dataflow graph — boxes for operations, arrows for data dependencies.

The Graph Viewer is especially helpful for:

Workflow tip: The best way to use the Graph Viewer is to click on an XLA op in the Trace Viewer, then follow the "View in Graph" link. This drops you directly into the right part of the graph, from where you can explore neighbors.

A common pattern you will see in the Graph Viewer is a fusion containing a dot (matmul) followed by some pointwise operations. For example, an MLP up-projection might show:

parameter: bf16[B, D]
Input activations from the previous layer
dot: bf16[B, D] * bf16[D, F] → bf16[B, F]
The actual matmul (up-projection)
gelu / multiply: bf16[B, F]
Activation function fused into the same op

XLA's fusion pass is quite aggressive. It will fuse the matmul with subsequent pointwise operations (ReLU, GeLU, bias adds, etc.) into a single operation, eliminating intermediate memory traffic. This is one of XLA's main advantages: you write clean, modular code and the compiler fuses it for free.

However, fusions are not always beneficial. Sometimes XLA fuses operations that would be better left separate, or it fails to fuse operations you expect. The Graph Viewer helps you verify that the compiler is doing what you want.

Practice recommendation: Spend time staring at HLO graphs and trying to map HLO ops onto the code you are profiling. This skill — reading the compiler's output — is one of the most transferable skills in performance engineering.
Check: What is the primary benefit of XLA's fusion pass for TPU performance?

Chapter 6: A Worked Profile

Let us walk through a real profile of a Transformer and verify our theoretical predictions. We will look at a model running on 8 TPU v2 cores with 4-way data parallelism and 2-way model parallelism.

The FFN Block

Zooming into the FFN (feed-forward network) block, we see the up-projection matmul. The HLO tells us the per-shard shapes:

// Per-shard computation:
bf16[8, 1024, 8192] * bf16[8192, 16384] -> bf16[8, 1024, 16384]

// True (unsharded) shapes (4-way DP, 2-way MP):
X: bf16[32, 1024, 8192] * Win: bf16[8192, 32768]
  -> Tmp: bf16[32, 1024, 32768]

How long should this take? Our batch size per DP shard is 8 * 1024 = 8192 tokens, so we are solidly compute-bound.

Texpected = 2 × 32 × 1024 × 8192 × 32768 / (23 × 1012 × 8) = 95.6 ms

The profile reports 96 ms. That is essentially perfect FLOPs utilization — we are hitting the roofline.

This is the power of profiling + roofline analysis. We predicted 95.6 ms from first principles. The hardware delivered 96 ms. When these numbers match, you know there is nothing left to optimize for that particular op.

Communication: The ReduceScatter

At the end of the second matmul, there is a small fusion that is actually a ReduceScatter:

%fusion.1 = bf16[8,1024,4096]{2,1,0:T(8,128)(2,1)S(1)}
  fusion(bf16[8,1024,8192]{...} %fusion.31),
  kind=kCustom, calls=%all-reduce-scatter.1

Expected time: the array has size 2 * 32 * 1024 * 8192 = 128 MB with the batch axis sharded 4 ways on a TPU v2 4x2 (1.2 × 1011 bidirectional bandwidth). One hop required:

Texpected = 128 MB / 1.2 × 1011 = 1.1 ms

Profile reports: 1.13 ms. Again, essentially at the roofline.

Attention

The attention component shows the Q, K, V projections as separate matmuls. The Q projection uses weight matrix WQ of shape [dmodel=8192, nheads=32, dqkv=256], with Megatron sharding along the head dimension.

Let us do the exercise. The Q projection has true shape:

X: bf16[32, 1024, 8192] × WQ: bf16[8192, 32, 256] → Q: bf16[32, 1024, 32, 256]

With Megatron sharding along the 32 heads (2-way MP), each shard computes 16 heads. The FLOPs per shard:

2 × 8 × 1024 × 8192 × 16 × 256 / 8 = 2 × 8 × 1024 × 8192 × 512 = 68.7 × 1012 FLOPs / 8 cores

At 23 TFLOPS per core: T ≈ 68.7 × 1012 / (23 × 1012 × 8) ≈ 373 ms. If the profile shows a similar number, we are at the roofline. If much longer, check for retiling or communication overhead.

Verifying Communication Predictions

The workflow for any operation in a profile:

1. Read the HLO op name and shapes
Identify the true (unsharded) shapes and sharding strategy
2. Calculate the roofline prediction
T = max(FLOPs / peak_FLOPs, bytes / bandwidth)
3. Compare against the profile's reported duration
Within 10%? Great. Much worse? Investigate.
4. If far from roofline, check for:
Retiling copies, unexpected collectives, fusion failures, memory bandwidth limits
Check: If a profiled matmul takes 96 ms and our roofline prediction was 95.6 ms, what does this tell us?

Chapter 7: Memory Profile

The Memory Profile view shows program memory as a function of time. This is invaluable for debugging out-of-memory (OOM) errors — the most common failure mode when scaling models.

In a typical profile, you will see:

Memory RegionTypical ContentsBehavior Over Time
ParametersModel weightsConstant — allocated at initialization
Optimizer statesAdam moments, gradient buffersConstant after first step
ActivationsIntermediate tensors from forward passGrows during forward, freed during backward
TemporariesBuffers for collectives, retilingShort-lived spikes
Peak memory is what matters. OOMs happen at peak memory, which typically occurs at the boundary between the forward pass and backward pass — where all activations are still live and the first gradients are being computed.

The Memory Profile also helps answer practical questions:

The Memory Viewer (a separate tab) provides a more granular view, showing individual buffer allocations. This is useful for identifying which specific tensors are consuming the most memory.

Common pitfall: XLA may allocate extra memory for retiling (layout conversion) buffers. If you see mysterious memory consumption, check the HLO for copy operations that convert between different tilings.

Practical Memory Debugging Workflow

When you hit an OOM, follow this systematic approach:

  1. Estimate theoretical memory: Calculate expected memory for weights (2 bytes/param for bf16), optimizer states (8 bytes/param for Adam), and activations (depends on batch size and sequence length)
  2. Capture a profile just before the OOM: Reduce batch size or sequence length until training succeeds, then profile
  3. Compare Memory Profile peak to your estimate: If the peak is much higher, the difference is overhead from tiling, temporaries, or XLA's memory allocator
  4. Use Memory Viewer for the culprits: Sort by buffer size to find the largest allocations
  5. Check for activation recomputation: If enabled, verify the sawtooth pattern in the Memory Profile

Memory Estimation Formulas

ComponentFormula (bf16 weights, Adam)Example: 7B model
Weights2 × params14 GB
Gradients2 × params14 GB
Adam m (first moment)4 × params (fp32)28 GB
Adam v (second moment)4 × params (fp32)28 GB
Master weights (fp32)4 × params28 GB
Total model state16 × params112 GB

This 112 GB already exceeds a single H100's 80 GB HBM — before any activations. This is why sharding (ZeRO, TP, PP) is essential for models ≥ 7B.

Check: When does peak memory typically occur during a training step?

Chapter 8: Worked Problems

Let us put everything together with some worked problems that test your ability to read profiles and reason about what is happening.

Problem 1: Mystery Profile

You see a profile with two big fusions, a reduce, and an AllReduce. One fusion has this HLO:

%fusion.1 = bf16[4096]{0:T(1024)(128)(2,1)}
  fusion(
    bf16[4096,8192]{1,0:T(8,128)(2,1)} %param.1,
    bf16[8192]{0:T(1024)(128)(2,1)} %reduce.6
  ), kind=kLoop, calls=%fused_computation.1

The final AllReduce has replica_groups={{0,16,32,48,64,80,96,112}, ...}.

Analysis: The per-shard computation is bf16[8192] * bf16[4096, 8192] → bf16[4096] (contracting over the 8192 dimension). The replica groups show 8-way model parallelism (stride of 16 in a 128-chip setup), so the true shapes are:
bf16[8, 8192] × bf16[32768, 8192] → bf16[8, 32768]

This is two chained matrix multiplications with 8-way MP sharding:

def matmul(w1, w2, x):
  return jnp.einsum('wf,bf->bw', w2,
                     jnp.einsum('fw,bw->bf', w1, x))

Problem 2: Fixing Sharding

You profile a Transformer and find an unexpectedly large AllGather taking 80% of the step time. This is a sign that XLA's automatic partitioner (GSPMD/Shardy) chose a poor sharding for some intermediate tensor.

The fix: use jax.lax.with_sharding_constraint to guide the compiler:

def ffn(x, w_up, w_down):
  h = jnp.einsum('bd,df->bf', x, w_up)
  # Force h to be sharded along F, not gathered
  h = jax.lax.with_sharding_constraint(h, P('X', 'Y'))
  h = jax.nn.gelu(h)
  return jnp.einsum('bf,df->bd', h, w_down)
The profiling workflow:
1. Run with profiling enabled
2. Check the Trace Viewer for unexpected ops or timing
3. Compare measured times against roofline predictions
4. When something is wrong, read the HLO to understand what XLA is doing
5. Use with_sharding_constraint or shard_map to fix it
6. Re-profile and verify the fix

Problem 3: Common Performance Anti-Patterns

Here are real-world patterns you will encounter and how to diagnose them from a profile:

SymptomProfile AppearanceRoot CauseFix
Step time 3x expectedGiant AllGather taking 70% of stepXLA gathered full weight matrix instead of using TPAdd with_sharding_constraint to intermediate activations
Matmul 2x slower than expectedCopy op before the dotInput tiling mismatch causing retileUse AUTO input layouts or restructure computation
OOM despite memory estimate says it should fitMemory Profile shows 1.5x expected peakTiling padding + XLA temporariesReduce batch size, enable activation recomputation, or reduce TP degree
GPU idle 30% of the timeGaps between ops in timelinePython-side data loading too slowPrefetch data, use async data loaders, profile host side

Key Profiling Checklist

What to CheckWhat to ExpectRed Flag
Matmul durationWithin 5-10% of rooflineMore than 2x roofline → wrong sharding or layout
Collective durationClose to B/W predictionMuch longer → check for unexpected AllGathers
Gaps between opsMinimal idle timeLarge gaps → pipeline bubbles or stalls
Copy opsRare and smallMany copies → tiling mismatches, use AUTO layout
Peak memoryClose to theoretical estimateMuch higher → check for retained activations
The profiling loop is never done. Even after fixing the obvious issues, re-profile. Each fix may reveal new bottlenecks that were previously hidden behind larger ones. The best TPU/GPU engineers profile obsessively — every change gets measured.
Check: If you see an unexpectedly large AllGather dominating your profile, what is the most likely cause and fix?