Your raw graph is almost never ready to feed into a GNN. Nodes may have no features, edges may be too sparse, and choosing the wrong prediction head throws away information. This lesson covers everything between "I have a graph" and "my model is trained."
You have a citation network. Every paper is a node, every citation is an edge. You want to train a GNN to predict a paper's research topic. You open the dataset — and realize: the nodes have no features. Just IDs. What do you feed the GNN?
Or: you have a protein interaction graph where two proteins are connected if they interact. Most proteins have only 2-3 known interactions — the graph is incredibly sparse. A 2-layer GNN can only see 2 hops out, which might be 4-6 nodes. That's almost no information per node. What do you do?
Or: you have Twitter's social graph. Some accounts have 10 million followers. When you aggregate neighbor embeddings, one famous account dominates the computation for millions of nodes. This isn't wrong, but it's incredibly slow — O(degree) per layer, and degree can be 107. What do you do?
A GNN needs two things to work: (1) node features, and (2) a graph structure that enables good information flow. The raw input graph often provides neither in usable form. Graph augmentation bridges this gap — it modifies the graph before message passing begins.
Augmentation falls into two categories. Feature augmentation addresses missing or incomplete node features. Structure augmentation addresses problems with the graph topology — too sparse, too dense, or disconnected in ways that make message passing ineffective.
Three problem graphs. Toggle between them to see what makes each difficult and what augmentation helps. This motivates the next four chapters.
When nodes have no features, you face a fundamental question: what do you put in the GNN's input layer? You can't leave it empty. The answer depends on what you need the features to do: do they need to distinguish individual nodes (expressiveness), or do they need to generalize to new nodes (transferability)?
Give every node the same scalar feature: hv(0) = 1. This is the simplest possible choice. The GNN still learns from structure — it can distinguish a node with 5 neighbors from a node with 2, because their aggregation results differ. But it cannot distinguish two structurally identical nodes.
Give each node a unique indicator vector: node v gets a vector of all zeros except position v, which is 1. This is equivalent to giving each node a row of the identity matrix In×n. Now every node is uniquely identifiable — the GNN can, in principle, learn different things for every single node.
Compute structural properties of each node and use them as features. These are things the GNN could theoretically learn, but computing them explicitly and giving them as input is much easier for the network to use. Common choices:
| Feature | What it captures | Formula |
|---|---|---|
| Node degree | How many neighbors | |N(v)| |
| Clustering coefficient | How densely connected neighbors are | edges among N(v) / (|N(v)| choose 2) |
| Pagerank score | Importance from random walk | iterative formula |
| Cycle participation | Shortest cycle length through v | BFS from v |
These features are inductive — you can compute them for any new node in any new graph. And they're often more informative than one-hot IDs for tasks like predicting molecular properties, where cycle length and clustering coefficient are chemically meaningful.
Sometimes nodes have partial features. For example, in a user-item recommendation graph, some users have profile information (age, location) and some don't. The solution: impute missing features with the column mean, or learn an embedding for "missing." Don't drop nodes with missing features — you lose graph structure.
python # Three augmentation strategies in PyTorch Geometric import torch from torch_geometric.transforms import ( AddSelfLoops, LocalDegreeProfile, OneHotDegree ) from torch_geometric.utils import degree # Strategy 1: constant feature (all 1s) def add_constant_feature(data): num_nodes = data.num_nodes data.x = torch.ones(num_nodes, 1) # shape [N, 1] return data # Strategy 2: one-hot node ID def add_onehot_feature(data): num_nodes = data.num_nodes data.x = torch.eye(num_nodes) # shape [N, N] — not scalable! return data # Strategy 3: degree + local profile (inductive, scalable) def add_structural_features(data): # LocalDegreeProfile adds [min, max, mean, std, sum] of neighbor degrees transform = LocalDegreeProfile() data = transform(data) # data.x shape: [N, 5] return data
You have a molecular graph where two atoms are 10 bonds apart. With a 3-layer GNN (3-hop receptive field), those atoms cannot "see" each other. But in chemistry, long-range electronic effects are real — atoms across the molecule do influence each other. Standard message passing is structurally blind to this.
One solution: add a virtual node — a synthetic node connected to every single node in the graph. Call it v*. Now every pair of real nodes is at most 2 hops apart: node A → v* → node B. The virtual node creates a "global information highway."
At initialization, the virtual node's embedding hv* starts at zero (or a learned embedding). During message passing, it aggregates from every real node — forming a global summary of the entire graph. Then it sends back this summary to every real node. So at layer 1, every real node has seen information from every other node in the graph.
In a molecular graph with 50 atoms, without the virtual node and with K=3 layers, a node might see only its 5 or 6 nearby neighbors. With the virtual node, every node sees the global average of all 50 atoms' embeddings after just 1 layer. For tasks like predicting a molecule's solubility — which depends on the whole molecule, not just local environment — this is crucial.
The virtual node aggregates from ALL nodes — including irrelevant ones. In a large graph with 10,000 nodes, the global summary is an average over 10,000 embeddings. Distant nodes contribute noise. This can hurt if the task is purely local. As with most engineering choices: use virtual nodes when your task has global structure; skip them when the task is local.
A sparse graph of 8 nodes. Toggle the virtual node on/off. Watch how many hops are needed for the highlighted source node (warm) to reach the target node (teal). With the virtual node, all nodes are 2 hops apart.
Imagine you're building a GNN for Pinterest's recommendation graph: 2 billion pins, 18 billion edges. A single forward pass aggregating ALL neighbors for ALL nodes simultaneously would require more memory than exists on Earth. You cannot run vanilla message passing on this graph. You need neighbor sampling.
The idea is surgical: instead of aggregating from all neighbors, sample a random subset of size k each forward pass. For node v with 10,000 neighbors, pick 25 at random. Aggregate those 25. Ignore the rest — for this forward pass.
PinSage (Ying et al., KDD 2018 — mentioned in Lecture 1) trains on Pinterest's graph using importance-based neighbor sampling rather than uniform sampling. Instead of picking k random neighbors, PinSage simulates random walks from v and samples neighbors proportionally to how often they appear in the walk — effectively sampling important neighbors.
The result: PinSage can produce embeddings for 3 billion nodes, something impossible with full-graph message passing. And it generalizes inductively — new pins can be embedded the moment they appear in the graph, by sampling their current neighbors.
Without sampling: aggregating for one node costs O(degree) work per layer. A 3-layer GNN on a high-degree node expands to O(degree3) neighbors in the worst case — exponential blowup. With sampling k neighbors per layer: O(k3) neighbors, completely degree-independent. On a graph with max degree 10,000 and k=25, that's 10,0003 vs 253 = 15,000x reduction.
Sampling introduces variance. For low-degree nodes (degree 2-3), sampling k=25 neighbors means you always use all of them — no problem. For high-degree nodes (degree 10,000), you see only 0.25% of neighbors per pass. The gradient is noisier, and you might consistently miss important neighbors if they're unlucky in sampling. Practical fix: increase k for high-degree nodes, or use importance-weighted sampling.
python # Neighbor sampling in PyTorch Geometric (NeighborLoader) from torch_geometric.loader import NeighborLoader # Sample up to 25 neighbors per layer, 3 layers deep loader = NeighborLoader( data, num_neighbors=[25, 25, 25], # k per layer, innermost first batch_size=512, # seed nodes per mini-batch input_nodes=train_mask, # which nodes to predict for shuffle=True, ) # Each batch is a subgraph: # - 512 seed nodes (the ones we want predictions for) # - Up to 25^3 = 15,625 sampled neighbor nodes # - Only edges within this sampled subgraph # Shape of batch.x: [num_sampled_nodes, feature_dim] for batch in loader: out = model(batch.x, batch.edge_index) # standard GNN forward loss = criterion(out[:batch.batch_size], batch.y[:batch.batch_size]) # Only compute loss on seed nodes (first batch_size rows) loss.backward()
Your GNN produces node embeddings hv(L) for every node v. These are vectors in Rd. But your task might be node-level (predict a label per node), edge-level (predict a score per edge), or graph-level (predict one label for the whole graph). Each task needs a different prediction head — the function that converts embeddings into predictions.
The simplest case. You have one embedding per node. Apply a linear layer and softmax:
W(H) maps from embedding dimension d to the number of classes c. Shape: [c × d]. That's it — a standard MLP head over individual node embeddings. For regression, drop the softmax.
For edges, you need to combine embeddings from two nodes into a single score. You have several options:
Concatenation + linear: Concatenate hu and hv, then apply a linear layer. Shape: W(H) ∈ Rc × 2d. Expressive — can capture asymmetric relationships (u links to v ≠ v links to u).
Dot product: Just take the inner product hu · hv. Returns one scalar — good for link prediction (will these two nodes connect?). Forces symmetric: score(u,v) = score(v,u).
For k-way edge classification (k types of edge labels), use multi-head dot product: learn k different linear maps W1,...,Wk, compute huT Wi hv for each, softmax over k scores.
For graph tasks (e.g., "is this molecule toxic?"), you need to pool all node embeddings into one graph embedding. Three main strategies:
| Strategy | Formula | When to use |
|---|---|---|
| Mean pooling | hG = (1/|V|) ∑v hv | When every node contributes equally. Good for homogeneous graphs. |
| Max pooling | hG = maxv hv | When the most extreme node matters (e.g., detecting if ANY toxic group exists). Ignores most nodes. |
| Sum pooling | hG = ∑v hv | When graph size matters — sum grows with |V|. Good for counting tasks. |
| DiffPool (hierarchical) | clusters = softmax(Whv), repeat | When graph has community structure. Most expressive but slow. |
A small graph with 6 nodes. Each node has a 2D embedding (visualized as position). Toggle task level and see how the prediction head combines embeddings. For edge: toggle concat vs dot-product. For graph: toggle mean vs max.
Once you have predictions ŷ and ground truth labels y, you need a loss function to measure the error and drive backpropagation. The right loss depends on your task type: classification, regression, or link prediction.
For node or graph classification with c classes, use cross-entropy. The prediction ŷ is a probability vector over c classes (after softmax). The label y is a one-hot vector (all zeros except class k, which is 1).
Since y is one-hot, only one term survives the sum: −log(ŷk) where k is the true class. Cross-entropy punishes confidently wrong predictions very harshly (log(0) = −∞) and rewards high confidence on the correct class.
For predicting continuous values (e.g., molecular energy, traffic flow), use MSE:
MSE penalizes large errors quadratically. If you suspect outlier labels in your dataset, use MAE (mean absolute error) instead — it's more robust to extreme values. The GNN architecture doesn't change; only the loss and final activation change (no softmax for regression — just a linear output).
For link prediction, you have positive edges (edges that actually exist in the graph) and negative edges — pairs of nodes that are NOT connected. You want connected nodes to have high dot-product similarity and non-connected nodes to have low similarity.
The first sum pushes positive edge scores up. The second pushes negative edge scores down. Because |E| << |non-edges|, you typically sample the same number of negative edges as positive edges (called negative sampling).
python # Three loss functions for three GNN task types import torch.nn as nn import torch # 1. Node/graph classification ce_loss = nn.CrossEntropyLoss() loss = ce_loss(logits, labels) # logits: [N, C], labels: [N] (int) # 2. Regression mse_loss = nn.MSELoss() loss = mse_loss(predictions.squeeze(), targets) # [N], [N] # 3. Link prediction (BCE with negative sampling) bce_loss = nn.BCEWithLogitsLoss() # pos_scores: dot products for real edges # neg_scores: dot products for sampled negative edges pos_scores = (h[src] * h[dst]).sum(dim=-1) # [E] neg_scores = (h[neg_src] * h[neg_dst]).sum(dim=-1) # [E] scores = torch.cat([pos_scores, neg_scores]) labels_bin = torch.cat([torch.ones(E), torch.zeros(E)]) loss = bce_loss(scores, labels_bin)
Where do training labels come from? In some domains, getting labeled data is expensive — labeling a protein's function requires wet-lab experiments. In other domains, the graph structure itself is a label generator. This distinction drives whether you use supervised or self-supervised training.
In supervised training, you have explicit task labels provided externally: node class labels (e.g., paper topic in a citation network), graph property labels (e.g., molecule toxicity measured in a lab), or edge labels (e.g., interaction type). You compute a loss directly against these labels.
The GNN is trained end-to-end to minimize task loss. The embeddings it produces are optimized for the specific labeled task. This is the most direct path to good task performance — if you have labels. The problem: labels are often scarce, expensive, or impossible to scale.
Self-supervised methods treat the graph structure itself as the source of supervision signals. No external labels needed. The two main variants:
Link prediction as pretext task: Mask some edges. Train the GNN to predict which edges were removed. Loss = BCE on real vs masked edges. After training, the embeddings encode "which nodes tend to connect" — useful for recommendation and social network tasks.
Node clustering pretext task: Assign initial cluster IDs using a simple algorithm (e.g., spectral clustering or Louvain). Train the GNN to predict which cluster a node belongs to. Embeddings learn to encode community structure — useful for social networks and biological networks.
Often you have a large graph where only a small fraction of nodes are labeled. Semi-supervised training uses both: a self-supervised loss on all nodes (no labels needed) plus a supervised loss on the labeled subset. The self-supervised term regularizes the model so it doesn't overfit to the few labeled examples.
The hyperparameter λ controls the balance. Even setting λ = 0.1 can dramatically improve performance when labels are scarce — the self-supervised signal forces the model to use graph structure rather than memorize the few labeled nodes.
python # Semi-supervised GNN training model = GNN(in_dim=64, hidden=128, out_dim=7) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) for epoch in range(200): model.train() embeddings = model(data.x, data.edge_index) # [N, 128] # Supervised: only on labeled nodes sup_loss = ce_loss(embeddings[train_mask], data.y[train_mask]) # Self-supervised: link prediction on all nodes pos_s = (embeddings[src] * embeddings[dst]).sum(-1) neg_s = (embeddings[neg_src] * embeddings[neg_dst]).sum(-1) self_loss = bce(torch.cat([pos_s, neg_s]), torch.cat([ones_E, zeros_E])) loss = sup_loss + 0.1 * self_loss loss.backward() optimizer.step() optimizer.zero_grad()
In standard deep learning (images, text), train/val/test splitting is straightforward: shuffle examples, take 80/10/10. But graphs have dependencies between examples — nodes are connected. Randomly splitting nodes or edges can leak information across splits. Graph splitting requires careful thought.
The first question is not how to split, but what you're splitting. In a transductive setting, you have one graph. You train on some of its nodes/edges, validate on others, and test on others — but the GNN sees the ENTIRE graph structure (all edges) during training. Only labels are withheld, not graph topology.
In an inductive setting, you have multiple graphs. You train on some graphs, validate on others, and test on held-out graphs the model has never seen. The model must generalize to entirely new graph structures, not just new labels on a known graph.
For transductive node classification: split LABELS, not the graph. Assign each node to train/val/test. During training, message passing uses all edges (including those touching val/test nodes). Only the loss is computed on train-masked nodes. This is the standard split for Cora, Citeseer, and PubMed benchmarks.
Link prediction is the trickiest split. You must split EDGES into train/val/test. But removing edges changes the graph structure used for message passing. The standard approach: split edges into two sets — "message passing edges" (used for GNN computation) and "supervision edges" (used only for the loss).
For datasets with multiple graphs (e.g., 1000 molecules), split is simple: randomly assign entire graphs to train/val/test. No information leaks across graphs — they're independent. This is the fully inductive setting. Make sure to split at the graph level, not at the node level, or you'll have partial graphs in test.
Toggle between the two splitting strategies on a single graph. Warm = train nodes. Teal = val nodes. Blue = test nodes. Green edges = used for message passing. Red dashed edges = supervision-only (hidden from GNN topology).
Every GNN project is a sequence of design decisions, and the choices interact. Add a virtual node and you might need fewer layers. Use neighbor sampling and you must adjust your batch size accordingly. This chapter traces the complete pipeline and its decision points.
You, Ying, and Leskovec (NeurIPS 2020) evaluated over 315,000 GNN configurations by systematically varying: aggregation (mean/sum/max), activation (ReLU/PReLU), normalization (none/batch/layer), skip connections (none/standard/dense), and K. Their findings:
No universal winner. The optimal design on molecular tasks is completely different from the optimal design on social network tasks. Neither GCN nor GAT is uniformly best. Mean aggregation beats sum on most benchmarks but not all.
Simple designs often win. Many fancy GNN variants were beaten by a well-tuned simple GCN with the right normalization and skip connections. "Design space" thinking helps you find the right simple design for your problem.
python # Complete minimal GNN training loop from torch_geometric.nn import SAGEConv, global_mean_pool import torch.nn.functional as F class GNNPipeline(nn.Module): def __init__(self, in_dim, hidden, out_dim): super().__init__() # Feature pre-processing MLP (optional but often helpful) self.pre = nn.Linear(in_dim, hidden) # K=2 GNN layers (SAGEConv = GraphSAGE mean aggregation) self.conv1 = SAGEConv(hidden, hidden) self.conv2 = SAGEConv(hidden, hidden) self.bn1 = nn.BatchNorm1d(hidden) self.bn2 = nn.BatchNorm1d(hidden) # Prediction head (graph-level: pool then linear) self.head = nn.Linear(hidden, out_dim) def forward(self, x, edge_index, batch): x = F.relu(self.pre(x)) # [N, hidden] x = F.relu(self.bn1(self.conv1(x, edge_index))) # 1-hop x = F.relu(self.bn2(self.conv2(x, edge_index))) # 2-hop x = global_mean_pool(x, batch) # [B, hidden] — graph-level return self.head(x) # [B, out_dim] — predictions # Training model = GNNPipeline(in_dim=9, hidden=256, out_dim=2) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) for batch in train_loader: optimizer.zero_grad() logits = model(batch.x, batch.edge_index, batch.batch) loss = F.cross_entropy(logits, batch.y) loss.backward() optimizer.step()
Every concept in this lecture connects outward — to other GNN ideas, to other fields, and to open problems. Follow the thread that interests you most.
Lecture 3 introduced message passing and GCN. Lecture 4 showed that all GNNs are MSG+AGG choices. This lecture showed that before you even choose your MSG+AGG, you must decide what graph to feed it — and that choice (augmentation) matters as much as the architecture.
This lecture took the practical view: given augmentation and training choices, build a working GNN. Lecture 6 asks the theoretical question: what can and cannot a GNN compute? The answer involves the Weisfeiler-Lehman graph isomorphism test — GNNs are provably limited to the same expressiveness as the 1-WL test. Some structural features (like cycle membership) cannot be captured by standard message passing at all — which is why feature augmentation (adding cycle lengths explicitly) can help so dramatically.
Feature augmentation in GNNs mirrors data augmentation in computer vision. Random crops and flips make images more diverse — neighbor sampling makes graph processing more scalable. Adding structural features to nodes is like adding edge detectors as input channels to a CNN. The goal is the same: make the input representation richer and more informative for the task.
PinSage (Lecture 1) now makes complete sense in this framework: it uses (1) neighbor sampling for scalability, (2) importance-based sampling instead of uniform, (3) inductive setting (new pins can be embedded on demand), and (4) a custom triplet loss (a form of self-supervised link prediction). Every technique in this lecture is present in PinSage — it's the reference implementation of scalable real-world GNNs.
The current frontier: (1) graph transformers that attend globally without being bottlenecked by graph structure, (2) better negative sampling strategies for link prediction (hard negatives), (3) self-supervised pretraining at scale (pretrain on large graphs, fine-tune with few labels), and (4) equivariant GNNs that respect 3D geometry for molecular design.
| Technique | Solves | Cost | Use when |
|---|---|---|---|
| Constant features | No features | Minimal | Structure-only tasks |
| Structural features | No features, inductive | Preprocessing time | Inductive setting |
| One-hot node ID | No features, transductive | Memory O(N²) | Small graphs, transductive only |
| Virtual node | Sparse graph, long-range | +1 node, small | Graph-level tasks |
| Neighbor sampling | Dense/large graph | Gradient noise | Millions of nodes |
| Semi-supervised loss | Few labels | Extra loss term | Label-scarce settings |
| Inductive split | Generalization test | Need multiple graphs | Production GNN systems |
CS224W Lec 3: GNN Basics — message passing from scratch
CS224W Lec 4: GNN Design Space — MSG+AGG framework, GCN vs SAGE vs GAT
CS224W Lec 2: Node Embeddings — DeepWalk and Node2Vec as self-supervised GNN precursors