Generate molecules that are always chemically valid by first building a tree of known-valid substructures (rings, functional groups), then assembling them together.
Picture a molecular generation model that works on SMILES strings — textual representations like "C1CCCCC1" for cyclohexane. You train an RNN or VAE to generate these strings token by token. What's the problem?
Generating "C1CCCCC" and then forgetting to close the ring (no final "1") produces an invalid SMILES. Generating "C(C)(C)(C)(C)C" gives carbon with 5 bonds — impossible. For models operating on SMILES strings, the vast majority of generated strings are chemically invalid. CharVAE, a SMILES-based VAE, has only ~35% validity rate on decoded samples.
Jin et al.'s insight: instead of generating atoms one at a time (which can create invalid combinations), generate molecules from a vocabulary of known-valid substructures. Every piece in the vocabulary is a valid chemical fragment — a benzene ring, a piperidine ring, a simple chain of 2-3 atoms. Assembling valid pieces can only produce valid molecules.
This is how experienced medicinal chemists think: not "let me add atoms," but "let me combine a pyridine ring with a morpholine and a short alkyl chain." The junction tree VAE formalizes this hierarchical thinking.
Building a molecule atom by atom. Green = current partial molecule is valid. Red = already invalid. Notice how easily a small mistake makes the whole molecule invalid.
The key operation that converts a molecular graph into a tree is called junction tree decomposition. It's borrowed from graph theory but applied to chemistry.
Any molecule can be broken into a set of overlapping subgraphs called clusters. JT-VAE defines clusters as:
Every atom appears in at least one cluster. Every bond appears in exactly one cluster (either as a simple bond or as part of a ring).
Now create a tree where each node represents one cluster. Two cluster-nodes in the tree are connected if the corresponding clusters share at least one atom. Because molecular graphs are close to trees themselves (most drug-like molecules have at most a few rings), this "cluster graph" is usually already a tree or close to one. If there are cycles in the cluster graph, they are contracted using a maximum spanning tree.
Over the ZINC dataset (250,000 drug-like molecules), the vocabulary of unique clusters has only ~780 distinct types. This is the key compression: from an enormous atom-level space to a compact library of meaningful building blocks. At generation time, the decoder only needs to select from these 780 vocabulary items, not from the unbounded space of atoms + bonds.
A schematic showing how a molecule (atoms + bonds) is converted to a junction tree (cluster nodes + connections). Click a cluster node to highlight its atoms in the molecule.
Now we can see the full two-level representation that JT-VAE maintains for every molecule: the atom-level graph (the actual molecule) and the cluster-level tree (the junction tree).
Left: the molecule graph (atoms = nodes, bonds = edges). Right: the junction tree (each tree node = a cluster). Click a tree node to highlight the corresponding atoms in the molecule. Build the tree step by step using "Add Cluster."
For encoding, both representations are processed simultaneously — the atom graph and the junction tree are both read to produce the latent vector z. For decoding, the junction tree is generated first (which clusters to use and how to connect them), then the atom-level graph is assembled from the specified clusters (which atoms within each cluster to share with neighboring clusters).
The encoder takes a molecule and produces a latent vector z. It processes both the atom-level graph and the junction tree, then combines them.
A standard message-passing GNN runs over the molecular graph for T=3 steps. At each step, atoms aggregate messages from neighbors. The final atom representations are mean-pooled into a single graph embedding hG.
A separate message-passing procedure runs over the junction tree. This is a two-phase process: bottom-up messages from leaves to root, then top-down messages from root to leaves. Each cluster node i gets a representation si. The root cluster's final state becomes the tree embedding hT.
The final latent vector z is sampled from a Gaussian whose mean and variance are parameterized by concatenating both embeddings:
where μ(·) and σ²(·) are linear projections. This z is a continuous, ~56-dimensional latent representation of the molecule — smooth enough for Bayesian optimization or latent space interpolation.
Decoding is the hard part. Given z, we need to reconstruct a molecule. JT-VAE decodes in two phases: first the tree, then the graph.
A tree-structured RNN decodes the junction tree from z. Starting from the root, it predicts: (1) the cluster label (which of the ~780 vocabulary substructures is at this tree node?), and (2) whether to grow more children (is this cluster connected to more clusters?). It recurses depth-first until it generates a stop signal at every leaf.
Each prediction is a softmax over the vocabulary (for labels) or a binary decision (for topology). The decoder is conditioned on z throughout — z is concatenated to the GRU hidden state at every step.
Now the tree is fixed — we know which clusters are present and how they connect. The graph assembly step fills in the "attachment chemistry": for each edge in the junction tree (connecting cluster Ci to cluster Cj), we must decide which atoms in Ci bond to which atoms in Cj.
This is done greedily: enumerate all valid attachment configurations between Ci and Cj, score each with a GNN that reads the current partial graph + z, and pick the highest-scoring configuration. Because the number of attachment options per edge is small (bounded by cluster size), this is tractable.
python # Simplified JT-VAE decoding pseudocode def decode(z, tree_decoder, graph_decoder, vocab): # Phase 1: decode the junction tree root_cluster = tree_decoder.predict_label(z) tree = JunctionTree(root=vocab[root_cluster]) stack = [tree.root] while stack: node = stack.pop() h = tree_decoder.hidden_state(z, node) while tree_decoder.predict_continue(h): child_label = tree_decoder.predict_label(h) child = tree.add_child(node, vocab[child_label]) stack.append(child) # Phase 2: assemble atom-level graph mol = init_from_root(tree.root) for edge in tree.edges_bfs(): Ci, Cj = edge configs = enumerate_attachments(Ci, Cj) scores = graph_decoder.score_all(mol, configs, z) best = configs[scores.argmax()] mol = apply_attachment(mol, best) return mol # guaranteed valid!
JT-VAE claims 100% validity on generated molecules. Let's understand precisely why this is mathematically guaranteed, not just empirically observed.
Step 1 — Every cluster in the vocabulary is valid. The vocabulary is built by extracting clusters from actual molecules in the training set. Every vocabulary item is a chemically valid fragment (correct valence, valid ring system). This is not learned — it's guaranteed by construction from real chemistry.
Step 2 — Tree decoding only picks from the vocabulary. The tree decoder's label prediction is a softmax over the ~780 vocabulary items. It cannot generate a cluster type that isn't in the vocabulary, so every cluster in the generated tree is valid.
Step 3 — Graph assembly only considers valid attachments. When attaching cluster Ci to cluster Cj, the code enumerates all chemically valid ways to share an atom between the two clusters. The decoder picks one of these pre-validated options — it cannot construct an invalid attachment.
Comparison of validity rates across generation approaches. JT-VAE achieves 100% by guaranteeing validity at each hierarchical level. Hover over each bar to see the mechanism.
Does every valid molecule have a junction tree decomposition? Yes — any molecular graph can be decomposed into simple bonds and rings, which are exactly the cluster types JT-VAE uses. The only limitation is the vocabulary: if a ring structure appears in a test molecule but not in any training molecule, it won't be in the vocabulary and can't be generated. In practice, ZINC's ~780 cluster vocabulary covers the vast majority of drug-like structures.
On the ZINC dataset (250k molecules), JT-VAE achieves:
| Metric | CVAE (SMILES) | GVAE (grammar) | JT-VAE |
|---|---|---|---|
| Reconstruction accuracy | 44.6% | 53.7% | 76.7% |
| Validity (random sample) | 0.7% | 7.2% | 100.0% |
| Uniqueness (of valid) | 100% | 100% | 99.9% |
| Novelty (not in train) | 90.0% | 61.0% | 99.9% |
The continuous, smooth latent space enables Bayesian optimization (BO): fit a Gaussian process to observed (z, property) pairs, use the GP acquisition function to select promising z values, decode them to molecules, evaluate, and iterate. This is a classical optimization approach that works beautifully when z is well-structured.
| Method | Penalized logP (top-3) | QED (top-3) |
|---|---|---|
| CVAE + BO | −3.01, −3.32, −3.49 | 0.734, 0.712, 0.706 |
| GVAE + BO | −3.63, −3.91, −3.98 | 0.752, 0.738, 0.730 |
| JT-VAE + BO | 5.30, 4.93, 4.49 | 0.925, 0.911, 0.896 |
Vocabulary bound: Can only generate molecules whose substructures appear in training vocabulary. Novel ring systems not in ZINC are unreachable.
3D ignored: Like GCPN, JT-VAE works on 2D topology. Stereochemistry (chirality, E/Z isomers) and 3D binding geometry are not modeled.
Slow decoding: The two-phase sequential decoding (tree first, then graph assembly with enumeration) is slower than a single forward pass. Tree RNN is O(N) steps, graph assembly involves per-edge enumeration.
Fixed vocabulary assumption: If you want to generate macrocycles or highly unusual heterocycles, you need a training set that contains them — the vocabulary-based approach constrains the generation space.
| Granularity | Method | Validity | Flexibility |
|---|---|---|---|
| Character (token) | CharVAE (SMILES chars) | ~35% | Highest |
| Grammar rule | GVAE (context-free grammar) | ~7% | High |
| Atom | GraphRNN, GCPN, MolGAN | ~65–100% | Medium |
| Substructure | JT-VAE | 100% | Vocab-bounded |
| Fragment (larger) | BRICS + ML, Fragment-VAE | ~100% | Lower |
"The key to reliable generation is to work in a space where every point decodes to something valid. For molecules, that means working with known-valid substructures, not raw atoms."
— JT-VAE design philosophy