The Complete Beginner's Path

Understand SSMs &
Mamba

Transformers have a quadratic bottleneck. State space models process sequences in linear time — and Mamba makes them input-dependent. Here's the full story.

Prerequisites: Basic linear algebra + Intuition for sequences. Familiarity with Transformers helps but isn't required.
10
Chapters
10+
Simulations
0
Assumed ML Knowledge

Chapter 0: Sequences Without Attention

Attention is powerful but expensive. For a sequence of length n, self-attention computes a score between every pair of tokens: that's n² operations. At 1,000 tokens, that's 1 million pairs. At 100,000 tokens, it's 10 billion. This quadratic cost is the fundamental bottleneck of Transformers.

What if we could process sequences without ever computing pairwise interactions? What if each token was processed in constant time, regardless of sequence length? This is the promise of state space models (SSMs) — a family of architectures borrowed from control theory that compress the entire history into a fixed-size hidden state.

Interactive: Attention Cost vs Linear Cost

Drag the slider to see how compute scales with sequence length. The gap grows fast.

Sequence length1,000
The core question: Attention lets every token see every other token. That's why it works so well. But do we really need O(n²) interactions? Maybe we can compress the past into a fixed-size state and update it one token at a time — like a rolling summary.
Check: Why is attention expensive for long sequences?

Chapter 1: Linear Recurrences

The simplest alternative to attention is a linear recurrence. At each step, we maintain a hidden state h and update it:

ht = A · ht−1 + B · xt

