CS224W Lecture 14

Advanced Topics in GNNs

GNNs that scale to billions of nodes, explain their predictions, respect 3D geometry, and adapt to graphs that change over time — the frontier beyond message passing basics.

Prerequisites: GNN message passing + basic graph concepts. That's it.
10
Chapters
5+
Simulations
0
Assumed Knowledge

Chapter 0: The Scaling Challenge

You've trained a GNN on a citation network with 50,000 papers. It works beautifully. Now your company wants to apply it to its product graph — 500 million items, 10 billion interactions. Your training loop crashes on the first mini-batch. Memory error.

Standard GNN message passing has a brutal scaling problem. To compute the embedding of a single node, you need its neighbors. For 2-layer GNN, you need the neighbors of neighbors. On a graph where each node has 50 neighbors on average, a 2-layer computation touches 50² = 2,500 nodes per node. Training a batch of 1,000 nodes requires 2.5 million nodes to fit in GPU memory. That's fine at research scale. At Facebook scale, it's impossible.

This chapter frames four scaling problems that the rest of Lecture 14 attacks. You can't solve them without first understanding exactly where the bottleneck lives.

The Four Bottlenecks

Neighborhood explosion. k-hop neighborhoods grow exponentially. A 3-layer GNN on a social graph doesn't just touch your friends — it touches friends-of-friends-of-friends. On a scale-free graph, this can mean 80% of all nodes within 3 hops. Every forward pass is a full-graph operation.
Memory wall. GPU memory is measured in gigabytes. The world's largest graphs are measured in terabytes. You cannot load the full adjacency matrix. You cannot load all node features. Mini-batching must be redesigned from scratch for graphs.
Heterogeneous schemas. Real graphs mix many node types (users, products, reviews, categories) and edge types (bought, viewed, reviewed). Each type needs different weight matrices. The number of parameters grows with schema complexity, not just graph size.
Retraining cost. A new graph means new structure, new features, possibly new task. In NLP, you fine-tune a pre-trained transformer. In GNNs, you typically retrain from scratch. Pre-training a graph foundation model would eliminate this cost — if it's possible.
Neighborhood Explosion Visualizer

Drag the depth slider to see how the 2-hop, 3-hop, and 4-hop neighborhoods of a central node grow. Each ring shows how many unique nodes you must process for one training example.

GNN Depth (k) 2
A social graph has average degree 100. How many nodes must a 3-layer GNN touch to compute one node's embedding (ignoring overlap)?

Chapter 1: Graph Foundation Models

In NLP, one model — a pre-trained transformer — works for sentiment analysis, translation, question answering, and summarization. You don't train a separate LSTM for each task. You fine-tune. This is the foundation model paradigm: one giant model, trained on massive data, that transfers to many downstream tasks.

Can we do this for graphs? This is the question that graph foundation models try to answer. The dream: train one model on millions of graphs from biology, chemistry, social networks, and relational databases. Then, for a new graph task, fine-tune for a few hundred steps instead of training from scratch.

The challenge is brutal. Text is text — the input token space is shared across all documents. But graphs from different domains have completely different node features, different edge semantics, different structural patterns. A protein interaction graph and a citation network share nothing obvious. There's no natural "vocabulary" to unify them.

Three Strategies for Unification

Researchers have proposed three approaches to make foundation models work across graph domains:

Strategy 1: Text-Attributed Graphs
Every node and edge gets a text description. Use a language model to encode these descriptions into a shared embedding space. The LLM vocabulary is the shared "language" across graphs.
Strategy 2: Structure-Only Pre-training
Ignore node features entirely. Pre-train on structural patterns: triangle counts, degree distributions, motifs. The model learns that "hub node in a star pattern" means something regardless of domain.
Strategy 3: Hybrid Tokenization
Serialize graph structure into a sequence (like code). Feed to a transformer that already understands sequences. No custom GNN needed — the transformer is the foundation model.
The key bottleneck is feature alignment. Strategy 1 requires that every graph has text metadata — not always true. Strategy 2 discards node features that might be essential. Strategy 3 depends on how well transformers generalize to graph-serialized sequences. None is perfect yet. Active research area.
Foundation Model Transfer Concept

