From max pooling to GeM — how to collapse variable-length features into the fixed-size vectors that power classification, retrieval, and embeddings.
A CNN outputs a 7×7×512 feature map — 25,088 numbers describing an image at different spatial locations. But the classifier needs ONE vector: is this a cat or a dog? A BERT encoder outputs 512 vectors, one per token. But the search engine needs ONE embedding for the whole sentence.
How do you go from many vectors to one?
This is the dimension problem. Neural networks produce outputs at every spatial position (CNNs give you H×W feature maps) or at every token position (transformers give you seq_len × d). But downstream tasks — classification, retrieval, regression — need a single fixed-size vector. One prediction per image. One embedding per sentence.
We need to collapse a variable-length dimension into a fixed representation. The question isn't whether to collapse — it's how to collapse without losing critical information.
Let's make this concrete. A convolutional layer with 512 filters applied to a 7×7 spatial grid produces a tensor of shape (7, 7, 512). Each spatial position (i, j) holds a 512-dimensional vector describing what the network "sees" at that location. Position (0, 0) might describe the top-left corner — maybe a patch of sky. Position (3, 3) might describe the center — maybe the cat's nose.
All 49 positions carry information. The challenge: compress those 49 vectors into one, preserving what matters for the task.
The naive approach: reshape the 7×7×512 tensor into a single vector of length 25,088 and feed it to a fully-connected layer. Early CNNs (AlexNet, VGG) did exactly this. It works — but with devastating consequences.
Problem 1: Size depends on input resolution. If your input image is 224×224, the feature map is 7×7 and the flattened vector is 25,088. If your input is 448×448, the feature map is 14×14 and the flattened vector is 100,352. The FC layer has a fixed weight matrix — it can't handle different input sizes. You're locked to one resolution forever.
Problem 2: Parameter explosion. VGG-16 has a 7×7×512 = 25,088 vector feeding into a 4,096-neuron FC layer. That's 25,088 × 4,096 = 102 million parameters in a single layer. The entire convolutional backbone has only 14 million. The FC layers dominate the model, and most of those parameters overfit.
Problem 3: No spatial invariance. Flattening assigns a separate weight to each spatial position. If the cat shifts 2 pixels right, completely different weights activate. The FC layer must re-learn "cat" at every possible position — a waste of capacity.
Let's see the parameter cost difference with real numbers.
| Approach | Feature Map | Vector Size | FC Params (to 1000 classes) |
|---|---|---|---|
| Flatten (7×7) | 7×7×512 | 25,088 | 25,088 × 1000 = 25.1M |
| Flatten (14×14) | 14×14×512 | 100,352 | 100,352 × 1000 = 100.4M |
| Global Pool | Any H×W×512 | 512 | 512 × 1000 = 0.5M |
Global pooling collapses any spatial size to 512 — a 50× reduction in parameters compared to flattening a 7×7 map, and 200× compared to 14×14. That's not optimization — it's a fundamentally different approach.
This lesson covers three families of pooling. Max pooling picks the strongest activation — if any neuron in a region fires, we keep it. Average pooling takes the mean — every position contributes equally. And sequence pooling handles the 1D case in transformers, where we collapse tokens instead of spatial positions.
Let's see them in action before we formalize anything.
A 4×4 feature map. Click any cell to change its value. Toggle between Max, Average, and ??? pooling to see how the output changes. Which method preserves the strongest signal? Which smooths everything out?
Notice: when you click around and create one very high value, max pooling immediately grabs it. The output jumps. Average pooling barely moves — one outlier among many gets diluted. The "???" mode is a preview of something we'll build later (weighted pooling), but for now, focus on max vs average. They answer different questions about the data.
A 4×4 feature map, pooled with 2×2 windows and stride 2. Each window collapses 4 values into 1.
Input:
values 0.2 0.8 | 0.1 0.9 0.5 0.3 | 0.7 0.4 -----|----- 0.6 0.1 | 0.3 0.5 0.4 0.9 | 0.2 0.8
Max pooling (2×2, stride 2): take the max of each block.
Output: [[0.8, 0.9], [0.9, 0.8]]. Max preserves the peaks.
Average pooling (2×2, stride 2): take the mean of each block.
Output: [[0.45, 0.525], [0.50, 0.45]]. Average produces a smoother summary.
See the difference? Max output ranges from 0.8 to 0.9 — it kept the strongest signals. Average output ranges from 0.45 to 0.525 — it smoothed everything toward the mean. If you care about whether a feature exists, max is your friend. If you care about the overall level, average tells a better story.
python import numpy as np def max_pool_2x2(x): # x: (H, W) feature map, H and W must be even H, W = x.shape out = np.zeros((H // 2, W // 2)) for i in range(0, H, 2): for j in range(0, W, 2): block = x[i:i+2, j:j+2] out[i//2, j//2] = np.max(block) return out def avg_pool_2x2(x): # Same structure, but mean instead of max H, W = x.shape out = np.zeros((H // 2, W // 2)) for i in range(0, H, 2): for j in range(0, W, 2): block = x[i:i+2, j:j+2] out[i//2, j//2] = np.mean(block) return out # Test with our hand calculation x = np.array([ [0.2, 0.8, 0.1, 0.9], [0.5, 0.3, 0.7, 0.4], [0.6, 0.1, 0.3, 0.5], [0.4, 0.9, 0.2, 0.8] ]) print(max_pool_2x2(x)) # [[0.8, 0.9], [0.9, 0.8]] print(avg_pool_2x2(x)) # [[0.45, 0.525], [0.5, 0.45]]
Imagine you're scanning a photo for edges. Your CNN has 64 edge detectors, each producing a number at every spatial position. One detector fires strongly at position (3, 4) — there's a vertical edge there. The neighboring positions have weaker responses: maybe 0.1, 0.3, 0.2.
Max pooling says: "I don't care WHERE in this 2×2 region the edge is. I only care that it EXISTS." It grabs the strongest response and discards the rest. If any neuron in the pooling window detected a feature, the max preserves that detection.
This gives max pooling a powerful property: translation invariance within the pooling window. Shift the edge one pixel right? It's still inside the same 2×2 window. The max pool output doesn't change. The network becomes robust to small spatial shifts — exactly what you want for recognition tasks.
For a 2×2 window with stride 2 applied to a 4×4 input, each output value is:
The output has shape (H/2, W/2). We halve the spatial dimensions. Each output cell summarizes a 2×2 patch by keeping only its maximum.
More generally, for a pool window of size k with stride s:
Output shape: ((H - k) / s + 1, (W - k) / s + 1). When k = s = 2, this simplifies to (H/2, W/2). When k = 3, s = 2, you get overlapping pooling regions — used in AlexNet.
Input feature map (4×4):
values 1.2 0.5 | 3.1 0.8 0.3 2.7 | 0.1 1.5 -----|----- 0.9 0.4 | 2.3 3.8 1.1 0.6 | 0.7 1.2
Top-left block [1.2, 0.5, 0.3, 2.7]: compare all four — max is 2.7 (bottom-right of the block).
Top-right block [3.1, 0.8, 0.1, 1.5]: max is 3.1 (top-left of the block). The strong feature at this exact position is preserved.
Bottom-left block [0.9, 0.4, 1.1, 0.6]: max is 1.1 (bottom-left). This is the weakest-maximum block — none of the positions fired strongly.
Bottom-right block [2.3, 3.8, 0.7, 1.2]: max is 3.8 (top-right). That single strong activation dominates.
Output (2×2):
result 2.7 3.1 1.1 3.8
Four values. Each is the winner of its 2×2 block. The spatial resolution dropped from 4×4 to 2×2, but the strongest features survived.
During backpropagation, where does the gradient flow? Only to the position that held the maximum value. All other positions in the window receive zero gradient. This is a "winner-take-all" mechanism.
Think about it mathematically. The max function is:
For our top-left block, the gradient flows entirely to position (1, 1) where the value was 2.7. Positions (0, 0), (0, 1), and (1, 0) get zero gradient from this output cell. They won't be updated from this particular pooling window.
Let's trace the backward pass for our example. Suppose the loss gradient flowing back to the output is:
gradient at output 0.5 -0.3 0.2 -0.1
The gradient at the input is a 4×4 grid of mostly zeros:
gradient at input 0.0 0.0 | -0.3 0.0 # 3.1 was max → gets -0.3 0.0 0.5 | 0.0 0.0 # 2.7 was max → gets 0.5 ------|------ 0.0 0.0 | 0.0 -0.1 # 3.8 was max → gets -0.1 0.2 0.0 | 0.0 0.0 # 1.1 was max → gets 0.2
Only 4 out of 16 positions receive gradient. The rest are silent for this particular pooling operation. Sparse gradients.
Edit values in the 4×4 grid (click a cell, then use the slider). The 2×2 max-pooled output updates live. Orange borders show which cell "won" each block. Teal arrows show gradient flow (only to winners).
Try making one cell much larger than its neighbors — say, 4.5. Watch the output snap to that value instantly. Now lower it below its neighbor. The "winner" border shifts to the new max. The gradient arrow changes direction. Only the current winner gets the learning signal.
python import numpy as np import torch import torch.nn as nn import torch.nn.functional as F # From scratch with gradient tracking def max_pool_2x2_with_grad(x): """Returns pooled output AND a mask of winner positions.""" H, W = x.shape out = np.zeros((H // 2, W // 2)) mask = np.zeros_like(x) # 1 at max positions, 0 elsewhere for i in range(0, H, 2): for j in range(0, W, 2): block = x[i:i+2, j:j+2] max_val = np.max(block) out[i//2, j//2] = max_val # Mark the winner (argmax within the block) mi, mj = np.unravel_index(np.argmax(block), (2, 2)) mask[i + mi, j + mj] = 1 return out, mask # PyTorch: nn.MaxPool2d pool = nn.MaxPool2d(kernel_size=2, stride=2) x_t = torch.tensor([[ [1.2, 0.5, 3.1, 0.8], [0.3, 2.7, 0.1, 1.5], [0.9, 0.4, 2.3, 3.8], [1.1, 0.6, 0.7, 1.2] ]], requires_grad=True).unsqueeze(0) # (1, 1, 4, 4) out = pool(x_t) # tensor([[[[2.7, 3.1], [1.1, 3.8]]]]) out.sum().backward() print(x_t.grad) # 1s at max positions, 0s elsewhere # Functional API (equivalent, but no stored state) out2 = F.max_pool2d(x_t, kernel_size=2, stride=2)
Max pooling appears in nearly every classic CNN. AlexNet (2012) used overlapping 3×3 max pools with stride 2. VGG (2014) used non-overlapping 2×2 with stride 2 after every 2-3 conv layers. ResNet (2015) uses a single 3×3 max pool early in the network, then relies on strided convolutions.
It works best when the presence of a feature matters more than its exact magnitude. Edge detection, object detection, texture recognition — anywhere you want "is this feature here?" rather than "how much of this feature is here?"
But max pooling has a fundamental limitation: it throws away magnitude information. A block with values [0.1, 0.1, 0.1, 5.0] and a block with values [4.8, 4.9, 4.7, 5.0] both produce output 5.0. The first block had one spike amid silence. The second had uniformly strong activation. Max pooling can't tell the difference. Sometimes that distinction matters — and that's where average pooling enters the picture.
Max pooling keeps the loudest voice. Average pooling listens to everyone equally. If a region has values [0.1, 0.2, 0.8, 0.1], max says "0.8 — there's a strong feature here." Average says "0.3 — there's a weak-to-moderate signal overall." Different answers. Different use cases.
Average pooling takes the arithmetic mean of all values in the pooling window. Every position contributes equally to the output. There are no winners and no losers — the output reflects the collective response.
For a 2×2 window with stride 2:
More generally, for a k×k window:
The gradient is simple and elegant: every position in the window receives an equal share. If the output gradient is g, each of the k² inputs gets g / k². No winner-take-all — everyone gets updated.
Let's compare them on the same data to build intuition for when to use which.
| Property | Max Pooling | Average Pooling |
|---|---|---|
| What it keeps | Strongest activation | Overall activation level |
| Gradient flow | Only to winner (sparse) | Equal to all (dense) |
| Sensitive to outliers | Very — one high value dominates | Robust — outliers are diluted |
| Best for | Feature detection (is it there?) | Feature magnitude (how much?) |
| Classic use | Local pooling in conv layers | Global pooling before classifier |
In practice, the community converged on a split: local pooling (2×2, stride 2 between conv layers) usually uses max. Global pooling (entire feature map → 1 value per channel) usually uses average. Why? Local max pooling preserves edge-like features as they flow through the network. Global average pooling produces a stable summary of the entire feature map for classification.
This is the big one. Global Average Pooling doesn't use a small window — it takes the mean across the ENTIRE spatial dimension. If your feature map is (batch, 512, 7, 7), GAP produces (batch, 512). One number per channel. The entire 7×7 spatial grid collapses into a single value.
This replaced the massive fully-connected layers in classic CNNs. VGG-16 had 119 million parameters in its FC layers (out of 138M total). ResNet with GAP has zero parameters in pooling and only 512 × 1000 = 0.5M in the final linear layer. A 200× reduction in classifier parameters.
GAP was introduced by Lin et al. in 2013 ("Network in Network") and became standard with GoogLeNet (2014) and ResNet (2015). Every modern CNN uses it.
One channel of a 3×3 feature map:
values 0.5 0.2 0.8 0.1 0.9 0.3 0.4 0.6 0.7
GAP = (0.5 + 0.2 + 0.8 + 0.1 + 0.9 + 0.3 + 0.4 + 0.6 + 0.7) / 9 = 4.5 / 9 = 0.5.
One number. It says: "this channel has moderate overall activation across the entire spatial extent." If this were the "cat ear" channel and only one position had a high value (say 3.0 while the rest are 0.1), GAP would output (3.0 + 8 × 0.1) / 9 = 3.8 / 9 ≈ 0.42. The strong local detection gets diluted — which is actually appropriate, because the channel's OVERALL response is indeed moderate.
Now consider a channel where MOST positions fire strongly: [2.1, 1.8, 2.3, 1.9, 2.5, 2.0, 1.7, 2.2, 2.4]. GAP = 18.9 / 9 = 2.1. This channel responds strongly across the entire image. GAP captures that distinction perfectly.
GAP forces each channel in the last conv layer to correspond to one class-relevant feature. Why? Because the final linear layer takes the 512-dimensional GAP output and produces class scores with a simple matrix multiply. Each channel's pooled value is multiplied by one weight per class.
This means channel 37 can't encode "cat nose at position (3, 4)" — it has to encode "cat-nose-ness" everywhere. The spatial information is explicitly removed. The channel must learn a position-invariant feature detector, or it contributes nothing useful after GAP.
This is built-in structural regularization. No dropout needed. No weight decay on the pooling layer (it has no weights). Just the architectural choice of "average everything" forces better feature learning upstream.
Same 4×4 grid. Toggle between Max, Average, and GAP. Add noise with the slider — watch how max is sensitive to outliers (one high value dominates) while average is robust. GAP collapses everything to a single number.
Try cranking the noise slider up. Max pooling's output swings wildly — whichever cell gets the largest noise spike wins. Average pooling barely moves, because the noise cancels out across positions. This is why average pooling is preferred for global aggregation: it's a more stable summary.
python import numpy as np import torch import torch.nn as nn # Average pooling from scratch def avg_pool_2x2(x): H, W = x.shape out = np.zeros((H // 2, W // 2)) for i in range(0, H, 2): for j in range(0, W, 2): out[i//2, j//2] = np.mean(x[i:i+2, j:j+2]) return out # Global Average Pooling from scratch def global_avg_pool(x): # x shape: (channels, H, W) return np.mean(x, axis=(1, 2)) # shape: (channels,) # PyTorch: nn.AvgPool2d for local average pooling avg_pool = nn.AvgPool2d(kernel_size=2, stride=2) x = torch.randn(1, 512, 8, 8) out = avg_pool(x) # (1, 512, 4, 4) print(out.shape) # PyTorch: nn.AdaptiveAvgPool2d for GAP # "Adaptive" means: whatever the input size, produce THIS output size gap = nn.AdaptiveAvgPool2d(1) # output is (1, 1) spatially x7 = torch.randn(1, 512, 7, 7) x14 = torch.randn(1, 512, 14, 14) print(gap(x7).shape) # (1, 512, 1, 1) — works! print(gap(x14).shape) # (1, 512, 1, 1) — same output shape! # ResNet-style classifier head: # features = backbone(image) # (batch, 512, 7, 7) # pooled = gap(features).flatten(1) # (batch, 512) # logits = fc(pooled) # (batch, num_classes)
Unlike max pooling's sparse gradient, average pooling distributes gradient uniformly. For a k×k window, each of the k² inputs receives 1/k² of the output gradient.
For GAP over an H×W feature map: each spatial position receives gradient = output_grad / (H × W). If H = W = 7, each position gets 1/49 of the gradient. Small per position, but EVERY position gets some signal. Dense but diluted.
This has an interesting consequence for learning dynamics. Max pooling gives strong gradients to a few positions (aggressive learning at peaks). Average pooling gives weak gradients to all positions (gentle, uniform learning). In practice, both work — the choice depends on whether your task cares about peaks or distributions.
BERT outputs 512 vectors, one per token. You need ONE vector to classify the sentiment of a movie review. Which token's output do you use? The first? The last? All of them?
BERT's answer: a special [CLS] token prepended to every input. Its output "summarizes" the whole sequence through attention. But is that actually the best approach? Spoiler: for many tasks, it isn't.
In BERT and its descendants, every input sequence is prepended with a special [CLS] token (short for "classification"). The input looks like:
tokens
[CLS] The movie was terrible [SEP]
After L layers of self-attention, each token's output vector has "seen" every other token. The [CLS] token has attended to "The," "movie," "was," "terrible" — accumulating information about the whole sequence through the attention mechanism.
To classify, you extract [CLS]'s output and pass it through a linear layer:
Simple. Elegant. One token to rule them all. But there's a catch.
At layer 1 of the transformer, [CLS] has zero information. It's initialized with a learned embedding, but it hasn't attended to anything yet. It must build its representation from scratch, layer by layer, through attention.
In the early layers, [CLS]'s representation is essentially random with respect to the actual input content. Only in the deeper layers does it start to carry meaningful sequence-level information. This means the lower layers of the transformer are wasting capacity: [CLS] is attending to tokens but not contributing useful information back.
For classification tasks (where [CLS] is explicitly trained with a classification loss), this works fine — the model learns to route information to [CLS] during fine-tuning. But for embedding tasks (retrieval, semantic search, similarity), where we want the vector to capture the meaning of the text, [CLS] has a fundamental limitation: it represents a single token's "view" of the sequence.
Mean pooling is dead simple: average all token outputs.
Where T is the sequence length and ht is the output vector of token t. Every token contributes equally. No special token needed.
Why does this work better for embeddings? Each token's output ht is a contextualized representation — it already contains information about the whole sequence through attention. Token 3 ("was") knows about "terrible" and "movie." So when you average all tokens, you're averaging many different contextual views of the same sequence, not just one.
Four tokens, embedding dimension d = 3. After the final transformer layer:
token outputs h0 [CLS] = [ 0.5, 0.3, -0.1] h1 "The" = [ 0.8, -0.2, 0.4] h2 "movie" = [ 0.1, 0.9, 0.3] h3 "terrible" = [-0.3, 0.5, 0.6]
CLS pooling: Just take h0.
Result: [0.5, 0.3, -0.1]
Mean pooling: Average all four vectors element-wise.
Result: [0.275, 0.375, 0.300]
The two results are completely different. CLS gives you whatever the [CLS] token learned. Mean gives you a balanced summary across all tokens. For a sentiment task, [CLS] might work fine if it was trained for sentiment. But for a general-purpose embedding, the mean captures more information — it includes the signal from "terrible" (h3) and "movie" (h2), not just what [CLS] decided to remember.
The Sentence-BERT paper (Reimers & Gurevych, 2019) compared CLS pooling, mean pooling, and max pooling for sentence embeddings. The results were clear:
| Pooling Method | STS Benchmark (Spearman) | Notes |
|---|---|---|
| [CLS] token | ~77 | Worst performance for embeddings |
| Max pooling | ~79 | Better than CLS but still limited |
| Mean pooling | ~84 | Best performance by a clear margin |
Since then, essentially every embedding model uses mean pooling: Sentence-BERT, E5, GTE, BGE, Nomic Embed. CLS pooling remains standard only for classification fine-tuning, where the [CLS] token is explicitly trained with a task-specific head.
There's a critical implementation detail that trips up beginners. In a batch, sequences have different lengths. Shorter sequences are padded with [PAD] tokens to match the longest sequence:
batch "The movie was terrible" → [CLS] The movie was terrible [SEP] [PAD] [PAD] "I loved it so much really" → [CLS] I loved it so much really [SEP]
The padding tokens have output vectors — usually close to zero, but NOT exactly zero. If you naively average all positions including padding, you're diluting real signal with padding noise. Sentence 1's embedding is worse because two of its eight "token" outputs are junk.
The fix: use the attention mask to exclude padding.
Where maskt = 1 for real tokens and 0 for padding. The denominator is the actual number of real tokens, not the padded sequence length.
Sequence with 3 real tokens and 1 padding token, d = 2:
tokens h0 "The" = [0.6, 0.4] mask = 1 h1 "cat" = [0.8, 0.2] mask = 1 h2 "sat" = [0.3, 0.7] mask = 1 h3 [PAD] = [0.1, 0.05] mask = 0
Without masking: (0.6+0.8+0.3+0.1)/4 = 0.45, (0.4+0.2+0.7+0.05)/4 = 0.3375 → [0.45, 0.3375]
With masking: (0.6+0.8+0.3)/3 = 0.567, (0.4+0.2+0.7)/3 = 0.433 → [0.567, 0.433]
Different results! The masked version correctly ignores the padding and gives a higher-magnitude, cleaner embedding. In production, this difference compounds across thousands of queries and degrades retrieval quality if you forget the mask.
A sequence of tokens, each with an embedding vector. Toggle between CLS (use first token) and Mean (average all tokens). The bottom bar shows the resulting sentence embedding. Click tokens to edit their values and see how each method responds.
Try adding padding tokens with the slider. In CLS mode, padding doesn't affect the result at all — we only use token 0. In Mean mode, watch the embedding shift as padding tokens dilute the average. This is why attention masking is essential for mean pooling in practice. The simulation applies masking automatically; toggle padding to see the difference.
python import torch # CLS pooling: just slice def cls_pool(hidden_states): """hidden_states: (batch, seq_len, d)""" return hidden_states[:, 0, :] # (batch, d) # Mean pooling with attention mask def mean_pool(hidden_states, attention_mask): """ hidden_states: (batch, seq_len, d) attention_mask: (batch, seq_len) — 1 for real, 0 for pad """ # Expand mask to (batch, seq_len, d) for broadcasting mask = attention_mask.unsqueeze(-1).float() # (batch, seq_len, 1) # Zero out padding positions, then sum summed = (hidden_states * mask).sum(dim=1) # (batch, d) # Divide by number of real tokens (not padded length!) counts = mask.sum(dim=1).clamp(min=1e-9) # (batch, 1) return summed / counts # (batch, d) # Example: batch of 2 sentences, max_len=6, d=4 hidden = torch.randn(2, 6, 4) mask = torch.tensor([ [1, 1, 1, 1, 0, 0], # 4 real tokens, 2 padding [1, 1, 1, 1, 1, 1], # 6 real tokens, no padding ]) cls_emb = cls_pool(hidden) # (2, 4) mean_emb = mean_pool(hidden, mask) # (2, 4) print(cls_emb.shape, mean_emb.shape)
sum(h_t × mask_t) / sum(mask_t). Forgetting this is one
of the most common bugs in embedding pipelines.
| Task | Best Pooling | Why |
|---|---|---|
| Text classification (fine-tuned) | [CLS] | Classification head is trained directly on [CLS]; it learns to encode task-specific info |
| Sentence embeddings / retrieval | Mean | Richer representation from all tokens; empirically better similarity scores |
| Named entity recognition | None (per-token) | Need per-token predictions; pooling would destroy position info |
| Semantic search | Mean | Every token contributes to meaning; [CLS] misses nuance |
| Zero-shot classification | Mean | Better generalization to unseen classes |
The pattern is clear: if you're fine-tuning for a specific task and can train the [CLS] token directly, CLS works. If you need general-purpose embeddings that transfer across tasks, mean pooling wins.
Mean pooling treats all tokens equally. But in "The movie was absolutely terrible," the words "absolutely terrible" carry all the sentiment. "The movie was" is filler. What if we could learn to weight the important tokens more heavily?
That's the idea behind attention pooling. Instead of assigning uniform weight 1/T to every position (mean pooling) or picking one winner (max pooling), we learn a query vector q that asks: "which tokens are relevant for my task?" Each token gets scored against this query, and the scores become weights via softmax. The final output is a weighted sum where important tokens dominate.
Think of q as a talent scout at an audition. Every token performs. The scout (q) evaluates each performance, assigns scores, and the final cast is dominated by the top performers — but everyone gets at least a small role (softmax never hits exactly zero).
Given a sequence of token representations h1, h2, ..., hT, each a d-dimensional vector:
The query vector q is a learned parameter — it starts random and the training loss pushes it toward directions that identify task-relevant tokens. For sentiment analysis, q might learn to attend to adjectives and adverbs. For topic classification, it might attend to nouns and proper names.
Let's trace through a concrete example. We have 4 tokens with d=2 dimensional representations, and a learned query vector.
Step 1 — Compute raw scores (dot products):
Token 1 scores highest (0.66) — its representation aligns best with the query. Tokens 2 and 3 score negative — they point away from what the query is looking for.
Step 2 — Scale by √d = √2 ≈ 1.414:
Step 3 — Softmax to get weights:
exp(0.141) = 1.152, exp(0.467) = 1.595, exp(-0.269) = 0.764, exp(-0.325) = 0.722. Sum = 4.233.
Token 1 gets 37.7% of the weight — the network decided it's most relevant. Token 0 gets 27.2%. Tokens 2 and 3 together only get about 35%. Compare this to mean pooling where every token gets exactly 25%.
Step 4 — Weighted sum:
Final output: [0.405, 0.254]. Compare with mean pooling: [0.275, 0.375]. The attention output is pulled toward token 1's representation ([0.8, -0.2]) because the query decided it was most important.
The simulation below shows attention pooling on 6 tokens. Adjust the query direction with the slider — watch how the attention weights shift, how different tokens light up, and how the output vector changes compared to uniform mean pooling.
Drag the query angle to change which tokens the attention focuses on. Bar heights show attention weights. The output vector (gold) shifts toward the highest-weighted tokens. Gray dashed line shows mean pooling for comparison.
python import torch import torch.nn as nn import torch.nn.functional as F class AttentionPooling(nn.Module): """Single-query attention pooling. Input: (B, T, D) — batch of token sequences Output: (B, D) — one vector per sequence """ def __init__(self, d_model): super().__init__() # Learned query: what to look for self.query = nn.Parameter(torch.randn(d_model)) def forward(self, x, mask=None): # x: (B, T, D), self.query: (D,) scores = torch.matmul(x, self.query) # (B, T) scores = scores / (x.size(-1) ** 0.5) # scale by sqrt(d) if mask is not None: scores = scores.masked_fill(~mask, float('-inf')) weights = F.softmax(scores, dim=-1) # (B, T) return torch.sum(weights.unsqueeze(-1) * x, dim=1) # (B, D) # Test it pool = AttentionPooling(d_model=64) x = torch.randn(2, 10, 64) # 2 sequences, 10 tokens, 64-dim out = pool(x) print(out.shape) # torch.Size([2, 64]) # With padding mask (tokens 8,9 are padding for sequence 0) mask = torch.ones(2, 10, dtype=torch.bool) mask[0, 8:] = False out_masked = pool(x, mask=mask) # Padding tokens get zero weight
Notice the mask parameter — in real NLP, sequences have different
lengths and are padded. Without masking, attention would waste weight on
padding tokens. The masked_fill(-inf) trick makes softmax assign
exactly zero weight to padded positions.
| Model / Paper | How it uses attention pooling |
|---|---|
| Set Transformer (Lee et al., 2019) | PMA (Pooling by Multi-head Attention) — uses k learned "seed" vectors as queries to extract k summary vectors from a set |
| Multi-Instance Learning | Aggregate bag-of-instances into a single bag representation; attention weights tell you which instances matter |
| Document Classification | Pool sentence embeddings into document embedding; attention weight = sentence importance |
| Sentence-BERT (optional) | Weighted aggregation of token embeddings, outperforms mean pooling on some tasks |
Attention pooling with one query asks one question. But what if you need multiple aspects of the sequence — sentiment AND topic AND entity? Multi-head pooling asks multiple questions simultaneously. And for autoregressive models like GPT, there's an even simpler answer: just use the last token.
This chapter covers three more aggregation strategies that go beyond simple attention pooling. Each makes a different bet about where information concentrates in the sequence — and each bet turns out to be right for different architectures.
If one query extracts one aspect of the sequence, why not use K queries to extract K different aspects? That's multi-head attention pooling, used in the Set Transformer (Lee et al., 2019). Each query "head" attends to different token patterns, producing K output vectors. You can concatenate them (getting a K·d vector) or average them (back to d).
Head 1 might attend to the subject of a sentence. Head 2 might attend to the action. Head 3 might attend to negation words. Together, they capture a richer summary than any single query could.
For causal (decoder) models like GPT, there's a beautifully simple pooling strategy: just take the last token's representation.
Why does this work? In causal attention, each token can only attend to previous tokens. Token 0 sees only itself. Token 1 sees tokens 0 and 1. The last token is the only position that has attended to the entire sequence. Its representation is already a summary of everything that came before it.
The operation is trivial — no parameters, no computation beyond indexing:
For variable-length sequences with padding on the right, you need the actual last non-padding position — not just the final array element.
What if each token could independently decide its own importance, without consulting a shared query? That's gated pooling. Each token passes through a small gating network that outputs a scalar weight:
Here σ is the sigmoid function, so each gate gt is between 0 and 1. We divide by the sum of gates to normalize. Unlike attention pooling, there's no shared query — each token decides independently whether it's important. The gate learns to suppress noise tokens and amplify signal tokens.
The difference is subtle but real: attention pooling uses a shared query to compare tokens against a template. Gated pooling uses each token's own content to decide its importance. Both produce weighted averages — they just compute the weights differently.
Let's apply all our pooling methods to the same 4 tokens from Chapter 4. Same representations: h0 = [0.5, 0.3], h1 = [0.8, -0.2], h2 = [0.1, 0.9], h3 = [-0.3, 0.5].
| Method | Output | What it captures |
|---|---|---|
| Mean | [0.275, 0.375] | Average of all tokens equally |
| Max (per-dim) | [0.8, 0.9] | Strongest feature per dimension |
| CLS (= h0) | [0.5, 0.3] | Dedicated summary position |
| Attention (q=[0.7,-0.5]) | [0.405, 0.254] | Query-relevant weighted sum |
| Last-token (= h3) | [-0.3, 0.5] | Final position (full causal context) |
Five methods, five different answers. Each is "correct" — they just extract different information from the same input. Mean smooths everything. Max grabs peaks. CLS and last-token pick one position. Attention learns what matters.
The simulation shows all five pooling methods applied to the same token sequence. Each method produces a different output point in 2D space. Edit the token vectors by dragging, and watch all five outputs update simultaneously.
All five methods pool the same 6 tokens. Colored dots show each method's output in 2D. Drag tokens (circles) to rearrange. Watch how each method responds differently to the same change.
python import torch import torch.nn as nn import torch.nn.functional as F class MeanPool(nn.Module): def forward(self, x, mask=None): # x: (B, T, D) if mask is not None: x = x * mask.unsqueeze(-1).float() return x.sum(dim=1) / mask.sum(dim=1, keepdim=True).float() return x.mean(dim=1) class MaxPool(nn.Module): def forward(self, x, mask=None): if mask is not None: x = x.masked_fill(~mask.unsqueeze(-1), float('-inf')) return x.max(dim=1).values class CLSPool(nn.Module): def forward(self, x, mask=None): return x[:, 0] # First token class LastTokenPool(nn.Module): def forward(self, x, mask=None): if mask is not None: # Find actual last token per sequence lengths = mask.sum(dim=1) - 1 return x[torch.arange(x.size(0)), lengths] return x[:, -1] class GatedPool(nn.Module): def __init__(self, d_model): super().__init__() self.gate = nn.Linear(d_model, 1) def forward(self, x, mask=None): g = torch.sigmoid(self.gate(x)) # (B, T, 1) if mask is not None: g = g * mask.unsqueeze(-1).float() return (g * x).sum(dim=1) / g.sum(dim=1).clamp(min=1e-8) # Compare all methods on the same input x = torch.randn(1, 10, 64) pools = { 'Mean': MeanPool(), 'Max': MaxPool(), 'CLS': CLSPool(), 'Last': LastTokenPool(), 'Gated': GatedPool(64), } for name, pool in pools.items(): out = pool(x) print(f"{name:6s}: mean={out.mean():.3f}, std={out.std():.3f}") # Mean: mean=-0.012, std=0.298 (smoothed, low variance) # Max: mean= 1.782, std=0.421 (all positive, high values) # CLS: mean=-0.034, std=0.986 (single token, full variance) # Last: mean= 0.087, std=1.024 (single token, full variance) # Gated: mean=-0.008, std=0.315 (similar to mean, slightly adapted)
You've learned six different ways to compress a variable-length input into a fixed-size output. Time to see them all in action on the same data — and discover what each one sees that the others miss.
This is your aggregation laboratory. Switch between vision mode (2D feature map) and NLP mode (token sequence). Pick any pooling method. Add noise. Watch which methods survive the noise and which crumble. The goal: build intuition for when to reach for which tool.
Switch between vision and NLP mode. Select a pooling method. Add noise to see robustness. The information retention bar shows how much input variance survives the pooling operation.
Try these experiments:
Experiment 1 — Noise robustness. Set noise to 0. Note the output. Crank noise to 2.0. Max pooling's output jumps wildly (noise creates outliers that dominate). Mean pooling barely budges (noise averages out). Attention pooling adapts — it down-weights noisy positions if they don't match the query.
Experiment 2 — Vision vs NLP. Switch modes. In vision mode, local pooling (2×2, 3×3) reduces spatial resolution. You can see how max pooling preserves bright spots while mean pooling smooths the whole grid. In NLP mode, there's no spatial locality — the pooling is purely over the token dimension.
Experiment 3 — Information retention. Watch the retention bar at the bottom. GAP (Global Average Pooling) has the lowest retention — it crushes an entire feature map to one number per channel. Attention pooling and gated pooling retain more because they can selectively preserve important information.
Experiment 4 — CLS vs Last. In NLP mode, CLS and last-token give completely different outputs from the same input. CLS takes position 0 (which might be garbage if there's no trained [CLS] embedding). Last-token takes the final position. Neither uses the full sequence unless the upstream model has already mixed information across positions.
What if your model needs to handle images of different sizes? A 224×224 image produces a 7×7 feature map after the convolutional backbone. A 384×384 image produces a 12×12. But the classifier head needs a fixed-size input. You can't change the number of neurons in a linear layer at inference time.
Adaptive pooling solves this: "give me a 1×1 output no matter what size comes in." The pooling window and stride are computed dynamically based on the input size, so the output size is always what you asked for. Every modern CNN uses this — it's why ResNet can process arbitrary resolution images.
Regular pooling says: "use a 2×2 window with stride 2." If your input is 7×7, you get 3×3 output (with one row/column dropped or padded). If your input is 12×12, you get 6×6. The output size depends on the input size.
Adaptive pooling says: "give me a 1×1 output." Then it computes the window size and stride that achieve this:
The most common use: nn.AdaptiveAvgPool2d((1,1)) — global average
pooling with adaptive sizing. Input: any (C, H, W). Output: always (C, 1, 1).
Squeeze the spatial dimensions, and you have a C-dimensional vector ready for
a classifier.
AdaptiveAvgPool on a 6×6 input, target output 2×2:
Hin = 6, Hout = 2. Each output cell covers 6/2 = 3 input rows and 3 input columns.
Now with a 7×7 input, target 2×2: 7/2 = 3.5, so the windows aren't equal. Output (0,0) covers input[0:3], output (1,0) covers input[3:7] (4 rows). The adaptive algorithm handles this asymmetry automatically.
Here's a beautiful idea from image retrieval (Radenovic et al., 2018): what if we could interpolate between average pooling and max pooling with a single learnable parameter?
When p = 1, this is exactly the arithmetic mean (average pooling). As p grows, the result is increasingly dominated by larger values. In the limit p → ∞, it converges to the max. The parameter p is learned — the network discovers whether max or average (or something in between) is best for its task.
Values: [1.0, 0.5, 2.0, 0.3], power p = 3.
Step 1 — Raise each value to power p:
Step 2 — Average the powered values:
Mean = (1.000 + 0.125 + 8.000 + 0.027) / 4 = 9.152 / 4 = 2.288
Step 3 — Take the p-th root:
2.2881/3 = 1.318
Compare:
| Method | Result | Interpretation |
|---|---|---|
| Average (p=1) | 0.950 | All values contribute equally |
| GeM (p=3) | 1.318 | Biased toward larger values |
| GeM (p=10) | 1.815 | Strongly biased toward max |
| Max (p→∞) | 2.000 | Only the largest survives |
As p increases, the output slides smoothly from mean toward max. At p = 3, the value 2.0 already dominates: it contributes 8.0 out of 9.152 to the sum of cubes. The 0.3 contributes almost nothing (0.027). GeM naturally amplifies strong features and suppresses weak ones — but smoothly, not all-or-nothing like max.
Drag the power slider from 0.5 to 10. Watch the GeM output smoothly interpolate between mean pooling (p=1) and max pooling (p large). Below the visualization: a bar chart showing how much each input value contributes to the final result.
Adjust p to see how GeM interpolates between average and max. At p=1 all bars are equal (average). As p grows, large values dominate. The top shows a feature map; the bottom shows effective contribution weights.
python import torch import torch.nn as nn import torch.nn.functional as F # --- Adaptive pooling (built-in, but here's what it does) --- x = torch.randn(1, 64, 7, 7) # from 224×224 input y = torch.randn(1, 64, 12, 12) # from 384×384 input gap = nn.AdaptiveAvgPool2d((1, 1)) print(gap(x).shape) # (1, 64, 1, 1) — always! print(gap(y).shape) # (1, 64, 1, 1) — same output size # Adaptive to 3×3 (used in SPPNet-style architectures) spp = nn.AdaptiveMaxPool2d((3, 3)) print(spp(x).shape) # (1, 64, 3, 3) print(spp(y).shape) # (1, 64, 3, 3) — same! # --- GeM (Generalized Mean Pooling) from scratch --- class GeM(nn.Module): """Generalized Mean Pooling. p=1: average pooling. p→∞: max pooling. Input: (B, C, H, W) Output: (B, C, 1, 1) """ def __init__(self, p=3.0, eps=1e-6): super().__init__() # p is learnable! self.p = nn.Parameter(torch.tensor(p)) self.eps = eps def forward(self, x): # Clamp to avoid log(0) or negative values x_clamped = x.clamp(min=self.eps) # Raise to power p, average, take p-th root return (x_clamped.pow(self.p) .mean(dim=(-2, -1), keepdim=True) .pow(1.0 / self.p)) # Use it in a retrieval model gem = GeM(p=3.0) features = torch.randn(4, 2048, 7, 7) pooled = gem(features).squeeze(-1).squeeze(-1) # (4, 2048) print(f"p learned value: {gem.p.item():.2f}") # starts at 3.0 # After training on landmarks: p ≈ 2.5-4.0 typically # Different channels may benefit from different p values
nn.AdaptiveAvgPool2d((1,1)) on a 7×7 input uses a 7×7
window. On a 14×14 input, a 14×14 window. The pooling operation is
identical to regular pooling; only the window sizing is automatic. There's no
neural network, no learned parameters — it's pure arithmetic that adapts the
kernel to hit your target output size.
One more modern trick: instead of a fixed pooling operation, use a strided convolution. A regular convolution with stride 2 halves the spatial resolution — just like 2×2 max pooling. But the convolution kernel is learned, so the network decides how to downsample instead of being forced into max or average.
ResNet uses max pooling early on. Many modern architectures (ConvNeXt, EfficientNet-v2) replace this with a strided conv — one less design choice to make, and the model can learn what matters. The tradeoff: strided convolutions add parameters and computation; max/avg pooling is free.
| Approach | Parameters | Learns? | Use case |
|---|---|---|---|
| Max Pool | 0 | No | Classic CNNs (VGG, ResNet) |
| Avg Pool | 0 | No | Smooth downsampling, GAP for classifiers |
| Adaptive Pool | 0 | No | Resolution-agnostic final layer |
| GeM | 1 (p) | Yes | Image retrieval, fine-grained recognition |
| Strided Conv | k2·C2 | Yes | Modern architectures replacing fixed pooling |
| Attention Pool | d (query) | Yes | Sequence aggregation, set aggregation |
Time to put all eight pooling methods head-to-head. Pick a task below — Classification, Similarity, or Retrieval — and watch each method race to show its strengths (and weaknesses).
What you're seeing: Each bar represents one pooling method. The bar length shows a composite score combining discriminability, robustness, and speed — weighted differently for each task. The star marks the winner.
No single method wins everything. That's the lesson. The right pooling depends on what you're optimizing for — and the best practitioners choose deliberately.
Everything from this lesson on one page. Pin it, print it, screenshot it.
| Method | Learnable? | Best For | Weakness | FLOPs |
|---|---|---|---|---|
| Max Pool | No | Edge/texture detection (CNNs) | Discards magnitude, sensitive to outliers | O(n) |
| Average Pool | No | Smooth features, denoising | Dilutes strong signals | O(n) |
| GAP | No | CNN classification (ResNet, EfficientNet) | Treats all spatial positions equally | O(HW) |
| CLS Token | Yes (implicit) | BERT-style fine-tuning | Requires [CLS] pretraining; single bottleneck | O(1) at pool time |
| Mean Pool | No | Sentence embeddings (SBERT) | Padding tokens dilute signal (mask them!) | O(n) |
| Attention Pool | Yes | Multi-instance learning, set classification | Extra parameters; can overfit on small data | O(nd) |
| GeM | Yes (p) | Image retrieval, place recognition | Assumes positive activations (post-ReLU) | O(n) |
| Last Token | No | Causal LMs (GPT-style) | Only sees leftward context in final position | O(1) at pool time |
# ── Max Pool (2D, stride=2) ── pool = nn.MaxPool2d(kernel_size=2, stride=2) out = pool(x) # [B, C, H/2, W/2] # ── Global Average Pooling ── gap = nn.AdaptiveAvgPool2d(1) out = gap(x).squeeze(-1).squeeze(-1) # [B, C] # ── Mean Pooling (with mask) ── def mean_pool(hidden, mask): mask_exp = mask.unsqueeze(-1).float() return (hidden * mask_exp).sum(1) / mask_exp.sum(1).clamp(min=1e-9) # ── CLS Token ── cls_embed = hidden_states[:, 0, :] # [B, D] # ── Attention Pooling ── class AttentionPool(nn.Module): def __init__(self, dim): super().__init__() self.query = nn.Parameter(torch.randn(dim)) def forward(self, x): # x: [B, N, D] scores = (x @ self.query) / x.size(-1)**0.5 weights = scores.softmax(dim=1).unsqueeze(-1) return (x * weights).sum(1) # ── GeM Pooling ── class GeM(nn.Module): def __init__(self, p=3.0): super().__init__() self.p = nn.Parameter(torch.tensor(p)) def forward(self, x): # x: [B, C, H, W] return x.clamp(min=1e-6).pow(self.p) \ .mean(dim=[-2,-1]).pow(1/self.p) # ── Last Token (causal LM) ── last = hidden_states[:, -1, :] # [B, D]
You're building a semantic search engine using a BERT model that was pretrained without a [CLS] token objective. Your inputs are variable-length sentences. Which pooling strategy should you use?