AI Architectures

JEPA

Yann LeCun’s Joint Embedding Predictive Architecture — stop predicting pixels, start predicting meaning. Learn from unlabeled data by predicting the representation of what you can’t see, in the abstract space where the unpredictable details have already been thrown away.

Prerequisites: An encoder turns an input into a vector of features + Training minimizes a loss by gradient descent. That’s it.
10
Chapters
9+
Simulations
0
Assumed Knowledge

Chapter 0: The Blur Problem

You want a model to learn about the world without labels — just by looking at millions of images. A natural idea: hide part of each image and train the model to fill in the missing pixels. If it can complete the picture, surely it understood it. This is how masked autoencoders (MAE) and most generative pre-training works. It works… but it fights a losing battle, and seeing why is the doorway to JEPA.

Here is the problem. Cover the lower half of a photo of a dog. What is behind the mask? Maybe the dog is sitting. Maybe lying down. Maybe there is grass, maybe a wooden floor, maybe a shadow falling left or right. There are thousands of completions that are all perfectly plausible. The true pixels are just one sample from a huge distribution.

Now ask the model to output one image and score it by pixel error against the truth. To minimize average error across all those possibilities, the safest bet is to predict the average of every plausible completion. And the average of many sharp, different pictures is a blur. The model is mathematically pushed toward mush — not because it failed to understand, but because pixel-error rewards hedging.

The trap: “If it can reconstruct the pixels, it understood the image.” But exact pixels are full of detail that is fundamentally unpredictable — the precise texture of fur, the exact blade of grass, sensor noise. Forcing the model to predict that detail wastes its capacity on noise and punishes it for the world’s inherent uncertainty. JEPA’s move: don’t predict the pixels at all.
Why pixel prediction blurs

The masked region has many valid completions (each a sharp pattern). Slide to average more of them — exactly what minimizing pixel error pushes the model to do. Watch a crisp answer dissolve into the safe, useless blur.

# plausible completions averaged 1

The deep lesson: a good objective should let the model say “there is a dog leg here, I don’t know the exact pixels” — and be rewarded for the first part without being punished for not nailing the second. That means scoring the prediction not on pixels, but on abstract features that ignore unpredictable detail. That is the whole idea of JEPA, and the next nine chapters build it from scratch.

Why does training a model to predict masked pixels by minimizing pixel error tend to produce blurry outputs?

Chapter 1: Three Paths to learn without labels

Self-supervised learning — learning from raw data with no human labels — comes in three broad flavors. JEPA is the third, and you can only appreciate it by seeing the first two and what each gets wrong.

Path 1 — Generative / reconstructive

Corrupt the input (mask it, add noise), then rebuild the original input. Examples: autoencoders, MAE, the “predict the next pixel” family. The decoder lives in input space. As Chapter 0 showed, that forces the model to spend capacity on unpredictable detail and rewards blur. It does work — MAE learns useful features — but it is paying a tax for modeling noise.

Path 2 — Joint-embedding / contrastive

Take two augmented views of the same image, encode both, and pull their embeddings together — while pushing apart embeddings of different images (the “negatives”). Examples: SimCLR, MoCo, DINO. This learns in feature space (good!), but it has its own costs: it needs carefully hand-designed augmentations to define “same,” and it often needs many negative examples or clever tricks to avoid the model cheating by ignoring the input. It learns what is invariant, but throws away where and how things relate.

Path 3 — JEPA (joint-embedding predictive)

The synthesis. Like Path 2, work in feature space, not pixel space. But like Path 1, the signal is prediction: from the features of one part of the input, predict the features of another part. Encode a context region, encode a target region, and train a small predictor to map context-features → target-features. No pixels to rebuild (so no blur tax). No negatives or augmentations required (so no hand-tuned invariances). You predict meaning, conditioned on where you’re predicting it.

The one-sentence essence: Generative methods predict the input from the input. Contrastive methods compare two embeddings. JEPA predicts one embedding from another — keeping the prediction signal of generative methods but moving it into the abstract space of joint-embedding methods, where the unpredictable stuff has already been discarded.
Three paradigms, side by side

Click each path to see its data flow and where the loss is computed. Notice where each one lives: pixel space (orange, blurs) vs feature space (teal, abstract).

What does JEPA borrow from each of the other two paradigms?

Chapter 2: Predict Meaning, not pixels

Here is the central trick, and it is worth slowing down for. Both the context and the target pass through an encoder that turns raw input into an abstract feature vector. The encoder is free to throw away whatever is not worth keeping — exact textures, noise, the precise shade of a pixel. What survives is the gist: “there is fur here, a leg shape there, grass below.” We then predict in that space.

Why latent prediction dodges the blur

