Microsoft’s bid to break the “impossible triangle” — one retention mechanism you can run three different ways: parallel to train, recurrent to deploy, chunkwise to scale.
When you design a sequence model, you want three things at once. Training parallelism — so you can use GPUs efficiently and scale. Cheap inference — ideally constant cost per generated token, with no growing memory. And strong performance — quality that matches the best models. For years, the brutal rule was: pick two. RetNet, from Microsoft in 2023, was named for its attempt to grab all three at once — what the paper called the “impossible triangle.”
Each known approach sat on one edge of the triangle, covering two corners but not the third. RetNet's claim was to reach the center — all three corners at once — through one mechanism called retention, which can be computed in three mathematically equivalent ways, each optimal for a different corner.
This is RetNet's signature idea, and what makes it distinct from its cousins. It's not that RetNet picks a clever middle point — it's that retention is a computation you can literally evaluate three different ways that all produce the identical result. Train with the parallel form (GPU-efficient). Deploy with the recurrent form (constant memory, constant per-token cost). Process very long sequences with the chunkwise form (parallel within chunks, recurrent across them). You switch forms to fit the job, never retraining. The rest of this lesson builds each form and shows they're the same thing.
The widget shows the three corners. Click each architecture to see which corners it covers — transformers and RNNs each light up two, leaving one dark. Then click RetNet and watch all three light up. That “all three” is the goal; the three computation forms are how it's reached.
Three corners: parallel training, cheap inference, strong performance. Click each model to see which corners it covers. RetNet aims for all three via its three computation forms.
The heart of RetNet is retention. The quickest way to understand it: take attention, remove the softmax, and add a fixed decay that makes a token pay less attention to the past the further back it is. Recent tokens count fully; distant tokens fade. That single change — softmax out, distance-decay in — is what makes retention expressible in those three forms.
In retention, how much token n attends to an earlier token m depends on two things: their query-key similarity (as in attention), and a decay factor that shrinks with the gap between them. The decay is a fixed number, call it gamma, between 0 and 1, raised to the power of the distance. So a token one step back is weighted by gamma; two steps back by gamma-squared; k steps back by gamma-to-the-k. The weighting falls off geometrically with distance — an exponential forgetting built right into the mechanism, not learned per token.
Take gamma = 0.9. How much does the current token retain from tokens at various distances back?
| distance back | weight (0.9 ^ distance) |
|---|---|
| 0 (current) | 0.90 = 1.00 |
| 1 | 0.91 = 0.90 |
| 5 | 0.95 = 0.59 |
| 10 | 0.910 = 0.35 |
| 50 | 0.950 ≈ 0.005 |
By 50 steps back, the weight has collapsed to half a percent — effectively forgotten. Lower gamma (say 0.5) forgets much faster; higher gamma (0.99) remembers far longer. This single number sets the model's “memory horizon.” And the genius is that this geometric decay is exactly the form that turns into a simple multiply-by-gamma in the recurrent state — remember the past, scaled down by gamma, then add the new token. The decay you see here is the recurrence, in disguise.
The widget plots how much the current token retains from each past token, under the decay. Drag gamma: a low value makes a sharp, recency-focused curve (short memory); a high value makes a long, slow tail (long memory). This is the shape of retention — and, as later chapters show, the shape of the recurrent state's forgetting.
Retention weight (gamma^distance) for tokens further back. Drag gamma: low = short memory (recency), high = long memory. This decay is what makes retention expressible as a recurrence.
The first of retention's three forms is the parallel form, and it's the one you use for training. It looks almost exactly like attention — compute all the pairwise scores at once as a matrix — which means it's just as GPU-friendly and parallelizable. The only difference from attention is what fills the matrix.
In attention, you compute the query-key score matrix and apply softmax. In retention's parallel form, you compute the query-key score matrix and instead multiply it, elementwise, by a decay mask — a matrix that encodes gamma-to-the-distance for every pair of positions, and is zero for future positions (so a token can't see ahead). Entry (n, m) of the mask is gamma raised to the power (n minus m) when m is in the past, and 0 otherwise. So the mask is lower-triangular, and within the triangle it decays as you move away from the diagonal.
With gamma = 0.9 and 4 tokens, the decay mask (rows = current token, columns = attended token) looks like this — 1 on the diagonal, decaying down-left, zero above:
| t0 | t1 | t2 | t3 | |
|---|---|---|---|---|
| t0 | 1.00 | 0 | 0 | 0 |
| t1 | 0.90 | 1.00 | 0 | 0 |
| t2 | 0.81 | 0.90 | 1.00 | 0 |
| t3 | 0.73 | 0.81 | 0.90 | 1.00 |
Read row t3: it retains its own token fully (1.00), the previous one at 0.90, two back at 0.81, three back at 0.73 — each step further is multiplied by another 0.9. The zeros above the diagonal enforce causality. Multiply this mask elementwise into the query-key scores, then multiply by the values, and you have the retention output for every token — computed all at once, in parallel, exactly like attention. The cost is the same as attention (it builds the full matrix), but that's fine for training, where parallelism is what you want.
The widget shows the decay mask as a grid — bright on the diagonal, fading down-left, black above (the future). Drag gamma and watch the triangle's fade change: low gamma makes a tight bright band near the diagonal (short memory); high gamma spreads brightness deep into the past (long memory). This single matrix, multiplied into the scores, is the entire parallel form.
Each cell (row n, col m) = gamma^(n-m) for the past, 0 for the future. Bright = strong retention. Drag gamma to reshape the fade. Multiply this into query-key scores — that's retention, in parallel.
Now the second form, and the one that wins the “cheap inference” corner of the triangle. The exact same retention computation can be rewritten as a recurrent state update — constant cost per token, constant memory, no growing cache. This is what you use at inference time, when generating one token at a time.
Here's the rewrite, and it falls right out of the per-distance decay from Chapter 1. Keep a running state — a small matrix. At each step: first multiply the state by gamma (this is the decay — everything already stored fades by one step's worth), then add the new token's contribution (its key times its value). To produce the output, multiply the current token's query against the state. That's it: decay, add, read. Two lines, executed once per token.
Start with an empty state. Token 1 arrives: state becomes k₁v₁ (nothing to decay yet). Token 2: first decay — state becomes gamma times k₁v₁ — then add k₂v₂, so state = gamma·k₁v₁ + k₂v₂. Token 3: decay again (everything ×gamma) and add k₃v₃, giving gamma²·k₁v₁ + gamma·k₂v₂ + k₃v₃.
Look at token 1's contribution in the final state: it's been multiplied by gamma twice — once at step 2, once at step 3 — giving gamma-squared, which is exactly gamma-to-the-distance (it's 2 steps back). The recurrence reproduces the decay mask's weights perfectly. And critically: the state is a fixed size, the work per token is constant, and you never look back at past tokens. Generate a millionth token at the same cost as the first.
Step through tokens and watch the recurrent retention state. Each step, the whole state fades by gamma (watch existing entries dim) and the new token's contribution is added (a fresh bright entry). The state never grows. Adjust gamma to see fast vs slow forgetting in the running state — the same gamma that shaped the decay mask now controls how quickly the state forgets.
Step through tokens. Each step: the whole state ×gamma (everything fades), then add the new key×value. Output = query × state. Fixed size, constant cost per token — the cheap-inference corner.
This is the form that makes RetNet genuinely practical, and the one its cousins only gesture at. The parallel form is great for training but quadratic. The recurrent form is great for inference but sequential (slow on GPUs because each step waits for the last). For long sequences during training, neither is ideal — quadratic blows up, sequential wastes the GPU. The chunkwise form is the brilliant compromise: parallel inside chunks, recurrent across them.
Split the long sequence into chunks of, say, 512 tokens each. Then:
So each chunk does two things: it computes the parallel retention within itself, and it adds in the contribution from the recurrent state passed from all previous chunks (decayed appropriately). The output is correct — identical to running the whole sequence in the parallel or recurrent form — but the cost is dramatically better: roughly linear in sequence length, while keeping most of the GPU parallelism. It's the third equivalent form, tuned for the long-sequence regime.
Without chunkwise, the “three forms” story would be incomplete: you'd train with an expensive quadratic form and only save at inference. Chunkwise is what lets RetNet train efficiently on long sequences — the regime where the quadratic transformer hurts most. This is the same trick the whole linear-attention family eventually adopted (the Linear Attention lesson mentioned it), but RetNet's clean decay structure makes the chunkwise form especially natural and exact. It's the practical heart of the architecture.
The widget shows a sequence split into chunks. Within each chunk, tokens are processed in parallel (a small bright block); between chunks, a state arrow carries the summary forward. Drag the chunk size: at maximum it's one big parallel block (pure parallel); at minimum it's a long chain of single tokens (pure recurrent). Watch the cost readout interpolate between quadratic and linear as you slide.
The sequence in chunks. Each chunk computes in parallel (block) and passes a recurrent state to the next (arrow). Drag chunk size between pure-parallel (one big chunk) and pure-recurrent (size 1).
A single decay gamma forces one memory horizon: either you forget fast (good for local patterns, blind to long range) or slow (good for long range, fuzzy on local detail). Real language needs both at once. RetNet's answer is multi-scale retention: like multi-head attention, it has many parallel heads — but each head uses a different gamma. Some heads forget quickly; some remember for a long time. Together they cover a spectrum of timescales.
The idea maps cleanly onto multi-head attention's structure, which is why it's a drop-in. In standard multi-head attention, each head learns to attend to different content. In multi-scale retention, each head additionally has a fixed, distinct decay rate — a different time horizon. Head 1 might use gamma near 0.5 (sharp, recency-focused, captures the last few tokens); head 8 might use gamma near 0.99 (long memory, tracks information from thousands of tokens back). The decay rates are spread across the heads on a fixed schedule. The model can then route different kinds of dependencies to the head with the matching timescale.
On top of multi-scale, RetNet wraps the heads with two refinements borrowed from the modern transformer toolkit: a group normalization on each head's output (stabilizing the unnormalized, softmax-free retention values), and a swish gate — a learned gate that modulates the combined output, adding expressiveness much like a gated linear unit. So the full operation is “gated multi-scale multi-head retention,” but the core is just: several retention heads, each with its own decay clock, normalized and gated.
The widget overlays the decay curves of several heads, each with a different gamma. The fast-decay heads (steep curves) capture nearby tokens; the slow-decay heads (long tails) reach far back. Together they tile the timescales. Adjust the number of heads to see the spectrum get denser — more heads, more memory horizons covered.
Several retention heads, each a different gamma. Steep curves = short memory (local); long tails = long memory (global). Together they cover many timescales at once. Adjust the head count.
This is the payoff that defines RetNet. The same retention — the same weights, the same input — computed three ways: parallel, recurrent, chunkwise. The simulator runs all three on the same little sequence and shows you two things: that they produce the identical output (the equivalence that makes the whole idea work), and that each one has a different cost profile suited to a different corner of the impossible triangle.
Toggle between the forms and watch the output vector stay the same while the cost and parallelism change. This is the “one mechanism, three forms” idea made concrete — pick the form that fits the job, never retrain:
Toggle the form. The output (top) is identical across all three — that's the equivalence. The cost, memory, and parallelism (bars) differ — that's why each form wins a different corner of the triangle.
No quiz — the simulator is the test. If you can explain why the output is identical while the cost differs across the three forms, you understand RetNet's core contribution.
We have retention (the three forms) and multi-scale heads. How is the full model assembled? Exactly as you'd now expect: RetNet reuses the transformer's block structure, swapping the attention sublayer for a retention sublayer. If you know a transformer block, you already know a RetNet block — with one component replaced.
A RetNet block has two sublayers, each wrapped in a residual connection and normalization (pre-norm, like modern transformers):
Stack many of these blocks, add token and positional handling, and you have a RetNet model — structurally a transformer with retention in place of attention. This is the same “swap the token mixer, keep the scaffolding” pattern shared by RWKV, Mamba, and xLSTM. The residual + norm backbone that makes deep transformers trainable carries over unchanged, so RetNet inherits all that hard-won training stability.
The widget shows a RetNet block. Notice the familiar transformer skeleton — two residual sublayers with normalization — with gated multi-scale retention where attention would be. Compare it mentally to a transformer block: everything is the same except the token-mixing sublayer.
Two residual sublayers (norm + residual), transformer-style: gated multi-scale retention replaces attention; the feed-forward network is unchanged. Click the retention sublayer to expand its internals.
RetNet shares the 2024-era recurrent-revival DNA with RWKV, Mamba, and xLSTM — all fixed-state, linear-cost, constant-memory-inference models. Having built RetNet in detail, you can now see exactly what it shares with its cousins and what sets it apart. The honest summary: same family, distinctive framing.
Like the whole family, RetNet drops softmax, carries a fixed-size recurrent state, and can train in parallel while inferring cheaply. Its retention state — accumulate key×value with a decay, query to retrieve — is the same matrix-valued associative memory as linear attention's KV state, the mLSTM's matrix memory, and (with different parameterization) Mamba's SSM state. The grand convergence from the xLSTM lesson includes RetNet: yet another lineage arriving at data-dependent... well, almost — which brings us to the difference.
| Model | Decay/forget | Named forms |
|---|---|---|
| RetNet | fixed per-head decay (multi-scale) | parallel, recurrent, chunkwise (explicit) |
| RWKV | learned per-channel decay (data-dep in v6) | parallel / recurrent |
| Mamba | input-dependent (selective) | parallel scan / recurrent |
| xLSTM | exponential gating (data-dep) | parallel (mLSTM) / recurrent |
| Transformer | none (full attention) | parallel only (quadratic) |
Is fixed decay a weakness? It's a tradeoff. Data-dependent forgetting (Mamba, RWKV-6) can adapt what to remember based on content, which helps on tasks needing precise, selective recall. RetNet's fixed multi-scale decay is less adaptive but simpler, more stable, and keeps the three-form equivalence exact and efficient. In practice, the family has largely moved toward data-dependent decay (it tends to win on quality), but RetNet's clean three-form framework — especially the chunkwise insight — influenced everyone. It's a landmark for how to think about these models, even where its specific fixed-decay choice was superseded.
Select a model to see its decay type and which computation forms it offers. RetNet stands out for explicitly providing all three named forms; the others share the parallel/recurrent duality but treat chunkwise as an implementation detail. On the decay axis, RetNet's “fixed” sits apart from the family's drift toward “data-dependent.”
Select a model. See its decay type (fixed vs data-dependent) and which computation forms it exposes. RetNet is the one that makes all three forms explicit.
You now understand RetNet completely: the impossible triangle it targets, retention as softmax-free attention with a per-distance decay, the parallel form for training, the recurrent form for inference, the chunkwise form for long sequences, multi-scale heads for many timescales, the transformer-style block, and where it sits in the family. The thread: one retention mechanism, three equivalent computation forms, each winning a different corner of the parallel/cheap/strong triangle.
“The same truth can be told three ways — choose the telling that fits the moment.”