AI Architectures

Griffin & Hawk

Google DeepMind’s recipe for transformer-quality language models that run like RNNs — a gated linear recurrence for cheap global memory, local attention for sharp recent recall, and constant-size state at inference.

Prerequisites: A recurrent state carries information forward token by token + Attention lets a token look back at earlier tokens. That’s it.
10
Chapters
9+
Simulations
0
Assumed Knowledge

Chapter 0: The Inference Wall

You have a trained transformer and you want it to read a long document — a whole book, a codebase, an hour of transcribed speech. Two costs blow up on you, and both come from the same place: attention.

First, compute. To produce one token, attention compares it against every previous token. For a sequence of length n, that is n comparisons for the last token, and the work to process the whole sequence scales like n². Double the context, quadruple the cost.

Second — and this is the one that actually kills you at deployment — memory. To avoid recomputing attention from scratch at every step, transformers cache the key and value vectors of every token they have seen. This KV cache grows linearly forever. At 100,000 tokens you are storing 100,000 keys and values per layer, per head. The cache, not the math, is what makes long-context generation expensive and slow.

Now picture the opposite extreme: a plain recurrent network (RNN). It reads one token, updates a single fixed-size hidden state, and moves on. To generate token number 100,000 it needs exactly the same memory as token number 3 — one state vector. Constant memory. Constant cost per token. Beautiful for inference.

So why did everyone abandon RNNs for transformers? Because classic RNNs are slow to train (each step waits for the previous one — no parallelism across the sequence) and they forget (gradients vanish over long ranges). Transformers train in parallel and remember everything. We traded inference efficiency for trainability.

The trap we’re escaping: “Attention is the only thing that works, so we must pay its inference tax.” Griffin’s claim, from the 2024 DeepMind paper, is that a carefully gated linear recurrence — trained in parallel like attention, run recurrently like an RNN — matched with a thin slice of local attention, can match a strong transformer (Llama-2 scale) while training on fewer tokens and generating far faster on long sequences.
Memory at inference: KV cache vs. recurrent state

Slide the sequence length. The transformer’s KV cache (orange) grows with every token; the recurrent state (teal) is flat — the same size whether you’ve read 10 tokens or 10,000.

Sequence length 512

Two names, one idea. Hawk is the pure-recurrence model: blocks built around a new recurrent unit called the RG-LRU. Griffin is the hybrid: mostly RG-LRU blocks, with a local attention layer mixed in every few blocks. By the end of this lesson you will understand every piece — the linear recurrence, why it can be trained in parallel, the gates that make it selective, and why a little local attention is the perfect complement.

At inference time, the thing that makes a transformer expensive on a very long sequence is mainly:

Chapter 1: The Linear Recurrence

We want RNN-style inference: one state, updated per token. The simplest possible update that still carries memory is a linear recurrence:

ht = a · ht-1 + b · xt

Read it out loud: the new state ht is a fraction a of the old state plus a fraction b of the new input xt. The number a is the recurrence weight — how much of the past you keep. The number b is the input weight — how much of the present you let in. In the real models h and x are vectors and a, b act element-by-element (a diagonal recurrence), but a single number teaches the whole idea.

The word that matters is linear. There is no nonlinear function wrapped around the state — no tanh, no sigmoid squashing ht-1 before it is reused. That one restriction is the secret to everything that follows. We will cash it in next chapter to train this thing in parallel. For now, let us see what it remembers.

Unrolling: what does the state actually hold?

Substitute the recurrence into itself. Start from h0 = b·x0:

ht = b·xt + a·b·xt-1 + a²·b·xt-2 + a³·b·xt-3 + …

The state is a weighted sum of all past inputs, where an input from k steps ago is weighted by ak. Because a is between 0 and 1, older inputs fade geometrically. This is an exponential moving average with decay rate a. The value of a is a memory dial: near 1, the past lingers for hundreds of steps; near 0, the state is basically just the latest input.

Worked example by hand

