CS224W Lecture 18

Deep Generative Models for Graphs

You can't just enumerate all valid drug molecules — there are 1060 of them. Generative models learn the distribution of real graphs, then sample new ones with desired properties: potent, soluble, synthesizable.

Prerequisites: GNN basics + VAE/RNN intuition. That's it.
10
Chapters
6+
Simulations
0
Assumed Knowledge

Chapter 0: Why Generate Graphs?

Drug discovery costs $2.6 billion and 12 years to bring one drug to market. The bottleneck isn't synthesis or trials — it's finding the right molecule to test. The chemical space has an estimated 1060 drug-like molecules. We can't enumerate them. We can't randomly sample them — only 1 in 1014 random molecules is drug-like at all. We need to generate molecules with desired properties on demand.

This is the core problem: given a distribution over "good" graphs (molecules, social networks, protein interaction networks), learn a model that can sample new graphs from that distribution — and ideally, sample graphs with a specific target property (high binding affinity, low toxicity, synthesizable).

Why Graphs Are Hard to Generate

Generating images is hard but well-studied: pixels form a 2D grid, output is a fixed-size matrix, order is canonical (row by row). Generating graphs is harder on every axis:

Variable size: A molecule might have 5 atoms or 50. A social network might have 10 nodes or 10 million. Unlike images (always 64×64), graphs have no fixed output size.
No canonical order: For a graph with n nodes, there are n! valid orderings of those nodes. A molecule is the same molecule regardless of which atom you number "1." The model must be permutation-invariant.
Structural validity: Not every adjacency matrix corresponds to a valid graph. For molecules, validity is even stricter: carbon must have exactly 4 bonds, nitrogen at most 3, each atom satisfies valence rules. The model must learn these constraints.
The Challenge: Valid Molecule Space is Tiny

Visualizing the chemical space. Random sampling almost never hits valid drug-like molecules. Generative models learn to sample from the valid, drug-like region directly.

Why can't we just randomly sample adjacency matrices to generate new molecules?

Chapter 1: Sequential Generation

The most natural approach: build the graph one piece at a time. Add a node. Add its edges to existing nodes. Add another node. Repeat until done. This converts graph generation into a sequence generation problem — and we have powerful models (RNNs, Transformers) for sequences.

The key design choice is the generation order: in what order do we add nodes and edges? Unlike images (natural row-by-row order), graphs have no inherent order. We must choose one — and the choice affects how complex the dependencies become.

Three Generation Strategies

Node by node + all edges
Add node v. Then decide, for each existing node u, whether edge (u,v) exists. Complexity: O(n) decisions per node, O(n²) total. Full graph connectivity visible at each step.
↓ or
Edge by edge (line graph)
At each step, add one edge (and its endpoints if new). Complexity: O(m) steps for a graph with m edges. Better for sparse graphs where m << n².
↓ or
Substructure by substructure
Add entire rings or functional groups (benzene ring = 6 nodes + 6 edges at once). Leverages domain knowledge about valid substructures. The basis of Junction Tree VAE.
The dependency problem: When adding edge (u, v), the probability depends on the types of nodes u and v, their existing connections, and the overall graph structure so far. This is a complicated conditional distribution — the whole reason simple rules don't work.
Sequential graph generation (pseudocode)
def generate_graph(model):
    G = Graph()  # start with empty graph
    h = model.initial_state()  # RNN hidden state

    while not model.stop(h):
        # Decide whether to add a new node
        node_type = model.sample_node_type(h)
        v = G.add_node(node_type)

        # For each existing node, decide edge existence
        for u in G.nodes:
            p_edge = model.edge_prob(h, u, v)
            if sample(p_edge): G.add_edge(u, v)

        h = model.update(h, G)  # update state with new graph

    return G
Generation Order Comparison

Watch how different strategies add nodes and edges. Notice how node-by-node requires quadratic edge decisions while edge-by-edge is linear in edges.

In node-by-node sequential graph generation, why does adding the k-th node require O(k) decisions?

Chapter 2: GraphRNN SHOWCASE

