From roofline models to 5D parallelism — a complete guide to training and serving LLMs on TPUs and GPUs, with profiling, JAX programming, and real benchmarks.
Arithmetic intensity, compute vs memory bound, and back-of-the-envelope performance modeling.
TPU architecture: MXU, VPU, HBM, VMEM, ICI interconnects, and chip-to-chip topologies.
1D and 2D weight sharding, collective communication, and the GSPMD framework.
Attention, MLPs, and full Transformer roofline analysis across sharding strategies.
Optimizer states, activation memory, gradient checkpointing, and data parallelism.
End-to-end analysis of training a real LLM: memory budget, sharding plan, throughput.
Prefill vs decode, KV cache, batched inference, and memory-bound decode analysis.
Serving a real LLM at scale: throughput, latency, and optimizing for production.
The JAX profiler, TensorBoard traces, reading HLO ops, memory profiles, worked examples.
Auto, explicit, and manual sharding modes. shard_map, collective matmuls, worked problems.