CS 229s — Systems for Machine Learning

Efficient Architecture Design

Beyond attention: state space models process sequences in O(N) instead of O(N²). From S4 to Mamba — the architectures that scale to millions of tokens.

Prerequisites: Transformers + Linear algebra. That's it.
9
Chapters
8+
Simulations
0
Assumed Systems Knowledge

Chapter 0: Beyond Attention

You have a sequence of 100,000 tokens — a textbook, an hour of audio, a genome. You want a model that understands it end to end. The problem? Transformers compute a score between every pair of tokens. That's N² operations. At 100K tokens, that's 10 billion pairwise interactions. Your GPU runs out of memory long before it runs out of things to learn.

This isn't a minor inconvenience. It's a fundamental scaling wall. Audio runs at 16,000 samples per second — a single second is already 256 million pairs. A math textbook has hundreds of thousands of words with dependencies spanning chapters. Attention simply cannot reach these scales.

So the question driving this lecture is: can we design architectures that scale sub-quadratically in sequence length, without sacrificing quality? The answer is yes — and it comes from an unexpected place: control theory.

The core tension: Attention is O(N²) but extremely expressive — every token can attend to every other token. The alternative architectures we'll study are O(N) or O(N log N), but they achieve this by compressing history into a fixed-size state. The art is in making that compression smart enough.
Compute Cost: Attention vs Linear

Drag the slider to increase sequence length. Watch the red (quadratic) curve explode while the teal (linear) one barely rises.

Sequence length N4,000
Check: Why can't we simply use efficient attention variants (sparse attention, linear attention) to solve the long-sequence problem?

Chapter 1: Sequence Modeling Primitives

Before we build new architectures, we need to understand the three fundamental ways to mix information across a sequence. Every sequence model is built from some combination of these primitives: convolutions, recurrences, and attention.

Convolutions

A convolution slides a kernel (a small weight vector) across the input and computes weighted sums at each position. The output at position n is a dot product of the kernel with a window of the input centered (or ending) at n.

y[n] = ∑k h[k] · x[n − k]

Where h is the kernel and x is the input sequence. The kernel weights are fixed — the same weights are applied at every position. This means the mixing pattern is input-independent.

Properties: parallelizable during training (all positions computed at once via FFT), but the kernel has finite reach. A kernel of size K can only look K steps back.

Recurrences

A recurrence processes the sequence one token at a time, maintaining a hidden state that summarizes everything seen so far.

hk = f(hk−1, xk)      yk = g(hk)

Properties: the hidden state has potentially infinite context (it can remember anything from the past), inference is O(1) per step (just update the state), but training is sequential — you can't compute hk without hk−1. Unless the function f is linear, in which case we can parallelize it.

Attention

Attention computes a weighted combination of all values, where the weights depend on pairwise similarity between queries and keys.

Attention(Q, K, V) = softmax(QK¹ / √d) · V

Properties: fully parallelizable, input-dependent mixing (the mixing matrix changes with every input), but O(N²) in both compute and memory.

PrimitiveTrainingInferenceContextMixing
ConvolutionParallel O(N log N)O(N) per sequenceFinite (kernel size)Input-independent
RecurrenceSequential O(N)O(1) per stepInfinite (in theory)Input-independent*
AttentionParallel O(N²)O(N) per stepFull sequenceInput-dependent

*Linear recurrences are input-independent. Mamba makes them input-dependent — we'll get there in Chapter 6.

Key insight: The dream architecture would combine the best of all three: parallel training (convolution), O(1) inference (recurrence), and input-dependent mixing (attention). The S4-to-Mamba line of work gets remarkably close.
Primitive Properties

Click each primitive to see how it processes a 5-token sequence. Watch the mixing pattern — which tokens influence each output.

Check: Which primitive gives O(1) memory and compute per step during inference (generation)?

Chapter 2: State Space Models

A state space model (SSM) is a system borrowed from control theory. It describes how a hidden state evolves over time in response to an input signal, and how we observe the output. The continuous-time form has four matrices:

x'(t) = Ax(t) + Bu(t)      y(t) = Cx(t) + Du(t)
SymbolShapeWhat it is
x(t)[N, 1]Hidden state — the model's internal memory
u(t)[1, 1]Input signal at time t (one channel)
y(t)[1, 1]Output signal at time t
A[N, N]State transition — how the state evolves on its own
B[N, 1]Input projection — how input enters the state
C[1, N]Output projection — how we read the state
D[1, 1]Skip connection (often set to 0)

Think of it like this: you're in a dark room with a complex system of springs and masses (the state x). You push it with an input u. The springs respond according to their physics (matrix A). You can only observe the system through a small window (matrix C). The state captures everything happening inside, but you only see a projection.

Why SSMs for sequences? The state x(t) is a fixed-size summary of the entire input history. No matter how long the sequence, the state dimension N stays the same. This is how we escape the O(N²) curse — we never compute pairwise interactions. We just maintain and update the state.

Multi-Channel Processing

In practice, a sequence model has D input channels (e.g., D=768 for a typical model dimension). Each channel gets its own independent SSM with its own A, B, C matrices. So a single SSM layer is actually D parallel SSMs, each with hidden state size N. The total hidden state is [D, N].

Input u(t)
Shape: [batch, seq_len, D]
↓ split into D channels
D Independent SSMs
Each: x'(t) = Ax(t) + Bu(t), state [N,1]
↓ recombine
Output y(t)
Shape: [batch, seq_len, D]
SSM State Evolution

A 4-dimensional hidden state responds to an input pulse. Each colored line is one component of x(t). The state integrates and decays — that's how it compresses history.

Check: In a continuous SSM, what does the matrix A control?

Chapter 3: Discretization

The continuous SSM is beautiful in theory, but computers work with discrete sequences. We need to convert x'(t) = Ax(t) + Bu(t) into a discrete recurrence: xk = Axk-1 + Buk. This process is called discretization.

Step 1: Zero-Order Hold (ZOH)

The simplest approach: assume the input u(t) is constant between sample points. Between time tk and tk+1 = tk + Δ, the input is uk. Under this assumption, the continuous ODE has an exact solution:

xk = e xk−1 + (e − I) A−1 B uk

So A = e and B = (e − I) A−1 B. The step size Δ controls the temporal resolution: small Δ = fine-grained sampling, large Δ = coarse sampling.