Recall the masked dog. In pixel space there were thousands of valid completions, and their average was mush. But ask: what is the feature vector of all those completions? They mostly agree — they all encode “dog leg on grass.” The detail that differed (exact blades of grass) was discarded by the encoder. So in feature space the target is nearly a single point, and predicting a single point is easy and sharp. The unpredictability didn’t vanish — it was filtered out before the loss ever saw it.

Worked example by hand

Suppose a masked region has two equally valid pixel completions, A and B:

pixelspixel-MSE if you predict avg [0.5, 0.5]encoder feature
completion A[0.2, 0.9](0.3²+0.4²)/2 = 0.125“dog-leg” = [1.0]
completion B[0.8, 0.1](0.3²+0.4²)/2 = 0.125“dog-leg” = [1.0]

In pixel space, no matter what single image you output, you eat an error of about 0.125 forever — the irreducible cost of guessing one of two sharp answers, which is why you hedge to the blur. But the encoder maps both completions to the same feature [1.0]. So if you predict the feature [1.0], your loss is zero — and you were never asked which exact pixels appear. The objective only graded you on the part that was actually knowable. That is the entire advantage, in one table.

Pixel error vs. latent error under uncertainty

Add “world uncertainty” — how much the unpredictable detail varies between valid completions. Pixel-space error (orange) climbs with it; latent-space error (teal) stays low, because the encoder discarded that detail before the loss.

world uncertainty (detail variation) 0.50
The key reframing: The job of the loss is not “reproduce the data.” It is “capture what is predictable about the data.” By predicting encoder outputs instead of inputs, JEPA lets the encoder define what counts as predictable — and only that gets graded.

But this freedom hides a danger. If the encoder can choose what to keep, what stops it from keeping nothing — mapping every input to the same boring vector, making prediction trivially perfect? That catastrophe has a name, and it is the subject of the next chapter.

Predicting in latent space avoids blur because:

Chapter 3: Collapse — the trivial cheat

We are training an encoder and a predictor to make predicted features match target features. The loss is the distance between them. Let us think like a lazy optimizer: what is the easiest way to drive that distance to zero?

Make the encoder ignore its input and output a constant vector — say, all zeros — for every image. Then the target feature is always zero. The predictor learns to output zero. The distance is exactly zero. Loss minimized, training “converged.” And the representations are completely useless — they contain no information about the input at all. This is representational collapse, and it is the central hazard of every joint-embedding method.

if encoder(anything) = c, then target = c, predictor outputs c, loss = 0 — perfect and worthless

This is why you cannot naively train both encoders by gradient descent on the same prediction loss. Gradient descent is delighted to find the constant solution — it is a deep, wide basin of zero loss. Contrastive methods avoid it with negatives (“push different images apart, so you can’t map everything to one point”). JEPA avoids it differently, with an elegant asymmetry we’ll meet next chapter.

Common misconception: “Lower loss is always better.” In self-supervised learning, zero loss can be the worst outcome. A collapsed model has perfect loss and zero usefulness. The real goal is high information content in the representations, with low prediction error as a means, not an end. Watch the loss and the representation variance.
Watch it collapse

Each dot is one image’s representation. Press Train (naive) — with no protection, gradient descent on the prediction loss pulls every representation to a single point. The loss hits zero; the information is gone.

“Representational collapse” in a joint-embedding predictive model means:

Chapter 4: The Fix — a target that won’t chase

The collapse happened because the target encoder was free to move toward the easy constant. The fix is to break the symmetry between the two encoders. JEPA uses two ideas together, borrowed from BYOL and MoCo.

Idea 1 — stop-gradient on the target

When we compute the loss, we treat the target features as a fixed goal: no gradient flows back into the target encoder through the loss. The predictor and context encoder must move to hit the target; the target does not get to move toward them. This alone removes the “both sides race to zero” dynamic.

Idea 2 — the target encoder is an EMA of the context encoder

So how does the target encoder ever learn? Not by gradients, but by slowly copying the context encoder. After each step we nudge the target weights a tiny bit toward the context weights — an exponential moving average (EMA):

θtarget ← m · θtarget + (1 − m) · θcontext   (m ≈ 0.996, very close to 1)

With m near 1, the target encoder is a slow, lagging echo of the context encoder. It is always a moving target that the context encoder is chasing — but a target that moves gently and is always grounded in a slightly older, stable version of the network. This lag is exactly what prevents the runaway collapse: the system can never instantly agree on the constant solution, because the target is always behind.

Worked example: how slow is the EMA?

Let m = 0.996, a target weight currently at 0.50, and the context weight has jumped to 0.90. The update: 0.996×0.50 + 0.004×0.90 = 0.498 + 0.0036 = 0.5016. The target barely moved — one ten-thousandth of the way per step. It takes hundreds of steps to catch up. That sluggishness is the feature, not a bug: a stable anchor the fast network learns against.

