CS224W Lecture 4

A General Perspective on GNNs

Every GNN ever invented is just a choice of message function and aggregation function. Once you see this, GCN, GraphSAGE, and GAT become three points in the same design space.

Prerequisites: GNN basics (Lec 3) + Attention intuition (optional but helpful). That's it.
10
Chapters
4+
Simulations
0
Assumed Knowledge

Chapter 0: The Design Space

In Lecture 3 you built the GCN — one specific choice for how to aggregate neighbor information. But why that choice? What if neighbors aren't equally important? What if you want to keep your own features separate from your neighbors' features? The answer is that GCN is just one point in a vast GNN design space.

You, Ying, and Leskovec (NeurIPS 2020) formalized this insight: every GNN layer can be decomposed into exactly two steps, and different choices for these steps give you different architectures. Understanding this decomposition means you can reason about ALL GNNs, not just memorize three names.

The universal GNN layer: Every GNN layer = MESSAGE computation + AGGREGATION. Pick your MSG function. Pick your AGG function. That's your architecture. GCN, GraphSAGE, and GAT are all choices of these two functions.

Step 1: Message Computation

Each neighbor u computes a message to send to node v. The message is a function of u's current embedding. At layer l:

mu(l) = MSG(l)( hu(l-1) )

The simplest MSG is a linear transformation: MSG(hu) = W(l) hu. But you could use a more complex neural network, or scale the message by some factor, or weight it by an attention score — all are valid choices.

Step 2: Aggregation

Node v receives messages from all its neighbors, and must collapse them into a single vector. Then it updates its own embedding:

hv(l) = AGG(l)( {mu(l) : u ∈ N(v)}, hv(l-1) )

