← Gleams
Stanford CS 231n · Lecture 11 · Large-Scale Distributed Training

Training Models Across Thousands of GPUs

Llama 3 has 405 billion parameters and needs 3.2 TB of memory just for training state. A single GPU has 80 GB. How do you bridge a 40× gap? You distribute.

24,576 H100 GPUs Data & model parallelism Pipeline & tensor sharding 3.2 TB → 80 GB
Roadmap

What You'll Master

Chapter 01

Why Distributed? — The Scale Problem

Imagine you want to train Llama 3 405B — Meta's largest open-weight language model. It has 405 billion parameters. Let's do some back-of-the-envelope math to see why a single GPU is hopeless.

The Memory Wall

During training, every parameter needs four numbers stored alongside it:

ItemPer ParameterPrecision
Weight1 numberBF16 (2 bytes)
Gradient1 numberBF16 (2 bytes)
Adam first moment (m)1 numberFP32 (4 bytes)
Adam second moment (v)1 numberFP32 (4 bytes)

That's 2 + 2 + 4 + 4 = 12 bytes per parameter for mixed-precision training (or ~8 bytes if you count weights in BF16 with FP32 optimizer states, totaling about 16 bytes when the master FP32 weight copy is included). For Llama 3:

Memory for Llama 3 405B Training 405 × 109 params × (2 + 2 + 4 + 4) bytes ≈ 4.86 TB
An H100 GPU has 80 GB. That's over 60 GPUs worth — just for the model state.

And this doesn't include activations (the intermediate values saved for backpropagation), communication buffers, or the data itself.

The Compute Wall

How many floating-point operations (FLOPs) does training require? A widely-used estimate for transformer training:

Training FLOPs Estimate FLOPs ≈ 6 × N × D
N = number of parameters, D = number of training tokens. Factor of 6: 2 for forward pass multiply-adds, plus 2× for the backward pass.
Worked Example — Llama 3 Training Compute

N = 405 × 109, D = 15 × 1012 tokens.

FLOPs = 6 × 405 × 109 × 15 × 1012 = 3.645 × 1025 FLOPs.

An H100 does ~989 TFLOP/s (BF16 tensor core peak) = 989 × 1012 FLOP/s.

On 1 GPU: 3.645 × 1025 / (989 × 1012) = 3.69 × 1010 seconds ≈ 1,170 years.

On 24,576 GPUs: 1,170 / 24,576 ≈ 17 days at ~40% hardware utilization (MFU).

Meta actually trained Llama 3 for about 54 days, reflecting real-world overheads, restarts, and the fact that peak utilization is never sustained.

The Core Challenge

