Memory breakdown, activation recomputation, gradient accumulation, and mixed precision — the single-GPU survival kit.
You have one GPU. You load a model and start training. Within seconds you see the dreaded message: CUDA out of memory.
This is the single most common failure mode when training neural networks. Before we can dream of scaling to hundreds of GPUs, we need to understand why we run out of memory and what we can do about it on a single device.
A typical H100 GPU has 80 GB of memory. That sounds like a lot. But a 7B parameter model in mixed precision already needs about 112 GB just for weights, gradients, and optimizer states — before a single activation is stored.
In this chapter, we will dissect GPU memory usage during training, learn three techniques to tame it (mixed precision, activation recomputation, gradient accumulation), and build a mental model that will guide every decision in the rest of the book.
During training, four categories of data live on your GPU:
The first three (weights, gradients, optimizer states) depend only on the number of parameters. They do not change with batch size or sequence length.
Activations are the wild card. They grow linearly with batch size and quadratically with sequence length. This is why you can load a model just fine but OOM the moment you try to train with a large batch.
The PyTorch profiler reveals this dynamic beautifully. In the first training step, activations spike during the forward pass, then during the backward pass gradients build up while activations are freed. The optimizer state appears after the first step and persists.
Let us put numbers on the first three memory categories. The number of parameters in a standard transformer is approximately:
where L is the number of layers, h is the hidden dimension, and V is the vocabulary size. The h2 term dominates for large models because it grows quadratically.
Each parameter is stored at some precision. FP32 uses 4 bytes per value; BF16 uses 2 bytes. The memory cost depends on the training regime:
| Item | FP32 Training | Mixed Precision (BF16) |
|---|---|---|
| Weights | 4P bytes | 2P (BF16) + 4P (FP32 master copy) |
| Gradients | 4P bytes | 2P (BF16) or 4P (FP32 accumulation) |
| Adam optimizer | 8P bytes (m + v) | 8P bytes (m + v in FP32) |
| Total | 16P bytes | 16P–18P bytes |
For concrete sizes: a 1B model needs ~16 GB, a 7B model needs ~112 GB, and a 70B model needs ~1,120 GB — that is 14 H100 GPUs just for the static model state.
Now for the wild card. The memory required for activations during a forward pass through a transformer in mixed precision is approximately:
where L = layers, s = sequence length, b = batch size, h = hidden dimension, and nheads = number of attention heads.
Two things jump out. Activation memory grows linearly with batch size and quadratically with sequence length (from the s · nheads attention term). The weights, gradients, and optimizer states are constant with respect to both.
This explosion is exactly why we need the techniques in the rest of this chapter. Activation recomputation caps the memory, gradient accumulation avoids large batches, and mixed precision halves the per-value activation cost.
Drag the slider to change the sequence length and watch activations explode.
Modern GPUs have specialized hardware for lower-precision arithmetic. An H100 performs BF16 matrix multiplications roughly twice as fast as FP32. So why not do everything in BF16?
The problem is numerical stability. Gradients can be very small, and BF16 has limited precision (about 3 decimal digits). Accumulating many small gradients in BF16 introduces errors that can destabilize training.
The solution is mixed precision training: run the forward and backward passes in BF16 (fast, memory-efficient) but keep a master copy of the weights and accumulate gradients in FP32 (stable).
The FP32 master copy is sometimes called the "master weights" in code and papers. Libraries like Nanotron also accumulate gradients in FP32 because BF16 is lossy for small values.
Now we tackle the activation explosion head-on. The idea is beautifully simple: throw away most activations during the forward pass and recompute them during the backward pass.
This is also called gradient checkpointing or rematerialization. We trade extra compute for less memory.
There are two main strategies:
| Strategy | What is saved | Memory savings | Compute cost |
|---|---|---|---|
| Full recomputation | Only layer boundary activations | Largest (~90%+ reduction) | ~30–40% extra FLOPs |
| Selective recomputation | MLP activations (discard attention) | ~70% reduction for GPT-3 scale | ~2.7% extra FLOPs |
The result is dramatic. For a model with sequence length 8192, full recomputation can reduce activation memory from hundreds of GB to a manageable fraction, while selective recomputation achieves most of the savings at a tiny compute cost.
Toggle recomputation strategies to see how they affect memory at different sequence lengths.
Even with recomputation, activations still grow linearly with batch size. If our target global batch size is 4 million tokens but we can only fit 8K tokens in one pass, what do we do?
The answer is gradient accumulation: split the big batch into smaller micro-batches, run forward and backward on each, sum the gradients, and only then run the optimizer.
The global batch size can be arbitrarily large while the memory footprint stays at the micro-batch level. The cost? Multiple sequential forward/backward passes per optimizer step — it is slower in wall-clock time.
Let us bring everything together. This interactive tool lets you configure a model and see how memory is distributed across the four categories. Toggle activation recomputation and mixed precision to see their effects.
Adjust model size and sequence length. Watch how the four memory categories change.
We have now mastered the single-GPU toolkit. Here is the complete picture:
| Technique | What it does | Cost |
|---|---|---|
| Mixed precision | Halves activation memory; 2x faster matmuls | Extra 4 bytes/param for FP32 master weights |
| Selective recomputation | ~70% activation reduction | ~2.7% extra compute |
| Full recomputation | ~90%+ activation reduction | ~30–40% extra compute |
| Gradient accumulation | Arbitrary batch size at constant memory | Sequential micro-batch passes |
The profiler is your best friend. Use PyTorch's built-in profiler to trace memory allocation, identify bottlenecks, and validate that your optimizations are working. The trace shows CPU threads launching GPU kernels, CUDA streams handling compute and communication in parallel, and memory allocation patterns.