Tazi et al., Chapter 2

Training on One GPU

Memory breakdown, activation recomputation, gradient accumulation, and mixed precision — the single-GPU survival kit.

Prerequisites: Chapter 1 (Overview). Familiarity with forward/backward passes and basic matrix math.
9
Chapters
3
Simulations
9
Quizzes

Chapter 0: The OOM Problem

You have one GPU. You load a model and start training. Within seconds you see the dreaded message: CUDA out of memory.

This is the single most common failure mode when training neural networks. Before we can dream of scaling to hundreds of GPUs, we need to understand why we run out of memory and what we can do about it on a single device.

A typical H100 GPU has 80 GB of memory. That sounds like a lot. But a 7B parameter model in mixed precision already needs about 112 GB just for weights, gradients, and optimizer states — before a single activation is stored.

The core question: Where does all that memory go? When you understand the four things stored in GPU memory during training — weights, gradients, optimizer states, and activations — you can make smart trade-offs to fit any model on any hardware.

In this chapter, we will dissect GPU memory usage during training, learn three techniques to tame it (mixed precision, activation recomputation, gradient accumulation), and build a mental model that will guide every decision in the rest of the book.

Check: Why does a 7B model not fit on an 80 GB GPU for training?

Chapter 1: Memory Anatomy

During training, four categories of data live on your GPU:

1. Model Weights
The learnable parameters — the thing you are training
2. Gradients
Computed in the backward pass, same shape as the weights
3. Optimizer States
Adam stores momentum + variance — two extra copies per parameter
4. Activations
Intermediate results from the forward pass, kept for the backward pass

The first three (weights, gradients, optimizer states) depend only on the number of parameters. They do not change with batch size or sequence length.

Activations are the wild card. They grow linearly with batch size and quadratically with sequence length. This is why you can load a model just fine but OOM the moment you try to train with a large batch.

Key insight: Think of GPU memory like a suitcase. Weights, gradients, and optimizer states are the heavy items you always pack. Activations are the souvenirs — the more you try to bring home, the harder it gets to close the lid.

The PyTorch profiler reveals this dynamic beautifully. In the first training step, activations spike during the forward pass, then during the backward pass gradients build up while activations are freed. The optimizer state appears after the first step and persists.

Why the first step looks different: PyTorch's caching allocator does prep work on step 1, and optimizer states have not yet been created. This is why you sometimes survive step 1 but OOM on step 2.
Check: Which of the four memory categories grows with batch size?

Chapter 2: Weights & Optimizer Memory

Let us put numbers on the first three memory categories. The number of parameters in a standard transformer is approximately:

P ≈ 12 · L · h2 + V · h

where L is the number of layers, h is the hidden dimension, and V is the vocabulary size. The h2 term dominates for large models because it grows quadratically.

Each parameter is stored at some precision. FP32 uses 4 bytes per value; BF16 uses 2 bytes. The memory cost depends on the training regime:

ItemFP32 TrainingMixed Precision (BF16)
Weights4P bytes2P (BF16) + 4P (FP32 master copy)
Gradients4P bytes2P (BF16) or 4P (FP32 accumulation)
Adam optimizer8P bytes (m + v)8P bytes (m + v in FP32)
Total16P bytes16P–18P bytes
Surprising fact: Mixed precision does not save total memory for weights + gradients + optimizer. It just redistributes the bytes differently. The real win is in activations (smaller) and compute (faster BF16 ops on modern GPUs).

For concrete sizes: a 1B model needs ~16 GB, a 7B model needs ~112 GB, and a 70B model needs ~1,120 GB — that is 14 H100 GPUs just for the static model state.

Check: For a 7B model with Adam in mixed precision, roughly how much memory do weights + gradients + optimizer need?

Chapter 3: Activation Memory

Now for the wild card. The memory required for activations during a forward pass through a transformer in mixed precision is approximately:

Mact ≈ L · s · b · (34h + 5 · s · nheads)

where L = layers, s = sequence length, b = batch size, h = hidden dimension, and nheads = number of attention heads.

Two things jump out. Activation memory grows linearly with batch size and quadratically with sequence length (from the s · nheads attention term). The weights, gradients, and optimizer states are constant with respect to both.

Key insight: For short sequences or small batches, activations are negligible. But past ~2–4K tokens, activations become the dominant memory consumer. This is the "activation explosion."

This explosion is exactly why we need the techniques in the rest of this chapter. Activation recomputation caps the memory, gradient accumulation avoids large batches, and mixed precision halves the per-value activation cost.

Activation Memory vs. Sequence Length

Drag the slider to change the sequence length and watch activations explode.

Seq length 2048
Check: How does activation memory scale with sequence length?

Chapter 4: Mixed Precision Training

Modern GPUs have specialized hardware for lower-precision arithmetic. An H100 performs BF16 matrix multiplications roughly twice as fast as FP32. So why not do everything in BF16?

