CS224W Lecture 5

GNN Augmentation & Training

Your raw graph is almost never ready to feed into a GNN. Nodes may have no features, edges may be too sparse, and choosing the wrong prediction head throws away information. This lesson covers everything between "I have a graph" and "my model is trained."

Prerequisites: GNN basics (Lec 3) + MSG+AGG framework (Lec 4). That's it.
10
Chapters
4+
Simulations
0
Assumed Knowledge

Chapter 0: The Raw Graph Is Never Ready

You have a citation network. Every paper is a node, every citation is an edge. You want to train a GNN to predict a paper's research topic. You open the dataset — and realize: the nodes have no features. Just IDs. What do you feed the GNN?

Or: you have a protein interaction graph where two proteins are connected if they interact. Most proteins have only 2-3 known interactions — the graph is incredibly sparse. A 2-layer GNN can only see 2 hops out, which might be 4-6 nodes. That's almost no information per node. What do you do?

Or: you have Twitter's social graph. Some accounts have 10 million followers. When you aggregate neighbor embeddings, one famous account dominates the computation for millions of nodes. This isn't wrong, but it's incredibly slow — O(degree) per layer, and degree can be 107. What do you do?

These are not edge cases. Real-world graphs almost always have at least one of these problems: missing features, too-sparse structure, or degree imbalance. Augmentation is not optional polish — it's necessary preprocessing.

The Gap Between Raw Input and Computational Graph

A GNN needs two things to work: (1) node features, and (2) a graph structure that enables good information flow. The raw input graph often provides neither in usable form. Graph augmentation bridges this gap — it modifies the graph before message passing begins.

Augmentation falls into two categories. Feature augmentation addresses missing or incomplete node features. Structure augmentation addresses problems with the graph topology — too sparse, too dense, or disconnected in ways that make message passing ineffective.

The key insight: The graph you feed to the GNN (the computational graph) does NOT have to be identical to the input graph. You can add features. You can add or remove edges. You can add entirely new virtual nodes. What matters is that the computational graph enables good information flow for your task.
The Augmentation Gap

Three problem graphs. Toggle between them to see what makes each difficult and what augmentation helps. This motivates the next four chapters.

Graph with no node features. Every node looks identical to the GNN.
Why can't you just skip augmentation and feed the raw graph directly to the GNN?

Chapter 1: Feature Augmentation

When nodes have no features, you face a fundamental question: what do you put in the GNN's input layer? You can't leave it empty. The answer depends on what you need the features to do: do they need to distinguish individual nodes (expressiveness), or do they need to generalize to new nodes (transferability)?

Option 1: Constant Feature

Give every node the same scalar feature: hv(0) = 1. This is the simplest possible choice. The GNN still learns from structure — it can distinguish a node with 5 neighbors from a node with 2, because their aggregation results differ. But it cannot distinguish two structurally identical nodes.

When constant works: If your task depends only on structural role (e.g., "is this node a hub?"), constant features are enough. The GNN learns to use degree and neighborhood structure. And it generalizes to new graphs perfectly — the feature is always 1.

Option 2: One-Hot Node ID

Give each node a unique indicator vector: node v gets a vector of all zeros except position v, which is 1. This is equivalent to giving each node a row of the identity matrix In×n. Now every node is uniquely identifiable — the GNN can, in principle, learn different things for every single node.

One-hot is inductive-unfriendly. If you train on a graph with 1000 nodes and then test on a new graph with 1001 nodes, node 1001 has no valid one-hot vector. The model has never seen it. You cannot generalize to new graphs with one-hot features. Use only for transductive settings (same graph at train and test time).

Option 3: Structural Features

Compute structural properties of each node and use them as features. These are things the GNN could theoretically learn, but computing them explicitly and giving them as input is much easier for the network to use. Common choices:

FeatureWhat it capturesFormula
Node degreeHow many neighbors|N(v)|
Clustering coefficientHow densely connected neighbors areedges among N(v) / (|N(v)| choose 2)
Pagerank scoreImportance from random walkiterative formula
Cycle participationShortest cycle length through vBFS from v

These features are inductive — you can compute them for any new node in any new graph. And they're often more informative than one-hot IDs for tasks like predicting molecular properties, where cycle length and clustering coefficient are chemically meaningful.

When the Graph Has Features But They're Incomplete

Sometimes nodes have partial features. For example, in a user-item recommendation graph, some users have profile information (age, location) and some don't. The solution: impute missing features with the column mean, or learn an embedding for "missing." Don't drop nodes with missing features — you lose graph structure.

Practical rule of thumb: If you have meaningful domain features (molecule atom type, protein sequence, text embedding), use those. If you have nothing, start with constant features + structural features (degree, clustering). Only use one-hot if your task is strictly transductive and you need maximum expressiveness per node.
python
# Three augmentation strategies in PyTorch Geometric
import torch
from torch_geometric.transforms import (
    AddSelfLoops, LocalDegreeProfile, OneHotDegree
)
from torch_geometric.utils import degree

# Strategy 1: constant feature (all 1s)
def add_constant_feature(data):
    num_nodes = data.num_nodes
    data.x = torch.ones(num_nodes, 1)  # shape [N, 1]
    return data

# Strategy 2: one-hot node ID
def add_onehot_feature(data):
    num_nodes = data.num_nodes
    data.x = torch.eye(num_nodes)  # shape [N, N] — not scalable!
    return data

# Strategy 3: degree + local profile (inductive, scalable)
def add_structural_features(data):
    # LocalDegreeProfile adds [min, max, mean, std, sum] of neighbor degrees
    transform = LocalDegreeProfile()
    data = transform(data)  # data.x shape: [N, 5]
    return data
You're training a GNN for drug discovery on a molecular graph dataset. You'll need to generalize to new molecules never seen in training. Which feature strategy should you use?

Chapter 2: Structure Augmentation — Virtual Nodes

You have a molecular graph where two atoms are 10 bonds apart. With a 3-layer GNN (3-hop receptive field), those atoms cannot "see" each other. But in chemistry, long-range electronic effects are real — atoms across the molecule do influence each other. Standard message passing is structurally blind to this.

One solution: add a virtual node — a synthetic node connected to every single node in the graph. Call it v*. Now every pair of real nodes is at most 2 hops apart: node A → v* → node B. The virtual node creates a "global information highway."

Why 2 hops? Every real node connects to v*. So the path from any real node A to any real node B is: A → v* → B, which is 2 hops. Even a 1-layer GNN can propagate information across the entire graph via the virtual node. This is enormous for sparse graphs where 6-hop paths are common.

How the Virtual Node Learns

At initialization, the virtual node's embedding hv* starts at zero (or a learned embedding). During message passing, it aggregates from every real node — forming a global summary of the entire graph. Then it sends back this summary to every real node. So at layer 1, every real node has seen information from every other node in the graph.

Why This Helps Sparse Graphs

In a molecular graph with 50 atoms, without the virtual node and with K=3 layers, a node might see only its 5 or 6 nearby neighbors. With the virtual node, every node sees the global average of all 50 atoms' embeddings after just 1 layer. For tasks like predicting a molecule's solubility — which depends on the whole molecule, not just local environment — this is crucial.

The Tradeoff: Signal Dilution

The virtual node aggregates from ALL nodes — including irrelevant ones. In a large graph with 10,000 nodes, the global summary is an average over 10,000 embeddings. Distant nodes contribute noise. This can hurt if the task is purely local. As with most engineering choices: use virtual nodes when your task has global structure; skip them when the task is local.

The virtual node in practice: It's widely used in molecular GNNs (e.g., DMP, MPNN variants). In the Open Graph Benchmark (OGB), adding a virtual node is one of the single most reliable tricks for improving performance on graph-level prediction tasks. It's essentially "free" in terms of parameters — just one extra node.
Virtual Node Information Highway

