Deep Learning Foundations

The PyTorch Training
Loop

Forward, loss, backward, step — the four-beat heartbeat of every neural network ever trained.

Prerequisites: Basic Python + What a neural network is. That's it.
10
Chapters
8+
Simulations
0
Assumed Knowledge

Chapter 0: Why a Training Loop?

Keras has model.fit(). scikit-learn has .fit(). You call one function and the model trains. So why does PyTorch make you write a loop by hand?

Because control IS the feature. When training goes wrong — and it will — you need to inspect every gradient, every batch, every parameter update. You need to add custom logging, gradient clipping, mixed precision, curriculum learning. A black box can't give you that.

Here's the proof. Watch this training run. The loss starts decreasing, then suddenly explodes. In Keras, you'd stare at a progress bar and wonder what happened. In PyTorch, you can pause at step 47, inspect the gradients, and find the one layer whose weights grew too large.

The philosophy: PyTorch gives you the atoms — forward pass, loss computation, gradient calculation, parameter update — and trusts you to assemble them. This explicitness is what makes debugging possible and research reproducible.
A Training Run Gone Wrong

Watch the loss curve. It starts well, then something breaks. With loop-level access, you can see exactly which step went wrong.

The four-beat cycle that every training loop follows:

Forward
Pass data through the model to get predictions
Loss
Measure how wrong the predictions are
Backward
Compute gradients: how should each weight change?
Step
Update weights in the direction that reduces loss
↻ repeat until converged
Why does PyTorch require you to write a training loop instead of providing model.fit()?

Chapter 1: Forward Pass

The forward pass is the simplest step: you feed data into your model and get a prediction out. In PyTorch, it looks like this:

python
output = model(x)  # That's it. One line.

But what's happening under the hood? When you call model(x), Python invokes the model's __call__ method, which does two things: (1) runs any registered hooks (for debugging/logging), then (2) calls your forward() method.

Let's trace data through a concrete network. Suppose we're classifying handwritten digits (MNIST). Our input is a flattened 28x28 image — a tensor of shape [batch_size, 784]. Our model is a 3-layer MLP:

python
import torch.nn as nn

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(784, 128),   # 784 inputs → 128 hidden
            nn.ReLU(),               # activation
            nn.Linear(128, 10),    # 128 hidden → 10 classes
        )

    def forward(self, x):
        return self.layers(x)

Here's what happens to the shapes at every step:

Input
[batch, 784] — flattened image pixels
↓ Linear(784, 128): multiply by W1[784×128], add b1[128]
Hidden
[batch, 128] — learned features
↓ ReLU: max(0, x) elementwise
Activated
[batch, 128] — shape unchanged, negatives zeroed
↓ Linear(128, 10): multiply by W2[128×10], add b2[10]
Output (logits)
[batch, 10] — raw scores for each digit class
Logits vs probabilities: The output of the last linear layer is called logits — raw, unnormalized scores. They can be any real number (positive or negative). To turn them into probabilities, you'd apply softmax. But in PyTorch, the loss function often does this internally for numerical stability.
Forward Pass Shape Tracker

Watch data flow through the network. Each layer transforms the shape. Adjust the batch size to see how it affects dimensions.

Batch size 32

A key detail: during the forward pass, PyTorch is secretly building a computation graph — recording every operation so it can later compute gradients. This graph is what makes loss.backward() possible.

python
# Every tensor operation is recorded
x = torch.randn(32, 784)  # input
out = model(x)               # forward pass builds the graph
print(out.shape)             # torch.Size([32, 10])
print(out.grad_fn)           # <AddmmBackward0> — proof of graph
If input shape is [16, 784] and the first layer is Linear(784, 128), what is the output shape of that layer?

Chapter 2: Loss Functions

The forward pass gives us predictions. The loss function tells us how wrong those predictions are — a single number that we want to minimize. Different problems need different loss functions.

CrossEntropyLoss (classification)

For classification, we use Cross-Entropy Loss. Under the hood, it's two steps:

  1. Softmax: convert raw logits to probabilities that sum to 1
  2. Negative log-likelihood: penalize low probability assigned to the correct class

Let's work through a concrete example. Suppose our model outputs logits [2.0, 1.0, 0.1] for 3 classes, and the true class is 0.

