Based on Thinking Machines Lab · Jeremy Bernstein · Sep 2025

Modular manifolds.

How to keep neural network weights healthy by constraining them to curved surfaces — deriving manifold optimization from first principles, building Manifold Muon, and composing modular learning rate budgets across entire architectures.

SOURCE Thinking Machines Blog DEPTH first-principles derivation MATH SVD, Lagrange duality, Stiefel geometry

00 Concept constellation

Every concept in this lesson and how they connect — the territory before the map.

This lesson unpacks a single blog post into 18 interconnected concepts across four topic clusters. The constellation below shows how they relate. Hover over any node to see its description. Click to jump to the chapter where it’s explained.

Geometry Optimization Theory Systems

00 Concept index

Every concept you’ll encounter, sorted by cluster.

Geometry

Manifold

A curved surface that is locally flat. The constraint surface for weights. Ch 02.

Geometry

Tangent Space

Local flat approximation to a manifold at a point. Ch 02.

Geometry

Retraction Map

Snaps a tangent step back onto the manifold. Ch 02.

Geometry

Hypersphere

The simplest manifold: ||w||=1. A circle in 2D. Ch 02.

Geometry

Stiefel Manifold

Matrices with orthonormal columns: W^T W = I. Ch 04.

Optimization

Gradient Descent

The baseline: step opposite the gradient in Euclidean space. Ch 03.

Optimization

Muon

Matrix-aware optimizer using spectral norm. Ch 05.

Optimization

Manifold Muon

Muon constrained to the Stiefel manifold via dual ascent. Ch 05.

Optimization

Dual Ascent

Optimizing the Lagrange dual to find the constrained update. Ch 05.

Optimization

Learning Rate

Step size on the manifold, reinterpreted as a distance budget. Ch 03.

Theory

SVD

Singular Value Decomposition: how matrices stretch space. Ch 04.

Theory

Spectral Norm

Largest singular value: max stretching of a matrix. Ch 04.

Theory

Nuclear Norm

Sum of singular values: the dual norm to spectral. Ch 05.

Theory

Lipschitz Sensitivity

How much output changes per unit input change. Ch 07.

Theory

eNTK

Empirical neural tangent kernel: local linearization. Ch 07.

Systems

Module Abstraction

Triple: (forward, manifold, norm). The building block. Ch 06.

Systems

Composition Rules

How to combine modules: product manifold + max norm. Ch 06.

Systems

Budget Scaling

Per-layer learning rate from architecture-aware norms. Ch 06.

00 Reading guide

This lesson is structured in layers. You can read linearly or skip around.

  • Chapters 01–02: The why and the geometry — what goes wrong without constraints, and the mathematical framework for fixing it.
  • Chapters 03–05: The algorithms — from the optimizer table to Manifold Muon, derived step by step.
  • Chapters 06–07: The system — composing modules into architectures, and connections to the broader literature.

If you already know Riemannian optimization basics, skip to Chapter 04. If you only care about the practical algorithm, jump to Chapter 05.

01 The health problem

Tensors grow or shrink during training. Left unchecked, this kills learning.

You’re training a deep network. After a thousand steps, you check the weight matrices in different layers. Some have grown enormous — their entries are 100x what they started at. Others have collapsed toward zero. Neither is good.

Large weights cause exploding gradients: the loss landscape becomes a narrow ravine, and SGD bounces between the walls instead of descending. Small weights cause vanishing gradients: the signal attenuates as it passes through near-zero multipliers, and early layers stop learning.

This isn’t a theoretical concern. It’s the primary failure mode of deep network training, and the reason we have an entire zoo of tricks to manage it.

The core insight: we have normalization for activations (LayerNorm, BatchNorm) and normalization for gradients (Adam, Muon). But we don’t have a principled way to keep the weights themselves well-behaved. Manifold constraints fill this gap.

01 The normalization landscape

Let’s map what exists today and identify the gap:

What’s normalizedMethodHow it works
ActivationsLayerNorm, BatchNorm, RMSNormRescale hidden states to unit variance at each layer
Gradients (scale)AdamDivide gradient by running estimate of its second moment
Gradients (direction)Muon, LIONUse only the sign or matrix sign of the gradient
Weights???Nothing principled. Weight decay is a hack. Spectral norm regularization is expensive.

Weight decay ($w \leftarrow (1-\lambda)w$) does shrink weights, but it’s not really normalization. It pulls every weight toward zero with equal force, regardless of the structure of the matrix. A matrix with perfectly balanced singular values gets the same treatment as one that’s wildly imbalanced.

What we want is something more geometric: constrain the weight matrix to live on a surface where “healthy” is guaranteed by construction.

01 Singular value drift

Visualizing what goes wrong when weights are unconstrained.

Every matrix $W \in \mathbb{R}^{m \times n}$ has a singular value decomposition $W = U \Sigma V^T$, where $\Sigma$ is a diagonal matrix of non-negative values $\sigma_1 \geq \sigma_2 \geq \ldots \geq \sigma_{\min(m,n)}$. These singular values tell you how much the matrix stretches space in each direction.

A “healthy” weight matrix has singular values that are all close to 1 — it transforms inputs without dramatically amplifying or suppressing any direction. An “unhealthy” matrix has some singular values much larger or smaller than 1.

Singular values (unconstrained) Singular values (manifold-constrained) σ = 1 target

The animation shows 8 singular values of a weight matrix during training. Without constraints (amber), they drift apart wildly — some explode, some vanish. With a manifold constraint (green), they stay pinned to 1. The constraint doesn’t just regularize — it eliminates the drift entirely.