A sparse graph of 8 nodes. Toggle the virtual node on/off. Watch how many hops are needed for the highlighted source node (warm) to reach the target node (teal). With the virtual node, all nodes are 2 hops apart.

Without virtual node: source (orange) to target (teal) requires 5 hops.
After adding a virtual node connected to all real nodes, what is the maximum hop distance between any two real nodes?

Chapter 3: Structure Augmentation — Neighbor Sampling

Imagine you're building a GNN for Pinterest's recommendation graph: 2 billion pins, 18 billion edges. A single forward pass aggregating ALL neighbors for ALL nodes simultaneously would require more memory than exists on Earth. You cannot run vanilla message passing on this graph. You need neighbor sampling.

The idea is surgical: instead of aggregating from all neighbors, sample a random subset of size k each forward pass. For node v with 10,000 neighbors, pick 25 at random. Aggregate those 25. Ignore the rest — for this forward pass.

Why sampling works mathematically: If you sample neighbors uniformly at random, the expected gradient is identical to the full-graph gradient. The estimator is unbiased. You trade exact computation for a noisy-but-correct-in-expectation approximation. In stochastic gradient descent, noise is fine — you're updating with stochastic gradients anyway.

How PinSage Uses This

PinSage (Ying et al., KDD 2018 — mentioned in Lecture 1) trains on Pinterest's graph using importance-based neighbor sampling rather than uniform sampling. Instead of picking k random neighbors, PinSage simulates random walks from v and samples neighbors proportionally to how often they appear in the walk — effectively sampling important neighbors.

The result: PinSage can produce embeddings for 3 billion nodes, something impossible with full-graph message passing. And it generalizes inductively — new pins can be embedded the moment they appear in the graph, by sampling their current neighbors.

The Complexity Win

Without sampling: aggregating for one node costs O(degree) work per layer. A 3-layer GNN on a high-degree node expands to O(degree3) neighbors in the worst case — exponential blowup. With sampling k neighbors per layer: O(k3) neighbors, completely degree-independent. On a graph with max degree 10,000 and k=25, that's 10,0003 vs 253 = 15,000x reduction.

Depth vs breadth tradeoff: With full aggregation, more GNN layers = exponentially more work. With neighbor sampling, more layers = linearly more work (you sample k per layer, so L layers costs O(kL) which is still large but manageable). This makes deep GNNs feasible on large graphs only because of sampling.

What You Lose

Sampling introduces variance. For low-degree nodes (degree 2-3), sampling k=25 neighbors means you always use all of them — no problem. For high-degree nodes (degree 10,000), you see only 0.25% of neighbors per pass. The gradient is noisier, and you might consistently miss important neighbors if they're unlucky in sampling. Practical fix: increase k for high-degree nodes, or use importance-weighted sampling.

python
# Neighbor sampling in PyTorch Geometric (NeighborLoader)
from torch_geometric.loader import NeighborLoader

# Sample up to 25 neighbors per layer, 3 layers deep
loader = NeighborLoader(
    data,
    num_neighbors=[25, 25, 25],  # k per layer, innermost first
    batch_size=512,               # seed nodes per mini-batch
    input_nodes=train_mask,        # which nodes to predict for
    shuffle=True,
)

# Each batch is a subgraph:
# - 512 seed nodes (the ones we want predictions for)
# - Up to 25^3 = 15,625 sampled neighbor nodes
# - Only edges within this sampled subgraph
# Shape of batch.x: [num_sampled_nodes, feature_dim]

for batch in loader:
    out = model(batch.x, batch.edge_index)  # standard GNN forward
    loss = criterion(out[:batch.batch_size], batch.y[:batch.batch_size])
    # Only compute loss on seed nodes (first batch_size rows)
    loss.backward()