Step 2: The Bilinear Transform (Tustin's Method)

Computing matrix exponentials is expensive. The bilinear transform (also called Tustin's method or the trapezoidal rule) gives us an algebraic approximation that is both cheap and numerically stable.

The idea: approximate the derivative x'(t) using the trapezoidal rule — the average of the slope at the start and end of each interval:

xk − xk−1 ≈ (Δ/2)(x'k + x'k−1)

Substituting x' = Ax + Bu and solving for xk:

Start
xk − xk−1 = (Δ/2)(Axk + Buk + Axk−1 + Buk−1)
↓ Collect xk terms on left
Rearrange
(I − ΔA/2)xk = (I + ΔA/2)xk−1 + (Δ/2)B(uk + uk−1)
↓ Multiply both sides by (I − ΔA/2)−1
Result
A = (I − ΔA/2)−1(I + ΔA/2)    B = (I − ΔA/2)−1 ΔB
Why bilinear over ZOH? The bilinear transform preserves stability: if A has eigenvalues in the left half-plane (stable continuous system), then A has eigenvalues inside the unit circle (stable discrete system). This is crucial — an unstable discretization would cause the hidden state to explode.

Worked Example: Scalar Case

Let A = −1, B = 1, Δ = 0.5. Then:

derivation
# Bilinear transform (scalar case)
A = -1, B = 1, delta = 0.5

A_bar = (1 - delta*A/2)**(-1) * (1 + delta*A/2)
      = (1 + 0.25)**(-1) * (1 - 0.25)
      = (0.8) * (0.75)
      = 0.6

B_bar = (1 - delta*A/2)**(-1) * delta * B
      = 0.8 * 0.5 * 1
      = 0.4

# Recurrence: x_k = 0.6 * x_{k-1} + 0.4 * u_k
# |A_bar| = 0.6 < 1, so the system is stable
Discretization Explorer

Adjust the step size Δ and continuous parameter A. Watch how the discrete system (stepped) approximates the continuous (smooth). Small Δ = better approximation.

Step size Δ0.50
Decay rate A-1.0
Check: What does the step size Δ control in discretization?

Chapter 4: The HiPPO Framework

Here's a deceptively hard question: how should you initialize the matrix A? Random initialization won't work — the hidden state would either explode or forget everything instantly. We need A to do something principled: compress the input history into the state in an optimal way.

This is the problem that the HiPPO framework (High-Order Polynomial Projection Operators) solves. The insight is beautiful: choose A so that the hidden state x(t) stores the coefficients of an optimal polynomial approximation of the input history.

The analogy: Imagine you're listening to a lecture and can only keep N numbers in your head. How do you choose those N numbers to best reconstruct everything you've heard? HiPPO says: store the coefficients of the best-fit polynomial of degree N — like a Fourier transform of the past, but using Legendre polynomials instead of sines and cosines.

The HiPPO-LegS Matrix

The most important instance is HiPPO-LegS (Legendre with Scaled measure), which approximates the history over a sliding window. The resulting matrix has a specific structure:

Ank = −(2n + 1)½(2k + 1)½   if n > k     Ank = −(n + 1)   if n = k

This produces a lower-triangular matrix with entries that grow in a structured way. The key properties:

Worked Example: 4-Dimensional HiPPO

For N=4, the HiPPO-LegS matrix is:

python
import numpy as np

def hippo_legs(N):
    A = np.zeros((N, N))
    for n in range(N):
        for k in range(N):
            if n > k:
                A[n, k] = -np.sqrt(2*n+1) * np.sqrt(2*k+1)
            elif n == k:
                A[n, k] = -(n + 1)
    return A

A = hippo_legs(4)
# A = [[-1.0,  0.0,  0.0,  0.0],
#      [-1.73, -2.0,  0.0,  0.0],
#      [-2.24, -2.45, -3.0,  0.0],
#      [-2.65, -3.0, -3.32, -4.0]]
HiPPO Memory: Multi-Timescale State

The 4 state components of a HiPPO-initialized SSM responding to a step input. Lower-index components (warm) react fast and decay fast. Higher-index components (purple) react slowly but retain long-range information.

Why HiPPO matters for S4: Without HiPPO initialization, SSMs could only remember ~100 steps. With HiPPO, they can model dependencies spanning 16,000+ steps. This is what enabled S4 to solve the Path-X benchmark (sequences of length 16,384) where all previous methods failed.
Check: What does the HiPPO initialization achieve?

Chapter 5: S4 — Structured State Spaces

We now have all the ingredients: a continuous SSM (Ch 2), discretization (Ch 3), and HiPPO initialization (Ch 4). The S4 architecture (Gu et al., 2022) combines them with one crucial trick: it shows that the same SSM can be computed as either a recurrence or a convolution, and you can switch between them.

The Duality: Recurrence ↔ Convolution

After discretization, our SSM is a linear recurrence:

xk = Axk−1 + Buk      yk = Cxk

Let's unroll this recurrence starting from x0 = 0:

k=0
x0 = Bu0 → y0 = CBu0
k=1
x1 = ABu0 + Bu1 → y1 = CABu0 + CBu1
k=2
y2 = CA²Bu0 + CABu1 + CBu2

See the pattern? Each output yk is a weighted sum of all past inputs, with weights CAjB. That's a convolution! The kernel is:

K = (CB, CAB, CA²B, ..., CAL−1B)

And the output is simply y = K * u (convolution).

The dual modes:
Training: Precompute the kernel K, then convolve with the input using FFT. This is O(N log N) and fully parallelizable.
Inference: Run the recurrence step by step. Each new token costs O(1) — just update the state xk = Axk−1 + Buk. No need to reprocess the history.

FFT Convolution: Why It's Fast

Naive convolution of two length-N sequences is O(N²). But the FFT Convolution Theorem states that convolution in the time domain equals pointwise multiplication in the frequency domain:

y = K * u  ⇔  Y = FFT(K) ⊙ FFT(u)

The Fast Fourier Transform computes FFT in O(N log N). So the full pipeline is: FFT the kernel (N log N), FFT the input (N log N), pointwise multiply (N), inverse FFT (N log N). Total: O(N log N) instead of O(N²).

The Diagonal-Plus-Low-Rank Trick

There's one more obstacle: computing the kernel K = (CB, CAB, ..., CAL-1B) requires L matrix-vector products, which is O(LN²) for a dense N×N matrix A. S4's key structural insight: decompose A into a diagonal matrix plus a low-rank correction (DPLR). This allows computing the kernel in O(N log² N) using a clever recursive doubling scheme. Later work (S4D) showed that simply using a diagonal A works almost as well, making the kernel computation trivially O(NL).

S4 Dual Mode: Convolution vs Recurrence

Toggle between the two computation modes. Convolution (training): all outputs computed in parallel via FFT. Recurrence (inference): outputs computed sequentially by updating the state. Same result, different algorithm.

Mode: Convolution — all outputs computed in parallel
Check: Why can an SSM be computed as a convolution during training?

Chapter 6: Mamba — Selective SSMs

S4 was a breakthrough for long-range tasks like audio and signal processing. But when researchers tried it on language modeling, something frustrating happened: it consistently underperformed Transformers. The perplexity gap was significant — about 1-3 points worse than attention-based models of the same size.

Why? The answer lies in a task called recall.

The Recall Problem

Consider this sequence: "A 3 B 7 C 2 ... What is A?" To answer, the model needs to look back, find "A 3", and return "3". This requires the mixing matrix (how tokens influence each other) to be input-dependent. The model needs to route "A?" back to wherever "A 3" appeared.

In a Transformer, the attention matrix is input-dependent — queries and keys are computed from the input, so the model can learn to attend to the right location. In an SSM, the matrices A, B, C are fixed — the same mixing pattern applies regardless of input content. The SSM literally cannot vary its routing based on what it sees.

The fundamental limitation of LTI models: A Linear Time-Invariant system applies the same transformation to every input. It's like a filter that amplifies certain frequencies regardless of what's playing. Great for audio processing. Terrible for "look up the value associated with key A" — a task where the correct response depends entirely on the input content.

Mamba's Solution: Make B, C, and Δ Input-Dependent

Mamba (Gu & Dao, 2023) keeps the linear recurrence structure but makes three critical matrices functions of the input:

Bk = Linear(xk)    Ck = Linear(xk)    Δk = softplus(Linear(xk))
ParameterS4 (LTI)Mamba (Selective)Effect
AFixed (learned)Fixed (learned)Base dynamics unchanged
BFixed (learned)Input-dependentModel chooses what enters the state
CFixed (learned)Input-dependentModel chooses what to read from state
ΔFixed (learned)Input-dependentModel chooses its own timescale per token

The Δ Knob: Selective Forgetting

The step size Δ is particularly powerful. Remember: Δ controls how much of the continuous dynamics to apply. A large Δ means the state changes a lot — the model is "paying attention" to this token, absorbing it into memory. A small Δ means the state barely changes — the model is "ignoring" this token, letting it pass through.

This is exactly the selection mechanism: when the model sees an important token (like "A 3" in a recall task), it can increase Δ to write it into the state. When it sees filler tokens, it can decrease Δ to preserve the existing state.

The tradeoff Mamba breaks: S4 achieved efficient computation (O(N)) by being linear and time-invariant — but this forced the mixing to be input-independent, killing recall. Mamba keeps the O(N) recurrence but makes it input-dependent, recovering the expressiveness of attention without the quadratic cost. The price: you lose the convolution view (input-dependent parameters mean a different kernel for every input), so you can't use FFT. Mamba compensates with a hardware-aware parallel scan.

Hardware-Aware Parallel Scan

Without the FFT trick, how does Mamba train efficiently? It uses a parallel scan (also called prefix sum). The key insight: a linear recurrence xk = akxk-1 + bk can be computed in O(log N) parallel steps using a tree reduction, even though it looks sequential.

Mamba's implementation is hardware-aware: it fuses the discretization, scan, and multiplication into a single CUDA kernel that minimizes HBM reads/writes — the same philosophy as FlashAttention but applied to recurrences. The result: Mamba trains at 3-5x the throughput of a Transformer of similar size.

Selective Scan: What Mamba Remembers

A sequence of tokens enters Mamba. Important tokens (colored) get large Δ (tall bars) and are written into the state. Filler tokens get small Δ (short bars) and are mostly ignored. Drag the slider to see how selectivity affects recall.

Selectivity0.70
Check: Why did S4 struggle on language modeling compared to Transformers?

Chapter 7: Architecture Comparisons

We now have a rich landscape of sequence architectures. Each makes different tradeoffs between quality, training efficiency, inference efficiency, and the ability to handle long-range dependencies. Let's put them side by side.

The Quality-Efficiency Frontier

ArchitectureTraining CostInference (per token)Long-RangeRecallLanguage Quality
TransformerO(N²D)O(ND) + KV cacheFull (N²)ExcellentBest
S4 (LTI SSM)O(N log N · D)O(D)ExcellentPoorGap vs Transformer
MambaO(ND)O(D)ExcellentGoodNear-Transformer
Hybrids (Jamba, Zamba)MixedMixedGoodExcellentBest of both

When to Use What

Use Transformers when: You need the best absolute quality on language, your sequences are under 8K tokens, and you can afford the O(N²) cost. Attention is still the gold standard for tasks requiring complex reasoning and recall.
Use SSMs/Mamba when: Your sequences are very long (audio, genomics, time series), inference latency matters (O(1) per step vs O(N) for KV cache), or your deployment is memory-constrained (no growing KV cache).
Use Hybrids when: You want the best of both worlds. A few attention layers for recall-heavy positions, SSM layers for the rest. Jamba (AI21) alternates Mamba and attention layers. This is the current frontier for efficient LLMs.

The Scaling Picture

At 360M parameters trained on 10B tokens of the Pile, the perplexity comparison looks like this:

ModelTypePerplexity
Llama-style TransformerAttention8.39
GPT-2-style TransformerAttention8.97
LTI SSMSSM (no gating)13.13
H3 (Gated SSM)SSM + gating10.60
Hyena (Gated SSM)SSM + gating10.11
RWKV-V5SSM + gating9.79

The gap is narrowing, but attention-based models still lead on language. The story is different on audio (S4 beats Transformers) and very long sequences (where Transformers simply can't run).

Quality-Efficiency Tradeoff

Each dot is an architecture. X-axis: inference cost. Y-axis: quality (inverse perplexity, higher = better). The ideal architecture would be in the top-left corner — high quality, low cost.

Check: What is the main advantage of hybrid architectures (e.g., Jamba) over pure SSM models?

Chapter 8: Connections

We've traveled from the quadratic bottleneck of attention, through the primitives of sequence modeling, into the elegant mathematics of state space models, and emerged with architectures that process sequences in linear time. Here's the cheat sheet.

The big picture in one sentence: SSMs compress sequence history into a fixed-size state using linear recurrences. S4 made them practical with HiPPO initialization and convolution-recurrence duality. Mamba made them expressive by letting the model choose what to remember.
ConceptOne-Line Summary
State Space Modelx'(t) = Ax(t) + Bu(t); compress history into fixed-size state
DiscretizationConvert continuous ODE to discrete recurrence via bilinear transform
HiPPOInitialize A to store optimal polynomial approximation of history
Conv-Recurrence DualityUnrolled linear recurrence = convolution; FFT makes it O(N log N)
S4HiPPO + DPLR structure + dual mode (conv for training, recurrence for inference)
Recall ProblemLTI models can't do content-based lookup; need input-dependent mixing
MambaInput-dependent B, C, Δ; hardware-aware parallel scan; selective forgetting
HybridsMix SSM + attention layers for the best quality-efficiency tradeoff

Key Takeaways from CS 229s

  1. Efficiency is not just about FLOPs. Convolutions are O(N log N) vs attention's O(N²), but attention solves recall more efficiently. Think about efficiency per task, not just per sequence length.
  2. Error analysis drives architecture design. The discovery that SSMs fail at recall came from carefully comparing prediction errors between attention and SSM models. This led directly to Mamba.
  3. New primitives need new implementations. Just as FlashAttention made attention hardware-efficient, Mamba needed a custom CUDA kernel for its parallel scan. The algorithm-hardware co-design is essential.

What's Next

The field is moving fast. Mamba-2 introduces the State Space Duality framework (connecting SSMs to structured attention). Griffin and Hawk explore gated linear recurrences. Based and Gated Linear Attention close the gap from the attention side. The boundary between "SSM" and "attention" is blurring.

Related Lessons

Closing thought: "Over the past 6 years, we've learned how to train Transformers — what width, depth, hyperparameters work well. We need the same legwork for the alternate models." — CS 229s Lecture 09. The architectures are here. The systems engineering is just beginning.