01 The manifold idea

Here’s the key realization: if we want all singular values to equal 1, we’re asking for $W^T W = I$. That’s not just a constraint — it defines a specific geometric object called the Stiefel manifold. And there’s a well-developed mathematical theory for how to optimize on manifolds.

The plan:

  1. Define the manifold

    Choose a constraint surface that encodes “health.” For vectors: the hypersphere $\|w\| = 1$. For matrices: the Stiefel manifold $W^TW = I$.

    Geometry — what surface to live on
  2. Compute in the tangent space

    Project the gradient onto the tangent plane — the set of directions that are compatible with staying on the manifold.

    Calculus — how to move without leaving
  3. Retract back onto the manifold

    After taking a step in the tangent space, snap back onto the manifold using a retraction map.

    Numerics — staying exactly on the surface

The beauty of this framework is that different manifold choices give different optimizers. The hypersphere gives you hyperspherical descent. The Stiefel manifold with spectral norm distance gives you Manifold Muon. And the modular abstraction lets you compose these choices across an entire architecture.

Think of it like a ball rolling on a curved surface. The ball (weight matrix) is always on the surface (manifold). Gravity (gradient) tries to pull it off. But we constrain its movement to the surface (tangent projection + retraction). It can only slide along the manifold, never leave it.

02 What is a manifold?

A curved surface that is locally flat — the constraint surface for our weights.

Forget the formal topology definition. For our purposes, a manifold is just a surface embedded in a higher-dimensional space that looks flat if you zoom in enough. The Earth’s surface is a 2D manifold embedded in 3D space. A circle is a 1D manifold embedded in 2D space.

The key property: at any point on the manifold, you can draw a flat plane that just barely touches the surface. This is the tangent space — the set of all directions you can move while (approximately) staying on the manifold.

Why manifolds for optimization? Because a constraint like $\|w\| = 1$ defines a manifold (the hypersphere). If we constrain our weights to live on this surface, we need to know how to take gradient steps along the surface rather than through ambient space.

02 Tangent spaces

The tangent space at a point $w$ on a manifold is the vector space of all velocities that keep you on the manifold to first order. Mathematically, if the manifold is defined by a constraint $c(w) = 0$, then the tangent space at $w$ is:

Tangent space definition $$T_w\mathcal{M} = \{a \in \mathbb{R}^n \mid \nabla c(w)^T a = 0\}$$
  • $T_w\mathcal{M}$ is the tangent space at point $w$
  • $a$ is a tangent vector (a feasible direction)
  • $\nabla c(w)$ is the gradient of the constraint at $w$
  • The condition says: tangent vectors are perpendicular to the constraint gradient

This makes geometric sense. The constraint gradient $\nabla c(w)$ points “away from” the manifold (it’s the direction of steepest increase of the constraint function). Vectors perpendicular to it are the ones that stay on the surface.

02 The hypersphere

The simplest manifold: all vectors of unit length.

Let’s work through the simplest case. The unit hypersphere in $\mathbb{R}^n$ is defined by the constraint:

Hypersphere constraint $$c(w) = \|w\|_2^2 - 1 = w^Tw - 1 = 0$$

The gradient of this constraint is $\nabla c(w) = 2w$. So the tangent space at $w$ is all vectors $a$ satisfying:

Hypersphere tangency condition $$w^T a = 0$$
  • The tangent vectors are precisely those perpendicular to $w$ itself
  • In 2D: if $w$ points to a spot on the circle, the tangent is the line touching the circle at that point
Current point w Gradient g Tangent projection Retracted point

The interactive visualization above shows the full manifold optimization procedure on a 2D circle. Drag the point to move it. The red arrow is the raw gradient (pointing toward the loss minimum). The blue arrow is its projection onto the tangent line. The green point is where we end up after retraction. Click “Step” to execute one optimization step.

02 Deriving the optimal tangent update

Using Lagrange multipliers to find the best direction on the manifold.

We want to find the tangent vector $a$ that decreases the loss most, subject to two constraints: (1) it must be tangent to the manifold ($w^Ta = 0$), and (2) it must have bounded norm ($\|a\| \leq \eta$). This is a constrained optimization problem.

Setup: Given gradient $g = \nabla L(w)$, find the direction in the tangent space that decreases the loss most per unit distance traveled on the manifold.

The linearized loss decrease from taking step $a$ is $g^T a$ (first-order Taylor). We want to minimize this (most negative = steepest descent). The problem is:

Constrained descent problem $$\min_a\ g^T a \quad \text{s.t.} \quad w^T a = 0,\quad \|a\|_2 \leq \eta$$

We can solve this by first projecting $g$ onto the tangent space, then normalizing. The tangent projection of $g$ is:

Tangent projection $$g_\perp = g - (w^T g)\, w = g - w w^T g$$
  • $w w^T g$ is the component of $g$ along $w$ (the normal direction)
  • $g_\perp$ is the component of $g$ in the tangent space
  • This is just the projection matrix $(I - ww^T)$ applied to $g$

The optimal tangent update is then simply this projected gradient, normalized to have length $\eta$:

Optimal update $$a_{\text{opt}} = -\eta \cdot \frac{g_\perp}{\|g_\perp\|_2} = -\eta \cdot \frac{g - ww^Tg}{\|g - ww^Tg\|_2}$$
Worked example: Let $w = (1, 0)$ (pointing right on the unit circle), $g = (0.3, -0.8)$, $\eta = 0.2$.