Let a = 0.8, b = 1, and feed the inputs x = [5, 0, 0, 0] (one spike, then silence). Watch the state decay:

txtcomputationht
051·55.00
100.8·5.00 + 1·04.00
200.8·4.003.20
300.8·3.202.56

Each step keeps 80% of the previous value: 5 → 4 → 3.2 → 2.56, exactly 5·0.8t. The single input from t=0 is still echoing three steps later. That echo is the memory. Make a bigger and the echo lasts longer; make it smaller and it dies faster.

Why this is the right primitive: An exponential moving average is the cheapest possible long-range memory — one multiply and one add per step, one number of state. The entire Griffin recurrence is this, made vector-valued and made data-dependent (the dials a and b will learn to change per token). Everything else is plumbing.

The memory dial: impulse response of a linear recurrence

We feed a single spike at the start, then silence. The bars show the state decaying as at. Drag a: near 1 the memory rings for a long time; near 0 it vanishes immediately.

a (recurrence weight) 0.80
b (input weight) 1.0
python — linear recurrence, the recurrent (inference) form
def linear_recurrence(x, a, b):     # x: list of inputs
    h = 0.0
    out = []
    for xt in x:
        h = a * h + b * xt           # one multiply-add, O(1) memory
        out.append(h)
    return out

# linear_recurrence([5,0,0,0], a=0.8, b=1.0) -> [5.0, 4.0, 3.2, 2.56]

This loop is exactly how Hawk generates text: constant state h, one cheap step per token. But a for loop over the sequence is death for training on a GPU — every step waits for the one before it. Next chapter, the linearity pays off and we delete the loop.

In ht = a·ht-1 + b·xt, an input that arrived k steps ago contributes to the current state with weight:

Chapter 2: The Parallel Scan

Here is the problem. The recurrence loop is sequential: h5 needs h4 needs h3… On a GPU with thousands of cores, that loop uses one of them and the rest sit idle. Transformers don’t have this problem — every token’s attention is computed at once. To compete, the recurrence must also be computed all at once. Linearity is what lets us.

The trick: combining is associative

Think of each step as carrying a little package (a, b·x) — a decay factor and an additive contribution. Applying step t after step s means: first scale the running state by as and add bsxs, then scale by at and add btxt. Two such steps compose into one equivalent step:

(a2, c2) • (a1, c1) = (a2a1,  a2c1 + c2)

where c = b·x. The key fact: this combine operator is associative. Grouping doesn’t matter — (A•B)•C equals A•(B•C). And anything associative can be computed by a parallel scan (also called a prefix-sum): a tournament-style tree that produces all the running states in log n parallel rounds instead of n sequential steps.

So the same math has two faces. Training: lay out the whole sequence, run the parallel scan, get all ht at once — GPU-friendly, like attention. Inference: throw away the scan and just do the one-step loop — constant memory, like an RNN. One computation, two execution modes. (If that rings a bell, it is the same duality RetNet and Mamba exploit.)

Why linearity is non-negotiable: If we wrapped a tanh around ht-1, the combine step would no longer be a clean “scale-and-add,” the operator would lose associativity, and the parallel scan would be impossible. You would be stuck with the sequential loop — back to a slow-training RNN. The nonlinearities in Griffin live outside the recurrence (in gates and MLPs), never inside it.
Parallel scan vs. sequential loop

Press Step to advance. The sequential loop (top) does one token per tick — n ticks. The parallel scan (bottom) combines pairs in a tree — finished in about log⊂2 n ticks. Watch the tree light up far faster.

Sequence length 8

For an 8-token sequence the loop needs 8 ticks; the scan finishes in 3 (because 2³ = 8). For 2048 tokens: 2048 versus 11. That is why a linear recurrence trains at transformer speed. DeepMind actually wrote a custom TPU kernel for this scan, because the memory-movement pattern, not the arithmetic, is the bottleneck — but the idea is exactly this tree.

Why can a linear recurrence be computed with a parallel scan but a tanh-RNN cannot?

