Austin et al., Part 2

How to Think About TPUs

A TPU is a matrix-multiply machine bolted to fast memory. Understanding its memory hierarchy and networking is the key to writing efficient distributed code.

Prerequisites: Chapter 1 (Roofline Analysis). Familiarity with matrix multiplication and bandwidth concepts.
9
Chapters
2
Simulations
9
Quizzes

Chapter 0: What Is a TPU?

Strip away the marketing and a TPU is remarkably simple. It is a compute core that specializes in matrix multiplication (called a TensorCore) attached to a stack of fast memory (called HBM).

The TensorCore has three key units:

MXU (Matrix Multiply Unit)
The heart of the TPU. A 128×128 systolic array that performs one bf16[8,128] × bf16[128,128] → f32[8,128] matmul every 8 cycles.
VPU (Vector Processing Unit)
Handles element-wise operations: ReLU, addition, reductions. An 8×128 SIMD array, much slower than MXU.
VMEM (Vector Memory)
On-chip scratchpad, ~128 MiB on TPU v5e. Much smaller than HBM but much higher bandwidth to MXU. Data must pass through VMEM before reaching the compute units.
Key mental model: TPUs load weights from HBM into VMEM, then from VMEM into a systolic array that performs ~200 trillion multiply-adds per second. The bandwidths of HBM↔VMEM and VMEM↔MXU set the fundamental limits.

Why is matmul so special? Because it uses O(n3) compute for O(n2) bytes. That makes it very easy for the compute to outpace the memory bandwidth. No other common operation has this property — this is why architectures dominated by matmul are so amenable to scaling.

A TPU TensorCore has three main units. Which one does the heavy lifting for matrix multiplication?

Chapter 1: The MXU & Systolic Arrays

At the core of the MXU is a 128×128 systolic array (256×256 on TPU v6e). That is 16,384 ALUs, each capable of a multiply-and-add per cycle.

How it works

Imagine a grid of 128×128 processing elements (PEs). Weights are loaded from the top (the "RHS"), filling the array diagonally. Activations are fed from the left (the "LHS"), also diagonally. Each PE multiplies its activation with its weight, adds the result to the partial sum passed from above, and passes the new partial sum down.

Systolic Array Animation

Watch how weights (blue) and activations (green) flow through the array. Click Step to advance.

Cycle 0 — Loading weights

After the initial pipeline bubble (while weights load diagonally), each subsequent cycle produces a valid output element. New inputs and weights can be streamed in without additional bubbles — the array stays fully saturated.

Performance: When fully saturated, the systolic array performs one bf16[8,128] × bf16[128,128] → f32[8,128] multiply every 8 cycles. At 1.5 GHz on TPU v5e, that is about 5e13 bf16 FLOPs/s per MXU. Most TensorCores have 2 or 4 MXUs, giving a total of 2e14 bf16 FLOPs/s per TPU v5e chip.

Padding requirement: Because the systolic array is 128×128, weight matrices must be padded to at least size 128 in both dimensions. Matrices smaller than 128 waste ALUs. On TPU v6e (256×256 MXU), the minimum is 256.

Lower precision means higher throughput. TPUs can do int8 OPs roughly 2x faster than bf16 FLOPs, and int4 at 4x.

The TPU v5e systolic array is 128×128. If your weight matrix is 64×64, what happens?

Chapter 2: VPU & VMEM

The VPU

The VPU (Vector Processing Unit) handles everything that is not a matmul: activations like ReLU/GELU, element-wise operations, reductions (sums). It is an 8×128 SIMD unit where each (lane, sublane) pair contains 4 independent ALUs.

At 1.75 GHz on TPU v5p, the VPU achieves:

8 × 128 × 4 ALUs × 1.75e9 Hz = 7e12 FLOPs/s per core

That is about 30x slower than the MXU (~2e14 per core). This is why we try to express as much of the computation as matmul.

VMEM

VMEM (Vector Memory) is the on-chip scratchpad, sitting between HBM and the compute units. Think of it as a programmer-controlled L1/L2 cache — but much larger (128 MiB on TPU v5e).

MemoryCapacity (TPU v5e)BandwidthRole
HBM16 GiB8.2e11 B/sMain storage for all tensors
VMEM128 MiB~22x HBM ≈ 1.8e13 B/sScratchpad; data must pass through here
VREGs~256 KiB/coreCycle-speedRegisters for VPU/MXU input/output
VMEM and arithmetic intensity: If you can fit your weights in VMEM instead of HBM, the effective bandwidth jumps ~22x. This drops the critical intensity from ~240 down to ~10. Operations that would be bandwidth-bound from HBM can become compute-bound from VMEM. This is the basis of "VMEM prefetching" — loading the next layer's weights into VMEM during the current layer's compute.

For example, during attention we might prefetch the large FFN weights into VMEM. If the weights fit (or are sharded small enough), the following FFN matmul runs at much higher efficiency.

VMEM bandwidth is ~22x higher than HBM bandwidth. If the HBM-based critical intensity is 240, what is the VMEM-based critical intensity?

Chapter 3: HBM & Pipelining

HBM (High Bandwidth Memory) is the big chunk of fast memory that stores all tensors. Capacity is typically tens of GiB (16 GiB on TPU v5e, 96 GiB on TPU v5p).

The data flow for a matmul X · A → Y looks like this:

1. HBM → VMEM
Stream chunks of X and A from HBM into VMEM scratchpad
2. VMEM → VREGs → MXU
Load chunks into registers, feed into systolic array
3. MXU → VREGs → VMEM → HBM
Results flow back through the same path to HBM
Pipelining is everything: All these steps are pipelined and overlapped. While the MXU processes chunk k, chunk k+1 is being loaded from HBM to VMEM, and chunk k-1's result is being written back. This overlap is what keeps the MXU saturated and matmuls compute-bound.

Element-wise operations (VPU) follow the same pipeline: data streams from HBM → VMEM → VREGs → VPU → VREGs → VMEM → HBM, with partial results pipelined without waiting for the full array.

PCIe: the slow link

The CPU host connects to its TPU tray via PCIe, which is about 1.6e10 bytes/s per TPU (3.2e10 on v6e). That is ~100x slower than HBM bandwidth. Loading data from host RAM to HBM is a bottleneck best avoided.

Why does pipelining HBM-to-VMEM transfers with MXU compute keep matmuls compute-bound?

Chapter 4: Chips, Trays & Hosts

A TPU chip typically has 2 TensorCores that share memory and act as one accelerator (called "megacore" configuration since TPU v4). Exception: inference chips like TPU v5e have just 1 core per chip.

Chips sit on trays. A tray holds 4 chips connected to a single CPU host via PCIe. For TPU v5e, each host has 2 trays (8 chips = 8 cores). For training chips like v5p, each host has a 2×2×1 topology of 4 chips.

LevelWhatExample (v5e)
Core1 TensorCore (MXU + VPU + VMEM)1 per chip
Chip1–2 cores + HBM16 GiB HBM, 1.97e14 FLOPs/s
Tray4 chipsConnected via ICI
HostCPU + 1–2 trays via PCIe8 chips per host
SliceICI-connected chipsUp to 16×16 = 256 chips
Pod/SuperpodMaximum ICI topologyv5p: 16×20×28 = 8960 chips
GPU comparison: NVIDIA GPUs within a node (8 H100s or up to 72 B200 NVL72) are connected via NVLink switches that approximate point-to-point connections. TPUs instead use nearest-neighbor ICI links, which are cheaper and scale to larger topologies but require data to hop through intermediate chips.
A TPU v5p superpod (16×20×28) has how many chips total?

Chapter 5: ICI Networking

ICI (Inter-Chip Interconnect) is the direct chip-to-chip link that forms the TPU's communication fabric. It does NOT go through the CPU host.

Topology

TPU v5e and v6e use a 2D torus (4 nearest neighbours per chip). TPU v4 and v5p use a 3D torus (6 nearest neighbours per chip). The toroidal wraparound reduces the maximum distance between any two chips from N to N/2.

Wraparound rules: Full-size axes (16 for v5e/v6e, or multiples of 4 on v*p with optical switches) get wraparound links. Smaller topologies (like 4×4 on v5e) do NOT get wraparounds, which doubles communication time for ring-based collectives.

The speed hierarchy

LinkBandwidth (per chip)Relative Speed
HBM ↔ TensorCore~1–3 TB/sFastest
ICI (per axis)45–90 GB/s unidirectional~10–30x slower than HBM
PCIe (host ↔ chip)~16 GB/s~100x slower than HBM
DCN (between hosts)~6 GB/sSlowest

Multi-slice training: ICI-connected chips form a "slice." Different slices connect via DCN (data-center networking), which is much slower. DCN goes host-to-host, requiring PCIe transfers on both ends. Minimizing DCN traffic is critical for multi-slice training.

TPU Topology Visualizer

Toggle topology size. Lines show ICI links, dashed lines show wraparounds.

On a TPU v5e 4×4 slice (no wraparounds), how many ICI hops separate chip (0,0) from chip (3,3)?

Chapter 6: TPU Specs

Here are the key numbers you need for roofline calculations. Memorize the order of magnitude — exact values change slightly between sources.

Compute & Memory