Tangent projection: $g_\perp = g - (w^Tg)w = (0.3, -0.8) - 0.3 \cdot (1,0) = (0, -0.8)$
Normalized: $g_\perp / \|g_\perp\| = (0, -1)$
Optimal step: $a_{\text{opt}} = -0.2 \cdot (0, -1) = (0, 0.2)$

The gradient had a component along $w$ (trying to leave the circle) which we removed. The remaining tangent component points downward, so we step upward.

02 Retraction

After taking a step $a$ in the tangent space, we’re no longer exactly on the manifold (we’re slightly off the circle/sphere). A retraction maps us back. For the hypersphere, the simplest retraction is normalization:

Retraction on the hypersphere $$w_{\text{new}} = \frac{w + a}{\|w + a\|_2}$$

But there’s a more elegant form. Since $w$ and the tangent $a$ are orthogonal (by the tangency condition $w^Ta = 0$), we can use the Pythagorean theorem:

Pythagorean retraction $$\|w + a\|_2 = \sqrt{\|w\|^2 + \|a\|^2} = \sqrt{1 + \eta^2}$$
  • Since $\|w\| = 1$ (on the manifold) and $\|a\| = \eta$ (our step size)
  • The orthogonality gives us this clean closed form

So the full update becomes:

Complete hyperspherical descent step $$w \leftarrow \frac{1}{\sqrt{1 + \eta^2}} \left[ w - \eta \cdot \frac{g_\perp}{\|g_\perp\|_2} \right]$$

This is the complete hyperspherical descent algorithm. Three lines of code:

Python
def hyperspherical_step(w, g, eta):
    # 1. Project gradient onto tangent space
    g_perp = g - w * (w @ g)
    # 2. Normalize to get unit tangent direction
    a = -eta * g_perp / g_perp.norm()
    # 3. Retract: w and a are orthogonal, so ||w+a|| = sqrt(1 + eta^2)
    w_new = (w + a) / (1 + eta**2)**0.5
    return w_new
Notice what this achieves: regardless of the gradient magnitude, the weight vector always stays on the unit sphere. There is no way for the weights to explode or vanish — the constraint is enforced by construction, not by a penalty term that might be overwhelmed.

03 The three-step template

Every manifold optimizer follows the same recipe.

In Chapter 02 we derived the full procedure for the hypersphere. But the same three-step template works for any manifold with any norm:

  1. Tangent: project the gradient

    Given gradient $G$, project it onto the tangent space of the manifold at the current point $W$. This removes the component that would move you off the manifold.

    Constrain direction to be feasible
  2. Scale: choose step size in the chosen norm

    Normalize the tangent direction according to the chosen distance metric (Euclidean, spectral, infinity, ...). The norm determines what “unit distance” means on this manifold.

    Constrain magnitude to the budget η
  3. Retract: snap back to the manifold

    After stepping, project back onto the constraint surface. Different retractions have different cost/accuracy tradeoffs.

    Stay on the manifold

The insight is that by varying the manifold choice and the norm choice independently, you get a table of optimizers — and many famous optimizers turn out to be special cases.

03 The optimizer table

Different manifolds and norms give different optimizers.

ManifoldNorm (distance)OptimizerKey Operation
Euclidean$\ell_2$SGD$w \leftarrow w - \eta \cdot g/\|g\|$
Euclidean$\ell_\infty$SignSGD$w \leftarrow w - \eta \cdot \text{sign}(g)$
Hypersphere$\ell_2$Hyperspherical descentProject + normalize + retract on sphere
Matrix (flat)SpectralMuon$W \leftarrow W - \eta \cdot \text{msign}(G)$
StiefelSpectralManifold MuonTangent project + dual ascent + retract
Stiefel$\ell_2$ (Frobenius)Cayley SGDProject + Cayley retraction
The first two rows are the optimizers everyone knows. The last row is the new one from this blog post. The progression is clear: we go from flat space (Euclidean) to curved space (hypersphere, Stiefel), and from simple norms ($\ell_2$, $\ell_\infty$) to matrix norms (spectral).

Let’s make the connection to Muon explicit. On flat Euclidean space (no manifold constraint), the spectral norm ball update is:

Muon as flat-manifold spectral descent $$W \leftarrow W - \eta \cdot \text{msign}(G) \quad \text{where msign}(G) = U V^T\ (G = U\Sigma V^T)$$
  • $\text{msign}(G)$ “snaps” all singular values to 1 — pure rotation, no stretching
  • This is the matrix version of $\text{sign}(g)$ for vectors
  • It maximizes $\text{tr}(G^T A)$ subject to $\|A\|_{\text{spectral}} \leq 1$

Manifold Muon adds the Stiefel constraint: we still use spectral distance, but the step must also be tangent to $W^TW = I$. This is harder — it requires solving a dual problem — but the payoff is that the weights stay on the Stiefel manifold forever.

03 Optimizer trajectories

Seeing the difference on a 2D loss landscape.

Unconstrained SGD Hyperspherical descent Muon (spectral) Unit circle constraint

The visualization shows three optimizers minimizing $f(w) = (w_1 - 0.7)^2 + 3(w_2 + 0.3)^2$ (an elliptical bowl). SGD (gray) wanders freely toward the minimum. Hyperspherical descent (amber) stays on the unit circle and finds the closest point to the minimum on the circle. Muon (blue) takes larger spectral-norm-bounded steps but can leave the circle.

The key difference: the constrained optimizer (amber) never leaves the manifold. It can’t overshoot the manifold, it can’t get lost in flat space. Its trajectory is restricted to a well-defined surface.

03 Reinterpreting learning rate

In standard SGD, the learning rate $\eta$ is a poorly understood hyperparameter. Too large and you diverge. Too small and you converge slowly. The “right” value depends on the loss landscape curvature, which changes during training.

