Epochs, batches, DataLoaders, and the six operations that make up every training step — the complete anatomy of how models learn.
You read a tutorial that says "train for 3 epochs with batch size 32." What does that actually mean? How many times does each example get seen? How many gradient updates happen? How many forward passes? These questions have precise answers, and getting them wrong means your training run does something completely different from what you intended.
This isn't pedantry. If you set max_steps=1000 thinking that means "1000 epochs" when it actually means "1000 gradient updates," your model trains for a fraction of what you planned. If you confuse "iteration" with "epoch," your learning rate schedule fires at the wrong time. The vocabulary matters because the code uses it literally.
Let's nail down every term with concrete numbers. No ambiguity. No "it depends." Just arithmetic.
Suppose you have a dataset of 10,000 training examples. Images, sentences, whatever — 10,000 of them, each with a label.
You pick a batch size of 32. This means the model sees 32 examples at once, computes the average loss across those 32, and makes one weight update. Why 32 and not all 10,000? Because (a) 10,000 examples won't fit in GPU memory, and (b) updating after every small batch makes training faster to converge than waiting for the entire dataset.
One iteration (also called a step) is one cycle of: load a batch of 32 → forward pass → compute loss → backward pass → update weights. One iteration processes one batch.
How many iterations make up one pass through the entire dataset? That's simple division: 10,000 / 32 = 312.5. You can't process half a batch, so the last batch has only 16 examples (the remainder). That gives us ceil(10000 / 32) = 313 iterations to see every example once.
One complete pass through the entire dataset is called an epoch. One epoch = 313 iterations (with our numbers). After one epoch, every example has been seen exactly once.
If you train for 3 epochs, the total number of gradient updates is 3 × 313 = 939 steps. Each example gets seen 3 times total (once per epoch). The order is reshuffled each epoch, so the batches are different — but every example appears exactly once per epoch.
Let's work through a concrete example by hand.
Setup: Dataset has 1,000 examples. Batch size = 64. Train for 5 epochs.
Iterations per epoch: ceil(1000 / 64) = ceil(15.625) = 16 iterations.
Let's be precise about that last batch. The first 15 batches each have 64 examples: 15 × 64 = 960. The 16th batch has the remaining 1000 - 960 = 40 examples. It's a partial batch — smaller than the rest, but still processed as one step.
Total steps across all epochs: 5 × 16 = 80 gradient updates.
Total forward passes: 80 batches total. 75 batches of 64 + 5 batches of 40 = 4800 + 200 = 5000 example-level forward passes. That's exactly 5 × 1000 — every example seen exactly 5 times. The math checks out.
Memory implication: At any given step, only 64 examples (or 40 for the last batch) live in GPU memory. The model computes gradients on this small sample and updates. Then the batch is discarded and the next one is loaded. This is why you can train on datasets that are far larger than your GPU memory.
Sometimes you want a larger effective batch size but can't fit it in GPU memory. Gradient accumulation solves this by splitting the large batch into smaller micro-batches, computing gradients on each, and summing them before updating.
With batch size 32 and gradient accumulation steps = 4:
scheduler.step() counts weight updates. In some logging setups, the step counter increments on every micro-batch. Confusing the two means your learning rate schedule, logging frequency, and checkpoint interval are all wrong.Adjust the parameters and watch the arithmetic update in real time. The timeline below shows epochs as large colored blocks with individual iterations as ticks. Click Animate to watch a single step in action — 32 examples selected, processed, and weights updated.
Enter dataset size, batch size, and epochs. See exactly how many steps, forward passes, and gradient updates happen.
python import math def training_arithmetic(dataset_size, batch_size, epochs, accum_steps=1): """Calculate every key number for a training run.""" # Iterations per epoch = how many batches to see all data once iters_per_epoch = math.ceil(dataset_size / batch_size) # Last batch may be partial last_batch_size = dataset_size % batch_size if last_batch_size == 0: last_batch_size = batch_size # divides evenly # Total micro-batch steps (forward+backward passes) total_micro_steps = iters_per_epoch * epochs # Weight updates (accounting for gradient accumulation) effective_batch = batch_size * accum_steps updates_per_epoch = math.ceil(dataset_size / effective_batch) total_weight_updates = updates_per_epoch * epochs # Each example is seen exactly once per epoch total_examples_seen = dataset_size * epochs return { "iters_per_epoch": iters_per_epoch, "last_batch_size": last_batch_size, "total_micro_steps": total_micro_steps, "total_weight_updates": total_weight_updates, "effective_batch_size": effective_batch, "total_examples_seen": total_examples_seen, } # Example: 10K dataset, batch 32, 3 epochs, no accumulation result = training_arithmetic(10000, 32, 3) print(f"Iters/epoch: {result['iters_per_epoch']}") # 313 print(f"Last batch: {result['last_batch_size']}") # 16 print(f"Total steps: {result['total_weight_updates']}") # 939 print(f"Examples: {result['total_examples_seen']}") # 30000 # With gradient accumulation = 4 result2 = training_arithmetic(10000, 32, 3, accum_steps=4) print(f"Effective batch: {result2['effective_batch_size']}") # 128 print(f"Weight updates: {result2['total_weight_updates']}") # 237
| Term | Definition | Formula (our example) |
|---|---|---|
| Dataset size | Total number of training examples | N = 10,000 |
| Batch size (B) | Examples processed per forward pass | B = 32 |
| Iteration / Step | One forward + backward + update cycle | ceil(N/B) = 313 per epoch |
| Epoch | One complete pass through the dataset | 313 iterations |
| Total steps | Epochs × iterations per epoch | 3 × 313 = 939 |
| Gradient accumulation | Micro-batches summed before update | Accum=4 → eff. batch=128 |
| Last batch | Partial batch at end of epoch | 10000 mod 32 = 16 |
Your data lives in files — images on disk, text in JSONLs, tensors in .pt files. Your model expects tensors on GPU. Between these two worlds sits the DataLoader: it reads, transforms, batches, shuffles, and delivers data to your training loop. Understanding this pipeline is essential because data bottlenecks waste more GPU hours than bad hyperparameters.
But the DataLoader doesn't work alone. It wraps a Dataset — the object that knows how to read one example. This separation of concerns is powerful: change the data format, change only the Dataset. Change the batching strategy, change only the DataLoader. The model never knows the difference.
A PyTorch Dataset is any Python object that implements two methods:
__len__() — returns the total number of examples. The DataLoader calls this to know when an epoch ends.__getitem__(idx) — given an integer index, returns one example. Typically a tuple of (input_tensor, target_tensor).That's the entire contract. Two methods. The simplest Dataset is a wrapper around two lists:
python from torch.utils.data import Dataset class SimpleDataset(Dataset): def __init__(self, inputs, targets): self.inputs = inputs # list of tensors self.targets = targets # list of labels def __len__(self): return len(self.inputs) # e.g. 10000 def __getitem__(self, idx): return self.inputs[idx], self.targets[idx] # Usage: ds[47] returns (inputs[47], targets[47]) # len(ds) returns 10000
A more realistic Dataset reads from disk. Each call to __getitem__ opens a file, applies transforms (resize, normalize, augment), and returns tensors. The key insight: data lives on disk, but enters the training loop one example at a time through this method.
python class ImageDataset(Dataset): def __init__(self, image_paths, labels, transform=None): self.paths = image_paths # ["/data/img001.jpg", ...] self.labels = labels # [0, 1, 0, 2, ...] self.transform = transform # torchvision.transforms def __len__(self): return len(self.paths) def __getitem__(self, idx): img = Image.open(self.paths[idx]) # disk read if self.transform: img = self.transform(img) # resize, normalize, augment return img, self.labels[idx]
The DataLoader wraps a Dataset and handles everything the Dataset doesn't: batching, shuffling, parallel loading, and collation. Here are the critical parameters:
| Parameter | Default | What it does |
|---|---|---|
| batch_size | 1 | How many examples per batch |
| shuffle | False | Randomize order each epoch |
| num_workers | 0 | Parallel data loading processes |
| collate_fn | default | How to merge examples into a batch tensor |
| drop_last | False | Skip incomplete final batch |
| pin_memory | False | Pre-allocate CPU→GPU transfer buffer |
python from torch.utils.data import DataLoader loader = DataLoader( dataset, # your Dataset object batch_size=32, # 32 examples per batch shuffle=True, # randomize order each epoch num_workers=4, # 4 parallel loading processes pin_memory=True, # faster GPU transfer drop_last=False, # keep the partial final batch ) # Iterate: each 'batch' is a tuple (inputs, targets) # inputs.shape = (32, ...), targets.shape = (32,) for inputs, targets in loader: outputs = model(inputs) loss = criterion(outputs, targets) # ... backward, step, zero_grad
Setup: Dataset with 100 examples, batch_size=16, shuffle=True, drop_last=False.
Epoch start: DataLoader generates a random permutation of indices [0..99]. Let's say: [47, 3, 82, 91, 12, 56, 7, 33, 68, 24, 99, 61, 45, 78, 15, 2, 88, ...].
Batch 1: indices [47, 3, 82, 91, 12, 56, 7, 33, 68, 24, 99, 61, 45, 78, 15, 2] — 16 examples. Calls dataset[47], dataset[3], ... , dataset[2], then stacks them into tensors.
Batch 2: next 16 indices from the permutation. And so on.
How many batches? ceil(100 / 16) = 7 batches. First 6 have 16 examples each (96 total). Last batch has 100 - 96 = 4 examples.
With drop_last=True: only 6 batches. Those last 4 examples are never seen this epoch. This ensures all batches have identical size, which matters for some architectures (like batch normalization, where statistics from a batch of 4 are unreliable).
__getitem__ is fast (reading from memory, doing minimal transforms), the overhead of spawning processes and transferring data between them can make things slower. Rule of thumb: num_workers=0 for in-memory data, num_workers=4-8 for disk-heavy loading (images, audio). Profile before assuming more workers = more speed.Watch how data flows from the Dataset through the DataLoader to the model. Adjust batch size and toggle shuffle/drop_last to see the effect on batch composition.
Left: numbered examples in the Dataset. Middle: DataLoader shuffles and groups into batches. Right: batches emerge ready for the model. Toggle options to see effects.
drop_last=True do?If you feed the model cats in batches 1-50 and dogs in batches 51-100, it will learn "first comes cats, then dogs" — a completely spurious pattern. The loss oscillates wildly as the model overfits to one class, then scrambles to relearn when the next class arrives. Shuffling randomizes the order so each batch is a representative sample of the whole dataset.
But shuffling is just one sampling strategy. What if your dataset has 90% dogs and 10% cats? Even with shuffling, most batches are overwhelmingly dogs. What about distributing data across 8 GPUs so none sees duplicates? Each of these problems has a different sampler.
Without shuffling, examples in the same class are often adjacent. Files sorted by name, data collected chronologically, labels grouped for human convenience — all create order in the data. The model sees many similar examples in a row, producing biased gradients that overfit to whatever class it's currently seeing.
With shuffling, each batch contains a mix of classes. The gradient from each batch points in a more representative direction. Training is smoother and converges faster.
A Sampler is the object that generates the sequence of indices. The DataLoader delegates all ordering decisions to it.
SequentialSampler — iterates indices 0, 1, 2, ..., N-1 in order. This is the default when shuffle=False. Used for validation and test sets where you want reproducible, deterministic order.
RandomSampler — generates a random permutation of [0..N-1]. This is the default when shuffle=True. Each epoch produces a different permutation.
WeightedRandomSampler — each example has a weight. Higher-weight examples are sampled more often. This is how you fix class imbalance: give rare-class examples higher weights so they appear more frequently.
DistributedSampler — for multi-GPU training. Splits the dataset across N GPUs so each GPU sees a unique, non-overlapping shard. GPU 0 gets examples [0, N, 2N, ...], GPU 1 gets [1, N+1, 2N+1, ...], etc. Each GPU processes 1/N of the data per epoch.
Problem: You have 900 "dog" images and 100 "cat" images. Without weighting, a random batch of 32 will have about 29 dogs and 3 cats. The model barely learns cats.
Solution: Give each example a sampling weight inversely proportional to its class frequency.
The sampler picks each example with probability proportional to its weight. After normalization, a cat example is 9× more likely to be selected than any individual dog example. Since there are 9× fewer cats, the expected composition of each batch becomes roughly 50/50.
python from torch.utils.data import WeightedRandomSampler import torch # 900 dogs (label 0), 100 cats (label 1) labels = [0] * 900 + [1] * 100 # Weight per CLASS (inverse frequency) class_counts = [900, 100] class_weights = [1.0 / c for c in class_counts] # [0.00111, 0.01] # Weight per EXAMPLE (look up by its label) sample_weights = [class_weights[l] for l in labels] sampler = WeightedRandomSampler( weights=sample_weights, num_samples=len(labels), # draw 1000 samples per epoch replacement=True, # must be True for weighted sampling ) loader = DataLoader(dataset, batch_size=32, sampler=sampler) # Now each batch has ~16 dogs + ~16 cats
The DataLoader calls dataset[idx] for each index in a batch, getting a list of individual examples. The collate function merges them into batch tensors.
The default collation uses torch.stack — it works when all examples have the same shape. Images of size (3, 224, 224) stack into a batch of (B, 3, 224, 224). Simple.
But what about variable-length sequences? One sentence has 12 tokens, another has 47. You can't stack tensors of different lengths. The custom collate function pads shorter sequences to the maximum length in the batch and creates an attention mask telling the model which positions are real and which are padding.
python import torch def collate_text(batch): """Pad variable-length sequences to the same length.""" inputs, labels = zip(*batch) # unpack list of (input, label) # Find max length in this batch max_len = max(len(x) for x in inputs) # Pad each sequence and build attention mask padded = [] masks = [] for x in inputs: pad_len = max_len - len(x) padded.append(torch.cat([x, torch.zeros(pad_len, dtype=x.dtype)])) masks.append(torch.cat([torch.ones(len(x)), torch.zeros(pad_len)])) return { "input_ids": torch.stack(padded), # (B, max_len) "attention_mask": torch.stack(masks), # (B, max_len) "labels": torch.tensor(labels), # (B,) } loader = DataLoader(dataset, batch_size=32, collate_fn=collate_text)
A dataset with two imbalanced classes (90% blue, 10% orange). Toggle between sequential, shuffled, and weighted sampling to see how batch composition changes. Below: variable-length sequences being padded to the same length.
Watch how different sampling strategies change batch composition for an imbalanced dataset. Bottom section shows how variable-length sequences get padded.
A single training step looks simple — forward, loss, backward, update. Four words. But the order matters. Call zero_grad at the wrong time and you erase the gradients you just computed. Forget model.eval() during validation and dropout corrupts your metrics. Skip torch.no_grad() and you waste memory storing computation graphs you'll never use.
Let's trace every operation, in order, with actual numbers. A tiny network. Real weights. Real gradients. No magic.
One training step has exactly 6 operations in a fixed order. Skip one or swap two and training silently breaks.
Let's trace each one in detail.
The input tensor flows through every layer of the model. Each layer applies its weights and produces an output. PyTorch records every operation in a computation graph — a directed acyclic graph of operations that connects inputs to the final output. This graph is what makes backward() possible.
Input shape: (B, features) — for a batch of 32 examples with 784 features each: (32, 784).
Output shape: (B, num_classes) — for 10-class classification: (32, 10).
Every intermediate tensor in the forward pass is stored in memory (because backward needs them). This is why larger batch sizes use more GPU memory — more intermediate tensors to store.
The loss function compares the model's output to the true labels and produces a single scalar. This scalar is the root of the computation graph. When we call backward() on it, gradients flow from this single number back through every operation to every parameter.
The loss tensor has a grad_fn attribute — a reference to the function that created it. This is PyTorch's autograd recording: "this scalar was created by a cross-entropy operation on these inputs."
loss.backward() walks the computation graph in reverse. At each operation, it applies the chain rule to compute the gradient of the loss with respect to that operation's inputs. When it reaches a leaf tensor (a model parameter), it stores the gradient in param.grad.
Critical detail: gradients are accumulated, not replaced. If param.grad already has a value from a previous step, the new gradient is added to it. This is why zero_grad() exists.
The optimizer reads every parameter's .grad attribute and updates .data according to its rule. For SGD: param.data -= lr * param.grad. For Adam: a more complex update involving first and second moment estimates.
After this call, the weights have changed. The model is slightly different from before.
optimizer.zero_grad() sets every parameter's .grad to zero (or None with set_to_none=True, which is slightly faster). This must happen before the next backward() call. Without it, gradients from consecutive steps accumulate — step 2's gradient gets added to step 1's, step 3's gets added to the sum, and so on.
zero_grad() before forward() or after step(). Both work — what matters is that gradients are zero before the next backward(). The PyTorch convention places zero_grad() first in the loop body, before forward(). This reads more clearly: "start fresh, compute, update."Learning rate schedulers adjust the learning rate over the course of training — starting high, decaying over time. scheduler.step() must be called after optimizer.step(). If you call it before, the learning rate changes before the current update is applied, which is subtly wrong (especially for warmup schedules).
Let's trace one step through the smallest possible network: a single linear layer with 2 inputs and 2 outputs. No hidden layers, no activation function. Just raw matrix multiplication.
Step 1 — Forward:
Step 2 — Loss (MSE):
Step 3 — Backward:
First, the gradient of the loss w.r.t. the output:
Now the gradient of the loss w.r.t. the weights. Since y = W·x, the gradient is the outer product:
And the gradient w.r.t. bias is just ∂L/∂y itself: [-1.1, 1.8].
Step 4 — SGD update (lr = 0.01):
Bias: b[0] = 0 + 0.011 = 0.011, b[1] = 0 - 0.018 = -0.018.
Step 5 — Zero grad: Set all .grad tensors to zero. Done.
Verify: Forward again with new weights: y[0] = 0.511 × 1 + (-0.278) × 2 + 0.011 = 0.511 - 0.556 + 0.011 = -0.034. Closer to target 1 than the old -0.1? Yes — the update moved y[0] in the right direction. Loss went from 2.225 to... compute it yourself!
Click through each of the 6 operations. At each stage, see the actual tensors: inputs, outputs, gradients, and updated weights. Blue = forward data flow, orange = backward gradient flow, green = weight update.
Step through a single training iteration on a tiny network. Watch data flow forward (blue), gradients flow backward (orange), and weights update (green).
python import torch import torch.nn as nn # Setup model = nn.Linear(784, 10) # simple classifier criterion = nn.CrossEntropyLoss() # loss function optimizer = torch.optim.SGD(model.parameters(), lr=0.01) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100) # One epoch model.train() # enable dropout, batchnorm in training mode for inputs, targets in train_loader: # 1. Zero gradients (start fresh) optimizer.zero_grad() # 2. Forward pass outputs = model(inputs) # (B, 10) # 3. Compute loss loss = criterion(outputs, targets) # scalar # 4. Backward pass loss.backward() # computes all gradients # 5. Update weights optimizer.step() # applies SGD update # 6. Adjust learning rate scheduler.step() # decays lr per schedule
During validation, you want to compute loss (to track progress) but never update weights. Two critical changes:
python model.eval() # disable dropout, use running stats for batchnorm with torch.no_grad(): # don't build computation graph for inputs, targets in val_loader: outputs = model(inputs) loss = criterion(outputs, targets) # NO backward. NO step. Just measure.
model.eval() disables dropout (which randomly zeroes neurons — useful during training, destructive during evaluation) and switches batch normalization to use stored running statistics instead of batch statistics. torch.no_grad() tells PyTorch not to build the computation graph, which saves memory — there's no backward pass, so no graph is needed.
Forgetting model.eval() means dropout is still active during validation. Your accuracy metrics will be noisy and lower than the model's true performance. Forgetting torch.no_grad() wastes memory but doesn't affect correctness.
loss.backward() without calling optimizer.zero_grad() first?
You've trained for one epoch. Time to check how well the model generalizes. You
run it on the validation set and get... different numbers every time? That's
because dropout is still randomly zeroing neurons. You didn't call
model.eval(). Or you called it but forgot torch.no_grad(),
and now you're accumulating computation graphs that eat all your GPU memory.
Evaluation requires exactly three changes from training. Miss any one of them and your metrics are silently corrupted. No error message. No crash. Just wrong numbers that look plausible enough to waste a week of your time.
Switch 1: model.eval() — toggles layer behavior.
Specifically, it changes Dropout from randomly zeroing neurons to
passing all values through, and it changes BatchNorm from computing
statistics on the current batch to using stored running statistics. Every other
layer — Linear, Conv2d, ReLU, Attention — behaves identically in both modes.
Switch 2: torch.no_grad() — disables gradient tracking.
During training, PyTorch builds a computation graph that records every
operation so it can compute gradients via backprop. This graph uses enormous memory.
For a 7B-parameter model, the computation graph alone can consume 20+ GB of VRAM.
During evaluation, you never call .backward(), so this graph is pure waste.
Wrapping your eval loop in torch.no_grad() tells PyTorch to skip building it.
Switch 3: No optimizer step. This one sounds obvious, but it's the semantic contract: you never update weights on validation data. The purpose of evaluation is to measure generalization, not to train further.
model.eval() does NOT disable gradient
computation. It only changes the behavior of Dropout and BatchNorm. You STILL need
torch.no_grad() to disable gradients. And model.eval()
does NOT freeze the model — you could still call .backward() and update
weights in eval mode. It's purely about layer behavior, not gradient tracking.
These are two completely independent switches.
Let's trace through a concrete example. Your model has a Dropout layer with
p=0.5, meaning it randomly zeroes 50% of neurons during training.
Input tensor: [0.8, 1.2, 0.3, 0.9].
During training (model.train()): Dropout randomly picks
which neurons to kill. Suppose it kills positions 1 and 3. The surviving values get
scaled up by 1/(1-p) = 2 to compensate for the missing
neurons. Result: [1.6, 0.0, 0.6, 0.0]. Run it again and you get
a different mask — maybe [0.0, 2.4, 0.0, 1.8]. Every forward pass
produces different output.
During eval (model.eval()): Dropout becomes a no-op.
All values pass through unchanged: [0.8, 1.2, 0.3, 0.9]. Every forward
pass produces the same output. This is what you want for evaluation —
deterministic, reproducible results.
If you evaluate with model.train() still active, your "accuracy" is
noisy garbage — it fluctuates randomly depending on which neurons get zeroed. Run
the same validation set twice and you get different numbers.
torch.no_grad() vs torch.inference_mode()
Both disable gradient tracking, but torch.inference_mode() is stricter.
It also disallows in-place operations on view tensors (tensors that share storage
with other tensors). For pure evaluation, inference_mode is faster
because the stricter constraints let PyTorch skip more bookkeeping. For evaluation
during training (where you might want gradients later), no_grad is safer.
| Context Manager | Disables Gradients | Allows Views | Speed | Use When |
|---|---|---|---|---|
torch.no_grad() | Yes | Yes | Fast | Eval during training |
torch.inference_mode() | Yes | No in-place on views | Fastest | Pure inference / deployment |
Watch data flow through a network in training mode versus eval mode. On the left,
dropout randomly grays out neurons on every pass — the output changes each time.
On the right, eval mode passes everything through deterministically. Below both
networks, the memory bar shows how no_grad slashes memory usage by
eliminating the computation graph.
Left: training mode with active dropout. Right: eval mode with all neurons active. Click "Forward Pass" to send data through both networks simultaneously. Watch how dropout produces different outputs each time in training mode, but eval mode is deterministic. The memory bar shows the cost of gradient tracking.
python def evaluate(model, val_loader, criterion, device): model.eval() # Switch 1: deterministic layers total_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): # Switch 2: no computation graph for X, y in val_loader: X, y = X.to(device), y.to(device) out = model(X) # Forward only loss = criterion(out, y) total_loss += loss.item() * X.size(0) preds = out.argmax(dim=1) correct += (preds == y).sum().item() total += X.size(0) # Switch 3: no optimizer.step() model.train() # CRITICAL: switch back for next epoch! return total_loss / total, correct / total
model.train() after evaluation means your next training epoch runs
without dropout. Your training loss looks great (no regularization = lower
training loss), but your validation loss gets worse because the model overfits.
The symptom: train loss drops, val loss rises, and you blame the learning rate
instead of a missing one-line call.
Consider a model with 100M parameters (each 4 bytes as FP32). The model weights use 400 MB. During a forward pass, the computation graph stores intermediate activations for backprop. For a typical transformer with 12 layers, activations use roughly 2-3x the model size — about 800 MB to 1.2 GB.
With torch.no_grad(): no graph is stored. Memory usage is just the
model weights (400 MB) plus the current batch's tensors. For a 7B model, that's
the difference between 28 GB (with graph) and 14 GB (without) — the difference
between fitting on one GPU or needing two.
model.eval() actually change in a PyTorch model?Your training loss isn't going down. There's no error message. The code runs fine. The bug is invisible — one misplaced line, and your model trains on garbage while reporting success. This chapter shows you the six most common silent bugs so you can spot them before they waste a week of GPU time.
These bugs are silent — they don't crash, don't raise exceptions, and often produce plausible-looking metrics. They're the reason experienced ML engineers check gradient norms, weight distributions, and learning rate schedules before they ever look at accuracy.
Recall that loss.backward() accumulates gradients into
each parameter's .grad attribute. It adds to whatever was already there.
If you don't call optimizer.zero_grad() before each backward pass, the
gradients from step 1 persist into step 2, step 3, step 4...
The effective gradient at step N becomes the sum of ALL previous gradients. By step 100, you're applying a gradient that's 100x larger than intended, pointing in some random averaged direction. Training becomes unstable: the loss oscillates wildly, then diverges.
Two failure modes: (A) Evaluating with model.train() — dropout randomly
zeroes neurons during validation, making your accuracy noisy and unreliable.
(B) Training with model.eval() — no dropout means less regularization,
and BatchNorm uses stale running statistics instead of fresh batch statistics.
Both silently degrade performance with no error message.
You compute model output and store it for logging: all_preds.append(output).
Seems innocent. But output is connected to the computation graph —
it holds references to every intermediate tensor from the forward pass. Python's
garbage collector can't free any of it because your list holds a reference.
Memory grows linearly with each step. After 1000 steps, you have 1000 computation
graphs in memory. Then: CUDA out of memory. The fix is one method call:
all_preds.append(output.detach()). The .detach() method
creates a new tensor that shares the same data but is disconnected from the graph.
You construct your optimizer, then later freeze some layers:
python optimizer = Adam(model.parameters(), lr=1e-3) # ... later ... for p in model.backbone.parameters(): p.requires_grad = False # Freeze backbone # BUG: optimizer still has backbone params! # It updates them with zero gradients (wasted compute) # Or worse: if you add new parameters, optimizer doesn't know
The optimizer captured a snapshot of parameters at construction time. Changing
requires_grad after that doesn't remove them from the optimizer's
parameter groups. The right fix: construct the optimizer AFTER deciding which
parameters to train, and pass only the trainable ones.
Your validation set overlaps with your training set. Or your data preprocessing computes statistics (mean, std) across the entire dataset including validation. Or your data augmentation pipeline applies the same random crop to train and test images, so the model memorizes crop positions.
The symptom: validation accuracy is 95%. You deploy to production and get 70%. The model never learned to generalize — it learned to memorize. No code error. No warning. Just a model that looks great on paper and fails in the real world.
Model is on GPU. Data arrives from the DataLoader on CPU. In most cases, PyTorch
raises an error when CPU and GPU tensors interact. But if you accidentally compute
the loss on CPU — for example, by forgetting .to(device) on the
targets — training "works" but runs 10x slower because the loss computation
and its backward pass execute on CPU instead of GPU. GPU utilization sits at
30% and you blame the DataLoader.
zero_grad for K steps, then call it. The bug becomes
a feature. The critical difference: intentional accumulation divides the loss by K
(loss = loss / K) before .backward() to maintain correct
gradient magnitude. Accidental accumulation doesn't scale — the gradients just
explode.
This simulation trains a small network with toggleable bugs. Enable each bug and watch the training loss curve change. Normal training converges smoothly. Each bug has a unique visual signature — oscillation, divergence, memory explosion, or suspiciously good metrics. Learn these signatures and you'll diagnose training failures at a glance.
Toggle each bug on/off and watch the training loss curve. Enable one bug at a time to learn its signature. Then try combining bugs to see how they interact.
| Bug | Loss Curve Signature | Other Symptom | Fix |
|---|---|---|---|
| No zero_grad | Oscillates, then explodes | Gradient norm grows linearly | optimizer.zero_grad() before backward |
| Wrong eval mode | Val accuracy is noisy | Same val set gives different results | model.eval() + model.train() |
| No detach | Normal loss, then OOM crash | Memory grows linearly per step | .detach() or .item() |
| Stale optimizer | Frozen layers still have gradients | Extra compute, wrong updates | Construct optimizer after freezing |
| Data leakage | Perfect val, terrible production | Val and train metrics suspiciously close | Split data before any preprocessing |
| Wrong device | Normal convergence, 10x slower | GPU util is low | .to(device) on all tensors |
Time to put it all together. This is the payoff — a complete training loop running live in your browser. You configure the dataset, batch size, learning rate, optimizer, and scheduler. You watch training progress step by step. And you can inject every bug from the previous chapter to see its effect in real time.
This isn't a toy. The simulator runs actual forward passes, computes real losses, tracks real gradients, and updates real weights. The dataset is synthetic (a 2D classification spiral), but every component of the loop is identical to what runs on GPU clusters — just at a scale your browser can handle.
Configure, train, and break a neural network. Every component is real — the dataset, the forward pass, the gradients, the weight updates.
Experiment 1: Learning rate sweep. Start at LR = 0.01 with Adam. Train 5 epochs — loss drops smoothly. Now reset and set LR = 1.0. The loss immediately jumps to NaN. The network's weights overflow because the updates are too large. The critical insight: Adam's adaptive learning rate helps, but it can't save you from a globally insane LR.
Experiment 2: SGD vs Adam. Reset. Set optimizer to SGD, LR = 0.01. Train 10 epochs. The loss drops slowly, especially on the spiral arms. Now switch to Adam at the same LR. It converges 3-5x faster because Adam adapts the learning rate per-parameter — dimensions with small gradients get bigger steps.
Experiment 3: The zero_grad bug. Reset. Enable "Skip zero_grad." Train. The loss oscillates for the first few epochs, then explodes. Compare the gradient norm readout between bugged and normal runs — it grows linearly with steps, confirming gradients are accumulating.
Experiment 4: Memory leak. Reset. Enable "No no_grad." Train several epochs. Watch the memory bar climb with every evaluation step. In a real training run, this would hit OOM within minutes. Disable it and watch memory stay flat.
Experiment 5: Cosine scheduling. Reset. Set scheduler to "Cosine." The LR starts at your configured value and decays toward zero over training. Compare the final loss: cosine typically reaches a lower minimum than constant LR because the shrinking step size lets the optimizer settle into a sharper minimum.
The 6-line training loop from Chapter 3 is a toy. A real training script has checkpointing (save model every N steps), early stopping (stop when validation loss plateaus), mixed precision (BF16 forward, FP32 weight update), gradient accumulation, distributed training hooks, experiment logging (WandB or TensorBoard), reproducibility (seed everything), and progress bars. This chapter shows you the loop that actually runs on GPU clusters.
We'll build from the minimal loop to the production loop, adding one component at a time. Each component solves a real problem — nothing is decoration. By the end, you'll understand every line of a production training script.
Training a large model takes hours to days. If your machine crashes at hour 23 of a 24-hour run, you've lost everything — unless you saved checkpoints. A checkpoint captures the complete training state so you can resume exactly where you left off.
What to save:
model.state_dict() is NOT enough for resuming training.
You also need the optimizer state dict — Adam's moment estimates (the running
averages of gradients and squared gradients) take hundreds of steps to rebuild. If
you resume without them, the optimizer resets to zero moments, the learning rate
schedule jumps to the wrong position, and training quality degrades. "Resuming"
without full state is actually starting a new run from pretrained weights.
python def save_checkpoint(model, optimizer, scheduler, epoch, step, best_val, path): torch.save({ 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'epoch': epoch, 'step': step, 'best_val': best_val, 'rng_torch': torch.get_rng_state(), 'rng_cuda': torch.cuda.get_rng_state_all(), 'rng_numpy': np.random.get_state(), 'rng_python': random.getstate(), }, path) def load_checkpoint(path, model, optimizer, scheduler): ckpt = torch.load(path, weights_only=False) model.load_state_dict(ckpt['model']) optimizer.load_state_dict(ckpt['optimizer']) scheduler.load_state_dict(ckpt['scheduler']) torch.set_rng_state(ckpt['rng_torch']) return ckpt['epoch'], ckpt['step'], ckpt['best_val']
You don't know in advance how many epochs to train. Too few and the model underfits. Too many and it overfits — the training loss keeps dropping but the validation loss starts rising. Early stopping tracks validation loss and stops training when it hasn't improved for P epochs (the patience).
python class EarlyStopping: def __init__(self, patience=5, min_delta=1e-4): self.patience = patience self.min_delta = min_delta self.best = float('inf') self.counter = 0 def __call__(self, val_loss): if val_loss < self.best - self.min_delta: self.best = val_loss self.counter = 0 return False # Keep training self.counter += 1 return self.counter >= self.patience # Stop?
Your GPU has 8 GB. Your batch size of 64 runs out of memory. You could use batch size 8 — but that gives noisier gradients and often worse convergence. Gradient accumulation lets you simulate batch size 64 using 8 forward passes of batch size 8 each.
The trick: skip optimizer.step() and zero_grad() for K-1
forward passes. Gradients accumulate naturally (because .backward()
adds). On the K-th pass, step and zero. The only subtlety: divide the loss by K
before calling .backward() so the total gradient magnitude
matches what you'd get from a single large batch.
python accum_steps = 8 # micro_batch=8, effective_batch=64 for i, (X, y) in enumerate(train_loader): loss = criterion(model(X), y) / accum_steps # Scale! loss.backward() # Accumulate if (i + 1) % accum_steps == 0: optimizer.step() # Update optimizer.zero_grad() # Clear
Modern GPUs have special hardware (Tensor Cores) that compute 16-bit operations
2-4x faster than 32-bit. Mixed precision training runs the forward
pass in BFloat16 (faster, less memory) but keeps the weight update in FP32
(more precise). PyTorch's autocast handles the switching automatically.
python scaler = torch.amp.GradScaler() for X, y in train_loader: with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16): loss = criterion(model(X), y) # Forward in BF16 scaler.scale(loss).backward() # Backward with scaling scaler.step(optimizer) # Unscale + step in FP32 scaler.update() optimizer.zero_grad()
Seed everything. PyTorch, NumPy, Python's random module, CUDA. Without this, re-running the same script gives different results because weight initialization, data shuffling, and dropout masks all depend on random number generators.
python def seed_everything(seed=42): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False
cudnn.deterministic = True disables cuDNN's
auto-tuner, which picks the fastest convolution algorithm for your input sizes.
Deterministic mode can be 10-30% slower. Use it for debugging and final results;
disable it for development speed.
This interactive flowchart shows the complete production training loop. Toggle components on and off to see how each one integrates into the loop. Hover over any box to see the actual code for that component.
Toggle components to build up from the minimal loop to the full production version. The flowchart updates live. Hover over boxes to see code.
python def train(config): seed_everything(config.seed) # Reproducibility model = build_model(config).to(config.device) optimizer = Adam(model.parameters(), lr=config.lr) scheduler = CosineAnnealingLR(optimizer, T_max=config.epochs) scaler = torch.amp.GradScaler() # Mixed precision stopper = EarlyStopping(patience=config.patience) start_epoch, best_val = 0, float('inf') # Resume from checkpoint if exists if config.resume and Path(config.ckpt_path).exists(): start_epoch, _, best_val = load_checkpoint( config.ckpt_path, model, optimizer, scheduler) for epoch in range(start_epoch, config.epochs): model.train() for i, (X, y) in enumerate(train_loader): X, y = X.to(config.device), y.to(config.device) with torch.amp.autocast('cuda', torch.bfloat16): loss = criterion(model(X), y) / config.accum_steps scaler.scale(loss).backward() if (i + 1) % config.accum_steps == 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_( # Clip gradients model.parameters(), config.max_grad_norm) scaler.step(optimizer) scaler.update() optimizer.zero_grad() scheduler.step() # LR schedule # Evaluate val_loss, val_acc = evaluate(model, val_loader, criterion, config.device) # Log metrics wandb.log({'train_loss': loss.item(), 'val_loss': val_loss, 'val_acc': val_acc, 'lr': scheduler.get_last_lr()[0]}) # Checkpoint if val_loss < best_val: best_val = val_loss save_checkpoint(model, optimizer, scheduler, epoch, i, best_val, config.ckpt_path) # Early stopping if stopper(val_loss): print(f"Early stop at epoch {epoch}") break
That's 45 lines. Every line solves a real problem. The minimal loop was 6 lines. The difference — seeding, mixed precision, gradient accumulation, gradient clipping, checkpointing, early stopping, logging — is what separates a research prototype from a production training script.
| Component | Lines Added | Memory Impact | Speed Impact | What It Prevents |
|---|---|---|---|---|
| Checkpointing | ~10 | +Disk (model_size x2) | ~1% (I/O) | Lost training on crash |
| Early Stopping | ~8 | None | Saves hours (stops early) | Overfitting, wasted compute |
| Grad Accumulation | ~4 | -60% (smaller micro-batch) | ~Same (same total FLOPs) | OOM on large effective batches |
| Mixed Precision | ~5 | -40% (BF16 activations) | +50-100% faster | Slow training, memory limits |
| Seeding | ~6 | None | -10-30% (deterministic) | Irreproducible results |
| Logging (WandB) | ~3 | None | Negligible | Flying blind |
| Gradient Clipping | ~2 | None | Negligible | Gradient explosion |
You've learned every component of a training loop. Now prove it. The Arena gives you a dataset, a model, and a goal — configure the training loop to hit the target accuracy in the fewest steps. This isn't a quiz with one right answer. It's an optimization problem with tradeoffs: larger batches converge smoother but take more memory. Adam converges faster but might overfit. Cosine scheduling helps at the end but wastes early capacity.
Each configuration choice affects convergence speed, memory usage, and final accuracy. The scoreboard tracks your best run. Can you beat the baseline?
Configure a training loop and race to the target accuracy. Scoreboard tracks your best runs.
Everything from the entire lesson on one page. Print this. Pin it next to your monitor. Reference it every time you write a training loop until the patterns become muscle memory.
| Term | Definition | Formula |
|---|---|---|
| Dataset size (N) | Total training examples | — |
| Batch size (B) | Examples per forward pass | — |
| Iteration / Step | One forward + backward + update | — |
| Epoch | One full pass through the dataset | ceil(N / B) iterations |
| Total steps | All gradient updates in a run | epochs × ceil(N / B) |
| Effective batch | With gradient accumulation (K steps) | B × K |
| Last batch | Partial batch at epoch end | N mod B (0 if divides evenly) |
| drop_last | Skip the partial final batch | floor(N / B) batches instead of ceil |
| Micro-batch | Actual batch per forward pass (with accum) | B (the physical batch) |
python model.eval() # Disable dropout, use BN running stats with torch.no_grad(): # Don't build computation graph for X, y in val_loader: out = model(X) loss = criterion(out, y) # Measure only — no backward, no step model.train() # ALWAYS switch back!
python model.train() for epoch in range(num_epochs): for i, (X, y) in enumerate(train_loader): X, y = X.to(device), y.to(device) optimizer.zero_grad() # 1. Clear out = model(X) # 2. Forward loss = criterion(out, y) # 3. Loss loss.backward() # 4. Backward optimizer.step() # 5. Update scheduler.step() # 6. Schedule # Evaluate at end of each epoch val_loss, val_acc = evaluate(model, val_loader, criterion, device)
| Bug | Symptom | Fix |
|---|---|---|
Missing zero_grad() | Gradients accumulate, loss oscillates then explodes | optimizer.zero_grad() before each backward |
Missing model.eval() | Noisy validation metrics (dropout still active) | model.eval() before val, model.train() after |
Missing no_grad() | Memory grows linearly during eval → OOM | with torch.no_grad(): wrapping eval loop |
Missing .detach() | Storing outputs leaks computation graphs → OOM | output.detach() or loss.item() |
| Stale optimizer | Frozen layers still consuming compute | Create optimizer AFTER freezing layers |
| Data leakage | 95% val accuracy, 70% production accuracy | Split data before any preprocessing |
| Wrong device | Normal loss but 10x slower, low GPU util | .to(device) on ALL tensors |
| Argument | Default | When to Change |
|---|---|---|
batch_size | 1 | Always — set to largest power of 2 that fits in memory |
shuffle | False | True for training, False for eval |
num_workers | 0 | 4-8 for disk I/O heavy data (images, audio) |
pin_memory | False | True when using GPU (faster CPU→GPU transfer) |
drop_last | False | True when using BatchNorm with small last batch |
collate_fn | default_collate | Custom when sequences have variable length (padding) |
sampler | Sequential/Random | WeightedRandomSampler for imbalanced classes |
persistent_workers | False | True with num_workers > 0 to avoid respawning |