Click different "source graph" domains to see which structural and semantic features transfer to a target citation graph. Green = transferable. Red = domain-specific. Orange = partially transferable.

Why can't you simply train a GNN on many different graphs and call it a "foundation model"?

Chapter 2: Pre-training Strategies

In NLP, masked language modeling is a brilliant pre-training objective: mask a random token, predict it. No labels needed. The signal comes from the data itself. This is called self-supervised learning — the supervision signal is derived from the structure of the input, not from human annotations.

Can we design analogous self-supervised objectives for graphs? The challenge: what is the "masked token" equivalent for a graph? There are several answers, each with tradeoffs.

Self-Supervised Objectives for Graphs

Node attribute masking. Randomly zero out some fraction of node features. Train the GNN to reconstruct them. The GNN must propagate information from neighboring nodes to infer the missing features — a natural structural task. Analogous to BERT's masked token prediction.
Context prediction. Choose a central node. Define its "context" as its k-hop neighborhood subgraph. Train the GNN to predict: does this subgraph context belong to this central node? Negative examples: randomly pair subgraphs with wrong center nodes. Forces the GNN to learn meaningful local structure.
Graph-level contrastive learning. Take a graph. Create two augmented views (drop edges, perturb features, mask nodes). Train the encoder so the two views of the same graph are close in embedding space, while different graphs are far apart. Like SimCLR, but for graphs.
Edge prediction. Remove some edges. Train the GNN to predict whether each removed edge exists. This is classic link prediction used as a pre-training signal — the model must learn that connected nodes share properties.

The right pre-training strategy depends on what you want to transfer. Node-level objectives (attribute masking, context prediction) transfer well to node classification tasks. Graph-level objectives (contrastive learning) transfer well to graph classification. There's no one-size-fits-all.

python
# Node attribute masking pre-training
class MaskGNN(nn.Module):
    def __init__(self, gnn, feat_dim):
        super().__init__()
        self.gnn = gnn
        # Reconstruction head: map embedding back to feature space
        self.decoder = nn.Linear(gnn.hidden_dim, feat_dim)

    def forward(self, x, edge_index, mask):
        # mask: boolean tensor, True = masked node
        x_masked = x.clone()
        x_masked[mask] = 0          # zero out masked features
        h = self.gnn(x_masked, edge_index)  # encode [N, hidden]
        recon = self.decoder(h[mask])       # decode only masked nodes
        target = x[mask]                    # original features
        return F.mse_loss(recon, target)    # reconstruct them

# The GNN backbone is pre-trained with no labels.
# After pre-training, discard decoder. Fine-tune GNN on downstream task.
You pre-train a GNN with graph-level contrastive loss on 10,000 molecular graphs. You then fine-tune for node-level protein function prediction. Why might this pre-training not help much?

Chapter 3: Transfer Learning for Graphs

Pre-training is only half the story. The other half is transfer: does knowledge learned on one graph actually help on a different graph with a different task? In NLP, transfer is nearly always helpful. In graphs, it's more complicated — and understanding why tells you when to expect it to work.

Positive transfer happens when the source and target graphs share structural regularities. Biological networks and social networks both contain hub-and-spoke patterns, motifs, community structure. A GNN pre-trained on biological networks may have learned to detect these patterns — and they're genuinely useful on social data too.

Negative transfer happens when the source and target differ enough that pre-trained representations mislead the model. If pre-training pushes nodes with high degree to have similar embeddings (because degree predicts many biological properties), but your target task requires distinguishing high-degree nodes that serve different semantic roles, the pre-training actively hurts.

The Domain Shift Problem