Why is neighbor sampling considered an unbiased estimator of the full-graph gradient?

Chapter 4: Prediction Heads — Showcase

Your GNN produces node embeddings hv(L) for every node v. These are vectors in Rd. But your task might be node-level (predict a label per node), edge-level (predict a score per edge), or graph-level (predict one label for the whole graph). Each task needs a different prediction head — the function that converts embeddings into predictions.

Node-Level Prediction

The simplest case. You have one embedding per node. Apply a linear layer and softmax:

ŷv = Softmax( W(H) · hv(L) )

W(H) maps from embedding dimension d to the number of classes c. Shape: [c × d]. That's it — a standard MLP head over individual node embeddings. For regression, drop the softmax.

Edge-Level Prediction

For edges, you need to combine embeddings from two nodes into a single score. You have several options:

Concatenation + linear: Concatenate hu and hv, then apply a linear layer. Shape: W(H) ∈ Rc × 2d. Expressive — can capture asymmetric relationships (u links to v ≠ v links to u).

ŷuv = W(H) · Concat( hu, hv )

Dot product: Just take the inner product hu · hv. Returns one scalar — good for link prediction (will these two nodes connect?). Forces symmetric: score(u,v) = score(v,u).

ŷuv = huT · hv

For k-way edge classification (k types of edge labels), use multi-head dot product: learn k different linear maps W1,...,Wk, compute huT Wi hv for each, softmax over k scores.

Graph-Level Prediction

For graph tasks (e.g., "is this molecule toxic?"), you need to pool all node embeddings into one graph embedding. Three main strategies:

StrategyFormulaWhen to use
Mean poolinghG = (1/|V|) ∑v hvWhen every node contributes equally. Good for homogeneous graphs.
Max poolinghG = maxv hvWhen the most extreme node matters (e.g., detecting if ANY toxic group exists). Ignores most nodes.
Sum poolinghG = ∑v hvWhen graph size matters — sum grows with |V|. Good for counting tasks.
DiffPool (hierarchical)clusters = softmax(Whv), repeatWhen graph has community structure. Most expressive but slow.
Why not just always use mean? Consider two molecules: one has 10 carbon atoms and no toxic groups, another has 10 carbons and one toxic group. Mean pooling might produce very similar embeddings (toxic group diluted by 10x). Max pooling preserves the presence of the extreme group. Task-specific pooling matters.
Prediction Head Explorer

A small graph with 6 nodes. Each node has a 2D embedding (visualized as position). Toggle task level and see how the prediction head combines embeddings. For edge: toggle concat vs dot-product. For graph: toggle mean vs max.

Node task: each node's embedding maps to a class prediction via W·h.
You want to predict whether a molecule contains ANY aromatic ring (a local structural feature). Which graph-level pooling strategy best preserves the signal from a single aromatic node?

Chapter 5: Loss Functions

Once you have predictions ŷ and ground truth labels y, you need a loss function to measure the error and drive backpropagation. The right loss depends on your task type: classification, regression, or link prediction.

Classification: Cross-Entropy

For node or graph classification with c classes, use cross-entropy. The prediction ŷ is a probability vector over c classes (after softmax). The label y is a one-hot vector (all zeros except class k, which is 1).

CE(y, ŷ) = −∑i=1c yi · log(ŷi)

Since y is one-hot, only one term survives the sum: −log(ŷk) where k is the true class. Cross-entropy punishes confidently wrong predictions very harshly (log(0) = −∞) and rewards high confidence on the correct class.

Binary classification is a special case with c=2. You can use binary cross-entropy (BCE) with a single sigmoid output: BCE = −[y log(ŷ) + (1−y) log(1−ŷ)]. This is what link prediction uses — "is this edge real (y=1) or fake (y=0)?"

Regression: Mean Squared Error

For predicting continuous values (e.g., molecular energy, traffic flow), use MSE:

MSE(y, ŷ) = (1/N) ∑i=1N (yi − ŷi)2

MSE penalizes large errors quadratically. If you suspect outlier labels in your dataset, use MAE (mean absolute error) instead — it's more robust to extreme values. The GNN architecture doesn't change; only the loss and final activation change (no softmax for regression — just a linear output).

Link Prediction: BCE on Positive and Negative Edges

For link prediction, you have positive edges (edges that actually exist in the graph) and negative edges — pairs of nodes that are NOT connected. You want connected nodes to have high dot-product similarity and non-connected nodes to have low similarity.

Llink = −∑(u,v)∈E logσ(hu·hv) − ∑(u,v)∉E log(1 − σ(hu·hv))

The first sum pushes positive edge scores up. The second pushes negative edge scores down. Because |E| << |non-edges|, you typically sample the same number of negative edges as positive edges (called negative sampling).

The loss drives the embedding space. Cross-entropy drives nodes of the same class to embed similarly (their linear classifier must separate them). BCE link loss drives connected nodes together and disconnected nodes apart in embedding space. The GNN weights are the same — only what the loss rewards changes.
python
# Three loss functions for three GNN task types
import torch.nn as nn
import torch

# 1. Node/graph classification
ce_loss = nn.CrossEntropyLoss()
loss = ce_loss(logits, labels)  # logits: [N, C], labels: [N] (int)

# 2. Regression
mse_loss = nn.MSELoss()
loss = mse_loss(predictions.squeeze(), targets)  # [N], [N]

# 3. Link prediction (BCE with negative sampling)
bce_loss = nn.BCEWithLogitsLoss()
# pos_scores: dot products for real edges
# neg_scores: dot products for sampled negative edges
pos_scores = (h[src] * h[dst]).sum(dim=-1)  # [E]
neg_scores = (h[neg_src] * h[neg_dst]).sum(dim=-1)  # [E]

scores = torch.cat([pos_scores, neg_scores])
labels_bin = torch.cat([torch.ones(E), torch.zeros(E)])
loss = bce_loss(scores, labels_bin)
For link prediction, why do we need to explicitly sample "negative edges" (node pairs that are NOT connected)?

Chapter 6: Supervised vs Self-Supervised Training

Where do training labels come from? In some domains, getting labeled data is expensive — labeling a protein's function requires wet-lab experiments. In other domains, the graph structure itself is a label generator. This distinction drives whether you use supervised or self-supervised training.

Supervised Training

In supervised training, you have explicit task labels provided externally: node class labels (e.g., paper topic in a citation network), graph property labels (e.g., molecule toxicity measured in a lab), or edge labels (e.g., interaction type). You compute a loss directly against these labels.

The GNN is trained end-to-end to minimize task loss. The embeddings it produces are optimized for the specific labeled task. This is the most direct path to good task performance — if you have labels. The problem: labels are often scarce, expensive, or impossible to scale.

Self-Supervised Training

Self-supervised methods treat the graph structure itself as the source of supervision signals. No external labels needed. The two main variants:

Link prediction as pretext task: Mask some edges. Train the GNN to predict which edges were removed. Loss = BCE on real vs masked edges. After training, the embeddings encode "which nodes tend to connect" — useful for recommendation and social network tasks.

Node clustering pretext task: Assign initial cluster IDs using a simple algorithm (e.g., spectral clustering or Louvain). Train the GNN to predict which cluster a node belongs to. Embeddings learn to encode community structure — useful for social networks and biological networks.

Self-supervised = self-supervised, NOT unsupervised. "Unsupervised" usually means no objective at all (dimensionality reduction, k-means). Self-supervised means the supervision signal comes from the data itself rather than external labels. The distinction matters: self-supervised GNNs can transfer to labeled tasks much better than truly unsupervised methods.

Semi-Supervised: The Best of Both