You et al. (2018) built the first scalable deep graph generator. The insight: use a node-level RNN to generate one node at a time, and for each new node, use an edge-level RNN to generate its connections to all previous nodes. Two RNNs — one for the node sequence, one for each node's edge row — nested inside each other.

The Two-RNN Architecture

Node RNN (outer): Maintains a hidden state h that summarizes the graph generated so far. At each step, it outputs: (1) whether to stop, (2) an initial state for the edge RNN. Its hidden state is updated after each new node is added.
Edge RNN (inner): Given the node RNN's state, generates a binary sequence: edge to node 1? edge to node 2? ... edge to node (k−1)? One binary output per existing node, in reverse BFS order (closest nodes first). Stops after M steps (bandwidth limit).
The BFS ordering trick: GraphRNN doesn't use arbitrary node ordering — it uses BFS order. This has a critical advantage: in BFS order, each new node only needs to connect back to nodes within the last M steps of BFS (the BFS frontier). Edges outside this window are statistically near-zero. This reduces the edge RNN's sequence length from O(n) to O(M) — a huge win for large graphs.

Training: Teacher Forcing

During training, the model sees real graphs. At each step, the ground-truth adjacency row is fed as the "observed sequence" for the edge RNN. The loss is binary cross-entropy between predicted edge probabilities and ground-truth edges. At generation time, the model samples from its own predictions — a process called autoregressive sampling.

L = −∑t=1ni=1M [at,i log pt,i + (1 − at,i) log(1 − pt,i)]

Where at,i ∈ {0,1} is the ground-truth edge between node t and node t-i, and pt,i is the predicted probability. This is exactly binary cross-entropy summed over all (node, position) pairs.

SHOWCASE: GraphRNN — Watch a Graph Grow

Click "Add Node" to add one node at a time. For each new node, the edge RNN decides which previous nodes to connect. Watch the graph build up step by step. Use the density slider to control connectivity.

Add Node →
Edge density 0.40
GraphRNN uses BFS ordering of nodes. What is the key computational advantage this provides over arbitrary node ordering?

Chapter 3: Evaluation

How do you know if your generated graphs are any good? You can't compare them graph-by-graph — the generated graphs are new, not reconstructions of training graphs. You need to compare distributions of graphs. This is the graph evaluation problem, and it's surprisingly hard.

Graph Statistics

The standard approach: compute a set of graph-level statistics on a large sample of generated graphs, and compare their distributions to the same statistics on real graphs. If the distributions match, the generative model has learned the real distribution.

Which statistics matter? Degree distribution (how many nodes have degree k?), clustering coefficient (how many triangles?), orbit counts (subgraph frequency), path length distribution. Each captures a different structural property. A good generative model matches ALL of them.

Maximum Mean Discrepancy (MMD)

MMD measures the distance between two distributions by comparing their mean embeddings in a kernel feature space. For graph evaluation, you compute a graph statistic (e.g., the degree distribution) for each generated graph, embed these statistics in a kernel space, and compute MMD against the real graph statistics.

MMD2(P, Q) = Ex,x'~P[k(x,x')] − 2Ex~P, y~Q[k(x,y)] + Ey,y'~Q[k(y,y')]

Where k is a kernel function (Gaussian RBF is common) and P, Q are the real/generated distributions over graph statistics. MMD = 0 means the distributions are identical; larger MMD means larger distributional gap.

FCD: Fréchet ChemNet Distance

For molecules specifically, FCD (Fréchet ChemNet Distance) embeds molecules using a pretrained chemical neural network, computes Fréchet distance between the real and generated molecule embeddings. This captures chemical validity and similarity simultaneously — not just structural properties but chemical semantics.

MetricWhat it measuresRangeBetter =
MMD (degree)Degree distribution match[0, ∞)Lower
MMD (clustering)Triangle density match[0, ∞)Lower
ValidityFraction of valid molecules[0, 1]Higher
UniquenessFraction of non-duplicate valid molecules[0, 1]Higher
NoveltyFraction not in training set[0, 1]Higher
FCDChemical distribution distance[0, ∞)Lower
Degree Distribution: Real vs. Generated

