GNNs that scale to billions of nodes, explain their predictions, respect 3D geometry, and adapt to graphs that change over time — the frontier beyond message passing basics.
You've trained a GNN on a citation network with 50,000 papers. It works beautifully. Now your company wants to apply it to its product graph — 500 million items, 10 billion interactions. Your training loop crashes on the first mini-batch. Memory error.
Standard GNN message passing has a brutal scaling problem. To compute the embedding of a single node, you need its neighbors. For 2-layer GNN, you need the neighbors of neighbors. On a graph where each node has 50 neighbors on average, a 2-layer computation touches 50² = 2,500 nodes per node. Training a batch of 1,000 nodes requires 2.5 million nodes to fit in GPU memory. That's fine at research scale. At Facebook scale, it's impossible.
This chapter frames four scaling problems that the rest of Lecture 14 attacks. You can't solve them without first understanding exactly where the bottleneck lives.
Drag the depth slider to see how the 2-hop, 3-hop, and 4-hop neighborhoods of a central node grow. Each ring shows how many unique nodes you must process for one training example.
In NLP, one model — a pre-trained transformer — works for sentiment analysis, translation, question answering, and summarization. You don't train a separate LSTM for each task. You fine-tune. This is the foundation model paradigm: one giant model, trained on massive data, that transfers to many downstream tasks.
Can we do this for graphs? This is the question that graph foundation models try to answer. The dream: train one model on millions of graphs from biology, chemistry, social networks, and relational databases. Then, for a new graph task, fine-tune for a few hundred steps instead of training from scratch.
The challenge is brutal. Text is text — the input token space is shared across all documents. But graphs from different domains have completely different node features, different edge semantics, different structural patterns. A protein interaction graph and a citation network share nothing obvious. There's no natural "vocabulary" to unify them.
Researchers have proposed three approaches to make foundation models work across graph domains:
Click different "source graph" domains to see which structural and semantic features transfer to a target citation graph. Green = transferable. Red = domain-specific. Orange = partially transferable.
In NLP, masked language modeling is a brilliant pre-training objective: mask a random token, predict it. No labels needed. The signal comes from the data itself. This is called self-supervised learning — the supervision signal is derived from the structure of the input, not from human annotations.
Can we design analogous self-supervised objectives for graphs? The challenge: what is the "masked token" equivalent for a graph? There are several answers, each with tradeoffs.
The right pre-training strategy depends on what you want to transfer. Node-level objectives (attribute masking, context prediction) transfer well to node classification tasks. Graph-level objectives (contrastive learning) transfer well to graph classification. There's no one-size-fits-all.
python # Node attribute masking pre-training class MaskGNN(nn.Module): def __init__(self, gnn, feat_dim): super().__init__() self.gnn = gnn # Reconstruction head: map embedding back to feature space self.decoder = nn.Linear(gnn.hidden_dim, feat_dim) def forward(self, x, edge_index, mask): # mask: boolean tensor, True = masked node x_masked = x.clone() x_masked[mask] = 0 # zero out masked features h = self.gnn(x_masked, edge_index) # encode [N, hidden] recon = self.decoder(h[mask]) # decode only masked nodes target = x[mask] # original features return F.mse_loss(recon, target) # reconstruct them # The GNN backbone is pre-trained with no labels. # After pre-training, discard decoder. Fine-tune GNN on downstream task.
Pre-training is only half the story. The other half is transfer: does knowledge learned on one graph actually help on a different graph with a different task? In NLP, transfer is nearly always helpful. In graphs, it's more complicated — and understanding why tells you when to expect it to work.
Positive transfer happens when the source and target graphs share structural regularities. Biological networks and social networks both contain hub-and-spoke patterns, motifs, community structure. A GNN pre-trained on biological networks may have learned to detect these patterns — and they're genuinely useful on social data too.
Negative transfer happens when the source and target differ enough that pre-trained representations mislead the model. If pre-training pushes nodes with high degree to have similar embeddings (because degree predicts many biological properties), but your target task requires distinguishing high-degree nodes that serve different semantic roles, the pre-training actively hurts.
| Strategy | When to use | Mechanism |
|---|---|---|
| Fine-tune all layers | Large target dataset, similar domain | Update all weights from pre-trained initialization |
| Freeze GNN, tune head | Small target dataset, similar domain | GNN embeddings are fixed; only the prediction head trains |
| Linear probe | Evaluate transfer quality | Train only a linear layer on top of frozen embeddings — if this works, the representations are good |
| Prompt tuning | Emerging: very small target data | Learn a small "graph prompt" prepended to input; GNN weights unchanged |
Let's solve the neighborhood explosion problem from Chapter 0. The cause: to train on node v, vanilla message passing needs all k-hop neighbors of v. As k grows, this set grows exponentially. The solution family is graph sampling: instead of using all neighbors, sample a fixed-size subset.
There are three major sampling strategies, each with a different tradeoff between accuracy and efficiency. Understanding the tradeoff helps you pick the right one for your use case.
At each GNN layer, for each node, sample at most k neighbors rather than using all of them. If a node has 200 neighbors, sample 10 at random. The computation tree is now bounded: layer 1 touches at most 10 neighbors, layer 2 touches at most 100 (10×10), not 40,000 (200×200). Memory cost becomes O(k^L) per node — controllable.
Cluster-GCN takes a different approach. Instead of sampling individual neighbors, it partitions the entire graph into dense clusters using METIS or similar algorithms. Each mini-batch is one cluster. Within a cluster, all neighbors are available — no sampling needed. Between clusters, edges are dropped.
The benefit: within-cluster message passing is exact. The cost: between-cluster information is lost, introducing bias. Works well when the graph has natural community structure (social networks, citation graphs). Breaks down on graphs with many long-range dependencies.
For very deep GNNs, randomly drop entire GNN layers during training (similar to DropPath in vision transformers). Each node sees a different number of message-passing steps in different batches. This regularizes training and reduces memory, but doesn't fundamentally solve the neighborhood explosion — you still need neighbor sampling at each kept layer.
python # GraphSAGE-style neighbor sampling with PyG NeighborLoader from torch_geometric.loader import NeighborLoader loader = NeighborLoader( data, num_neighbors=[10, 10], # sample 10 neighbors per layer, 2-layer GNN batch_size=512, # 512 seed nodes per batch input_nodes=train_mask, # only sample from training nodes ) for batch in loader: # batch.num_nodes: seed nodes + their sampled 1-hop + sampled 2-hop # Worst case: 512 + 512×10 + 512×10×10 = 56,832 nodes # Not 512 × (entire graph). Tractable. out = model(batch.x, batch.edge_index) loss = criterion(out[:batch.batch_size], batch.y[:batch.batch_size]) loss.backward()
A GNN predicts that molecule X is toxic. A doctor wants to know: which atoms caused this prediction? A fraud analyst whose GNN flagged a transaction asks: which edges in the transaction graph triggered the alert? This is GNN explainability — identifying which nodes and edges are most responsible for a specific prediction.
Without explainability, a GNN is a black box. You can't debug it, trust it in high-stakes settings, or use its predictions to generate scientific hypotheses. The field of GNN explainability has produced several methods, each with a different theory of what "explanation" means.
GNNExplainer (Ying et al., 2019) asks: what is the minimal subgraph that, when fed to the GNN, produces roughly the same prediction as the full graph? It optimizes a soft mask over edges and node features to find this minimal subgraph. The mask is a differentiable parameter — you backpropagate through the GNN to train it.
Simpler approach: compute the gradient of the prediction score with respect to each edge's presence. Edges with large gradient magnitude are "important" — perturbing them changes the prediction most. Fast to compute (one backward pass). Noisier than GNNExplainer but requires no separate optimization loop.
If your GNN uses attention (GAT), the attention weights over edges are a natural "importance score." High attention weight on edge (u, v) means node u relied heavily on node v's features. Simple, free to compute. But attention ≠ importance in general — a node can attend heavily to a neighbor and still not change its output much.
A small graph where nodes are classified. Click a node to select it, then click "Explain" to run GNNExplainer — watch as the mask optimizes to find the minimal explanatory subgraph. Edge brightness = mask weight (how important is this edge?). Use the slider to set the sparsity of the explanation.
You're building a GNN to predict the energy of a molecule from its 3D structure. You represent each atom as a node with a 3D position (x, y, z). If you rotate the molecule 90 degrees, the energy doesn't change — it's the same molecule. But a standard GNN would compute a different embedding after rotation, because (x, y, z) coordinates change. This is wrong.
This is the problem of geometric equivariance. A model is equivariant to a transformation (like rotation) if rotating the input rotates the output in a predictable way. A model is invariant if the output doesn't change at all under the transformation. For molecular energy prediction, we want invariance: rotate the molecule, energy stays the same.
The fix is to use geometric quantities that are invariant or equivariant by construction. Interatomic distances are invariant to rotation and translation (rotating two atoms doesn't change the distance between them). Bond angles are invariant. Torsion angles are invariant. These are the right features for a molecular GNN.
SchNet uses only interatomic distances as edge features. It's invariant to rotation by construction. But it loses directional information — it can't distinguish a molecule's 3D shape from a different molecule with the same pairwise distances (which can happen).
EGNN (Equivariant Graph Neural Network) is more powerful. It maintains a feature vector and a 3D coordinate for each node. Messages update both the features and the coordinates. The coordinate updates are designed to be equivariant: rotating the input rotates the coordinates by the same rotation. No information is lost.
The coordinate update is a weighted sum of displacement vectors (xi − xj). If you rotate all positions by a rotation matrix R, the displacement vectors rotate by R, so the update rotates by R. The model is equivariant — outputs rotate with inputs.
python # EGNN coordinate update (simplified) def coord_update(pos, edge_index, h, phi_x): src, dst = edge_index diff = pos[src] - pos[dst] # [E, 3]: displacement vectors dist = diff.norm(dim=-1, keepdim=True) # [E, 1]: distances # Message: combine source/dest features + distance m_ij = phi_x(h[src], h[dst], dist) # [E, 1]: scalar weight # Weighted sum of displacement vectors → equivariant update agg = scatter_mean(diff * m_ij, dst, dim=0) # [N, 3] return pos + agg # rotate input → rotate output by same rotation
A friendship graph from 2010 is not the same as the friendship graph from 2024. Edges appear and disappear. Node properties evolve. New nodes join. Standard GNNs are designed for static snapshots — they have no notion of time. For many real applications, this is the critical missing piece.
Dynamic graphs (also called temporal graphs) are graphs where edges carry timestamps. The edge (u, v, t) means u and v interacted at time t. This richer representation enables new tasks: predict whether u and v will interact in the next hour. Detect when u's behavior pattern changed. Rank which past interactions most influenced u's current state.
There are two ways to represent a dynamic graph, leading to two families of models:
TGN (Rossi et al., 2020) maintains a memory vector for each node — a running summary of everything that has happened to that node so far. When event (u, v, t) arrives: (1) retrieve memories of u and v, (2) compute messages from the event, (3) update memories, (4) compute temporal embeddings using a GNN on the temporal neighborhood.
GNNs fail in practice for reasons that have nothing to do with the theory. This chapter is the hard-won knowledge that doesn't appear in papers — the debugging patterns and design choices that separate working implementations from broken ones.
Add more GNN layers → nodes see larger neighborhoods → more context → better predictions. Right? Wrong. After too many layers, all node embeddings converge to the same vector. This is over-smoothing: repeated averaging makes every node look like the graph average. In practice, 2-4 layers is almost always optimal. Adding a 5th layer often hurts.
Over-squashing is subtler. To transmit information from a node 4 hops away, that information must be compressed into a single vector as it passes through each intermediate node. If the intermediate nodes have high degree (many neighbors), the 4-hop information gets diluted. Long-range dependencies are poorly captured. Adding more layers doesn't fix this — the bottleneck is at high-degree nodes.
| Problem | Symptom | Fix |
|---|---|---|
| Over-smoothing | Deep GNN worse than shallow | Fewer layers, residual connections, jumping knowledge |
| Over-squashing | Poor long-range task performance | Graph rewiring, graph transformers |
| Feature scale | Training diverges or very slow | Normalize node features; BatchNorm after each GNN layer |
| Data leakage | Test accuracy >> val accuracy | Inductive split — test nodes not seen during training |
| Message collision | Sum aggregation hides structure | Use mean or max aggregation; or principal neighborhood aggregation (PNA) |
Advanced GNNs don't exist in isolation — they connect to the broader landscape of machine learning. Knowing these connections lets you borrow solutions from adjacent fields instead of reinventing them.
| GNN Concept | Analogous ML Concept | Key difference |
|---|---|---|
| Graph pre-training | BERT / GPT pre-training | No natural token vocabulary; structure varies across graphs |
| Neighbor sampling | SGD mini-batching | Dependencies between samples (graph structure) complicate i.i.d. assumption |
| Equivariant GNNs | CNNs (translation equivariance) | Rotation group vs. translation group; 3D vs. 2D |
| Dynamic graphs (TGN) | Sequence models (LSTM, Transformer) | Events occur on a graph, not a linear sequence |
| GNNExplainer | LIME / SHAP for tabular data | Explanation is a subgraph, not a feature importance vector |
| Graph prompt tuning | Soft prompts (prefix tuning) | Prompt is graph-structured, not a sequence of vectors |
"The goal of science is not to open a door to infinite wisdom, but to set a limit to infinite error."
— Bertolt Brecht (apt for GNN explainability)