Chapter 3: Stability (keeping the state from exploding)

The recurrence multiplies the state by a every step. Unroll it and a contribution from k steps back is scaled by ak. That single fact decides whether the model lives or dies:

The fix: never let a leave (0, 1)

We must guarantee the recurrence weight stays in the safe zone — not hope the optimizer keeps it there. The trick is to never learn a directly. Instead learn an unconstrained number Λ and pass it through a squashing function:

a = σ(Λ)   where σ is the sigmoid, output always in (0, 1)

Now no matter what value the optimizer pushes Λ to — minus a million or plus a million — a is mathematically pinned inside (0, 1). The model is stable by construction. (Griffin actually parameterizes it in log-space for numerical comfort, but the guarantee is this sigmoid.)

Why real and diagonal?

Earlier linear-recurrence models (S4, the LRU) used complex numbers for a, so the state could rotate as well as decay — useful for modeling oscillations. DeepMind found that for language a plain real diagonal recurrence works just as well and is simpler and faster: no complex arithmetic, just a per-channel decay. “Diagonal” means each channel of the state has its own independent a — no mixing between channels inside the recurrence (mixing happens in the linear projections around it). That keeps the scan a cheap element-wise operation.

Common misconception: “A bigger recurrence weight is always better — more memory!” Not quite. Push a too close to 1 on every channel and the state becomes a sluggish average that can never forget irrelevant junk; push it too low and the model is amnesiac. The real win in the next chapter is making a data-dependent: high when the token is worth remembering, low when it should be flushed. A fixed dial can’t do that.
Stable, knife-edge, or explosive

State magnitude over time after a unit input. Below 1: decays (stable, teal). At 1: holds (knife-edge, yellow). Above 1: explodes (red — note the axis rescales). The sigmoid parameterization makes the red zone unreachable.

a (recurrence weight) 0.90
Griffin learns an unconstrained Λ and sets a = σ(Λ) instead of learning a directly. The point is to:

Chapter 4: RG-LRU — gating the recurrence

A fixed decay a treats every token the same. But language isn’t uniform: the name of the protagonist on page 1 should persist; the word “the” should not. We want the model to decide, per token, how much to remember and how much to let in. That decision-making is the RG-LRU — the Real-Gated Linear Recurrent Unit, the heart of Hawk and Griffin.

Two gates

Both gates are simple: take the current input xt, run a linear layer, squash with a sigmoid (so each is a number in (0,1)):

recurrence gate  rt = σ(Wa xt)     input gate  it = σ(Wx xt)

The input gate it filters the incoming token — like deciding how loudly this word gets to speak into the state. The recurrence gate rt is cleverer: it modulates the decay itself. The effective decay becomes

at = a c · rt   (c = 8 is a fixed constant; a = σ(Λ) is the learned base decay)

Stare at this. When the gate rt → 0, the exponent → 0, so ata0 = 1: the state is preserved almost perfectly — “this token says: hold everything, don’t decay.” When rt → 1, the exponent is its largest, so at drops toward a8 (a strong decay) — “flush the old state, the topic just changed.” The gate slides the decay between “remember everything” and “forget aggressively,” per token, per channel.

The full RG-LRU update

ht = at · ht-1 + √(1 − at²) · (it · xt)

The first term is the gated memory. The second term is the gated input — but notice the input weight is √(1 − at²), not a free parameter. This is a variance-preserving normalizer. If you keep a fraction at of the old state, this exact input weight ensures the state’s variance stays roughly constant instead of blowing up or shrinking as gates change. When at is near 1 (holding memory), √(1−at²) is near 0 (let almost nothing new in); when at is near 0 (forgetting), it is near 1 (let the new token dominate). The two terms trade off automatically.

Worked example by hand

Base decay a = 0.9, constant c = 8, current state ht-1 = 2.0, input xt = 1.0. Compare a “keep” token vs. a “flush” token.

casertat = 0.98r√(1−at²)itht
keep memory0.10.90.8 = 0.9200.3920.50.920·2 + 0.392·0.5·1 = 2.04
flush memory0.90.97.2 = 0.4670.8840.50.467·2 + 0.884·0.5·1 = 1.38

Same input, same base decay — but the “keep” token barely moves the state (2.0 → 2.04, the past dominates), while the “flush” token pulls it down hard (2.0 → 1.38, the old memory is being discarded). The model learned when to do which, from data. That selectivity — the same idea Mamba calls “input-dependent dynamics” — is what lets a fixed-size state punch above its weight.

Gated memory: keep vs. flush, token by token

A sequence of input spikes flows in. Set the base decay and the recurrence gate. Low gate = the state integrates and holds (long memory); high gate = each new token wipes the slate. Watch the teal state line respond.

base decay a 0.90
recurrence gate r (forget more →) 0.20
python — one RG-LRU step (per-channel, vectorized with numpy)
import numpy as np

def rg_lru_step(h, x, Lam, Wa, Wx, c=8.0):
    a    = sigmoid(Lam)              # base decay per channel, in (0,1)
    r    = sigmoid(Wa @ x)           # recurrence gate, in (0,1)
    i    = sigmoid(Wx @ x)           # input gate, in (0,1)
    a_t  = a ** (c * r)              # data-dependent decay
    h    = a_t * h + np.sqrt(1 - a_t**2) * (i * x)
    return h                        # new state, same shape as h
When the recurrence gate rt is near 0, the effective decay at = a8rt approaches 1. What does the model do with this token?

Chapter 5: The Recurrent Block (data flow, end to end)

The RG-LRU is the engine; the recurrent block is the car built around it. Hawk and Griffin stack many copies of one residual block. Let us trace a real tensor through it so you could implement it from memory. Say the model width is D = 1024 and we process a sequence of T = 2048 tokens, so the input to the block is shaped [T, D] = [2048, 1024].

The outer residual wrapper

Like a transformer, the block is two residual sub-layers: a temporal-mixing sub-layer (this is where the recurrence lives) and a channel-mixing MLP. Each is preceded by RMSNorm and added back to the residual stream:

residual in [T,D]
→ RMSNorm → recurrent core → + residual
residual [T,D]
→ RMSNorm → gated MLP (GeGLU) → + residual
↓ out [T,D]

Inside the recurrent core (the temporal-mixing sub-layer)

This is where Griffin’s design echoes Mamba: the normed input splits into two branches that are multiplied back together at the end — a gating structure.

  1. Split into two branches with two linear layers, each [D] → [Drnn] (the paper uses Drnn = D, so [2048,1024] → [2048,1024] for each branch).
  2. Gate branch: linear → GeLU. This branch carries no recurrence; it is a learned, element-wise multiplicative mask applied at the end.
  3. Recurrence branch: linear → a small causal 1-D convolution (kernel width 4, depthwise) → RG-LRU. The tiny conv mixes each channel over the last ~4 tokens before the recurrence — cheap, sharp, short-range pattern detection that complements the recurrence’s smooth long-range memory.
  4. Merge: multiply the two branches element-wise (the GeLU branch gates the RG-LRU branch), then a final linear [Drnn] → [D] back to model width.

Output shape: [2048, 1024], same as the input — so blocks stack cleanly. Note where the nonlinearities are: the GeLU, the conv, the MLP all sit outside the recurrence. The RG-LRU itself stays strictly linear in its state, which is exactly what keeps the parallel scan from Chapter 2 valid. That placement is not an accident — it is the whole design constraint.

Concept → realization: at training the recurrence branch runs as a parallel scan over all 2048 positions at once. At inference the very same weights run as the one-step loop, and the only thing carried between tokens is the RG-LRU state vector (size Drnn) plus the conv’s 3-token window. No growing cache. That is the entire payoff of Chapters 1–4 made concrete.

Recurrent block: click a stage to trace the data

Click any box to see its shape and what it does. Follow the two branches from the split to the multiply.