In manifold optimization, $\eta$ has a clean geometric meaning: it’s the maximum distance you can travel on the manifold in one step. The distance is measured in whatever norm you chose (Euclidean, spectral, infinity).

Learning rate as distance budget $$\eta = \max_{\text{step}} d_{\text{manifold}}(W_t, W_{t+1})$$
  • On the hypersphere with $\ell_2$ distance: $\eta$ bounds the arc length traveled
  • On the Stiefel manifold with spectral distance: $\eta$ bounds the largest singular value change
  • This interpretation makes $\eta$ transferable across architectures

This is the key idea behind “modular manifolds” (Chapter 06): if learning rate is distance, then you can budget it across layers based on architecture-aware sensitivity analysis, rather than tuning it empirically.

04 Why matrices need manifolds

Vectors live in simple space. Matrices are maps between spaces — they need richer geometry.

A weight vector $w \in \mathbb{R}^n$ maps a scalar to $n$ numbers. The hypersphere $\|w\| = 1$ is a fine constraint: it says “the vector has unit energy.”

But in neural networks, most weights are matrices. A matrix $W \in \mathbb{R}^{m \times n}$ is a linear map from $\mathbb{R}^n$ to $\mathbb{R}^m$. It takes an input vector and transforms it into an output vector. The “health” of this map isn’t captured by a single number like $\|W\|_F$. You need to understand how it stretches each direction.

This is where SVD enters.

04 SVD: matrices as stretching

The Singular Value Decomposition writes any matrix as: rotate, stretch, rotate.

SVD decomposition $$W = U \Sigma V^T$$
  • $U \in \mathbb{R}^{m \times m}$ — rotation in output space (orthogonal)
  • $\Sigma \in \mathbb{R}^{m \times n}$ — diagonal of singular values $\sigma_1 \geq \sigma_2 \geq \ldots$
  • $V \in \mathbb{R}^{n \times n}$ — rotation in input space (orthogonal)
  • Geometric meaning: first rotate input ($V^T$), then stretch each axis ($\Sigma$), then rotate output ($U$)

The singular values $\sigma_i$ tell you the stretching factors. If $\sigma_1 = 10$ and $\sigma_n = 0.01$, the matrix amplifies one direction by 10x and squashes another by 100x. That’s unhealthy — it means the layer’s sensitivity varies by 1000x depending on the input direction.

The spectral norm is just the largest singular value: $\|W\|_\sigma = \sigma_1$. It measures the maximum stretching.

A “healthy” matrix has all singular values close to 1. The extreme case: all singular values exactly 1. That’s an orthogonal matrix (or more precisely, a matrix with orthonormal columns).

04 The Stiefel manifold

The Stiefel manifold $\text{St}(m, n)$ is the set of all $m \times n$ matrices with orthonormal columns:

Stiefel manifold definition $$\text{St}(m, n) = \{W \in \mathbb{R}^{m \times n} \mid W^T W = I_n\}$$
  • $m \geq n$ (more rows than columns, or square)
  • $W^TW = I_n$ means columns are orthonormal
  • All singular values are exactly 1
  • In SVD terms: $W = UV^T$ (no stretching — $\Sigma = I$)

Special cases:

  • $\text{St}(n, 1)$: unit vectors in $\mathbb{R}^n$ = the hypersphere $S^{n-1}$. Our Chapter 02 example.
  • $\text{St}(n, n)$: $n \times n$ orthogonal matrices = the orthogonal group $O(n)$.
  • $\text{St}(m, n)$ with $m > n$: “tall thin” matrices with orthonormal columns. This is the typical case for neural network weight matrices.
Why Stiefel for neural networks? A linear layer $y = Wx$ on the Stiefel manifold preserves the norm of any input: $\|Wx\|^2 = x^T W^T W x = x^T I x = \|x\|^2$. This means the layer neither amplifies nor suppresses signals. The entire gradient explosion/vanishing problem is eliminated by construction for this layer.

04 The tangency condition

To optimize on the Stiefel manifold, we need the tangent space. The constraint is $c(W) = W^TW - I = 0$ (a matrix equation). Let’s differentiate.

If $W(t)$ is a curve on the manifold with velocity $A = \dot{W}$, then differentiating $W^TW = I$ gives:

Stiefel tangency condition $$\frac{d}{dt}(W^TW) = A^T W + W^T A = 0$$
  • $A \in \mathbb{R}^{m \times n}$ is a tangent vector at $W$
  • The condition says $A^TW$ must be skew-symmetric: $(A^TW)^T = -(A^TW)$
  • This generalizes the vector condition $w^Ta = 0$ to matrices
Derivation: Start from $W^TW = I$. Let $W(t) = W + tA + O(t^2)$. Then:
$(W+tA)^T(W+tA) = W^TW + t(A^TW + W^TA) + O(t^2) = I$
For this to equal $I$ to first order: $A^TW + W^TA = 0$.
So the tangent space is: $T_W\text{St}(m,n) = \{A \mid A^TW + W^TA = 0\}$, i.e., $A^TW$ is skew-symmetric.

What does this mean geometrically? If we write $A = WS + W_\perp K$ where $S$ is $n \times n$ skew-symmetric and $W_\perp$ is the orthogonal complement of $W$’s column space, then the tangency condition is automatically satisfied. The $WS$ part rotates within the span of $W$’s columns (like rotating an orthogonal frame), and the $W_\perp K$ part moves into new directions.

04 Visualizing Stiefel geometry

What “unit stretching” means for input-output maps.

