Behrouz, Razaviyayn, Zhong, Mirrokni — Google, NeurIPS 2025

Nested Learning

The Illusion of Deep Learning Architecture — a neural network + its optimizer is not a flat stack of layers. It's a system of nested, multi-level optimization problems, each with its own context flow and update frequency.

Prerequisites: Transformers + Gradient descent + Attention mechanisms
10
Chapters
5+
Simulations

Chapter 0: The Problem

You've just trained a 70-billion-parameter language model. It took months of GPU time, petabytes of data, and millions of dollars. The result is extraordinary — it can reason, code, write poetry. But there's a problem so fundamental it's almost embarrassing.

The model can't learn anything new.

After pre-training, the weights are frozen. The model is like a patient with anterograde amnesia — it remembers everything from training (distant past encoded in MLP weights), and it can process what's in its current context window (the immediate present via attention), but it cannot form new long-term memories. It is continuously experiencing the present as if it were always new.

The two-timescale trap: A Transformer has exactly two timescales of memory. MLP weights operate at frequency 0 — updated only during pre-training, then frozen forever. Attention operates at frequency ∞ — recomputed from scratch for every single token, with no persistence. There is nothing in between. No medium-term memory. No gradual consolidation. Just two extremes.

Compare this to the human brain. Neuroscience has identified at least four distinct neural oscillation bands, each operating at a different timescale:

The brain doesn't have two timescales. It has a spectrum. Fast oscillations for immediate reactions, slow oscillations for deep consolidation, and everything in between for the gradual work of learning. This is neuroplasticity — the brain's ability to continuously rewire itself at multiple speeds simultaneously.

Static LLM vs. Brain Multi-Timescale

Left: a Transformer with only two extremes (frozen MLP, ephemeral attention). Right: the brain's spectrum of oscillation frequencies. The gap in the middle is what Nested Learning aims to fill.

The paper's thesis: this isn't just a missing feature. It's a fundamental architectural limitation. And the reason we haven't fixed it is that we've been thinking about neural networks wrong. We treat them as flat stacks of layers — input flows in, passes through layers, output comes out. But the paper argues that the real structure is nested: layers within layers, optimizers within optimizers, each operating at its own frequency. Once you see the nesting, the path to multi-timescale memory becomes obvious.

Why can't a standard Transformer form new long-term memories after pre-training?

Chapter 1: Everything is Associative Memory

Before we can understand nested learning, we need one powerful unifying idea: learning is memory acquisition, and memory is a neural update caused by input. That's it. Every time a neural network changes in response to data — whether during training, inference, or anything in between — it's forming a memory.

What is an associative memory?

An associative memory is a function that maps keys to values. You give it a key, it returns the associated value. Think of a dictionary, a hash table, or — most relevantly — a weight matrix in a neural network.

Formally, the optimal associative memory M* solves:

M* = argminM L̃(M(K); V)

where K is a set of keys, V is a set of values, and L̃ is a loss measuring how well M maps each key to its corresponding value.

Training IS associative memory

Here's the key insight. Consider training a simple linear layer W with SGD on a single data point x with loss L:

Wt+1 = Wt − η ∇yL(Wt; xt+1) ⊗ xt+1

Look at what this says. The update to W is an outer product: the gradient of the loss (the "surprise signal" — how wrong the prediction was) tensor-producted with the input. The weight matrix W is literally storing key-value pairs where:

Gradient = surprise: The gradient ∇L is not just a mathematical convenience for optimization. It's a signal that encodes how surprised the network was by this input. A large gradient means high surprise (the prediction was very wrong). A zero gradient means no surprise (the prediction was perfect). Training a neural network is building an associative memory that maps inputs to their surprise signals.

The proximal view

There's an equivalent way to write the SGD update that makes the associative memory structure even more explicit. Instead of the gradient step, we can write it as an optimization problem:

Wt+1 = argminW ⟨Wxt+1, ut+1⟩ + (1/2η)||W − Wt||²

where ut+1 = ∇yL is the surprise signal. The first term says "map x to something aligned with the surprise u." The second term says "don't change W too much from its current value." This is a proximal operator — and it reveals that every SGD step is solving a tiny optimization problem: find the best memory update that balances learning the new input against preserving existing knowledge.

Training as Surprise-Memory Mapping

Each input x produces a gradient (surprise signal). The weight matrix W stores these key-value associations. Click inputs to see their surprise signals and how W accumulates memories.

Step 0 / 6
In the Nested Learning framework, what does the gradient ∇L represent beyond a direction for optimization?

Chapter 2: Optimizers as Associative Memories

If training with SGD is associative memory, what about fancier optimizers? This is where the "nested" part begins.

SGD: 1-level memory

Plain SGD maps data directly to surprise signals and stores them in W. One level of memory. One optimization problem. Simple.

Wt+1 = Wt − η ∇L(Wt; xt+1)

SGD with momentum: 2-level nested memory

Now add momentum. The standard update is:

mt+1 = β mt + η ∇L   (inner level: accumulate gradients)
Wt+1 = Wt − mt+1   (outer level: apply accumulated update)

This is two nested associative memories:

The momentum buffer is itself an associative memory with its own parameters (m), its own update rule (exponential moving average), and its own "context" (the stream of gradients). The model weights W sit at a lower level, receiving processed information from above.

Adam: optimal element-wise memory

Adam goes further. It maintains both first moments (mean of gradients) and second moments (mean of squared gradients):

mt+1 = β1 mt + (1 − β1) ∇L   (first moment)
vt+1 = β2 vt + (1 − β2) (∇L)²   (second moment)
Wt+1 = Wt − η mt+1 / (√vt+1 + ε)

The paper shows that Adam's update rule is actually the optimal associative memory for element-wise L2 regression on the gradient stream. The second moment v acts as a normalizer — it's a memory of how volatile each parameter's gradient has been, used to scale the update appropriately.

The deep insight: Optimizers and architectures are the same thing at different levels of the nesting hierarchy. An attention head computes over a sequence of tokens. A momentum buffer computes over a sequence of gradients. Both are associative memories. Both map keys to values. The only difference is what they operate on and when they update. We've been treating "architecture" and "optimizer" as separate concepts, but they're instances of the same abstraction.
Why is SGD with momentum a "2-level nested memory" rather than a 1-level memory like plain SGD?

Chapter 3: Architectures as Nested Optimization

This is the core of the paper — the "aha" moment. We'll show that a Transformer block isn't just a stack of layers. It's a system of nested optimization problems operating at different frequencies.

Linear attention as gradient descent

Consider linear attention (no softmax). At each step t, the model maintains a memory matrix Mt and updates it:

Mt+1 = Mt + vt+1 kt+1T

The output at query time is:

ot+1 = Mt+1 qt+1

Now here's the key observation. The memory update Mt+1 = Mt + v kT is exactly a single step of gradient descent on the objective:

L̃(M) = −⟨Mk, v⟩

This is a dot-product objective: find M such that Mk is aligned with v. The gradient is −vkT, and subtracting it (with step size 1) gives exactly M + vkT.

Linear attention = inner optimization: Every token processed by linear attention is a gradient step on an inner optimization problem. The "architecture" is performing optimization at inference time. The keys and values that define this inner problem come from projections Wk, Wv, Wq that were learned during pre-training — the outer optimization. Two levels of optimization, nested.

Decomposing a Transformer block

A standard Transformer block has two components: an attention sublayer and an MLP sublayer. In the nested learning view:

Nested Optimization Levels in a Transformer

Interactive decomposition showing how MLP, attention, and the optimizer form nested levels at different frequencies. Click each level to see its role, update frequency, and what it stores. The gap in the middle is the "missing spectrum."

