Google DeepMind’s recipe for transformer-quality language models that run like RNNs — a gated linear recurrence for cheap global memory, local attention for sharp recent recall, and constant-size state at inference.
You have a trained transformer and you want it to read a long document — a whole book, a codebase, an hour of transcribed speech. Two costs blow up on you, and both come from the same place: attention.
First, compute. To produce one token, attention compares it against every previous token. For a sequence of length n, that is n comparisons for the last token, and the work to process the whole sequence scales like n². Double the context, quadruple the cost.
Second — and this is the one that actually kills you at deployment — memory. To avoid recomputing attention from scratch at every step, transformers cache the key and value vectors of every token they have seen. This KV cache grows linearly forever. At 100,000 tokens you are storing 100,000 keys and values per layer, per head. The cache, not the math, is what makes long-context generation expensive and slow.
Now picture the opposite extreme: a plain recurrent network (RNN). It reads one token, updates a single fixed-size hidden state, and moves on. To generate token number 100,000 it needs exactly the same memory as token number 3 — one state vector. Constant memory. Constant cost per token. Beautiful for inference.
So why did everyone abandon RNNs for transformers? Because classic RNNs are slow to train (each step waits for the previous one — no parallelism across the sequence) and they forget (gradients vanish over long ranges). Transformers train in parallel and remember everything. We traded inference efficiency for trainability.
Slide the sequence length. The transformer’s KV cache (orange) grows with every token; the recurrent state (teal) is flat — the same size whether you’ve read 10 tokens or 10,000.
Two names, one idea. Hawk is the pure-recurrence model: blocks built around a new recurrent unit called the RG-LRU. Griffin is the hybrid: mostly RG-LRU blocks, with a local attention layer mixed in every few blocks. By the end of this lesson you will understand every piece — the linear recurrence, why it can be trained in parallel, the gates that make it selective, and why a little local attention is the perfect complement.
We want RNN-style inference: one state, updated per token. The simplest possible update that still carries memory is a linear recurrence:
Read it out loud: the new state ht is a fraction a of the old state plus a fraction b of the new input xt. The number a is the recurrence weight — how much of the past you keep. The number b is the input weight — how much of the present you let in. In the real models h and x are vectors and a, b act element-by-element (a diagonal recurrence), but a single number teaches the whole idea.
The word that matters is linear. There is no nonlinear function wrapped around the
state — no tanh, no sigmoid squashing ht-1 before
it is reused. That one restriction is the secret to everything that follows. We will cash it in next
chapter to train this thing in parallel. For now, let us see what it remembers.
Substitute the recurrence into itself. Start from h0 = b·x0:
The state is a weighted sum of all past inputs, where an input from k steps ago is weighted by ak. Because a is between 0 and 1, older inputs fade geometrically. This is an exponential moving average with decay rate a. The value of a is a memory dial: near 1, the past lingers for hundreds of steps; near 0, the state is basically just the latest input.
Let a = 0.8, b = 1, and feed the inputs x = [5, 0, 0, 0] (one spike, then silence). Watch the state decay:
| t | xt | computation | ht |
|---|---|---|---|
| 0 | 5 | 1·5 | 5.00 |
| 1 | 0 | 0.8·5.00 + 1·0 | 4.00 |
| 2 | 0 | 0.8·4.00 | 3.20 |
| 3 | 0 | 0.8·3.20 | 2.56 |
Each step keeps 80% of the previous value: 5 → 4 → 3.2 → 2.56, exactly 5·0.8t. The single input from t=0 is still echoing three steps later. That echo is the memory. Make a bigger and the echo lasts longer; make it smaller and it dies faster.
We feed a single spike at the start, then silence. The bars show the state decaying as at. Drag a: near 1 the memory rings for a long time; near 0 it vanishes immediately.
python — linear recurrence, the recurrent (inference) form def linear_recurrence(x, a, b): # x: list of inputs h = 0.0 out = [] for xt in x: h = a * h + b * xt # one multiply-add, O(1) memory out.append(h) return out # linear_recurrence([5,0,0,0], a=0.8, b=1.0) -> [5.0, 4.0, 3.2, 2.56]
This loop is exactly how Hawk generates text: constant state h, one cheap step per
token. But a for loop over the sequence is death for training on a GPU —
every step waits for the one before it. Next chapter, the linearity pays off and we delete the loop.
Here is the problem. The recurrence loop is sequential: h5 needs h4 needs h3… On a GPU with thousands of cores, that loop uses one of them and the rest sit idle. Transformers don’t have this problem — every token’s attention is computed at once. To compete, the recurrence must also be computed all at once. Linearity is what lets us.
Think of each step as carrying a little package (a, b·x) — a decay factor and an additive contribution. Applying step t after step s means: first scale the running state by as and add bsxs, then scale by at and add btxt. Two such steps compose into one equivalent step:
where c = b·x. The key fact: this combine operator is associative. Grouping doesn’t matter — (A•B)•C equals A•(B•C). And anything associative can be computed by a parallel scan (also called a prefix-sum): a tournament-style tree that produces all the running states in log n parallel rounds instead of n sequential steps.
So the same math has two faces. Training: lay out the whole sequence, run the parallel scan, get all ht at once — GPU-friendly, like attention. Inference: throw away the scan and just do the one-step loop — constant memory, like an RNN. One computation, two execution modes. (If that rings a bell, it is the same duality RetNet and Mamba exploit.)
tanh around
ht-1, the combine step would no longer be a clean
“scale-and-add,” the operator would lose associativity, and the parallel scan would be
impossible. You would be stuck with the sequential loop — back to a slow-training RNN. The
nonlinearities in Griffin live outside the recurrence (in gates and MLPs), never inside it.
Press Step to advance. The sequential loop (top) does one token per tick — n ticks. The parallel scan (bottom) combines pairs in a tree — finished in about log⊂2 n ticks. Watch the tree light up far faster.
For an 8-token sequence the loop needs 8 ticks; the scan finishes in 3 (because 2³ = 8). For 2048 tokens: 2048 versus 11. That is why a linear recurrence trains at transformer speed. DeepMind actually wrote a custom TPU kernel for this scan, because the memory-movement pattern, not the arithmetic, is the bottleneck — but the idea is exactly this tree.
The recurrence multiplies the state by a every step. Unroll it and a contribution from k steps back is scaled by ak. That single fact decides whether the model lives or dies:
We must guarantee the recurrence weight stays in the safe zone — not hope the optimizer keeps it there. The trick is to never learn a directly. Instead learn an unconstrained number Λ and pass it through a squashing function:
Now no matter what value the optimizer pushes Λ to — minus a million or plus a million — a is mathematically pinned inside (0, 1). The model is stable by construction. (Griffin actually parameterizes it in log-space for numerical comfort, but the guarantee is this sigmoid.)
Earlier linear-recurrence models (S4, the LRU) used complex numbers for a, so the state could rotate as well as decay — useful for modeling oscillations. DeepMind found that for language a plain real diagonal recurrence works just as well and is simpler and faster: no complex arithmetic, just a per-channel decay. “Diagonal” means each channel of the state has its own independent a — no mixing between channels inside the recurrence (mixing happens in the linear projections around it). That keeps the scan a cheap element-wise operation.
State magnitude over time after a unit input. Below 1: decays (stable, teal). At 1: holds (knife-edge, yellow). Above 1: explodes (red — note the axis rescales). The sigmoid parameterization makes the red zone unreachable.
A fixed decay a treats every token the same. But language isn’t uniform: the name of the protagonist on page 1 should persist; the word “the” should not. We want the model to decide, per token, how much to remember and how much to let in. That decision-making is the RG-LRU — the Real-Gated Linear Recurrent Unit, the heart of Hawk and Griffin.
Both gates are simple: take the current input xt, run a linear layer, squash with a sigmoid (so each is a number in (0,1)):
The input gate it filters the incoming token — like deciding how loudly this word gets to speak into the state. The recurrence gate rt is cleverer: it modulates the decay itself. The effective decay becomes
Stare at this. When the gate rt → 0, the exponent → 0, so at → a0 = 1: the state is preserved almost perfectly — “this token says: hold everything, don’t decay.” When rt → 1, the exponent is its largest, so at drops toward a8 (a strong decay) — “flush the old state, the topic just changed.” The gate slides the decay between “remember everything” and “forget aggressively,” per token, per channel.
The first term is the gated memory. The second term is the gated input — but notice the input weight is √(1 − at²), not a free parameter. This is a variance-preserving normalizer. If you keep a fraction at of the old state, this exact input weight ensures the state’s variance stays roughly constant instead of blowing up or shrinking as gates change. When at is near 1 (holding memory), √(1−at²) is near 0 (let almost nothing new in); when at is near 0 (forgetting), it is near 1 (let the new token dominate). The two terms trade off automatically.
Base decay a = 0.9, constant c = 8, current state ht-1 = 2.0, input xt = 1.0. Compare a “keep” token vs. a “flush” token.
| case | rt | at = 0.98r | √(1−at²) | it | ht |
|---|---|---|---|---|---|
| keep memory | 0.1 | 0.90.8 = 0.920 | 0.392 | 0.5 | 0.920·2 + 0.392·0.5·1 = 2.04 |
| flush memory | 0.9 | 0.97.2 = 0.467 | 0.884 | 0.5 | 0.467·2 + 0.884·0.5·1 = 1.38 |
Same input, same base decay — but the “keep” token barely moves the state (2.0 → 2.04, the past dominates), while the “flush” token pulls it down hard (2.0 → 1.38, the old memory is being discarded). The model learned when to do which, from data. That selectivity — the same idea Mamba calls “input-dependent dynamics” — is what lets a fixed-size state punch above its weight.
A sequence of input spikes flows in. Set the base decay and the recurrence gate. Low gate = the state integrates and holds (long memory); high gate = each new token wipes the slate. Watch the teal state line respond.
python — one RG-LRU step (per-channel, vectorized with numpy) import numpy as np def rg_lru_step(h, x, Lam, Wa, Wx, c=8.0): a = sigmoid(Lam) # base decay per channel, in (0,1) r = sigmoid(Wa @ x) # recurrence gate, in (0,1) i = sigmoid(Wx @ x) # input gate, in (0,1) a_t = a ** (c * r) # data-dependent decay h = a_t * h + np.sqrt(1 - a_t**2) * (i * x) return h # new state, same shape as h
The RG-LRU is the engine; the recurrent block is the car built around it. Hawk and Griffin stack many copies of one residual block. Let us trace a real tensor through it so you could implement it from memory. Say the model width is D = 1024 and we process a sequence of T = 2048 tokens, so the input to the block is shaped [T, D] = [2048, 1024].
Like a transformer, the block is two residual sub-layers: a temporal-mixing sub-layer (this is where the recurrence lives) and a channel-mixing MLP. Each is preceded by RMSNorm and added back to the residual stream:
This is where Griffin’s design echoes Mamba: the normed input splits into two branches that are multiplied back together at the end — a gating structure.
GeLU. This branch carries no recurrence;
it is a learned, element-wise multiplicative mask applied at the end.Output shape: [2048, 1024], same as the input — so blocks stack cleanly. Note where the nonlinearities are: the GeLU, the conv, the MLP all sit outside the recurrence. The RG-LRU itself stays strictly linear in its state, which is exactly what keeps the parallel scan from Chapter 2 valid. That placement is not an accident — it is the whole design constraint.
Click any box to see its shape and what it does. Follow the two branches from the split to the multiply.
The RG-LRU gives us cheap, smooth, long-range memory — but it is lossy. A fixed-size state can only summarize the past; it cannot reach back and read an exact earlier token verbatim. For some jobs you need exactly that: “copy the variable name I defined two lines ago,” “match this closing bracket to that opening one.” That is attention’s superpower — precise, content-based recall. Griffin keeps a little of it.
The expensive part of full attention is that every token looks at all previous tokens. Local (sliding-window) attention restricts each token to the most recent W tokens — a fixed window, say W = 1024. Token at position t attends only to positions t−W+1 … t.
Two consequences, both wonderful:
On its own, yes. A pure local-attention model literally cannot see past W tokens. This is exactly why Griffin is a hybrid. The recurrence branch carries information from arbitrarily far back (compressed into the state), while local attention handles precise recent lookups. Neither alone is enough; together they cover both regimes — long-range summary from the RG-LRU, short-range detail from local attention. The window doesn’t need to be long because the recurrence already owns the long range.
Each row is a token attending to earlier tokens. Full causal attention (left of the toggle) fills the whole triangle — the cache grows with the sequence. Sliding-window attention keeps only a band of width W — constant cache. Drag W and the query position.
Now assemble it. Griffin interleaves the two block types we just built: mostly recurrent blocks (RG-LRU), with a local-attention block mixed in periodically — the paper’s recipe is roughly two recurrent blocks for every one local-attention block. Hawk is the special case of all recurrent blocks, no attention at all.
The simulator below is a small Griffin stack processing a sequence. A piece of information enters at one token and we watch how far it can travel through the network. In a recurrent layer it propagates through the state — reachable from any later position (long range, but compressed). In a local-attention layer it can only jump within the window (short range, but exact). Stacking them lets information route through whichever mechanism fits. Below the stack, a live tally compares the inference memory of this hybrid against a pure transformer of the same depth as the sequence grows.
Each row is a layer (teal = recurrent / RG-LRU, blue = local attention). Set the layer pattern, the attention window, and the sequence length. Click a token in the bottom row to inject information and press Propagate to watch which later tokens it can reach. The memory readout shows why the hybrid stays cheap.
Play with the ratio. Set it to 0:1 (all attention, window-limited) and information can only crawl forward one window per layer — and the memory readout climbs as the sequence grows. Set it to 4:0 (pure Hawk) and every later token is reachable through the recurrence, with flat memory — but there is no exact-copy path. The sweet spot in the middle gives you both: a token far in the past stays reachable via the recurrence, while the attention layers provide crisp local lookups, and inference memory stays nearly flat as the sequence grows.
Architecture choices only matter if they change the bill. Let us make the inference advantage concrete with the three cost axes that decide whether you can serve a model cheaply.
A transformer must keep the KV cache of every past token. For a sequence of length n, with L layers, the cache scales like n · L — unbounded. A Griffin recurrent layer keeps one fixed-size state regardless of n; a Griffin local-attention layer keeps a cache capped at the window W. So Griffin’s state is roughly (recurrent layers × D) + (attention layers × W) — a constant that does not grow with the sequence. At n = 100k this is the difference between gigabytes and megabytes.
Lower memory traffic means each generated token is produced faster, and crucially the per-token cost stays flat as the sequence grows, instead of creeping upward as the KV cache balloons. Long-context generation is where Griffin pulls clearly ahead of a same-size transformer.
Because the state is small, you can fit many more concurrent sequences in memory and batch them — so total tokens-per-second served goes up. For a serving system, throughput-per-dollar is often the number that matters, and constant-size state is what unlocks it.
Transformers tie meaning to absolute or relative positions via positional encodings; push them past their training length and those encodings hit values never seen, and quality falls off a cliff. Griffin’s recurrence has no length-dependent positional scheme — it just keeps applying the same per-step update — and its attention is local, so it never sees a relative distance bigger than W. Neither component cares how long the sequence is. That structural fact is why Griffin generalizes to far longer inputs than it trained on.
Relative inference-state size as the sequence grows. Transformer (orange): linear, unbounded. Hawk (teal): flat — pure recurrence. Griffin (purple): nearly flat — recurrence plus a capped local-attention cache. Drag the length and watch the gap open.
Everything in one place. If you can reconstruct this table and the equations beneath it, you understand Griffin and Hawk well enough to teach them.
| Model | Long-range mechanism | Recall mechanism | Inference state |
|---|---|---|---|
| Transformer | full attention | full attention | KV cache — grows with n |
| Hawk | RG-LRU recurrence | (none exact) | fixed state — constant |
| Griffin | RG-LRU recurrence | local attention | state + capped window — constant |
| Mamba | selective SSM | (none exact) | fixed state — constant |
| RetNet | retention (decay) | chunk attention | fixed state — constant |
| RWKV | WKV recurrence | token-shift mix | fixed state — constant |
The shared idea across this whole row of “efficient” models: replace the growing KV cache with a fixed-size recurrent state, trained in parallel via a linear-recurrence / scan trick. Griffin’s distinctive bets are (1) a real, gated diagonal recurrence (RG-LRU) and (2) explicitly keeping a slice of local attention for exact recall, rather than trying to make the recurrence do everything.
→ Linear Attention & RWKV — the recurrence=attention duality in depth
→ SSM & Mamba — the selective state-space cousin of RG-LRU
→ RetNet — one computation, three execution forms
→ xLSTM — another modernized recurrence with exponential gating
→ Attention Variants — MHA, MQA, sliding-window, FlashAttention
→ Jamba — hybrid Mamba + attention + MoE at production scale