Input unit circle Output (transformed) Output unit circle (Stiefel target)

The visualization shows a 2D matrix acting on the unit circle. When $\sigma_1 = \sigma_2 = 1$ (Stiefel), the circle maps to itself — pure rotation. When singular values differ, the circle becomes an ellipse. The “Snap to Stiefel” button resets both singular values to 1, showing the manifold constraint in action.

In a real neural network with $W \in \mathbb{R}^{4096 \times 4096}$, this same principle applies in 4096 dimensions: the Stiefel constraint ensures that no input direction is amplified or suppressed. Every direction gets the same treatment.

05 The optimization problem

Combining the Stiefel constraint with spectral-norm distance — the hardest case in our table.

We want to find the tangent vector $A$ at $W$ on the Stiefel manifold that maximizes the linearized loss decrease, subject to spectral-norm distance budget $\eta$:

Manifold Muon primal problem $$\min_A\ \text{tr}(G^T A) \quad \text{s.t.} \quad \|A\|_\sigma \leq \eta, \quad A^TW + W^TA = 0$$
  • $G = \nabla_W L$ is the gradient of the loss w.r.t. $W$
  • $\|A\|_\sigma \leq \eta$ bounds the spectral norm of the step (max singular value)
  • $A^TW + W^TA = 0$ is the Stiefel tangency constraint
  • $\text{tr}(G^TA) = \langle G, A \rangle_F$ is the linearized loss decrease

This is harder than the hypersphere case because the spectral norm constraint and the tangency constraint interact in a non-trivial way. We can’t just “project then normalize” like before — the spectral norm ball isn’t rotationally symmetric in the tangent space.

The solution: Lagrange duality.

05 Dual formulation

Using Lagrange multipliers to turn a constrained problem into an unconstrained one.

Introduce a Lagrange multiplier $\Lambda \in \mathbb{R}^{n \times n}$ for the tangency constraint. The Lagrangian is:

Lagrangian $$\mathcal{L}(A, \Lambda) = \text{tr}(G^T A) + \text{tr}(\Lambda^T(A^TW + W^TA))$$

Rearranging the constraint term: $\text{tr}(\Lambda^T A^T W) + \text{tr}(\Lambda^T W^T A) = \text{tr}((W\Lambda)^T A) + \text{tr}((W\Lambda^T)^T A) = \text{tr}((W(\Lambda + \Lambda^T))^T A)$.

So the Lagrangian becomes:

Simplified Lagrangian $$\mathcal{L}(A, \Lambda) = \text{tr}\left((G + W(\Lambda + \Lambda^T))^T A\right)$$

Now we minimize over $A$ subject only to $\|A\|_\sigma \leq \eta$. The minimum of $\text{tr}(M^T A)$ over $\|A\|_\sigma \leq \eta$ is $-\eta \|M\|_*$ (where $\|\cdot\|_*$ is the nuclear norm — sum of singular values). This is because spectral and nuclear norms are dual to each other.

Key duality fact: For any matrix $M$,
$\min_{\|A\|_\sigma \leq \eta} \text{tr}(M^T A) = -\eta \|M\|_* = -\eta \sum_i \sigma_i(M)$

The minimizer is $A^* = -\eta \cdot \text{msign}(M)$ where $\text{msign}(M) = UV^T$ from the SVD $M = U\Sigma V^T$.
This generalizes: $\min_{\|a\|_\infty \leq \eta} g^Ta = -\eta\|g\|_1$ with minimizer $a^* = -\eta \cdot \text{sign}(g)$.

Plugging in, the dual function is:

Dual function $$d(\Lambda) = -\eta \|G + W(\Lambda + \Lambda^T)\|_*$$
  • $\|M\|_* = \sum_i \sigma_i(M)$ is the nuclear norm (sum of all singular values)
  • We maximize $d(\Lambda)$ over $\Lambda$ (dual ascent)
  • At the optimum, $A^* = -\eta \cdot \text{msign}(G + W(\Lambda^* + {\Lambda^*}^T))$

The dual problem is: find $\Lambda$ that maximizes $-\eta\|G + W(\Lambda + \Lambda^T)\|_*$. Since we’re maximizing a negative quantity, this means minimizing the nuclear norm of $G + W(\Lambda + \Lambda^T)$ over $\Lambda$.

05 Dual ascent algorithm

We solve the dual problem by gradient ascent on $\Lambda$. The gradient of the dual function w.r.t. $\Lambda$ involves the derivative of the nuclear norm, which (by the chain rule through SVD) gives:

Dual gradient $$\nabla_\Lambda d = -\eta \cdot W^T \text{msign}(G + W(\Lambda + \Lambda^T)) \cdot 2$$

But rather than derive this exactly, the blog post advocates a simpler approach: directly ascending the dual with a fixed step size. The full algorithm:

Python — Manifold Muon
def manifold_muon_step(W, G, eta, dual_steps=5, dual_lr=0.1):
    """One step of Manifold Muon on the Stiefel manifold."""
    m, n = W.shape
    Lambda = torch.zeros(n, n, device=W.device)

    # Dual ascent: find the Lagrange multiplier
    for _ in range(dual_steps):
        M = G + W @ (Lambda + Lambda.T)
        U, S, Vt = torch.linalg.svd(M, full_matrices=False)
        # Gradient of dual w.r.t. Lambda
        A_candidate = -eta * U @ Vt  # = -eta * msign(M)
        # Tangency violation: A^T W + W^T A should be 0
        violation = A_candidate.T @ W + W.T @ A_candidate
        Lambda = Lambda + dual_lr * violation

    # Final tangent update
    M = G + W @ (Lambda + Lambda.T)
    U, S, Vt = torch.linalg.svd(M, full_matrices=False)
    A = -eta * U @ Vt

    # Retraction: polar retraction (snap to nearest Stiefel point)
    W_new = W + A
    U2, _, Vt2 = torch.linalg.svd(W_new, full_matrices=False)
    W_new = U2 @ Vt2  # Polar retraction: W_new = UV^T

    return W_new