AGG must be a function that: (1) takes a set of messages (order shouldn't matter — graphs have no canonical ordering), and (2) incorporates the node's own previous embedding hv(l-1). Common choices: mean, sum, max, or a neural network applied after pooling.

Why sets? Graphs have no canonical ordering of neighbors. If you process neighbors left-to-right, you've implicitly assumed an ordering that doesn't exist. AGG must be a permutation-invariant function — mean, sum, and max are all permutation-invariant. MLPs and LSTMs over raw neighbor lists are not (unless you sort them, which introduces its own assumptions).

The Design Space Is Huge

Concretely, for MSG you can vary: the weight matrix (shared vs per-layer), whether to normalize by degree, whether to include edge features. For AGG you can vary: mean vs sum vs max vs attention-weighted, whether to concatenate self-embedding vs add it, whether to apply a nonlinearity before or after aggregation. That's a huge combinatorial space — and You et al. found that simple designs often outperform complex ones on many benchmarks.

The Full Framework (You et al., 2020)

The design space paper identifies four dimensions of GNN design beyond just MSG and AGG:

1. Intra-layer design: MSG function, AGG function, activation, normalization (none, batch norm, layer norm). These are the choices we discuss in chapters 1-3.

2. Inter-layer design: How layers connect to each other. Options: stack (plain), residual (add skip), dense (all-to-all connections), jumping knowledge (concatenate all layers).

3. Layer connectivity: How many GNN layers vs how many MLP layers before/after. Pre-processing MLPs on raw features often help dramatically. Post-processing MLPs for the final prediction.

4. Learning objective: Supervised (cross-entropy, MSE), self-supervised (contrastive), semi-supervised. The objective drives what the embeddings learn to encode.

You et al. evaluated over 315,000 GNN configurations. Their top finding: no single design choice dominates across all tasks. The winning configuration on molecular property prediction looks completely different from the winning configuration on social network link prediction. This means the "design space" framing is not just academic — it's practical guidance to explore your specific problem's sweet spot rather than defaulting to GCN.

Task Type Shapes MSG and AGG Choices

The right MSG+AGG choice often follows from the task structure:

The MSG + AGG Framework

A small graph with 5 nodes. Click any node to see its message computation and aggregation step. Toggle which AGG function to use and observe how the output changes.

Click a node to see its MSG and AGG computation. Currently using Mean aggregation.
In the MSG + AGG framework, why must the AGG function be permutation-invariant?

Chapter 1: GCN as MSG + AGG

Now that we have the unified framework, let's re-read GCN through this lens. You already know GCN from Lecture 3 — but seeing it as MSG + AGG reveals the specific choices Kipf & Welling made, and makes it easy to compare with GraphSAGE and GAT.

GCN's Message Function

In GCN, the message from neighbor u to node v is a linear transformation of u's embedding, normalized by both their degrees:

mu→v(l) = W(l) · hu(l-1) / √|N(u)| · |N(v)|

This is symmetric normalization — normalize by the geometric mean of both degrees. The motivation: if u has 100 neighbors, it sends weaker messages (its influence is "diluted" across 100 targets). If v has 100 neighbors, each incoming message is weaker (v averages 100 signals). This prevents high-degree nodes from dominating.

Simpler version you'll see most often: In practice, GCN is often written as just MSG = W(l) hu(l-1) / |N(v)|, normalizing only by v's degree. The symmetric version is theoretically cleaner and is what you get from the graph Laplacian derivation, but both work in practice.

GCN's Aggregation Function

Aggregation is a sum (which after normalization becomes a weighted mean). Self-loop is included via an augmented graph à = A + I, so node v is in its own neighborhood:

hv(l) = σ( ∑u ∈ N(v) ∪ {v} mu→v(l) )

Because the messages are normalized, this sum is effectively a weighted mean where higher-degree neighbors contribute less. The self-loop means v's own previous embedding participates in the sum with weight 1/|Ñ(v)| where Ñ includes v itself.

Data Flow in GCN

Let's be precise about tensor shapes. Graph: n = 1000 nodes. Input features: din = 64. Two GCN layers with dhidden = 128, dout = 64.

Input X
Shape [1000, 64] — raw features. Also need adjacency à (with self-loops), normalized: D̃-1/2 à D̃-1/2, shape [1000, 1000] (sparse).
↓ Layer 1: W1 ∈ R64 × 128
H(1)
Shape [1000, 128]. Each row encodes 1-hop neighborhood. Parameters: 64×128 = 8192. Same for all 1000 nodes.
↓ Layer 2: W2 ∈ R128 × 64
H(2) = Z
Shape [1000, 64]. Each row is the final node embedding. Parameters: 128×64 = 8192. Total: 16384 parameters, regardless of graph size.
GCN's strength: O(|E| · d) compute per layer (proportional to edges, not nodes squared). Works on graphs with millions of nodes as long as the graph is sparse. Parameters are O(d2), completely independent of graph size.

What GCN Gets Wrong

GCN uses uniform weights — every neighbor contributes equally (after degree normalization). But in a citation network, a paper cited by Nature and a paper cited by an obscure blog are very different. Should they carry the same weight in a neighbor's embedding? GCN says yes. GAT will say no.

Also, GCN's normalization by degree is a hand-crafted heuristic. Why geometric mean? Why not just v's degree? Why not something learned? GraphSAGE and GAT replace this heuristic with learned components.

GCN vs GAT Aggregation Weight Comparison

A star graph: central node v connected to 5 neighbors with varying degrees. See how GCN assigns weights based purely on degree (fixed heuristic), while GAT weights would depend on feature similarity. Drag the slider to change the central node's degree and watch GCN weights recalculate.

v degree (additional connections) 5
GCN weight = 1/√(deg_u × deg_v). Higher-degree neighbors contribute less.
In GCN, what is the purpose of normalizing messages by node degrees (dividing by |N(u)| and |N(v)|)?

Chapter 2: GraphSAGE

Hamilton, Ying, and Leskovec introduced GraphSAGE (Graph SAmple and aggreGatE) at NeurIPS 2017. The motivation was inductive learning: train on one graph (e.g., a portion of Reddit), run inference on new, unseen graphs or new nodes. GCN was transductive — it needed to see the full graph adjacency matrix during training. GraphSAGE fixes this with a key architectural difference: it keeps the node's own embedding separate from its neighborhood summary, and concatenates them.

The Two-Stage Update

GraphSAGE replaces the single pooled sum of GCN with an explicit two-stage operation:

Stage 1 — Aggregate from neighbors:

hN(v)(l) = AGG( {hu(l-1) : u ∈ N(v)} )

Stage 2 — Concatenate self with neighborhood summary, then transform:

hv(l) = σ( W(l) · [hv(l-1) ‖ hN(v)(l)] )

The symbol ∥ means concatenation. If hv(l-1) is d-dimensional and hN(v)(l) is d-dimensional, then the concatenated vector is 2d-dimensional. W(l) maps from 2d to d.

The CONCAT is the key difference. GCN sums everything together — self and neighbors are mixed in a single operation. GraphSAGE keeps them separate until after aggregation, then concatenates. This gives the network the ability to "know" which part of the vector came from itself vs from the neighborhood, and weight them differently via W(l). Empirically, this consistently improves performance over GCN.

Three AGG Options

GraphSAGE proposes three aggregation functions for Stage 1:

AGG NameFormulaProperties
MeanhN(v) = (1/|N(v)|) ∑u∈N(v) huFastest. No extra parameters. Loses count information (can't distinguish "2 neighbors" from "5 neighbors" if they average to the same thing).
PoolhN(v) = max({ReLU(Wpool hu + b) : u∈N(v)})Element-wise max after a per-neighbor MLP. Extra parameters Wpool. Can pick out the "most extreme" feature across neighbors.
LSTMhN(v) = LSTM({hu : u∈N(v)})Most expressive but NOT permutation-invariant. Requires random shuffling of neighbors during training to prevent ordering bias. Slower.

In practice, mean-aggregation and pool-aggregation perform similarly. LSTM is often not worth the extra complexity and broken permutation invariance. The real win of GraphSAGE over GCN comes from the concat architecture, not from fancy AGG functions.

Inductive Learning via Sampling

GraphSAGE also introduced neighbor sampling: instead of aggregating over ALL neighbors (which can be 10,000 for hub nodes in a social graph), sample a fixed-size subset S (typically 25 for layer 1, 10 for layer 2). This makes mini-batch training feasible on massive graphs. The expected gradient is still an unbiased estimate of the full gradient.

python
import torch
import torch.nn as nn

class GraphSAGELayer(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        # W maps from concatenated [self || neighbors] to output
        self.W = nn.Linear(in_dim * 2, out_dim)
        self.activation = nn.ReLU()

    def forward(self, h, adj):
        # h: [N, in_dim] node features
        # adj: [N, N] adjacency (normalized, no self-loops)

        # Stage 1: aggregate neighbors (mean)
        deg = adj.sum(dim=1, keepdim=True).clamp(min=1)  # [N, 1]
        h_neigh = (adj @ h) / deg  # [N, in_dim] — mean of neighbors

        # Stage 2: concatenate self + neighbor, then transform
        h_cat = torch.cat([h, h_neigh], dim=-1)  # [N, 2*in_dim]
        h_out = self.activation(self.W(h_cat))    # [N, out_dim]

        # Optional: L2 normalize (common in inductive settings)
        h_out = h_out / h_out.norm(dim=-1, keepdim=True).clamp(min=1e-6)
        return h_out

PinSage: GraphSAGE at Billion Scale

GraphSAGE's neighbor-sampling design made it the foundation for PinSage, Pinterest's industrial GNN for image recommendation (Ying et al., KDD 2018). PinSage operates on a graph with 3 billion nodes (pins) and 18 billion edges — the full graph would require 100TB of memory just to store. Key adaptations:

Localized convolution: Only sample a fixed-size neighborhood (K=500 for layer 1, K=250 for layer 2). Training mini-batches never touch the full graph.

Importance-based sampling: Instead of uniform random sampling, weight neighbors by their random-walk visit probability to the central node — more important neighbors sampled more often.

Curriculum training: Train with easy negatives first (random negatives from across the corpus), then hard negatives (items that are almost but not quite relevant). Prevents collapse to trivial solutions.

MapReduce inference: At test time, compute all embeddings in a single pass using MapReduce — no repeated neighborhood lookups. This enables real-time serving at Pinterest scale.

Why PinSage uses GraphSAGE not GCN: GCN requires the full normalized adjacency matrix D-1/2ÃD-1/2. On a 3-billion-node graph this is impossible to materialize. GraphSAGE's local inductive computation — just look at your local neighbors — is the key property that makes billion-scale GNNs feasible.
What is the key architectural difference between GraphSAGE and GCN?

Chapter 3: Graph Attention Networks (Showcase)

GCN and GraphSAGE both treat neighbor importance as a fixed, hand-crafted function of degree. A neighbor with degree 1 contributes more than a neighbor with degree 100, simply because of the normalization. But surely the content of a neighbor matters more than its structural position? That's the question GAT (Velickovic et al., ICLR 2018) answers.

GAT learns attention weights αvu that tell node v how much to attend to each neighbor u. These weights depend on the features of both v and u — two nodes with very different features might warrant a small weight, while two similar nodes warrant a large weight. Crucially, these weights are learned, not hand-designed.

Step 1: Compute Raw Attention Coefficients

First, transform both node embeddings with the same weight matrix W, then score their compatibility with a small neural network parametrized by vector a:

evu = LeakyReLU( aT · [W hv ‖ W hu] )

Breaking this down: W hv and W hu project both embeddings into the same space. Concatenating them gives a vector of length 2d'. The learnable vector a (length 2d') scores this concatenation with a single dot product. LeakyReLU allows negative scores (unlike ReLU, which would kill half the scores). The result evu is a scalar — how "compatible" v and u are.

Step 2: Normalize with Softmax

Raw scores are incomparable across different nodes (different scales). Normalize them into a proper probability distribution over v's neighborhood:

αvu = softmaxu(evu) = exp(evu) / ∑k ∈ N(v) exp(evk)

Now Σu∈N(v) αvu = 1, and all weights are positive. Node v's attention over its neighbors is a proper distribution — it sums to 1, and we can interpret each αvu as "what fraction of attention v gives to u."

Step 3: Weighted Aggregation

Use the attention weights to compute a weighted sum of neighbors' transformed embeddings:

hv(l) = σ( ∑u ∈ N(v) αvu · W(l) hu(l-1) )

Compare to GCN: replace the uniform weight 1/|N(v)| with the learned weight αvu. Same structure, but now the network decides which neighbors matter most — for EACH node, EACH layer, based on features.

What GAT learns to attend to: In citation networks, papers tend to attend more to papers with similar topics. In social networks, users attend more to friends with similar interests. The attention weights are a soft, differentiable way for the network to select relevant neighbors — without any manual feature engineering.

A Concrete Worked Example

Let's trace through one step of GAT for a small graph. Say node v has features hv = [1.0, 0.5] and has neighbors A with hA = [0.9, 0.6] and B with hB = [0.1, 0.8].

Setup: W = I (identity, so Wh = h), a = [1, -1, 0.5, -0.5] (first two dims = source, last two = target).

Score for v→A: [Whv ∥ WhA] = [1.0, 0.5, 0.9, 0.6].
aT · this = 1(1.0) + (-1)(0.5) + 0.5(0.9) + (-0.5)(0.6) = 1.0 - 0.5 + 0.45 - 0.30 = 0.65.
LeakyReLU(0.65) = 0.65.

Score for v→B: [Whv ∥ WhB] = [1.0, 0.5, 0.1, 0.8].
aT · this = 1.0 - 0.5 + 0.5(0.1) + (-0.5)(0.8) = 0.5 + 0.05 - 0.40 = 0.15.
LeakyReLU(0.15) = 0.15.

Softmax: exp(0.65) = 1.916, exp(0.15) = 1.162. Sum = 3.078. αvA = 1.916/3.078 = 0.623, αvB = 1.162/3.078 = 0.377.

Output: hvnew = σ(0.623 · [0.9, 0.6] + 0.377 · [0.1, 0.8]) = σ([0.598, 0.678]) ≈ [0.645, 0.663] (with sigmoid).

Notice: node A (more similar to v) gets 62.3% of the attention, node B (less similar) gets 37.7%. The network has learned — from the parameters W and a — to prefer more compatible neighbors. And this all happened via gradient descent on the task loss, not via any manual similarity function.

GAT Attention Visualizer (Showcase)

A 6-node graph. Click a node to see its attention distribution over neighbors — shown as edge thickness. Drag a feature slider to change node features and watch attention weights update in real time.

Node A feature 0.8
Node B feature 0.3
Click a node to see its attention distribution. Edge thickness = attention weight αvu.
In GAT, what does the softmax normalization of attention coefficients guarantee?

Chapter 4: Multi-Head Attention

A single set of attention weights asks ONE question about each neighbor: "how important is u to v?" But importance can have multiple dimensions. In a citation network, paper v might attend to paper u because they share a topic — and simultaneously attend to paper w because they use similar methods. A single attention head can only capture one notion of importance at a time.

Multi-head attention runs K independent attention mechanisms in parallel, each with its own parameters Wk and ak. Each head learns to attend to different aspects of the neighborhood.

The Multi-Head Formula

Run K heads, get K output embeddings per node, then concatenate:

hv(l) = ‖k=1K σ( ∑u ∈ N(v) αvuk · Wk hu(l-1) )

The notation ∥ means concatenation. If each head outputs a d/K-dimensional vector, the concatenated output is d-dimensional — same total size as a single-head model. Each headk has its own attention weights αkvu computed with its own parameters (Wk, ak).

For the final layer, concatenation can produce very high-dimensional outputs. A common alternative is to average across heads instead: hv = σ( (1/K) ∑ku αkvu Wk hu ). Velickovic et al. use concatenation for intermediate layers and averaging for the final layer.

Why Multiple Heads Stabilize Training

Think of K heads as a K-model ensemble. If one head learns a bad set of attention weights early in training, the other K-1 heads can compensate. The gradient signal averaged across heads is lower-variance than a single head — similar to why mini-batch SGD is more stable than single-sample gradient estimates.

Empirically, GAT with K=8 heads consistently outperforms K=1. The improvement plateaus around K=8; more heads add parameters but not much accuracy. In the original paper: K=8 for all but the last layer (where K=1 averaging is used).

Parameter Count

One GAT layer with din input dimensions, dout output dimensions, K heads:

ComponentShapeCount (din=64, dout=64, K=8)
Wk per head (K matrices)[din, dout/K]8 × (64 × 8) = 4096
ak per head (K vectors)[2 × dout/K]8 × 16 = 128
Total4224 parameters

Vs GCN one layer: din × dout = 64 × 64 = 4096 parameters. GAT adds only 128 extra parameters (the attention vectors ak) over GCN — a tiny cost for learned attention weights.

Implementing Multi-Head GAT in PyTorch Geometric

python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv

class TwoLayerGAT(nn.Module):
    def __init__(self, d_in, d_hid, d_out, K=8):
        super().__init__()
        # Layer 1: K=8 heads, each head outputs d_hid//K → concat = d_hid
        self.gat1 = GATConv(d_in, d_hid // K, heads=K, concat=True,
                            dropout=0.6, negative_slope=0.2)
        # Layer 2: K=1 (or average K=8), outputs d_out
        self.gat2 = GATConv(d_hid, d_out, heads=1, concat=False,
                            dropout=0.6, negative_slope=0.2)
        self.dropout = nn.Dropout(p=0.6)

    def forward(self, x, edge_index):
        x = self.dropout(x)
        x = F.elu(self.gat1(x, edge_index))  # [N, d_hid]
        x = self.dropout(x)
        x = self.gat2(x, edge_index)          # [N, d_out] — averaged heads
        return F.log_softmax(x, dim=-1)

# The exact architecture from the GAT paper (Cora: 8-head hidden, 1-head output)
model = TwoLayerGAT(d_in=1433, d_hid=64, d_out=7, K=8)
# d_in=1433 (Cora bag-of-words), d_out=7 (7 classes), K=8 heads in hidden layer
In multi-head GAT, each head has its own weight matrix Wk and attention vector ak. What can different heads learn?

Chapter 5: GCN vs GraphSAGE vs GAT

Now we can make the comparison crisp. All three operate within the MSG + AGG framework — they're just different choices. The table below is the one mental model you should carry forward from this lecture.

Property GCN GraphSAGE GAT
MSG function W hu / √degu degv hu (no transform before AGG) αvu W hu
AGG function Sum (= weighted mean after normalization) Mean / Pool / LSTM, then CONCAT with self Weighted sum (weights = softmax attention)
Neighbor weights Fixed: 1/√degu · degv Fixed: uniform mean (or max) Learned: αvu from features
Self vs neighbor Mixed together (self-loop in Ã) Explicit CONCAT — kept separate until W Attention can include self (add self-loop)
Parameters (1 layer, d×d) d2 2d2 (2d input to W due to concat) d2 + 2d·K (attention vectors)
Inductive? No (needs full adj matrix) Yes (computes from features + local neighborhood) Yes (attention computed from features)
Key strength Simple, fast, strong baseline Inductive, keeps self separate Content-aware neighbor weighting
Which to use? Start with 2-layer GCN — it's often surprisingly strong. If you need inductive learning (new nodes at test time), switch to GraphSAGE. If you suspect not all neighbors are equally relevant, add GAT. You et al. (NeurIPS 2020) found that the best architecture is highly task-dependent — no single winner across all benchmarks.

Code Comparison: One Forward Pass

Here are the three models side by side in PyTorch Geometric. Notice how the aggregation step is the only thing that changes.

python
from torch_geometric.nn import GCNConv, SAGEConv, GATConv

# GCN: symmetric normalization, single W matrix
gcn_layer = GCNConv(in_channels=64, out_channels=128)
h_gcn = gcn_layer(x, edge_index)  # x: [N, 64] → [N, 128]

# GraphSAGE: mean AGG, concat self+neigh, 2× input to W
sage_layer = SAGEConv(in_channels=64, out_channels=128)
h_sage = sage_layer(x, edge_index)  # x: [N, 64] → [N, 128]

# GAT: learned attention, K=8 heads, concat output
gat_layer = GATConv(in_channels=64, out_channels=16, heads=8, concat=True)
h_gat = gat_layer(x, edge_index)  # x: [N, 64] → [N, 128] (16 × 8 heads)

# All produce [N, 128] — same shape, different computation inside
Which GNN variant is best suited for predicting properties of nodes that were NOT present during training (e.g., new proteins added to a database)?

Chapter 6: Stacking GNN Layers

With one GNN layer, each node "sees" its 1-hop neighborhood. With K layers, it sees its K-hop neighborhood. More layers should mean richer representations, right? In practice, deeper GNNs often perform WORSE than shallow ones — a phenomenon called over-smoothing.

What Over-Smoothing Is

Let's think about what repeated mean aggregation does. After 1 layer: each node's embedding is the average of its neighbors' features. After 2 layers: it's the average of averages. After K layers: the embedding is a weighted average over the K-hop neighborhood — and as K grows, this neighborhood covers more and more of the graph.

In an Erdős-Rényi random graph with average degree d, after O(logd n) layers, the K-hop neighborhood covers the entire graph. At that point, EVERY node's embedding is a weighted average over ALL nodes — they all converge to the same vector. The embeddings are fully mixed. This is over-smoothing.

Over-smoothing is not just theoretical. Experiments consistently show that GCN performance peaks at K=2 or K=3 layers on most citation network benchmarks. Adding layers 4, 5, 6... monotonically degrades performance. The network becomes a low-pass filter that removes all high-frequency (discriminative) structure from the graph signal.

A Spectral View

From a signal processing perspective, GCN applies a low-pass filter to the graph signal. The graph Laplacian L = D - A has eigenvalues in [0, 2d] where d is max degree. GCN corresponds to multiplying the graph signal by (I - D-1/2 A D-1/2) — a filter that attenuates high-frequency components (nodes that differ from their neighbors). After K layers, this filter is applied K times, suppressing high-frequency information exponentially. The NODES THAT DIFFER FROM THEIR NEIGHBORHOOD are exactly the ones you want to distinguish — and GCN erases them with enough layers.

How Deep Should You Go?

In practice: 2-3 layers for most tasks. If you need a 7-hop receptive field, don't use 7 GNN layers — use the graph structure more cleverly (multi-hop connections, virtual nodes). This is different from CNNs and Transformers, where going deeper almost always helps.

Over-Smoothing Demo

A graph with two communities (warm = community A, teal = community B). Each "Step" applies one GCN mean-aggregation layer. Watch how node colors (embeddings) converge as layers increase.

Layer 0: Two distinct communities. Community A (warm) and Community B (teal).
Why does stacking many GNN layers cause over-smoothing?

Chapter 7: Skip Connections

You hit a wall: shallow GNNs can't see far enough, deep GNNs over-smooth. The solution comes from an unlikely source — ResNets. He et al. (2015) showed that adding a skip connection (identity shortcut) from layer l to layer l+1 prevents degradation in very deep CNNs. The same trick works for GNNs.

Residual GNN

The residual update adds the previous layer's embedding directly:

hv(l) = GNN_layer(hv(l-1)) + hv(l-1)

This ensures that even if the GNN layer "messes up" (e.g., over-smooths), the skip connection passes through the pre-smoothed embedding unchanged. The network can learn to use the skip connection when aggregation hurts (set GNN_layer output ≈ 0) and to use the GNN output when aggregation helps.

Why residual connections prevent over-smoothing: The skip adds hv(l-1) directly. At layer K, the output is a weighted average of embeddings from ALL layers 0 through K. Even if the pure GNN contribution is over-smoothed, the early-layer embeddings (less smoothed) are still present in the sum. The network sees the graph from multiple "distances" simultaneously.

Jumping Knowledge Networks (JK-Net)

Xu et al. (2018) proposed a more aggressive version: instead of adding consecutive layers, concatenate embeddings from ALL K layers and aggregate them at the end. This is called a JK-Net (Jumping Knowledge):

hvfinal = AGG( {hv(1), hv(2), ..., hv(K)} )

Where AGG is typically max-pooling or LSTM over the layer sequence. Each layer hv(l) captures a different scale: layer 1 sees 1-hop, layer 2 sees 2-hop, etc. JK-Net lets each node adaptively select which scale of neighborhood information is most useful for its task.

Comparison of Approaches

Residual connection: Simple, adds 0 parameters. Skip from l-1 to l. Like "maintain a running average of progress."

JK-Net: More expressive. Accesses ALL layer embeddings. Each node can use a different scale adaptively. More parameters (aggregation over K embeddings).

Dense connections (DenseNet-style): Every layer connects to all subsequent layers. Maximally expressive but O(K2) connections.

DropEdge: Randomly drop edges during training. Acts as data augmentation. Prevents co-adaptation of edge weights. Works well combined with residuals.

python
class ResidualGNN(nn.Module):
    def __init__(self, d, K):
        super().__init__()
        self.layers = nn.ModuleList([GCNConv(d, d) for _ in range(K)])

    def forward(self, x, edge_index):
        h = x
        for layer in self.layers:
            h = F.relu(layer(h, edge_index)) + h  # skip connection
        return h

class JKNet(nn.Module):
    def __init__(self, d, K):
        super().__init__()
        self.layers = nn.ModuleList([GCNConv(d, d) for _ in range(K)])

    def forward(self, x, edge_index):
        h = x
        all_hs = []
        for layer in self.layers:
            h = F.relu(layer(h, edge_index))
            all_hs.append(h)
        # Max-pool across all layer embeddings
        return torch.stack(all_hs, dim=0).max(dim=0).values
How does a skip connection (residual) help combat over-smoothing in deep GNNs?

Chapter 8: Practical GNN Design

Theory tells you what's possible. Practice tells you what actually works. You, Ying, and Leskovec (NeurIPS 2020) ran a massive systematic evaluation across 12 graph datasets and over 300,000 GNN configurations. Their key finding: there's no universally best GNN design, but there are reliable rules of thumb.

What You et al. (2020) Found

Aggregation: Sum aggregation often beats mean for node classification. This surprised people — sum preserves more information (it doesn't normalize away degree). But for link prediction, mean can win.

Activation: ReLU works well in most cases. PReLU (learnable slope) sometimes helps. Avoid sigmoid in hidden layers (vanishing gradients).

Layer connectivity: Skip connections almost always help when going beyond 3 layers. Without them, performance degrades monotonically after ~3 layers.

Depth: 2-3 layers is the sweet spot for most citation network benchmarks. On molecular graphs, 4-5 layers can help (molecules are small, less over-smoothing risk). On social networks (large, dense), 2 layers.

Batch normalization: BatchNorm after each GNN layer stabilizes training, especially for deeper models (K≥4). Treat it like you would in CNN design.

Dropout: Apply dropout to node features BEFORE aggregation, or to the final embedding. Avoid dropout on edge weights (unstable for GAT attention).

The Practical Recipe

Start Here
2-layer GCN with mean aggregation, ReLU, no skip connections. This is your baseline. ~80% of the way to the best possible performance on most tasks.
↓ if baseline underfits
Add Expressivity
Switch to GraphSAGE (concat) or GAT (attention). Use sum aggregation if you need to count (e.g., isomorphism-sensitive tasks). Add BatchNorm between layers.
↓ if you need more depth (K≥4)
Add Skip Connections
Residual connections or JK-Net. Without these, adding layers hurts. With them, you can safely go to K=6-8.
↓ if still underperforming
Augment the Graph
Virtual nodes (add a node connected to everything), higher-order features (triangle counts), edge features. These are often more impactful than changing the GNN architecture.
The uncomfortable truth: In the large-scale study, a carefully tuned 2-layer GCN often matched or beat sophisticated 6-layer models with attention and skip connections. The hyperparameters (learning rate, dropout, hidden dim) often matter more than the architecture choice. Don't optimize architecture until you've optimized your hyperparameters.

Dropout Placement in GNNs

Where you place dropout in a GNN matters more than in standard DNNs. Three options:

Feature dropout (best): Apply dropout to node feature vectors BEFORE aggregation. This randomly masks individual input features. Prevents the GNN from relying on any single feature dimension. The original GAT paper uses this for Cora (p=0.6).

Attention dropout: Apply dropout to attention coefficients αvu before they're used in the weighted sum. Randomly drops edges from the computation graph — similar to DropEdge. Makes the model more robust to missing connections.

Embedding dropout: Apply dropout to node embeddings after each GNN layer. Standard approach, works but slightly less effective for GNNs than feature dropout because embeddings already aggregate information from multiple sources.

What NOT to do: Don't apply dropout to the final output layer if using cross-entropy — the softmax expects un-scaled logits. And don't apply dropout inside the message function if you're debugging — it makes attention weights stochastic and hard to interpret.

Batch Normalization for GNNs

Apply BatchNorm after each GNN layer, before the activation. This normalizes node embeddings across the batch of nodes — analogous to normalizing across the batch in standard DNNs. For mini-batch training (where batch = subgraph), use the mini-batch statistics (standard PyTorch BatchNorm works). Empirically, BatchNorm is especially important for K≥3 layers.

python
class PracticalGNN(nn.Module):
    def __init__(self, d_in, d_hid, d_out, K=3):
        super().__init__()
        dims = [d_in] + [d_hid] * (K - 1) + [d_out]
        self.convs = nn.ModuleList(
            [GCNConv(dims[i], dims[i+1]) for i in range(K)]
        )
        self.bns = nn.ModuleList(
            [nn.BatchNorm1d(dims[i+1]) for i in range(K - 1)]
        )
        self.dropout = nn.Dropout(p=0.5)

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < len(self.bns):
                x = F.relu(self.bns[i](x))  # BN then ReLU
                x = self.dropout(x)
        return x  # final layer: no BN, no ReLU, no dropout

When to Use Each Architecture: Decision Guide

SituationRecommended StartWhy
Homophilic graph (neighbors similar to center)2-layer GCNUniform weights near-optimal when all neighbors are relevant. Simple baseline usually hard to beat.
Heterophilic graph (neighbors often dissimilar)GAT with K=8 headsNeed to selectively attend to the few relevant neighbors among noisy majority.
New nodes at inference time (inductive)GraphSAGE (mean AGG)GCN requires full adj matrix. GraphSAGE only needs node features + local neighborhood.
Task requires counting (graph isomorphism sensitive)GIN (sum + MLP)Mean/max lose count information. SUM preserves it. GIN is maximally expressive within MSG+AGG.
Very deep model needed (K≥5)Any + skip connectionsOver-smoothing makes deep models fail. Residual or JK-Net allows depth.
Large graph (>10M nodes)GraphSAGE + neighbor samplingGCN full-batch doesn't fit in memory. GAT is O(|E|) but still needs all edges.
Edge features availableMPNN or edge-conditioned GATStandard GAT ignores edge features. Concat them into attention scoring.
According to You et al. (2020), what is the most common practical mistake when designing GNNs?

Chapter 9: Connections

You now understand the GNN design space: MSG + AGG, with GCN / GraphSAGE / GAT as three points in it. You understand over-smoothing and how skip connections address it. Here's where this knowledge connects to the broader landscape.

Where We've Been

LectureKey IdeaLink
Lec 1: Intro to GraphsGraphs as data structures. Node/edge/graph tasks.Lecture 1
Lec 2: Node EmbeddingsDeepWalk, node2vec: shallow transductive embeddings.Lecture 2
Lec 3: GNNs 1Computation graphs, GCN, message passing, matrix form.Lecture 3
Lec 4: GNNs 2Design space, GraphSAGE, GAT, over-smoothing, skip connections.This lesson

Where GAT Lives in the Broader Picture

GAT replaces the graph Laplacian with a learned, input-dependent filter. This makes it structurally similar to Transformers — in fact, a Transformer is just GAT on a fully-connected graph (every token attends to every other token). The GAT attention mechanism (concat then LeakyReLU) is the "additive attention" variant; Transformers use "scaled dot-product attention." Both are valid choices in the same design space.

Read the original GAT paper in full detail:

Graph Attention Networks — Velickovic et al., ICLR 2018 →

Where GraphSAGE Lives

GraphSAGE's inductive formulation was the foundation for PinSage (Ying et al., KDD 2018) — Pinterest's production GNN that ran on 3 billion nodes. The key insight was that neighbor sampling + mini-batch training + inductive inference made billion-scale GNNs feasible. No full adjacency matrix ever fits in memory.

GNN Expressiveness and the WL Test

All the GNNs we've covered have a fundamental expressiveness limit: they are at most as powerful as the 1-Weisfeiler-Lehman graph isomorphism test. This means there exist pairs of non-isomorphic graphs that all MSG + AGG GNNs map to the same embedding. GIN (Graph Isomorphism Network, Xu et al. 2019) characterizes exactly when a GNN IS as powerful as 1-WL: it uses SUM aggregation (not mean or max) and applies an MLP with sufficient capacity. Mean and max aggregation are strictly less expressive. This is coming in CS224W Lecture 5.

The design space lens: Once you see MSG + AGG, you can read ANY new graph learning paper and immediately identify: what is their message function? what is their aggregation function? what graph structure do they operate on? This decomposition is the most useful mental model in graph ML.

Hyperparameter Grid for GNN Experiments

Before changing architecture, sweep these hyperparameters first (in order of typical impact):

HyperparameterTypical RangeImpact
Learning rate1e-4 to 1e-2 (log scale)High — wrong LR prevents convergence entirely
Dropout rate0.0, 0.3, 0.5, 0.6High — especially for small training sets
Weight decay (L2)0, 1e-5, 5e-4, 1e-3Medium — regularization, interact with dropout
Hidden dimension64, 128, 256, 512Medium — more capacity, but diminishing returns
Number of layers1, 2, 3 (with skip for 3)Medium — more layers for larger graphs
Attention heads K1, 4, 8 (GAT only)Low — modest improvement, K=8 usually fine

Graph Augmentation: Beyond Architecture

Sometimes the best way to improve GNN performance is not to change the architecture but to change the GRAPH. Graph augmentation adds or removes edges/nodes to make the structure more amenable to message passing:

The GNN Design Space as Transfer Learning

One underappreciated benefit of the design space framework: the same trained GNN can transfer across similar graph domains. If you train GraphSAGE on a social network recommendation task and want to apply it to a protein interaction network, the architecture is the same — only the input features and output labels differ. The message function (W hu) and the concat-then-transform aggregation are domain-agnostic operations. You can fine-tune a pretrained GNN on a new graph domain with far less data than training from scratch.

This is analogous to how CNNs trained on ImageNet transfer to medical imaging — the low-level pattern detectors (edges, textures) are domain-general. In GNNs, the message-passing pattern (aggregate-and-combine) is domain-general; the learned weights encode domain-specific notions of what makes a "good" neighborhood summary.

The 1-WL Test and GNN Expressiveness

All the GNNs in this lecture — GCN, GraphSAGE, GAT — have a shared theoretical ceiling on expressiveness: they are at most as powerful as the 1-Weisfeiler-Lehman (1-WL) graph isomorphism test. This test is an efficient algorithm that decides whether two graphs are isomorphic by iteratively relabeling nodes with their sorted neighbor label multisets. If two graphs produce the same 1-WL coloring, the test fails to distinguish them.

Xu et al. (2019) proved: any MSG + AGG GNN is at most as expressive as 1-WL. Moreover, if AGG is SUM and MSG is any injective function, the GNN is exactly as expressive as 1-WL. This gives us a concrete design target: use SUM aggregation for maximum expressiveness. Mean and Max are strictly less expressive — they lose count information (mean can't distinguish "two neighbors with value 1" from "one neighbor with value 2").

AGG FunctionInformation Lost1-WL equivalent?
SumNothing (injective with MLP)Yes (if MSG is injective)
MeanMultiset size (count of neighbors)No — can't count neighbors
MaxAll but the maximum elementNo — ignores all but strongest signal
LSTMPermutation equivarianceUnclear — permutation-sensitive

GIN (Graph Isomorphism Network, Xu et al. 2019) instantiates the maximally expressive GNN: uses SUM aggregation with an MLP, and includes the center node's own embedding: hv(k) = MLP((1+ε) hv(k-1) + Σu∈N(v) hu(k-1)). The (1+ε) weight distinguishes the center from its neighbors. This is the theoretical optimal within the MSG + AGG framework.

What's Next

CS224W Lecture 5: GNN Augmentation — how to make GNNs more expressive by augmenting the graph (virtual nodes, position encodings, subgraph features). The lesson also covers GNN expressiveness theory (1-WL test, GIN). Stay tuned.

Software Ecosystem

All three architectures (GCN, GraphSAGE, GAT) are available as drop-in layers in modern graph ML libraries:

"The most important insight in GNN design is that there is no universally optimal architecture. The choice of MSG and AGG functions should be guided by the structure of your problem — specifically, by whether neighbor importance is uniform, structure-dependent, or content-dependent."
— You, Ying, Leskovec, NeurIPS 2020