Transformers have a quadratic bottleneck. State space models process sequences in linear time — and Mamba makes them input-dependent. Here's the full story.
Attention is powerful but expensive. For a sequence of length n, self-attention computes a score between every pair of tokens: that's n² operations. At 1,000 tokens, that's 1 million pairs. At 100,000 tokens, it's 10 billion. This quadratic cost is the fundamental bottleneck of Transformers.
What if we could process sequences without ever computing pairwise interactions? What if each token was processed in constant time, regardless of sequence length? This is the promise of state space models (SSMs) — a family of architectures borrowed from control theory that compress the entire history into a fixed-size hidden state.
Drag the slider to see how compute scales with sequence length. The gap grows fast.
The simplest alternative to attention is a linear recurrence. At each step, we maintain a hidden state h and update it:
The matrix A controls how the previous state decays and mixes (it's the "memory" operator). The matrix B controls how much of the new input is written into the state. The output yt = C · ht reads a projection of the state. Processing each token costs constant time: one matrix-vector product for A·h, one for B·x, one for C·h. No matter how long the sequence, each new token costs the same.
If the state has dimension N and the input has dimension D, the per-step cost is O(N² + ND) for the matrix multiplications. Compare this to attention's O(L · D) per token at generation (where L is the sequence length so far). For long sequences, the SSM wins because N is fixed (typically 16–64), while L keeps growing.
Watch the hidden state h evolve as input tokens arrive. A controls decay (memory); B controls input mixing.
| Approach | Cost/Token | Total (n tokens) | Can Select? |
|---|---|---|---|
| Self-Attention | O(n) | O(n²) | Yes (dynamic) |
| Linear Recurrence | O(1) | O(n) | No (fixed A, B) |
A state space model (SSM) comes from control theory. It starts in continuous time:
This is a differential equation: h' is the derivative of the hidden state. To use it on discrete sequences (like text), we discretize it using a step size Δ. The most common method (zero-order hold) gives:
A continuous signal (teal) is sampled at discrete steps. Smaller Δ = more faithful representation but more steps.
The recurrence hk = Ā hk-1 + B̄ xk processes each token in constant time — one matrix-vector multiply, regardless of how long the sequence is. For a sequence of length L, the total cost is O(L). Attention, by contrast, computes pairwise scores between all L tokens: O(L²). The difference is enormous at scale.
Unrolling the recurrence reveals something beautiful. Substituting repeatedly:
This is a convolution: y = K * x where the kernel Ki = C Āi B̄. When A, B, C are fixed, this kernel is precomputable. You convolve it with the input using FFT in O(L log L). Training uses this parallel path; inference uses the sequential recurrence for constant memory.
Given the continuous SSM: h'(t) = A h(t) + B x(t), and the assumption that x(t) is constant over each interval [kΔ, (k+1)Δ] (the zero-order hold assumption).
Your task: Derive the discrete matrices Ā = exp(ΔA) and B̄ = (exp(ΔA) − I) A−1 B from first principles. Why does B̄ have that form and not simply ΔB?
Full derivation:
Start with h'(t) = Ah(t) + Bx(t). Over interval [kΔ, (k+1)Δ], assume x(t) = xk (constant). The general solution at time (k+1)Δ given h(kΔ) = hk is:
hk+1 = exp(AΔ) hk + [∫0Δ exp(Aτ) dτ] B xk
Evaluating the integral: ∫0Δ exp(Aτ) dτ = A−1(exp(AΔ) − I)
Therefore: Ā = exp(ΔA) and B̄ = A−1(exp(ΔA) − I) B = (ΔA)−1(Ā − I) · ΔB
The key insight: B̄ ≠ ΔB because the input isn't applied instantaneously — it's integrated over the full time step while the state is simultaneously evolving under A. The factor A−1(exp(AΔ) − I) captures this interaction. For small Δ, it approaches ΔI (recovering the Euler approximation), but for larger Δ the correction matters significantly.
Both are linear state space models with identical structure: a transition matrix evolves a hidden state, and an observation matrix reads it out. The Kalman filter adds noise covariance tracking (Q, R) and optimal gain computation (K). An SSM in a neural network learns its matrices from data instead of deriving them from physics. But the core operation — compress history into a fixed-size state, update linearly — is identical.
If the SSM learned Q and R matrices and computed a Kalman gain, what would that correspond to in the neural network? (Answer: an attention-like weighting mechanism over the state update.)
When A, B, C are fixed (not input-dependent), the recurrence hk = Āhk−1 + B̄xk can be unrolled into a convolution y = K * x where Ki = CĀiB̄. This kernel K is precomputable because it doesn't depend on the input.
Training uses convolution mode (O(n log n) via FFT) because GPUs excel at parallel computation — processing all tokens simultaneously is vastly faster than a sequential loop.
Inference uses recurrence mode because we generate one token at a time anyway. The recurrence gives O(1) per token with O(N) state memory — no need to recompute over the entire history. The KV-cache-free nature of SSMs comes directly from this recurrence view.
This dual view breaks the moment A, B, C become input-dependent (Mamba), because you can no longer precompute K. That's why Mamba needs the parallel scan instead of FFT convolution.
The breakthrough of S4 (Structured State Spaces for Sequence Modeling) was solving the long-range memory problem. Previous SSMs forgot quickly — information from 1000 steps ago had decayed to nearly zero. S4 fixed this with two key ideas:
Instead of random initialization, the A matrix is set to the HiPPO (High-Order Polynomial Projection Operator) matrix. This special matrix is designed to optimally compress the history of a continuous signal into a fixed-size state. Each element of ht stores a coefficient of a polynomial approximation of the entire input history.
Computing Ā = exp(ΔA) naively for an N×N matrix costs O(N³). S4 decomposes A into diagonal + low-rank form, reducing this to O(N). This makes training practical for state sizes of 64 or more.
A signal arrives at step 0. Watch how much the hidden state remembers it over time. HiPPO retains far more information.
S4 decomposes A into Diagonal Plus Low-Rank (DPLR) form: A = Λ − PQ*, where Λ is diagonal and P, Q are low-rank matrices. This is critical because computing exp(ΔA) for a general N×N matrix costs O(N³). With DPLR structure, it reduces to O(N). The convolution kernel K can then be computed via a Cauchy kernel formula, making the entire training step O(L log L) via FFT.
In subsequent work, S4D (diagonal state spaces) simplified this further: just use a diagonal A matrix, initialized from the HiPPO eigenvalues. This drops the low-rank components entirely, making implementation much simpler while retaining most of the performance. Mamba uses diagonal A by default.
The HiPPO matrix specifically captures a polynomial projection: each element of the state ht stores a Legendre polynomial coefficient of the input history. Element ht[k] approximates the k-th Legendre coefficient of x(s) over [0, t]. This means the state literally stores a polynomial approximation of everything it has seen — a lossless-ish compression of the entire history into a fixed-size vector.
The HiPPO framework asks: "What is the optimal way to compress the history of a signal x(s) for s ∈ [0, t] into a fixed-size state vector c(t) ∈ RN?"
Your task: Show why projecting onto Legendre polynomials gives the HiPPO-LegS matrix Ank = −(2n+1)1/2(2k+1)1/2 for n > k. What makes this a "good" compression?
Full derivation:
1. Define the online polynomial approximation: cn(t) = ∫0t wn(s, t) x(s) ds where wn(s, t) = (2n+1)/t · Pn(2s/t − 1) is the Legendre weight rescaled to [0, t].
2. Differentiate with respect to t using Leibniz rule and the chain rule on the rescaled argument 2s/t − 1:
c'n(t) = (2n+1)/t · x(t) · Pn(1) − ∑k=0n Ank/t · ck(t)
3. Using Pn(1) = 1 and the Legendre derivative identities, the coupling matrix works out to Ank = (2n+1)1/2(2k+1)1/2 for n > k, and Ann = n + 1.
4. The system is c'(t) = −(1/t) A c(t) + (1/t) B x(t) where Bn = (2n+1)1/2.
The key insight: This specific A matrix is the unique linear system that maintains the optimal polynomial approximation of the entire history as new data arrives. A randomly initialized A has no such guarantee — it just exponentially forgets. HiPPO's A is structured so that information doesn't decay; it gets redistributed across polynomial modes. This is why S4 can model dependencies over thousands of steps.
S4 has a critical limitation: A, B, C are the same for every input token. The model can't decide to "pay more attention" to important tokens or "forget" irrelevant ones. Mamba (Gu & Dao, 2023) fixes this by making B, C, and Δ functions of the input.
This is the selective in "selective scan." When the model sees an important token, it can increase Δ (larger step = more input mixed in) and adjust B to write strongly to the state. For irrelevant tokens, it can shrink Δ and effectively skip them.
Watch how Mamba's input-dependent gates open (green) for important tokens and close (dim) for irrelevant ones. Click tokens to toggle importance.
Think of Δt as a gate width. When Δ is large, exp(ΔA) shrinks the old state more and B̄ grows — the new input floods in. When Δ is small, the old state is preserved and the input barely registers. The model learns to open the gate wide for important tokens (names, keywords, punctuation) and keep it narrow for filler words.
Bt controls what is written to the state, Ct controls what is read out, and Δt controls how much the state changes. All three are linear projections of the current input xt — making the entire SSM dynamics content-dependent. This is what S4 couldn't do.
| Model | A, B, C | Convolution? | Content-aware? |
|---|---|---|---|
| S4 | Fixed (same for all inputs) | Yes (parallel training) | No |
| Mamba | Input-dependent | No (must use scan) | Yes |
The parameter generation is lightweight. For each token xt in Rd_inner: Bt = WB xt where WB is [d_state, d_inner], producing Bt in Rd_state. Similarly for Ct. For Δt: a linear projection to Rd_inner followed by softplus ensures Δ > 0. The total parameter overhead for selectivity is minimal — just three small linear layers per block.
The convolution kernel Ki = CĀiB̄ can be precomputed ONLY because C, Ā, B̄ are the same for every token. The kernel depends only on the lag i (how far back we're looking), not on which token is at position i.
When Bt = Linear(xt), Ct = Linear(xt), and Δt = softplus(Linear(xt)), the effective "kernel" at position k looking back j steps becomes Ck ∏i=k-j+1k Āi · B̄k-j. This product depends on every input in the range — it's no longer shift-invariant. You can't factor it into a single convolution.
This is the fundamental cost of selectivity: you gain content-awareness but lose the O(n log n) FFT training path. The parallel scan (next chapter) is the solution — it's not as fast as FFT, but it's O(n) work in O(log n) parallel depth, which is practical on GPUs.
python import torch import torch.nn.functional as F def selective_ssm_step(x_t, h_prev, A, W_B, W_C, W_delta, bias_delta): # 1. Generate input-dependent parameters B_t = W_B @ x_t # [d_state] C_t = W_C @ x_t # [d_state] delta_t = F.softplus(W_delta @ x_t + bias_delta) # [d_inner], > 0 # 2. Discretize (A is in log-space, shape [d_inner, d_state]) A_real = torch.exp(A) # negative reals A_bar = torch.exp(delta_t.unsqueeze(1) * A_real) # [d_inner, d_state] B_bar = delta_t.unsqueeze(1) * B_t.unsqueeze(0) # [d_inner, d_state] # 3. Update state h_t = A_bar * h_prev + B_bar * x_t.unsqueeze(1) # [d_inner, d_state] # 4. Read output y_t = (h_t * C_t.unsqueeze(0)).sum(dim=1) # [d_inner] return y_t, h_t
Making B, C, Δ input-dependent means we can't use convolution for training. The naive approach — a sequential for-loop — would be unbearably slow on GPUs. Mamba uses a parallel scan algorithm that computes the full recurrence in O(log n) parallel steps instead of O(n) sequential ones.
The key insight: a linear recurrence ht = at ht-1 + bt is an associative operation. Like addition, we can rearrange the order of computation. Instead of left-to-right, we use a binary tree:
Click "Step" to advance through the parallel scan. Each level halves the remaining work.
Mamba doesn't just use a parallel scan — it fuses the entire SSM computation (discretization, scan, output projection) into a single GPU kernel. The key bottleneck on modern GPUs isn't compute but memory bandwidth. Reading and writing intermediate tensors of shape [B, L, d_inner, d_state] to global memory (HBM) would dominate runtime.
Instead, Mamba's kernel keeps the state (d_state=16 floats per channel) entirely in SRAM (the GPU's fast on-chip scratchpad). The associative scan runs in registers, and only the final output [B, L, d_inner] is written back to HBM. This reduces memory I/O by a factor proportional to d_state, making the selective SSM nearly as fast as a simple matrix multiply.
The linear recurrence ht = at · ht−1 + bt can be written as a binary operation on tuples: (a2, b2) • (a1, b1) = (a2 · a1, a2 · b1 + b2).
Your task: Prove this operation is associative: ((a3, b3) • (a2, b2)) • (a1, b1) = (a3, b3) • ((a2, b2) • (a1, b1)). Then explain why associativity enables O(log n) parallel depth.
Full proof:
Define •: (a2, b2) • (a1, b1) = (a2a1, a2b1 + b2)
Left: ((a3,b3)•(a2,b2))•(a1,b1) = (a3a2, a3b2+b3)•(a1,b1) = (a3a2a1, a3a2b1 + a3b2 + b3)
Right: (a3,b3)•((a2,b2)•(a1,b1)) = (a3,b3)•(a2a1, a2b1+b2) = (a3·a2a1, a3(a2b1+b2)+b3) = (a3a2a1, a3a2b1 + a3b2 + b3)
Left = Right. QED. The operation is associative (but NOT commutative — order matters!).
Why this enables the parallel scan: The prefix computation hn = reduce(•, [(a1,b1), ..., (an,bn)]) is equivalent to computing ALL intermediate prefixes. With associativity, we can use the Blelloch scan algorithm: an up-sweep (reduce phase) followed by a down-sweep (propagate phase), each taking log2(n) steps. Total work is O(n), parallel depth is O(log n).
The key insight: This tuple representation encodes the entire linear recurrence. The first element tracks the cumulative "decay" (how much of the initial state survives), and the second element tracks the cumulative "input contribution" (everything that's been added along the way). It's a matrix-free way to represent ht = (∏ ai)h0 + accumulated inputs.
A Mamba block looks quite different from a Transformer block. It has no attention mechanism at all. Instead, it uses a combination of linear projections, 1D convolution, and the selective SSM:
Click "Step" to trace data through the Mamba block. Watch the two branches split and merge.
For Mamba-2.8B: d_model=2560, d_state=16, d_conv=4, expand_factor=2, so inner_dim=5120. Here's the exact data flow through one block:
| Stage | Shape | Operation |
|---|---|---|
| Input | [B, L, 2560] | — |
| Linear expand | [B, L, 10240] | Project to 2 × inner_dim, split into two branches |
| Branch A: Conv1D | [B, L, 5120] | Causal conv, kernel=4, groups=5120 (depthwise) |
| Generate B, C, Δ | [B, L, 16] each | Linear projections from Branch A |
| SSM scan | [B, L, 5120] | Parallel scan with d_state=16 per channel |
| Gate merge | [B, L, 5120] | SSM output ⊗ SiLU(Branch B) |
| Output projection | [B, L, 2560] | Linear back to d_model |
The SSM state has shape [B, 5120, 16] — each of the 5120 channels maintains its own 16-dimensional state vector. During inference, this state is all you carry forward. No KV cache needed.
Pure Mamba models are fast but sometimes struggle with tasks requiring precise retrieval of specific tokens from the past (like "what was the 7th word?"). Pure Transformers are great at retrieval but expensive. The natural solution: combine both.
| Model | Architecture | Key Idea |
|---|---|---|
| Jamba (AI21) | Mamba + Attention + MoE | Alternate Mamba and attention layers. MoE for capacity. |
| Zamba (Zyphra) | Mamba + shared Attention | One shared attention layer interleaved every few Mamba blocks. |
| Griffin (DeepMind) | Gated linear recurrence + local attention | Use recurrence for global, windowed attention for local. |
| Mamba-2 | Structured state space duality | Show SSMs and attention are dual views of the same operation. |
Select an architecture to see how Mamba and attention layers are interleaved.
Jamba (AI21, 2024) uses a concrete pattern: for every 7 Mamba layers, insert 1 attention layer. This means only ~12% of layers are attention, but that's enough for strong retrieval performance. Combined with Mixture-of-Experts (MoE) at each layer, Jamba-1.5 fits 256K context in a single 80GB GPU — something no pure Transformer of equivalent quality can do.
Zamba (Zyphra) goes further: it uses a single shared attention layer whose weights are reused every N Mamba blocks. This dramatically reduces the parameter count from attention (which is the most parameter-heavy component) while still providing the retrieval capability where needed.
Mamba models are trained with the exact same objective as Transformers: next-token prediction with cross-entropy loss. The training data, tokenizer, and optimizer are all the same. The difference is entirely in the architecture.
During training, the entire sequence is available. The selective SSM uses the parallel scan to process all tokens simultaneously. This is analogous to how Transformers compute all attention weights at once.
During generation, we switch to sequential recurrence. At each step:
Compare memory usage during generation. Transformer's KV cache grows linearly; Mamba's state is constant.
Consider a 2.8B model generating a 100K-token sequence:
| Transformer (2.8B) | Mamba (2.8B) | |
|---|---|---|
| KV cache at 1K tokens | ~160 MB | 0 (no KV cache) |
| KV cache at 100K tokens | ~16 GB | 0 |
| SSM state (all layers) | n/a | ~2.5 MB (fixed) |
| Time per token at 100K | Grows (must attend to all KV) | Constant |
The Mamba state is tiny: d_inner × d_state × n_layers × 2 bytes = 5120 × 16 × 48 × 2 ≈ 7.5 MB per sequence. This is why Mamba shines for long-context applications like book-length analysis, genomics, and continuous audio processing.
Neither architecture is strictly better. Each has fundamental strengths and weaknesses rooted in their core mechanisms.
| Dimension | Transformer | SSM / Mamba |
|---|---|---|
| Training cost | O(n²) per layer | O(n) per layer |
| Inference memory | KV cache grows with n | Fixed-size state |
| Inference time/token | Grows with n (attend to all KV) | Constant |
| Exact retrieval | Excellent (direct access) | Weaker (compressed state) |
| Long-range context | Struggles past training length | Naturally extends |
| In-context learning | Strong (induction heads) | Emerging but less understood |
| Ecosystem maturity | Very mature (5+ years) | Rapidly growing |
| Hardware optimization | Flash Attention, GQA | Parallel scan, SRAM-aware |
Select a task to see which architecture is better suited and why.
One of the simplest ways to expose SSM vs Transformer differences: the copying task. Present the model with "Copy this: ABCDE. Output: " and check if it reproduces ABCDE exactly. Transformers trivially solve this with induction heads — they can attend directly to the tokens to copy. SSMs struggle because the information must pass through the compressed state bottleneck. A 16-dimensional state vector can't perfectly store 5 arbitrary tokens.
This is precisely why hybrid models exist: the attention layers handle exact recall, and the SSM layers handle everything else (fluency, long-range context, style). In Jamba's design, the few attention layers disproportionately handle retrieval-like subtasks while the many Mamba layers handle the bulk of language modeling.
The memory math:
Pure Transformer (7B, 32 layers, GQA with 8 KV heads, d_head=128): KV cache = 32 × 2 × 1M × 128 × 8 × 2 bytes = ~128 GB. Doesn't fit on one A100. Even with 4-bit KV quantization, that's 32 GB for cache alone, leaving ~34 GB for model+activations. Marginal at best.
Pure Mamba (7B, 48 layers, d_inner=5120, d_state=16): State = 48 × 5120 × 16 × 2 bytes = ~7.5 MB. Trivially fits. But RAG retrieval from compressed state is unreliable.
What Jamba/Zamba actually do: Use ~12% attention layers (4 out of 32) with a limited window (4096-8192 tokens). The attention layers handle precise retrieval within their window, while Mamba layers propagate long-range context. For 1M tokens with 4 attention layers at 4096 window: KV cache = 4 × 2 × 4096 × 128 × 8 × 2 bytes = ~64 MB. Total state: ~72 MB. Fits trivially with room to spare.
For RAG specifically: chunk the retrieved documents so the most relevant passages fall within the attention window of the final few layers. This guarantees exact recall where it matters.
This is the compression vs direct access tradeoff that appears everywhere in CS. A hash table (O(1) lookup, lossy compression of key space) vs a sorted array (O(log n) lookup, preserves all info). A JPEG (fixed-size, lossy) vs a bitmap (scales with resolution, lossless). In sequence modeling: a recurrent state is the "JPEG of the past" — it captures the gist but loses details. Attention is the "bitmap" — it keeps everything but the cost scales with size.
Where else in ML do you see this exact tradeoff between a fixed-size bottleneck and full storage? (Hint: think about VAE latents, pooling layers, and knowledge distillation.)
You now understand both the Transformer and its most promising challenger. From continuous-time state spaces to selective scans, from HiPPO to hardware-aware algorithms — this is the frontier of sequence modeling.