The critical observation: Transformers use only two extreme frequencies. Frequency 0 (MLP — never updates) and frequency ∞ (attention — always recomputes). The brain uses a continuous spectrum. This two-frequency limitation is why Transformers can't form medium-term memories — there's no component that updates at an intermediate rate.

The illusion of depth: We call a 96-layer Transformer "deep." But in the nested learning view, its depth is only 2 — two levels of optimization (training-time for MLP, inference-time for attention). The 96 layers are repeated applications of the same two levels. True depth would mean more levels of nesting, each at a different update frequency. The "depth" we see is width in disguise.
In the nested learning framework, why is a 96-layer Transformer considered to have "depth 2" rather than "depth 96"?

Chapter 4: Update Frequency

The nested learning framework introduces a precise concept for what we've been gesturing at: the update frequency fA of a component A.

Defining update frequency

fA measures how often a component's parameters change per unit of input processing. Components are ordered into levels by their frequency — higher frequency means lower level (faster, more reactive), lower frequency means higher level (slower, more persistent).

The brain analogy

Map this to neural oscillations:

Frequency Spectrum: Brain vs. Transformer

The brain covers a continuous spectrum of frequencies. Transformers occupy only the two extremes. Nested Learning fills the gap with intermediate-frequency components.

The paper's prescription is clear: to build models that learn like brains, we need components at intermediate frequencies. Not just frozen weights and ephemeral attention, but parameters that update at medium rates — fast enough to adapt to new information, slow enough to retain it beyond a single context window.

Why the spectrum matters: A system with only extreme frequencies has a critical failure mode. Knowledge either persists forever (MLP weights) or vanishes instantly (attention). There's no mechanism for gradual consolidation — the process by which the brain converts working memory to long-term memory over hours and days. The missing middle frequencies are exactly what enables learning from experience.
What does "update frequency" measure in the nested learning framework?

Chapter 5: Knowledge Transfer Between Levels

In a nested system, components at different levels need to communicate. How does knowledge flow between a slow, persistent memory and a fast, ephemeral one? The paper identifies three mechanisms.

1. Parametric direct connection

A lower-frequency memory M(0) can be directly conditioned on parameters from a higher-frequency level:

M(0)(·; Θ(1))

The slow memory's behavior is shaped by the fast memory's parameters. Example: the MLP weights in a Transformer were shaped by the optimizer (a higher-frequency process) during training. The MLP's "behavior" (what it computes) is a function of the entire training history, compressed into Θ.

2. Non-parametric conditioning

Softmax attention is a non-parametric solution — it doesn't have persistent parameters of its own, but its output is conditioned on the current context (keys and values from the input sequence). The "context" serves as a compressed representation from a different level.

3. Context generation

The architecture generates gradients for the optimizer. Think about what backpropagation does: the forward pass (architecture) produces predictions, the loss function computes error, and the backward pass generates gradients. These gradients are the context that the optimizer operates on. The architecture doesn't just process data — it generates the information that drives its own learning.