Worked example:
Logits: [2.0, 1.0, 0.1]
Step 1 — Softmax: exp([2.0, 1.0, 0.1]) = [7.389, 2.718, 1.105] → sum = 11.212
Probabilities: [7.389/11.212, 2.718/11.212, 1.105/11.212] = [0.659, 0.242, 0.099]
Step 2 — NLL: target is class 0, so loss = -log(0.659) = 0.417

If the model had assigned probability 0.99 to class 0, loss would be -log(0.99) = 0.01. Lower probability → higher loss.
python
import torch
import torch.nn as nn

# Our logits and target
logits = torch.tensor([[2.0, 1.0, 0.1]])  # [1, 3]
target = torch.tensor([0])                    # class 0

# PyTorch combines softmax + NLL in one function
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, target)
print(loss)  # tensor(0.4170)

# Manual computation matches:
import math
probs = torch.softmax(logits, dim=1)  # [0.659, 0.242, 0.099]
manual_loss = -math.log(0.659)         # 0.417

MSELoss (regression)

For regression (predicting continuous values), we use Mean Squared Error: the average of squared differences between predictions and targets.

MSE = (1/N) ∑i (ŷi − yi
python
prediction = torch.tensor([2.5, 3.1, 4.0])
target = torch.tensor([3.0, 3.0, 4.0])
loss = nn.MSELoss()(prediction, target)
# = ((2.5-3)^2 + (3.1-3)^2 + (4-4)^2) / 3
# = (0.25 + 0.01 + 0) / 3 = 0.0867
Loss Landscape

Drag the slider to change the model's predicted probability for the correct class. Watch how loss responds — steep penalty for confident wrong answers.

P(correct) 0.66
Why -log? The logarithm creates an asymmetric penalty. If the model assigns probability 0.9 to the correct class, loss is only 0.105. But if it assigns 0.01 — catastrophically wrong — loss is 4.605. This strongly punishes confident wrong predictions.
For CrossEntropyLoss with logits [5.0, 0.0, 0.0] and target class 0, the softmax probability for class 0 will be close to 1. What will the loss be?

Chapter 3: Backward Pass

We have a loss — a single number measuring how wrong we are. Now we need to answer: which weights caused the most error, and in which direction should we nudge them?

The answer is gradients. The gradient of the loss with respect to a weight tells you: "if you increase this weight by a tiny amount, the loss changes by this much." In PyTorch, one line computes ALL gradients simultaneously:

python
loss.backward()  # Fills .grad for every parameter in the model

After this call, every parameter's .grad attribute contains ∂loss/∂parameter. Let's verify with a tiny network:

python
# Tiny network: one weight, one bias
w = torch.tensor([2.0], requires_grad=True)
b = torch.tensor([1.0], requires_grad=True)

# Forward: y = w*x + b
x = torch.tensor([3.0])
y = w * x + b        # y = 2*3 + 1 = 7

# Loss: MSE with target=5
target = torch.tensor([5.0])
loss = (y - target)**2  # (7-5)^2 = 4

# Backward
loss.backward()

print(w.grad)  # tensor([12.0])
print(b.grad)  # tensor([4.0])
Let's verify by hand:
loss = (w·x + b - target)² = (2·3 + 1 - 5)² = 4
∂loss/∂w = 2(w·x + b - target) · x = 2(7-5) · 3 = 12
∂loss/∂b = 2(w·x + b - target) · 1 = 2(7-5) · 1 = 4

The chain rule unfolds automatically through the computation graph PyTorch built during the forward pass.

Autograd is PyTorch's automatic differentiation engine. During the forward pass, it records every operation in a directed acyclic graph (DAG). During .backward(), it traverses this graph in reverse, applying the chain rule at each node.

Gradient Computation

Adjust w and x to see how gradients change. The gradient ∂loss/∂w tells us the direction to move w to reduce loss.

w 2.0
x 3.0
Critical detail: .backward() accumulates gradients. If you call it twice, gradients double. This is why you must call optimizer.zero_grad() before each backward pass — we'll see this in Chapter 5.
After calling loss.backward(), where are the computed gradients stored?

Chapter 4: Optimizer Step

We have gradients. Now we need to actually update the weights. The simplest rule: move each weight in the opposite direction of its gradient, scaled by a learning rate.

SGD from scratch (3 lines)

python
lr = 0.01
for param in model.parameters():
    param.data -= lr * param.grad  # That's SGD. Really.

That's Stochastic Gradient Descent (SGD). The weight update rule is:

wnew = wold − η · ∇L

Where η is the learning rate and ∇L is the gradient of the loss with respect to w.

In practice, you use PyTorch's optimizer:

python
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# ... after computing gradients:
optimizer.step()  # applies the update rule to all parameters

Adam: the default choice

Adam improves on SGD with two ideas:

  1. Momentum: keep a running average of past gradients (like a ball rolling downhill — it accumulates speed)
  2. Adaptive learning rates: parameters with large gradients get smaller updates, parameters with small gradients get larger updates
python
# Adam maintains two moving averages per parameter:
# m = momentum (1st moment of gradient)
# v = squared gradient (2nd moment)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
# 3e-4 is the "Karpathy constant" — a good default for Adam
SGD vs Adam analogy: SGD is like walking downhill blindfolded with fixed step size — you take one step, feel the slope, take another. Adam is like rolling a bowling ball — it builds momentum, adapts to the terrain, and smooths over bumps. Adam converges faster but SGD often generalizes better.
SGD vs Adam on a 2D Loss Surface

Watch how SGD (orange) oscillates while Adam (teal) takes smoother, more direct paths. Click Reset to try a different starting point.

Learning rate 0.032
OptimizerUpdate RuleWhen to Use
SGDw -= lr * gradSimple problems, want best generalization
SGD+Momentumv = 0.9*v + grad; w -= lr*vFaster SGD, smooths oscillations
AdamAdaptive lr + momentumDefault choice, fast convergence
AdamWAdam + decoupled weight decayTransformers, modern best practice
In SGD with learning rate 0.1 and gradient 2.0, how does the weight change?

Chapter 5: The Complete Loop

Now we assemble the four atoms into the complete training loop. The order of operations is critical — get it wrong and your model either doesn't learn or learns garbage.

python
for epoch in range(num_epochs):
    for batch_x, batch_y in dataloader:
        # 1. Forward pass
        predictions = model(batch_x)

        # 2. Compute loss
        loss = loss_fn(predictions, batch_y)

        # 3. Zero gradients (BEFORE backward)
        optimizer.zero_grad()

        # 4. Backward pass
        loss.backward()

        # 5. Update weights
        optimizer.step()
THE order matters. Here's why:

zero_grad() must come before backward() — gradients accumulate by default, so we clear old gradients first.
backward() must come after the loss is computed — it needs the computation graph from the forward pass.
step() must come after backward() — it reads .grad to update weights.

What if you forget zero_grad?

Gradients accumulate. After step 1, grad = 2.0. After step 2, grad = 2.0 + 1.8 = 3.8. After step 3, grad = 3.8 + 1.5 = 5.3. The gradient grows each iteration, and your updates become wildly wrong.

python
# BUG: forgot zero_grad
for batch_x, batch_y in dataloader:
    predictions = model(batch_x)
    loss = loss_fn(predictions, batch_y)
    # optimizer.zero_grad()  ← MISSING!
    loss.backward()         # gradients ADD to existing .grad
    optimizer.step()       # updates use inflated gradients
    # loss may oscillate wildly or explode
Zero Grad vs Accumulated Gradients

Compare training with zero_grad (teal, correct) vs without it (red, broken). Notice how the red curve becomes unstable.

The full picture with logging

python
model = MLP()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(10):
    total_loss = 0
    for batch_x, batch_y in train_loader:
        predictions = model(batch_x)
        loss = loss_fn(predictions, batch_y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch}: loss = {total_loss/len(train_loader):.4f}")
What happens if you call optimizer.step() BEFORE loss.backward()?

Chapter 6: DataLoader

In the previous chapter, we glossed over where train_loader comes from. The DataLoader handles three critical jobs: (1) splitting your dataset into mini-batches, (2) shuffling the order each epoch, and (3) loading data in parallel.

Why mini-batches?

You could compute the loss on ALL training examples at once (full-batch gradient descent). But this is memory-prohibitive for large datasets. You could compute it on ONE example at a time (stochastic GD). But this is too noisy. Mini-batches (typically 32-512 examples) are the sweet spot: stable gradients with reasonable memory.

Minimal DataLoader from scratch

python
import numpy as np

def simple_dataloader(X, y, batch_size, shuffle=True):
    """Yields batches of (X_batch, y_batch)."""
    n = len(X)
    indices = np.arange(n)
    if shuffle:
        np.random.shuffle(indices)
    for start in range(0, n, batch_size):
        idx = indices[start:start + batch_size]
        yield X[idx], y[idx]

# Usage:
for batch_x, batch_y in simple_dataloader(X_train, y_train, 32):
    # train on this batch
    pass

PyTorch's DataLoader

python
from torch.utils.data import DataLoader, TensorDataset

# Wrap tensors in a Dataset
dataset = TensorDataset(X_train, y_train)

# Create DataLoader
train_loader = DataLoader(
    dataset,
    batch_size=32,       # examples per batch
    shuffle=True,        # re-order each epoch
    num_workers=4,      # parallel data loading processes
    pin_memory=True,    # faster GPU transfer
    drop_last=True,     # drop incomplete final batch
)
Why shuffling matters: Without shuffling, the model sees data in the same order every epoch. If your dataset is sorted by class (all dogs, then all cats), the model oscillates — it optimizes for dogs, then "forgets" them while learning cats. Shuffling ensures each batch is a representative sample.
Shuffling vs No Shuffling

Watch training with shuffled batches (teal) vs sorted batches (red). The sorted version oscillates because it keeps "forgetting" earlier classes.

ParameterWhat it doesTypical value
batch_sizeExamples per gradient update32-512
shuffleRandomize order each epochTrue for train, False for eval
num_workersParallel data loading processes4-8 (CPU cores)
pin_memoryPre-allocate GPU-mapped memoryTrue if using GPU
drop_lastSkip last incomplete batchTrue for training
With 1000 training examples and batch_size=32, how many gradient updates happen per epoch?

Chapter 7: Training Dashboard

Time to put it all together. Below is a fully interactive neural network training simulation running in your browser. You're training a tiny 2-layer network to classify two concentric circles (inner ring vs outer ring).

Try these experiments:
1. Set lr=0.01, Adam, batch=16. Train 50 epochs. Watch the decision boundary form.
2. Set lr=10 (the "Explode" button). Watch loss shoot to infinity.
3. Set lr=0.0001. Notice how painfully slow learning is.
4. Compare SGD vs Adam at lr=0.1 — Adam handles it, SGD may oscillate.
5. Try batch_size=1 vs batch_size=128 — notice gradient noise differences.
Live Training Simulation

A 2-layer neural network learning to separate concentric circles. Left: decision boundary. Right: loss curve.

Learning rate 0.010
Batch size 16
Epoch: 0  |  Loss:  |  Accuracy:

What you're seeing:

What you should observe: With Adam at lr=0.01, the decision boundary starts as a random line, then gradually curves to wrap around the inner circle. By epoch 30-50 it should fit well. The loss curve drops steeply at first (easy patterns) then slowly (fine-tuning).

Chapter 8: Schedulers & Checkpointing

So far we've used a fixed learning rate. But in practice, you want the learning rate to change over training. Start high (make big moves early), end low (fine-tune without overshooting).

Learning Rate Schedulers

A scheduler adjusts the learning rate based on a schedule. The three most common:

python
from torch.optim.lr_scheduler import (
    StepLR, CosineAnnealingLR, OneCycleLR
)

# 1. Step decay: multiply lr by 0.1 every 30 epochs
scheduler = StepLR(optimizer, step_size=30, gamma=0.1)

# 2. Cosine annealing: smooth decay from lr to 0
scheduler = CosineAnnealingLR(optimizer, T_max=100)

# 3. One-cycle: warmup → peak → decay (best for modern training)
scheduler = OneCycleLR(optimizer, max_lr=0.01,
                        steps_per_epoch=len(train_loader),
                        epochs=100)
Warmup + Cosine Decay is the modern default for transformers:
- First 5-10% of training: linearly increase lr from 0 to max_lr (warmup)
- Remaining 90-95%: decrease lr following a cosine curve to near-zero
This prevents early instability (warmup) and allows fine convergence (decay).
Learning Rate Schedules

Compare different scheduling strategies. The x-axis is training progress (0% to 100%).

Checkpointing

Training can take hours or days. If your machine crashes at epoch 95 of 100, you want to resume from where you stopped. Checkpoints save the full training state:

python
# Save checkpoint
torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'loss': loss,
}, 'checkpoint.pt')

