A beautifully simple idea: score a knowledge graph triple by the element-wise product of head, relation, and tail vectors. Symmetric by construction, surprisingly powerful in practice.
Freebase contains 1.2 billion facts. Yet it's missing 93% of the birthplaces and 78% of the nationalities it should have. Knowledge graphs are enormous — and full of holes.
The task is knowledge graph completion: given a set of known triples (head, relation, tail) — like (Marie Curie, born-in, Warsaw) — predict missing ones. Can we infer that Marie Curie won-Nobel-Prize or was-nationality Polish, even though it's not explicitly in the graph?
A knowledge graph is a collection of facts expressed as triples: (h, r, t) where h is the head entity, r is the relation, and t is the tail entity. The graph might have millions of entities and hundreds of relation types. Completing it by hand is impossible.
A tiny fragment of a real knowledge graph. Missing edges are marked with "?" — these are what we need to predict.
Early approaches used symbolic reasoning — if A is-parent-of B and B is-parent-of C, then A is-grandparent-of C. But these hand-crafted rules don't generalize and can't handle uncertainty. We need something that learns the patterns from data.
The core idea of embedding-based KG completion is simple: map every entity and relation to a low-dimensional vector, then define a scoring function that assigns a high score to true triples and a low score to false ones.
The most general bilinear scoring function for a triple (h, r, t) is:
Here h and t are embedding vectors for the head and tail entities, and Wr is a matrix specific to relation r. This is the full bilinear model — "bilinear" because the score is linear in both h and t separately.
The problem with the full bilinear model is scale. If entity embeddings have dimension d, then Wr has d² parameters — one matrix per relation. With thousands of relations, this explodes. DistMult's insight is to restrict Wr to be a diagonal matrix.
When Wr = diag(r) for a vector r, the bilinear product simplifies dramatically:
This is the DistMult scoring function. Instead of d² parameters per relation, we now have just d — a 10,000x reduction for typical embedding sizes.
The DistMult scoring function is deceptively simple:
Think of it as a three-way dot product: for each dimension i, multiply the head, relation, and tail values together, then sum. This is also called a Hadamard product followed by summation.
Intuitively: if hi and ti are both large (say, both entities are strongly "European"), and ri is also large for dimension i (dimension i is relevant to this relation), the contribution to the score is high. The relation vector r acts as a selector — emphasizing the dimensions that matter for this relation type.
Drag the sliders to set 4D embeddings for head (h), relation (r), and tail (t). Watch how each dimension contributes to the final score 〈h, r, t〉.
For a query like (Marie Curie, won-prize, ?), we compute fr(h, t) for every possible entity t, sort by score, and report the top-k. If the correct answer (Nobel Prize) appears in the top 10, that's a "hit@10."
Here's a key property hidden inside the DistMult formula. What happens if we swap h and t?
Swapping head and tail leaves the score unchanged. DistMult is symmetric by construction: for any relation r, fr(h, t) = fr(t, h). Always.
Relations in knowledge graphs come in three flavors:
Most real-world relations are not symmetric. This is the fundamental limitation of DistMult — and it's precisely what ComplEx (the next paper) is designed to fix.
DistMult is trained with negative sampling. For each true triple (h, r, t), we generate corrupted triples by replacing either h or t with a random entity. We then maximize the score of true triples and minimize the score of corrupted ones.
The loss function is a pairwise ranking loss (or equivalently, a softmax cross-entropy over all possible entities):
Where (hj', t) are negative samples with random head entities. σ is the sigmoid function. The model learns to assign high scores to true triples and near-zero (post-sigmoid) scores to false ones.
Embeddings are L2-regularized to prevent them from growing unboundedly (large embeddings could always produce high scores). The loss becomes:
python import torch import torch.nn as nn class DistMult(nn.Module): def __init__(self, n_entities, n_relations, dim): super().__init__() self.entity_emb = nn.Embedding(n_entities, dim) self.relation_emb = nn.Embedding(n_relations, dim) nn.init.xavier_uniform_(self.entity_emb.weight) nn.init.xavier_uniform_(self.relation_emb.weight) def score(self, h_idx, r_idx, t_idx): h = self.entity_emb(h_idx) # (batch, dim) r = self.relation_emb(r_idx) # (batch, dim) t = self.entity_emb(t_idx) # (batch, dim) return (h * r * t).sum(dim=-1) # (batch,) def forward(self, pos_h, pos_r, pos_t, neg_h, neg_r, neg_t): pos_score = self.score(pos_h, pos_r, pos_t) neg_score = self.score(neg_h, neg_r, neg_t) loss = -torch.log(torch.sigmoid(pos_score)).mean() loss += -torch.log(torch.sigmoid(-neg_score)).mean() return loss
Yang et al. evaluated DistMult on two standard benchmarks: WordNet (WN18) — a lexical database with 18 relation types like hypernym, hyponym, and member-of — and Freebase (FB15k) — a curated subset with 1,345 relation types across 14,951 entities.
The primary metric is Hits@10: for each test triple, rank all possible tail entities by score. What fraction of the time does the true tail appear in the top 10? Higher is better.
| Model | WN18 Hits@10 | FB15k Hits@10 | Parameters |
|---|---|---|---|
| Unstructured | 38.2% | 6.3% | O(n·d) |
| TransE | 89.2% | 47.1% | O((n+m)·d) |
| RESCAL | 52.8% | 44.1% | O(n·d + m·d²) |
| DistMult | 94.2% | 57.7% | O((n+m)·d) |
On FB15k, DistMult's 57.7% also beats TransE's 47.1%. But there's a catch: FB15k contains many inverse relations (e.g., both "capital-of" and "is-capital") that artificially inflate scores. Later work (Toutanova and Chen, 2015) showed DistMult's advantage partly comes from exploiting these inverse pairs.
TransE and DistMult represent two fundamentally different philosophies for modeling knowledge graph relations, and understanding the difference is crucial for knowing when to use each.
TransE (Bordes et al., 2013) models relations as translations in embedding space. A true triple (h, r, t) should satisfy h + r ≈ t. The scoring function is the negative distance: fr(h, t) = −||h + r − t||.
| Property | TransE | DistMult |
|---|---|---|
| Scoring | −||h + r − t|| | 〈h, r, t〉 |
| Relation params | d per relation | d per relation |
| Symmetric relations | Struggles | Natural fit |
| Antisymmetric relations | Handles well | Cannot model |
| Geometric intuition | Translation | Feature selection |
| WN18 Hits@10 | 89.2% | 94.2% |
| FB15k Hits@10 | 47.1% | 57.7% |
This complementarity is not accidental — it points toward a deeper truth. You need a model expressive enough to handle both symmetric and antisymmetric relations. That's exactly what ComplEx provides, by extending DistMult to complex vector spaces.
DistMult is a foundational paper in knowledge graph embeddings — simple enough to analyze mathematically, strong enough to anchor years of subsequent research. Here's where it fits in the broader landscape.
| Model | Key Idea | DistMult Relation |
|---|---|---|
| RESCAL | Full bilinear W_r matrix | DistMult = diagonal RESCAL |
| TransE | Translation h + r ≈ t | Different family entirely |
| ComplEx | Complex-valued DistMult | Fixes symmetry limitation |
| RotatE | Rotation in complex space | Inspired by ComplEx/DistMult |
| TuckER | Tucker decomposition | Generalization of DistMult |
| R-GCN | Graph convolution + DistMult decoder | Uses DistMult as scoring head |
Further reading: