Training Foundations

Pooling & Aggregation

From max pooling to GeM — how to collapse variable-length features into the fixed-size vectors that power classification, retrieval, and embeddings.

Prerequisites: What a feature map / token embedding is + Mean and max operations. That’s it.
10
Chapters
12+
Simulations
0
Assumed Knowledge

Chapter 0: The Dimension Problem

A CNN outputs a 7×7×512 feature map — 25,088 numbers describing an image at different spatial locations. But the classifier needs ONE vector: is this a cat or a dog? A BERT encoder outputs 512 vectors, one per token. But the search engine needs ONE embedding for the whole sentence.

How do you go from many vectors to one?

This is the dimension problem. Neural networks produce outputs at every spatial position (CNNs give you H×W feature maps) or at every token position (transformers give you seq_len × d). But downstream tasks — classification, retrieval, regression — need a single fixed-size vector. One prediction per image. One embedding per sentence.

We need to collapse a variable-length dimension into a fixed representation. The question isn't whether to collapse — it's how to collapse without losing critical information.

What Does a Feature Map Look Like?

Let's make this concrete. A convolutional layer with 512 filters applied to a 7×7 spatial grid produces a tensor of shape (7, 7, 512). Each spatial position (i, j) holds a 512-dimensional vector describing what the network "sees" at that location. Position (0, 0) might describe the top-left corner — maybe a patch of sky. Position (3, 3) might describe the center — maybe the cat's nose.

All 49 positions carry information. The challenge: compress those 49 vectors into one, preserving what matters for the task.

The core tension. If you only look at one spatial position, you might miss the cat's ear. If you look at all positions equally, you might drown the cat's nose in a sea of background. Every pooling method is a different answer to: "Which positions matter, and how much?"

Why Not Just Flatten?

The naive approach: reshape the 7×7×512 tensor into a single vector of length 25,088 and feed it to a fully-connected layer. Early CNNs (AlexNet, VGG) did exactly this. It works — but with devastating consequences.

Problem 1: Size depends on input resolution. If your input image is 224×224, the feature map is 7×7 and the flattened vector is 25,088. If your input is 448×448, the feature map is 14×14 and the flattened vector is 100,352. The FC layer has a fixed weight matrix — it can't handle different input sizes. You're locked to one resolution forever.

Problem 2: Parameter explosion. VGG-16 has a 7×7×512 = 25,088 vector feeding into a 4,096-neuron FC layer. That's 25,088 × 4,096 = 102 million parameters in a single layer. The entire convolutional backbone has only 14 million. The FC layers dominate the model, and most of those parameters overfit.

Problem 3: No spatial invariance. Flattening assigns a separate weight to each spatial position. If the cat shifts 2 pixels right, completely different weights activate. The FC layer must re-learn "cat" at every possible position — a waste of capacity.

Hand Calculation: Flatten vs Pool

Let's see the parameter cost difference with real numbers.

ApproachFeature MapVector SizeFC Params (to 1000 classes)
Flatten (7×7)7×7×51225,08825,088 × 1000 = 25.1M
Flatten (14×14)14×14×512100,352100,352 × 1000 = 100.4M
Global PoolAny H×W×512512512 × 1000 = 0.5M

Global pooling collapses any spatial size to 512 — a 50× reduction in parameters compared to flattening a 7×7 map, and 200× compared to 14×14. That's not optimization — it's a fundamentally different approach.

The Three Methods

This lesson covers three families of pooling. Max pooling picks the strongest activation — if any neuron in a region fires, we keep it. Average pooling takes the mean — every position contributes equally. And sequence pooling handles the 1D case in transformers, where we collapse tokens instead of spatial positions.

Let's see them in action before we formalize anything.

Pooling Explorer

A 4×4 feature map. Click any cell to change its value. Toggle between Max, Average, and ??? pooling to see how the output changes. Which method preserves the strongest signal? Which smooths everything out?

Notice: when you click around and create one very high value, max pooling immediately grabs it. The output jumps. Average pooling barely moves — one outlier among many gets diluted. The "???" mode is a preview of something we'll build later (weighted pooling), but for now, focus on max vs average. They answer different questions about the data.

Pooling is NOT just "downsampling." Downsampling (e.g., stride-2 convolution) reduces resolution but keeps per-position features. Pooling aggregates information from multiple positions into one — it's a fundamentally different operation. Stride reduces WHERE you sample. Pooling combines WHAT you sample.

Hand Calculation: Max vs Average

A 4×4 feature map, pooled with 2×2 windows and stride 2. Each window collapses 4 values into 1.

Input:

values
0.2  0.8  |  0.1  0.9
0.5  0.3  |  0.7  0.4
-----|-----
0.6  0.1  |  0.3  0.5
0.4  0.9  |  0.2  0.8

Max pooling (2×2, stride 2): take the max of each block.

Output: [[0.8, 0.9], [0.9, 0.8]]. Max preserves the peaks.

Average pooling (2×2, stride 2): take the mean of each block.

Output: [[0.45, 0.525], [0.50, 0.45]]. Average produces a smoother summary.

See the difference? Max output ranges from 0.8 to 0.9 — it kept the strongest signals. Average output ranges from 0.45 to 0.525 — it smoothed everything toward the mean. If you care about whether a feature exists, max is your friend. If you care about the overall level, average tells a better story.

From Scratch in Code

python
import numpy as np