# Load checkpoint
ckpt = torch.load('checkpoint.pt')
model.load_state_dict(ckpt['model_state_dict'])
optimizer.load_state_dict(ckpt['optimizer_state_dict'])
scheduler.load_state_dict(ckpt['scheduler_state_dict'])
start_epoch = ckpt['epoch'] + 1
state_dict() returns a dictionary mapping parameter names to tensors. This is the canonical way to save/load models in PyTorch. Never save the model object itself (it breaks across code changes). Always save the state_dict.

Early Stopping

Monitor validation loss. If it hasn't improved for patience epochs, stop training — you're overfitting.

python
best_val_loss = float('inf')
patience = 10
counter = 0

for epoch in range(max_epochs):
    train_one_epoch(model, train_loader, optimizer)
    val_loss = evaluate(model, val_loader)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        counter = 0
        torch.save(model.state_dict(), 'best_model.pt')
    else:
        counter += 1
        if counter >= patience:
            print("Early stopping!")
            break

Mixed Precision (torch.cuda.amp)

Modern GPUs have specialized hardware for float16 math. Mixed precision uses float16 for forward/backward (fast) and float32 for weight updates (accurate):

python
scaler = torch.amp.GradScaler()

for batch_x, batch_y in train_loader:
    with torch.amp.autocast(device_type='cuda'):
        predictions = model(batch_x)
        loss = loss_fn(predictions, batch_y)

    optimizer.zero_grad()
    scaler.scale(loss).backward()  # scale loss to prevent underflow
    scaler.step(optimizer)         # unscale grads, then step
    scaler.update()