Often you have a large graph where only a small fraction of nodes are labeled. Semi-supervised training uses both: a self-supervised loss on all nodes (no labels needed) plus a supervised loss on the labeled subset. The self-supervised term regularizes the model so it doesn't overfit to the few labeled examples.

Ltotal = Lsupervised + λ · Lself-supervised

The hyperparameter λ controls the balance. Even setting λ = 0.1 can dramatically improve performance when labels are scarce — the self-supervised signal forces the model to use graph structure rather than memorize the few labeled nodes.

When to use what: Labels abundant → pure supervised. Labels scarce but graph large → self-supervised pretraining, then fine-tune on labeled data. Labels sparse but graph available → semi-supervised. Labels don't exist at all → self-supervised only (embeddings as output).
python
# Semi-supervised GNN training
model = GNN(in_dim=64, hidden=128, out_dim=7)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(200):
    model.train()
    embeddings = model(data.x, data.edge_index)  # [N, 128]

    # Supervised: only on labeled nodes
    sup_loss = ce_loss(embeddings[train_mask], data.y[train_mask])

    # Self-supervised: link prediction on all nodes
    pos_s = (embeddings[src] * embeddings[dst]).sum(-1)
    neg_s = (embeddings[neg_src] * embeddings[neg_dst]).sum(-1)
    self_loss = bce(torch.cat([pos_s, neg_s]),
                    torch.cat([ones_E, zeros_E]))

    loss = sup_loss + 0.1 * self_loss
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
A biologist has a protein interaction graph with 50,000 proteins but only 200 experimentally verified function labels. Which training strategy is most appropriate?

Chapter 7: Dataset Splitting

In standard deep learning (images, text), train/val/test splitting is straightforward: shuffle examples, take 80/10/10. But graphs have dependencies between examples — nodes are connected. Randomly splitting nodes or edges can leak information across splits. Graph splitting requires careful thought.

Transductive vs Inductive Settings

The first question is not how to split, but what you're splitting. In a transductive setting, you have one graph. You train on some of its nodes/edges, validate on others, and test on others — but the GNN sees the ENTIRE graph structure (all edges) during training. Only labels are withheld, not graph topology.

In an inductive setting, you have multiple graphs. You train on some graphs, validate on others, and test on held-out graphs the model has never seen. The model must generalize to entirely new graph structures, not just new labels on a known graph.

Transductive ≠ cheating. It's a valid and common setting — for example, training a node classifier on a social network where the graph structure is fixed and known. The GNN can use edge information from test nodes during training (for message passing) — you just don't use test node labels. This is how GCN and GraphSAGE were originally evaluated.

Node Classification Splitting

For transductive node classification: split LABELS, not the graph. Assign each node to train/val/test. During training, message passing uses all edges (including those touching val/test nodes). Only the loss is computed on train-masked nodes. This is the standard split for Cora, Citeseer, and PubMed benchmarks.

Link Prediction Splitting — The Hard Part

Link prediction is the trickiest split. You must split EDGES into train/val/test. But removing edges changes the graph structure used for message passing. The standard approach: split edges into two sets — "message passing edges" (used for GNN computation) and "supervision edges" (used only for the loss).

All edges E
Split into: Etrain (85%), Eval (5%), Etest (10%). These are supervision edges — labels for the loss.
Message passing graph
Use only Etrain for message passing during ALL splits. Val and test edges are INVISIBLE to the GNN topology.
Loss computation
Train: predict Etrain edges. Val: predict Eval edges (while still using only Etrain for message passing). Test: predict Etest.
Why not just use all edges for message passing? If the GNN sees a val/test edge during message passing, the two endpoints share information through that edge. Predicting the edge becomes trivial — the model already "knows" it exists via message passing. You'd be testing on training data in disguise.

Graph Classification Splitting

For datasets with multiple graphs (e.g., 1000 molecules), split is simple: randomly assign entire graphs to train/val/test. No information leaks across graphs — they're independent. This is the fully inductive setting. Make sure to split at the graph level, not at the node level, or you'll have partial graphs in test.