def max_pool_2x2(x):
    # x: (H, W) feature map, H and W must be even
    H, W = x.shape
    out = np.zeros((H // 2, W // 2))
    for i in range(0, H, 2):
        for j in range(0, W, 2):
            block = x[i:i+2, j:j+2]
            out[i//2, j//2] = np.max(block)
    return out

def avg_pool_2x2(x):
    # Same structure, but mean instead of max
    H, W = x.shape
    out = np.zeros((H // 2, W // 2))
    for i in range(0, H, 2):
        for j in range(0, W, 2):
            block = x[i:i+2, j:j+2]
            out[i//2, j//2] = np.mean(block)
    return out

# Test with our hand calculation
x = np.array([
    [0.2, 0.8, 0.1, 0.9],
    [0.5, 0.3, 0.7, 0.4],
    [0.6, 0.1, 0.3, 0.5],
    [0.4, 0.9, 0.2, 0.8]
])
print(max_pool_2x2(x))  # [[0.8, 0.9], [0.9, 0.8]]
print(avg_pool_2x2(x))  # [[0.45, 0.525], [0.5, 0.45]]
Why can't we just flatten the feature map and feed it to a classifier?

Chapter 1: Max Pooling

Imagine you're scanning a photo for edges. Your CNN has 64 edge detectors, each producing a number at every spatial position. One detector fires strongly at position (3, 4) — there's a vertical edge there. The neighboring positions have weaker responses: maybe 0.1, 0.3, 0.2.

Max pooling says: "I don't care WHERE in this 2×2 region the edge is. I only care that it EXISTS." It grabs the strongest response and discards the rest. If any neuron in the pooling window detected a feature, the max preserves that detection.

This gives max pooling a powerful property: translation invariance within the pooling window. Shift the edge one pixel right? It's still inside the same 2×2 window. The max pool output doesn't change. The network becomes robust to small spatial shifts — exactly what you want for recognition tasks.

The Formula

For a 2×2 window with stride 2 applied to a 4×4 input, each output value is:

output[i, j] = max(input[2i : 2i+2,   2j : 2j+2])

The output has shape (H/2, W/2). We halve the spatial dimensions. Each output cell summarizes a 2×2 patch by keeping only its maximum.

More generally, for a pool window of size k with stride s:

output[i, j] = max(input[si : si+k,   sj : sj+k])

Output shape: ((H - k) / s + 1, (W - k) / s + 1). When k = s = 2, this simplifies to (H/2, W/2). When k = 3, s = 2, you get overlapping pooling regions — used in AlexNet.

Hand Calculation: Full 4×4 Max Pool

Input feature map (4×4):

values
1.2  0.5  |  3.1  0.8
0.3  2.7  |  0.1  1.5
-----|-----
0.9  0.4  |  2.3  3.8
1.1  0.6  |  0.7  1.2

Top-left block [1.2, 0.5, 0.3, 2.7]: compare all four — max is 2.7 (bottom-right of the block).

Top-right block [3.1, 0.8, 0.1, 1.5]: max is 3.1 (top-left of the block). The strong feature at this exact position is preserved.

Bottom-left block [0.9, 0.4, 1.1, 0.6]: max is 1.1 (bottom-left). This is the weakest-maximum block — none of the positions fired strongly.

Bottom-right block [2.3, 3.8, 0.7, 1.2]: max is 3.8 (top-right). That single strong activation dominates.

Output (2×2):

result
2.7  3.1
1.1  3.8

Four values. Each is the winner of its 2×2 block. The spatial resolution dropped from 4×4 to 2×2, but the strongest features survived.

The Gradient: Winner Take All

During backpropagation, where does the gradient flow? Only to the position that held the maximum value. All other positions in the window receive zero gradient. This is a "winner-take-all" mechanism.

Think about it mathematically. The max function is:

∂max(a, b, c, d) / ∂a = 1 if a = max,   0 otherwise

For our top-left block, the gradient flows entirely to position (1, 1) where the value was 2.7. Positions (0, 0), (0, 1), and (1, 0) get zero gradient from this output cell. They won't be updated from this particular pooling window.

Let's trace the backward pass for our example. Suppose the loss gradient flowing back to the output is:

gradient at output
0.5  -0.3
0.2  -0.1

The gradient at the input is a 4×4 grid of mostly zeros:

gradient at input
0.0   0.0  |  -0.3  0.0      # 3.1 was max → gets -0.3
0.0   0.5  |   0.0  0.0      # 2.7 was max → gets 0.5
------|------
0.0   0.0  |   0.0  -0.1     # 3.8 was max → gets -0.1
0.2   0.0  |   0.0   0.0     # 1.1 was max → gets 0.2

Only 4 out of 16 positions receive gradient. The rest are silent for this particular pooling operation. Sparse gradients.

Interactive Max Pooling

Edit values in the 4×4 grid (click a cell, then use the slider). The 2×2 max-pooled output updates live. Orange borders show which cell "won" each block. Teal arrows show gradient flow (only to winners).

Selected cell value 1.0

Try making one cell much larger than its neighbors — say, 4.5. Watch the output snap to that value instantly. Now lower it below its neighbor. The "winner" border shifts to the new max. The gradient arrow changes direction. Only the current winner gets the learning signal.

From Scratch and with PyTorch

python
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# From scratch with gradient tracking
def max_pool_2x2_with_grad(x):
    """Returns pooled output AND a mask of winner positions."""
    H, W = x.shape
    out = np.zeros((H // 2, W // 2))
    mask = np.zeros_like(x)  # 1 at max positions, 0 elsewhere
    for i in range(0, H, 2):
        for j in range(0, W, 2):
            block = x[i:i+2, j:j+2]
            max_val = np.max(block)
            out[i//2, j//2] = max_val
            # Mark the winner (argmax within the block)
            mi, mj = np.unravel_index(np.argmax(block), (2, 2))
            mask[i + mi, j + mj] = 1
    return out, mask

# PyTorch: nn.MaxPool2d
pool = nn.MaxPool2d(kernel_size=2, stride=2)
x_t = torch.tensor([[
    [1.2, 0.5, 3.1, 0.8],
    [0.3, 2.7, 0.1, 1.5],
    [0.9, 0.4, 2.3, 3.8],
    [1.1, 0.6, 0.7, 1.2]
]], requires_grad=True).unsqueeze(0)  # (1, 1, 4, 4)

out = pool(x_t)    # tensor([[[[2.7, 3.1], [1.1, 3.8]]]])
out.sum().backward()
print(x_t.grad)    # 1s at max positions, 0s elsewhere

# Functional API (equivalent, but no stored state)
out2 = F.max_pool2d(x_t, kernel_size=2, stride=2)
Max pooling's gradient going only to the max position doesn't mean other neurons "don't learn." The OTHER neurons still receive gradients from OTHER pooling regions where they ARE the max. Only within a single pooling window is the gradient winner-take-all. Across the full feature map, most neurons get gradients from at least some windows. Additionally, gradients also flow through previous layers via the convolution weights, which are shared across all spatial positions.

Where Max Pooling Shines

Max pooling appears in nearly every classic CNN. AlexNet (2012) used overlapping 3×3 max pools with stride 2. VGG (2014) used non-overlapping 2×2 with stride 2 after every 2-3 conv layers. ResNet (2015) uses a single 3×3 max pool early in the network, then relies on strided convolutions.

It works best when the presence of a feature matters more than its exact magnitude. Edge detection, object detection, texture recognition — anywhere you want "is this feature here?" rather than "how much of this feature is here?"

But max pooling has a fundamental limitation: it throws away magnitude information. A block with values [0.1, 0.1, 0.1, 5.0] and a block with values [4.8, 4.9, 4.7, 5.0] both produce output 5.0. The first block had one spike amid silence. The second had uniformly strong activation. Max pooling can't tell the difference. Sometimes that distinction matters — and that's where average pooling enters the picture.

In max pooling, how many positions in each pooling window receive gradient during backpropagation?

Chapter 2: Average Pooling

Max pooling keeps the loudest voice. Average pooling listens to everyone equally. If a region has values [0.1, 0.2, 0.8, 0.1], max says "0.8 — there's a strong feature here." Average says "0.3 — there's a weak-to-moderate signal overall." Different answers. Different use cases.

Average pooling takes the arithmetic mean of all values in the pooling window. Every position contributes equally to the output. There are no winners and no losers — the output reflects the collective response.

The Formula

For a 2×2 window with stride 2:

output[i, j] = (1/4) ∑ input[2i : 2i+2,   2j : 2j+2]

More generally, for a k×k window:

output[i, j] = (1 / k²) ∑m,n input[si+m, sj+n]

The gradient is simple and elegant: every position in the window receives an equal share. If the output gradient is g, each of the k² inputs gets g / k². No winner-take-all — everyone gets updated.

∂output / ∂input[m, n] = 1 / k²   for all m, n in the window

Average vs Max: Side by Side

Let's compare them on the same data to build intuition for when to use which.

PropertyMax PoolingAverage Pooling
What it keepsStrongest activationOverall activation level
Gradient flowOnly to winner (sparse)Equal to all (dense)
Sensitive to outliersVery — one high value dominatesRobust — outliers are diluted
Best forFeature detection (is it there?)Feature magnitude (how much?)
Classic useLocal pooling in conv layersGlobal pooling before classifier

In practice, the community converged on a split: local pooling (2×2, stride 2 between conv layers) usually uses max. Global pooling (entire feature map → 1 value per channel) usually uses average. Why? Local max pooling preserves edge-like features as they flow through the network. Global average pooling produces a stable summary of the entire feature map for classification.

Global Average Pooling (GAP)

This is the big one. Global Average Pooling doesn't use a small window — it takes the mean across the ENTIRE spatial dimension. If your feature map is (batch, 512, 7, 7), GAP produces (batch, 512). One number per channel. The entire 7×7 spatial grid collapses into a single value.

Conv Output
(batch, 512, 7, 7) — 25,088 values per image
↓ Global Average Pool
Pooled Vector
(batch, 512) — 512 values per image
↓ Linear Layer
Class Logits
(batch, 1000) — one score per class

This replaced the massive fully-connected layers in classic CNNs. VGG-16 had 119 million parameters in its FC layers (out of 138M total). ResNet with GAP has zero parameters in pooling and only 512 × 1000 = 0.5M in the final linear layer. A 200× reduction in classifier parameters.

GAP was introduced by Lin et al. in 2013 ("Network in Network") and became standard with GoogLeNet (2014) and ResNet (2015). Every modern CNN uses it.

Hand Calculation: Global Average Pooling

One channel of a 3×3 feature map:

values
0.5  0.2  0.8
0.1  0.9  0.3
0.4  0.6  0.7

GAP = (0.5 + 0.2 + 0.8 + 0.1 + 0.9 + 0.3 + 0.4 + 0.6 + 0.7) / 9 = 4.5 / 9 = 0.5.

One number. It says: "this channel has moderate overall activation across the entire spatial extent." If this were the "cat ear" channel and only one position had a high value (say 3.0 while the rest are 0.1), GAP would output (3.0 + 8 × 0.1) / 9 = 3.8 / 9 ≈ 0.42. The strong local detection gets diluted — which is actually appropriate, because the channel's OVERALL response is indeed moderate.

Now consider a channel where MOST positions fire strongly: [2.1, 1.8, 2.3, 1.9, 2.5, 2.0, 1.7, 2.2, 2.4]. GAP = 18.9 / 9 = 2.1. This channel responds strongly across the entire image. GAP captures that distinction perfectly.

Why GAP is Regularization

GAP forces each channel in the last conv layer to correspond to one class-relevant feature. Why? Because the final linear layer takes the 512-dimensional GAP output and produces class scores with a simple matrix multiply. Each channel's pooled value is multiplied by one weight per class.

This means channel 37 can't encode "cat nose at position (3, 4)" — it has to encode "cat-nose-ness" everywhere. The spatial information is explicitly removed. The channel must learn a position-invariant feature detector, or it contributes nothing useful after GAP.

This is built-in structural regularization. No dropout needed. No weight decay on the pooling layer (it has no weights). Just the architectural choice of "average everything" forces better feature learning upstream.

Max vs Average vs Global Average

Same 4×4 grid. Toggle between Max, Average, and GAP. Add noise with the slider — watch how max is sensitive to outliers (one high value dominates) while average is robust. GAP collapses everything to a single number.

Noise level 0.0

Try cranking the noise slider up. Max pooling's output swings wildly — whichever cell gets the largest noise spike wins. Average pooling barely moves, because the noise cancels out across positions. This is why average pooling is preferred for global aggregation: it's a more stable summary.

From Scratch and with PyTorch

python
import numpy as np
import torch
import torch.nn as nn

# Average pooling from scratch
def avg_pool_2x2(x):
    H, W = x.shape
    out = np.zeros((H // 2, W // 2))
    for i in range(0, H, 2):
        for j in range(0, W, 2):
            out[i//2, j//2] = np.mean(x[i:i+2, j:j+2])
    return out

# Global Average Pooling from scratch
def global_avg_pool(x):
    # x shape: (channels, H, W)
    return np.mean(x, axis=(1, 2))  # shape: (channels,)

# PyTorch: nn.AvgPool2d for local average pooling
avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
x = torch.randn(1, 512, 8, 8)
out = avg_pool(x)       # (1, 512, 4, 4)
print(out.shape)

# PyTorch: nn.AdaptiveAvgPool2d for GAP
# "Adaptive" means: whatever the input size, produce THIS output size
gap = nn.AdaptiveAvgPool2d(1)  # output is (1, 1) spatially
x7 = torch.randn(1, 512, 7, 7)
x14 = torch.randn(1, 512, 14, 14)
print(gap(x7).shape)    # (1, 512, 1, 1) — works!
print(gap(x14).shape)   # (1, 512, 1, 1) — same output shape!

# ResNet-style classifier head:
# features = backbone(image)      # (batch, 512, 7, 7)
# pooled = gap(features).flatten(1) # (batch, 512)
# logits = fc(pooled)              # (batch, num_classes)
Global Average Pooling is NOT just "a pooling layer." It's an architectural choice that eliminates millions of parameters. VGG-16 had 138M parameters, with 119M in the FC layers after flattening. ResNet with GAP has only 25M parameters total. GAP acts as structural regularization — it forces each channel to encode a position-invariant, class-relevant feature. This is why modern CNNs dramatically outperform their predecessors with fewer parameters.

The Gradient of Average Pooling

Unlike max pooling's sparse gradient, average pooling distributes gradient uniformly. For a k×k window, each of the k² inputs receives 1/k² of the output gradient.

For GAP over an H×W feature map: each spatial position receives gradient = output_grad / (H × W). If H = W = 7, each position gets 1/49 of the gradient. Small per position, but EVERY position gets some signal. Dense but diluted.

This has an interesting consequence for learning dynamics. Max pooling gives strong gradients to a few positions (aggressive learning at peaks). Average pooling gives weak gradients to all positions (gentle, uniform learning). In practice, both work — the choice depends on whether your task cares about peaks or distributions.

Why did Global Average Pooling replace fully-connected layers in modern CNNs?

Chapter 3: Sequence Pooling — CLS Token vs Mean

BERT outputs 512 vectors, one per token. You need ONE vector to classify the sentiment of a movie review. Which token's output do you use? The first? The last? All of them?

BERT's answer: a special [CLS] token prepended to every input. Its output "summarizes" the whole sequence through attention. But is that actually the best approach? Spoiler: for many tasks, it isn't.

The [CLS] Token Approach

In BERT and its descendants, every input sequence is prepended with a special [CLS] token (short for "classification"). The input looks like:

tokens
[CLS]  The  movie  was  terrible  [SEP]

After L layers of self-attention, each token's output vector has "seen" every other token. The [CLS] token has attended to "The," "movie," "was," "terrible" — accumulating information about the whole sequence through the attention mechanism.

To classify, you extract [CLS]'s output and pass it through a linear layer:

BERT Output
(seq_len, d) — one vector per token
↓ Extract position 0
[CLS] Vector
(d,) — 768 dimensions
↓ Linear Layer
Class Logits
(num_classes,) — e.g., positive/negative

Simple. Elegant. One token to rule them all. But there's a catch.

The Cold Start Problem

At layer 1 of the transformer, [CLS] has zero information. It's initialized with a learned embedding, but it hasn't attended to anything yet. It must build its representation from scratch, layer by layer, through attention.

In the early layers, [CLS]'s representation is essentially random with respect to the actual input content. Only in the deeper layers does it start to carry meaningful sequence-level information. This means the lower layers of the transformer are wasting capacity: [CLS] is attending to tokens but not contributing useful information back.

For classification tasks (where [CLS] is explicitly trained with a classification loss), this works fine — the model learns to route information to [CLS] during fine-tuning. But for embedding tasks (retrieval, semantic search, similarity), where we want the vector to capture the meaning of the text, [CLS] has a fundamental limitation: it represents a single token's "view" of the sequence.

The Mean Pooling Alternative

Mean pooling is dead simple: average all token outputs.

embedding = (1 / T) ∑t=1T ht

Where T is the sequence length and ht is the output vector of token t. Every token contributes equally. No special token needed.

Why does this work better for embeddings? Each token's output ht is a contextualized representation — it already contains information about the whole sequence through attention. Token 3 ("was") knows about "terrible" and "movie." So when you average all tokens, you're averaging many different contextual views of the same sequence, not just one.

Hand Calculation: CLS vs Mean

Four tokens, embedding dimension d = 3. After the final transformer layer:

token outputs
h0 [CLS]       = [ 0.5,  0.3, -0.1]
h1 "The"       = [ 0.8, -0.2,  0.4]
h2 "movie"     = [ 0.1,  0.9,  0.3]
h3 "terrible"  = [-0.3,  0.5,  0.6]

CLS pooling: Just take h0.

Result: [0.5, 0.3, -0.1]

Mean pooling: Average all four vectors element-wise.

Result: [0.275, 0.375, 0.300]

The two results are completely different. CLS gives you whatever the [CLS] token learned. Mean gives you a balanced summary across all tokens. For a sentiment task, [CLS] might work fine if it was trained for sentiment. But for a general-purpose embedding, the mean captures more information — it includes the signal from "terrible" (h3) and "movie" (h2), not just what [CLS] decided to remember.

The Evidence: Reimers & Gurevych (2019)

The Sentence-BERT paper (Reimers & Gurevych, 2019) compared CLS pooling, mean pooling, and max pooling for sentence embeddings. The results were clear:

Pooling MethodSTS Benchmark (Spearman)Notes
[CLS] token~77Worst performance for embeddings
Max pooling~79Better than CLS but still limited
Mean pooling~84Best performance by a clear margin

Since then, essentially every embedding model uses mean pooling: Sentence-BERT, E5, GTE, BGE, Nomic Embed. CLS pooling remains standard only for classification fine-tuning, where the [CLS] token is explicitly trained with a task-specific head.

The Masking Problem

There's a critical implementation detail that trips up beginners. In a batch, sequences have different lengths. Shorter sequences are padded with [PAD] tokens to match the longest sequence:

batch
"The movie was terrible"     [CLS] The movie was terrible [SEP] [PAD] [PAD]
"I loved it so much really"  [CLS] I loved it so much really [SEP]

The padding tokens have output vectors — usually close to zero, but NOT exactly zero. If you naively average all positions including padding, you're diluting real signal with padding noise. Sentence 1's embedding is worse because two of its eight "token" outputs are junk.

The fix: use the attention mask to exclude padding.

embedding = ∑(ht × maskt) / ∑(maskt)

Where maskt = 1 for real tokens and 0 for padding. The denominator is the actual number of real tokens, not the padded sequence length.

Hand Calculation: Masked Mean Pooling

Sequence with 3 real tokens and 1 padding token, d = 2:

tokens
h0 "The"   = [0.6, 0.4]   mask = 1
h1 "cat"   = [0.8, 0.2]   mask = 1
h2 "sat"   = [0.3, 0.7]   mask = 1
h3 [PAD]   = [0.1, 0.05]  mask = 0

Without masking: (0.6+0.8+0.3+0.1)/4 = 0.45, (0.4+0.2+0.7+0.05)/4 = 0.3375 → [0.45, 0.3375]

With masking: (0.6+0.8+0.3)/3 = 0.567, (0.4+0.2+0.7)/3 = 0.433 → [0.567, 0.433]

Different results! The masked version correctly ignores the padding and gives a higher-magnitude, cleaner embedding. In production, this difference compounds across thousands of queries and degrades retrieval quality if you forget the mask.

CLS vs Mean Pooling

A sequence of tokens, each with an embedding vector. Toggle between CLS (use first token) and Mean (average all tokens). The bottom bar shows the resulting sentence embedding. Click tokens to edit their values and see how each method responds.

Padding tokens 0

Try adding padding tokens with the slider. In CLS mode, padding doesn't affect the result at all — we only use token 0. In Mean mode, watch the embedding shift as padding tokens dilute the average. This is why attention masking is essential for mean pooling in practice. The simulation applies masking automatically; toggle padding to see the difference.

From Scratch and with PyTorch

python
import torch

# CLS pooling: just slice
def cls_pool(hidden_states):
    """hidden_states: (batch, seq_len, d)"""
    return hidden_states[:, 0, :]  # (batch, d)

# Mean pooling with attention mask
def mean_pool(hidden_states, attention_mask):
    """
    hidden_states: (batch, seq_len, d)
    attention_mask: (batch, seq_len) — 1 for real, 0 for pad
    """
    # Expand mask to (batch, seq_len, d) for broadcasting
    mask = attention_mask.unsqueeze(-1).float()  # (batch, seq_len, 1)

    # Zero out padding positions, then sum
    summed = (hidden_states * mask).sum(dim=1)  # (batch, d)

    # Divide by number of real tokens (not padded length!)
    counts = mask.sum(dim=1).clamp(min=1e-9)  # (batch, 1)
    return summed / counts  # (batch, d)

# Example: batch of 2 sentences, max_len=6, d=4
hidden = torch.randn(2, 6, 4)
mask = torch.tensor([
    [1, 1, 1, 1, 0, 0],  # 4 real tokens, 2 padding
    [1, 1, 1, 1, 1, 1],  # 6 real tokens, no padding
])

cls_emb = cls_pool(hidden)          # (2, 4)
mean_emb = mean_pool(hidden, mask)  # (2, 4)
print(cls_emb.shape, mean_emb.shape)
Mean pooling should use an attention mask to EXCLUDE padding tokens. If your batch has sequences of different lengths padded to the max, averaging without masking includes the padding vectors (usually near-zero), which dilutes the representation. Always: sum(h_t × mask_t) / sum(mask_t). Forgetting this is one of the most common bugs in embedding pipelines.

When to Use Which

TaskBest PoolingWhy
Text classification (fine-tuned)[CLS]Classification head is trained directly on [CLS]; it learns to encode task-specific info
Sentence embeddings / retrievalMeanRicher representation from all tokens; empirically better similarity scores
Named entity recognitionNone (per-token)Need per-token predictions; pooling would destroy position info
Semantic searchMeanEvery token contributes to meaning; [CLS] misses nuance
Zero-shot classificationMeanBetter generalization to unseen classes

The pattern is clear: if you're fine-tuning for a specific task and can train the [CLS] token directly, CLS works. If you need general-purpose embeddings that transfer across tasks, mean pooling wins.

For sentence embedding tasks (retrieval, semantic search), which pooling method generally produces better results?

Chapter 4: Attention Pooling

Mean pooling treats all tokens equally. But in "The movie was absolutely terrible," the words "absolutely terrible" carry all the sentiment. "The movie was" is filler. What if we could learn to weight the important tokens more heavily?

That's the idea behind attention pooling. Instead of assigning uniform weight 1/T to every position (mean pooling) or picking one winner (max pooling), we learn a query vector q that asks: "which tokens are relevant for my task?" Each token gets scored against this query, and the scores become weights via softmax. The final output is a weighted sum where important tokens dominate.

Think of q as a talent scout at an audition. Every token performs. The scout (q) evaluates each performance, assigns scores, and the final cast is dominated by the top performers — but everyone gets at least a small role (softmax never hits exactly zero).

The Mechanism

Given a sequence of token representations h1, h2, ..., hT, each a d-dimensional vector:

Score
at = qT · ht for each token t
Scale
Divide by √d to prevent large dot products from saturating softmax
Normalize
wt = softmax(a / √d) — weights sum to 1
Aggregate
output = ∑ wt · ht — weighted combination

The query vector q is a learned parameter — it starts random and the training loss pushes it toward directions that identify task-relevant tokens. For sentiment analysis, q might learn to attend to adjectives and adverbs. For topic classification, it might attend to nouns and proper names.

Why scale by √d? When d is large, dot products grow proportionally to √d. A dot product of 50 makes softmax output nearly one-hot — one token gets weight ~1.0 and everything else gets ~0.0. Scaling by √d keeps the scores in a range where softmax produces meaningfully different weights instead of a hard argmax. This is the exact same scaling used in self-attention (Vaswani et al., 2017).

Hand Calculation: Attention Pooling Step by Step

Let's trace through a concrete example. We have 4 tokens with d=2 dimensional representations, and a learned query vector.

Setup. Token representations: h0 = [0.5, 0.3], h1 = [0.8, -0.2], h2 = [0.1, 0.9], h3 = [-0.3, 0.5]. Learned query: q = [0.7, -0.5].

Step 1 — Compute raw scores (dot products):

Token 1 scores highest (0.66) — its representation aligns best with the query. Tokens 2 and 3 score negative — they point away from what the query is looking for.

Step 2 — Scale by √d = √2 ≈ 1.414:

Step 3 — Softmax to get weights:

exp(0.141) = 1.152, exp(0.467) = 1.595, exp(-0.269) = 0.764, exp(-0.325) = 0.722. Sum = 4.233.

Token 1 gets 37.7% of the weight — the network decided it's most relevant. Token 0 gets 27.2%. Tokens 2 and 3 together only get about 35%. Compare this to mean pooling where every token gets exactly 25%.

Step 4 — Weighted sum:

output = 0.272 × [0.5, 0.3] + 0.377 × [0.8, -0.2] + 0.181 × [0.1, 0.9] + 0.170 × [-0.3, 0.5]

Final output: [0.405, 0.254]. Compare with mean pooling: [0.275, 0.375]. The attention output is pulled toward token 1's representation ([0.8, -0.2]) because the query decided it was most important.

See It Live

The simulation below shows attention pooling on 6 tokens. Adjust the query direction with the slider — watch how the attention weights shift, how different tokens light up, and how the output vector changes compared to uniform mean pooling.

Attention Pooling Explorer

Drag the query angle to change which tokens the attention focuses on. Bar heights show attention weights. The output vector (gold) shifts toward the highest-weighted tokens. Gray dashed line shows mean pooling for comparison.

Query angle 135°
Temperature 1.0

From Scratch: Attention Pooling Module

python
import torch
import torch.nn as nn
import torch.nn.functional as F

class AttentionPooling(nn.Module):
    """Single-query attention pooling.
    Input:  (B, T, D) — batch of token sequences
    Output: (B, D)   — one vector per sequence
    """
    def __init__(self, d_model):
        super().__init__()
        # Learned query: what to look for
        self.query = nn.Parameter(torch.randn(d_model))

    def forward(self, x, mask=None):
        # x: (B, T, D), self.query: (D,)
        scores = torch.matmul(x, self.query)  # (B, T)
        scores = scores / (x.size(-1) ** 0.5)  # scale by sqrt(d)

        if mask is not None:
            scores = scores.masked_fill(~mask, float('-inf'))

        weights = F.softmax(scores, dim=-1)  # (B, T)
        return torch.sum(weights.unsqueeze(-1) * x, dim=1)  # (B, D)


# Test it
pool = AttentionPooling(d_model=64)
x = torch.randn(2, 10, 64)  # 2 sequences, 10 tokens, 64-dim
out = pool(x)
print(out.shape)  # torch.Size([2, 64])

# With padding mask (tokens 8,9 are padding for sequence 0)
mask = torch.ones(2, 10, dtype=torch.bool)
mask[0, 8:] = False
out_masked = pool(x, mask=mask)  # Padding tokens get zero weight

Notice the mask parameter — in real NLP, sequences have different lengths and are padded. Without masking, attention would waste weight on padding tokens. The masked_fill(-inf) trick makes softmax assign exactly zero weight to padded positions.

Attention pooling ≠ self-attention. Self-attention computes pairwise interactions between all tokens: each token queries every other token, costing O(T2) computation. Attention pooling uses ONE learned query against all tokens: O(T) computation. It's a lightweight aggregation step, not a full attention layer. You could think of it as a single "read head" that extracts one summary from the sequence. Self-attention is T read heads, one per position.

Where Attention Pooling Shows Up

Model / PaperHow it uses attention pooling
Set Transformer (Lee et al., 2019)PMA (Pooling by Multi-head Attention) — uses k learned "seed" vectors as queries to extract k summary vectors from a set
Multi-Instance LearningAggregate bag-of-instances into a single bag representation; attention weights tell you which instances matter
Document ClassificationPool sentence embeddings into document embedding; attention weight = sentence importance
Sentence-BERT (optional)Weighted aggregation of token embeddings, outperforms mean pooling on some tasks
What does the learned query vector q in attention pooling represent?

Chapter 5: Weighted Pooling & Learned Aggregation

Attention pooling with one query asks one question. But what if you need multiple aspects of the sequence — sentiment AND topic AND entity? Multi-head pooling asks multiple questions simultaneously. And for autoregressive models like GPT, there's an even simpler answer: just use the last token.

This chapter covers three more aggregation strategies that go beyond simple attention pooling. Each makes a different bet about where information concentrates in the sequence — and each bet turns out to be right for different architectures.

Multi-Head Attention Pooling

If one query extracts one aspect of the sequence, why not use K queries to extract K different aspects? That's multi-head attention pooling, used in the Set Transformer (Lee et al., 2019). Each query "head" attends to different token patterns, producing K output vectors. You can concatenate them (getting a K·d vector) or average them (back to d).

K Queries
q1, q2, ..., qK — each a learned d-dim vector
K Weighted Sums
Each query produces its own attention weights and output vector
Combine
Concatenate [o1; o2; ...; oK] or mean(o1, ..., oK)

Head 1 might attend to the subject of a sentence. Head 2 might attend to the action. Head 3 might attend to negation words. Together, they capture a richer summary than any single query could.

Last-Token Pooling

For causal (decoder) models like GPT, there's a beautifully simple pooling strategy: just take the last token's representation.

Why does this work? In causal attention, each token can only attend to previous tokens. Token 0 sees only itself. Token 1 sees tokens 0 and 1. The last token is the only position that has attended to the entire sequence. Its representation is already a summary of everything that came before it.

The last token is the natural CLS for decoders. In BERT (bidirectional), every token sees every other token, so no single position is special — that's why they invented the [CLS] token. In GPT-style models, the last token naturally fills this role. For classification tasks, GPT-2 and GPT-3 simply take the last token representation and feed it to a classifier head.

The operation is trivial — no parameters, no computation beyond indexing:

output = hT (the last token's hidden state)

For variable-length sequences with padding on the right, you need the actual last non-padding position — not just the final array element.

Gated Pooling

What if each token could independently decide its own importance, without consulting a shared query? That's gated pooling. Each token passes through a small gating network that outputs a scalar weight:

gt = σ(W · ht + b),    output = ∑ gt · ht / ∑ gt

Here σ is the sigmoid function, so each gate gt is between 0 and 1. We divide by the sum of gates to normalize. Unlike attention pooling, there's no shared query — each token decides independently whether it's important. The gate learns to suppress noise tokens and amplify signal tokens.

The difference is subtle but real: attention pooling uses a shared query to compare tokens against a template. Gated pooling uses each token's own content to decide its importance. Both produce weighted averages — they just compute the weights differently.

Hand Calculation: Comparing All Methods

Let's apply all our pooling methods to the same 4 tokens from Chapter 4. Same representations: h0 = [0.5, 0.3], h1 = [0.8, -0.2], h2 = [0.1, 0.9], h3 = [-0.3, 0.5].

MethodOutputWhat it captures
Mean[0.275, 0.375]Average of all tokens equally
Max (per-dim)[0.8, 0.9]Strongest feature per dimension
CLS (= h0)[0.5, 0.3]Dedicated summary position
Attention (q=[0.7,-0.5])[0.405, 0.254]Query-relevant weighted sum
Last-token (= h3)[-0.3, 0.5]Final position (full causal context)

Five methods, five different answers. Each is "correct" — they just extract different information from the same input. Mean smooths everything. Max grabs peaks. CLS and last-token pick one position. Attention learns what matters.

See It Live

The simulation shows all five pooling methods applied to the same token sequence. Each method produces a different output point in 2D space. Edit the token vectors by dragging, and watch all five outputs update simultaneously.

Five Pooling Methods, One Input

All five methods pool the same 6 tokens. Colored dots show each method's output in 2D. Drag tokens (circles) to rearrange. Watch how each method responds differently to the same change.

Attn query angle 45°

All Methods in One Script

python
import torch
import torch.nn as nn
import torch.nn.functional as F

class MeanPool(nn.Module):
    def forward(self, x, mask=None):
        # x: (B, T, D)
        if mask is not None:
            x = x * mask.unsqueeze(-1).float()
            return x.sum(dim=1) / mask.sum(dim=1, keepdim=True).float()
        return x.mean(dim=1)

class MaxPool(nn.Module):
    def forward(self, x, mask=None):
        if mask is not None:
            x = x.masked_fill(~mask.unsqueeze(-1), float('-inf'))
        return x.max(dim=1).values

class CLSPool(nn.Module):
    def forward(self, x, mask=None):
        return x[:, 0]  # First token

class LastTokenPool(nn.Module):
    def forward(self, x, mask=None):
        if mask is not None:
            # Find actual last token per sequence
            lengths = mask.sum(dim=1) - 1
            return x[torch.arange(x.size(0)), lengths]
        return x[:, -1]

class GatedPool(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.gate = nn.Linear(d_model, 1)

    def forward(self, x, mask=None):
        g = torch.sigmoid(self.gate(x))  # (B, T, 1)
        if mask is not None:
            g = g * mask.unsqueeze(-1).float()
        return (g * x).sum(dim=1) / g.sum(dim=1).clamp(min=1e-8)


# Compare all methods on the same input
x = torch.randn(1, 10, 64)
pools = {
    'Mean': MeanPool(),
    'Max': MaxPool(),
    'CLS': CLSPool(),
    'Last': LastTokenPool(),
    'Gated': GatedPool(64),
}
for name, pool in pools.items():
    out = pool(x)
    print(f"{name:6s}: mean={out.mean():.3f}, std={out.std():.3f}")
# Mean:   mean=-0.012, std=0.298  (smoothed, low variance)
# Max:    mean= 1.782, std=0.421  (all positive, high values)
# CLS:    mean=-0.034, std=0.986  (single token, full variance)
# Last:   mean= 0.087, std=1.024  (single token, full variance)
# Gated:  mean=-0.008, std=0.315  (similar to mean, slightly adapted)
Last-token pooling only works for causal models. In BERT (bidirectional), the last token has no special status — it's seen the same context as every other token via bidirectional attention. CLS was invented specifically because bidirectional models have no natural "summary" position. In GPT-style models, the last token naturally serves this role because causal masking means it's the only token that has seen the entire input. Using last-token pooling on a BERT model would give you an arbitrary token's representation — no better than picking any random position.
Why do GPT-style decoder models use last-token pooling instead of CLS?

Chapter 6: The Aggregation Explorer

You've learned six different ways to compress a variable-length input into a fixed-size output. Time to see them all in action on the same data — and discover what each one sees that the others miss.

This is your aggregation laboratory. Switch between vision mode (2D feature map) and NLP mode (token sequence). Pick any pooling method. Add noise. Watch which methods survive the noise and which crumble. The goal: build intuition for when to reach for which tool.

The single most important insight in this lesson: There is no universally best pooling method. Max pooling is great for detecting the presence of a feature (is there an edge somewhere?). Mean pooling is great for capturing overall magnitude (how active is this region?). Attention pooling is great when different inputs need different weighting. The right choice depends on your task, your data, and your architecture.
The Aggregation Explorer

Switch between vision and NLP mode. Select a pooling method. Add noise to see robustness. The information retention bar shows how much input variance survives the pooling operation.

Mode
Method
Pool size (vision) 2×2
Noise level 0.0

What to Look For

Try these experiments:

Experiment 1 — Noise robustness. Set noise to 0. Note the output. Crank noise to 2.0. Max pooling's output jumps wildly (noise creates outliers that dominate). Mean pooling barely budges (noise averages out). Attention pooling adapts — it down-weights noisy positions if they don't match the query.

Experiment 2 — Vision vs NLP. Switch modes. In vision mode, local pooling (2×2, 3×3) reduces spatial resolution. You can see how max pooling preserves bright spots while mean pooling smooths the whole grid. In NLP mode, there's no spatial locality — the pooling is purely over the token dimension.

Experiment 3 — Information retention. Watch the retention bar at the bottom. GAP (Global Average Pooling) has the lowest retention — it crushes an entire feature map to one number per channel. Attention pooling and gated pooling retain more because they can selectively preserve important information.

Experiment 4 — CLS vs Last. In NLP mode, CLS and last-token give completely different outputs from the same input. CLS takes position 0 (which might be garbage if there's no trained [CLS] embedding). Last-token takes the final position. Neither uses the full sequence unless the upstream model has already mixed information across positions.

Chapter 7: Adaptive Pooling & Modern Tricks

What if your model needs to handle images of different sizes? A 224×224 image produces a 7×7 feature map after the convolutional backbone. A 384×384 image produces a 12×12. But the classifier head needs a fixed-size input. You can't change the number of neurons in a linear layer at inference time.

Adaptive pooling solves this: "give me a 1×1 output no matter what size comes in." The pooling window and stride are computed dynamically based on the input size, so the output size is always what you asked for. Every modern CNN uses this — it's why ResNet can process arbitrary resolution images.

How Adaptive Pooling Works

Regular pooling says: "use a 2×2 window with stride 2." If your input is 7×7, you get 3×3 output (with one row/column dropped or padded). If your input is 12×12, you get 6×6. The output size depends on the input size.

Adaptive pooling says: "give me a 1×1 output." Then it computes the window size and stride that achieve this:

The adaptive recipe. For target output size Hout × Wout from input Hin × Win: for each output position (i, j), pool over the input region from floor(i · Hin / Hout) to floor((i+1) · Hin / Hout). The window size adapts to the input. Different output positions may use slightly different window sizes when the division isn't exact.

The most common use: nn.AdaptiveAvgPool2d((1,1)) — global average pooling with adaptive sizing. Input: any (C, H, W). Output: always (C, 1, 1). Squeeze the spatial dimensions, and you have a C-dimensional vector ready for a classifier.

Hand Calculation: Adaptive Pool

AdaptiveAvgPool on a 6×6 input, target output 2×2:

Hin = 6, Hout = 2. Each output cell covers 6/2 = 3 input rows and 3 input columns.

Now with a 7×7 input, target 2×2: 7/2 = 3.5, so the windows aren't equal. Output (0,0) covers input[0:3], output (1,0) covers input[3:7] (4 rows). The adaptive algorithm handles this asymmetry automatically.

GeM — Generalized Mean Pooling

Here's a beautiful idea from image retrieval (Radenovic et al., 2018): what if we could interpolate between average pooling and max pooling with a single learnable parameter?

GeM(x) = (1/N · ∑ xip)1/p

When p = 1, this is exactly the arithmetic mean (average pooling). As p grows, the result is increasingly dominated by larger values. In the limit p → ∞, it converges to the max. The parameter p is learned — the network discovers whether max or average (or something in between) is best for its task.

Why GeM wins for retrieval. In image retrieval, you want to match images based on distinctive local features (landmarks, logos, textures). Average pooling dilutes these signals by mixing them with featureless backgrounds. Max pooling is too aggressive — one noisy activation dominates. GeM with learned p ≈ 3 emphasizes distinctive features without being dominated by outliers. That's why most competitive image retrieval models (Google Landmarks, GLDv2) use GeM.

Hand Calculation: GeM

Values: [1.0, 0.5, 2.0, 0.3], power p = 3.

Step 1 — Raise each value to power p:

Step 2 — Average the powered values:

Mean = (1.000 + 0.125 + 8.000 + 0.027) / 4 = 9.152 / 4 = 2.288

Step 3 — Take the p-th root:

2.2881/3 = 1.318

Compare:

MethodResultInterpretation
Average (p=1)0.950All values contribute equally
GeM (p=3)1.318Biased toward larger values
GeM (p=10)1.815Strongly biased toward max
Max (p→∞)2.000Only the largest survives

As p increases, the output slides smoothly from mean toward max. At p = 3, the value 2.0 already dominates: it contributes 8.0 out of 9.152 to the sum of cubes. The 0.3 contributes almost nothing (0.027). GeM naturally amplifies strong features and suppresses weak ones — but smoothly, not all-or-nothing like max.

See It Live

Drag the power slider from 0.5 to 10. Watch the GeM output smoothly interpolate between mean pooling (p=1) and max pooling (p large). Below the visualization: a bar chart showing how much each input value contributes to the final result.

GeM Pooling: The Power Slider

Adjust p to see how GeM interpolates between average and max. At p=1 all bars are equal (average). As p grows, large values dominate. The top shows a feature map; the bottom shows effective contribution weights.

Power p 1.0

From Scratch: Adaptive Pool & GeM

python
import torch
import torch.nn as nn
import torch.nn.functional as F

# --- Adaptive pooling (built-in, but here's what it does) ---
x = torch.randn(1, 64, 7, 7)   # from 224×224 input
y = torch.randn(1, 64, 12, 12)  # from 384×384 input

gap = nn.AdaptiveAvgPool2d((1, 1))
print(gap(x).shape)  # (1, 64, 1, 1) — always!
print(gap(y).shape)  # (1, 64, 1, 1) — same output size

# Adaptive to 3×3 (used in SPPNet-style architectures)
spp = nn.AdaptiveMaxPool2d((3, 3))
print(spp(x).shape)  # (1, 64, 3, 3)
print(spp(y).shape)  # (1, 64, 3, 3) — same!


# --- GeM (Generalized Mean Pooling) from scratch ---
class GeM(nn.Module):
    """Generalized Mean Pooling.
    p=1: average pooling. p→∞: max pooling.
    Input:  (B, C, H, W)
    Output: (B, C, 1, 1)
    """
    def __init__(self, p=3.0, eps=1e-6):
        super().__init__()
        # p is learnable!
        self.p = nn.Parameter(torch.tensor(p))
        self.eps = eps

    def forward(self, x):
        # Clamp to avoid log(0) or negative values
        x_clamped = x.clamp(min=self.eps)
        # Raise to power p, average, take p-th root
        return (x_clamped.pow(self.p)
                .mean(dim=(-2, -1), keepdim=True)
                .pow(1.0 / self.p))


# Use it in a retrieval model
gem = GeM(p=3.0)
features = torch.randn(4, 2048, 7, 7)
pooled = gem(features).squeeze(-1).squeeze(-1)  # (4, 2048)
print(f"p learned value: {gem.p.item():.2f}")  # starts at 3.0

# After training on landmarks: p ≈ 2.5-4.0 typically
# Different channels may benefit from different p values
Adaptive pooling isn't magic — it just computes the window size dynamically. nn.AdaptiveAvgPool2d((1,1)) on a 7×7 input uses a 7×7 window. On a 14×14 input, a 14×14 window. The pooling operation is identical to regular pooling; only the window sizing is automatic. There's no neural network, no learned parameters — it's pure arithmetic that adapts the kernel to hit your target output size.

Stride Pooling: The Learned Alternative

One more modern trick: instead of a fixed pooling operation, use a strided convolution. A regular convolution with stride 2 halves the spatial resolution — just like 2×2 max pooling. But the convolution kernel is learned, so the network decides how to downsample instead of being forced into max or average.

ResNet uses max pooling early on. Many modern architectures (ConvNeXt, EfficientNet-v2) replace this with a strided conv — one less design choice to make, and the model can learn what matters. The tradeoff: strided convolutions add parameters and computation; max/avg pooling is free.

ApproachParametersLearns?Use case
Max Pool0NoClassic CNNs (VGG, ResNet)
Avg Pool0NoSmooth downsampling, GAP for classifiers
Adaptive Pool0NoResolution-agnostic final layer
GeM1 (p)YesImage retrieval, fine-grained recognition
Strided Convk2·C2YesModern architectures replacing fixed pooling
Attention Poold (query)YesSequence aggregation, set aggregation
In GeM pooling, what happens as the power parameter p increases?

Chapter 8: The Pooling Arena

Time to put all eight pooling methods head-to-head. Pick a task below — Classification, Similarity, or Retrieval — and watch each method race to show its strengths (and weaknesses).

The Setup: Each method processes the same set of token embeddings. We measure discriminability (how separable are different classes), robustness (how stable under noise), and speed (relative FLOPs). The best method depends on the task.

What you're seeing: Each bar represents one pooling method. The bar length shows a composite score combining discriminability, robustness, and speed — weighted differently for each task. The star marks the winner.

Key takeaways from the arena:
  • Classification — Attention pooling and GeM dominate because they learn to focus on the most class-relevant features.
  • Similarity — Mean pooling and CLS token win because stable, full-sequence representations make better similarity anchors.
  • Retrieval — GeM and attention pooling excel because they produce more distinctive embeddings with less information loss.

No single method wins everything. That's the lesson. The right pooling depends on what you're optimizing for — and the best practitioners choose deliberately.

Chapter 9: Cheat Sheet

Everything from this lesson on one page. Pin it, print it, screenshot it.

Decision Flowchart

Which pooling should I use?

1. Are your inputs fixed-length grids (images)?
  → Yes: Start with Global Average Pooling (GAP). Try GeM if retrieval matters.
  → No: Go to 2.

2. Are you fine-tuning a pretrained transformer?
  → Yes, with [CLS] pretraining: Use the CLS token.
  → Yes, without [CLS]: Use mean pooling (e.g., Sentence-BERT style).
  → No: Go to 3.

3. Do some tokens matter more than others?
  → Yes, and you have labels: Attention pooling (learnable query).
  → Yes, but no extra labels: Weighted mean with IDF or position weights.
  → No, all tokens equal: Mean pooling.

4. Do you need max discrimination for retrieval/ranking?
  → Yes: GeM (p>1 sharpens features). Tune p on validation set.
  → No: Mean or attention pooling.

Comparison Table

Method Learnable? Best For Weakness FLOPs
Max Pool No Edge/texture detection (CNNs) Discards magnitude, sensitive to outliers O(n)
Average Pool No Smooth features, denoising Dilutes strong signals O(n)
GAP No CNN classification (ResNet, EfficientNet) Treats all spatial positions equally O(HW)
CLS Token Yes (implicit) BERT-style fine-tuning Requires [CLS] pretraining; single bottleneck O(1) at pool time
Mean Pool No Sentence embeddings (SBERT) Padding tokens dilute signal (mask them!) O(n)
Attention Pool Yes Multi-instance learning, set classification Extra parameters; can overfit on small data O(nd)
GeM Yes (p) Image retrieval, place recognition Assumes positive activations (post-ReLU) O(n)
Last Token No Causal LMs (GPT-style) Only sees leftward context in final position O(1) at pool time

PyTorch Quick Reference

# ── Max Pool (2D, stride=2) ──
pool = nn.MaxPool2d(kernel_size=2, stride=2)
out = pool(x)  # [B, C, H/2, W/2]

# ── Global Average Pooling ──
gap = nn.AdaptiveAvgPool2d(1)
out = gap(x).squeeze(-1).squeeze(-1)  # [B, C]

# ── Mean Pooling (with mask) ──
def mean_pool(hidden, mask):
    mask_exp = mask.unsqueeze(-1).float()
    return (hidden * mask_exp).sum(1) / mask_exp.sum(1).clamp(min=1e-9)

# ── CLS Token ──
cls_embed = hidden_states[:, 0, :]  # [B, D]

# ── Attention Pooling ──
class AttentionPool(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.query = nn.Parameter(torch.randn(dim))
    def forward(self, x):  # x: [B, N, D]
        scores = (x @ self.query) / x.size(-1)**0.5
        weights = scores.softmax(dim=1).unsqueeze(-1)
        return (x * weights).sum(1)

# ── GeM Pooling ──
class GeM(nn.Module):
    def __init__(self, p=3.0):
        super().__init__()
        self.p = nn.Parameter(torch.tensor(p))
    def forward(self, x):  # x: [B, C, H, W]
        return x.clamp(min=1e-6).pow(self.p) \
                .mean(dim=[-2,-1]).pow(1/self.p)

# ── Last Token (causal LM) ──
last = hidden_states[:, -1, :]  # [B, D]

Connections

Where pooling shows up next:
  • Sentence Transformers — Mean pooling + contrastive loss = state-of-the-art embeddings.
  • Vision Transformers (ViT) — CLS token vs GAP is an active design choice.
  • Retrieval-Augmented Generation (RAG) — Embedding quality = retrieval quality. Pooling is step zero.
  • Multi-Instance Learning — Attention pooling over bag instances is the standard architecture.
  • Diffusion Models — Cross-attention pooling connects text embeddings to image generation.
  • Graph Neural Networks — Global readout = pooling over nodes. Same ideas, different domain.
🔬 Final Check

You're building a semantic search engine using a BERT model that was pretrained without a [CLS] token objective. Your inputs are variable-length sentences. Which pooling strategy should you use?

Mean pooling with attention mask is the right choice. Without [CLS] pretraining, the first token has no special aggregation role. Mean pooling captures information from all tokens equally, and the attention mask prevents padding tokens from contaminating the representation. This is exactly what Sentence-BERT uses — and it outperforms CLS pooling even on models that were pretrained with [CLS].