Why train from scratch when someone already learned to see? Reuse a pretrained network and adapt it to your task.
You have 200 photos of rare birds and you want to build a classifier. Training a deep CNN from scratch on 200 images is hopeless — the network has millions of parameters and will memorize your tiny dataset instantly. Validation accuracy: garbage.
Meanwhile, someone already trained a network on ImageNet — 1.2 million images across 1,000 categories. That network spent days on eight GPUs learning to detect edges, textures, shapes, and objects. All that knowledge is sitting in its weights, freely downloadable.
Left: training from random weights on a small dataset — overfitting and low accuracy. Right: starting from pretrained weights — faster convergence and better generalization. Click to restart.
This works because the early layers of a CNN learn universal visual features — edges, corners, textures — that are useful for almost any vision task. Only the later layers specialize for the original task. By reusing the universal layers and replacing the specialized ones, you get a massive head start.
Before we can reuse a CNN's knowledge, we need to understand what it learned. A typical CNN trained on ImageNet develops a layered hierarchy of features, from simple to complex.
Early layers (conv1, conv2) detect low-level patterns: edges at various orientations, color gradients, small textures. These are generic — every visual task needs edge detection. A network trained on animals and a network trained on buildings learn nearly identical early filters.
Middle layers (conv3, conv4) combine edges into parts: corners, circles, grid patterns, repeated textures. Still fairly generic, but starting to specialize. A "wheel" detector might emerge here for a vehicle dataset.
Late layers (conv5, fc) combine parts into objects: faces, wheels, text, specific object categories. These are highly task-specific. A network trained on faces has "eye" and "nose" detectors here that are useless for classifying flowers.
Hover over each layer to see what kind of features it learns. Early layers are universal; late layers are task-specific.
This hierarchy was first visualized by Zeiler & Fergus (2014), who showed that layer 1 filters look like Gabor filters (oriented edges), layer 3 has texture detectors, and layer 5 has object-part detectors. The universality of early features is why transfer learning works so well.
The simplest form of transfer learning: take a pretrained CNN, chop off the last classification layer, and use everything else as a fixed feature extractor. Run your images through the frozen network, collect the output vectors, and train a simple classifier (like a linear SVM or softmax) on those features.
Why does this work? The pretrained network transforms a 224×224×3 image (150,528 raw pixels) into a compact 2048-dimensional feature vector (for ResNet-50). That feature vector captures semantic information — "this looks furry," "this has stripes," "this is round" — instead of raw pixel values. A linear classifier on these features massively outperforms one on raw pixels.
python import torch from torchvision import models, transforms # Load pretrained ResNet-50 model = models.resnet50(weights='IMAGENET1K_V2') model.eval() # freeze batch norm stats # Remove final classification layer feature_extractor = torch.nn.Sequential( *list(model.children())[:-1] # everything except fc ) # Extract features (no gradients needed) with torch.no_grad(): features = feature_extractor(my_image_batch) # features.shape: [batch_size, 2048, 1, 1]
An image enters the frozen CNN. Each layer transforms it into higher-level features. The final feature vector is handed to a new, small classifier. Click "New Image" to see different feature activations.
Feature extraction treats the pretrained network as a black box. But what if your target domain is different enough that the pretrained features aren't quite right? Medical images, satellite photos, or microscopy look nothing like ImageNet's cats and cars. The frozen features might be a poor fit.
Fine-tuning goes further: instead of freezing all pretrained weights, you unfreeze some (or all) of them and continue training on your data with a small learning rate. The pretrained weights serve as a smart initialization rather than a fixed feature extractor.
Fine-tuning is more powerful but more dangerous. With a large learning rate, you'll destroy the pretrained features — the network "forgets" what it learned on ImageNet. With too few images, you'll overfit. The art is in how much to unfreeze and how gently to update.
| Approach | What Trains | Speed | Risk | Best When |
|---|---|---|---|---|
| Feature extraction | New head only | Fast | Low | Small data, similar domain |
| Fine-tuning (partial) | Head + late layers | Medium | Medium | Medium data, moderate domain shift |
| Fine-tuning (full) | All layers | Slow | High | Large data, different domain |
python # Fine-tuning: replace head, unfreeze late layers model = models.resnet50(weights='IMAGENET1K_V2') # Replace the 1000-class head with our task model.fc = torch.nn.Linear(2048, num_classes) # Freeze early layers, unfreeze late ones for name, param in model.named_parameters(): if 'layer4' in name or 'fc' in name: param.requires_grad = True else: param.requires_grad = False # Use a SMALL learning rate to avoid destroying features optimizer = torch.optim.SGD( filter(lambda p: p.requires_grad, model.parameters()), lr=0.001, momentum=0.9 )
The central decision in transfer learning: how many layers do you freeze? Freeze too many and the pretrained features might not fit your domain. Freeze too few and you risk overfitting on limited data. The answer depends on two factors: how much data you have, and how similar your data is to the pretrained dataset.
Click each quadrant to see the recommended strategy. The vertical axis is data size; the horizontal axis is how similar your data is to ImageNet.
Quadrant 1: Small data, similar domain. Example: classifying dog breeds (ImageNet already has dogs). Use feature extraction — freeze everything, train only the new head. The pretrained features are already a great fit.
Quadrant 2: Large data, similar domain. Example: a massive e-commerce product dataset. Fine-tune the later layers. You have enough data to safely adapt without overfitting, and the domain is close enough that early features transfer well.
Quadrant 3: Small data, different domain. This is the hardest case. Example: 100 medical X-rays. Try extracting features from an earlier layer (before specialization sets in), or use heavy data augmentation. Fine-tuning risks overfitting catastrophically.
Quadrant 4: Large data, different domain. Example: millions of satellite images. Fine-tune the entire network. You have enough data to learn new features from scratch, but pretrained initialization still helps convergence speed.
When fine-tuning, not all layers should learn at the same speed. Early layers hold universal features that need minimal adjustment. Late layers hold task-specific features that need significant rewriting. The new classification head starts from random weights and needs the fastest updates.
This calls for discriminative learning rates (also called differential learning rates): assign different learning rates to different layer groups, with earlier layers getting smaller rates.
python # Discriminative learning rates in PyTorch optimizer = torch.optim.SGD([ {'params': model.layer1.parameters(), 'lr': 1e-4}, {'params': model.layer2.parameters(), 'lr': 1e-4}, {'params': model.layer3.parameters(), 'lr': 1e-3}, {'params': model.layer4.parameters(), 'lr': 1e-2}, {'params': model.fc.parameters(), 'lr': 1e-1}, ], momentum=0.9)
Another common pattern is gradual unfreezing. Start by training only the new head for a few epochs. Then unfreeze the last conv block and train. Then unfreeze the next block. This lets each layer group stabilize before the layers below it start changing.
Drag the slider to adjust the base learning rate. Watch how discriminative rates cascade: each deeper layer gets a smaller rate. The bar chart shows relative update magnitude per layer.
Transfer learning reduces your data needs, but when you only have a few hundred images, even fine-tuning can overfit. Data augmentation artificially expands your training set by applying random transformations to each image: flips, rotations, crops, color jitter, scaling.
The key constraint: augmentations must preserve the label. Flipping a cat horizontally still gives you a cat. But flipping a "6" vertically gives you something that looks like a "9" — the label would be wrong. Choose augmentations that make sense for your domain.
python # Standard augmentation pipeline for transfer learning train_transform = transforms.Compose([ transforms.RandomResizedCrop(224, scale=(0.8, 1.0)), transforms.RandomHorizontalFlip(), transforms.ColorJitter( brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1 ), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ])
A single training image is transformed in multiple ways. Each version is treated as a unique training example. Click "Augment" to generate new random variants.
ImageNet normalization matters. When using a pretrained ImageNet model, always normalize your images with ImageNet's mean and standard deviation. The pretrained weights expect inputs in that range. Feeding unnormalized images is like speaking the wrong language to the network.
Now put it all together. Below is a pretrained CNN with five layer groups. You control how many layers to freeze, the learning rate, and the dataset size. Watch how each choice affects accuracy and training cost in real time.
Toggle layers between frozen (blue, locked) and trainable (orange, unlocked). Adjust dataset size and learning rate. The chart shows predicted accuracy vs training cost. Click a layer to toggle it.
Experiment with these scenarios:
Transfer learning isn't just a CNN trick. The same principle — pretrain on a large dataset, fine-tune on a small one — has become the dominant paradigm across all of deep learning.
In NLP, models like BERT and GPT are pretrained on billions of words of text and then fine-tuned for sentiment analysis, question answering, or translation. In speech, wav2vec is pretrained on unlabeled audio. In reinforcement learning, agents pretrained in simulation are fine-tuned in the real world.
| Strategy | Layers Trained | Data Needed | Risk | Best For |
|---|---|---|---|---|
| Feature extraction | New head only | Very little (100s) | Underfitting | Similar domain, tiny data |
| Partial fine-tuning | Head + late blocks | Moderate (1000s) | Balanced | Most practical cases |
| Full fine-tuning | All layers | Large (10,000s+) | Overfitting | Different domain, ample data |
| Train from scratch | Everything (random init) | Huge (millions) | Slow convergence | Very different domain, massive data |
Key lessons from this lesson: