Seq2seq models use the chain rule to decompose joint probabilities. But the chain rule works in any order — and some orders are dramatically better than others. What do you do when your data has no natural order at all?
You have five numbers: {3, 1, 4, 1, 5}. Your job is to sort them. Simple for a human, but how do you feed them to a neural network?
You could line them up left to right: 3, 1, 4, 1, 5. Feed them into an LSTM encoder, then decode the sorted output: 1, 1, 3, 4, 5. This is the sequence-to-sequence approach, and it works. But there is something deeply wrong with it.
The input is a set. The numbers {3, 1, 4, 1, 5} and {5, 4, 3, 1, 1} and {1, 3, 1, 5, 4} are all the same input. The answer should be identical regardless of how you arrange them. But an LSTM reads left to right. It sees 3-then-1-then-4 as a fundamentally different sequence from 5-then-4-then-3. You are forcing an order onto data that has none.
The problem cuts both ways. Sometimes the input is a set (like numbers to sort, or objects detected in an image). Sometimes the output is a set (like predicting which objects are in a scene, where the order of your predictions shouldn't matter). Sometimes both are sets.
Vinyals, Bengio and Kudlur tackle both sides. For input sets, they propose the Read-Process-and-Write architecture, which uses attention to build an order-invariant encoding. For output sets, they propose searching over orderings during training, letting the model discover the best order on its own.
Seq2seq models work by decomposing the output probability using the chain rule:
This is mathematically exact. No approximations, no independence assumptions. You just predict one token at a time, each conditioned on everything before it. An LSTM implements this naturally: at each step, its hidden state summarizes all previous outputs, and a softmax predicts the next one.
But here is the catch. The chain rule works in any order. For three random variables, all of these are equally valid:
In theory, a sufficiently powerful model should learn the correct joint distribution regardless of which factorization you choose. In practice, it does not. The order you pick determines which conditional distributions the model must learn. Some conditionals are easy. Others are brutally hard.
Think about language. English sentences flow left to right, and each word is heavily predicted by the words before it. "The cat sat on the ___" strongly constrains the next word. But what if you modeled the sentence right to left? "the on sat cat The" — predicting "The" from "the on sat cat" is much harder because the conditioning context is unnatural. The dependencies flow against the grain.
For sequences like English, there is a natural order. But for sets, there is no natural order. Every order you pick is arbitrary. And as we will see, some arbitrary choices are catastrophically worse than others.
Before proposing a solution, the paper first demonstrates the problem with hard evidence. Consider three cases where changing the input order changed everything:
Machine translation. Sutskever et al. (2014) found that reversing the input English sentence before translating to French improved BLEU score by 5 points. Just reading "cats like I" instead of "I like cats" — same words, reversed order — produced dramatically better translations. Why? Because reversing puts the first English words closest to the first French words in the decoder, reducing the effective distance the LSTM must bridge.
Constituency parsing. When parsing English sentences into syntax trees, reversing the input sentence improved F1 by 0.5% absolute. Again, just flipping the reading order.
Convex hull computation. When computing the convex hull of a set of 2D points, sorting the input points by angle before feeding them to the model increased accuracy by up to 10% absolute. Sorting transformed the task from O(n log n) complexity to O(n), making it trivial for the network.
For naturally sequential data like language, we at least have a reasonable default order (left to right). But what about sets? If you have 15 numbers to sort, which of the 15! ≈ 1.3 trillion possible input orderings should you pick? The answer is: none of them. You need an architecture that does not care about order at all.
Drag items to reorder them. Watch how a simulated seq2seq encoder produces a different hidden state for each ordering of the same set — even though the set is identical. The bar chart shows the encoder's final hidden state vector.
The simulation above illustrates the core problem. A sequential encoder (like an LSTM) produces a different internal representation for every ordering of the same set. This means the decoder sees a different "summary" of the input depending on an arbitrary choice that has nothing to do with the task. The model must waste capacity learning to be invariant to these spurious differences.
How do we build an encoder that genuinely does not care about input order? The simplest approach is the bag of words: just add up the embeddings of all input elements. Adding is commutative, so {A, B, C} and {C, A, B} produce the same sum. Problem solved?
Not quite. Addition throws away all structural information. The sum of embeddings for {3, 7} is the same as for {5, 5} if 3 + 7 = 5 + 5 in embedding space. Worse, the representation has a fixed dimensionality regardless of how many items are in the set. You are trying to cram an arbitrarily large set into a single fixed-size vector. For large sets, information is inevitably lost.
The solution is content-based attention. Instead of crushing all inputs into a single vector, we keep them all around as a memory bank. When the model needs information, it queries this memory and retrieves a weighted combination of the stored items. The key property: the attention mechanism computes a weighted sum over memory slots, and a weighted sum is invariant to the order of the slots.
Here is how it works. We have memory vectors m1, ..., mn (one per input element) and a query vector q. The attention mechanism computes:
Where f is a scoring function (e.g., a dot product). The readout r is a weighted average of the memories. If you permute the memories — swap mi and mj — the weights ai and aj also swap, and the sum r stays the same. Order invariance is guaranteed by the mathematics of weighted summation.
But a single attention readout may not capture everything. What if the model needs to reason about relationships between items? That requires multiple rounds of attention — reading the memory repeatedly, each time with a different query that has been updated by what was read previously.
The paper's main architectural contribution is the Read-Process-and-Write model. It has three stages:
The Process block is the clever part. It is an LSTM that takes no inputs and produces no outputs. It just thinks. At each of its T processing steps, it:
After T steps, the final state q*T is a rich, order-invariant summary of the entire input set. The number of processing steps T is a hyperparameter: more steps let the model perform more complex reasoning about relationships within the set.
The Write block uses a pointer network: instead of outputting tokens from a fixed vocabulary, it points at specific input elements. For sorting, this means it points at the input numbers in sorted order. The write block can also use an additional attention step (called a glimpse) before each pointer output, which the paper found significantly improves performance.
You can think of the entire architecture as a special case of a Neural Turing Machine or Memory Network, but specifically designed to guarantee permutation invariance of the input encoding.
The paper tests the Read-Process-Write architecture on a clean synthetic task: sorting N random numbers between 0 and 1. This is a pure set-to-sequence problem. The input is an unordered set. The output is a specific sequence (the sorted order).
They compare two approaches:
| Ptr-Net (baseline) | Read-Process-Write | |
|---|---|---|
| Encoder | LSTM reads numbers sequentially | Embed each number, then Process block with attention |
| Decoder | Pointer network | Pointer network (same) |
| Order invariant? | No | Yes |
The results tell a clear story:
| N | Ptr-Net | P=0 steps | P=1 step | P=5 steps | P=10 steps |
|---|---|---|---|---|---|
| 5 | 90% | 84% | 92% | 94% | 94% |
| 10 | 28% | 30% | 44% | 57% | 50% |
| 15 | 4% | 2% | 5% | 4% | 10% |
(All results with glimpses enabled. Accuracy = fraction of sequences sorted perfectly.)
Notice how rapidly accuracy drops with N. Sorting 5 numbers is nearly solved (94%). Sorting 15 is still very hard (10%). This is a combinatorial problem: the number of possible orderings grows as N!, and the model must find the single correct one. Still, the Read-Process-Write architecture consistently outperforms the sequential baseline on this set-to-sequence task.
The glimpse mechanism also matters enormously. Without glimpses, the best N=10 accuracy is only 19%. With glimpses, it reaches 57%. The glimpse is an extra attention step that lets the Write block look at the memory between each pointer output, providing fine-grained, context-dependent information.
We have handled input sets. Now let us flip the problem. What happens when the output is a set?
The chain rule forces you to produce outputs one at a time, in some order. But if the output is a set, every ordering is equally valid. The sorted output {1, 3, 5} could be produced as 1→3→5 or 5→3→1 or 3→1→5 — all represent the same answer. Which factorization should the model learn?
The paper demonstrates the impact with two experiments:
Language modeling. They train LSTMs on Penn Treebank text in three orderings:
| Order | Example | Perplexity |
|---|---|---|
| Natural | "This is a sentence ." | 86 |
| Reversed | ". sentence a is This" | 86 |
| 3-word scramble | "a is This <pad> . sentence" | 96 |
Natural and reversed perform identically — both preserve local word dependencies (just mirrored). But the 3-word scramble, which breaks n-gram structure, costs 10 perplexity points. The model's capacity is wasted trying to learn scrambled conditionals.
Constituency parsing. A parse tree can be linearized in two ways: depth-first traversal or breadth-first traversal. Same tree, same information, different orderings. Depth-first achieves 89.5% F1. Breadth-first drops to 81.5% — an 8-point gap from ordering alone.
For combinatorial problems like sorting, the situation is even more dramatic. If you treat the output indices as a set and train with random orderings, the model must place equal probability on all N! valid orderings for every input. This is catastrophically inefficient: for N=5, there are 120 valid outputs for each input. The model's probability mass is spread thin, and convergence is painfully slow or impossible.
If the output order matters but we do not know the best one, can we let the model find it?
The paper proposes a surprisingly simple idea. Instead of maximizing log-probability under a fixed ordering, maximize over all possible orderings:
For each training example, find the ordering π that gives the highest probability under the current model, and train on that ordering. The model simultaneously learns the parameters and discovers the best order.
But there are N! possible orderings. You cannot try them all. The paper addresses this with two tricks:
They test this on 5-gram language modeling. Each 5-gram "This is a five gram" is converted to a set of (word, position) tuples: {(This,1), (is,2), (a,3), (five,4), (gram,5)}. The model must produce these tuples, but can choose any ordering.
| Setup | Orderings considered | Perplexity |
|---|---|---|
| Natural order (1,2,3,4,5) | 1 | 225 |
| Scrambled (5,1,3,4,2) | 1 | 280 |
| Easy search (2 options) | 2 | 225 |
| Full search (5! options) | 120 | 225 |
This is remarkable. The model is given a set with no ordering information. Through training dynamics alone, it discovers that left-to-right (or right-to-left) English word order minimizes perplexity. The chain rule's factorization order is not just a modeling choice — the model can learn the best one.
Each bar is one of the possible orderings of a set. Heights show log-probability under the current model. Watch as training progresses: the model converges to prefer one ordering over all others. Click Train to step through the process.
The paper validates its ideas on standard benchmarks, not just synthetic tasks.
Language modeling on Penn Treebank. The key finding is not about achieving state-of-the-art perplexity, but about demonstrating the ordering effect at scale. Natural order and reversed order match at 86 perplexity. The 3-word scramble degrades to 96. Even the mighty LSTM cannot fully compensate for a bad ordering — and training perplexity is also 10 points higher, confirming that the model struggles to learn the scrambled conditionals, not just to generalize them.
Constituency parsing. The depth-first vs breadth-first comparison (89.5% vs 81.5% F1) shows that output ordering is not a minor detail. It is an architectural decision with consequences comparable to changing the model size or training data.
Graphical model estimation. The paper generates star-shaped graphical models: one "head" variable connected to several "leaf" variables. The leaves are conditionally independent given the head. They train LSTMs to model the joint probability in two orderings: head-first vs head-last.
This is perhaps the deepest insight of the paper. The best ordering for the chain rule is the one that aligns with the causal structure of the data. When causes come before effects in the factorization, each conditional is simple. When effects come before causes, the conditionals become complex marginalizations that are hard for a finite-capacity model to learn.
Pointer Networks (Vinyals et al., 2015). The Write block of Read-Process-Write uses pointer networks as its output mechanism. Instead of selecting from a fixed vocabulary, it points at input elements. This is essential for combinatorial problems where the output is a permutation of the input.
Neural Turing Machines (Graves et al., 2014) and Memory Networks (Weston et al., 2015). The Read-Process-Write architecture can be viewed as a special case of these external-memory architectures. The memory bank stores input embeddings, and the Process block reads from it using content-based addressing. The key specialization is the focus on permutation invariance.
Set functions and Deep Sets. This paper is an early step toward the later "Deep Sets" framework (Zaheer et al., 2017), which formalized permutation-invariant functions. The insight that set encodings must be order-invariant was foundational. Deep Sets proved that any permutation-invariant function can be decomposed as ρ(∑ φ(xi)), but the Read-Process-Write model shows that multiple rounds of attention can be more powerful than a single sum.
Transformers (Vaswani et al., 2017). The Process block — multiple rounds of self-attention over a set of memory vectors — anticipates a core idea of Transformers. In a Transformer, every layer performs attention over all positions. Positional encodings are added to inject order; without them, a Transformer is a set function. The connection is direct: both architectures process sets through iterative attention.
Paper impact. Published at ICLR 2016, this paper influenced the development of set-based neural architectures, non-autoregressive decoding strategies, and the understanding of inductive biases in sequence models. Every time you add positional encodings to a Transformer, you are acknowledging the insight from this paper: without explicit order information, attention treats its inputs as a set.