The key question for transfer: do the useful structural patterns in the source domain align with the useful structural patterns in the target domain? If yes — transfer helps. If no — transfer hurts. This is not unique to graphs: image classifiers pre-trained on ImageNet transfer well to medical images (both have local textures) but poorly to satellite imagery (different statistical structure).

Strategies That Actually Work

StrategyWhen to useMechanism
Fine-tune all layersLarge target dataset, similar domainUpdate all weights from pre-trained initialization
Freeze GNN, tune headSmall target dataset, similar domainGNN embeddings are fixed; only the prediction head trains
Linear probeEvaluate transfer qualityTrain only a linear layer on top of frozen embeddings — if this works, the representations are good
Prompt tuningEmerging: very small target dataLearn a small "graph prompt" prepended to input; GNN weights unchanged
Graph prompt tuning is the emerging frontier. Analogous to soft prompts in NLP: instead of updating GNN weights, you learn a small set of virtual nodes/edges that are prepended to the input graph. The GNN processes the prompted graph. Only the prompt parameters are updated — a tiny fraction of total parameters. Works well when target data is very scarce.
You have a pre-trained GNN and 50 labeled examples in your target domain. Which transfer strategy is most appropriate?

Chapter 4: GNN Scalability

Let's solve the neighborhood explosion problem from Chapter 0. The cause: to train on node v, vanilla message passing needs all k-hop neighbors of v. As k grows, this set grows exponentially. The solution family is graph sampling: instead of using all neighbors, sample a fixed-size subset.

There are three major sampling strategies, each with a different tradeoff between accuracy and efficiency. Understanding the tradeoff helps you pick the right one for your use case.

Neighbor Sampling (GraphSAGE)

At each GNN layer, for each node, sample at most k neighbors rather than using all of them. If a node has 200 neighbors, sample 10 at random. The computation tree is now bounded: layer 1 touches at most 10 neighbors, layer 2 touches at most 100 (10×10), not 40,000 (200×200). Memory cost becomes O(k^L) per node — controllable.

The trick: you're computing a stochastic estimator of the true message. If you sample uniformly, the expected message equals the true average-pooled message. So the loss gradient is unbiased, even though each mini-batch only sees a small fraction of the graph. This is the same principle as stochastic gradient descent: noisy but unbiased.

Cluster-GCN

Cluster-GCN takes a different approach. Instead of sampling individual neighbors, it partitions the entire graph into dense clusters using METIS or similar algorithms. Each mini-batch is one cluster. Within a cluster, all neighbors are available — no sampling needed. Between clusters, edges are dropped.

The benefit: within-cluster message passing is exact. The cost: between-cluster information is lost, introducing bias. Works well when the graph has natural community structure (social networks, citation graphs). Breaks down on graphs with many long-range dependencies.

Stochastic Depth / Layer Dropout

For very deep GNNs, randomly drop entire GNN layers during training (similar to DropPath in vision transformers). Each node sees a different number of message-passing steps in different batches. This regularizes training and reduces memory, but doesn't fundamentally solve the neighborhood explosion — you still need neighbor sampling at each kept layer.

python
# GraphSAGE-style neighbor sampling with PyG NeighborLoader
from torch_geometric.loader import NeighborLoader

loader = NeighborLoader(
    data,
    num_neighbors=[10, 10],   # sample 10 neighbors per layer, 2-layer GNN
    batch_size=512,            # 512 seed nodes per batch
    input_nodes=train_mask,     # only sample from training nodes
)

for batch in loader:
    # batch.num_nodes: seed nodes + their sampled 1-hop + sampled 2-hop
    # Worst case: 512 + 512×10 + 512×10×10 = 56,832 nodes
    # Not 512 × (entire graph). Tractable.
    out = model(batch.x, batch.edge_index)
    loss = criterion(out[:batch.batch_size], batch.y[:batch.batch_size])
    loss.backward()
GraphSAGE with neighbor sampling produces unbiased gradient estimates. What makes it "unbiased"?

Chapter 5: GNN Explainability (SHOWCASE)

A GNN predicts that molecule X is toxic. A doctor wants to know: which atoms caused this prediction? A fraud analyst whose GNN flagged a transaction asks: which edges in the transaction graph triggered the alert? This is GNN explainability — identifying which nodes and edges are most responsible for a specific prediction.

Without explainability, a GNN is a black box. You can't debug it, trust it in high-stakes settings, or use its predictions to generate scientific hypotheses. The field of GNN explainability has produced several methods, each with a different theory of what "explanation" means.

GNNExplainer

GNNExplainer (Ying et al., 2019) asks: what is the minimal subgraph that, when fed to the GNN, produces roughly the same prediction as the full graph? It optimizes a soft mask over edges and node features to find this minimal subgraph. The mask is a differentiable parameter — you backpropagate through the GNN to train it.

The key insight: instead of asking "what does this GNN compute?" (hard — requires opening the black box), ask "what inputs matter for this prediction?" (tractable — keep the GNN frozen, optimize the inputs). This is input attribution, not mechanistic interpretability.

Gradient-Based Attribution

Simpler approach: compute the gradient of the prediction score with respect to each edge's presence. Edges with large gradient magnitude are "important" — perturbing them changes the prediction most. Fast to compute (one backward pass). Noisier than GNNExplainer but requires no separate optimization loop.

Attention as Explanation

If your GNN uses attention (GAT), the attention weights over edges are a natural "importance score." High attention weight on edge (u, v) means node u relied heavily on node v's features. Simple, free to compute. But attention ≠ importance in general — a node can attend heavily to a neighbor and still not change its output much.

GNNExplainer Interactive Demo

A small graph where nodes are classified. Click a node to select it, then click "Explain" to run GNNExplainer — watch as the mask optimizes to find the minimal explanatory subgraph. Edge brightness = mask weight (how important is this edge?). Use the slider to set the sparsity of the explanation.

Sparsity 0.50
GNNExplainer objective: maximize I(Y ; GS) subject to |GS| ≤ K. In words: find the smallest subgraph GS that maximizes mutual information between the subgraph and the prediction Y. Equivalently: find the fewest edges that make the model most confident — the "core" of the prediction.
GNNExplainer keeps the GNN frozen and optimizes a mask over edges/nodes. Why keep the GNN frozen?

Chapter 6: Equivariant GNNs

You're building a GNN to predict the energy of a molecule from its 3D structure. You represent each atom as a node with a 3D position (x, y, z). If you rotate the molecule 90 degrees, the energy doesn't change — it's the same molecule. But a standard GNN would compute a different embedding after rotation, because (x, y, z) coordinates change. This is wrong.

This is the problem of geometric equivariance. A model is equivariant to a transformation (like rotation) if rotating the input rotates the output in a predictable way. A model is invariant if the output doesn't change at all under the transformation. For molecular energy prediction, we want invariance: rotate the molecule, energy stays the same.

Why Standard GNNs Fail on 3D Graphs

The coordinate problem. If you naively include (x, y, z) as node features, the model is not rotation-invariant. Two identical molecules rotated differently produce different outputs — a physically meaningless distinction. The model wastes capacity learning that rotations don't matter, and still doesn't do it perfectly.