Why are the GeLU, the 1-D convolution, and the MLP all placed outside the RG-LRU recurrence rather than inside it?

Chapter 6: Local Attention (the sharp complement)

The RG-LRU gives us cheap, smooth, long-range memory — but it is lossy. A fixed-size state can only summarize the past; it cannot reach back and read an exact earlier token verbatim. For some jobs you need exactly that: “copy the variable name I defined two lines ago,” “match this closing bracket to that opening one.” That is attention’s superpower — precise, content-based recall. Griffin keeps a little of it.

Sliding-window attention

The expensive part of full attention is that every token looks at all previous tokens. Local (sliding-window) attention restricts each token to the most recent W tokens — a fixed window, say W = 1024. Token at position t attends only to positions t−W+1 … t.

Two consequences, both wonderful:

But doesn’t a window throw away long-range info?

On its own, yes. A pure local-attention model literally cannot see past W tokens. This is exactly why Griffin is a hybrid. The recurrence branch carries information from arbitrarily far back (compressed into the state), while local attention handles precise recent lookups. Neither alone is enough; together they cover both regimes — long-range summary from the RG-LRU, short-range detail from local attention. The window doesn’t need to be long because the recurrence already owns the long range.

Common misconception: “Local attention is just a cheaper approximation of full attention, so it must be strictly worse.” In a hybrid it isn’t an approximation at all — it is a division of labor. Asking full attention to also do the long-range job is what was expensive; once the recurrence handles that, the window can be small and the model loses nothing. Griffin also uses MQA (multi-query attention) in these layers to shrink the cache further.
The sliding window

Each row is a token attending to earlier tokens. Full causal attention (left of the toggle) fills the whole triangle — the cache grows with the sequence. Sliding-window attention keeps only a band of width W — constant cache. Drag W and the query position.

window W 4
In Griffin, why can the local-attention window be relatively small without crippling long-range understanding?

Chapter 7: Griffin — the hybrid stack (showcase)

Now assemble it. Griffin interleaves the two block types we just built: mostly recurrent blocks (RG-LRU), with a local-attention block mixed in periodically — the paper’s recipe is roughly two recurrent blocks for every one local-attention block. Hawk is the special case of all recurrent blocks, no attention at all.

The simulator below is a small Griffin stack processing a sequence. A piece of information enters at one token and we watch how far it can travel through the network. In a recurrent layer it propagates through the state — reachable from any later position (long range, but compressed). In a local-attention layer it can only jump within the window (short range, but exact). Stacking them lets information route through whichever mechanism fits. Below the stack, a live tally compares the inference memory of this hybrid against a pure transformer of the same depth as the sequence grows.

Griffin stack: information flow & inference memory

Each row is a layer (teal = recurrent / RG-LRU, blue = local attention). Set the layer pattern, the attention window, and the sequence length. Click a token in the bottom row to inject information and press Propagate to watch which later tokens it can reach. The memory readout shows why the hybrid stays cheap.

recurrent : attention ratio 2:1
local window W 3
sequence length (for memory readout) 2048

Play with the ratio. Set it to 0:1 (all attention, window-limited) and information can only crawl forward one window per layer — and the memory readout climbs as the sequence grows. Set it to 4:0 (pure Hawk) and every later token is reachable through the recurrence, with flat memory — but there is no exact-copy path. The sweet spot in the middle gives you both: a token far in the past stays reachable via the recurrence, while the attention layers provide crisp local lookups, and inference memory stays nearly flat as the sequence grows.

What Griffin actually delivered (2024): at 7B and 14B parameters Griffin matched Llama-2’s quality while training on roughly 6× fewer tokens; at inference it reached higher throughput and lower latency on long sequences; and it extrapolated — trained on 2,048-token sequences, it kept improving on inputs far longer than it ever saw in training, something transformers struggle to do without special tricks.

Chapter 8: Inference Economics & Extrapolation

Architecture choices only matter if they change the bill. Let us make the inference advantage concrete with the three cost axes that decide whether you can serve a model cheaply.