Transductive vs Inductive Split Visualizer

Toggle between the two splitting strategies on a single graph. Warm = train nodes. Teal = val nodes. Blue = test nodes. Green edges = used for message passing. Red dashed edges = supervision-only (hidden from GNN topology).

Transductive: entire graph visible during training. Only labels withheld for val/test nodes.
For link prediction, why must you NOT include val/test edges in the message passing graph?

Chapter 8: Putting It All Together

Every GNN project is a sequence of design decisions, and the choices interact. Add a virtual node and you might need fewer layers. Use neighbor sampling and you must adjust your batch size accordingly. This chapter traces the complete pipeline and its decision points.

The Full GNN Pipeline

1. Raw Input Graph
Nodes V, edges E, optional raw features Xraw. Decide: transductive or inductive? Node, edge, or graph task?
2. Feature Augmentation
If no features: add constant, structural (degree, clustering coeff), or one-hot IDs. If partial features: impute. If good features: normalize (zero mean, unit variance).
3. Structure Augmentation
Sparse graph: add virtual node. Dense graph (high-degree nodes): plan for neighbor sampling. Always add self-loops (so nodes aggregate their own features).
4. K Rounds of Message Passing
Choose MSG + AGG (GCN, GraphSAGE, or GAT). Choose K layers based on task range: local task K=2, global task K=3-4. Add batch norm and residual connections for K≥3.
5. Prediction Head
Node task: linear W·hv. Edge task: concat+linear or dot product. Graph task: mean/max/sum pool then linear.
6. Loss + Backprop
Classification: cross-entropy. Regression: MSE. Link prediction: BCE with negative sampling. Optional: + self-supervised loss term.
↻ repeat for each mini-batch
7. Trained GNN
Evaluate on val set. Report final metrics on test set. Never tune hyperparameters on test set.

The Design Space Paper's Main Finding

You, Ying, and Leskovec (NeurIPS 2020) evaluated over 315,000 GNN configurations by systematically varying: aggregation (mean/sum/max), activation (ReLU/PReLU), normalization (none/batch/layer), skip connections (none/standard/dense), and K. Their findings:

No universal winner. The optimal design on molecular tasks is completely different from the optimal design on social network tasks. Neither GCN nor GAT is uniformly best. Mean aggregation beats sum on most benchmarks but not all.

Simple designs often win. Many fancy GNN variants were beaten by a well-tuned simple GCN with the right normalization and skip connections. "Design space" thinking helps you find the right simple design for your problem.

Practical starting point: Start with GraphSAGE (mean aggregation) + 2 layers + batch norm + residual connections + mean pooling for graph tasks. This is a strong baseline that beats many specialized architectures out of the box. Add complexity only when this baseline is insufficient.
python
# Complete minimal GNN training loop
from torch_geometric.nn import SAGEConv, global_mean_pool
import torch.nn.functional as F

class GNNPipeline(nn.Module):
    def __init__(self, in_dim, hidden, out_dim):
        super().__init__()
        # Feature pre-processing MLP (optional but often helpful)
        self.pre = nn.Linear(in_dim, hidden)
        # K=2 GNN layers (SAGEConv = GraphSAGE mean aggregation)
        self.conv1 = SAGEConv(hidden, hidden)
        self.conv2 = SAGEConv(hidden, hidden)
        self.bn1 = nn.BatchNorm1d(hidden)
        self.bn2 = nn.BatchNorm1d(hidden)
        # Prediction head (graph-level: pool then linear)
        self.head = nn.Linear(hidden, out_dim)

    def forward(self, x, edge_index, batch):
        x = F.relu(self.pre(x))          # [N, hidden]
        x = F.relu(self.bn1(self.conv1(x, edge_index)))  # 1-hop
        x = F.relu(self.bn2(self.conv2(x, edge_index)))  # 2-hop
        x = global_mean_pool(x, batch)   # [B, hidden] — graph-level
        return self.head(x)              # [B, out_dim] — predictions