EMA momentum: stability vs. collapse

Representation variance over training (high = informative, zero = collapsed). At low momentum the target chases the context and everything collapses (red). At high momentum (near 1) the lag holds the system stable and representations stay rich (teal).

EMA momentum m 0.996
python — the JEPA training step (the whole idea in 9 lines)
z_ctx = context_encoder(context_patches)          # [B, N_ctx, D]  — trainable
with torch.no_grad():                            # stop-gradient on the target
    z_tgt = target_encoder(all_patches)[tgt_idx]  # [B, N_tgt, D] — the fixed goal
pred  = predictor(z_ctx, target_positions)        # [B, N_tgt, D]
loss  = ((pred - z_tgt)**2).mean()             # L2 in feature space
loss.backward(); opt.step()                       # updates context + predictor only
for p, pt in zip(context_encoder.parameters(), target_encoder.parameters()):
    pt.data = m * pt.data + (1 - m) * p.data    # EMA — no gradients here
Why is the target encoder updated by EMA of the context encoder instead of by gradient descent?

Chapter 5: The Predictor & mask tokens

We keep saying “predict the target features from the context features.” But there is a missing piece: the predictor must know which target it is predicting. The features of a dog’s leg and the features of the sky above it are different — the predictor needs to be told where it is aiming. That is the job of mask tokens.

How the predictor is conditioned

The predictor (a small transformer) takes two things as input:

  1. The context features — the encoder’s output for the visible patches, each carrying its own positional embedding so the predictor knows where the evidence sits.
  2. One mask token per target location. A mask token is a single learned vector (the same one reused everywhere) plus a positional embedding for the specific patch it stands in for. The learned vector says “something goes here, predict it”; the positional embedding says “here, at row 7, column 3.”

The predictor attends from the mask tokens to the context features and outputs a predicted feature vector for each target location. So the same context can be queried for many different targets just by feeding different positional embeddings — “given what I can see, what’s the feature at this spot? and this one?” The positional embedding is the “question” and the predicted feature is the “answer.”

Concept → realization: the mask token is how an architecture asks a conditional question. Without the positional embedding, the predictor would have to output one blurry “average target” (blur is back!). With it, the prediction is sharp and location-specific — the architecture-level reason JEPA escapes the very averaging problem from Chapter 0.

A crucial design choice: the predictor is kept small and narrow relative to the encoders. We want the heavy lifting — understanding the image — to happen in the encoder, so that the encoder’s features are rich and reusable downstream. If the predictor were huge, it could memorize shortcuts and let the encoder stay lazy. A small predictor forces the encoder to produce features that are genuinely predictable from one another.

Querying the predictor by position

The teal patches are the visible context. Click any gray patch to place a mask token there — the predictor uses the context plus that location’s positional embedding to predict its feature (shown as the arrow of attention from context to the queried spot).

What is a “mask token” in the predictor?

Chapter 6: Masking — what to hide matters enormously

JEPA’s power depends heavily on how you choose the context and target. Hide the wrong thing and the task becomes either trivial or impossible. I-JEPA (the image version, 2023) introduced a specific multi-block masking strategy that is worth understanding, because it encodes a belief about what makes a good learning signal.

The recipe

Why blocks, not random patches?

If you masked single scattered patches (like some pixel methods), each hole could be filled by copying a neighbor — low-level texture interpolation, no understanding needed. By masking large contiguous blocks, the nearest evidence is far away, so the only way to predict the target features is to reason about the whole object: “the context shows a dog’s head and front legs, so the hidden block is probably the body.” The masking strategy is how you dial in the difficulty to force semantic learning.

Common misconception: “More masking is always harder and therefore better.” Not quite — mask too much and there is no context left to predict from (impossible); mask too little and prediction is trivial copying (no learning). The art is choosing block sizes that sit in the sweet spot where prediction is hard but achievable. That sweet spot is what makes the encoder work.
Multi-block masking visualizer

Teal = context (visible), orange = target blocks to predict, dark = unused. Drag the sliders to resize the blocks. Notice the target blocks are carved out of the context — no overlap, so the model can’t cheat by copying.

target block size 3
# target blocks 3
Why does I-JEPA mask large contiguous blocks rather than scattered single patches?

Chapter 7: The Full I-JEPA Loop (showcase)

Now assemble every piece into the complete training loop and run it. An image is split into patches. Context patches go through the context encoder. All patches go through the target encoder (the EMA copy, stop-gradient), and we keep the features at the target locations. The predictor, fed context features plus mask tokens, predicts those target features. The L2 distance is the loss — gradients update the context encoder and predictor; the target encoder is nudged by EMA.