1. Memory per generated token

A transformer must keep the KV cache of every past token. For a sequence of length n, with L layers, the cache scales like n · L — unbounded. A Griffin recurrent layer keeps one fixed-size state regardless of n; a Griffin local-attention layer keeps a cache capped at the window W. So Griffin’s state is roughly (recurrent layers × D) + (attention layers × W) — a constant that does not grow with the sequence. At n = 100k this is the difference between gigabytes and megabytes.

2. Latency per token

Lower memory traffic means each generated token is produced faster, and crucially the per-token cost stays flat as the sequence grows, instead of creeping upward as the KV cache balloons. Long-context generation is where Griffin pulls clearly ahead of a same-size transformer.

3. Throughput

Because the state is small, you can fit many more concurrent sequences in memory and batch them — so total tokens-per-second served goes up. For a serving system, throughput-per-dollar is often the number that matters, and constant-size state is what unlocks it.

Length extrapolation — a bonus from the structure

Transformers tie meaning to absolute or relative positions via positional encodings; push them past their training length and those encodings hit values never seen, and quality falls off a cliff. Griffin’s recurrence has no length-dependent positional scheme — it just keeps applying the same per-step update — and its attention is local, so it never sees a relative distance bigger than W. Neither component cares how long the sequence is. That structural fact is why Griffin generalizes to far longer inputs than it trained on.

Inference memory vs. sequence length

Relative inference-state size as the sequence grows. Transformer (orange): linear, unbounded. Hawk (teal): flat — pure recurrence. Griffin (purple): nearly flat — recurrence plus a capped local-attention cache. Drag the length and watch the gap open.

sequence length 4096
local window W 1024
Griffin extrapolates to sequences much longer than it trained on, where vanilla transformers usually degrade. The structural reason is:

Chapter 9: Cheat Sheet & Connections

Everything in one place. If you can reconstruct this table and the equations beneath it, you understand Griffin and Hawk well enough to teach them.

The whole pipeline in one breath

linear recurrence
ht = a·ht-1 + b·xt — an exponential moving average; memory = ak
↓ make it trainable in parallel
parallel scan
associative scale-and-add → all states in log n rounds (train); one-step loop (infer)
↓ make it stable
a = σ(Λ)
real, diagonal, pinned in (0,1) — stable by construction
↓ make it selective
RG-LRU
at = a8rt; ht = atht-1 + √(1−at²)(itxt)
↓ wrap + complement
Griffin block
conv+RG-LRU branch × GeLU branch, RMSNorm+MLP residual; interleave local attention

The family tree

ModelLong-range mechanismRecall mechanismInference state
Transformerfull attentionfull attentionKV cache — grows with n
HawkRG-LRU recurrence(none exact)fixed state — constant
GriffinRG-LRU recurrencelocal attentionstate + capped window — constant
Mambaselective SSM(none exact)fixed state — constant
RetNetretention (decay)chunk attentionfixed state — constant
RWKVWKV recurrencetoken-shift mixfixed state — constant

The shared idea across this whole row of “efficient” models: replace the growing KV cache with a fixed-size recurrent state, trained in parallel via a linear-recurrence / scan trick. Griffin’s distinctive bets are (1) a real, gated diagonal recurrence (RG-LRU) and (2) explicitly keeping a slice of local attention for exact recall, rather than trying to make the recurrence do everything.

When to reach for Griffin/Hawk

Keep exploring

Linear Attention & RWKV — the recurrence=attention duality in depth
SSM & Mamba — the selective state-space cousin of RG-LRU
RetNet — one computation, three execution forms
xLSTM — another modernized recurrence with exponential gating
Attention Variants — MHA, MQA, sliding-window, FlashAttention
Jamba — hybrid Mamba + attention + MoE at production scale

“What I cannot create, I do not understand.” You just rebuilt Griffin from a single multiply-add: a linear recurrence, made parallel by a scan, stable by a sigmoid, selective by two gates, and complete by a window of attention. That is the whole bird.