Leverage the sequential nature of language to disentangle semantic from syntactic features in a self-supervised manner — one contrastive loss term changes everything.
You train a Sparse Autoencoder on the internal representations of GPT-2, hoping to discover what the model "knows." You feed in thousands of tokens, learn 16,000 latent features, and excitedly look at what fires. Feature 11795: "the phrase 'The' at the start of sentences." Feature 3042: "sentence endings or periods." Feature 8901: "code syntax endings."
These are real features from publicly available SAEs on Neuronpedia. They are real, interpretable, and completely useless for understanding what the model is thinking.
You wanted to find concepts like "discussion of plant biology" or "scientific explanation" or "legal reasoning." Instead you got a catalog of punctuation patterns and capitalization rules. The SAE recovered syntax — the mechanical grammar of language — but missed semantics — the meaning.
This isn't just an aesthetic complaint. Features that fire on "the word 'the'" are useless for three critical downstream tasks:
The simulation below shows what happens when you decompose a sequence with a standard SAE vs. what we'd want to see. The standard SAE's top features are noisy and token-specific. The ideal features would be smooth and topic-aligned.
Top: a standard SAE's feature activations over a 3-topic sequence — noisy, per-token spikes. Bottom: ideal semantic features — smooth, topic-aligned. Click "Regenerate" to see different random realizations.
The question this paper asks: Can we modify the SAE training objective — with one additional loss term — to recover smooth, semantic features instead of noisy syntactic ones?
The answer, remarkably, is yes. And the modification is almost trivially simple.
Read this sentence: "Photosynthesis is the process by which plants convert sunlight into energy."
Now ask yourself two questions about each token:
This isn't a new idea in linguistics. Chomsky (1965) distinguished syntactic and semantic structure. Griffiths et al. (2004) formally argued that semantics exhibit long-range dependencies while syntax depends on short-range interactions. What's new is using this insight as a training objective for SAEs.
Think of an LLM's hidden representation at each token as a radio signal. It carries two overlapping broadcasts simultaneously:
A standard SAE mixes them together. It has no reason to separate them. T-SAEs add one constraint: some features must look like the FM station — smooth and consistent over time. The rest automatically become the AM modulation.
Drag the slider to see how feature activations differ in temporal scale. Semantic features (orange) change slowly — they track the topic. Syntactic features (teal) fluctuate rapidly — they track per-token grammar. The insight: we can use this difference as a training signal.
Several prior approaches tried to improve SAEs: Matryoshka SAEs (Bussmann et al., 2025) learn hierarchical features but don't enforce temporal structure. Transcoders (Paulo et al., 2025) learn causal features but still treat tokens as i.i.d. BatchTopK (Bussmann et al., 2024) improves sparsity but doesn't address the semantic/syntactic split.
T-SAEs are the first to inject a structural prior from linguistics into SAE training: the temporal consistency of meaning. The modification is a single contrastive loss term added to the standard objective. Everything else — the encoder, decoder, sparsity — stays the same.
Before building the T-SAE architecture, the authors formalize why temporal consistency should work. They propose a simple model of how language is generated, then show what it implies for dictionary learning.
Imagine a speaker producing a sequence of tokens τ1, ..., τT. At each timestep t, the speaker's word choice is controlled by several factors:
The speaker produces each token as a function of context and both latent types:
where τt-1 = (τ1, ..., τt-1) is all previous tokens. Think of ht as "what to say" and lt as "how to say this particular word."
The first key assumption formalizes what we just described intuitively:
This doesn't mean ht is exactly the same at every position — it means it changes slowly. A paragraph about biology has the same topic at token 1 and token 50. The topic might shift at the paragraph boundary (say, to history), but within a passage it's approximately constant.
The second assumption says that the LLM's internal representation xt encodes both ht and lt, and they're hierarchical:
Why is this hierarchical? Because the high-level features do most of the reconstruction work. If you know the topic is "plant biology," you can already predict a lot about the activations. The low-level features clean up the residual — they capture the stuff that the topic alone can't explain (like whether this specific token is a noun or a verb).
If these assumptions hold (and the experimental evidence strongly supports them), then the optimal SAE training strategy has two parts:
This is exactly the Matryoshka-style hierarchy (Bussmann et al., 2025), but with a crucial addition: the temporal contrastive loss on the high-level split. Without it, both splits learn the same kind of features. With it, they specialize.
The high-level features (h) reconstruct most of the signal (shown as a smooth approximation). The low-level features (l) capture the fast-changing residual. Together they perfectly reconstruct xt. Drag the slider to adjust how much the high-level features contribute (ε).
Consider a 768-dimensional activation vector xt from layer 12 of Pythia-160m. The T-SAE has m = 16,384 features, split 20/80: h = 3,277 high-level features and 13,107 low-level features.
At token t = "Photosynthesis" and token t+1 = "is":
T-SAEs modify the standard SAE architecture in exactly two ways: (1) partition the feature space into high-level and low-level splits, and (2) add a contrastive loss on the high-level split. Let's walk through every component.
The encoder takes a model activation xt ∈ Rd (e.g., d = 768 for Pythia-160m, d = 2304 for Gemma2-2b) and produces a sparse feature vector f(xt) ∈ Rm (e.g., m = 16,384 features):
where Wenc ∈ Rm×d is the encoder weight matrix, benc ∈ Rd is the encoder bias, and σ is the activation function. The paper uses BatchTopK activation (k = 20), meaning only the top 20 features (out of 16,384) are non-zero for each token. This enforces strict sparsity.
The m features are partitioned into two groups:
For a 16k SAE: h = 3,277 high-level features, 13,107 low-level features. The split ratio (20/80) is a hyperparameter — ablations in Section 4.6 of the paper show that 10/90 and 20/80 both work well, while 50/50 hurts syntax recovery.
The decoder reconstructs xt from the sparse features:
where Wdec ∈ Rd×m is the decoder matrix and bdec ∈ Rd is the decoder bias. The decoder is split to match the feature partition:
Their concatenation equals the full decoder: Wdec = [Wdec0:h | Wdech:m].
The complete data flow from input activation to reconstruction. High-level features (orange) are encouraged to be temporally consistent via contrastive loss. Low-level features (teal) capture the residual. Click components to highlight data flow.
| Component | Shape | Value |
|---|---|---|
| Input activation xt | Rd | d = 768 (Pythia) or 2304 (Gemma2-2b) |
| Encoder Wenc | Rm×d | m = 16,384 features |
| Feature vector f(xt) | Rm | Only k = 20 nonzero entries (BatchTopK) |
| High-level split | Rh | h = 0.2m = 3,277 features |
| Low-level split | Rm-h | m - h = 13,107 features |
| Decoder Wdec | Rd×m | Same d, m as encoder |
| Contrastive weight α | scalar | α = 1.0 |
| Model layer | — | Layer 8 (Pythia), Layer 12 (Gemma) |
python import torch import torch.nn as nn class TemporalSAE(nn.Module): def __init__(self, d_model, n_features, n_high, k=20): super().__init__() self.n_high = n_high # h: number of high-level features self.k = k # BatchTopK sparsity self.encoder = nn.Linear(d_model, n_features) self.decoder = nn.Linear(n_features, d_model) def encode(self, x): # x: (batch, d_model) -> f: (batch, n_features) pre_act = self.encoder(x) # (B, m) # BatchTopK: keep only top-k activations per sample topk_vals, topk_idx = torch.topk(pre_act, self.k, dim=-1) f = torch.zeros_like(pre_act) f.scatter_(-1, topk_idx, topk_vals) return f def decode_high(self, f): # Reconstruct using only high-level features f_high = f.clone() f_high[:, self.n_high:] = 0 # zero out low-level return self.decoder(f_high) def forward(self, x): f = self.encode(x) # (B, m), sparse x_hat = self.decoder(f) # full reconstruction x_hat_high = self.decode_high(f) # high-level only z = f[:, :self.n_high] # high-level features for contrastive return x_hat, x_hat_high, z
This is the heart of the paper. Everything else — the feature split, the Matryoshka reconstruction, the architecture — is borrowed from prior work. The contrastive loss is what makes T-SAEs work.
Suppose you just train with the Matryoshka reconstruction loss (LH + LL) without any temporal constraint. What happens? Both feature groups learn the same kind of features. The high-level split has no reason to prefer smooth, semantic features over noisy, syntactic ones. The split is meaningless.
We need a loss term that says: "the high-level features for token t should be similar to the high-level features for token t−1." But we also need to prevent collapse — we don't want ALL tokens in the batch to have the same high-level features (that would just be a constant bias term).
This is exactly the setup for contrastive learning: pull positive pairs together, push negative pairs apart.
Let zt = f0:h(xt) be the high-level features of token t. Let s(x, y) be the cosine similarity between x and y. We define:
The contrastive loss is a symmetric InfoNCE-style objective:
Let's break this down piece by piece:
Let's trace through a concrete batch with N = 4 sequences:
| Sequence i | Token t content | Token t-1 content | s(zt(i), zt-1(i)) |
|---|---|---|---|
| 1 | "plants" (biology) | "convert" (biology) | 0.9 |
| 2 | "war" (history) | "Napoleon" (history) | 0.85 |
| 3 | "integral" (math) | "derivative" (math) | 0.88 |
| 4 | "plaintiff" (law) | "verdict" (law) | 0.82 |
Cross-sequence similarities (negatives) should be low:
For sequence 1, the first term becomes:
If the model improves the positive similarity to 0.95 while keeping negatives at 0.1:
The loss decreases. The gradient pushes the encoder to make same-sequence high-level features more similar and cross-sequence features less similar.
Combining everything, the total T-SAE loss is:
where:
In practice, activations are loaded as pairs (xt, xt-1) — adjacent tokens from the same sequence. The pairs are shuffled across the batch so that negative pairs come from different sequences. This means each training batch is 2× the normal size (each sample is a pair), which reduces the effective batch size by half for the same memory budget.
Drag the tokens to change their high-level feature representations. The contrastive loss computes in real time. Positive pairs (same sequence) want high similarity; negative pairs (different sequences) want low similarity. Watch the loss change as you move tokens.
python def contrastive_loss(z_t, z_prev): # z_t, z_prev: (N, h) high-level features for adjacent tokens # Compute pairwise cosine similarities z_t_norm = z_t / z_t.norm(dim=-1, keepdim=True) z_prev_norm = z_prev / z_prev.norm(dim=-1, keepdim=True) sim = z_t_norm @ z_prev_norm.T # (N, N) # Positive pairs are on the diagonal: sim[i, i] # Row-wise: for each z_t^(i), classify correct z_prev labels = torch.arange(z_t.size(0), device=z_t.device) loss_row = nn.functional.cross_entropy(sim, labels) # Column-wise: for each z_prev^(i), classify correct z_t loss_col = nn.functional.cross_entropy(sim.T, labels) return (loss_row + loss_col) / 2 def tsae_loss(x_t, x_prev, model, alpha=1.0): x_hat_t, x_hat_high_t, z_t = model(x_t) x_hat_prev, x_hat_high_prev, z_prev = model(x_prev) L_H = (x_t - x_hat_high_t).pow(2).sum(-1).mean() L_L = (x_t - x_hat_t).pow(2).sum(-1).mean() L_matr = L_H + L_L L_contr = contrastive_loss(z_t, z_prev) return L_matr + alpha * L_contr
The whole point of T-SAEs is to separate semantic from syntactic features. But does it actually work? And how do we even measure "disentanglement"?
The paper uses linear probes — simple logistic regression classifiers trained on top of SAE features — to measure what information each feature split encodes.
The setup: take MMLU questions (multiple-choice academic questions from 57 subjects), encode the last 20 tokens of each question through the LLM, extract SAE features, and train probes to predict:
And that's exactly what happens. Using Gemma2-2b SAEs on MMLU:
| SAE Type | Semantics (Acc) | Context (Acc) | Syntax (Acc) |
|---|---|---|---|
| T-SAE (all features) | 0.91 | 0.95 | 0.81 |
| Matryoshka SAE | 0.82 | 0.87 | 0.83 |
| BatchTopK SAE | 0.80 | 0.85 | 0.82 |
| Baseline (model activations) | 0.84 | 0.89 | 0.79 |
The T-SAE beats baselines on semantics by 9+ percentage points and on context by 8+ points. It's slightly worse on syntax — but that's expected and actually desirable, because the high-level features are now capturing semantics instead of wasting capacity on syntax.
The paper's Figure 2 provides a stunning visual confirmation. When you plot T-SAE high-level feature activations in 2D (via TSNE) and color by question category, you see clean clusters for each subject. Matryoshka SAE features, plotted the same way, show no clear clustering.
Even more telling: when you color by syntax (part-of-speech), T-SAE low-level features cluster cleanly by POS tag, while T-SAE high-level features don't. The split is working.
Each dot represents a token's SAE features projected to 2D. Toggle between labeling schemes to see how T-SAE features cluster differently than baseline SAEs. Orange = T-SAE high-level, Teal = T-SAE low-level, Gray = Baseline SAE.
The paper goes further: it probes each split separately. For T-SAEs:
For Matryoshka SAEs, both splits perform similarly on all three tasks — no specialization. The 20/80 partition is arbitrary without the contrastive loss to drive differentiation.
A natural worry: does adding the contrastive loss hurt reconstruction quality? If T-SAEs disentangle features but can't reconstruct the input, they're useless. The paper evaluates five standard SAE metrics.
| Metric | What it measures | Higher or lower is better? |
|---|---|---|
| FVE (Fraction Variance Explained) | 1 − Var(x − x̂) / Var(x). How much of the input's variance the SAE captures. | Higher ↑ |
| Cosine Similarity | cos(x, x̂). Are the reconstructions pointing in the right direction? | Higher ↑ |
| Fraction Alive | What fraction of the 16k features activate at least once on the test data? | Higher ↑ |
| Smoothness | Average max absolute change in active feature activations, normalized by the change in the model's activations. Lower = smoother. | Lower ↓ |
| AutoInterp Score | Can an LLM (Llama3.3-70B) generate a correct feature explanation? SAEBench evaluation. | Higher ↑ |
| Model | SAE | FVE ↑ | CosSim ↑ | Alive ↑ | Smooth (High) ↓ | Smooth (Low) ↓ | AutoInterp ↑ |
|---|---|---|---|---|---|---|---|
| Pythia-160m | Temporal SAE | 0.94 | 0.93 | 0.87 | 0.09 | 0.17 | 0.81 ± 0.17 |
| Matryoshka SAE | 0.95 | 0.94 | 0.89 | 0.12 | 0.13 | 0.83 ± 0.16 | |
| BatchTopK SAE | 0.95 | 0.94 | 0.84 | 0.13 | — | 0.85 ± 0.15 | |
| Gemma2-2b | Temporal SAE | 0.75 | 0.88 | 0.78 | 0.10 | 0.15 | 0.83 ± 0.15 |
| Matryoshka SAE | 0.75 | 0.89 | 0.76 | 0.14 | 0.12 | 0.83 ± 0.16 | |
| BatchTopK SAE | 0.76 | 0.89 | 0.66 | 0.13 | — | 0.83 ± 0.16 |
The smoothness metric deserves careful explanation. For a sequence of length T and a set of active features, we compute:
This is the maximum absolute change in feature activation, normalized by how much the underlying model activation changed. A smooth feature might have Δs = 0.05 (it barely changes even when the model activation changes a lot). A noisy syntactic feature might have Δs = 0.5 (it spikes wildly).
T-SAE high-level features have the lowest smoothness score of any method. Their low-level features are appropriately less smooth, confirming that the split works as intended.
Feature activations over a sequence for three SAE types. T-SAE high-level features (orange, solid) are visibly smoother than Matryoshka (gray, dashed) or BatchTopK (gray, dotted). T-SAE low-level features (teal) are appropriately less smooth. Click "New Sequence" for different realizations.
An interesting detail: T-SAEs have a higher fraction of alive features (0.87 vs. 0.84 on Pythia). This means more of the 16k features activate at least once, suggesting that the contrastive loss encourages the SAE to use more of its feature capacity. Dead features (features that never activate) are wasted parameters — having fewer of them is a benefit.
Academic metrics are nice, but do T-SAEs actually help with real problems? The paper presents two compelling case studies: understanding alignment datasets and steering model behavior.
The HH-RLHF dataset (Bai et al., 2022) is Anthropic's human preference dataset used to train safety-focused models. It contains pairs of completions — one "chosen" (preferred by human raters) and one "rejected." The paper asks: what features differentiate chosen from rejected completions?
Method: For each completion pair, compute the difference in mean T-SAE feature activations (rejected − chosen). Features with the largest average difference are the ones that distinguish unsafe from safe content.
What Matryoshka SAEs find for the same analysis: "specific bicycle components," "terms related to data management," "references to ecosystem dynamics and environmental conditions." These are random noise features that happen to correlate with rejected completions for spurious reasons (like response length).
Here's a subtle but critical finding. Some T-SAE features that show high activation differences are actually spuriously correlated with response length, not with actual safety content. The paper identifies these by computing the Pearson correlation between feature activation difference and response length difference.
| Feature | Avg Diff (rejected − chosen) | Corr with length | Type |
|---|---|---|---|
| transition words and phrases | 0.063 | 0.52 | Length-related |
| legal and formal language | 0.058 | 0.38 | Length-related |
| the word "the" | 0.047 | 0.31 | Length-related |
| crime and malicious activities | 0.060 | 0.12 | Semantically relevant |
| violent or aggressive behavior | 0.044 | 0.05 | Semantically relevant |
| negative comments and insults | 0.043 | -0.08 | Semantically relevant |
The semantically relevant features (green) have low correlation with length — they genuinely capture unsafe content, not just the fact that rejected responses tend to be longer. The length-related features (orange) are spurious confounders. T-SAEs recover both, but critically, the semantically relevant ones are clearly identified and separable.
Can T-SAE features be used to steer model generation? The paper intervenes on features during inference by adding α · di to the model's residual stream (where di is the decoder column for feature i and α is the intervention strength).
The key finding: steering with high-level (semantic) features is dramatically more effective than steering with low-level (syntactic) features.
Simulated steering with high-level (semantic) vs. low-level (syntactic) features at varying intervention strengths. High-level steering changes the topic while maintaining coherence. Low-level steering degrades to repetition. Drag the strength slider to see the effect.
The paper ablates three key design choices in the T-SAE training pipeline. Each ablation answers a specific question about why the method works.
How much of the feature space should be "high-level"? The paper tests 10/90, 20/80, and 50/50 splits.
| Split (High/Low) | FVE | Smoothness (High) | Semantics | Context | Syntax |
|---|---|---|---|---|---|
| 10/90 | 0.94 | 0.08 | 0.88 | 0.92 | 0.82 |
| 20/80 (default) | 0.94 | 0.09 | 0.91 | 0.95 | 0.81 |
| 50/50 | 0.94 | 0.11 | 0.90 | 0.94 | 0.73 |
Instead of contrasting with the immediately previous token (t-1), what if we contrast with a random token from the past context? The paper samples the contrastive partner uniformly from x1, ..., xt-1, where the random token xt-r has r < 25.
| Contrastive Partner | Semantics | Context | Syntax |
|---|---|---|---|
| Adjacent (t-1) [default] | 0.91 | 0.95 | 0.81 |
| Random past (r < 25) | 0.89 | 1.06 | 0.71 |
Random contrasting boosts context accuracy significantly (+11%) because it encourages longer-range consistency. But it hurts syntax accuracy (−10%) because the low-level features now have less capacity — the high-level features are "greedier," capturing more information.
What if you remove the contrastive loss entirely and just use the Matryoshka reconstruction objective with the 20/80 split?
| Configuration | Δ Semantics | Δ Context | Δ Syntax |
|---|---|---|---|
| No contrastive (α = 0) | −0.07 | −0.10 | +0.01 |
| Naive L2 smoothness (not contrastive) | −0.02 | +0.02 | +0.07 |
Without the contrastive loss, semantics drops by 7 points and context by 10 points. The Matryoshka split alone is not enough — the contrastive loss is essential for driving specialization.
The "naive L2 smoothness" ablation replaces the contrastive loss with a simple per-sample L2 penalty: ℓ = α‖zt − zt-1‖22. This enforces smoothness but without the contrastive negative pairs. It performs worse on semantics and much better on syntax — because without negatives to prevent collapse, all high-level features converge to similar values, losing discriminative power.
Adjust the contrastive weight α and split ratio to see how probe accuracies change. The bars show semantics (orange), context (blue), and syntax (teal). Watch how removing the contrastive loss (α=0) or changing the split ratio affects specialization.
T-SAEs sit at the intersection of three research areas: mechanistic interpretability, contrastive representation learning, and computational linguistics. Let's map the connections and limitations.
| Symbol | Meaning | Typical Value |
|---|---|---|
| xt | LLM activation at token t | R768 or R2304 |
| f(xt) | Sparse feature vector | R16384, only 20 nonzero |
| f0:h(xt) = zt | High-level (semantic) features | First 20% of features |
| fh:m(xt) | Low-level (syntactic) features | Remaining 80% of features |
| LH | High-level reconstruction loss | ‖x − Wdec0:hf0:h + b‖2 |
| LL | Full reconstruction loss | ‖x − Wdecf + b‖2 |
| Lcontr | Symmetric contrastive loss (InfoNCE) | On zt, zt-1 across batch |
| α | Contrastive weight | 1.0 |
| L | Total loss | LH + LL + αLcontr |
| s(x, y) | Cosine similarity | [−1, 1] |
| Δs | Smoothness metric | Lower = smoother features |
| Method | Key Idea | vs. T-SAEs |
|---|---|---|
| Standard SAE | Sparse reconstruction | No temporal structure, features are noisy and syntactic |
| Matryoshka SAE | Hierarchical feature splits | Same split but no contrastive loss → no specialization |
| BatchTopK SAE | Fixed-k sparsity | Better sparsity control but still i.i.d. tokens |
| Transcoders | Causal features via MLP replacement | Different architecture; T-SAEs modify the loss, not the model |
| CPC / InfoNCE | Contrastive predictive coding | T-SAEs apply the same principle to SAE feature learning |
| Griffiths et al. (2004) | HMM + LDA for syntax/semantics | T-SAEs are the neural, unsupervised version of this idea |
"What I cannot create, I do not understand." — Richard Feynman. T-SAEs let us decompose what LLMs create into parts we can understand: meaning and grammar, separately.