Why do we save optimizer_state_dict in a checkpoint, not just the model weights?

Chapter 9: Mastery & Connections

You now understand every component of the PyTorch training loop. Let's consolidate with a production-ready template, a bug reference, and connections to deeper topics.

Production Training Template (~50 lines)

python
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

def train(model, train_loader, val_loader, epochs=100, lr=3e-4, device='cuda'):
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    loss_fn = nn.CrossEntropyLoss()
    scaler = torch.amp.GradScaler()
    best_val = float('inf')

    for epoch in range(epochs):
        model.train()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            with torch.amp.autocast(device_type='cuda'):
                loss = loss_fn(model(x), y)
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
        scheduler.step()

        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                val_loss += loss_fn(model(x), y).item()
        val_loss /= len(val_loader)

        if val_loss < best_val:
            best_val = val_loss
            torch.save(model.state_dict(), 'best.pt')
        print(f"Epoch {epoch}: val_loss={val_loss:.4f} lr={scheduler.get_last_lr()[0]:.6f}")

    model.load_state_dict(torch.load('best.pt'))
    return model

Common Bugs Cheat Sheet

BugSymptomFix
Forgot zero_grad()Loss oscillates wildly, grows over timeAdd optimizer.zero_grad() before backward()
Wrong loss functionLoss doesn't decrease, model outputs garbageUse CrossEntropyLoss for classification, MSELoss for regression
Data on wrong deviceRuntimeError: tensors on different devicesMove both model and data to same device: .to(device)
Not calling model.train()Dropout/BatchNorm behave wrongCall model.train() before training, model.eval() before validation
LR too highLoss explodes to inf/NaNReduce lr by 10x, add gradient clipping
LR too lowLoss barely moves after many epochsIncrease lr by 10x, or use a scheduler with warmup
No torch.no_grad() in evalSlow validation, memory leakWrap eval loop in with torch.no_grad()
Mismatched shapesRuntimeError: mat1 and mat2 shapesPrint shapes at each layer, check input dimensions

Connections

Where to go next:
Autograd deep dive — understand HOW backward() computes gradients (the computation graph)
Transformers — the architecture that dominates NLP/vision, but the training loop is identical
Distributed training — when one GPU isn't enough: DataParallel, DistributedDataParallel
RLHF — when the "loss" comes from a reward model instead of ground-truth labels

"What I cannot create, I do not understand." — Richard Feynman

In the production template, why do we call model.eval() before validation?