Comparing degree distributions of real graphs (warm) vs. generated graphs (teal). A good generator should overlay almost perfectly. Click to see different model quality levels.

A graph generative model scores 0.98 validity (almost all generated molecules are chemically valid) but a high FCD score (large distance from real molecule distribution). What does this tell us?

Chapter 4: Goal-Directed Generation — GCPN

GraphRNN learns to generate graphs that look like training graphs. But drug discovery needs more: generate a molecule with specific target properties — high QED (drug-likeness score), low toxicity, high binding affinity. This is goal-directed generation, and it requires optimization, not just sampling.

GCPN (Graph Convolutional Policy Network, You et al. 2018) frames molecule generation as an MDP and uses RL to optimize molecular properties while maintaining chemical validity.

GCPN as a Markov Decision Process

State st
The current partial molecule as a graph. A GCN runs on this graph to produce node embeddings — a neural representation of the molecule's structure so far.
Action at
Choose: (1) which scaffold node to attach to, (2) what atom/bond type to add, (3) whether to stop. This is a combinatorial action space — GCPN samples from a learned policy over this space.
Reward rt
At each step: small penalty for valence violations (intermediate reward). At the final step: large reward for the target property (QED, logP, binding affinity) plus an adversarial reward for "looks like real molecules."
↓ PPO updates the GCN policy
The two reward types: (1) Property reward: the computed QED/logP/SA score of the finished molecule. (2) Adversarial reward: a discriminator trained to distinguish GCPN molecules from real training molecules. This pushes GCPN molecules to be realistic, not just high-scoring but chemically implausible.
GCPN reward function (simplified)
def compute_reward(mol, step, final):
    r = 0.0

    # Step-level: penalize valence violations
    if has_valence_violation(mol):
        r -= 0.1

    if final:
        # Molecular property reward (e.g., penalized logP)
        r += 1.0 * penalized_logp(mol)

        # Adversarial reward (discriminator output)
        r += 0.5 * discriminator.score(mol)

        # Similarity constraint (don't stray too far from valid space)
        if is_valid(mol):
            r += 0.5

    return r
GCPN: Property Score vs. Steps

Watch how the molecule's penalized logP score improves as GCPN trains over episodes. Early molecules are random; later ones are specifically optimized. The adversarial reward prevents runaway optimization.

GCPN uses two reward signals: a property reward and an adversarial reward. Why is the adversarial reward necessary?

Chapter 5: Junction Tree VAE

GraphRNN and GCPN generate atoms one at a time. But chemists don't think about molecules atom by atom — they think about scaffolds and functional groups. Benzene rings. Carboxyl groups. Amide bonds. These substructures appear repeatedly across drug-like molecules. Junction Tree VAE (JT-VAE, Jin et al. 2018) generates molecules at this higher level of abstraction.

The Hierarchical Representation

Every molecule can be decomposed into a junction tree — a tree where each node is a ring or bond (a "motif"), and edges represent how motifs are fused. The molecule is assembled by combining motifs according to this tree structure.

The brilliant simplification: Trees have no cycles, so they're much easier to decode. A junction tree over n motifs has only n−1 edges, and decoding a tree is O(n). Contrast this with generating all O(n²) possible edges in an atom-level model.

The Three-Phase Generation Process

Phase 1: Encode
Run a GNN on the full atom-level molecule graph. Run a separate GNN on the junction tree. Combine both representations into a single latent vector z in a continuous VAE latent space.
Phase 2: Decode Tree
From z, generate the junction tree autoregressively: root motif first, then children, in depth-first order. Each node in the tree is one of ~780 vocabulary motifs (common rings/bonds extracted from training data).
Phase 3: Assemble Graph
Given the junction tree, decide how to fuse adjacent motifs (which atoms overlap at the junction). Run a GNN to score candidate atom-level assemblies and pick the valid one. This step adds atom-level precision back in.
LJTVAE = −Eq(z|x)[log p(T|z) + log p(G|T, z)] + KL[q(z|x) || p(z)]

Where T is the junction tree, G is the full molecule, and z is the shared latent vector. The first term rewards accurate tree reconstruction, the second rewards accurate graph assembly, and the KL term regularizes the latent space to be smooth (so we can interpolate between molecules).