The matrix A controls how the previous state decays and mixes (it's the "memory" operator). The matrix B controls how much of the new input is written into the state. The output yt = C · ht reads a projection of the state. Processing each token costs constant time: one matrix-vector product for A·h, one for B·x, one for C·h. No matter how long the sequence, each new token costs the same.

If the state has dimension N and the input has dimension D, the per-step cost is O(N² + ND) for the matrix multiplications. Compare this to attention's O(L · D) per token at generation (where L is the sequence length so far). For long sequences, the SSM wins because N is fixed (typically 16–64), while L keeps growing.

Interactive: Linear Recurrence

Watch the hidden state h evolve as input tokens arrive. A controls decay (memory); B controls input mixing.

A (decay)0.90
B (input)0.50
The problem: A and B are fixed for every input. The same transformation is applied regardless of whether the current token is important or irrelevant. An SSM with fixed parameters is like reading a book with the same level of attention on every word — you can't speed-read the boring parts or slow down for key details.
ApproachCost/TokenTotal (n tokens)Can Select?
Self-AttentionO(n)O(n²)Yes (dynamic)
Linear RecurrenceO(1)O(n)No (fixed A, B)
Check: What is the main limitation of a linear recurrence with fixed A and B?

Chapter 2: The State Space Model

A state space model (SSM) comes from control theory. It starts in continuous time:

h'(t) = A h(t) + B x(t)      y(t) = C h(t) + D x(t)

This is a differential equation: h' is the derivative of the hidden state. To use it on discrete sequences (like text), we discretize it using a step size Δ. The most common method (zero-order hold) gives:

Ā = exp(Δ A)     B̄ = (ΔA)−¹(Ā − I) · ΔB
Continuous
h'(t) = Ah(t) + Bx(t) — differential equation
↓ discretize with step Δ
Discrete
ht = Ā ht-1 + B̄ xt — recurrence
↓ unroll
Convolution
y = K * x where Ki = C Āi B̄ — parallel!
Interactive: Discretization

A continuous signal (teal) is sampled at discrete steps. Smaller Δ = more faithful representation but more steps.

Step Δ0.15

Why O(L), Not O(L²)?

The recurrence hk = Ā hk-1 + B̄ xk processes each token in constant time — one matrix-vector multiply, regardless of how long the sequence is. For a sequence of length L, the total cost is O(L). Attention, by contrast, computes pairwise scores between all L tokens: O(L²). The difference is enormous at scale.

Unrolling the recurrence reveals something beautiful. Substituting repeatedly:

yk = C · (Āk B̄ x0 + Āk-1 B̄ x1 + … + B̄ xk) = ∑j (C Āk-j B̄) · xj

This is a convolution: y = K * x where the kernel Ki = C Āi B̄. When A, B, C are fixed, this kernel is precomputable. You convolve it with the input using FFT in O(L log L). Training uses this parallel path; inference uses the sequential recurrence for constant memory.

The dual view: An SSM can be computed as either a recurrence (sequential, O(n)) or a convolution (parallel, O(n log n) via FFT). Training uses convolution mode for parallelism. Inference uses recurrence mode for constant memory per token.
Check: Why start with a continuous-time formulation?
🔨 Derivation Zero-Order Hold Discretization ✓ ATTEMPTED

Given the continuous SSM: h'(t) = A h(t) + B x(t), and the assumption that x(t) is constant over each interval [kΔ, (k+1)Δ] (the zero-order hold assumption).

Your task: Derive the discrete matrices Ā = exp(ΔA) and B̄ = (exp(ΔA) − I) A−1 B from first principles. Why does B̄ have that form and not simply ΔB?

The general solution to a linear ODE h'(t) = Ah(t) + Bu (with u constant) is h(t) = exp(At) h(0) + ∫0t exp(A(t−s)) B u ds. This is the variation-of-constants formula from ODE theory.
When u is constant over [0, Δ], the integral becomes: ∫0Δ exp(A(Δ−s)) ds · B · u. Substitute τ = Δ−s to get ∫0Δ exp(Aτ) dτ · B · u.
0Δ exp(Aτ) dτ = A−1(exp(AΔ) − I). This is the matrix analog of ∫0Δ e dτ = (e − 1)/a for scalars.

Full derivation:

Start with h'(t) = Ah(t) + Bx(t). Over interval [kΔ, (k+1)Δ], assume x(t) = xk (constant). The general solution at time (k+1)Δ given h(kΔ) = hk is:

hk+1 = exp(AΔ) hk + [∫0Δ exp(Aτ) dτ] B xk

Evaluating the integral: ∫0Δ exp(Aτ) dτ = A−1(exp(AΔ) − I)

Therefore: Ā = exp(ΔA) and B̄ = A−1(exp(ΔA) − I) B = (ΔA)−1(Ā − I) · ΔB

The key insight: B̄ ≠ ΔB because the input isn't applied instantaneously — it's integrated over the full time step while the state is simultaneously evolving under A. The factor A−1(exp(AΔ) − I) captures this interaction. For small Δ, it approaches ΔI (recovering the Euler approximation), but for larger Δ the correction matters significantly.

🔗 Pattern Recognition
SSM = Kalman Filter Without the Noise Model
This Lesson (SSM)
ht = Ā ht−1 + B̄ xt
yt = C ht
Kalman Filter
t|t−1 = F x̂t−1 + B ut
t = H x̂t|t−1Kalman Filter lesson

Both are linear state space models with identical structure: a transition matrix evolves a hidden state, and an observation matrix reads it out. The Kalman filter adds noise covariance tracking (Q, R) and optimal gain computation (K). An SSM in a neural network learns its matrices from data instead of deriving them from physics. But the core operation — compress history into a fixed-size state, update linearly — is identical.

If the SSM learned Q and R matrices and computed a Kalman gain, what would that correspond to in the neural network? (Answer: an attention-like weighting mechanism over the state update.)

Checkpoint — Before you move on
Explain in your own words: why does the discrete SSM have TWO computation paths (recurrence and convolution), and why do we use different paths for training vs inference?
✓ Gate cleared
Model Answer

When A, B, C are fixed (not input-dependent), the recurrence hk = Āhk−1 + B̄xk can be unrolled into a convolution y = K * x where Ki = CĀiB̄. This kernel K is precomputable because it doesn't depend on the input.

Training uses convolution mode (O(n log n) via FFT) because GPUs excel at parallel computation — processing all tokens simultaneously is vastly faster than a sequential loop.

Inference uses recurrence mode because we generate one token at a time anyway. The recurrence gives O(1) per token with O(N) state memory — no need to recompute over the entire history. The KV-cache-free nature of SSMs comes directly from this recurrence view.

This dual view breaks the moment A, B, C become input-dependent (Mamba), because you can no longer precompute K. That's why Mamba needs the parallel scan instead of FFT convolution.

Chapter 3: The S4 Trick

The breakthrough of S4 (Structured State Spaces for Sequence Modeling) was solving the long-range memory problem. Previous SSMs forgot quickly — information from 1000 steps ago had decayed to nearly zero. S4 fixed this with two key ideas:

1. HiPPO Initialization

Instead of random initialization, the A matrix is set to the HiPPO (High-Order Polynomial Projection Operator) matrix. This special matrix is designed to optimally compress the history of a continuous signal into a fixed-size state. Each element of ht stores a coefficient of a polynomial approximation of the entire input history.

Ank = −(2n+1)1/2(2k+1)1/2   if n > k,   −(n+1)   if n = k

2. Diagonal + Low-Rank Structure

Computing Ā = exp(ΔA) naively for an N×N matrix costs O(N³). S4 decomposes A into diagonal + low-rank form, reducing this to O(N). This makes training practical for state sizes of 64 or more.

Interactive: Memory Decay — Random vs HiPPO

A signal arrives at step 0. Watch how much the hidden state remembers it over time. HiPPO retains far more information.

What S4 Actually Computes

S4 decomposes A into Diagonal Plus Low-Rank (DPLR) form: A = Λ − PQ*, where Λ is diagonal and P, Q are low-rank matrices. This is critical because computing exp(ΔA) for a general N×N matrix costs O(N³). With DPLR structure, it reduces to O(N). The convolution kernel K can then be computed via a Cauchy kernel formula, making the entire training step O(L log L) via FFT.

In subsequent work, S4D (diagonal state spaces) simplified this further: just use a diagonal A matrix, initialized from the HiPPO eigenvalues. This drops the low-rank components entirely, making implementation much simpler while retaining most of the performance. Mamba uses diagonal A by default.

The HiPPO matrix specifically captures a polynomial projection: each element of the state ht stores a Legendre polynomial coefficient of the input history. Element ht[k] approximates the k-th Legendre coefficient of x(s) over [0, t]. This means the state literally stores a polynomial approximation of everything it has seen — a lossless-ish compression of the entire history into a fixed-size vector.

Why it matters: Before S4, SSMs couldn't match Transformers on tasks requiring long-range dependencies (e.g., understanding context 4000 tokens ago). S4 was the first SSM to achieve competitive performance on the Long Range Arena benchmark, matching or beating Transformers.
Check: What problem does HiPPO initialization solve?
🔨 Derivation Why HiPPO Uses Legendre Polynomials ✓ ATTEMPTED

The HiPPO framework asks: "What is the optimal way to compress the history of a signal x(s) for s ∈ [0, t] into a fixed-size state vector c(t) ∈ RN?"

Your task: Show why projecting onto Legendre polynomials gives the HiPPO-LegS matrix Ank = −(2n+1)1/2(2k+1)1/2 for n > k. What makes this a "good" compression?

We want cn(t) to be the n-th coefficient of the best L2 approximation of x(s) on [0, t] using an orthogonal polynomial basis. Legendre polynomials Pn are orthogonal on [−1, 1], so we rescale to [0, t]: cn(t) = (2n+1)/t ∫0t Pn(2s/t − 1) x(s) ds.
Taking d/dt of cn(t) involves the Leibniz rule (the integral's upper bound depends on t) AND the fact that the basis functions themselves shift as t grows (because we're projecting onto [0,t], not a fixed interval). This produces a coupling between all coefficients ck with k ≤ n, yielding the matrix ODE c'(t) = −(1/t) A c(t) + (1/t) B x(t).
The derivative of Pn can be expressed as a linear combination of Pk for k < n (a classical identity for Legendre polynomials). This means updating coefficient n only requires knowing coefficients k ≤ n — hence A is lower-triangular. The specific entries come from the Legendre recurrence relation coefficients scaled by the normalization factors (2n+1)1/2.

Full derivation:

1. Define the online polynomial approximation: cn(t) = ∫0t wn(s, t) x(s) ds where wn(s, t) = (2n+1)/t · Pn(2s/t − 1) is the Legendre weight rescaled to [0, t].

2. Differentiate with respect to t using Leibniz rule and the chain rule on the rescaled argument 2s/t − 1:

c'n(t) = (2n+1)/t · x(t) · Pn(1) − ∑k=0n Ank/t · ck(t)

3. Using Pn(1) = 1 and the Legendre derivative identities, the coupling matrix works out to Ank = (2n+1)1/2(2k+1)1/2 for n > k, and Ann = n + 1.

4. The system is c'(t) = −(1/t) A c(t) + (1/t) B x(t) where Bn = (2n+1)1/2.

The key insight: This specific A matrix is the unique linear system that maintains the optimal polynomial approximation of the entire history as new data arrives. A randomly initialized A has no such guarantee — it just exponentially forgets. HiPPO's A is structured so that information doesn't decay; it gets redistributed across polynomial modes. This is why S4 can model dependencies over thousands of steps.

Chapter 4: Selective Scan (Mamba)

S4 has a critical limitation: A, B, C are the same for every input token. The model can't decide to "pay more attention" to important tokens or "forget" irrelevant ones. Mamba (Gu & Dao, 2023) fixes this by making B, C, and Δ functions of the input.

Bt = Linear(xt)    Ct = Linear(xt)    Δt = softplus(Linear(xt))

This is the selective in "selective scan." When the model sees an important token, it can increase Δ (larger step = more input mixed in) and adjust B to write strongly to the state. For irrelevant tokens, it can shrink Δ and effectively skip them.

Interactive: Selective Gates

Watch how Mamba's input-dependent gates open (green) for important tokens and close (dim) for irrelevant ones. Click tokens to toggle importance.

What "Selective" Actually Means

Think of Δt as a gate width. When Δ is large, exp(ΔA) shrinks the old state more and B̄ grows — the new input floods in. When Δ is small, the old state is preserved and the input barely registers. The model learns to open the gate wide for important tokens (names, keywords, punctuation) and keep it narrow for filler words.

Bt controls what is written to the state, Ct controls what is read out, and Δt controls how much the state changes. All three are linear projections of the current input xt — making the entire SSM dynamics content-dependent. This is what S4 couldn't do.

The breakthrough: Making parameters input-dependent breaks the convolution view — you can no longer precompute a single kernel K. But it makes the model vastly more expressive. Mamba showed that with careful GPU implementation, the recurrence can be nearly as fast as the convolution.
ModelA, B, CConvolution?Content-aware?
S4Fixed (same for all inputs)Yes (parallel training)No
MambaInput-dependentNo (must use scan)Yes

The parameter generation is lightweight. For each token xt in Rd_inner: Bt = WB xt where WB is [d_state, d_inner], producing Bt in Rd_state. Similarly for Ct. For Δt: a linear projection to Rd_inner followed by softplus ensures Δ > 0. The total parameter overhead for selectivity is minimal — just three small linear layers per block.

Check: What is the key innovation of Mamba over S4?
Checkpoint — Before you move on
Making B, C, Δ input-dependent breaks the convolution view. Explain why — what specific property of the kernel Ki = CĀiB̄ is violated when these matrices change at every timestep?
✓ Gate cleared
Model Answer

The convolution kernel Ki = CĀiB̄ can be precomputed ONLY because C, Ā, B̄ are the same for every token. The kernel depends only on the lag i (how far back we're looking), not on which token is at position i.

When Bt = Linear(xt), Ct = Linear(xt), and Δt = softplus(Linear(xt)), the effective "kernel" at position k looking back j steps becomes Cki=k-j+1ki · B̄k-j. This product depends on every input in the range — it's no longer shift-invariant. You can't factor it into a single convolution.

This is the fundamental cost of selectivity: you gain content-awareness but lose the O(n log n) FFT training path. The parallel scan (next chapter) is the solution — it's not as fast as FFT, but it's O(n) work in O(log n) parallel depth, which is practical on GPUs.

💻 Build It Implement the Selective SSM Step ✓ ATTEMPTED
Implement one step of Mamba's selective SSM: given the current input xt, produce the output yt and update the hidden state. Include the discretization (ZOH), the input-dependent parameter generation, and the state update.
signature def selective_ssm_step(x_t, h_prev, A, W_B, W_C, W_delta, bias_delta): """ One step of Mamba's selective SSM. Args: x_t: input vector [d_inner] h_prev: previous hidden state [d_inner, d_state] A: diagonal state matrix [d_inner, d_state] (log-space) W_B: projection for B [d_state, d_inner] W_C: projection for C [d_state, d_inner] W_delta: projection for delta [d_inner, d_inner] bias_delta: bias for delta [d_inner] Returns: y_t: output [d_inner] h_t: new hidden state [d_inner, d_state] """
Test case
With d_inner=4, d_state=2, x_t=ones(4), h_prev=zeros(4,2), A=-ones(4,2):
delta_t should be positive (softplus output), B_bar should scale with delta,
and y_t should have shape [4] with non-zero values (since B_t writes x_t into fresh state).
In practice, Mamba uses a first-order approximation for B̄: B̄t ≈ Δt · Bt (element-wise scaling). The exact ZOH form A−1(eΔA−I)B is equivalent for diagonal A, but the simplified form is what's actually implemented in the CUDA kernel. For Ā: since A is stored in log-space, Ā = exp(Δ · exp(A_log)) = exp(Δ · A_real).
python
import torch
import torch.nn.functional as F

def selective_ssm_step(x_t, h_prev, A, W_B, W_C, W_delta, bias_delta):
    # 1. Generate input-dependent parameters
    B_t = W_B @ x_t                         # [d_state]
    C_t = W_C @ x_t                         # [d_state]
    delta_t = F.softplus(W_delta @ x_t + bias_delta)  # [d_inner], > 0

    # 2. Discretize (A is in log-space, shape [d_inner, d_state])
    A_real = torch.exp(A)                   # negative reals
    A_bar = torch.exp(delta_t.unsqueeze(1) * A_real)  # [d_inner, d_state]
    B_bar = delta_t.unsqueeze(1) * B_t.unsqueeze(0)   # [d_inner, d_state]

    # 3. Update state
    h_t = A_bar * h_prev + B_bar * x_t.unsqueeze(1)  # [d_inner, d_state]

    # 4. Read output
    y_t = (h_t * C_t.unsqueeze(0)).sum(dim=1)  # [d_inner]

    return y_t, h_t
Bonus challenge: Extend this to process an entire sequence of L tokens. Can you see why a naive for-loop would be slow on GPU? How would you batch the parameter generation (steps 1-2) across all tokens in parallel, then use a scan for step 3?

Chapter 5: Hardware-Aware Scan

Making B, C, Δ input-dependent means we can't use convolution for training. The naive approach — a sequential for-loop — would be unbearably slow on GPUs. Mamba uses a parallel scan algorithm that computes the full recurrence in O(log n) parallel steps instead of O(n) sequential ones.

The Parallel Scan

The key insight: a linear recurrence ht = at ht-1 + bt is an associative operation. Like addition, we can rearrange the order of computation. Instead of left-to-right, we use a binary tree:

Step 1
Compute pairs: (h0,h1), (h2,h3), (h4,h5), ...
Step 2
Combine pairs: (h0..1,h2..3), (h4..5,h6..7), ...
Step 3
Combine again: (h0..3,h4..7), ...
↓ log2(n) steps total
Done
All h0..n computed in parallel
Interactive: Parallel Scan Tree

Click "Step" to advance through the parallel scan. Each level halves the remaining work.

Step 0 / 3: Initial values

The Fused CUDA Kernel

Mamba doesn't just use a parallel scan — it fuses the entire SSM computation (discretization, scan, output projection) into a single GPU kernel. The key bottleneck on modern GPUs isn't compute but memory bandwidth. Reading and writing intermediate tensors of shape [B, L, d_inner, d_state] to global memory (HBM) would dominate runtime.

Instead, Mamba's kernel keeps the state (d_state=16 floats per channel) entirely in SRAM (the GPU's fast on-chip scratchpad). The associative scan runs in registers, and only the final output [B, L, d_inner] is written back to HBM. This reduces memory I/O by a factor proportional to d_state, making the selective SSM nearly as fast as a simple matrix multiply.

GPU optimization: Mamba also keeps the full state in SRAM (fast on-chip memory) and avoids writing intermediate results to slower HBM. This is the same philosophy as Flash Attention — restructure the algorithm to match GPU memory hierarchy.
Check: How does the parallel scan reduce O(n) sequential steps?
🔨 Derivation Proving the Scan is Associative ✓ ATTEMPTED

The linear recurrence ht = at · ht−1 + bt can be written as a binary operation on tuples: (a2, b2) • (a1, b1) = (a2 · a1, a2 · b1 + b2).

Your task: Prove this operation is associative: ((a3, b3) • (a2, b2)) • (a1, b1) = (a3, b3) • ((a2, b2) • (a1, b1)). Then explain why associativity enables O(log n) parallel depth.

Left side: first compute (a3, b3) • (a2, b2) = (a3a2, a3b2 + b3). Then apply • (a1, b1): = (a3a2 · a1, a3a2 · b1 + a3b2 + b3).
Right side: first compute (a2, b2) • (a1, b1) = (a2a1, a2b1 + b2). Then apply (a3, b3) •: = (a3 · a2a1, a3(a2b1 + b2) + b3) = (a3a2a1, a3a2b1 + a3b2 + b3). Same as left side!
If the operation is associative, computing the "prefix sum" (cumulative application from left to right) can be restructured into a balanced binary tree. Instead of ((a•b)•c)•d (3 sequential steps), we compute (a•b) and (c•d) in parallel, then combine. For n elements: log2(n) parallel steps instead of n−1 sequential steps.

Full proof:

Define •: (a2, b2) • (a1, b1) = (a2a1, a2b1 + b2)

Left: ((a3,b3)•(a2,b2))•(a1,b1) = (a3a2, a3b2+b3)•(a1,b1) = (a3a2a1, a3a2b1 + a3b2 + b3)

Right: (a3,b3)•((a2,b2)•(a1,b1)) = (a3,b3)•(a2a1, a2b1+b2) = (a3·a2a1, a3(a2b1+b2)+b3) = (a3a2a1, a3a2b1 + a3b2 + b3)

Left = Right. QED. The operation is associative (but NOT commutative — order matters!).

Why this enables the parallel scan: The prefix computation hn = reduce(•, [(a1,b1), ..., (an,bn)]) is equivalent to computing ALL intermediate prefixes. With associativity, we can use the Blelloch scan algorithm: an up-sweep (reduce phase) followed by a down-sweep (propagate phase), each taking log2(n) steps. Total work is O(n), parallel depth is O(log n).

The key insight: This tuple representation encodes the entire linear recurrence. The first element tracks the cumulative "decay" (how much of the initial state survives), and the second element tracks the cumulative "input contribution" (everything that's been added along the way). It's a matrix-free way to represent ht = (∏ ai)h0 + accumulated inputs.

Chapter 6: The Mamba Block

A Mamba block looks quite different from a Transformer block. It has no attention mechanism at all. Instead, it uses a combination of linear projections, 1D convolution, and the selective SSM:

Input
x ∈ Rn×d
Linear Projection
Expand: d → 2·dinner (split into two branches)
↓ Branch A      ↓ Branch B
Conv1D (Branch A)
Short causal convolution (kernel size ~4)
Selective SSM
Input-dependent scan with generated B, C, Δ
↓ ⊗ SiLU(Branch B)
Gated Merge
Element-wise multiply with SiLU-activated Branch B
Linear Projection
Project back: dinner → d
Interactive: Mamba Block Data Flow

Click "Step" to trace data through the Mamba block. Watch the two branches split and merge.

Click "Step" to trace the data flow

Concrete Shapes (Mamba-2.8B)

For Mamba-2.8B: d_model=2560, d_state=16, d_conv=4, expand_factor=2, so inner_dim=5120. Here's the exact data flow through one block:

StageShapeOperation
Input[B, L, 2560]
Linear expand[B, L, 10240]Project to 2 × inner_dim, split into two branches
Branch A: Conv1D[B, L, 5120]Causal conv, kernel=4, groups=5120 (depthwise)
Generate B, C, Δ[B, L, 16] eachLinear projections from Branch A
SSM scan[B, L, 5120]Parallel scan with d_state=16 per channel
Gate merge[B, L, 5120]SSM output ⊗ SiLU(Branch B)
Output projection[B, L, 2560]Linear back to d_model

The SSM state has shape [B, 5120, 16] — each of the 5120 channels maintains its own 16-dimensional state vector. During inference, this state is all you carry forward. No KV cache needed.

Why Conv1D? The short convolution provides local context mixing — it lets the model see a few neighboring tokens before the SSM processes the global sequence. It's like a tiny receptive field that helps the SSM know "what's nearby" before deciding how to update the state.
Check: What role does the SiLU-gated branch play?
💥 Break-It Lab What Dies When You Remove Mamba Components? ✓ ATTEMPTED
A working Mamba block processes a sequence of tokens through selective SSM with HiPPO initialization. The canvas shows how well the model retains and selectively processes information. Toggle components off to see what breaks.
Remove Selectivity (revert to LTI) ACTIVE
Failure mode: Without input-dependent B, C, Δ, the model treats every token identically. The "important" tokens (green) get the same state update as filler tokens. The model can no longer selectively retain relevant information — it either remembers everything (state overflows) or forgets everything (information washes out). On the copying task, perplexity increases by 2-4x because the model can't "open the gate" for tokens it needs to recall.
Remove HiPPO Init (random A) ACTIVE
Failure mode: Random initialization of A causes exponential decay of past information. After ~50 tokens, the state has effectively forgotten everything. The model becomes a "short-range" processor, unable to maintain context beyond a small window. Long-range benchmarks (Path-X, Long Range Arena) drop from 90%+ to near-random. The polynomial compression structure is lost.
Remove Gated Branch (no SiLU gate) ACTIVE
Failure mode: Without the multiplicative gate from Branch B, the SSM output passes directly to the output projection. The model loses its ability to suppress or amplify SSM outputs based on local context. Training becomes unstable (gradients are less well-conditioned), and the model underperforms by ~1 perplexity point. The gate acts as an information bottleneck that regularizes the SSM signal.

Chapter 7: Hybrid Models

Pure Mamba models are fast but sometimes struggle with tasks requiring precise retrieval of specific tokens from the past (like "what was the 7th word?"). Pure Transformers are great at retrieval but expensive. The natural solution: combine both.

Key Hybrid Architectures

ModelArchitectureKey Idea
Jamba (AI21)Mamba + Attention + MoEAlternate Mamba and attention layers. MoE for capacity.
Zamba (Zyphra)Mamba + shared AttentionOne shared attention layer interleaved every few Mamba blocks.
Griffin (DeepMind)Gated linear recurrence + local attentionUse recurrence for global, windowed attention for local.
Mamba-2Structured state space dualityShow SSMs and attention are dual views of the same operation.
Interactive: Hybrid Architecture Viewer

Select an architecture to see how Mamba and attention layers are interleaved.

How Jamba Interleaves

Jamba (AI21, 2024) uses a concrete pattern: for every 7 Mamba layers, insert 1 attention layer. This means only ~12% of layers are attention, but that's enough for strong retrieval performance. Combined with Mixture-of-Experts (MoE) at each layer, Jamba-1.5 fits 256K context in a single 80GB GPU — something no pure Transformer of equivalent quality can do.

Zamba (Zyphra) goes further: it uses a single shared attention layer whose weights are reused every N Mamba blocks. This dramatically reduces the parameter count from attention (which is the most parameter-heavy component) while still providing the retrieval capability where needed.

Why hybrids work: Mamba excels at long-range, smooth information flow (like understanding overall topic and style). Attention excels at precise, content-based retrieval (like copying a name from 5000 tokens ago). Combining both gives you the best of both worlds with a fraction of the attention cost.
Check: What weakness of pure Mamba models do hybrid architectures address?

Chapter 8: Training & Inference

Mamba models are trained with the exact same objective as Transformers: next-token prediction with cross-entropy loss. The training data, tokenizer, and optimizer are all the same. The difference is entirely in the architecture.

Training: Parallel Mode

During training, the entire sequence is available. The selective SSM uses the parallel scan to process all tokens simultaneously. This is analogous to how Transformers compute all attention weights at once.

Inference: Recurrence Mode

During generation, we switch to sequential recurrence. At each step:

Get token
Receive xt from previous prediction
Update state
ht = Āt ht-1 + B̄t xt — constant time!
Output
yt = Ct ht — predict next token
Interactive: Inference Memory Comparison

Compare memory usage during generation. Transformer's KV cache grows linearly; Mamba's state is constant.

Generated tokens100

Memory: A Concrete Example

Consider a 2.8B model generating a 100K-token sequence:

Transformer (2.8B)Mamba (2.8B)
KV cache at 1K tokens~160 MB0 (no KV cache)
KV cache at 100K tokens~16 GB0
SSM state (all layers)n/a~2.5 MB (fixed)
Time per token at 100KGrows (must attend to all KV)Constant

The Mamba state is tiny: d_inner × d_state × n_layers × 2 bytes = 5120 × 16 × 48 × 2 ≈ 7.5 MB per sequence. This is why Mamba shines for long-context applications like book-length analysis, genomics, and continuous audio processing.

Constant memory per token: Unlike the Transformer's KV cache, which grows with every generated token, the SSM hidden state is always the same size. At 100K tokens, a Transformer might need 40+ GB for the KV cache. Mamba needs the same fixed state it had at token 1.
Check: How does Mamba's inference memory scale with sequence length?

Chapter 9: SSM vs Transformer Tradeoffs

Neither architecture is strictly better. Each has fundamental strengths and weaknesses rooted in their core mechanisms.

DimensionTransformerSSM / Mamba
Training costO(n²) per layerO(n) per layer
Inference memoryKV cache grows with nFixed-size state
Inference time/tokenGrows with n (attend to all KV)Constant
Exact retrievalExcellent (direct access)Weaker (compressed state)
Long-range contextStruggles past training lengthNaturally extends
In-context learningStrong (induction heads)Emerging but less understood
Ecosystem maturityVery mature (5+ years)Rapidly growing
Hardware optimizationFlash Attention, GQAParallel scan, SRAM-aware
Interactive: Task Suitability Comparison

Select a task to see which architecture is better suited and why.

The Copying Test

One of the simplest ways to expose SSM vs Transformer differences: the copying task. Present the model with "Copy this: ABCDE. Output: " and check if it reproduces ABCDE exactly. Transformers trivially solve this with induction heads — they can attend directly to the tokens to copy. SSMs struggle because the information must pass through the compressed state bottleneck. A 16-dimensional state vector can't perfectly store 5 arbitrary tokens.

This is precisely why hybrid models exist: the attention layers handle exact recall, and the SSM layers handle everything else (fluency, long-range context, style). In Jamba's design, the few attention layers disproportionately handle retrieval-like subtasks while the many Mamba layers handle the bulk of language modeling.

When to use which: For very long contexts (100K+ tokens, audio, genomics), SSMs shine. For retrieval-heavy tasks (RAG, QA over documents), Transformers are more reliable. For production chat models, hybrids are increasingly the best choice.
⚔ Adversarial: Mamba has no attention matrix. How does it handle the copying task?
You give Mamba the prompt: "Remember: X7Q2M. Now repeat: ". The model must output "X7Q2M" exactly. But the hidden state is only 16 dimensions per channel. There's no attention matrix to "look back" at the original tokens.
🏗 Design Challenge You're the Architect: 1M Token Context at 7B Scale ✓ ATTEMPTED
Your team needs to ship a 7B-parameter model that handles 1 million token context at inference on a single A100 (80GB). The model must support both long-document summarization (smooth context) AND retrieval-augmented generation (exact recall from injected documents). Design the architecture.
GPU Memory
80 GB (A100)
Model Size
7B parameters (~14 GB in FP16)
Context Length
1,000,000 tokens
Latency Target
< 50ms per token at generation
Key Tasks
Summarization + RAG retrieval
1. Pure Transformer with 1M context: what's the KV cache size? Can it fit? (Hint: compute n_layers × 2 × seq_len × d_head × n_heads × 2 bytes)
2. Pure Mamba: what's the state size? Will retrieval work for RAG?
3. Hybrid: how many attention layers do you need? Where do you place them? What attention window do you use?
4. What's your prefill strategy? (Chunked? Which layers get full context vs windowed?)

The memory math:

Pure Transformer (7B, 32 layers, GQA with 8 KV heads, d_head=128): KV cache = 32 × 2 × 1M × 128 × 8 × 2 bytes = ~128 GB. Doesn't fit on one A100. Even with 4-bit KV quantization, that's 32 GB for cache alone, leaving ~34 GB for model+activations. Marginal at best.

Pure Mamba (7B, 48 layers, d_inner=5120, d_state=16): State = 48 × 5120 × 16 × 2 bytes = ~7.5 MB. Trivially fits. But RAG retrieval from compressed state is unreliable.

What Jamba/Zamba actually do: Use ~12% attention layers (4 out of 32) with a limited window (4096-8192 tokens). The attention layers handle precise retrieval within their window, while Mamba layers propagate long-range context. For 1M tokens with 4 attention layers at 4096 window: KV cache = 4 × 2 × 4096 × 128 × 8 × 2 bytes = ~64 MB. Total state: ~72 MB. Fits trivially with room to spare.

For RAG specifically: chunk the retrieved documents so the most relevant passages fall within the attention window of the final few layers. This guarantees exact recall where it matters.

🔗 Pattern Recognition
Attention and Recurrence: The Same Tradeoff Everywhere
SSM / Mamba
Compress history into fixed state.
O(1) per token, but lossy.
Can't retrieve arbitrary past tokens.
Transformer
Store all past tokens in KV cache.
O(n) per token, but lossless.
Direct access to any past token. → Transformer lesson

This is the compression vs direct access tradeoff that appears everywhere in CS. A hash table (O(1) lookup, lossy compression of key space) vs a sorted array (O(log n) lookup, preserves all info). A JPEG (fixed-size, lossy) vs a bitmap (scales with resolution, lossless). In sequence modeling: a recurrent state is the "JPEG of the past" — it captures the gist but loses details. Attention is the "bitmap" — it keeps everything but the cost scales with size.

Where else in ML do you see this exact tradeoff between a fixed-size bottleneck and full storage? (Hint: think about VAE latents, pooling layers, and knowledge distillation.)

"The next generation of sequence models will likely combine the best of both worlds."
— The emerging consensus

You now understand both the Transformer and its most promising challenger. From continuous-time state spaces to selective scans, from HiPPO to hardware-aware algorithms — this is the frontier of sequence modeling.