Every GNN ever invented is just a choice of message function and aggregation function. Once you see this, GCN, GraphSAGE, and GAT become three points in the same design space.
In Lecture 3 you built the GCN — one specific choice for how to aggregate neighbor information. But why that choice? What if neighbors aren't equally important? What if you want to keep your own features separate from your neighbors' features? The answer is that GCN is just one point in a vast GNN design space.
You, Ying, and Leskovec (NeurIPS 2020) formalized this insight: every GNN layer can be decomposed into exactly two steps, and different choices for these steps give you different architectures. Understanding this decomposition means you can reason about ALL GNNs, not just memorize three names.
Each neighbor u computes a message to send to node v. The message is a function of u's current embedding. At layer l:
The simplest MSG is a linear transformation: MSG(hu) = W(l) hu. But you could use a more complex neural network, or scale the message by some factor, or weight it by an attention score — all are valid choices.
Node v receives messages from all its neighbors, and must collapse them into a single vector. Then it updates its own embedding:
AGG must be a function that: (1) takes a set of messages (order shouldn't matter — graphs have no canonical ordering), and (2) incorporates the node's own previous embedding hv(l-1). Common choices: mean, sum, max, or a neural network applied after pooling.
Concretely, for MSG you can vary: the weight matrix (shared vs per-layer), whether to normalize by degree, whether to include edge features. For AGG you can vary: mean vs sum vs max vs attention-weighted, whether to concatenate self-embedding vs add it, whether to apply a nonlinearity before or after aggregation. That's a huge combinatorial space — and You et al. found that simple designs often outperform complex ones on many benchmarks.
The design space paper identifies four dimensions of GNN design beyond just MSG and AGG:
1. Intra-layer design: MSG function, AGG function, activation, normalization (none, batch norm, layer norm). These are the choices we discuss in chapters 1-3.
2. Inter-layer design: How layers connect to each other. Options: stack (plain), residual (add skip), dense (all-to-all connections), jumping knowledge (concatenate all layers).
3. Layer connectivity: How many GNN layers vs how many MLP layers before/after. Pre-processing MLPs on raw features often help dramatically. Post-processing MLPs for the final prediction.
4. Learning objective: Supervised (cross-entropy, MSE), self-supervised (contrastive), semi-supervised. The objective drives what the embeddings learn to encode.
You et al. evaluated over 315,000 GNN configurations. Their top finding: no single design choice dominates across all tasks. The winning configuration on molecular property prediction looks completely different from the winning configuration on social network link prediction. This means the "design space" framing is not just academic — it's practical guidance to explore your specific problem's sweet spot rather than defaulting to GCN.
The right MSG+AGG choice often follows from the task structure:
A small graph with 5 nodes. Click any node to see its message computation and aggregation step. Toggle which AGG function to use and observe how the output changes.
Now that we have the unified framework, let's re-read GCN through this lens. You already know GCN from Lecture 3 — but seeing it as MSG + AGG reveals the specific choices Kipf & Welling made, and makes it easy to compare with GraphSAGE and GAT.
In GCN, the message from neighbor u to node v is a linear transformation of u's embedding, normalized by both their degrees:
This is symmetric normalization — normalize by the geometric mean of both degrees. The motivation: if u has 100 neighbors, it sends weaker messages (its influence is "diluted" across 100 targets). If v has 100 neighbors, each incoming message is weaker (v averages 100 signals). This prevents high-degree nodes from dominating.
Aggregation is a sum (which after normalization becomes a weighted mean). Self-loop is included via an augmented graph à = A + I, so node v is in its own neighborhood:
Because the messages are normalized, this sum is effectively a weighted mean where higher-degree neighbors contribute less. The self-loop means v's own previous embedding participates in the sum with weight 1/|Ñ(v)| where Ñ includes v itself.
Let's be precise about tensor shapes. Graph: n = 1000 nodes. Input features: din = 64. Two GCN layers with dhidden = 128, dout = 64.
GCN uses uniform weights — every neighbor contributes equally (after degree normalization). But in a citation network, a paper cited by Nature and a paper cited by an obscure blog are very different. Should they carry the same weight in a neighbor's embedding? GCN says yes. GAT will say no.
Also, GCN's normalization by degree is a hand-crafted heuristic. Why geometric mean? Why not just v's degree? Why not something learned? GraphSAGE and GAT replace this heuristic with learned components.
A star graph: central node v connected to 5 neighbors with varying degrees. See how GCN assigns weights based purely on degree (fixed heuristic), while GAT weights would depend on feature similarity. Drag the slider to change the central node's degree and watch GCN weights recalculate.
Hamilton, Ying, and Leskovec introduced GraphSAGE (Graph SAmple and aggreGatE) at NeurIPS 2017. The motivation was inductive learning: train on one graph (e.g., a portion of Reddit), run inference on new, unseen graphs or new nodes. GCN was transductive — it needed to see the full graph adjacency matrix during training. GraphSAGE fixes this with a key architectural difference: it keeps the node's own embedding separate from its neighborhood summary, and concatenates them.
GraphSAGE replaces the single pooled sum of GCN with an explicit two-stage operation:
Stage 1 — Aggregate from neighbors:
Stage 2 — Concatenate self with neighborhood summary, then transform:
The symbol ∥ means concatenation. If hv(l-1) is d-dimensional and hN(v)(l) is d-dimensional, then the concatenated vector is 2d-dimensional. W(l) maps from 2d to d.
GraphSAGE proposes three aggregation functions for Stage 1:
| AGG Name | Formula | Properties |
|---|---|---|
| Mean | hN(v) = (1/|N(v)|) ∑u∈N(v) hu | Fastest. No extra parameters. Loses count information (can't distinguish "2 neighbors" from "5 neighbors" if they average to the same thing). |
| Pool | hN(v) = max({ReLU(Wpool hu + b) : u∈N(v)}) | Element-wise max after a per-neighbor MLP. Extra parameters Wpool. Can pick out the "most extreme" feature across neighbors. |
| LSTM | hN(v) = LSTM({hu : u∈N(v)}) | Most expressive but NOT permutation-invariant. Requires random shuffling of neighbors during training to prevent ordering bias. Slower. |
In practice, mean-aggregation and pool-aggregation perform similarly. LSTM is often not worth the extra complexity and broken permutation invariance. The real win of GraphSAGE over GCN comes from the concat architecture, not from fancy AGG functions.
GraphSAGE also introduced neighbor sampling: instead of aggregating over ALL neighbors (which can be 10,000 for hub nodes in a social graph), sample a fixed-size subset S (typically 25 for layer 1, 10 for layer 2). This makes mini-batch training feasible on massive graphs. The expected gradient is still an unbiased estimate of the full gradient.
python import torch import torch.nn as nn class GraphSAGELayer(nn.Module): def __init__(self, in_dim, out_dim): super().__init__() # W maps from concatenated [self || neighbors] to output self.W = nn.Linear(in_dim * 2, out_dim) self.activation = nn.ReLU() def forward(self, h, adj): # h: [N, in_dim] node features # adj: [N, N] adjacency (normalized, no self-loops) # Stage 1: aggregate neighbors (mean) deg = adj.sum(dim=1, keepdim=True).clamp(min=1) # [N, 1] h_neigh = (adj @ h) / deg # [N, in_dim] — mean of neighbors # Stage 2: concatenate self + neighbor, then transform h_cat = torch.cat([h, h_neigh], dim=-1) # [N, 2*in_dim] h_out = self.activation(self.W(h_cat)) # [N, out_dim] # Optional: L2 normalize (common in inductive settings) h_out = h_out / h_out.norm(dim=-1, keepdim=True).clamp(min=1e-6) return h_out
GraphSAGE's neighbor-sampling design made it the foundation for PinSage, Pinterest's industrial GNN for image recommendation (Ying et al., KDD 2018). PinSage operates on a graph with 3 billion nodes (pins) and 18 billion edges — the full graph would require 100TB of memory just to store. Key adaptations:
Localized convolution: Only sample a fixed-size neighborhood (K=500 for layer 1, K=250 for layer 2). Training mini-batches never touch the full graph.
Importance-based sampling: Instead of uniform random sampling, weight neighbors by their random-walk visit probability to the central node — more important neighbors sampled more often.
Curriculum training: Train with easy negatives first (random negatives from across the corpus), then hard negatives (items that are almost but not quite relevant). Prevents collapse to trivial solutions.
MapReduce inference: At test time, compute all embeddings in a single pass using MapReduce — no repeated neighborhood lookups. This enables real-time serving at Pinterest scale.
GCN and GraphSAGE both treat neighbor importance as a fixed, hand-crafted function of degree. A neighbor with degree 1 contributes more than a neighbor with degree 100, simply because of the normalization. But surely the content of a neighbor matters more than its structural position? That's the question GAT (Velickovic et al., ICLR 2018) answers.
GAT learns attention weights αvu that tell node v how much to attend to each neighbor u. These weights depend on the features of both v and u — two nodes with very different features might warrant a small weight, while two similar nodes warrant a large weight. Crucially, these weights are learned, not hand-designed.
First, transform both node embeddings with the same weight matrix W, then score their compatibility with a small neural network parametrized by vector a:
Breaking this down: W hv and W hu project both embeddings into the same space. Concatenating them gives a vector of length 2d'. The learnable vector a (length 2d') scores this concatenation with a single dot product. LeakyReLU allows negative scores (unlike ReLU, which would kill half the scores). The result evu is a scalar — how "compatible" v and u are.
Raw scores are incomparable across different nodes (different scales). Normalize them into a proper probability distribution over v's neighborhood:
Now Σu∈N(v) αvu = 1, and all weights are positive. Node v's attention over its neighbors is a proper distribution — it sums to 1, and we can interpret each αvu as "what fraction of attention v gives to u."
Use the attention weights to compute a weighted sum of neighbors' transformed embeddings:
Compare to GCN: replace the uniform weight 1/|N(v)| with the learned weight αvu. Same structure, but now the network decides which neighbors matter most — for EACH node, EACH layer, based on features.
Let's trace through one step of GAT for a small graph. Say node v has features hv = [1.0, 0.5] and has neighbors A with hA = [0.9, 0.6] and B with hB = [0.1, 0.8].
Setup: W = I (identity, so Wh = h), a = [1, -1, 0.5, -0.5] (first two dims = source, last two = target).
Score for v→A: [Whv ∥ WhA] = [1.0, 0.5, 0.9, 0.6].
aT · this = 1(1.0) + (-1)(0.5) + 0.5(0.9) + (-0.5)(0.6) = 1.0 - 0.5 + 0.45 - 0.30 = 0.65.
LeakyReLU(0.65) = 0.65.
Score for v→B: [Whv ∥ WhB] = [1.0, 0.5, 0.1, 0.8].
aT · this = 1.0 - 0.5 + 0.5(0.1) + (-0.5)(0.8) = 0.5 + 0.05 - 0.40 = 0.15.
LeakyReLU(0.15) = 0.15.
Softmax: exp(0.65) = 1.916, exp(0.15) = 1.162. Sum = 3.078. αvA = 1.916/3.078 = 0.623, αvB = 1.162/3.078 = 0.377.
Output: hvnew = σ(0.623 · [0.9, 0.6] + 0.377 · [0.1, 0.8]) = σ([0.598, 0.678]) ≈ [0.645, 0.663] (with sigmoid).
Notice: node A (more similar to v) gets 62.3% of the attention, node B (less similar) gets 37.7%. The network has learned — from the parameters W and a — to prefer more compatible neighbors. And this all happened via gradient descent on the task loss, not via any manual similarity function.
A 6-node graph. Click a node to see its attention distribution over neighbors — shown as edge thickness. Drag a feature slider to change node features and watch attention weights update in real time.
A single set of attention weights asks ONE question about each neighbor: "how important is u to v?" But importance can have multiple dimensions. In a citation network, paper v might attend to paper u because they share a topic — and simultaneously attend to paper w because they use similar methods. A single attention head can only capture one notion of importance at a time.
Multi-head attention runs K independent attention mechanisms in parallel, each with its own parameters Wk and ak. Each head learns to attend to different aspects of the neighborhood.
Run K heads, get K output embeddings per node, then concatenate:
The notation ∥ means concatenation. If each head outputs a d/K-dimensional vector, the concatenated output is d-dimensional — same total size as a single-head model. Each headk has its own attention weights αkvu computed with its own parameters (Wk, ak).
Think of K heads as a K-model ensemble. If one head learns a bad set of attention weights early in training, the other K-1 heads can compensate. The gradient signal averaged across heads is lower-variance than a single head — similar to why mini-batch SGD is more stable than single-sample gradient estimates.
Empirically, GAT with K=8 heads consistently outperforms K=1. The improvement plateaus around K=8; more heads add parameters but not much accuracy. In the original paper: K=8 for all but the last layer (where K=1 averaging is used).
One GAT layer with din input dimensions, dout output dimensions, K heads:
| Component | Shape | Count (din=64, dout=64, K=8) |
|---|---|---|
| Wk per head (K matrices) | [din, dout/K] | 8 × (64 × 8) = 4096 |
| ak per head (K vectors) | [2 × dout/K] | 8 × 16 = 128 |
| Total | 4224 parameters |
Vs GCN one layer: din × dout = 64 × 64 = 4096 parameters. GAT adds only 128 extra parameters (the attention vectors ak) over GCN — a tiny cost for learned attention weights.
python import torch import torch.nn as nn import torch.nn.functional as F from torch_geometric.nn import GATConv class TwoLayerGAT(nn.Module): def __init__(self, d_in, d_hid, d_out, K=8): super().__init__() # Layer 1: K=8 heads, each head outputs d_hid//K → concat = d_hid self.gat1 = GATConv(d_in, d_hid // K, heads=K, concat=True, dropout=0.6, negative_slope=0.2) # Layer 2: K=1 (or average K=8), outputs d_out self.gat2 = GATConv(d_hid, d_out, heads=1, concat=False, dropout=0.6, negative_slope=0.2) self.dropout = nn.Dropout(p=0.6) def forward(self, x, edge_index): x = self.dropout(x) x = F.elu(self.gat1(x, edge_index)) # [N, d_hid] x = self.dropout(x) x = self.gat2(x, edge_index) # [N, d_out] — averaged heads return F.log_softmax(x, dim=-1) # The exact architecture from the GAT paper (Cora: 8-head hidden, 1-head output) model = TwoLayerGAT(d_in=1433, d_hid=64, d_out=7, K=8) # d_in=1433 (Cora bag-of-words), d_out=7 (7 classes), K=8 heads in hidden layer
Now we can make the comparison crisp. All three operate within the MSG + AGG framework — they're just different choices. The table below is the one mental model you should carry forward from this lecture.
| Property | GCN | GraphSAGE | GAT |
|---|---|---|---|
| MSG function | W hu / √degu degv | hu (no transform before AGG) | αvu W hu |
| AGG function | Sum (= weighted mean after normalization) | Mean / Pool / LSTM, then CONCAT with self | Weighted sum (weights = softmax attention) |
| Neighbor weights | Fixed: 1/√degu · degv | Fixed: uniform mean (or max) | Learned: αvu from features |
| Self vs neighbor | Mixed together (self-loop in Ã) | Explicit CONCAT — kept separate until W | Attention can include self (add self-loop) |
| Parameters (1 layer, d×d) | d2 | 2d2 (2d input to W due to concat) | d2 + 2d·K (attention vectors) |
| Inductive? | No (needs full adj matrix) | Yes (computes from features + local neighborhood) | Yes (attention computed from features) |
| Key strength | Simple, fast, strong baseline | Inductive, keeps self separate | Content-aware neighbor weighting |
Here are the three models side by side in PyTorch Geometric. Notice how the aggregation step is the only thing that changes.
python from torch_geometric.nn import GCNConv, SAGEConv, GATConv # GCN: symmetric normalization, single W matrix gcn_layer = GCNConv(in_channels=64, out_channels=128) h_gcn = gcn_layer(x, edge_index) # x: [N, 64] → [N, 128] # GraphSAGE: mean AGG, concat self+neigh, 2× input to W sage_layer = SAGEConv(in_channels=64, out_channels=128) h_sage = sage_layer(x, edge_index) # x: [N, 64] → [N, 128] # GAT: learned attention, K=8 heads, concat output gat_layer = GATConv(in_channels=64, out_channels=16, heads=8, concat=True) h_gat = gat_layer(x, edge_index) # x: [N, 64] → [N, 128] (16 × 8 heads) # All produce [N, 128] — same shape, different computation inside
With one GNN layer, each node "sees" its 1-hop neighborhood. With K layers, it sees its K-hop neighborhood. More layers should mean richer representations, right? In practice, deeper GNNs often perform WORSE than shallow ones — a phenomenon called over-smoothing.
Let's think about what repeated mean aggregation does. After 1 layer: each node's embedding is the average of its neighbors' features. After 2 layers: it's the average of averages. After K layers: the embedding is a weighted average over the K-hop neighborhood — and as K grows, this neighborhood covers more and more of the graph.
In an Erdős-Rényi random graph with average degree d, after O(logd n) layers, the K-hop neighborhood covers the entire graph. At that point, EVERY node's embedding is a weighted average over ALL nodes — they all converge to the same vector. The embeddings are fully mixed. This is over-smoothing.
From a signal processing perspective, GCN applies a low-pass filter to the graph signal. The graph Laplacian L = D - A has eigenvalues in [0, 2d] where d is max degree. GCN corresponds to multiplying the graph signal by (I - D-1/2 A D-1/2) — a filter that attenuates high-frequency components (nodes that differ from their neighbors). After K layers, this filter is applied K times, suppressing high-frequency information exponentially. The NODES THAT DIFFER FROM THEIR NEIGHBORHOOD are exactly the ones you want to distinguish — and GCN erases them with enough layers.
In practice: 2-3 layers for most tasks. If you need a 7-hop receptive field, don't use 7 GNN layers — use the graph structure more cleverly (multi-hop connections, virtual nodes). This is different from CNNs and Transformers, where going deeper almost always helps.
A graph with two communities (warm = community A, teal = community B). Each "Step" applies one GCN mean-aggregation layer. Watch how node colors (embeddings) converge as layers increase.
You hit a wall: shallow GNNs can't see far enough, deep GNNs over-smooth. The solution comes from an unlikely source — ResNets. He et al. (2015) showed that adding a skip connection (identity shortcut) from layer l to layer l+1 prevents degradation in very deep CNNs. The same trick works for GNNs.
The residual update adds the previous layer's embedding directly:
This ensures that even if the GNN layer "messes up" (e.g., over-smooths), the skip connection passes through the pre-smoothed embedding unchanged. The network can learn to use the skip connection when aggregation hurts (set GNN_layer output ≈ 0) and to use the GNN output when aggregation helps.
Xu et al. (2018) proposed a more aggressive version: instead of adding consecutive layers, concatenate embeddings from ALL K layers and aggregate them at the end. This is called a JK-Net (Jumping Knowledge):
Where AGG is typically max-pooling or LSTM over the layer sequence. Each layer hv(l) captures a different scale: layer 1 sees 1-hop, layer 2 sees 2-hop, etc. JK-Net lets each node adaptively select which scale of neighborhood information is most useful for its task.
Residual connection: Simple, adds 0 parameters. Skip from l-1 to l. Like "maintain a running average of progress."
JK-Net: More expressive. Accesses ALL layer embeddings. Each node can use a different scale adaptively. More parameters (aggregation over K embeddings).
Dense connections (DenseNet-style): Every layer connects to all subsequent layers. Maximally expressive but O(K2) connections.
DropEdge: Randomly drop edges during training. Acts as data augmentation. Prevents co-adaptation of edge weights. Works well combined with residuals.
python class ResidualGNN(nn.Module): def __init__(self, d, K): super().__init__() self.layers = nn.ModuleList([GCNConv(d, d) for _ in range(K)]) def forward(self, x, edge_index): h = x for layer in self.layers: h = F.relu(layer(h, edge_index)) + h # skip connection return h class JKNet(nn.Module): def __init__(self, d, K): super().__init__() self.layers = nn.ModuleList([GCNConv(d, d) for _ in range(K)]) def forward(self, x, edge_index): h = x all_hs = [] for layer in self.layers: h = F.relu(layer(h, edge_index)) all_hs.append(h) # Max-pool across all layer embeddings return torch.stack(all_hs, dim=0).max(dim=0).values
Theory tells you what's possible. Practice tells you what actually works. You, Ying, and Leskovec (NeurIPS 2020) ran a massive systematic evaluation across 12 graph datasets and over 300,000 GNN configurations. Their key finding: there's no universally best GNN design, but there are reliable rules of thumb.
Aggregation: Sum aggregation often beats mean for node classification. This surprised people — sum preserves more information (it doesn't normalize away degree). But for link prediction, mean can win.
Activation: ReLU works well in most cases. PReLU (learnable slope) sometimes helps. Avoid sigmoid in hidden layers (vanishing gradients).
Layer connectivity: Skip connections almost always help when going beyond 3 layers. Without them, performance degrades monotonically after ~3 layers.
Depth: 2-3 layers is the sweet spot for most citation network benchmarks. On molecular graphs, 4-5 layers can help (molecules are small, less over-smoothing risk). On social networks (large, dense), 2 layers.
Batch normalization: BatchNorm after each GNN layer stabilizes training, especially for deeper models (K≥4). Treat it like you would in CNN design.
Dropout: Apply dropout to node features BEFORE aggregation, or to the final embedding. Avoid dropout on edge weights (unstable for GAT attention).
Where you place dropout in a GNN matters more than in standard DNNs. Three options:
Feature dropout (best): Apply dropout to node feature vectors BEFORE aggregation. This randomly masks individual input features. Prevents the GNN from relying on any single feature dimension. The original GAT paper uses this for Cora (p=0.6).
Attention dropout: Apply dropout to attention coefficients αvu before they're used in the weighted sum. Randomly drops edges from the computation graph — similar to DropEdge. Makes the model more robust to missing connections.
Embedding dropout: Apply dropout to node embeddings after each GNN layer. Standard approach, works but slightly less effective for GNNs than feature dropout because embeddings already aggregate information from multiple sources.
What NOT to do: Don't apply dropout to the final output layer if using cross-entropy — the softmax expects un-scaled logits. And don't apply dropout inside the message function if you're debugging — it makes attention weights stochastic and hard to interpret.
Apply BatchNorm after each GNN layer, before the activation. This normalizes node embeddings across the batch of nodes — analogous to normalizing across the batch in standard DNNs. For mini-batch training (where batch = subgraph), use the mini-batch statistics (standard PyTorch BatchNorm works). Empirically, BatchNorm is especially important for K≥3 layers.
python class PracticalGNN(nn.Module): def __init__(self, d_in, d_hid, d_out, K=3): super().__init__() dims = [d_in] + [d_hid] * (K - 1) + [d_out] self.convs = nn.ModuleList( [GCNConv(dims[i], dims[i+1]) for i in range(K)] ) self.bns = nn.ModuleList( [nn.BatchNorm1d(dims[i+1]) for i in range(K - 1)] ) self.dropout = nn.Dropout(p=0.5) def forward(self, x, edge_index): for i, conv in enumerate(self.convs): x = conv(x, edge_index) if i < len(self.bns): x = F.relu(self.bns[i](x)) # BN then ReLU x = self.dropout(x) return x # final layer: no BN, no ReLU, no dropout
| Situation | Recommended Start | Why |
|---|---|---|
| Homophilic graph (neighbors similar to center) | 2-layer GCN | Uniform weights near-optimal when all neighbors are relevant. Simple baseline usually hard to beat. |
| Heterophilic graph (neighbors often dissimilar) | GAT with K=8 heads | Need to selectively attend to the few relevant neighbors among noisy majority. |
| New nodes at inference time (inductive) | GraphSAGE (mean AGG) | GCN requires full adj matrix. GraphSAGE only needs node features + local neighborhood. |
| Task requires counting (graph isomorphism sensitive) | GIN (sum + MLP) | Mean/max lose count information. SUM preserves it. GIN is maximally expressive within MSG+AGG. |
| Very deep model needed (K≥5) | Any + skip connections | Over-smoothing makes deep models fail. Residual or JK-Net allows depth. |
| Large graph (>10M nodes) | GraphSAGE + neighbor sampling | GCN full-batch doesn't fit in memory. GAT is O(|E|) but still needs all edges. |
| Edge features available | MPNN or edge-conditioned GAT | Standard GAT ignores edge features. Concat them into attention scoring. |
You now understand the GNN design space: MSG + AGG, with GCN / GraphSAGE / GAT as three points in it. You understand over-smoothing and how skip connections address it. Here's where this knowledge connects to the broader landscape.
| Lecture | Key Idea | Link |
|---|---|---|
| Lec 1: Intro to Graphs | Graphs as data structures. Node/edge/graph tasks. | Lecture 1 |
| Lec 2: Node Embeddings | DeepWalk, node2vec: shallow transductive embeddings. | Lecture 2 |
| Lec 3: GNNs 1 | Computation graphs, GCN, message passing, matrix form. | Lecture 3 |
| Lec 4: GNNs 2 | Design space, GraphSAGE, GAT, over-smoothing, skip connections. | This lesson |
GAT replaces the graph Laplacian with a learned, input-dependent filter. This makes it structurally similar to Transformers — in fact, a Transformer is just GAT on a fully-connected graph (every token attends to every other token). The GAT attention mechanism (concat then LeakyReLU) is the "additive attention" variant; Transformers use "scaled dot-product attention." Both are valid choices in the same design space.
Read the original GAT paper in full detail:
Graph Attention Networks — Velickovic et al., ICLR 2018 →
GraphSAGE's inductive formulation was the foundation for PinSage (Ying et al., KDD 2018) — Pinterest's production GNN that ran on 3 billion nodes. The key insight was that neighbor sampling + mini-batch training + inductive inference made billion-scale GNNs feasible. No full adjacency matrix ever fits in memory.
All the GNNs we've covered have a fundamental expressiveness limit: they are at most as powerful as the 1-Weisfeiler-Lehman graph isomorphism test. This means there exist pairs of non-isomorphic graphs that all MSG + AGG GNNs map to the same embedding. GIN (Graph Isomorphism Network, Xu et al. 2019) characterizes exactly when a GNN IS as powerful as 1-WL: it uses SUM aggregation (not mean or max) and applies an MLP with sufficient capacity. Mean and max aggregation are strictly less expressive. This is coming in CS224W Lecture 5.
Before changing architecture, sweep these hyperparameters first (in order of typical impact):
| Hyperparameter | Typical Range | Impact |
|---|---|---|
| Learning rate | 1e-4 to 1e-2 (log scale) | High — wrong LR prevents convergence entirely |
| Dropout rate | 0.0, 0.3, 0.5, 0.6 | High — especially for small training sets |
| Weight decay (L2) | 0, 1e-5, 5e-4, 1e-3 | Medium — regularization, interact with dropout |
| Hidden dimension | 64, 128, 256, 512 | Medium — more capacity, but diminishing returns |
| Number of layers | 1, 2, 3 (with skip for 3) | Medium — more layers for larger graphs |
| Attention heads K | 1, 4, 8 (GAT only) | Low — modest improvement, K=8 usually fine |
Sometimes the best way to improve GNN performance is not to change the architecture but to change the GRAPH. Graph augmentation adds or removes edges/nodes to make the structure more amenable to message passing:
One underappreciated benefit of the design space framework: the same trained GNN can transfer across similar graph domains. If you train GraphSAGE on a social network recommendation task and want to apply it to a protein interaction network, the architecture is the same — only the input features and output labels differ. The message function (W hu) and the concat-then-transform aggregation are domain-agnostic operations. You can fine-tune a pretrained GNN on a new graph domain with far less data than training from scratch.
This is analogous to how CNNs trained on ImageNet transfer to medical imaging — the low-level pattern detectors (edges, textures) are domain-general. In GNNs, the message-passing pattern (aggregate-and-combine) is domain-general; the learned weights encode domain-specific notions of what makes a "good" neighborhood summary.
All the GNNs in this lecture — GCN, GraphSAGE, GAT — have a shared theoretical ceiling on expressiveness: they are at most as powerful as the 1-Weisfeiler-Lehman (1-WL) graph isomorphism test. This test is an efficient algorithm that decides whether two graphs are isomorphic by iteratively relabeling nodes with their sorted neighbor label multisets. If two graphs produce the same 1-WL coloring, the test fails to distinguish them.
Xu et al. (2019) proved: any MSG + AGG GNN is at most as expressive as 1-WL. Moreover, if AGG is SUM and MSG is any injective function, the GNN is exactly as expressive as 1-WL. This gives us a concrete design target: use SUM aggregation for maximum expressiveness. Mean and Max are strictly less expressive — they lose count information (mean can't distinguish "two neighbors with value 1" from "one neighbor with value 2").
| AGG Function | Information Lost | 1-WL equivalent? |
|---|---|---|
| Sum | Nothing (injective with MLP) | Yes (if MSG is injective) |
| Mean | Multiset size (count of neighbors) | No — can't count neighbors |
| Max | All but the maximum element | No — ignores all but strongest signal |
| LSTM | Permutation equivariance | Unclear — permutation-sensitive |
GIN (Graph Isomorphism Network, Xu et al. 2019) instantiates the maximally expressive GNN: uses SUM aggregation with an MLP, and includes the center node's own embedding: hv(k) = MLP((1+ε) hv(k-1) + Σu∈N(v) hu(k-1)). The (1+ε) weight distinguishes the center from its neighbors. This is the theoretical optimal within the MSG + AGG framework.
CS224W Lecture 5: GNN Augmentation — how to make GNNs more expressive by augmenting the graph (virtual nodes, position encodings, subgraph features). The lesson also covers GNN expressiveness theory (1-WL test, GIN). Stay tuned.
All three architectures (GCN, GraphSAGE, GAT) are available as drop-in layers in modern graph ML libraries:
GCNConv, SAGEConv, GATConv, GATv2Conv. Each implements the paper's algorithm exactly. The MessagePassing base class lets you implement your own MSG+AGG GNN in ~20 lines.fn.copy_u, fn.mean, etc.) map directly to the MSG+AGG framework — reading DGL code IS reading MSG+AGG code."The most important insight in GNN design is that there is no universally optimal architecture. The choice of MSG and AGG functions should be guided by the structure of your problem — specifically, by whether neighbor importance is uniform, structure-dependent, or content-dependent."
— You, Ying, Leskovec, NeurIPS 2020