Worked example (2x2): Let $W = I_2$ (identity, on Stiefel), $G = \begin{pmatrix} 0.5 & 0.3 \\ -0.2 & 0.8 \end{pmatrix}$, $\eta = 0.1$.

Iteration 1: $\Lambda = 0$, so $M = G$. Compute SVD of $G$: $\sigma_1 \approx 0.93, \sigma_2 \approx 0.43$.
$A = -0.1 \cdot U V^T$ (msign of $G$). Check tangency: $A^TW + W^TA = A^T + A \neq 0$ in general.
Violation $\approx$ some skew-symmetric correction. Update $\Lambda$.

After 5 iterations: $\Lambda$ converges such that $A^TW + W^TA \approx 0$, i.e., the update is tangent to Stiefel.

Retract: $W + A$ is approximately orthogonal; polar retraction snaps it exactly onto Stiefel.

05 The matrix sign function

The matrix sign function $\text{msign}(M)$ appears repeatedly. Let’s understand it geometrically.

Given SVD $M = U\Sigma V^T$, the matrix sign is $\text{msign}(M) = UV^T$. It “snaps” all singular values to 1, keeping the rotational structure but removing the stretching. Equivalently:

Matrix sign function $$\text{msign}(M) = \arg\max_{\|A\|_\sigma \leq 1} \text{tr}(M^T A)$$
  • It’s the unit spectral-norm matrix most aligned with $M$
  • Analogy: $\text{sign}(g)$ is the unit $\ell_\infty$-norm vector most aligned with $g$
  • Geometrically: project $M$ onto the set of matrices with all singular values $\leq 1$, saturating at the boundary

Why is this the right “normalization” for matrix gradients? Because the spectral norm and nuclear norm are dual pairs. Just as sign($g$) maximizes $g^Ta$ over $\|a\|_\infty \leq 1$, msign($M$) maximizes $\text{tr}(M^TA)$ over $\|A\|_\sigma \leq 1$. It’s the steepest descent direction in spectral-norm distance.

05 Interactive dual ascent

Watch the dual variable converge to enforce the tangency constraint.

Nuclear norm (dual objective) Tangency violation ||A^TW + W^TA|| Update direction (msign)

The visualization shows the dual ascent process for a random 3x3 gradient. The left panel plots the nuclear norm (being minimized) and tangency violation (being driven to zero) over dual iterations. The right panel shows the update direction evolving as $\Lambda$ adjusts. After convergence, the tangency violation is near zero and the update is valid on the Stiefel manifold.

05 Experimental results

The blog post reports CIFAR-10 experiments comparing optimizers. The key finding: Manifold Muon achieves the same final accuracy as Muon but with stable singular values throughout training. The weights never drift.

OptimizerCIFAR-10 Accσmax driftConstraint
SGD + weight decay93.2%Large (3-15x)None (penalty only)
Adam94.1%Moderate (2-5x)None
Muon94.5%Moderate (1.5-3x)None (flat spectral)
Manifold Muon94.4%Zero (all σ=1)Stiefel (exact)

The accuracy is competitive, but the real win is in training stability. With Manifold Muon, you never need to worry about learning rate warmup (the weights can’t explode on the first few steps), and you never need weight decay (the weights can’t grow). The manifold constraint absorbs the role of several ad-hoc tricks.

The trade-off is compute: each step requires ~5 SVDs of an $m \times n$ matrix for the dual ascent, plus one SVD for retraction. For large matrices this is expensive. The modular manifold framework (Chapter 06) addresses this by being smart about which layers get manifold constraints.

06 The module abstraction

A module is a triple: (forward function, manifold, norm). Everything else follows.

The blog post’s key contribution isn’t just Manifold Muon — it’s the abstraction layer that lets you compose manifold-constrained layers into full architectures. A module is defined by three things:

01
Component

Forward

The computation: $y = f(w, x)$. For a linear layer, $f(W, x) = Wx$. For a norm layer, $f(\gamma, x) = \gamma \cdot x / \|x\|$.

02
Component

Manifold

The constraint surface for the parameters. Stiefel for matrices, hypersphere for vectors, Euclidean (unconstrained) for scalars.

03
Component

Norm

The distance metric that defines “one step.” Spectral for matrices, $\ell_2$ for vectors, absolute value for scalars.

Given these three pieces, the optimizer is fully determined. You don’t choose an optimizer separately — it falls out of the module specification. This is the manifold analog of how autograd derives backward passes from the forward computation.

06 StiefelLinear example

The simplest manifold module: a linear layer constrained to Stiefel.

Python — StiefelLinear Module
class StiefelLinear(ManifoldModule):
    """Linear layer on the Stiefel manifold."""

    def __init__(self, in_features, out_features):
        # Initialize on the manifold: random orthogonal matrix
        W = torch.randn(out_features, in_features)
        U, _, Vt = torch.linalg.svd(W, full_matrices=False)
        self.W = U @ Vt  # On Stiefel from the start

    def forward(self, x):
        return self.W @ x

    def manifold(self):
        return StiefelManifold(self.W.shape)

    def norm(self):
        return SpectralNorm()

    # The optimizer step is derived from manifold + norm:
    # tangent_project(G, W) + scale_to_spectral_ball(eta) + polar_retract()