Junction Tree Decomposition

Click a molecule to see how it decomposes into motifs (rings and bonds) forming a junction tree. The tree structure is much simpler than the full atom graph.

Why does JT-VAE first generate a junction tree rather than directly generating the atom-level molecular graph?

Chapter 6: Diffusion on Graphs

Score-based diffusion models dominated image generation (DDPM, Stable Diffusion). The same idea applies to graphs: add noise to a real graph until it becomes a random graph, then learn to reverse the process — starting from random noise and gradually denoising into a valid graph.

What Does "Noise" Mean for a Graph?

For images, noise means adding Gaussian noise to pixel values. For graphs, there are two noise models:

Edge noise: Randomly flip edges (add/remove). A fully noised graph is the Erdős–Rényi random graph G(n, 0.5) — each edge exists with probability 0.5. Denoising = learning which edges to keep.
Node feature noise: Add Gaussian noise to continuous node features (atom positions in 3D molecular conformers), or randomly permute discrete node types. Denoising = recovering the correct atom types and positions.
The forward process (noise addition): q(Gt | Gt-1) is a Markov chain that adds noise at each step. For edges, it's a Bernoulli noise process. For node features, it's Gaussian. After T steps, GT is approximately pure noise.
q(Gt | G0) = ∏(i,j) q(At,ij | A0,ij)

The reverse process pθ(Gt-1 | Gt) is learned by a GNN that sees the noisy graph Gt and predicts either the denoised graph G0 (x-prediction) or the noise ε (noise-prediction). Training minimizes the expected denoising error over all noise levels.

GDSS and DiGress: Two Key Models

GDSS (Jo et al. 2022) uses continuous stochastic differential equations for both node features and adjacency simultaneously. DiGress (Vignac et al. 2022) uses discrete diffusion — categorical node/edge type distributions that transition via a Markov matrix at each step. DiGress achieves state-of-the-art on molecular generation benchmarks.

DiGress: discrete diffusion forward process
import torch

# Q_t: transition matrix for discrete edge types at timestep t
# Q_t[i,j] = probability of type i transitioning to type j
def get_Qt(t, T=1000):
    alpha_t = 1.0 - t / T  # noise schedule (linear)
    # Absorbing state: with prob (1-alpha_t), jump to "noise" type
    Qt = alpha_t * torch.eye(num_edge_types) + \
         (1 - alpha_t) * torch.ones(num_edge_types, num_edge_types) / num_edge_types
    return Qt

# Marginal at time t: E_t = E_0 @ Qt_bar (product of Qt's)
# Denoising: GNN predicts E_0 given noisy graph E_t
Diffusion: Noising and Denoising a Graph

Drag the slider to see a graph at different noise levels. Left = clean molecular graph. Right = fully noised (Erdős–Rényi). Generation runs right to left.

Noise level t/T 0.00
In graph diffusion models, what is the "reverse process" that the neural network learns?

Chapter 7: Molecular Generation Benchmarks

Molecular generation has standardized benchmarks that measure whether generated molecules are actually useful for drug discovery — not just structurally valid, but chemically desirable. The two canonical property scores are QED and penalized logP.

QED: Quantitative Estimate of Drug-Likeness

QED scores a molecule from 0 to 1 based on how closely it matches properties of known oral drugs: molecular weight (200-500 Da), lipophilicity (logP 0-5), polar surface area, number of hydrogen bond donors/acceptors, and absence of known problematic groups. QED = 1 means the molecule looks exactly like known drugs.

QED(mol) = exp&left;( 1ni wi log di(mol) &right;)

Where di(mol) is the desirability score for property i (using predefined bell curves matching known drug distributions) and wi are fixed weights. This is a weighted geometric mean of desirability scores.

Penalized logP

Raw logP (octanol-water partition coefficient) measures lipophilicity — how much a drug can cross cell membranes. But high-logP molecules tend to be too greasy, poorly soluble, and toxic. Penalized logP subtracts two penalty terms:

plogP(mol) = logP(mol) − SA(mol) − ringPenalty(mol)

