Why graphs? Three levels of prediction. From AlphaFold to Pinterest. The ML pipeline for graph data.
Consider a molecule. You could describe it as a list of atoms: carbon, carbon, hydrogen, oxygen. But that list tells you almost nothing useful. What makes a molecule a drug, or a poison, or an antibiotic, is not which atoms are present — it's how they're connected. Change one bond, and you get a completely different compound with completely different properties.
This is the central insight behind graph machine learning: relational structure — who connects to whom — carries information that can't be captured by treating entities as an independent list. Images have pixel grids. Text has sequences. But molecules, social networks, proteins, and knowledge bases are fundamentally graphs: collections of nodes (entities) connected by edges (relationships).
Graphs are everywhere. Here's a sampling of the domains that CS224W covers:
| Domain | Nodes | Edges | Example Task |
|---|---|---|---|
| Social networks | Users | Friendships | Detect bots |
| Biology | Proteins | Interactions | Predict function |
| Chemistry | Atoms | Bonds | Predict drug activity |
| Knowledge bases | Entities | Relations | Answer questions |
| Road networks | Intersections | Roads | Estimate travel time |
| Citation networks | Papers | Citations | Classify research area |
| Recommenders | Users + Items | Interactions | Suggest next item |
Standard ML tools — CNNs, RNNs, transformers — are designed for regular structure: grids, sequences, fixed-size vectors. Graphs are irregular: each node has a different number of neighbors, and there's no natural ordering of nodes. We need new tools built from the ground up for this structure.
Click the canvas to add nodes. Click two nodes in sequence to add an edge. Notice that the graph carries structural information you can't see in a plain list of nodes.
Before we can do machine learning on graphs, we need a precise language for describing them. A graph G = (V, E) has two ingredients: a set V of vertices (also called nodes) and a set E of edges (also called links). An edge (i, j) says that node i and node j are related.
Graphs come in several flavors. An undirected graph treats edges as symmetric: if i is connected to j, then j is connected to i (friendship). A directed graph treats edges as one-way arrows: i follows j on Twitter, but j may not follow back. Edges can also carry weights — road distance, interaction strength, confidence score.
The most compact way to represent a graph is an adjacency matrix A. For a graph with n nodes, A is an n×n matrix where A[i,j] = 1 if edge (i,j) exists, 0 otherwise. For undirected graphs, A is symmetric: A[i,j] = A[j,i]. For weighted graphs, A[i,j] holds the weight instead of 1.
The degree kv of a node v is simply the number of edges connected to it. For an undirected graph, kv = ∑j A[v,j] — the sum of row v in the adjacency matrix. A node with many edges is a hub; one with few is a peripheral node.
In random graphs, most nodes have similar degrees. But in real networks — the Web, Twitter, protein interactions — the degree distribution P(k) follows a power law: P(k) ∝ k−α. Most nodes have very few connections, and a tiny number of hubs have enormous connectivity. This "scale-free" property has profound implications for how information spreads, how robust a network is, and which nodes to prioritize in learning.
A bipartite graph has two types of nodes, with edges only allowed between types (not within). Users and movies (Netflix), drugs and proteins (pharmacology), papers and authors (citations). Bipartite graphs are natural for recommendation problems and for representing typed knowledge.
Toggle directed/undirected to see how the adjacency matrix changes. In undirected mode, the matrix is symmetric. In directed mode, A[i,j] ≠ A[j,i] is possible.
Machine learning on graphs isn't one problem — it's three, operating at different scales. You can make predictions about individual nodes, about individual edges, or about whole graphs. Each level has its own applications, its own loss functions, and its own challenges.
The three levels also require different amounts of information. Node-level tasks only need local structure (who are my neighbors?). Edge-level tasks need pairs of nodes. Graph-level tasks need a global view of the entire graph. GNNs are designed to efficiently compute all three by passing messages through the graph.
Given a graph, label each node. Examples: classify each user in a social network as bot or human; classify each paper in a citation network by research area; predict which amino acids in a protein sequence have which structural role. The graph structure provides context — a user's neighbors reveal a lot about the user.
Given a graph, predict which edges should exist but are missing (or which will exist in the future). Examples: recommend items to users (the "link" between user and item); predict drug-drug interactions; complete a knowledge graph. Formulated as: given nodes i and j, score the likelihood that edge (i,j) exists.
Given a graph (or a set of graphs), classify the whole graph as one entity. Examples: classify a molecule as a drug candidate or not; classify a program's execution graph as buggy or correct; classify a network as a brain scan from a healthy or diseased patient. The graph itself is one datapoint.
Click a task type to highlight what the ML model must predict in this graph.
The most dramatic example of node-level graph ML is AlphaFold 2. Proteins are chains of amino acids, and for 50 years the central challenge in biology was: given the sequence, what is the 3D shape? Shape determines function. A misfolded protein causes diseases from Alzheimer's to Parkinson's.
AlphaFold solves this as a node-level prediction problem. Nodes are amino acids. Edges connect amino acids that are spatially close in 3D (or nearby in the sequence). The task: predict the (x, y, z) coordinate of each node. That's node regression — each node gets a continuous output instead of a discrete label.
Input: a sequence of amino acids (encoded as integer IDs) + a multiple sequence alignment (how this protein compares to evolutionary relatives). Output: (x, y, z) for every amino acid, plus a per-residue confidence score (pLDDT).
| Component | Type | Shape |
|---|---|---|
| Amino acid sequence | Node features | [L] integers (L = sequence length) |
| MSA profile | Node features | [L, 22] (22 amino acid types) |
| Proximity edges | Edges | Sparse [L, L] adjacency |
| Predicted coordinates | Node output | [L, 3] floats (x, y, z) |
| pLDDT confidence | Node output | [L] floats in [0, 100] |
AlphaFold 2 (2021) achieved median GDT score > 90 on CASP14 — matching experimental accuracy. It has since predicted structures for 200 million proteins, covering essentially all of known biology. This is what node-level graph ML looks like at its most impactful.
A simplified protein fragment. Nodes are amino acids. Edges connect amino acids within a distance threshold. Drag the slider to change the distance cutoff and see how the graph changes.
You're on Pinterest looking at a wedding cake. The system immediately suggests 12 more wedding cakes, 3 florists, and a calligraphy invite designer. How? Not because those items look similar in pixel space, but because millions of users pinned them to the same boards. The connections reveal similarity.
This is PinSage (Ying et al., KDD 2018), a graph-based recommender system running at Pinterest scale: 3 billion nodes, 18 billion edges. The graph connects pins (images) to boards, and boards to users. Two pins on the same board are implicitly linked. The task: given a query pin, find pins with similar embeddings zi such that:
The embeddings zi are learned by a GNN that aggregates information from each pin's neighborhood in the graph. No hand-designed features. The graph structure alone teaches the model what "similar" means.
A second edge-level problem: given two drugs, predict what adverse side effects occur when they're taken together (Zitnik et al., 2018). The graph has two types of nodes — drugs and proteins — and two types of edges: drug-protein interactions (from pharmacology databases) and drug-drug edges for known interactions. The task is to predict new drug-drug edges and label them with specific side effects.
A bipartite graph: users (circles, left) and items (squares, right) connected by interactions. Select a user and a hop depth to see recommended items (those reachable within that many hops through shared connections).
In 2020, researchers at MIT and the Broad Institute trained a graph neural network on ~2,500 known antibiotics and asked: of the 6,000 compounds in a library of FDA-approved drugs and natural products, which ones might work against drug-resistant bacteria? The model predicted halicin. Lab tests confirmed it was effective against organisms that resist every known antibiotic, including M. tuberculosis. (Stokes et al., Cell, 2020.)
This is graph-level prediction. Each molecule is one graph. Nodes are atoms; edges are chemical bonds. The model reads the whole graph and outputs a single number: antibiotic activity. The challenge is converting a variable-size graph into a fixed-size vector — a problem called graph pooling.
The simplest approach: mean pooling. Compute a feature vector hv for each node (via GNN message passing), then average them across all nodes. The resulting vector hG = mean({hv}) is a fixed-size representation of the whole graph, regardless of how many nodes it has.
Toggle between two molecules. Despite sharing atom counts, their graphs differ — leading to different predicted properties. Click a node to see its features.
python # Graph-level prediction with PyTorch Geometric import torch from torch_geometric.nn import GCNConv, global_mean_pool from torch.nn import Linear class MoleculeClassifier(torch.nn.Module): def __init__(self, num_features, hidden_dim, num_classes): super().__init__() self.conv1 = GCNConv(num_features, hidden_dim) self.conv2 = GCNConv(hidden_dim, hidden_dim) self.linear = Linear(hidden_dim, num_classes) def forward(self, x, edge_index, batch): # x: [num_nodes, num_features] — atom features # edge_index: [2, num_edges] — bond connections # batch: [num_nodes] — which graph each node belongs to h = self.conv1(x, edge_index).relu() # [num_nodes, hidden] h = self.conv2(h, edge_index).relu() # [num_nodes, hidden] # Global mean pool: average over all nodes in each graph h_G = global_mean_pool(h, batch) # [num_graphs, hidden] return self.linear(h_G) # [num_graphs, num_classes]
Before GNNs learned features automatically, researchers hand-engineered them. Understanding these hand-crafted features builds intuition for what GNNs later learn to compute automatically. Each feature captures a different aspect of a node's role in the graph.
The simplest feature: degree kv = number of edges incident to v. High degree = hub. But degree alone is limited: a node with 5 connections in a tight clique plays a very different role than a node with 5 connections bridging separate communities, even though their degrees are identical.
The clustering coefficient Cv measures how interconnected v's neighbors are. Specifically: what fraction of all possible edges among v's neighbors actually exist?
Cv = 1: all neighbors are connected to each other (tight clique). Cv = 0: no two neighbors are connected (v is a bridge). In social networks, clustering coefficient identifies community members (high C) vs. connectors between communities (low C).
Centrality answers: "how important is this node?" Three definitions, each capturing something different:
| Measure | Definition | High value means |
|---|---|---|
| Degree centrality | kv / (n−1) | Many direct connections |
| Betweenness | Fraction of shortest paths passing through v | Bridge between groups |
| Closeness | 1 / avg. distance to all other nodes | Close to everyone |
| Eigenvector | Importance from important neighbors | Connected to hubs |
Hover over or tap a node to see its degree, clustering coefficient, and betweenness centrality. High-centrality nodes are highlighted.
Node features describe individual nodes. But to compare entire graphs — "is this molecular graph similar to that one?" — we need features that describe the whole structure. The key challenge: two graphs with the same nodes and edges but with nodes permuted are the same graph. Graph features must be permutation-invariant.
A graphlet is a small connected subgraph. For 3 nodes, there are 4 graphlets: a path, a triangle, a star with 2 edges, and a complete triangle. For 4 nodes there are 15. The graphlet degree vector (GDV) for a node counts how many times each graphlet appears touching that node — a richer fingerprint than simple degree.
At the whole-graph level, the graphlet kernel counts how many of each graphlet type appear in the entire graph, producing a histogram. Two molecules with similar graphlet histograms have similar local structural patterns, and thus likely similar properties.
The WL kernel is an iterative algorithm that hashes each node's neighborhood label into a new label, propagates these labels for K rounds, and then compares histograms of final labels between graphs. It's fast, and — crucially — the message-passing computation of modern GNNs is mathematically equivalent to WL. Understanding WL = understanding what GNNs can and cannot distinguish.
The simplest graph feature: count how many nodes have each degree, and store this as a histogram vector. Two random Erdös-Rényi graphs look similar; a social network has a power-law tail. The degree histogram captures this basic structural signature cheaply.
Two graphs shown side by side. Their degree histograms appear as bar charts below. Similar histogram distributions indicate similar structural patterns.
Now that we understand the three task levels and the kinds of features graphs have, let's look at the two paradigms for machine learning on graphs. They represent the "before" and "after" of the graph ML revolution.
Before GNNs, the workflow was:
The problem: feature engineering is manual, domain-specific, and expensive. Features that work for social networks fail for molecules. Features that work for node classification fail for link prediction. You need a human expert for every new domain.
The key insight of modern graph ML: learn the features automatically from the graph structure. A GNN replaces the hand-design step entirely.
At each layer, a GNN updates each node's representation by:
This single formula, with different choices of AGGREGATE (mean, max, sum, attention), gives rise to GCN, GraphSAGE, GAT, and GIN — the four foundational GNN architectures that Lectures 6–8 of CS224W cover in depth.
Lecture 1 of CS224W is an invitation. You've seen the three task levels, the canonical applications, and the conceptual shift from hand-designed features to learned GNN representations. Every subsequent lecture in CS224W builds directly on these foundations.
| Lectures | Topic | Key Methods |
|---|---|---|
| 2–3 | Node Embeddings | DeepWalk, node2vec, LINE |
| 4–5 | Link Analysis | PageRank, HITS, SimRank |
| 6–8 | GNN Foundations | GCN, GraphSAGE, GAT, GIN |
| 9–10 | GNN Theory | Expressive power, WL test, over-smoothing |
| 11–12 | Knowledge Graphs | TransE, RotatE, KGNN |
| 13–14 | Scalable GNNs | GraphSAINT, cluster-GCN, neighbor sampling |
| 15–16 | Graph Transformers | GPS, Graphormer |
| 17–19 | Generative Models | GraphRNN, GDSS, graph diffusion |
PyTorch Geometric (PyG) is the standard library for GNN research. It provides efficient sparse graph operations, pre-built GNN layers (GCNConv, SAGEConv, GATConv, GINConv), and datasets (Cora, CiteSeer, OGB). Install with:
bash # Install PyTorch Geometric (after installing PyTorch) pip install torch_geometric # The Open Graph Benchmark — standard evaluation datasets pip install ogb # Load the Cora citation network (classic node classification benchmark) from torch_geometric.datasets import Planetoid dataset = Planetoid(root='/tmp/Cora', name='Cora') data = dataset[0] # data.x: [2708, 1433] — node features (bag-of-words) # data.edge_index: [2, 10556] — graph connectivity (COO format) # data.y: [2708] — node labels (7 classes, one per research area)
After Lecture 1, you can:
"The history of graph theory is the history of people discovering that everything is connected to everything else — and then trying to figure out what to do about it."
— paraphrased from Jure Leskovec, CS224W Lecture 1