The key insight: the module doesn’t know about Manifold Muon or dual ascent. It only specifies what it is (forward, manifold, norm). The framework derives the optimizer automatically. If you change the norm from spectral to Frobenius, you get a different optimizer (Cayley SGD). If you change the manifold from Stiefel to Euclidean, you get plain Muon.

06 Composition rules

How to combine modules into networks: the product manifold and max norm.

A neural network has multiple parameter groups: $(W_1, W_2, \ldots, W_L)$ for an $L$-layer network. Each lives on its own manifold. The full parameter space is the product manifold:

Product manifold $$\mathcal{M} = \mathcal{M}_1 \times \mathcal{M}_2 \times \ldots \times \mathcal{M}_L$$
  • Each factor $\mathcal{M}_i$ is the manifold for layer $i$ (Stiefel, hypersphere, Euclidean, ...)
  • A point on the product manifold is $(W_1, W_2, \ldots, W_L)$ with each $W_i \in \mathcal{M}_i$
  • Tangent vectors are tuples: $(A_1, A_2, \ldots, A_L)$ with each $A_i \in T_{W_i}\mathcal{M}_i$

But how do we define a single learning rate $\eta$ for the whole network? We need a norm on the product space. The blog post advocates the scaled max norm:

Module composition norm $$\|(w_1, w_2, \ldots, w_L)\| = \max_i\left(s_i \cdot \|w_i\|_i\right)$$
  • $\|w_i\|_i$ is the norm specific to module $i$ (spectral, $\ell_2$, etc.)
  • $s_i$ is a sensitivity scalar that reweights each layer’s contribution
  • The max means: no single layer can move more than its budget

The sensitivity scalars $s_i$ encode how much the network’s output changes when layer $i$ changes. A layer near the output has high sensitivity (small changes have big effects). A layer near the input has low sensitivity (changes are attenuated by many subsequent layers).

The max norm is crucial. If we used a sum norm, a single “rogue” layer could consume the entire learning budget by taking a huge step while others stay still. The max norm prevents this: it imposes a per-layer budget that depends on that layer’s sensitivity.

06 Product manifolds

The simplest example of a product manifold: a cylinder = line $\times$ circle. One parameter lives on a line (Euclidean), another lives on a circle (hypersphere). The product is a cylinder.

Circle manifold (layer 1) Line manifold (layer 2) Product point (cylinder)

The visualization shows a 2-module network. Layer 1 has a vector constrained to the circle (Stiefel for vectors). Layer 2 is unconstrained (Euclidean). The product manifold is a cylinder. The learning rate budget is shared: $\eta = \max(s_1 \cdot \|a_1\|, s_2 \cdot |a_2|)$. Adjusting the sliders shows how the per-layer budgets interact.

06 Budget scaling

How to set the sensitivity scalars from architecture analysis.

The sensitivity scalars $s_i$ should reflect how much the network output changes when layer $i$’s parameters change by one unit (in that layer’s norm). This is the Lipschitz sensitivity:

Sensitivity scalars from Lipschitz analysis $$s_i = \frac{\partial \|f(x)\|}{\partial \|W_i\|_i} \bigg|_{\text{typical}}$$
  • For a depth-$L$ network of Stiefel layers: $s_i \approx 1$ for all $i$ (each layer contributes equally)
  • For a ResNet with skip connections: sensitivity depends on branch structure
  • For a Transformer: attention layers have different sensitivity than FFN layers

The practical consequence: given a global learning rate $\eta$, each layer gets an effective learning rate $\eta_i = \eta / s_i$. Layers with high sensitivity get smaller steps. Layers with low sensitivity get larger steps. This is derived from the architecture, not tuned.

Python — Budget Scaling
def compute_budgets(model, global_eta):
    """Derive per-layer learning rates from architecture."""
    budgets = {}
    for name, module in model.named_modules():
        if isinstance(module, ManifoldModule):
            # Sensitivity from forward-mode Lipschitz analysis
            s_i = module.lipschitz_sensitivity(model)
            budgets[name] = global_eta / s_i
    return budgets

# Example for a 12-layer Stiefel transformer:
# All layers get eta_i = eta (uniform sensitivity)
# But attention QK layers might get eta/sqrt(d_head) due to softmax sensitivity
Concrete example: Consider a 3-layer network where layer 1 has sensitivity $s_1 = 1$, layer 2 (attention) has $s_2 = 4$ (softmax amplification), and layer 3 has $s_3 = 2$ (close to output).

With global $\eta = 0.1$:
Layer 1 budget: $0.1 / 1 = 0.1$ (full budget)
Layer 2 budget: $0.1 / 4 = 0.025$ (small steps — attention is sensitive)
Layer 3 budget: $0.1 / 2 = 0.05$ (moderate — near output)

No per-layer LR tuning needed. The architecture determines the budget.

07 Connection to LoRA

Both are about weight parameterization. The manifold view unifies them.

LoRA (Low-Rank Adaptation) parameterizes weight updates as $\Delta W = BA$ where $B \in \mathbb{R}^{m \times r}$ and $A \in \mathbb{R}^{r \times n}$ with $r \ll \min(m,n)$. This implicitly constrains the update to a low-rank manifold.

The manifold view makes this explicit:

  • LoRA manifold: $\{W_0 + BA \mid B \in \mathbb{R}^{m \times r}, A \in \mathbb{R}^{r \times n}\}$ — a rank-$r$ affine subspace around $W_0$
  • Stiefel manifold: $\{W \mid W^TW = I\}$ — orthogonal matrices
  • Combined: You could do LoRA on the Stiefel manifold: $W = W_0 \cdot \text{exp}(BA - A^TB^T)$ where the skew-symmetric structure ensures you stay on Stiefel