# Training
model = GNNPipeline(in_dim=9, hidden=256, out_dim=2)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for batch in train_loader:
    optimizer.zero_grad()
    logits = model(batch.x, batch.edge_index, batch.batch)
    loss = F.cross_entropy(logits, batch.y)
    loss.backward()
    optimizer.step()
The Design Space paper (You et al., 2020) found that across 315,000 GNN configurations, which architecture was universally best?

Chapter 9: Connections & What's Next

Every concept in this lecture connects outward — to other GNN ideas, to other fields, and to open problems. Follow the thread that interests you most.

Backward: Lectures 3 and 4

Lecture 3 introduced message passing and GCN. Lecture 4 showed that all GNNs are MSG+AGG choices. This lecture showed that before you even choose your MSG+AGG, you must decide what graph to feed it — and that choice (augmentation) matters as much as the architecture.

Forward: Lecture 6 — How Expressive Are GNNs?

This lecture took the practical view: given augmentation and training choices, build a working GNN. Lecture 6 asks the theoretical question: what can and cannot a GNN compute? The answer involves the Weisfeiler-Lehman graph isomorphism test — GNNs are provably limited to the same expressiveness as the 1-WL test. Some structural features (like cycle membership) cannot be captured by standard message passing at all — which is why feature augmentation (adding cycle lengths explicitly) can help so dramatically.

Data Augmentation Parallels

Feature augmentation in GNNs mirrors data augmentation in computer vision. Random crops and flips make images more diverse — neighbor sampling makes graph processing more scalable. Adding structural features to nodes is like adding edge detectors as input channels to a CNN. The goal is the same: make the input representation richer and more informative for the task.

Neighbor sampling = stochastic gradient descent for graphs. Full-graph training is like computing the exact gradient on all data before taking a step. Neighbor sampling is like taking a mini-batch step — noisier but much faster. The underlying optimization theory is the same.

PinSage Revisited

PinSage (Lecture 1) now makes complete sense in this framework: it uses (1) neighbor sampling for scalability, (2) importance-based sampling instead of uniform, (3) inductive setting (new pins can be embedded on demand), and (4) a custom triplet loss (a form of self-supervised link prediction). Every technique in this lecture is present in PinSage — it's the reference implementation of scalable real-world GNNs.

Where the Field Is Going

The current frontier: (1) graph transformers that attend globally without being bottlenecked by graph structure, (2) better negative sampling strategies for link prediction (hard negatives), (3) self-supervised pretraining at scale (pretrain on large graphs, fine-tune with few labels), and (4) equivariant GNNs that respect 3D geometry for molecular design.

TechniqueSolvesCostUse when
Constant featuresNo featuresMinimalStructure-only tasks
Structural featuresNo features, inductivePreprocessing timeInductive setting
One-hot node IDNo features, transductiveMemory O(N²)Small graphs, transductive only
Virtual nodeSparse graph, long-range+1 node, smallGraph-level tasks
Neighbor samplingDense/large graphGradient noiseMillions of nodes
Semi-supervised lossFew labelsExtra loss termLabel-scarce settings
Inductive splitGeneralization testNeed multiple graphsProduction GNN systems

Related Lessons

CS224W Lec 3: GNN Basics — message passing from scratch
CS224W Lec 4: GNN Design Space — MSG+AGG framework, GCN vs SAGE vs GAT
CS224W Lec 2: Node Embeddings — DeepWalk and Node2Vec as self-supervised GNN precursors

"What I cannot create, I do not understand."
— Richard Feynman. After this lesson: you can create a GNN pipeline. Feature augmentation for missing data. Structure augmentation for graph topology problems. Appropriate prediction heads for node/edge/graph tasks. Correct dataset splits that don't leak information. That's the full engineering picture.