Beyond attention: state space models process sequences in O(N) instead of O(N²). From S4 to Mamba — the architectures that scale to millions of tokens.
You have a sequence of 100,000 tokens — a textbook, an hour of audio, a genome. You want a model that understands it end to end. The problem? Transformers compute a score between every pair of tokens. That's N² operations. At 100K tokens, that's 10 billion pairwise interactions. Your GPU runs out of memory long before it runs out of things to learn.
This isn't a minor inconvenience. It's a fundamental scaling wall. Audio runs at 16,000 samples per second — a single second is already 256 million pairs. A math textbook has hundreds of thousands of words with dependencies spanning chapters. Attention simply cannot reach these scales.
So the question driving this lecture is: can we design architectures that scale sub-quadratically in sequence length, without sacrificing quality? The answer is yes — and it comes from an unexpected place: control theory.
Drag the slider to increase sequence length. Watch the red (quadratic) curve explode while the teal (linear) one barely rises.
Before we build new architectures, we need to understand the three fundamental ways to mix information across a sequence. Every sequence model is built from some combination of these primitives: convolutions, recurrences, and attention.
A convolution slides a kernel (a small weight vector) across the input and computes weighted sums at each position. The output at position n is a dot product of the kernel with a window of the input centered (or ending) at n.
Where h is the kernel and x is the input sequence. The kernel weights are fixed — the same weights are applied at every position. This means the mixing pattern is input-independent.
Properties: parallelizable during training (all positions computed at once via FFT), but the kernel has finite reach. A kernel of size K can only look K steps back.
A recurrence processes the sequence one token at a time, maintaining a hidden state that summarizes everything seen so far.
Properties: the hidden state has potentially infinite context (it can remember anything from the past), inference is O(1) per step (just update the state), but training is sequential — you can't compute hk without hk−1. Unless the function f is linear, in which case we can parallelize it.
Attention computes a weighted combination of all values, where the weights depend on pairwise similarity between queries and keys.
Properties: fully parallelizable, input-dependent mixing (the mixing matrix changes with every input), but O(N²) in both compute and memory.
| Primitive | Training | Inference | Context | Mixing |
|---|---|---|---|---|
| Convolution | Parallel O(N log N) | O(N) per sequence | Finite (kernel size) | Input-independent |
| Recurrence | Sequential O(N) | O(1) per step | Infinite (in theory) | Input-independent* |
| Attention | Parallel O(N²) | O(N) per step | Full sequence | Input-dependent |
*Linear recurrences are input-independent. Mamba makes them input-dependent — we'll get there in Chapter 6.
Click each primitive to see how it processes a 5-token sequence. Watch the mixing pattern — which tokens influence each output.
A state space model (SSM) is a system borrowed from control theory. It describes how a hidden state evolves over time in response to an input signal, and how we observe the output. The continuous-time form has four matrices:
| Symbol | Shape | What it is |
|---|---|---|
| x(t) | [N, 1] | Hidden state — the model's internal memory |
| u(t) | [1, 1] | Input signal at time t (one channel) |
| y(t) | [1, 1] | Output signal at time t |
| A | [N, N] | State transition — how the state evolves on its own |
| B | [N, 1] | Input projection — how input enters the state |
| C | [1, N] | Output projection — how we read the state |
| D | [1, 1] | Skip connection (often set to 0) |
Think of it like this: you're in a dark room with a complex system of springs and masses (the state x). You push it with an input u. The springs respond according to their physics (matrix A). You can only observe the system through a small window (matrix C). The state captures everything happening inside, but you only see a projection.
In practice, a sequence model has D input channels (e.g., D=768 for a typical model dimension). Each channel gets its own independent SSM with its own A, B, C matrices. So a single SSM layer is actually D parallel SSMs, each with hidden state size N. The total hidden state is [D, N].
A 4-dimensional hidden state responds to an input pulse. Each colored line is one component of x(t). The state integrates and decays — that's how it compresses history.
The continuous SSM is beautiful in theory, but computers work with discrete sequences. We need to convert x'(t) = Ax(t) + Bu(t) into a discrete recurrence: xk = Axk-1 + Buk. This process is called discretization.
The simplest approach: assume the input u(t) is constant between sample points. Between time tk and tk+1 = tk + Δ, the input is uk. Under this assumption, the continuous ODE has an exact solution:
So A = eAΔ and B = (eAΔ − I) A−1 B. The step size Δ controls the temporal resolution: small Δ = fine-grained sampling, large Δ = coarse sampling.
Computing matrix exponentials is expensive. The bilinear transform (also called Tustin's method or the trapezoidal rule) gives us an algebraic approximation that is both cheap and numerically stable.
The idea: approximate the derivative x'(t) using the trapezoidal rule — the average of the slope at the start and end of each interval:
Substituting x' = Ax + Bu and solving for xk:
Let A = −1, B = 1, Δ = 0.5. Then:
derivation # Bilinear transform (scalar case) A = -1, B = 1, delta = 0.5 A_bar = (1 - delta*A/2)**(-1) * (1 + delta*A/2) = (1 + 0.25)**(-1) * (1 - 0.25) = (0.8) * (0.75) = 0.6 B_bar = (1 - delta*A/2)**(-1) * delta * B = 0.8 * 0.5 * 1 = 0.4 # Recurrence: x_k = 0.6 * x_{k-1} + 0.4 * u_k # |A_bar| = 0.6 < 1, so the system is stable
Adjust the step size Δ and continuous parameter A. Watch how the discrete system (stepped) approximates the continuous (smooth). Small Δ = better approximation.
Here's a deceptively hard question: how should you initialize the matrix A? Random initialization won't work — the hidden state would either explode or forget everything instantly. We need A to do something principled: compress the input history into the state in an optimal way.
This is the problem that the HiPPO framework (High-Order Polynomial Projection Operators) solves. The insight is beautiful: choose A so that the hidden state x(t) stores the coefficients of an optimal polynomial approximation of the input history.
The most important instance is HiPPO-LegS (Legendre with Scaled measure), which approximates the history over a sliding window. The resulting matrix has a specific structure:
This produces a lower-triangular matrix with entries that grow in a structured way. The key properties:
For N=4, the HiPPO-LegS matrix is:
python import numpy as np def hippo_legs(N): A = np.zeros((N, N)) for n in range(N): for k in range(N): if n > k: A[n, k] = -np.sqrt(2*n+1) * np.sqrt(2*k+1) elif n == k: A[n, k] = -(n + 1) return A A = hippo_legs(4) # A = [[-1.0, 0.0, 0.0, 0.0], # [-1.73, -2.0, 0.0, 0.0], # [-2.24, -2.45, -3.0, 0.0], # [-2.65, -3.0, -3.32, -4.0]]
The 4 state components of a HiPPO-initialized SSM responding to a step input. Lower-index components (warm) react fast and decay fast. Higher-index components (purple) react slowly but retain long-range information.
We now have all the ingredients: a continuous SSM (Ch 2), discretization (Ch 3), and HiPPO initialization (Ch 4). The S4 architecture (Gu et al., 2022) combines them with one crucial trick: it shows that the same SSM can be computed as either a recurrence or a convolution, and you can switch between them.
After discretization, our SSM is a linear recurrence:
Let's unroll this recurrence starting from x0 = 0:
See the pattern? Each output yk is a weighted sum of all past inputs, with weights CAjB. That's a convolution! The kernel is:
And the output is simply y = K * u (convolution).
Naive convolution of two length-N sequences is O(N²). But the FFT Convolution Theorem states that convolution in the time domain equals pointwise multiplication in the frequency domain:
The Fast Fourier Transform computes FFT in O(N log N). So the full pipeline is: FFT the kernel (N log N), FFT the input (N log N), pointwise multiply (N), inverse FFT (N log N). Total: O(N log N) instead of O(N²).
There's one more obstacle: computing the kernel K = (CB, CAB, ..., CAL-1B) requires L matrix-vector products, which is O(LN²) for a dense N×N matrix A. S4's key structural insight: decompose A into a diagonal matrix plus a low-rank correction (DPLR). This allows computing the kernel in O(N log² N) using a clever recursive doubling scheme. Later work (S4D) showed that simply using a diagonal A works almost as well, making the kernel computation trivially O(NL).
Toggle between the two computation modes. Convolution (training): all outputs computed in parallel via FFT. Recurrence (inference): outputs computed sequentially by updating the state. Same result, different algorithm.
S4 was a breakthrough for long-range tasks like audio and signal processing. But when researchers tried it on language modeling, something frustrating happened: it consistently underperformed Transformers. The perplexity gap was significant — about 1-3 points worse than attention-based models of the same size.
Why? The answer lies in a task called recall.
Consider this sequence: "A 3 B 7 C 2 ... What is A?" To answer, the model needs to look back, find "A 3", and return "3". This requires the mixing matrix (how tokens influence each other) to be input-dependent. The model needs to route "A?" back to wherever "A 3" appeared.
In a Transformer, the attention matrix is input-dependent — queries and keys are computed from the input, so the model can learn to attend to the right location. In an SSM, the matrices A, B, C are fixed — the same mixing pattern applies regardless of input content. The SSM literally cannot vary its routing based on what it sees.
Mamba (Gu & Dao, 2023) keeps the linear recurrence structure but makes three critical matrices functions of the input:
| Parameter | S4 (LTI) | Mamba (Selective) | Effect |
|---|---|---|---|
| A | Fixed (learned) | Fixed (learned) | Base dynamics unchanged |
| B | Fixed (learned) | Input-dependent | Model chooses what enters the state |
| C | Fixed (learned) | Input-dependent | Model chooses what to read from state |
| Δ | Fixed (learned) | Input-dependent | Model chooses its own timescale per token |
The step size Δ is particularly powerful. Remember: Δ controls how much of the continuous dynamics to apply. A large Δ means the state changes a lot — the model is "paying attention" to this token, absorbing it into memory. A small Δ means the state barely changes — the model is "ignoring" this token, letting it pass through.
This is exactly the selection mechanism: when the model sees an important token (like "A 3" in a recall task), it can increase Δ to write it into the state. When it sees filler tokens, it can decrease Δ to preserve the existing state.
Without the FFT trick, how does Mamba train efficiently? It uses a parallel scan (also called prefix sum). The key insight: a linear recurrence xk = akxk-1 + bk can be computed in O(log N) parallel steps using a tree reduction, even though it looks sequential.
Mamba's implementation is hardware-aware: it fuses the discretization, scan, and multiplication into a single CUDA kernel that minimizes HBM reads/writes — the same philosophy as FlashAttention but applied to recurrences. The result: Mamba trains at 3-5x the throughput of a Transformer of similar size.
A sequence of tokens enters Mamba. Important tokens (colored) get large Δ (tall bars) and are written into the state. Filler tokens get small Δ (short bars) and are mostly ignored. Drag the slider to see how selectivity affects recall.
We now have a rich landscape of sequence architectures. Each makes different tradeoffs between quality, training efficiency, inference efficiency, and the ability to handle long-range dependencies. Let's put them side by side.
| Architecture | Training Cost | Inference (per token) | Long-Range | Recall | Language Quality |
|---|---|---|---|---|---|
| Transformer | O(N²D) | O(ND) + KV cache | Full (N²) | Excellent | Best |
| S4 (LTI SSM) | O(N log N · D) | O(D) | Excellent | Poor | Gap vs Transformer |
| Mamba | O(ND) | O(D) | Excellent | Good | Near-Transformer |
| Hybrids (Jamba, Zamba) | Mixed | Mixed | Good | Excellent | Best of both |
At 360M parameters trained on 10B tokens of the Pile, the perplexity comparison looks like this:
| Model | Type | Perplexity |
|---|---|---|
| Llama-style Transformer | Attention | 8.39 |
| GPT-2-style Transformer | Attention | 8.97 |
| LTI SSM | SSM (no gating) | 13.13 |
| H3 (Gated SSM) | SSM + gating | 10.60 |
| Hyena (Gated SSM) | SSM + gating | 10.11 |
| RWKV-V5 | SSM + gating | 9.79 |
The gap is narrowing, but attention-based models still lead on language. The story is different on audio (S4 beats Transformers) and very long sequences (where Transformers simply can't run).
Each dot is an architecture. X-axis: inference cost. Y-axis: quality (inverse perplexity, higher = better). The ideal architecture would be in the top-left corner — high quality, low cost.
We've traveled from the quadratic bottleneck of attention, through the primitives of sequence modeling, into the elegant mathematics of state space models, and emerged with architectures that process sequences in linear time. Here's the cheat sheet.
| Concept | One-Line Summary |
|---|---|
| State Space Model | x'(t) = Ax(t) + Bu(t); compress history into fixed-size state |
| Discretization | Convert continuous ODE to discrete recurrence via bilinear transform |
| HiPPO | Initialize A to store optimal polynomial approximation of history |
| Conv-Recurrence Duality | Unrolled linear recurrence = convolution; FFT makes it O(N log N) |
| S4 | HiPPO + DPLR structure + dual mode (conv for training, recurrence for inference) |
| Recall Problem | LTI models can't do content-based lookup; need input-dependent mixing |
| Mamba | Input-dependent B, C, Δ; hardware-aware parallel scan; selective forgetting |
| Hybrids | Mix SSM + attention layers for the best quality-efficiency tradeoff |
The field is moving fast. Mamba-2 introduces the State Space Duality framework (connecting SSMs to structured attention). Griffin and Hawk explore gated linear recurrences. Based and Gated Linear Attention close the gap from the attention side. The boundary between "SSM" and "attention" is blurring.