The Thinking Machines “LoRA Without Regret” blog post (a companion piece) explores this connection: how to do efficient fine-tuning while maintaining manifold constraints.

07 Connection to Muon

Muon (the original optimizer by Jordan et al.) is exactly “Manifold Muon without the manifold.” It uses the same spectral-norm distance metric to normalize gradient steps, but without constraining the weights to Stiefel:

PropertyMuonManifold Muon
ManifoldEuclidean (unconstrained)Stiefel ($W^TW = I$)
Step normSpectral ($\|A\|_\sigma \leq \eta$)Spectral ($\|A\|_\sigma \leq \eta$)
Update formula$W \leftarrow W - \eta \cdot \text{msign}(G)$$W \leftarrow \text{retract}(W - \eta \cdot \text{msign}(G + W\Lambda))$
Weight healthNot guaranteedGuaranteed ($\sigma_i = 1\ \forall i$)
Cost per step1 SVD~6 SVDs (dual ascent + retraction)

The relationship is clean: Manifold Muon = Muon + tangency correction. The $W\Lambda$ term in the gradient is exactly the correction needed to make the Muon step tangent to Stiefel. When $\Lambda = 0$, you recover plain Muon.

07 Lipschitz-constrained learning

There’s a deep connection between manifold constraints and Lipschitz bounds. A function $f: \mathbb{R}^n \to \mathbb{R}^m$ is $L$-Lipschitz if $\|f(x) - f(y)\| \leq L \|x - y\|$ for all $x, y$.

For a linear layer $f(x) = Wx$, the Lipschitz constant is exactly $\|W\|_\sigma$ (the spectral norm). Constraining $W$ to Stiefel means $\|W\|_\sigma = 1$, so the layer is 1-Lipschitz.

For a deep network $f = f_L \circ f_{L-1} \circ \ldots \circ f_1$, the overall Lipschitz constant is bounded by the product: $L_{\text{total}} \leq \prod_i L_i$. If every layer is 1-Lipschitz (Stiefel), the whole network is 1-Lipschitz. This gives:

  • Training stability: Gradient norms are bounded — no explosion
  • Robustness: Small input perturbations cause small output changes
  • Generalization: Lipschitz-bounded networks have tighter generalization bounds
The manifold view makes Lipschitz constraints easy. Instead of computing expensive spectral norm penalties or using projection after each step, you simply declare the manifold. The constraint is maintained exactly, for free, by the optimizer.

07 eNTK and local linearization

The empirical Neural Tangent Kernel (eNTK) is the Jacobian of the network output w.r.t. its parameters: $J(x) = \partial f(x) / \partial \theta$. The kernel is $K(x, x') = J(x) J(x')^T$.

On a manifold, the relevant object is the projected eNTK: the Jacobian restricted to the tangent space. This tells you how much the output changes for tangent directions only (feasible parameter updates). The blog post argues that this projected eNTK is what determines generalization on manifold-constrained networks — not the full eNTK.

Connection to budget scaling: the sensitivity scalar $s_i$ for layer $i$ is essentially the operator norm of the projected Jacobian block for that layer. Layers where the projected Jacobian is large have high sensitivity and get smaller budgets.

07 Open directions

01
Open Problem

MoE Layers

Mixture-of-Expert layers have multiple weight matrices per layer. How should their manifold constraints and budgets interact? The gating mechanism adds non-trivial coupling.

02
Open Problem

Low-Precision Training

SVD is expensive in FP16/BF16. Can we approximate the retraction cheaply? Newton-Schulz iterations (used in Muon) are one candidate.

03
Open Problem

Architecture Co-Design

If the optimizer is derived from the architecture, can we design architectures that are optimal for manifold training? What does a “manifold-native” transformer look like?

04
Open Problem

Adaptive Manifolds

The manifold is fixed during training. Could we learn the constraint surface itself? Start on Stiefel, relax toward Euclidean as training progresses?

07 References

  1. Bernstein, J. “Modular Manifolds.” Thinking Machines Lab Blog, Sep 2025. Blog post
  2. Bernstein, J. “LoRA Without Regret.” Thinking Machines Lab Blog, 2025.
  3. Jordan, K. et al. “Muon: An optimizer for hidden layers.” 2024.
  4. Absil, P.-A., Mahony, R., Sepulchre, R. Optimization Algorithms on Matrix Manifolds. Princeton University Press, 2008.
  5. Edelman, A., Arias, T., Smith, S. “The Geometry of Algorithms with Orthogonality Constraints.” SIAM J. Matrix Anal. Appl., 1998.
  6. Li, H., Xu, Z., Taylor, G., Studer, C., Goldstein, T. “Visualizing the Loss Landscape of Neural Nets.” NeurIPS, 2018. arXiv:1712.09913
  7. Miyato, T., Kataoka, T., Koyama, M., Yoshida, Y. “Spectral Normalization for GANs.” ICLR, 2018. arXiv:1802.05957
  8. Jacot, A., Gabriel, F., Hongler, C. “Neural Tangent Kernel.” NeurIPS, 2018. arXiv:1806.07572
  9. Hu, E.J. et al. “LoRA: Low-Rank Adaptation of Large Language Models.” ICLR, 2022. arXiv:2106.09685
  10. Bhojanapalli, S. et al. “Lipschitz Bounds and Provably Robust Training.” NeurIPS, 2021.