The simulator below runs this loop on toy data and tracks the two numbers that matter: the prediction loss (should fall) and the representation variance (must stay high — if it crashes to zero, the model collapsed). Use it to break JEPA and see the failure modes for yourself.

I-JEPA training, live — and how to break it

Press Train. Loss (orange) should fall while representation variance (teal) stays healthy. Now turn stop-gradient OFF or drag EMA momentum toward 0 and train again — watch the variance collapse to zero (the model cheated). The two safeguards from Chapter 4 are the only things standing between you and a useless model.

EMA momentum m 0.996
mask ratio 0.60

Play until you have seen both outcomes: a healthy run (loss down, variance up) and a collapse (loss down, variance to zero). That contrast is the lesson — in self-supervised learning, a falling loss alone tells you nothing; you must watch what the representations are doing. The genius of I-JEPA is that, with the asymmetry in place, the easy way to lower the loss is to actually understand the image.

What I-JEPA delivered (2023): strong image representations with no hand-crafted augmentations (unlike contrastive methods), learning more semantic off-the-shelf features than pixel-reconstruction (MAE), and doing so with markedly better compute efficiency — a large ViT trained in a fraction of the GPU-hours. V-JEPA (2024) extended the same recipe to video, predicting masked spatiotemporal regions and learning motion and physics from raw footage.

Chapter 8: JEPA as a World Model

Zoom all the way out. JEPA is not just an image trick — for Yann LeCun it is a blueprint for how a machine should model the world at all. His argument: an intelligent agent needs a world model that can predict the consequences of actions, so it can plan. And that model should predict in abstract state space, not pixels — for exactly the Chapter 0 reason.

Why pixel world-models fail at planning

Suppose you want to plan by imagining the future: “if I push this object, what happens next?” A pixel-predicting video model has to render every future frame in full detail. The future is uncertain, so — just like the masked dog — it blurs. Worse, planning means rolling the prediction forward many steps, and each blurry step feeds the next, so errors compound into useless fog within a few frames. You cannot plan in a hallucinated blur.

A JEPA world model predicts the next abstract state from the current abstract state and an action. Because the state already dropped unpredictable detail, the prediction stays sharp, and rolling it forward stays stable — you can imagine many steps ahead in concept space and pick the action sequence that reaches your goal. That is the substrate LeCun argues planning agents need.

The bigger family

Planning rollout: pixel space vs. latent space

Roll a prediction forward step by step. The pixel model (orange) accumulates blur and error each step until it is fog. The latent JEPA model (teal) stays sharp, because uncertainty was dropped before rollout. Drag the horizon and watch the gap explode.

planning horizon (steps ahead) 8
Why does LeCun argue a planning agent’s world model should predict in abstract state space rather than pixels?

Chapter 9: Cheat Sheet & Connections

The whole architecture in one breath, then the family comparison.

problem
predicting pixels rewards blur — the world’s detail is unpredictable
↓ predict in feature space
context & target encoders
encode visible context and (full image → target blocks); the encoder discards the unpredictable
↓ predict one from the other
predictor + mask tokens
context features + positional mask tokens → predicted target features; small, to keep encoders strong
↓ loss
L2 in feature space
distance(pred, target); target uses stop-gradient
↓ avoid collapse
EMA target encoder
target = slow lagging copy of context encoder → stable, no constant cheat

The self-supervised family

MethodPredictsSpaceAnti-collapseNeeds augmentations?
MAE / autoencodermasked pixelspixel(reconstruction target is the input)no
SimCLR / contrastivesame vs differentfeaturenegativesyes (heavy)
BYOL / DINOview agreementfeatureEMA target + stop-gradyes
I-JEPAtarget-block featuresfeatureEMA target + stop-gradno
V-JEPAspacetime featuresfeatureEMA target + stop-gradno

JEPA’s distinctive position: it predicts in feature space (so no blur), uses prediction rather than comparison (so it captures spatial relationships, not just invariance), and needs neither negatives nor hand-crafted augmentations (the masking is the task). The EMA-target trick it shares with BYOL/DINO is the load-bearing piece that keeps it from collapsing.

When to reach for JEPA

Keep exploring

Contrastive Learning — InfoNCE, BYOL, DINO, and the collapse problem in depth
CLIP — joint-embedding across image and text
Vision Transformer — the patch encoder JEPA is built on
World Models — learning to imagine and plan
Diffusion — the generative counterpart that does model pixels

“What I cannot create, I do not understand.” You just rebuilt JEPA from one observation — the world’s detail is unpredictable, so predict its meaning instead: encode both sides, predict one embedding from another, condition on position, and hold the target steady with a slow EMA so it can never collapse. That is the architecture LeCun bets the future of machine intelligence on.