Forward, loss, backward, step — the four-beat heartbeat of every neural network ever trained.
Keras has model.fit(). scikit-learn has .fit(). You call one function and the model trains. So why does PyTorch make you write a loop by hand?
Because control IS the feature. When training goes wrong — and it will — you need to inspect every gradient, every batch, every parameter update. You need to add custom logging, gradient clipping, mixed precision, curriculum learning. A black box can't give you that.
Here's the proof. Watch this training run. The loss starts decreasing, then suddenly explodes. In Keras, you'd stare at a progress bar and wonder what happened. In PyTorch, you can pause at step 47, inspect the gradients, and find the one layer whose weights grew too large.
Watch the loss curve. It starts well, then something breaks. With loop-level access, you can see exactly which step went wrong.
The four-beat cycle that every training loop follows:
The forward pass is the simplest step: you feed data into your model and get a prediction out. In PyTorch, it looks like this:
python output = model(x) # That's it. One line.
But what's happening under the hood? When you call model(x), Python invokes the model's __call__ method, which does two things: (1) runs any registered hooks (for debugging/logging), then (2) calls your forward() method.
Let's trace data through a concrete network. Suppose we're classifying handwritten digits (MNIST). Our input is a flattened 28x28 image — a tensor of shape [batch_size, 784]. Our model is a 3-layer MLP:
python import torch.nn as nn class MLP(nn.Module): def __init__(self): super().__init__() self.layers = nn.Sequential( nn.Linear(784, 128), # 784 inputs → 128 hidden nn.ReLU(), # activation nn.Linear(128, 10), # 128 hidden → 10 classes ) def forward(self, x): return self.layers(x)
Here's what happens to the shapes at every step:
Watch data flow through the network. Each layer transforms the shape. Adjust the batch size to see how it affects dimensions.
A key detail: during the forward pass, PyTorch is secretly building a computation graph — recording every operation so it can later compute gradients. This graph is what makes loss.backward() possible.
python # Every tensor operation is recorded x = torch.randn(32, 784) # input out = model(x) # forward pass builds the graph print(out.shape) # torch.Size([32, 10]) print(out.grad_fn) # <AddmmBackward0> — proof of graph
The forward pass gives us predictions. The loss function tells us how wrong those predictions are — a single number that we want to minimize. Different problems need different loss functions.
For classification, we use Cross-Entropy Loss. Under the hood, it's two steps:
Let's work through a concrete example. Suppose our model outputs logits [2.0, 1.0, 0.1] for 3 classes, and the true class is 0.
python import torch import torch.nn as nn # Our logits and target logits = torch.tensor([[2.0, 1.0, 0.1]]) # [1, 3] target = torch.tensor([0]) # class 0 # PyTorch combines softmax + NLL in one function loss_fn = nn.CrossEntropyLoss() loss = loss_fn(logits, target) print(loss) # tensor(0.4170) # Manual computation matches: import math probs = torch.softmax(logits, dim=1) # [0.659, 0.242, 0.099] manual_loss = -math.log(0.659) # 0.417
For regression (predicting continuous values), we use Mean Squared Error: the average of squared differences between predictions and targets.
python prediction = torch.tensor([2.5, 3.1, 4.0]) target = torch.tensor([3.0, 3.0, 4.0]) loss = nn.MSELoss()(prediction, target) # = ((2.5-3)^2 + (3.1-3)^2 + (4-4)^2) / 3 # = (0.25 + 0.01 + 0) / 3 = 0.0867
Drag the slider to change the model's predicted probability for the correct class. Watch how loss responds — steep penalty for confident wrong answers.
We have a loss — a single number measuring how wrong we are. Now we need to answer: which weights caused the most error, and in which direction should we nudge them?
The answer is gradients. The gradient of the loss with respect to a weight tells you: "if you increase this weight by a tiny amount, the loss changes by this much." In PyTorch, one line computes ALL gradients simultaneously:
python loss.backward() # Fills .grad for every parameter in the model
After this call, every parameter's .grad attribute contains ∂loss/∂parameter. Let's verify with a tiny network:
python # Tiny network: one weight, one bias w = torch.tensor([2.0], requires_grad=True) b = torch.tensor([1.0], requires_grad=True) # Forward: y = w*x + b x = torch.tensor([3.0]) y = w * x + b # y = 2*3 + 1 = 7 # Loss: MSE with target=5 target = torch.tensor([5.0]) loss = (y - target)**2 # (7-5)^2 = 4 # Backward loss.backward() print(w.grad) # tensor([12.0]) print(b.grad) # tensor([4.0])
Autograd is PyTorch's automatic differentiation engine. During the forward pass, it records every operation in a directed acyclic graph (DAG). During .backward(), it traverses this graph in reverse, applying the chain rule at each node.
Adjust w and x to see how gradients change. The gradient ∂loss/∂w tells us the direction to move w to reduce loss.
.backward() accumulates gradients. If you call it twice, gradients double. This is why you must call optimizer.zero_grad() before each backward pass — we'll see this in Chapter 5.We have gradients. Now we need to actually update the weights. The simplest rule: move each weight in the opposite direction of its gradient, scaled by a learning rate.
python lr = 0.01 for param in model.parameters(): param.data -= lr * param.grad # That's SGD. Really.
That's Stochastic Gradient Descent (SGD). The weight update rule is:
Where η is the learning rate and ∇L is the gradient of the loss with respect to w.
In practice, you use PyTorch's optimizer:
python optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # ... after computing gradients: optimizer.step() # applies the update rule to all parameters
Adam improves on SGD with two ideas:
python # Adam maintains two moving averages per parameter: # m = momentum (1st moment of gradient) # v = squared gradient (2nd moment) optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) # 3e-4 is the "Karpathy constant" — a good default for Adam
Watch how SGD (orange) oscillates while Adam (teal) takes smoother, more direct paths. Click Reset to try a different starting point.
| Optimizer | Update Rule | When to Use |
|---|---|---|
| SGD | w -= lr * grad | Simple problems, want best generalization |
| SGD+Momentum | v = 0.9*v + grad; w -= lr*v | Faster SGD, smooths oscillations |
| Adam | Adaptive lr + momentum | Default choice, fast convergence |
| AdamW | Adam + decoupled weight decay | Transformers, modern best practice |
Now we assemble the four atoms into the complete training loop. The order of operations is critical — get it wrong and your model either doesn't learn or learns garbage.
python for epoch in range(num_epochs): for batch_x, batch_y in dataloader: # 1. Forward pass predictions = model(batch_x) # 2. Compute loss loss = loss_fn(predictions, batch_y) # 3. Zero gradients (BEFORE backward) optimizer.zero_grad() # 4. Backward pass loss.backward() # 5. Update weights optimizer.step()
Gradients accumulate. After step 1, grad = 2.0. After step 2, grad = 2.0 + 1.8 = 3.8. After step 3, grad = 3.8 + 1.5 = 5.3. The gradient grows each iteration, and your updates become wildly wrong.
python # BUG: forgot zero_grad for batch_x, batch_y in dataloader: predictions = model(batch_x) loss = loss_fn(predictions, batch_y) # optimizer.zero_grad() ← MISSING! loss.backward() # gradients ADD to existing .grad optimizer.step() # updates use inflated gradients # loss may oscillate wildly or explode
Compare training with zero_grad (teal, correct) vs without it (red, broken). Notice how the red curve becomes unstable.
python model = MLP() optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) loss_fn = nn.CrossEntropyLoss() for epoch in range(10): total_loss = 0 for batch_x, batch_y in train_loader: predictions = model(batch_x) loss = loss_fn(predictions, batch_y) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() print(f"Epoch {epoch}: loss = {total_loss/len(train_loader):.4f}")
In the previous chapter, we glossed over where train_loader comes from. The DataLoader handles three critical jobs: (1) splitting your dataset into mini-batches, (2) shuffling the order each epoch, and (3) loading data in parallel.
You could compute the loss on ALL training examples at once (full-batch gradient descent). But this is memory-prohibitive for large datasets. You could compute it on ONE example at a time (stochastic GD). But this is too noisy. Mini-batches (typically 32-512 examples) are the sweet spot: stable gradients with reasonable memory.
python import numpy as np def simple_dataloader(X, y, batch_size, shuffle=True): """Yields batches of (X_batch, y_batch).""" n = len(X) indices = np.arange(n) if shuffle: np.random.shuffle(indices) for start in range(0, n, batch_size): idx = indices[start:start + batch_size] yield X[idx], y[idx] # Usage: for batch_x, batch_y in simple_dataloader(X_train, y_train, 32): # train on this batch pass
python from torch.utils.data import DataLoader, TensorDataset # Wrap tensors in a Dataset dataset = TensorDataset(X_train, y_train) # Create DataLoader train_loader = DataLoader( dataset, batch_size=32, # examples per batch shuffle=True, # re-order each epoch num_workers=4, # parallel data loading processes pin_memory=True, # faster GPU transfer drop_last=True, # drop incomplete final batch )
Watch training with shuffled batches (teal) vs sorted batches (red). The sorted version oscillates because it keeps "forgetting" earlier classes.
| Parameter | What it does | Typical value |
|---|---|---|
| batch_size | Examples per gradient update | 32-512 |
| shuffle | Randomize order each epoch | True for train, False for eval |
| num_workers | Parallel data loading processes | 4-8 (CPU cores) |
| pin_memory | Pre-allocate GPU-mapped memory | True if using GPU |
| drop_last | Skip last incomplete batch | True for training |
Time to put it all together. Below is a fully interactive neural network training simulation running in your browser. You're training a tiny 2-layer network to classify two concentric circles (inner ring vs outer ring).
A 2-layer neural network learning to separate concentric circles. Left: decision boundary. Right: loss curve.
What you're seeing:
So far we've used a fixed learning rate. But in practice, you want the learning rate to change over training. Start high (make big moves early), end low (fine-tune without overshooting).
A scheduler adjusts the learning rate based on a schedule. The three most common:
python from torch.optim.lr_scheduler import ( StepLR, CosineAnnealingLR, OneCycleLR ) # 1. Step decay: multiply lr by 0.1 every 30 epochs scheduler = StepLR(optimizer, step_size=30, gamma=0.1) # 2. Cosine annealing: smooth decay from lr to 0 scheduler = CosineAnnealingLR(optimizer, T_max=100) # 3. One-cycle: warmup → peak → decay (best for modern training) scheduler = OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(train_loader), epochs=100)
Compare different scheduling strategies. The x-axis is training progress (0% to 100%).
Training can take hours or days. If your machine crashes at epoch 95 of 100, you want to resume from where you stopped. Checkpoints save the full training state:
python # Save checkpoint torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'loss': loss, }, 'checkpoint.pt') # Load checkpoint ckpt = torch.load('checkpoint.pt') model.load_state_dict(ckpt['model_state_dict']) optimizer.load_state_dict(ckpt['optimizer_state_dict']) scheduler.load_state_dict(ckpt['scheduler_state_dict']) start_epoch = ckpt['epoch'] + 1
Monitor validation loss. If it hasn't improved for patience epochs, stop training — you're overfitting.
python best_val_loss = float('inf') patience = 10 counter = 0 for epoch in range(max_epochs): train_one_epoch(model, train_loader, optimizer) val_loss = evaluate(model, val_loader) if val_loss < best_val_loss: best_val_loss = val_loss counter = 0 torch.save(model.state_dict(), 'best_model.pt') else: counter += 1 if counter >= patience: print("Early stopping!") break
Modern GPUs have specialized hardware for float16 math. Mixed precision uses float16 for forward/backward (fast) and float32 for weight updates (accurate):
python scaler = torch.amp.GradScaler() for batch_x, batch_y in train_loader: with torch.amp.autocast(device_type='cuda'): predictions = model(batch_x) loss = loss_fn(predictions, batch_y) optimizer.zero_grad() scaler.scale(loss).backward() # scale loss to prevent underflow scaler.step(optimizer) # unscale grads, then step scaler.update()
You now understand every component of the PyTorch training loop. Let's consolidate with a production-ready template, a bug reference, and connections to deeper topics.
python import torch import torch.nn as nn from torch.utils.data import DataLoader def train(model, train_loader, val_loader, epochs=100, lr=3e-4, device='cuda'): model = model.to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) loss_fn = nn.CrossEntropyLoss() scaler = torch.amp.GradScaler() best_val = float('inf') for epoch in range(epochs): model.train() for x, y in train_loader: x, y = x.to(device), y.to(device) with torch.amp.autocast(device_type='cuda'): loss = loss_fn(model(x), y) optimizer.zero_grad() scaler.scale(loss).backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) scaler.step(optimizer) scaler.update() scheduler.step() # Validation model.eval() val_loss = 0 with torch.no_grad(): for x, y in val_loader: x, y = x.to(device), y.to(device) val_loss += loss_fn(model(x), y).item() val_loss /= len(val_loader) if val_loss < best_val: best_val = val_loss torch.save(model.state_dict(), 'best.pt') print(f"Epoch {epoch}: val_loss={val_loss:.4f} lr={scheduler.get_last_lr()[0]:.6f}") model.load_state_dict(torch.load('best.pt')) return model
| Bug | Symptom | Fix |
|---|---|---|
| Forgot zero_grad() | Loss oscillates wildly, grows over time | Add optimizer.zero_grad() before backward() |
| Wrong loss function | Loss doesn't decrease, model outputs garbage | Use CrossEntropyLoss for classification, MSELoss for regression |
| Data on wrong device | RuntimeError: tensors on different devices | Move both model and data to same device: .to(device) |
| Not calling model.train() | Dropout/BatchNorm behave wrong | Call model.train() before training, model.eval() before validation |
| LR too high | Loss explodes to inf/NaN | Reduce lr by 10x, add gradient clipping |
| LR too low | Loss barely moves after many epochs | Increase lr by 10x, or use a scheduler with warmup |
| No torch.no_grad() in eval | Slow validation, memory leak | Wrap eval loop in with torch.no_grad() |
| Mismatched shapes | RuntimeError: mat1 and mat2 shapes | Print shapes at each layer, check input dimensions |
"What I cannot create, I do not understand." — Richard Feynman