Two walls block us: memory (the model doesn't fit on one GPU) and compute (training would take centuries). Distributed training solves both by spreading work across thousands of GPUs. The rest of this lesson is about how.

Model Utilization Fraction (MFU)

No training job achieves 100% of peak GPU throughput. Communication, synchronization, memory copies, and pipeline bubbles all eat into useful compute. The Model FLOPs Utilization (MFU) measures what fraction of theoretical peak is spent on actual model arithmetic:

MFU MFU = (model FLOPs per step) / (peak FLOP/s × step time)
Llama 3 achieved ~38-43% MFU on 24K H100s. Even 40% is considered excellent at this scale.
Chapter 02

GPU Hardware — The Compute Engine

Before distributing work, you need to understand what a single GPU can do. A modern GPU isn't one big processor — it's thousands of small processors working in parallel, organized into a hierarchy.

Streaming Multiprocessors (SMs)

An NVIDIA H100 contains 132 Streaming Multiprocessors. Each SM is a self-contained compute unit with its own registers, shared memory, and two types of cores:

Definition
CUDA Cores

Scalar processing units that perform one FP32 multiply-add (a × x + b) per clock cycle. The H100 has 128 CUDA cores per SM × 132 SMs = 16,896 CUDA cores total. These handle general-purpose parallel computation.

Definition
Tensor Cores

Specialized matrix-multiply units. Each tensor core computes a [16×4] · [4×8] matrix product per cycle — that's 16 × 4 × 8 = 1,024 FLOPs in a single operation. The H100 has 4 tensor cores per SM × 132 SMs = 528 tensor cores. This is where the real throughput lives.

Tensor cores are why modern GPUs achieve such absurd throughput for matrix operations. A standard matmul decomposes into many small [16×4]·[4×8] tiles, and the tensor cores chew through them in parallel.

Mixed Precision Training

Tensor cores are optimized for half-precision formats: FP16 (IEEE float16) or BF16 (bfloat16). The trick is to do the forward and backward pass in BF16 for speed, but keep the optimizer states and a master copy of weights in FP32 for numerical stability.

Why BF16 Over FP16?

BF16 has the same exponent range as FP32 (8 exponent bits) but fewer mantissa bits (7 vs 23). This means it can represent the same range of magnitudes — no overflow or underflow surprises — at the cost of some precision. FP16 has only 5 exponent bits and can overflow during training. BF16 is now the standard for LLM training.

GPU Evolution: 1000× in a Decade

GPUYearBF16 Tensor TFLOP/sMemoryMem BW (GB/s)
K402013~5 (FP32 only)12 GB GDDR5288
V100201712532 GB HBM2900
A100202031280 GB HBM2e2,039
H100202398980 GB HBM33,352
B2002025~5,000192 GB HBM3e8,000

From K40 to B200: roughly 1,000× more throughput. This isn't just faster transistors — it's a fundamental architectural shift toward specialized matrix hardware.

Arithmetic Intensity & the Roofline Model

Not every operation can use all that throughput. The key metric is arithmetic intensity: how many FLOPs you do per byte you load from memory.

Arithmetic Intensity Arithmetic Intensity = FLOPs / Bytes Loaded
Unit: FLOPs per byte. Higher = more compute-bound (good for GPUs).
Worked Example — Matmul vs Element-wise

Matrix multiply [M×K] · [K×N]: does 2MKN FLOPs, loads (MK + KN) × 2 bytes (BF16). For M=N=K=4096: intensity = 2×40963 / (2×40962 × 2) ≈ 2048 FLOPs/byte. Extremely compute-bound.

Element-wise ReLU on 4096×4096: does N operations, loads N × 2 bytes. Intensity = 1/2 = 0.5 FLOPs/byte. Completely memory-bandwidth-bound.

The H100 has 989 TFLOP/s compute and 3,352 GB/s bandwidth. The crossover point is 989,000 / 3,352 ≈ 295 FLOPs/byte. Operations below this intensity are bottlenecked by memory bandwidth, not compute. Large matmuls (the core of transformers) are well above this threshold — which is why tensor cores can actually be utilized.

The Roofline

Think of it as a speed limit. At low arithmetic intensity, memory bandwidth is the ceiling — no matter how many tensor cores you have, you're waiting for data. At high intensity, compute is the ceiling — you're limited by TFLOP/s. Large batch matmuls sit firmly in the compute-bound regime, which is exactly where GPUs shine.

Chapter 03

GPU Servers & Clusters

A single H100 has 80 GB of memory and 989 TFLOP/s. To train a 405B-parameter model, you need to wire up thousands of these chips. The critical question is: how fast can they talk to each other?

The Bandwidth Hierarchy

Not all GPU connections are equal. There's a steep hierarchy of bandwidth, and understanding it is essential for choosing parallelism strategies.

GPU → Server → Rack → Pod → Cluster Interactive

Click a level to zoom in. Each level has dramatically less bandwidth than the one below it.

Level: Single GPU
LevelWhatBandwidthRatio to HBM
HBM ↔ GPUOn-chip memory bus3,352 GB/s
GPU ↔ GPU (intra-server)NVLink / NVSwitch900 GB/s0.27×
Server ↔ Server (intra-pod)InfiniBand (8× 400G)~50 GB/s0.015×
Pod ↔ Pod (cross-pod)InfiniBand (spine)<50 GB/s<0.015×
The 18× Drop

Going from intra-server NVLink (900 GB/s) to inter-server InfiniBand (~50 GB/s) is an 18× bandwidth drop. This cliff shapes every parallelism decision: put communication-heavy operations on NVLink, communication-light operations across InfiniBand.

Cluster Anatomy: Llama 3's 24,576 GPUs

Concrete Numbers — Meta's Grand Teton Cluster

1 Server = 8× H100 GPUs, connected via NVSwitch (900 GB/s all-to-all). 640 GB total GPU memory.

1 Rack = 2 servers = 16 GPUs. 1.28 TB memory.

1 Pod = 192 racks = 3,072 GPUs. Connected by InfiniBand fat-tree network (50 GB/s per GPU). 245 TB memory.

1 Cluster = 8 pods = 24,576 GPUs. Cross-pod InfiniBand (<50 GB/s). Total: 1.97 PB memory, 24.3 EFLOP/s peak.

Network Topology

Within a pod, Meta uses a fat-tree topology: every server connects to a top-of-rack switch, which connects to spine switches, creating multiple redundant paths. This provides full bisection bandwidth — any half of the pod can talk to the other half at full speed.

Across pods, bandwidth is more limited. Meta uses a rail-optimized design: GPU 0 in every server connects to the same network rail, GPU 1 to another rail, etc. This ensures that collective operations (like all-reduce) on the same GPU index across servers use dedicated, non-contending paths.

Failures at Scale

When Things Break

With 24,576 GPUs, failures are routine. Meta reported interruptions every few hours during Llama 3 pre-training. Their solution: aggressive checkpointing (save model state every few minutes), automatic job restart (detect failures, reassign work, resume from last checkpoint), and spare capacity (extra GPUs on standby). At this scale, reliability engineering is as important as the training algorithm itself.

Google TPUs

Google takes a different approach with TPUs (Tensor Processing Units). The v5p chip delivers 459 TFLOP/s (BF16) and pods scale up to 8,960 chips with a custom high-bandwidth interconnect (ICI). TPU pods achieve tighter integration than GPU clusters at equivalent scale, but are only available through Google Cloud.

Chapter 04

Data Parallelism — The Simple Idea

Here's the most intuitive way to use multiple GPUs: give every GPU a complete copy of the model, and split the training data across them. Each GPU processes a different mini-batch, computes gradients, and then they all share results.

Definition
Data Parallelism (DP)

Every GPU holds a full copy of the model (weights, gradients, optimizer states). The training batch is split into N equal chunks, one per GPU. Each GPU computes local gradients on its chunk, then gradients are aggregated across all GPUs before the weight update.

Why Gradient Averaging Works

The key mathematical insight: gradients are linear. The gradient over a batch is the average of gradients over individual samples. So splitting a batch across GPUs and averaging the per-GPU gradients gives the exact same result as computing the gradient on one GPU with the full batch.

Gradient Linearity ∇L(B) = (1/|B|) ∑x ∈ B ∇L(x)
= (1/N) ∑i=1N [ (N/|B|) ∑x ∈ Bi ∇L(x) ]
= (1/N) ∑i=1N ∇L(Bi)
Average of local gradients = gradient of the full batch. Mathematically equivalent.

All-Reduce: The Communication Primitive

After each GPU computes its local gradient, we need to compute the sum (or average) and distribute the result to every GPU. This operation is called all-reduce.

Definition
All-Reduce

A collective communication operation that takes an array from each of N participants, computes an element-wise reduction (sum, average, max), and distributes the result to all participants. Every GPU ends up with the same aggregated result.

Ring All-Reduce Algorithm Interactive

Watch how N GPUs exchange gradients in a ring. Scatter-reduce first, then all-gather. Each GPU sends and receives exactly once per step.

5
Ready — click Step

Ring All-Reduce

A naive all-reduce sends all data to one GPU, sums it, and broadcasts back. That bottlenecks on one link. The ring all-reduce is much smarter:

Ring All-Reduce
  1. Setup: Arrange N GPUs in a ring. Each GPU splits its gradient into N chunks.
  2. Scatter-reduce (N−1 steps): Each GPU sends one chunk to its right neighbor and receives one chunk from its left. It adds the received chunk to its own. After N−1 steps, each GPU holds one fully-reduced chunk.
  3. All-gather (N−1 steps): Each GPU sends its fully-reduced chunk rightward. After N−1 steps, every GPU has all N reduced chunks.
Worked Example — Communication Volume

Model size M bytes, N GPUs. Each GPU's gradient is M bytes.

Naive: All GPUs send M bytes to GPU 0, then GPU 0 broadcasts M bytes. Total: 2(N−1)M bytes through GPU 0. Bottleneck: GPU 0.

Ring: Each step, every GPU sends M/N bytes. There are 2(N−1) steps total. Per GPU: 2(N−1) × M/N ≈ 2M bytes (for large N). Every link carries equal load — no bottleneck.

Communication Is Independent of N

In ring all-reduce, each GPU sends approximately 2M bytes regardless of how many GPUs participate. The total time is 2(N−1)M/(N × bandwidth) ≈ 2M/bandwidth. Doubling the GPUs doesn't double the communication cost — it's essentially free scaling.

The Problem with Pure DP

Data parallelism has one fatal flaw for large models: every GPU needs a full copy of the model. For Llama 3 405B, that's ~4.86 TB of training state per GPU. Even with 80 GB of HBM, a single H100 can't hold it. We need to shard the model itself.

Chapter 05

Fully Sharded Data Parallelism

The insight behind FSDP is beautifully simple: in data parallelism, every GPU stores the entire model, but at any given moment during the forward pass, it only needs one layer's weights. What if we only store the weights we need right now?

Definition
ZeRO (Zero Redundancy Optimizer)

A family of memory optimization techniques that shard (split) the model's training state across GPUs, eliminating redundancy. Each weight tensor is "owned" by exactly one GPU, which stores its weight, gradient, and optimizer states. Other GPUs request it on demand.

ZeRO Stages

ZeRO comes in three stages, each sharding more aggressively:

StageWhat's ShardedMemory per GPUComm. Overhead
Stage 1Optimizer states only (m, v)~Nparams × (2+2+8/N) bytesSame as DP
Stage 2Optimizer states + gradients~Nparams × (2+10/N) bytesSame as DP
Stage 3 (FSDP)Everything: weights + grads + optimizer~Nparams × 12/N bytes~3× DP
Worked Example — Llama 3 with FSDP

405B params × 12 bytes = 4.86 TB total training state.

Stage 3 on 1,024 GPUs: 4.86 TB / 1,024 = 4.75 GB per GPU. That fits comfortably in 80 GB, leaving room for activations and buffers.

Stage 1 on 1,024 GPUs: Each GPU stores full weights (810 GB in BF16) + full gradients (810 GB) + sharded optimizer (4 bytes × 2 × 405B / 1024 ≈ 3.16 GB). Still doesn't fit — 810 GB alone exceeds one GPU.

For 405B parameters, Stage 3 is mandatory.

How FSDP Works (Stage 3)

Imagine the model has L layers. GPU k "owns" a subset of layers. Here's what happens:

FSDP Forward + Backward
  1. Forward pass, layer i: The owner of Wi broadcasts (all-gather) the weights to all GPUs. Every GPU computes the forward pass for layer i. Then non-owners discard their copy of Wi to free memory.
  2. Backward pass, layer i: The owner broadcasts Wi again (needed for gradient computation). Every GPU computes local dL/dWi. Each GPU sends its local gradient to the owner via reduce-scatter. The owner accumulates all gradients and updates Wi.
  3. Overlap: While computing layer i, prefetch layer i+1's weights in the background. Communication and computation run in parallel.
Memory Scales Perfectly

With N GPUs, each GPU stores 1/N of the model. Double the GPUs, halve the memory per GPU. This is how you fit a 4.86 TB model on 80 GB GPUs: use enough of them. The cost is 3× more communication than vanilla DP (one all-gather in forward, one all-gather + one reduce-scatter in backward).

Communication vs. Computation Overlap

The magic of FSDP in practice is prefetching. While GPU cores are busy computing the forward pass for layer i, the network interface card (NIC) is already pulling in weights for layer i+1. If the network is fast enough relative to computation, the communication cost is almost entirely hidden.

When Overlap Fails

If the model's layers are small (few parameters, fast to compute) but the network is slow, the GPU finishes computing before the next layer's weights arrive. It sits idle, waiting. This is why FSDP works best on fast interconnects (NVLink) and with models that have large, compute-heavy layers.

Chapter 06

Hybrid & 3D Parallelism

FSDP works brilliantly within a server where NVLink provides 900 GB/s. But across servers, bandwidth drops 18×. If you do FSDP across all 24,576 GPUs, the slow inter-server links become a bottleneck. The solution: HSDP — combine FSDP locally with plain DP globally.

Definition
HSDP (Hybrid Sharded Data Parallelism)

Split N GPUs into M groups of K. Within each group: full FSDP (shard everything). Across groups: plain DP (replicate model, all-reduce gradients). Place intra-group communication on fast links, inter-group on slow links.

Communication Analysis

Within an FSDP group of K GPUs, the forward+backward pass requires 3× the model size in communication (two all-gathers + one reduce-scatter). Across M groups, only gradients need to be all-reduced: 1× the model size.

Worked Example — Llama 3 HSDP

24,576 GPUs. Split into M=3,072 groups of K=8 (one group = one server).

Intra-server (FSDP): 3× model communication over NVLink at 900 GB/s. Fast.

Inter-server (DP): 1× gradient all-reduce over InfiniBand at ~50 GB/s. Slower, but 3× less data.

By routing 3× the traffic on the 18× faster link and only 1× on the slow link, HSDP matches the network hierarchy to the communication pattern.

3D Parallelism

For the largest models, even HSDP isn't enough. You need to combine multiple parallelism strategies simultaneously. This is 3D parallelism (or 4D, or 5D — however many dimensions you need).

Think of the GPU cluster as a multi-dimensional mesh. Each dimension maps to a different parallelism strategy:

DimensionStrategyWhat's SplitBest Link
Dim 1TP (Tensor Parallelism)Weight matricesNVLink (fastest)
Dim 2PP (Pipeline Parallelism)Model layersNVLink or fast IB
Dim 3DP/FSDPData batchesInfiniBand (slowest OK)
Dim 4CP (Context Parallelism)Sequence lengthVaries
The Design Principle

Communication-heavy strategies go on fast links. Communication-light strategies go on slow links. TP requires all-reduces after every layer — put it on NVLink. DP only all-reduces gradients once per step — InfiniBand is fine. This is the fundamental insight behind all multi-dimensional parallelism schemes.

Llama 3 405B used 4D parallelism: TP=8 (within server) × PP=16 (across servers in a pod) × CP=1-4 × DP=variable across pods. This mapped the communication requirements to the bandwidth hierarchy almost perfectly.

Chapter 07

Pipeline Parallelism — Split by Layers

Pipeline parallelism takes a different approach: instead of giving every GPU the full model, split the model by layers. GPU 1 gets layers 1-20, GPU 2 gets layers 21-40, and so on. Data flows through the pipeline like an assembly line.

Definition
Pipeline Parallelism (PP)

Partition the model's layers across P stages, one stage per GPU (or group of GPUs). Each stage processes its layers, passes the output (activations) to the next stage, and receives gradients flowing back. The inter-stage communication is just the activation tensor at each layer boundary — much smaller than the full model.

The Bubble Problem

Naive pipelining has a devastating flaw. Stage 2 can't start until Stage 1 finishes. Stage 3 can't start until Stage 2 finishes. During forward pass, only one stage is active at a time — the rest idle. Same during backward. This creates a pipeline bubble.

Pipeline Parallelism Schedules Interactive

Compare schedules. Each colored block is a micro-batch (forward or backward). Gray = idle (bubble). Toggle between naive, GPipe, and 1F1B.

4
8
Bubble: 75.0%

Micro-batching: The Fix

The solution: split the batch into micro-batches and pipeline them. While Stage 1 processes micro-batch 2, Stage 2 can process micro-batch 1. The pipeline fills up, and most of the time, all stages are busy.

GPipe Schedule

GPipe (2019) is the simplest micro-batching scheme: run all forward micro-batches through the pipeline, then run all backward micro-batches. There's still a bubble at the boundary where forward finishes and backward starts.

1F1B Schedule

The 1F1B (one-forward-one-backward) schedule interleaves forward and backward passes. After the pipeline fills with forward micro-batches, each stage alternates: one forward, one backward. This reduces peak memory because gradients and activations for earlier micro-batches are freed before later ones are computed.

Pipeline Bubble Fraction Bubble fraction ≈ (P − 1) / (P − 1 + M)
P = number of pipeline stages, M = number of micro-batches. More micro-batches → smaller bubble.
Worked Example — Bubble Calculation

P=4 stages, M=8 micro-batches: bubble = 3 / (3 + 8) = 27.3%.

P=4 stages, M=32 micro-batches: bubble = 3 / (3 + 32) = 8.6%.

P=16 stages, M=8: bubble = 15 / (15 + 8) = 65.2%. Terrible!

P=16 stages, M=64: bubble = 15 / (15 + 64) = 19.0%. Manageable.

The rule: M should be several times larger than P to keep the bubble small.

Interleaved Pipeline Parallelism

There's an even cleverer trick: assign non-contiguous layers to each stage. Instead of GPU 1 getting layers 1-20, give it layers 1-5 and 21-25. This creates a "virtual pipeline" with more stages but each stage computes faster. The bubble shrinks because there are more, shorter stages to pipeline.

PP Communication Is Light

Pipeline parallelism only sends activations between stages — a single tensor at each layer boundary. For a transformer, this is [batch × seq_len × hidden_dim] × 2 bytes, which is far smaller than the full model weights. That's why PP works across slower InfiniBand links.

Activation Checkpointing

During the forward pass, each layer produces activations that must be saved for the backward pass. For large models with long sequences, these activations can consume more memory than the model itself. Activation checkpointing (gradient checkpointing) saves only the input to each layer and recomputes the activations during backward. This trades compute (~33% more forward passes) for memory (~proportional to sqrt(L) instead of L).

Chapter 08

Tensor Parallelism — Split by Dimensions

Pipeline parallelism splits across layers. Tensor parallelism goes deeper: it splits individual weight matrices across GPUs. Each GPU holds a slice of every layer's weights and computes a slice of every layer's output.

Definition
Tensor Parallelism (TP)

Partition individual weight matrices across T GPUs along the row or column dimension. Each GPU computes a portion of the matrix multiply, then an all-reduce (or concatenation) combines the partial results. This happens at every layer, requiring extremely fast inter-GPU links.

Column Parallel

Consider a linear layer Y = XW where W is [K × N]. Split W into T column chunks: W = [W1, W2, ..., WT], each of size [K × N/T]. GPU i computes Yi = X · Wi, getting a [batch × N/T] output. Concatenate all Yi to get the full [batch × N] output.

Column Parallel Linear Y = X · [W1 | W2 | ... | WT]
GPU i computes: Yi = X · Wi   (local matmul, no communication)
Y = concat(Y1, Y2, ..., YT)   (all-gather)

Row Parallel

Alternatively, split W along rows: W = [W1; W2; ...; WT], each [K/T × N]. Now the input X must be split column-wise. GPU i computes Yi = Xi · Wi. The results are partial sums that must be added: Y = ∑ Yi (all-reduce).

Row Parallel Linear GPU i computes: Yi = Xi · Wi   (local matmul)
Y = Y1 + Y2 + ... + YT   (all-reduce)

Transformer MLP Block

The transformer MLP has two linear layers: first projects up (hidden → 4×hidden), then projects down (4×hidden → hidden). The elegant trick: use column parallel on the first layer and row parallel on the second. The column parallel output is already split across GPUs, which is exactly what row parallel needs as input. No communication between the two layers!

Only One All-Reduce per MLP

Column parallel on layer 1 → no communication → row parallel on layer 2 → one all-reduce. The entire MLP block requires only one all-reduce for the forward pass (and one for backward). Without this pairing trick, you'd need two all-reduces.

Self-Attention

Multi-head attention is naturally column-parallel. If the model has H attention heads and T GPUs, assign H/T heads to each GPU. Each GPU computes its subset of heads independently, then the output projection (a row-parallel linear) combines them with a single all-reduce.

Why TP Needs NVLink

TP requires an all-reduce at every transformer block — twice per block for forward (once for MLP, once for attention), and twice more for backward. For a 126-layer Llama 3 model, that's ~504 all-reduces per step. At 50 GB/s InfiniBand, this would be catastrophically slow. At 900 GB/s NVLink, it's tolerable. TP must stay within a single server.

Worked Example — TP Communication

Llama 3 405B, hidden dim = 16,384, batch×seq = 4M tokens. TP=8 across one server.

Each all-reduce sends ~16,384 × 4M × 2 bytes = 128 GB per transformer block (for the activation tensor). With 126 blocks: 126 × 4 × 128 GB = 64 TB of communication per step.

At 900 GB/s NVLink: 64 TB / 900 GB/s ≈ 71 seconds just for communication. This must overlap with computation to be viable.

Context Parallelism

When sequence lengths grow very long (128K+ tokens), even the activations within a single layer become huge. Context parallelism (CP) splits the sequence dimension: each GPU handles a portion of the sequence. In attention, this requires a ring attention scheme where KV chunks are passed around a ring so each GPU can attend to all positions.

Sequence Parallelism (Not Context Parallelism)

Don't confuse context parallelism with sequence parallelism. Sequence parallelism applies TP-style sharding to operations that don't naturally parallelize (LayerNorm, Dropout). Instead of replicating these across all TP ranks, each GPU handles its sequence chunk. It reduces memory but uses the same TP group — it's a complement to TP, not a separate dimension.

Chapter 09

Parallelism Strategy Explorer

Now let's put it all together. The interactive tool below lets you configure a distributed training setup and see the resulting memory usage, communication overhead, and GPU utilization.

Distributed Training Strategy Explorer Interactive

Configure your model, cluster, and parallelism strategy. See how memory, communication, and utilization change.

70
128
1
1
DP=128 replicas
Try These Configurations

Small model (7B), 8 GPUs: Pure DP works. Each GPU stores ~84 GB — slightly over 80 GB limit. Turn on FSDP to fix it.

Large model (405B), 128 GPUs: Need TP=8 + PP=4 + FSDP. Watch how each dimension reduces memory but adds communication.

Extreme: 405B on 8 GPUs: Even with TP=8 and FSDP, memory is crushing. You need more GPUs or model compression.

Chapter 10

Putting It All Together

Llama 3 Training Recipe

Meta trained Llama 3 405B on 24,576 H100 GPUs using 4D parallelism:

Llama 3 405B Configuration

TP=8: Within each server (8 GPUs on NVLink). Each GPU holds 1/8 of every weight matrix.

PP=16: Across 16 servers. Each stage handles ~8 of the 126 transformer layers.

DP=192: 24,576 / (8×16) = 192 data-parallel replicas, with FSDP for memory efficiency.

CP=variable: Context parallelism added for long-sequence training (up to 128K tokens).

Result: ~38-43% MFU, ~54 days of training, ~15 trillion tokens.

Mixed Precision Training

Every large-scale training job uses mixed precision. The recipe:

Mixed Precision Training
  1. Master weights: Stored in FP32 (the "ground truth" copy, kept by the optimizer).
  2. Forward pass: Cast weights to BF16, compute activations in BF16. 2× memory savings, faster matmul.
  3. Loss computation: In FP32 for numerical stability.
  4. Backward pass: Compute gradients in BF16.
  5. Optimizer step: Convert gradients to FP32, update master weights in FP32. Cast back to BF16 for next forward.

Gradient Accumulation

Sometimes you want a larger effective batch size than your GPUs can fit in memory. Gradient accumulation runs multiple forward-backward passes, summing gradients, before doing a single optimizer step. Mathematically identical to a larger batch, but uses constant memory.

The Parallelism Comparison

StrategySplitsCommunicationWhen to Use
DPData (batches)All-reduce grads, 1× per stepModel fits on 1 GPU
FSDPData + model shards3× DP (all-gather + reduce-scatter)Model doesn't fit on 1 GPU
PPModel layersActivations between stages (small)Many layers, moderate inter-node BW
TPWeight matricesAll-reduce per layer (frequent)NVLink only, within server
CPSequence lengthRing attention (KV rotation)Very long sequences (128K+)

Memory Breakdown

For a complete picture of GPU memory during training:

ComponentSize (Llama 3 405B)Reducible By
Weights~810 GB (BF16)TP, PP, FSDP
Gradients~810 GB (BF16)FSDP, gradient accumulation
Optimizer (m, v)~3.24 TB (FP32)FSDP
ActivationsVariable (batch×seq×hidden)Activation checkpointing, PP, CP
Comm buffers~1-2 GBOverlap scheduling
The One Sentence

Distributed training maps the communication hierarchy of your algorithms to the bandwidth hierarchy of your hardware: frequent communication on fast links, infrequent communication on slow links.

Connections

This systems perspective complements the algorithmic side of deep learning. The Transformer lesson covers the architecture being distributed here. SSM/Mamba explores alternatives with different parallelism properties. And understanding RLHF shows why we train these massive models in the first place.

References

1. Rajbhandari et al. "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models." SC 2020. arXiv
2. Huang et al. "GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism." NeurIPS 2019. arXiv
3. Shoeybi et al. "Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism." 2020. arXiv
4. Narayanan et al. "Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM." SC 2021. arXiv
5. Dubey et al. "The Llama 3 Herd of Models." 2024. arXiv
6. Micikevicius et al. "Mixed Precision Training." ICLR 2018. arXiv
7. CS 231n Lecture 11 slides: "Large-Scale Distributed Training." Stanford, 2024.