Inter-level flow determines capability: The types of connections between levels determine what the model can and cannot learn. A system where the fast level can write to the slow level (like the optimizer writing to weights) can consolidate short-term patterns into long-term knowledge. A system where the slow level can only read from the fast level (like a frozen MLP reading attention's output) cannot. The directionality of knowledge flow is everything.

This framing explains a subtle point about why some architectures generalize better than others. Architectures with richer inter-level connections can transfer knowledge between timescales — fast-adapting components can discover patterns that slowly get consolidated into persistent parameters. Architectures with poor inter-level flow (like standard inference-time Transformers) are stuck with whatever was learned during training.

What role does backpropagation play in the nested learning framework?

Chapter 6: Revisiting ICL & Pre-training

The nested learning framework lets us reinterpret three fundamental concepts in deep learning. What seemed like separate phenomena turn out to be instances of the same structure.

In-context learning is nested levels

In-context learning (ICL) — the ability of LLMs to solve new tasks from examples in the prompt — has been mysterious. Why can a model that was trained on next-token prediction suddenly do few-shot classification?

The nested learning answer: ICL is a consequence of having multiple nested optimization levels. The examples in the prompt serve as "training data" for the attention mechanism's inner optimization. Attention is performing gradient descent on the dot-product objective, and the in-context examples provide the keys and values for this inner optimization. ICL is not a property of attention specifically — it's a property of any multi-level system where the inner level can adapt at inference time.

ICL demystified: A model has in-context learning capability if and only if it has at least two nested optimization levels: one that sets the stage (pre-trained weights) and one that adapts within the context (attention or equivalent). The "magic" of ICL is just the inner optimizer doing its job.

Pre-training IS in-context learning

Here's the mind-bending part. Pre-training looks different from ICL — one happens during training, the other during inference. But in the nested learning view, they're the same process at different levels:

The only difference is the timescale. Pre-training's "context window" is the entire training set processed over weeks. ICL's "context window" is the prompt processed in milliseconds. Same mechanism, different frequency.

Continual learning = multi-frequency contexts

Continual learning — learning from a sequence of tasks without forgetting previous ones — becomes natural in the nested view. Different levels compress different timescales of context:

Catastrophic forgetting happens because standard models lack medium-frequency levels. Everything is either permanent (MLP weights, can't adapt) or ephemeral (attention, can't remember). With a full spectrum of frequencies, knowledge can cascade from fast to slow levels, achieving continual learning naturally.

The unified view: Pre-training, in-context learning, and continual learning are not three separate capabilities. They are the same process — nested optimization — operating at different timescales. The only difference is the update frequency of the level doing the learning. This is the paper's deepest insight.
How does the nested learning framework explain the relationship between pre-training and in-context learning?

Chapter 7: Continuum Memory System

Traditional cognitive science divides memory into two discrete buckets: short-term memory (small capacity, fast access, quickly forgotten) and long-term memory (large capacity, slower access, persistent). This maps directly onto the Transformer: attention is short-term, MLP weights are long-term.

But the paper argues this binary is wrong — for both brains and models.

From two buckets to a spectrum

Instead of two discrete memory types, the paper proposes a continuum memory system: a spectrum of memory components, each with a different update frequency.

Continuum Memory vs. Discrete Buckets

Left: the traditional two-bucket model (short-term / long-term). Right: the continuum memory system with a spectrum of frequencies. Toggle to see how information flows differently in each model.

The memory loop

A continuum memory system has a remarkable property: knowledge can be partially recovered when forgotten. Here's how:

  1. A fast-frequency component learns a pattern from input and begins to forget it
  2. Before it fully decays, a medium-frequency component has partially absorbed the pattern
  3. When the fast component encounters similar input later, the medium component's retained signal reinforces the fast component's new learning
  4. Meanwhile, the slowest components gradually accumulate the most persistent patterns

This creates a memory loop: information cascades from fast to slow, and slow memories provide a "prior" that helps fast memories form more efficiently on related inputs. In a two-bucket system, there's no cascade — information either stays in the fast bucket (and vanishes) or makes it to the slow bucket (and persists). The middle of the spectrum is what makes the cascade possible.

Think of it this way: Imagine you hear a new word once. Your fastest "memory" (attention-like) registers it momentarily. In a Transformer, that's it — the word vanishes when the context ends. But in a continuum system, a medium-speed memory has been slightly nudged by this word. The next time you hear it, that slight nudge means you process it faster. And a slower memory gets nudged in turn. After a few exposures, the word has cascaded down to persistent memory. No single exposure was enough — it was the gradual cascade across the frequency spectrum that did it.
What advantage does a continuum memory system have over the two-bucket (short-term/long-term) model?

Chapter 8: The Self-Modifying Sequence Model

The paper's culminating proposal: combine the continuum memory system with self-referential updates. The model doesn't just have multiple frequency levels — it learns its own update algorithm.

NSAM: Nested Self-referential Associative Memory

The Neural Self-referential Associative Memory (NSAM) is a concrete architecture that implements the nested learning principles:

The key innovation: the model's update algorithm is not handcrafted (like SGD or Adam). It's learned. The outer optimization (pre-training) discovers update rules for the inner levels that are tailored to the model's specific task and data distribution.

What this enables

The combination of continuum memory + learned updates produces several capabilities that standard Transformers lack:

Self-modification: In a standard Transformer, the architecture is fixed — the same operations happen for every input. In NSAM, the model modifies its own processing based on what it's seen. This is not the same as fine-tuning (which requires external gradient computation). The model's own forward pass updates its internal state. It's a self-modifying program — a concept that goes back to von Neumann and Schmidhuber's early work on self-referential machines.

Early results

The paper reports promising results on benchmarks designed to test exactly these capabilities. NSAM-based models show improvements over standard Transformers on tasks requiring memory beyond the context window, tasks with distributional shift, and tasks requiring rapid adaptation. The results are preliminary but directionally strong — the framework correctly predicts which architectures will succeed on which tasks.

What makes NSAM different from a standard Transformer with a learned optimizer (like learned learning rates)?

Chapter 9: Connections

The nested learning framework doesn't exist in isolation. It unifies and extends several important lines of research.

Meta-learning (MAML)

Model-Agnostic Meta-Learning (Finn et al., 2017) is a 2-level nested system: the outer loop learns an initialization θ*, the inner loop fine-tunes on a specific task. In the NL view, MAML is the simplest non-trivial nested system — exactly two levels. NL generalizes MAML to arbitrary numbers of levels, and identifies that the "meta" and "base" levels are just different points on the frequency spectrum.

Hypernetworks

Hypernetworks (Ha et al., 2016) use one network to generate the weights of another. This is a special case of inter-level knowledge transfer: the hypernetwork operates at a higher frequency (it can change its output) and the generated network operates at a lower frequency (its weights are determined by the hypernetwork). NL reveals that hypernetworks are doing inter-level parametric conditioning.

Fast Weight Programmers

Schmidhuber's Fast Weight Programmers (1992) are an early version of multi-level memory. A "slow" network generates weights for a "fast" network that processes the immediate input. This is a 2-level nested system with explicit frequency separation. The paper acknowledges this as a key predecessor.

Loop Transformers

Loop Transformers share parameters across layers and iterate, creating a system where each "loop" is an optimization step. In the NL view, looping adds a level of nesting: the shared parameters are the slow level, and each loop iteration is a fast-level optimization step. More loops = more inner optimization steps.

Titans

Titans (Behrouz et al., 2025) — modern RNN-like architectures with persistent memory — are explicitly designed as multi-level nested systems. They have a fast attention-like component, a medium-frequency recurrent memory, and slow persistent parameters. Titans are a direct implementation of the NL framework.

Continual Learning Survey

The continual learning literature identifies five method families (regularization, replay, architecture, representation, optimization). NL provides a unifying framework: each method family corresponds to a different strategy for managing inter-level knowledge flow. Regularization constrains how much slow levels change. Replay replays context from previous tasks. Architecture-based methods explicitly add levels. The NL lens reveals these as variations on the same theme.

Cheat sheet — key definitions:
  • Associative memory: A function mapping keys → values, optimized to minimize reconstruction loss
  • Update frequency (fA): How often a component's parameters change per unit of input
  • Nested system: Multiple associative memories at different frequencies, with inter-level knowledge transfer
  • NSAM: Neural Self-referential Associative Memory — architecture with learnable multi-frequency update rules
  • Context: The information one level passes to another (e.g., gradients from architecture to optimizer)
  • Level: A component's position in the frequency hierarchy (slow = high level, fast = low level)
How does nested learning relate to meta-learning (MAML)?