The problem is numerical stability. Gradients can be very small, and BF16 has limited precision (about 3 decimal digits). Accumulating many small gradients in BF16 introduces errors that can destabilize training.

The solution is mixed precision training: run the forward and backward passes in BF16 (fast, memory-efficient) but keep a master copy of the weights and accumulate gradients in FP32 (stable).

Forward pass
Use BF16 weights → BF16 activations (fast, small)
Backward pass
BF16 gradients computed, optionally accumulated in FP32
Optimizer step
FP32 master weights updated, then cast to BF16 for next step
Why it works: Mixed precision does not reduce total model-state memory (it actually adds ~4 bytes for the FP32 master copy). Its two real wins are: (1) activations are halved because the forward pass is in BF16, and (2) BF16 matrix multiplications are ~2x faster on modern GPUs.

The FP32 master copy is sometimes called the "master weights" in code and papers. Libraries like Nanotron also accumulate gradients in FP32 because BF16 is lossy for small values.

Check: What is the primary memory benefit of mixed precision training?

Chapter 5: Activation Recomputation

Now we tackle the activation explosion head-on. The idea is beautifully simple: throw away most activations during the forward pass and recompute them during the backward pass.

This is also called gradient checkpointing or rematerialization. We trade extra compute for less memory.

There are two main strategies:

StrategyWhat is savedMemory savingsCompute cost
Full recomputationOnly layer boundary activationsLargest (~90%+ reduction)~30–40% extra FLOPs
Selective recomputationMLP activations (discard attention)~70% reduction for GPT-3 scale~2.7% extra FLOPs
Key insight: Selective recomputation targets attention computations specifically, because they produce the largest activations but are cheap to recompute. This is why FlashAttention natively includes selective recomputation — it recomputes attention scores in the backward pass instead of storing them.

The result is dramatic. For a model with sequence length 8192, full recomputation can reduce activation memory from hundreds of GB to a manageable fraction, while selective recomputation achieves most of the savings at a tiny compute cost.

Hardware FLOPs vs. Model FLOPs: When measuring GPU efficiency, hardware FLOPs utilization (HFU) counts all operations including recomputation. Model FLOPs utilization (MFU) counts only the necessary forward/backward operations. MFU is more useful for comparing hardware, since it rewards GPUs that have enough memory to skip recomputation.
Recomputation Impact

Toggle recomputation strategies to see how they affect memory at different sequence lengths.

Check: Why is selective recomputation preferred over full recomputation in practice?

Chapter 6: Gradient Accumulation

Even with recomputation, activations still grow linearly with batch size. If our target global batch size is 4 million tokens but we can only fit 8K tokens in one pass, what do we do?

The answer is gradient accumulation: split the big batch into smaller micro-batches, run forward and backward on each, sum the gradients, and only then run the optimizer.

Micro-batch 1
Forward → Backward → accumulate gradients
Micro-batch 2
Forward → Backward → accumulate gradients
...
Repeat for all micro-batches
Optimizer Step
Average gradients, update weights
BSglobal = BSmicro × grad_acc_steps

The global batch size can be arbitrarily large while the memory footprint stays at the micro-batch level. The cost? Multiple sequential forward/backward passes per optimizer step — it is slower in wall-clock time.

Key insight: The forward/backward passes for each micro-batch are independent — they use different data but the same weights. This means they can be run in parallel on different GPUs. That is exactly what data parallelism does, and it is where we go next in Chapter 3.
Check: Gradient accumulation lets you increase the effective batch size. What does it cost?

Chapter 7: Memory Explorer

Let us bring everything together. This interactive tool lets you configure a model and see how memory is distributed across the four categories. Toggle activation recomputation and mixed precision to see their effects.

GPU Memory Explorer

Adjust model size and sequence length. Watch how the four memory categories change.

Parameters (B) 7
Seq Length 2048
Batch Size 4
Check: For a 7B model with batch size 4 and sequence length 4096, what dominates memory?

Chapter 8: Summary

We have now mastered the single-GPU toolkit. Here is the complete picture:

TechniqueWhat it doesCost
Mixed precisionHalves activation memory; 2x faster matmulsExtra 4 bytes/param for FP32 master weights
Selective recomputation~70% activation reduction~2.7% extra compute
Full recomputation~90%+ activation reduction~30–40% extra compute
Gradient accumulationArbitrary batch size at constant memorySequential micro-batch passes
What comes next: Gradient accumulation processes micro-batches sequentially. But the forward/backward passes are independent — they can be parallelized across GPUs. That is exactly the idea behind data parallelism, the topic of Chapter 3.

The profiler is your best friend. Use PyTorch's built-in profiler to trace memory allocation, identify bottlenecks, and validate that your optimizations are working. The trace shows CPU threads launching GPU kernels, CUDA streams handling compute and communication in parallel, and memory allocation patterns.

Check: A colleague says "mixed precision halves total model memory." Is this correct?