ModelPod SizeHBM/chipHBM BW/chipbf16 FLOPs/s/chipint8 OPs/s/chip
v332×3232 GB9.0e111.4e141.4e14
v4p16×16×1632 GB1.2e122.75e142.75e14
v5p16×20×2896 GB2.8e124.59e149.18e14
v5e16×1616 GB8.2e111.97e143.94e14
v6e16×1632 GB1.6e129.20e141.84e15

Interconnects

ModelICI BW/link (1-way)ICI BW/link (bidi)
v31.0e112.0e11
v4p4.5e109.0e10
v5p9.0e101.8e11
v5e4.5e109.0e10
v6e9.0e101.8e11

PCIe: ~1.6e10 bytes/s per TPU (3.2e10 for v6e).

DCN: ~6.25e9 bytes/s per TPU (12.5e9 for v6e, 3.125e9 for v5e).

Bidirectional bandwidth means total bytes that can flow along a link in both directions simultaneously. We use it when a full ring exists (wraparound links present). Each device sends V/N bytes in each direction during ring-based collectives.
What is the total bf16 FLOPs/s for a full TPU v5e pod (16×16 = 256 chips)?

Chapter 7: Worked Problems

Problem 1: Bounding LLM latency

You want to sample from a 200B parameter model in bf16 split across 32 TPU v4p. How long to load all parameters from HBM?

Total bytes = 2 × 200e9 = 400e9 bytes
Per chip = 400e9 / 32 = 12.5e9 bytes
Time = 12.5e9 / 1.2e12 = 10.4 ms
This is a lower bound on sampling latency. Each autoregressive step must load all parameters from HBM (at small batch sizes, the matmul is bandwidth-bound). 10 ms per token is achievable in practice — and tells you the minimum hardware needed for a given latency target.

Problem 2: Full pod numbers

TPU v5e pod (16×16): 256 chips, 1 core each = 256 TensorCores. Hosts: 256/8 = 32 hosts. Total FLOPs/s: 256 × 1.97e14 = 5.04e16. Total HBM: 256 × 16 = 4 TB.

TPU v5p pod (16×20×28): 8960 chips, 2 cores each = 17,920 TensorCores. Hosts: 8960/4 = 2240 hosts. Total FLOPs/s: 8960 × 4.59e14 = 4.1e18. Total HBM: 8960 × 96 = 860 TB.

Problem 3: PCIe operational intensity

Weights bf16[D, F] and activations bf16[B, D] stored in host DRAM. Multiplied on a single TPU v6e. With B ≪ D and F = 4D:

FLOPs = 2BDF = 8BD2
Bytes over PCIe = 2(BD + DF + BF) ≈ 2DF = 8D2
Compute time = 8BD2 / 9.2e14
PCIe time = 8D2 / 1.6e10
Compute > PCIe ⇒ B > 9.2e14 / 1.6e10 ≈ 57,500

You need a batch of ~57,500 tokens before computation outpaces PCIe loading. This is why we keep tensors in HBM, not host RAM.

Problem 4: ICI transfer time

Send bf16[8, 128, 8192] from TPU{0,0} to TPU{3,3} on a v5e 4×4 slice (no wraparounds).

Bytes = 2 × 8 × 128 × 8192 = 16.8 MB
Hops = 3 + 3 = 6 (no wraparounds), latency = 6 μs
Transfer = 16.8e6 / (2 × 4.5e10) = 188 μs

First byte arrives in ~6 μs. Full transfer completes in ~188 μs.

A 200B bf16 model on 32 TPU v4p chips takes ~10 ms to load. If we double to 64 chips, the load time becomes:

Chapter 8: Summary

ConceptKey Takeaway
TPU architectureMXU (systolic array for matmul) + VPU (element-wise) + VMEM (scratchpad) + HBM (main memory)
MXU128×128 systolic array (256×256 on v6e), ~200T multiply-adds/s
VMEM~22x faster than HBM but ~128 MiB. Prefetching weights here enables better efficiency
PaddingMatrices must be ≥128 in both dims to fill the MXU
Speed hierarchyHBM >> ICI >> PCIe >> DCN
TPU vs GPUTPUs: nearest-neighbor ICI, cheaper, scales larger. GPUs: NVLink switches, richer connectivity per node
Multi-sliceSlices (ICI) connect via DCN (host-to-host). Minimize DCN traffic
The mental model: A TPU is a matrix-multiply machine (MXU) connected to memory (HBM, fast), other chips (ICI, rather fast), and the datacenter (DCN, slow). Every optimization is about keeping the MXU fed with data fast enough that it never waits.
Which link would you most want to avoid putting on the critical path of a distributed matmul?