Build a complete neural network training pipeline by hand — tensors, autograd, modules, training loop, GPU — before ever calling torch.nn.
Here's a challenge: build a neural network that classifies handwritten digits. No torch.nn. No torch.optim. Just raw tensors and math you can see, touch, and verify line by line.
By the end of this lesson, you will have reimplemented PyTorch's core from scratch. You'll write the forward pass, the loss function, the backward pass, the optimizer, the data loader, and the training loop — all by hand. Then you'll compare your code to the "real" PyTorch code and see: they produce identical results.
nn.Linear, every loss.backward(), every optimizer.step() — you'll know exactly what's happening under the hood, because you built it yourself.Here's what we're building toward: a 784→128→10 network that takes a 28×28 pixel image (flattened to 784 numbers) and outputs probabilities for digits 0-9. The network has just 4 tensors — two weight matrices and two bias vectors. That's it. No classes, no modules, no abstractions.
Click "Draw" to sketch a digit. The network (running in your browser) classifies it in real time. This network was trained using ONLY the code we'll write in this lesson.
Let's look at the complete architecture we're building:
The entire network is 4 tensors: W1 (784×128), b1 (128), W2 (128×10), b2 (10). Total parameters: 784×128 + 128 + 128×10 + 10 = 101,770. That's it. Let's build it.
A tensor is just a multi-dimensional array of numbers. A 1D tensor is a vector. A 2D tensor is a matrix. PyTorch tensors live on a device (CPU or GPU) and can track gradients. That's it — no magic.
Our network needs 4 tensors. Let's create them:
python import torch # Initialize weights with random values (Gaussian, mean=0, std=1) # requires_grad=True tells PyTorch: "track derivatives for these" W1 = torch.randn(784, 128, requires_grad=True) # [784, 128] b1 = torch.zeros(128, requires_grad=True) # [128] W2 = torch.randn(128, 10, requires_grad=True) # [128, 10] b2 = torch.zeros(10, requires_grad=True) # [10]
That's the entire network definition. Four tensors. No classes, no inheritance, no __init__ methods. Just numbers arranged in matrices.
Let's understand the memory footprint. Each weight is a 32-bit float (4 bytes):
| Tensor | Shape | Parameters | Memory |
|---|---|---|---|
| W1 | [784, 128] | 100,352 | 401 KB |
| b1 | [128] | 128 | 0.5 KB |
| W2 | [128, 10] | 1,280 | 5 KB |
| b2 | [10] | 10 | 0.04 KB |
| Total | 101,770 | ~407 KB |
W1 is by far the largest — it maps 784 inputs to 128 hidden neurons. That's 784 × 128 = 100,352 individual connections, each with its own learned weight.
requires_grad=True means: PyTorch builds a computation graph as you do math with this tensor. Every operation (multiply, add, etc.) gets recorded. Later, when you call .backward(), PyTorch walks this graph backwards to compute derivatives. Without requires_grad=True, the tensor is just a plain array — no tracking, no gradients.Let's look at what's actually stored in a tensor:
python print(W1.shape) # torch.Size([784, 128]) print(W1.dtype) # torch.float32 print(W1.device) # device(type='cpu') print(W1.requires_grad) # True print(W1.grad) # None (no backward pass yet) # Peek at actual values print(W1[0, :5]) # tensor([ 0.3421, -1.2018, 0.8843, -0.4912, 0.1357]) # Random numbers ~ N(0,1) — some positive, some negative
Each rectangle is a tensor. Width = number of columns, height = number of rows. The numbers inside are the actual float32 values PyTorch stores in memory.
torch.nn.Linear(784, 128) is literally just creating a [784, 128] tensor and a [128] bias tensor internally. We're skipping the wrapper and working with the raw numbers directly.One important detail: weight initialization scale matters. Plain randn (standard deviation = 1) actually causes problems in deep networks. The standard fix is Kaiming initialization:
python # Better initialization for ReLU networks (He et al., 2015) # Scale by sqrt(2/fan_in) to keep variance stable through layers W1 = torch.randn(784, 128) * (2.0 / 784) ** 0.5 W1.requires_grad_(True) # In-place version of requires_grad=True # For our 784→128 layer: scale = sqrt(2/784) ≈ 0.0505 # Values will be roughly in [-0.15, 0.15] instead of [-3, 3]
requires_grad=True on our weight tensors?The forward pass is just matrix multiplication plus an activation function. No magic. Let's trace every step with exact shapes.
Given a batch of 64 images, each flattened to 784 pixels:
python # x shape: [64, 784] — 64 images, each with 784 pixel values (0-1) x = batch_of_images # normalized to [0, 1] # Step 1: Linear transformation (hidden layer) h = x @ W1 + b1 # [64, 784] @ [784, 128] = [64, 128] # Each of the 64 images now has 128 "features" # b1 broadcasts: [128] → [64, 128] # Step 2: ReLU activation h = torch.clamp(h, min=0) # [64, 128] → [64, 128] (same shape, negative values → 0) # Step 3: Linear transformation (output layer) out = h @ W2 + b2 # [64, 128] @ [128, 10] = [64, 10] # Each image now has 10 "logits" — one per digit class
That's the entire forward pass. Three lines. Let's break down what each does:
x @ W1 + b1 — Matrix multiply. Each of the 784 input pixels is multiplied by a weight and the results are summed into 128 output values. Think of it as: each of the 128 hidden neurons "looks at" all 784 pixels, with different weights deciding which pixels matter to that neuron. The bias b1 shifts the result (like a y-intercept in y = mx + b).torch.clamp(h, min=0) — This IS the ReLU function. If a value is negative, set it to zero. If positive, keep it. That's it. nn.ReLU() does nothing more than this. Why do we need it? Without a nonlinearity, stacking two linear layers is equivalent to one linear layer (matrix A × matrix B = matrix C). ReLU breaks this linearity, allowing the network to learn curved decision boundaries.h @ W2 + b2 — Another matrix multiply. Maps 128 hidden features down to 10 output values (one per digit). These are called logits — raw, unnormalized scores. A large positive logit means the network is confident about that digit.Let's trace through a concrete example with actual numbers:
python # Single image of a "7" — mostly zeros except for the stroke x = torch.zeros(1, 784) x[0, 350:360] = 1.0 # simulate some bright pixels h = x @ W1 + b1 # [1, 784] @ [784, 128] = [1, 128] print(h[0, :5]) # tensor([-0.42, 1.87, 0.03, -1.21, 0.95]) h = torch.clamp(h, min=0) print(h[0, :5]) # tensor([ 0.00, 1.87, 0.03, 0.00, 0.95]) # ^^^^ ^^^^ # negatives become zero (ReLU) out = h @ W2 + b2 # [1, 128] @ [128, 10] = [1, 10] print(out) # tensor([[-0.5, 0.2, -0.1, 0.8, -0.3, 0.1, -0.4, 2.1, -0.2, 0.3]]) # ^^^ # logit for "7" is highest → correct prediction!
Watch tensor shapes transform through the network. Adjust batch size to see how it affects only the first dimension.
Now let's package this into a function:
python def forward(x, W1, b1, W2, b2): """Complete forward pass — 3 lines, no torch.nn.""" h = torch.clamp(x @ W1 + b1, min=0) # hidden + ReLU out = h @ W2 + b2 # output logits return out # Usage: logits = forward(x_batch, W1, b1, W2, b2) # [64, 10] prediction = logits.argmax(dim=1) # [64] — predicted digit for each image
nn.Linear(784, 128) is literally just x @ W + b. nn.ReLU() is literally just torch.clamp(x, min=0). The "layers" in PyTorch are thin wrappers around single-line math operations. The math IS the model.The forward pass gives us 10 logits — raw scores. But we need a single number that says "how wrong is the network?" That number is the loss. The optimizer will minimize it.
For classification, we use cross-entropy loss. It has two parts: softmax (convert logits to probabilities) and negative log-likelihood (penalize wrong predictions).
Step 1: Softmax — convert logits to probabilities that sum to 1:
python # Naive softmax (DON'T USE — numerically unstable!) def softmax_naive(logits): exp_logits = torch.exp(logits) # [64, 10] return exp_logits / exp_logits.sum(dim=1, keepdim=True) # Problem: if logits are large (e.g., 100), exp(100) = 2.7 × 10^43 → overflow! logits = torch.tensor([[100.0, 101.0, 102.0]]) print(torch.exp(logits)) # tensor([[2.6881e+43, 7.3071e+43, inf]]) # ^^^ OVERFLOW
python # Stable softmax (what PyTorch actually computes) def softmax(logits): # Subtract max for numerical stability shifted = logits - logits.max(dim=1, keepdim=True).values exp_logits = torch.exp(shifted) # now all values ≤ 0, so exp ≤ 1 return exp_logits / exp_logits.sum(dim=1, keepdim=True) # Verify: same input, no overflow logits = torch.tensor([[100.0, 101.0, 102.0]]) print(softmax(logits)) # tensor([[0.0900, 0.2447, 0.6652]]) — sums to 1.0 ✓
Step 2: Negative log-likelihood — how bad is the prediction?
After softmax, we have a probability for each class. If the correct answer is digit 7, we want prob[7] to be close to 1.0. The loss is:
Why negative log? Because:
python def cross_entropy_loss(logits, targets): """ logits: [batch_size, 10] — raw output of network targets: [batch_size] — integer labels (0-9) returns: scalar loss (averaged over batch) """ # Stable softmax shifted = logits - logits.max(dim=1, keepdim=True).values log_probs = shifted - torch.log(torch.exp(shifted).sum(dim=1, keepdim=True)) # log_probs shape: [64, 10] # Pick the log-probability of the correct class for each sample loss = -log_probs[torch.arange(logits.shape[0]), targets] # shape: [64] — one loss per sample return loss.mean() # average over batch → scalar # Verify: matches PyTorch's built-in import torch.nn.functional as F logits = torch.randn(64, 10) targets = torch.randint(0, 10, (64,)) ours = cross_entropy_loss(logits, targets) theirs = F.cross_entropy(logits, targets) print(torch.allclose(ours, theirs)) # True ← identical!
The correct class is highlighted in teal. Drag the logits to see how probabilities and loss change. Notice: the loss ONLY depends on the correct class probability.
We have a loss. Now we need gradients — partial derivatives of the loss with respect to every weight. These tell us: "if I nudge this weight up slightly, does the loss go up or down, and by how much?"
PyTorch computes these automatically with one magic line:
python loss.backward() # After this call: # W1.grad contains ∂loss/∂W1 (shape [784, 128]) # b1.grad contains ∂loss/∂b1 (shape [128]) # W2.grad contains ∂loss/∂W2 (shape [128, 10]) # b2.grad contains ∂loss/∂b2 (shape [10])
But what does .backward() actually compute? Let's derive it by hand. We'll use the chain rule: if loss depends on out, which depends on W2, then:
Let's work through each gradient step by step:
python # Forward pass (saving intermediates for backward) h_pre = x @ W1 + b1 # [64, 128] — pre-activation h = torch.clamp(h_pre, min=0) # [64, 128] — post-ReLU out = h @ W2 + b2 # [64, 10] — logits # Softmax + loss probs = softmax(out) # [64, 10] loss = cross_entropy_loss(out, targets) # === BACKWARD PASS (what .backward() computes) === # Gradient of loss w.r.t. logits (softmax + NLL combined) # This is the beautiful simplification: d_out = probs - one_hot(targets) d_out = probs.clone() # [64, 10] d_out[torch.arange(64), targets] -= 1 # subtract 1 from correct class d_out /= 64 # average over batch # d_out[i,j] = (prob[i,j] - 1{j==target[i]}) / batch_size # Gradient of loss w.r.t. W2: out = h @ W2 + b2 # ∂loss/∂W2 = h.T @ d_out d_W2 = h.T @ d_out # [128, 64] @ [64, 10] = [128, 10] ✓ d_b2 = d_out.sum(dim=0) # [10] ✓ # Gradient flows back through W2 d_h = d_out @ W2.T # [64, 10] @ [10, 128] = [64, 128] # ReLU gradient: 1 if input > 0, else 0 d_h_pre = d_h * (h_pre > 0).float() # [64, 128] — zero out where ReLU killed # Gradient of loss w.r.t. W1: h_pre = x @ W1 + b1 # ∂loss/∂W1 = x.T @ d_h_pre d_W1 = x.T @ d_h_pre # [784, 64] @ [64, 128] = [784, 128] ✓ d_b1 = d_h_pre.sum(dim=0) # [128] ✓
python # Compare our manual gradients to PyTorch's autograd loss.backward() print(torch.allclose(d_W1, W1.grad, atol=1e-5)) # True ✓ print(torch.allclose(d_b1, b1.grad, atol=1e-5)) # True ✓ print(torch.allclose(d_W2, W2.grad, atol=1e-5)) # True ✓ print(torch.allclose(d_b2, b2.grad, atol=1e-5)) # True ✓
Let's summarize the key gradients in a table:
| Gradient | Formula | Shape | Intuition |
|---|---|---|---|
| ∂L/∂out | softmax(out) − one_hot(y) | [64, 10] | How much each logit needs to change |
| ∂L/∂W2 | hT @ d_out | [128, 10] | Correlation of hidden features with output error |
| ∂L/∂h | d_out @ W2T | [64, 128] | Error "blamed" back to each hidden unit |
| ∂L/∂W1 | xT @ (d_h · relu') | [784, 128] | Correlation of inputs with hidden error |
Forward pass flows left→right (blue). Backward pass flows right→left (orange). Each node shows its gradient value. Click "Step" to animate one backward step at a time.
Gradients tell us the direction of steepest increase for the loss. To minimize the loss, we go the opposite direction: subtract the gradient, scaled by a learning rate.
SGD (Stochastic Gradient Descent) — the simplest optimizer, in 3 lines:
python lr = 0.01 # learning rate — how big a step to take with torch.no_grad(): # don't track these ops in the graph W1 -= lr * W1.grad # take one step down the gradient b1 -= lr * b1.grad W2 -= lr * W2.grad b2 -= lr * b2.grad # CRITICAL: zero the gradients after updating! W1.grad.zero_() b1.grad.zero_() W2.grad.zero_() b2.grad.zero_()
torch.no_grad()? Without it, PyTorch would try to track these subtract operations in the computation graph (since W1 has requires_grad=True). We're updating weights, not computing a loss — we don't need gradients of the gradient update. no_grad() says "just do the math, don't record it.".grad.zero_()? PyTorch ACCUMULATES gradients. If you call .backward() twice without zeroing, the gradients ADD UP. This is occasionally useful (gradient accumulation for large batches), but usually a bug. Always zero after each update.SGD works, but it's slow and oscillates. Momentum fixes this by adding "velocity" — a running average of past gradients:
python # SGD with Momentum lr = 0.01 mu = 0.9 # momentum coefficient (how much history to keep) # Initialize velocity tensors (zeros, same shape as weights) v_W1 = torch.zeros_like(W1) v_b1 = torch.zeros_like(b1) v_W2 = torch.zeros_like(W2) v_b2 = torch.zeros_like(b2) # Update step with torch.no_grad(): v_W1 = mu * v_W1 + W1.grad # blend old velocity with new gradient W1 -= lr * v_W1 # update using velocity (not raw gradient) # ... same for b1, W2, b2
Think of momentum like a ball rolling downhill. On a flat surface, it keeps rolling (velocity carries it). In a narrow valley, it damps oscillations because the velocity averages out the back-and-forth gradients.
Now let's implement Adam — the optimizer used in nearly all modern deep learning. It combines momentum with per-parameter adaptive learning rates:
python # Adam from scratch (~15 lines) lr = 0.001 beta1, beta2 = 0.9, 0.999 # momentum coefficients eps = 1e-8 # prevents division by zero t = 0 # timestep counter # First moment (mean of gradients) and second moment (mean of squared gradients) m_W1 = torch.zeros_like(W1) v_W1 = torch.zeros_like(W1) # ... same for b1, W2, b2 def adam_step(param, grad, m, v, t): """One Adam update step for a single parameter.""" t += 1 m = beta1 * m + (1 - beta1) * grad # update mean estimate v = beta2 * v + (1 - beta2) * grad ** 2 # update variance estimate # Bias correction (early steps would be biased toward zero) m_hat = m / (1 - beta1 ** t) v_hat = v / (1 - beta2 ** t) # Update: larger gradients get SMALLER steps (adaptive!) param -= lr * m_hat / (torch.sqrt(v_hat) + eps) return m, v, t # Usage in training loop: with torch.no_grad(): m_W1, v_W1, t = adam_step(W1, W1.grad, m_W1, v_W1, t) # ... same for b1, W2, b2
Watch SGD (red), Momentum (yellow), and Adam (teal) minimize a 2D loss landscape. Adam adapts to the terrain shape; SGD oscillates.
.grad.zero_() between training steps?Training on one image at a time is noisy (high variance gradients). Training on all 60,000 images at once is wasteful (slow iterations, uses tons of memory). The solution: mini-batches. Process 64 images at a time, update weights, repeat.
Here's a complete DataLoader in 5 lines:
python # Load all data into memory (MNIST is small enough) X_train = load_mnist_images() # [60000, 784] — all training images y_train = load_mnist_labels() # [60000] — all labels (0-9) batch_size = 64 N = X_train.shape[0] # 60000 # Shuffle and iterate in batches — that's ALL DataLoader does perm = torch.randperm(N) # random permutation of [0, 1, ..., 59999] for i in range(0, N, batch_size): idx = perm[i : i + batch_size] # indices for this batch x_batch = X_train[idx] # [64, 784] y_batch = y_train[idx] # [64] # ... forward, loss, backward, update ...
torch.utils.data.DataLoader adds multiprocessing (load next batch while GPU processes current one), pinned memory (faster CPU→GPU transfer), and collation (combine varied-length samples). But the core logic is just: shuffle, slice, yield.Let's understand WHY we shuffle:
python # WITHOUT shuffling: # Batch 1: all zeros # Batch 2: all zeros # ...500 batches of zeros... # Batch 501: all ones # The network learns "everything is zero" then "everything is one" # → oscillates wildly, never converges # WITH shuffling: # Batch 1: mix of [3, 7, 0, 9, 1, 4, ...] — representative sample # Gradient from each batch points roughly toward the global optimum # → smooth convergence
Let's also handle the last batch (which might be smaller):
python # Complete epoch iteration with proper handling def get_batches(X, y, batch_size): """Yields (x_batch, y_batch) tuples for one epoch.""" N = X.shape[0] perm = torch.randperm(N) for i in range(0, N, batch_size): idx = perm[i : i + batch_size] yield X[idx], y[idx] # Last batch may have < batch_size samples — that's fine # e.g., 60000 / 64 = 937.5 → 937 full batches + 1 batch of 32 # Number of batches per epoch: n_batches = (N + batch_size - 1) // batch_size # = 938
60 "images" colored by class. Click "Shuffle" to see random permutation, then "Batch" to slice into groups of 8. Without shuffling, batches would be all-one-color (bad). With shuffling, each batch is a representative mix (good).
Time to combine everything. Here's a complete neural network training pipeline — from initialization to convergence — in 50 lines of pure PyTorch. No nn. No optim. Just tensors, math, and a for loop.
python import torch # ─── Data ─────────────────────────────────────────────── # (In practice you'd load MNIST here; we simulate it) X_train = torch.randn(60000, 784) # [60000, 784] y_train = torch.randint(0, 10, (60000,)) # [60000] # ─── Initialize parameters ────────────────────────────── W1 = torch.randn(784, 128) * (2.0/784)**0.5; W1.requires_grad_(True) b1 = torch.zeros(128, requires_grad=True) W2 = torch.randn(128, 10) * (2.0/128)**0.5; W2.requires_grad_(True) b2 = torch.zeros(10, requires_grad=True) # ─── Hyperparameters ──────────────────────────────────── lr = 0.001 batch_size = 64 epochs = 10 # ─── Adam state ───────────────────────────────────────── params = [W1, b1, W2, b2] ms = [torch.zeros_like(p) for p in params] # first moments vs = [torch.zeros_like(p) for p in params] # second moments t = 0 # ─── Training loop ────────────────────────────────────── for epoch in range(epochs): perm = torch.randperm(60000) epoch_loss = 0.0 correct = 0 for i in range(0, 60000, batch_size): # Get batch idx = perm[i:i+batch_size] x, y = X_train[idx], y_train[idx] # Forward h = torch.clamp(x @ W1 + b1, min=0) out = h @ W2 + b2 # Loss (stable cross-entropy) shifted = out - out.max(dim=1, keepdim=True).values log_probs = shifted - torch.log(torch.exp(shifted).sum(1, keepdim=True)) loss = -log_probs[torch.arange(x.shape[0]), y].mean() # Backward loss.backward() # Adam update t += 1 with torch.no_grad(): for j, p in enumerate(params): ms[j] = 0.9 * ms[j] + 0.1 * p.grad vs[j] = 0.999 * vs[j] + 0.001 * p.grad**2 m_hat = ms[j] / (1 - 0.9**t) v_hat = vs[j] / (1 - 0.999**t) p -= lr * m_hat / (v_hat.sqrt() + 1e-8) p.grad.zero_() # Track metrics epoch_loss += loss.item() correct += (out.argmax(1) == y).sum().item() # Print epoch summary n_batches = (60000 + batch_size - 1) // batch_size print(f"Epoch {epoch+1}/10 | Loss: {epoch_loss/n_batches:.4f} | Acc: {correct/60000*100:.1f}%")
Let's count the lines of actual logic (excluding comments and data loading):
| Component | Lines | What it does |
|---|---|---|
| Parameters | 4 | Create W1, b1, W2, b2 |
| Forward | 2 | Matrix multiply + ReLU |
| Loss | 3 | Cross-entropy |
| Backward | 1 | loss.backward() |
| Optimizer | 6 | Adam update + zero grad |
| Data loop | 4 | Shuffle + batch iteration |
| Total | 20 | Complete training pipeline |
Twenty lines to train a neural network from scratch. Everything else is just plumbing.
Watch a network train in real time (in your browser!). The loss curve drops as the network learns a 2D classification task. This uses the exact same code we wrote — no frameworks.
Here's the moment of truth. We'll train the SAME network two ways: our from-scratch code (left) and the standard torch.nn approach (right). Same random seed, same data, same hyperparameters. They should produce identical results.
Our From-Scratch Code:
python # Raw tensors W1 = torch.randn(784,128) * 0.05 W1.requires_grad_(True) b1 = torch.zeros(128, requires_grad=True) W2 = torch.randn(128,10) * 0.05 W2.requires_grad_(True) b2 = torch.zeros(10, requires_grad=True) # Forward h = torch.clamp(x@W1+b1, min=0) out = h @ W2 + b2 # Loss + backward loss = cross_entropy(out, y) loss.backward() # SGD update with torch.no_grad(): W1 -= lr * W1.grad W1.grad.zero_()
Standard torch.nn Code:
python # nn.Module model = nn.Sequential( nn.Linear(784, 128), nn.ReLU(), nn.Linear(128, 10) ) # Forward out = model(x) # Loss + backward loss = F.cross_entropy(out, y) loss.backward() # Optimizer step optimizer.step() optimizer.zero_grad()
nn.Linear(784, 128) literally creates a [784, 128] weight matrix and a [128] bias — exactly what we did manually. nn.ReLU() calls torch.clamp(x, min=0). F.cross_entropy computes our exact stable softmax + NLL. optimizer.step() does our Adam update loop. There is NO additional computation — torch.nn is pure convenience.Both networks train on the SAME 2D classification task with the SAME initial weights. Orange = from scratch. Teal = torch.nn. The curves should overlap perfectly — because they ARE the same computation.
To make the comparison rigorous, we set the same random seed:
python torch.manual_seed(42) # From-scratch version W1_scratch = torch.randn(784, 128) * 0.05; W1_scratch.requires_grad_(True) # ... torch.manual_seed(42) # SAME seed # torch.nn version model = nn.Sequential(nn.Linear(784, 128), nn.ReLU(), nn.Linear(128, 10)) # nn.Linear uses the same torch.randn internally! # After training both for 10 epochs on same data: print(torch.allclose(W1_scratch, model[0].weight.T, atol=1e-6)) # True print(loss_scratch == loss_nn) # True (same to floating point)
So why use torch.nn at all? Three reasons:
model.parameters() automatically finds all weight tensors (no manual lists)torch.save(model.state_dict()) saves/loads all weights with one callBut the math? Identical. The computation? Identical. The results? Identical. You now know what every line of torch.nn actually does.
Everything we've built so far runs on CPU. Moving to GPU requires exactly ONE change per tensor: .to('cuda'). The math is identical — only the hardware changes.
python device = 'cuda' if torch.cuda.is_available() else 'cpu' # Move parameters to GPU W1 = torch.randn(784, 128, device=device, requires_grad=True) b1 = torch.zeros(128, device=device, requires_grad=True) W2 = torch.randn(128, 10, device=device, requires_grad=True) b2 = torch.zeros(10, device=device, requires_grad=True) # Move data to GPU X_train = X_train.to(device) y_train = y_train.to(device) # Training loop: ZERO code changes needed! # x @ W1 + b1 now runs on GPU cores instead of CPU cores # loss.backward() now computes gradients on GPU # W1 -= lr * W1.grad now updates weights on GPU
device=device when creating tensors (or .to(device) for existing ones). Everything else — forward, backward, update — works identically. The GPU just does the matrix multiplications in parallel across thousands of cores instead of sequentially on a few CPU cores.Why is GPU faster? Matrix multiplication is embarrassingly parallel. Consider x @ W1 where x is [64, 784] and W1 is [784, 128]:
| Operation | Multiply-adds | CPU (8 cores) | GPU (5000 cores) |
|---|---|---|---|
| x @ W1 | 64 × 784 × 128 = 6.4M | ~2 ms | ~0.02 ms |
| Full forward | ~6.5M | ~3 ms | ~0.03 ms |
| Full backward | ~13M | ~6 ms | ~0.05 ms |
| One epoch (938 batches) | ~18B | ~8.4 s | ~0.08 s |
For our small network: GPU is ~100× faster per epoch. For large networks (millions of parameters), the speedup is even greater because the GPU stays saturated.
python # Benchmarking CPU vs GPU import time # CPU timing x_cpu = torch.randn(64, 784) W_cpu = torch.randn(784, 128) start = time.time() for _ in range(1000): _ = x_cpu @ W_cpu cpu_time = time.time() - start # ~0.15 seconds # GPU timing x_gpu = torch.randn(64, 784, device='cuda') W_gpu = torch.randn(784, 128, device='cuda') torch.cuda.synchronize() start = time.time() for _ in range(1000): _ = x_gpu @ W_gpu torch.cuda.synchronize() # MUST sync before timing (GPU is async!) gpu_time = time.time() - start # ~0.003 seconds print(f"CPU: {cpu_time:.3f}s | GPU: {gpu_time:.3f}s | Speedup: {cpu_time/gpu_time:.0f}x") # CPU: 0.150s | GPU: 0.003s | Speedup: 50x
torch.cuda.synchronize() — GPU operations are asynchronous. When you write x @ W, Python returns immediately and the GPU computes in the background. If you time without synchronizing, you're only measuring the time to LAUNCH the operation, not to complete it. Always sync before reading results or timing.Simulated comparison: watch how many batches each device processes in the same time. The GPU processes batches in parallel (multiple columns at once).
You've built everything from scratch. Now let's understand what torch.nn provides on top — and critically, what it does NOT provide. The answer: organization and convenience. Zero additional computation.
1. Automatic parameter tracking:
python # From scratch: manually list all parameters params = [W1, b1, W2, b2] # add W3, b3 if you add a layer → easy to forget one! # torch.nn: parameters found automatically class Net(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(784, 128) # registers .weight and .bias self.fc2 = nn.Linear(128, 10) def forward(self, x): return self.fc2(torch.relu(self.fc1(x))) model = Net() print(list(model.parameters())) # automatically finds all 4 tensors! optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # pass them to Adam
nn.Module gives you: When you assign a nn.Linear as an attribute of a Module, it gets registered. Calling model.parameters() recursively finds ALL registered parameters in the entire model tree. For our 2-layer network this saves 1 line. For a 100-layer ResNet with skip connections, it saves you from manually tracking hundreds of tensors.2. Sequential containers:
python # From scratch: explicit function def forward(x): h = torch.clamp(x @ W1 + b1, min=0) return h @ W2 + b2 # torch.nn: stack layers like LEGO model = nn.Sequential( nn.Linear(784, 128), nn.ReLU(), nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 10), ) out = model(x) # applies each layer in order
3. Serialization (save/load):
python # From scratch: save each tensor manually torch.save({'W1': W1, 'b1': b1, 'W2': W2, 'b2': b2}, 'model.pt') # torch.nn: one-liner torch.save(model.state_dict(), 'model.pt') model.load_state_dict(torch.load('model.pt'))
4. Pre-built layers:
| torch.nn Layer | From-Scratch Equivalent | Lines Saved |
|---|---|---|
| nn.Linear(in, out) | W = randn(in, out); b = zeros(out) | 1 |
| nn.ReLU() | torch.clamp(x, min=0) | 0 |
| nn.Conv2d(3, 64, 3) | ~20 lines of manual convolution loops | 19 |
| nn.BatchNorm1d(128) | ~10 lines (running mean/var tracking) | 9 |
| nn.Dropout(0.5) | mask = (torch.rand_like(x) > 0.5); x *= mask / 0.5 | 1 |
| nn.MultiheadAttention | ~50 lines of QKV projection + softmax | 49 |
5. Hooks for debugging:
python # Peek at activations during forward pass def print_activation(module, input, output): print(f"{module.__class__.__name__}: {output.shape}, mean={output.mean():.3f}") # Register on every layer for layer in model: layer.register_forward_hook(print_activation) model(x) # Linear: torch.Size([64, 128]), mean=0.012 # ReLU: torch.Size([64, 128]), mean=0.401 # Linear: torch.Size([64, 10]), mean=-0.003
Click on each torch.nn component to see what it ACTUALLY does under the hood — the raw tensor operations we already know.
You've built PyTorch from scratch. Let's consolidate with a complete equivalence table, a design challenge, and connections to what's next.
Complete Equivalence Table:
| From Scratch | torch.nn Equivalent | They're the same? |
|---|---|---|
W = randn(in, out, requires_grad=True) | nn.Linear(in, out).weight | Identical |
torch.clamp(x, min=0) | nn.ReLU()(x) or F.relu(x) | Identical |
softmax + NLL (our function) | F.cross_entropy(logits, targets) | Identical |
loss.backward() | loss.backward() | Same call |
W -= lr * W.grad | optim.SGD.step() | Identical |
Our Adam (~15 lines) | optim.Adam.step() | Identical |
W.grad.zero_() | optimizer.zero_grad() | Identical |
randperm + slice | DataLoader(shuffle=True) | Same logic + multiprocessing |
tensor.to('cuda') | model.to('cuda') | Same (but recursive for nn.Module) |
python def dropout(x, p=0.5, training=True): """ x: input tensor [batch, features] p: probability of DROPPING a neuron (0.5 = drop half) training: if False, do nothing (use full network at test time) """ if not training: return x # at test time: use all neurons # Create random mask: 1 with probability (1-p), 0 with probability p mask = (torch.rand_like(x) > p).float() # Scale up survivors to maintain expected value # If we drop 50%, surviving neurons are doubled # This way, the sum at test time (all neurons) ≈ sum at train time (half neurons × 2) return x * mask / (1 - p) # That's it. nn.Dropout(0.5) does exactly this. # The "/ (1-p)" is called "inverted dropout" — the standard approach.
What we didn't cover (and where to go next):
| Topic | What it adds | Lesson |
|---|---|---|
| Convolutions | Spatial weight sharing for images | CNN architectures |
| Attention | Dynamic routing between tokens | Transformer |
| Batch Normalization | Stabilizes deep network training | Training tricks |
| Residual Connections | Gradient highways for 100+ layers | ResNet architecture |
| Mixed Precision | Float16 on GPU for 2× speed | GPU optimization |
| Distributed Training | Split work across multiple GPUs | Scaling |
"What I cannot create, I do not understand." — Richard Feynman
You created it. You understand it. Nothing in PyTorch is magic anymore.