Where SA is the synthetic accessibility score (0-10, lower is easier to synthesize), and ringPenalty penalizes macrocycles (large rings that are hard to synthesize). Goal-directed generation optimizes for high plogP — but early methods found pathological solutions (giant rings, unusual atoms) that were structurally valid but not synthesizable.

ModelQED (top-3 avg)plogP (top-3 avg)ValidityKey innovation
GraphRNN0.893.198.1%Two-RNN autoregressive
GCPN0.9487.98100%RL-guided with GCN policy
JT-VAE0.9255.30100%Hierarchical motif generation
DiGress0.9529.30100%Discrete graph diffusion
GDSS0.9237.6295.7%SDE-based continuous diffusion
Property Score Distribution: Real vs. Generated

Comparing plogP distributions of real drug molecules (warm) vs. GCPN-generated (teal). GCPN successfully shifts the distribution toward higher plogP. But does it go too far?

A goal-directed generator achieves plogP = 15 (very high) but SA score of 9.5 (nearly impossible to synthesize). What does this tell us about the generator?

Chapter 8: Applications

Graph generation is not just a toy problem. Three domains have demonstrated production-scale impact: drug discovery, material design, and protein engineering.

Drug Discovery Pipeline

Target identification
Identify a protein involved in disease (e.g., EGFR kinase in lung cancer). Get its 3D structure from PDB.
Structure-based generation
Use a 3D graph generative model (e.g., TargetDiff, Pocket2Mol) conditioned on the protein pocket. Generate molecules that fit the binding site's 3D shape and electrostatics.
Property filtering
Filter by QED > 0.6, SA < 4, molecular weight 200-500 Da. Keep only molecules that pass all filters. Typically reduces candidates 100×.
Virtual screening
Use a docking program (AutoDock Vina, Glide) to score binding affinity. Top candidates go to wet lab synthesis and testing.
Real-world result: Insilico Medicine used generative graph models to design a novel CDK20 inhibitor for idiopathic pulmonary fibrosis. The molecule entered Phase II clinical trials in 2023 — designed, synthesized, and validated in 18 months. Traditional drug design takes 5-10 years for the same stage.

Beyond Drug Discovery

Material design: Generate crystal structures (graphs of atoms + unit cell) with target bandgap, conductivity, or thermal stability. CDVAE and DiffCSP use graph diffusion for crystal structure generation. Applications: battery materials, solar cells, superconductors.
Protein design: Proteins are graphs of amino acids connected by peptide bonds, plus non-covalent contacts in 3D space. ProteinMPNN and RFDiffusion generate protein sequences/structures with target function. This is the hottest area in biotech circa 2024.
A structure-based drug design model generates molecules conditioned on the protein binding pocket. What information from the protein does the model need?

Chapter 9: Connections

Graph generation sits at the intersection of graph learning, generative modeling, and chemistry. Understanding where it connects to adjacent topics reveals the full scope of what you've learned.

TopicConnection to Graph GenerationLearn more
GNN BasicsGraphRNN, GCPN, JT-VAE all use GCN/GGNN as the neural backbone for encoding partial graphs. Message passing = reading the current graph state.Lec 3: GNN
Graph TheoryWL isomorphism test bounds what GNNs can distinguish — relevant to whether the model can differentiate valid from invalid subgraphs during generation.Lec 6: Theory
RL AlgorithmsGCPN uses PPO for goal-directed generation. The molecule-building MDP is a classic RL setting with sparse reward at episode end.RL Algorithms
Agents + GraphsGCPN's policy is essentially a graph-grounded agent: at each step, it reads the current graph state, reasons about which action improves properties, and executes.Lec 17: Agents
VAE/VQVAEJT-VAE directly extends the VAE framework to graph-structured latent spaces. The reparameterization trick and ELBO apply identically.VAE/VQVAE
Diffusion ModelsDiGress and GDSS extend DDPM's forward-reverse process to discrete/continuous graph spaces. The score function is now a GNN, not a U-Net.Diffusion

Key Papers

"Making drugs is hard. Making drug molecules is easy — once you have a good generator. The hard part is learning what 'good' means."
Next: Lec 19 — Conclusion →