The fix is to use geometric quantities that are invariant or equivariant by construction. Interatomic distances are invariant to rotation and translation (rotating two atoms doesn't change the distance between them). Bond angles are invariant. Torsion angles are invariant. These are the right features for a molecular GNN.

EGNN and SchNet

SchNet uses only interatomic distances as edge features. It's invariant to rotation by construction. But it loses directional information — it can't distinguish a molecule's 3D shape from a different molecule with the same pairwise distances (which can happen).

EGNN (Equivariant Graph Neural Network) is more powerful. It maintains a feature vector and a 3D coordinate for each node. Messages update both the features and the coordinates. The coordinate updates are designed to be equivariant: rotating the input rotates the coordinates by the same rotation. No information is lost.

hil+1 = φh(hil, ∑j mij),   xil+1 = xil + C ∑j (xil − xjl) φx(mij)

The coordinate update is a weighted sum of displacement vectors (xi − xj). If you rotate all positions by a rotation matrix R, the displacement vectors rotate by R, so the update rotates by R. The model is equivariant — outputs rotate with inputs.

python
# EGNN coordinate update (simplified)
def coord_update(pos, edge_index, h, phi_x):
    src, dst = edge_index
    diff = pos[src] - pos[dst]       # [E, 3]: displacement vectors
    dist = diff.norm(dim=-1, keepdim=True)  # [E, 1]: distances

    # Message: combine source/dest features + distance
    m_ij = phi_x(h[src], h[dst], dist)  # [E, 1]: scalar weight

    # Weighted sum of displacement vectors → equivariant update
    agg = scatter_mean(diff * m_ij, dst, dim=0)  # [N, 3]
    return pos + agg   # rotate input → rotate output by same rotation
Why are interatomic distances "invariant" to rotation, but (x, y, z) coordinates are not?

Chapter 7: Dynamic Graphs

A friendship graph from 2010 is not the same as the friendship graph from 2024. Edges appear and disappear. Node properties evolve. New nodes join. Standard GNNs are designed for static snapshots — they have no notion of time. For many real applications, this is the critical missing piece.

Dynamic graphs (also called temporal graphs) are graphs where edges carry timestamps. The edge (u, v, t) means u and v interacted at time t. This richer representation enables new tasks: predict whether u and v will interact in the next hour. Detect when u's behavior pattern changed. Rank which past interactions most influenced u's current state.

Two Representations

There are two ways to represent a dynamic graph, leading to two families of models:

Discrete snapshots. Divide time into bins (daily, hourly). Each bin is a static graph snapshot. Train a GNN on each snapshot; stack a recurrent network (LSTM or GRU) across snapshots to model evolution. The GNN handles structure within each time step; the RNN handles evolution across time steps.
Continuous-time events. Treat the graph as a stream of timestamped events. When edge (u, v, t) arrives, update the representations of u and v immediately. No discretization needed. TGN (Temporal Graph Network) is the main architecture here.

Temporal Graph Networks (TGN)

TGN (Rossi et al., 2020) maintains a memory vector for each node — a running summary of everything that has happened to that node so far. When event (u, v, t) arrives: (1) retrieve memories of u and v, (2) compute messages from the event, (3) update memories, (4) compute temporal embeddings using a GNN on the temporal neighborhood.

Event arrives: (u, v, t)
User u interacted with item v at time t
Message computation
m_u = msg(s_u, s_v, Δt, e_uv) — combine memories and time gap
Memory update
s_u ← GRU(s_u, m_u) — update u's memory with the new event
Temporal embedding
z_u = GNN(s_u, temporal neighbors of u) — embed with context
The time encoding trick: TGN encodes the time gap Δt = t - tlast as a learnable function: cos(wi · Δt + bi), analogous to positional encodings in transformers. This lets the model learn that "interacted 5 minutes ago" and "interacted 5 years ago" have different implications, without hardcoding the timescale.
TGN's "memory" vector for each node serves what purpose?

Chapter 8: Practical GNN Tips

GNNs fail in practice for reasons that have nothing to do with the theory. This chapter is the hard-won knowledge that doesn't appear in papers — the debugging patterns and design choices that separate working implementations from broken ones.

Over-Smoothing: The Depth Problem

Add more GNN layers → nodes see larger neighborhoods → more context → better predictions. Right? Wrong. After too many layers, all node embeddings converge to the same vector. This is over-smoothing: repeated averaging makes every node look like the graph average. In practice, 2-4 layers is almost always optimal. Adding a 5th layer often hurts.

Diagnosis test: compute the mean cosine similarity between all pairs of node embeddings in your last GNN layer. If it's above 0.95, you have over-smoothing. The fix: reduce depth, add residual connections (skip connections that let each layer's output partially bypass the aggregation), or use jumping knowledge (concatenate embeddings from all layers, not just the last).

Over-Squashing: The Bottleneck Problem

Over-squashing is subtler. To transmit information from a node 4 hops away, that information must be compressed into a single vector as it passes through each intermediate node. If the intermediate nodes have high degree (many neighbors), the 4-hop information gets diluted. Long-range dependencies are poorly captured. Adding more layers doesn't fix this — the bottleneck is at high-degree nodes.

Fix for over-squashing: graph rewiring — add "shortcut" edges between nodes that need to communicate but are far apart. Algorithms like SDRF (Stochastic Discrete Ricci Flow) detect bottleneck edges and add shortcuts automatically. Or use transformers with full attention, paying quadratic cost but eliminating the squashing problem entirely.

Practical Checklist

ProblemSymptomFix
Over-smoothingDeep GNN worse than shallowFewer layers, residual connections, jumping knowledge
Over-squashingPoor long-range task performanceGraph rewiring, graph transformers
Feature scaleTraining diverges or very slowNormalize node features; BatchNorm after each GNN layer
Data leakageTest accuracy >> val accuracyInductive split — test nodes not seen during training
Message collisionSum aggregation hides structureUse mean or max aggregation; or principal neighborhood aggregation (PNA)
The most important debugging step: before training any GNN, verify that a 2-layer MLP on node features alone gives a reasonable baseline. If the MLP already achieves 90% accuracy, the graph structure isn't helping — you don't need a GNN. If the MLP achieves 60% and the GNN achieves 85%, the graph structure is genuinely useful.
Your GNN's node embeddings after layer 6 all have cosine similarity > 0.97 with each other. What has happened and what should you try?

Chapter 9: Connections

Advanced GNNs don't exist in isolation — they connect to the broader landscape of machine learning. Knowing these connections lets you borrow solutions from adjacent fields instead of reinventing them.

Where These Ideas Come From

GNN ConceptAnalogous ML ConceptKey difference
Graph pre-trainingBERT / GPT pre-trainingNo natural token vocabulary; structure varies across graphs
Neighbor samplingSGD mini-batchingDependencies between samples (graph structure) complicate i.i.d. assumption
Equivariant GNNsCNNs (translation equivariance)Rotation group vs. translation group; 3D vs. 2D
Dynamic graphs (TGN)Sequence models (LSTM, Transformer)Events occur on a graph, not a linear sequence
GNNExplainerLIME / SHAP for tabular dataExplanation is a subgraph, not a feature importance vector
Graph prompt tuningSoft prompts (prefix tuning)Prompt is graph-structured, not a sequence of vectors

What Comes Next in CS224W

Lecture 15 goes deeper on graph foundation models specifically for knowledge graphs — graphs where nodes are entities and edges are named relations. The question becomes: can one model reason over any KG schema without retraining? → KG Foundation Models
Lecture 16 combines GNNs with large language models. LLMs can provide rich text-based node features; GNNs can provide structural context that LLMs lack. The two architectures are complementary, and combining them achieves things neither can do alone. → LLM + GNN

Related Lessons

→ GNN Theory — The Weisfeiler-Lehman test and why some graphs are indistinguishable to standard GNNs.
→ Heterogeneous GNNs — Graphs with multiple node and edge types, which underlie most foundation model challenges.
→ Advanced RDL — Universal encoders for relational databases — the schema-transfer problem in a structured data context.
"The goal of science is not to open a door to infinite wisdom, but to set a limit to infinite error."
— Bertolt Brecht (apt for GNN explainability)