The Complete Beginner's Path

Understand VAE / VQ-VAE

The compression engines behind image generation. Learn how neural nets learn to encode, quantize, and reconstruct the visual world.

Prerequisites: Neural network basics + Probability intuition. That's it.
10
Chapters
8+
Interactives
0
Assumed Knowledge

Chapter 0: Why Latent Spaces?

A 256×256 colour image has 196,608 numbers. But most of those numbers are redundant: smooth patches, repeated textures, predictable edges. The real information in an image — its content, composition, style — can be captured by far fewer numbers. A latent space is where those compressed essentials live.

The encoder-decoder idea is simple: an encoder squeezes the input into a small latent vector z, and a decoder tries to reconstruct the original from z alone. If the reconstruction is good, z must contain the essence of the data.

Input x
256×256×3 = 196,608 dims
↓ Encoder
Latent z
64 dims (compressed essence)
↓ Decoder
Reconstruction x̂
196,608 dims (recovered)
The core idea: High-dimensional data lives on a low-dimensional manifold. An image of a face can be described by a few dozen numbers: skin tone, eye shape, hair length, pose. The latent space learns to discover these factors automatically.

What does this look like in a real system? Stable Diffusion's VAE takes a 512×512 RGB image — a tensor of shape [1, 3, 512, 512] — and compresses it to a latent of shape [1, 4, 64, 64]. That's 786,432 input values squeezed into 16,384 latent values: a 48× compression. Every image you've seen generated by Stable Diffusion started life as a tiny 64×64×4 grid.

SD Input
[1, 3, 512, 512] = 786,432 values
↓ VAE Encoder (conv layers)
SD Latent
[1, 4, 64, 64] = 16,384 values
↓ Diffusion happens here!
SD Output
[1, 3, 512, 512] via VAE Decoder
Compression Ratio

Drag the latent dimension to see how much compression we achieve. The bar shows what fraction of the original information we keep.

Latent dims64
Real-world scale: Stable Diffusion's encoder has ~34M parameters. The decoder has ~50M. Together they're ~84M params — tiny compared to the 860M-parameter U-Net that does diffusion. The VAE is cheap to train, cheap to run, and yet it's the foundation everything else stands on. If the VAE is bad, everything downstream is bad.
Check: Why do we compress data into a latent space?

Chapter 1: Autoencoders — The Bottleneck

A plain autoencoder trains two neural networks end-to-end: the encoder f and decoder g. The loss is simply the reconstruction error: how different is g(f(x)) from x? The bottleneck — a narrow latent layer — forces compression.

It works. But there's a problem: the latent space is messy. Points are scattered unpredictably. If you pick a random point in latent space and decode it, you get garbage. The autoencoder memorizes efficient codes but doesn't organize them.

L = ||x − Decoder(Encoder(x))||²

Concretely, the encoder is a stack of convolutional layers that progressively downsample. An image [batch, 3, 256, 256] passes through conv → downsample → conv → downsample until the spatial dimensions shrink and the channel count grows: [batch, 256, 32, 32]. A final linear layer (or global average pooling) flattens this to the latent vector [batch, latent_dim]. The decoder mirrors this process: upsample → conv → upsample until you're back to the original resolution.

python
# Minimal encoder architecture
class Encoder(nn.Module):
    def forward(self, x):        # x: [B, 3, 256, 256]
        x = self.conv1(x)         # [B, 64, 128, 128]
        x = self.conv2(x)         # [B, 128, 64, 64]
        x = self.conv3(x)         # [B, 256, 32, 32]
        x = x.flatten(1)         # [B, 256*32*32]
        return self.fc(x)         # [B, latent_dim]
Latent Space: Organized vs Messy

Left: messy autoencoder latent space — clusters with gaps. Right: organized VAE latent space — smooth and continuous. Click to regenerate.

Key problem: Autoencoders are good at compression but bad at generation. You can't sample new images because you don't know which latent codes are "valid." The VAE fixes this by forcing structure on the latent space.

To make this concrete: train an autoencoder on face images. Encode two faces and get z1 and z2. Decode the midpoint (z1 + z2) / 2 — you might get a coherent blend, or you might get static noise. The autoencoder never promised that midpoints would be meaningful. It only learned to reconstruct, not to organize. The latent space is full of "dead zones" where decoded output is garbage.

The VAE's fix is elegant: force every encoded point to look like it came from a standard Gaussian. Then sampling is trivial — just draw z ~ N(0,1) and decode. But this constraint comes at a cost: the encoder can't use arbitrary regions of latent space anymore, which limits reconstruction fidelity. This tension defines the entire field.

Check: What's the main limitation of a plain autoencoder?

Chapter 2: The Variational Trick

The key insight of the VAE: instead of encoding x to a single point z, encode it to a distribution — specifically, a Gaussian with mean μ and variance σ². Then sample z from that distribution. This forces the latent space to be smooth.

In code, the encoder doesn't output one vector. It outputs two: μ (the mean) and log σ² (the log-variance). Why log-variance instead of variance? Because variance must be positive, and log-variance is unconstrained — the network can output any real number. To get the standard deviation: std = exp(0.5 * log_var).

python
# Encoder outputs TWO heads, not one
h = encoder_backbone(x)           # [B, 256, 32, 32] → flatten → [B, hidden]
mu      = self.fc_mu(h)           # [B, latent_dim]
log_var = self.fc_logvar(h)       # [B, latent_dim]

# Reparameterization trick
std = torch.exp(0.5 * log_var)   # [B, latent_dim]
eps = torch.randn_like(std)       # [B, latent_dim] ~ N(0,1)
z   = mu + std * eps              # [B, latent_dim] — differentiable!

But wait: sampling is not differentiable. How do we backpropagate through a random number? The reparameterization trick: instead of z ~ N(μ, σ²), write z = μ + σ · ε where ε ~ N(0,1). Now the randomness (ε) is external, and gradients flow through μ and σ.

z = μ + σ · ε      ε ~ N(0, 1)

Why does this matter so much? If we sampled z directly from N(μ, σ²), the sampling operation has no gradient with respect to μ or σ. Backprop hits a wall. But z = μ + σ · ε is a deterministic function of μ, σ, and the random noise ε. Gradients flow through μ and σ just fine — ε is treated as a constant input. The stochasticity is still there (each forward pass samples a different ε), but the computation graph is smooth.

Encoder Output
μ = 2.3, log_var = −1.39 → σ = 0.5
↓ sample ε ~ N(0,1)
Reparameterize
z = 2.3 + 0.5 × ε
Decoder
Reconstruct from z
Reparameterization Trick

Adjust μ and σ. Each frame samples a new ε and shows the resulting z. The orange curve is the distribution; teal dots are samples.

Mean μ1.0
Std dev σ0.8
Why it matters: The reparameterization trick is what makes VAEs trainable. Without it, you can't backpropagate through the sampling step. It's one of the cleverest tricks in modern deep learning.

Let's trace the gradient path to make this visceral. During backprop, we need ∂L/∂μ and ∂L/∂log_var. With reparameterization: ∂z/∂μ = 1 and ∂z/∂σ = ε. Both are simple, well-defined numbers. Without reparameterization, z is drawn from N(μ, σ²) and ∂z/∂μ doesn't exist — sampling is a black box. The trick turns a stochastic bottleneck into a deterministic computation with a random input.

Without trick:
x → Encoder → μ, σ → SAMPLE z → Decoder
Gradients stop at the sampling step.
With trick:
x → Encoder → μ, σ → z = μ + σ·ε → Decoder
Gradients flow through μ and σ.
Check: What does the reparameterization trick achieve?
🔨 Derivation Why z = μ + σε is differentiable but sampling from N(μ, σ²) is not ✓ ATTEMPTED

We need ∂L/∂μ and ∂L/∂σ to train the encoder. With z sampled from N(μ, σ²), the computation graph has a stochastic node. With z = μ + σ·ε, it doesn't.

Your task: Show that ∂z/∂μ and ∂z/∂σ are well-defined under the reparameterization, and explain why the "naive" sampling has no gradient.

Treat ε as a constant (it's sampled once and fixed for this forward pass). Then z is just a linear function of μ: ∂z/∂μ = 1. Simple as that.
Again, ε is a constant. z = μ + σ·ε, so ∂z/∂σ = ε. The gradient with respect to sigma is just the noise sample itself.
A random number generator is a black-box function. The operation "draw from distribution" doesn't have a derivative because it's not a deterministic function of μ and σ in the computation graph. The stochastic node blocks backpropagation entirely.

Full derivation:

Let L be the loss computed from the decoder output. We need ∂L/∂μ via the chain rule:

∂L/∂μ = (∂L/∂z) · (∂z/∂μ)

With reparameterization z = μ + σ·ε:

∂z/∂μ = 1     ∂z/∂σ = ε

So: ∂L/∂μ = ∂L/∂z   and   ∂L/∂σ = ε · ∂L/∂z

Without reparameterization, z = sample(N(μ, σ²)). The "sample" operation is discontinuous — tiny changes in μ can produce arbitrarily different z values depending on internal RNG state. There is no function z(μ) to differentiate. The gradient ∂z/∂μ is undefined.

The key insight: The trick moves the source of randomness outside the computation graph. ε enters as a "data input," not as an operation. The graph becomes: ε (input) → multiply by σ → add μ → z. Every arrow is differentiable.

Chapter 3: ELBO — The Training Objective

We want to maximize the likelihood of our data: p(x). But computing p(x) directly requires integrating over all possible latent codes z, which is intractable. The integral p(x) = ∫ p(x|z) p(z) dz sums over every possible latent code. With a 64-dimensional z, that's integration over a 64-dimensional space — computationally hopeless.

Instead, we optimize a lower bound on log p(x): the Evidence Lower Bound (ELBO). The gap between log p(x) and the ELBO is exactly KL(q(z|x) || p(z|x)) — the KL divergence between our approximate posterior and the true posterior. Since KL is always ≥ 0, the ELBO is always ≤ log p(x). Maximizing the ELBO simultaneously tightens the bound and improves the model.

log p(x) ≥ ELBO = Eq(z|x)[log p(x|z)] − KL(q(z|x) || p(z))

The ELBO has two terms: a reconstruction term (how well can we decode z back to x?) and a KL divergence term (how close is our learned distribution q(z|x) to the prior p(z) = N(0,1)?). The reconstruction term wants good decoding; the KL term wants an organized latent space.

Reconstruction term: "Decode z and get back x." Encourages faithful reconstruction. Equivalent to MSE or BCE loss.
KL term: "Keep q(z|x) close to N(0,1)." Prevents the encoder from cheating by using a tiny region of latent space.

Let's make this concrete with numbers. Suppose for one image, the encoder outputs μ = 3.0, σ = 0.2 for some latent dimension. The KL term for that dimension is ½(3.0² + 0.04 − log 0.04 − 1) = ½(9 + 0.04 + 3.22 − 1) = 5.63. That's a big penalty — the KL is screaming "your mean is way off from zero!" If instead μ = 0.1, σ = 0.9, the KL is ½(0.01 + 0.81 + 0.21 − 1) = 0.015. Almost free. The reconstruction term pushes back: "but I need μ = 3.0 to encode this cat!" The balance between these two forces IS the VAE.

ELBO Decomposition

Adjust the balance between reconstruction quality and KL penalty to see the tradeoff. Low KL = organized latent space but blurrier. High reconstruction = sharp but messy latent space.

Recon weight5.0
KL weight5.0
Intuition: The ELBO is a tug-of-war. The reconstruction loss pulls the encoder to memorize every detail. The KL loss pulls it toward a smooth, standard Gaussian. The balance between them determines the character of the latent space.
Check: What are the two components of the ELBO?
🔨 Derivation Derive the ELBO from log p(x) ✓ ATTEMPTED

We want to maximize log p(x) but can't compute it directly. We introduce an approximate posterior q(z|x) and derive a tractable lower bound.

Your task: Starting from log p(x) = log ∫ p(x,z) dz, introduce q(z|x) and derive: log p(x) ≥ Eq[log p(x|z)] − KL(q(z|x) || p(z)).

Write log p(x) = log ∫ p(x,z) · q(z|x)/q(z|x) dz = log Eq(z|x)[p(x,z)/q(z|x)]. Now you have an expectation under q.
Since log is concave, Jensen's gives: log Eq[p(x,z)/q(z|x)] ≥ Eq[log p(x,z)/q(z|x)] = Eq[log p(x,z) − log q(z|x)]. This is the ELBO.
Eq[log p(x|z) + log p(z) − log q(z|x)] = Eq[log p(x|z)] − Eq[log q(z|x) − log p(z)] = Eq[log p(x|z)] − KL(q||p). That last term is the KL divergence by definition.

Full derivation:

Step 1: log p(x) = log ∫ p(x,z) dz = log ∫ p(x,z) · q(z|x)/q(z|x) dz

Step 2: = log Eq(z|x)[p(x,z) / q(z|x)]

Step 3 (Jensen's): ≥ Eq(z|x)[log(p(x,z) / q(z|x))]

Step 4: = Eq[log p(x|z) + log p(z) − log q(z|x)]

Step 5: = Eq[log p(x|z)] − Eq[log q(z|x) − log p(z)]

Step 6: = Eq[log p(x|z)] − KL(q(z|x) || p(z))

The gap between log p(x) and the ELBO is exactly KL(q(z|x) || p(z|x)) — the divergence between our approximate posterior and the TRUE posterior. When q perfectly approximates p(z|x), the bound is tight.

The key insight: We can't compute p(z|x) (that's what we're trying to learn!), so we optimize a lower bound instead. Maximizing the ELBO simultaneously improves the model (makes the bound higher) and tightens the approximation (makes q closer to the true posterior).

🔗 Pattern Recognition
ELBO as Rate-Distortion
This Lesson (VAE)
ELBO = E[log p(x|z)] − KL(q(z|x) || p(z))
= Reconstruction − Regularization
Information Theory
Rate-Distortion: minimize D (distortion) subject to R ≤ Rmax (rate)
Same tradeoff: fidelity vs compression → Diffusion lesson

The KL term IS the rate: it measures how many bits the encoder uses to communicate z. The reconstruction term IS the distortion: how much quality we lose. The β in β-VAE literally traces the rate-distortion curve — higher β means lower rate (more compression) at the cost of higher distortion (blurrier images). This is exactly the same tradeoff that JPEG, MP3, and every lossy codec faces.

Where else have you seen "compress information through a bottleneck then reconstruct"? (Hint: attention mechanisms compress T tokens through a T×T matrix, forcing the model to select which information survives.)

Checkpoint — Before you move on
Explain in your own words: why can't we just maximize log p(x) directly, and how does the ELBO let us train anyway? What role does q(z|x) play?
✓ Gate cleared
Model Answer

Computing log p(x) requires integrating p(x|z)p(z) over ALL possible z values — a high-dimensional integral with no closed form. We can't even evaluate p(x), let alone optimize it.

The ELBO sidesteps this by introducing q(z|x): a learned approximation to the true posterior p(z|x). Instead of integrating over all z, we sample z from q(z|x) and evaluate log p(x|z). This is tractable — it's just running the decoder on one sampled z and comparing to x.

The bound is useful because: (1) maximizing a lower bound on log p(x) also increases log p(x) itself, and (2) the gap between them shrinks as q gets better. So one loss function simultaneously trains the encoder (make q better), the decoder (make p(x|z) better), and regularizes the latent space (keep q close to the prior).

Chapter 4: Training a VAE

In practice, the KL term has a closed-form solution for Gaussians. For each latent dimension j, the KL divergence is: ½(μj² + σj² − log σj² − 1). This is cheap to compute and differentiable. The total KL sums over all latent dimensions.

The big practical knob is β: a weight on the KL term. With β=1, you get the standard VAE (ELBO). With β>1, you get β-VAE, which forces more disentanglement at the cost of blurrier reconstructions. With β<1, you get sharper images but messier latent space.

Why do VAEs produce blurry images? The reconstruction loss is usually MSE: ||x − x̂||². MSE penalizes pixel-level differences, so the decoder hedges its bets — when uncertain, it outputs the average, which is blurry. Modern VAEs fix this with perceptual loss (compare features from a pre-trained VGG network instead of raw pixels) or adversarial loss (add a discriminator that penalizes blurriness). Stable Diffusion's VAE uses both, which is why its reconstructions are sharp.

L = ||x − x̂||² + β · ∑j ½(μj² + σj² − log σj² − 1)

Let's trace through a concrete example. Suppose latent_dim = 4 and the encoder outputs:

Dim jμjσj²KLjInterpretation
00.10.90.008Near standard normal — almost free
12.50.14.0Mean far from zero — big penalty
20.00.011.8Variance too small — "cheating" by being too certain
30.31.20.06Slightly off — small penalty

Total KL = 5.87. With β = 1 and a reconstruction loss of, say, 150, the KL is only ~4% of the total — hence the model mostly focuses on reconstruction. With β = 10, KL becomes 58.7, nearly half the loss — now the model is strongly pushed toward N(0,1).

Beta Slider: Sharpness vs Structure

Low β: sharp but unstructured. High β: blurry but well-organized latent space. Watch how the "reconstruction" and "organization" bars respond.

Beta β1.00
β valueReconstructionLatent structureUse case
β < 1SharpMessyWhen quality matters most
β = 1BalancedGoodStandard VAE (ELBO)
β > 1BlurryDisentangledβ-VAE for interpretable factors

Here's the full training loop in code. Notice how each loss component maps to a specific part of the model:

python
# Forward pass
mu, log_var = encoder(x)                 # two heads
std = torch.exp(0.5 * log_var)
z = mu + std * torch.randn_like(std)     # reparam trick
x_hat = decoder(z)                       # reconstruction

# Reconstruction loss (applies to decoder)
recon_loss = F.mse_loss(x_hat, x, reduction='sum')

# KL loss (applies to encoder's mu and log_var heads)
kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

# Total
loss = recon_loss + beta * kl_loss
KL annealing: A common trick: start with β=0 (pure autoencoder) and slowly raise it during training. This lets the model learn useful codes before the KL term collapses them.
Check: What does increasing β do?
💥 Break-It Lab What Dies When You Remove VAE Loss Components? ✓ ATTEMPTED
A working VAE balances reconstruction loss and KL divergence. Toggle components off to see the training dynamics break in different ways.
Remove KL term (β = 0) ACTIVE
Failure mode: Without KL regularization, the encoder uses tiny non-overlapping regions of latent space. Reconstruction is perfect, but sampling from N(0,1) produces garbage. The latent space is a Swiss cheese of valid codes surrounded by dead zones. This is just an autoencoder — generation is impossible.
Remove reconstruction loss ACTIVE
Failure mode: The KL term alone is minimized when q(z|x) = N(0,1) for ALL inputs. The encoder outputs μ=0, σ=1 regardless of input — it ignores x entirely. The decoder receives random noise and outputs the dataset mean (a blurry average face). This is posterior collapse: the latent code carries zero information.
Set β = 100 (extreme) ACTIVE
Failure mode: Extreme KL pressure forces all latent distributions toward N(0,1). The encoder barely differentiates between inputs. Images are blurry and nearly identical. The model is over-regularized — it compressed away all the useful information to satisfy the KL constraint.
⚔ Adversarial: Your VAE generates blurry images. The KL term is near zero. Diagnosis?
You've trained a VAE for 100 epochs. The reconstruction loss plateaued at a high value. The KL divergence collapsed to ~0.01 per dimension early in training and stayed there. Generated samples (z ~ N(0,1), decode) are blurry and look like averaged versions of training images.

Chapter 5: VQ-VAE — Discrete Codes

What if latent codes were discrete instead of continuous? VQ-VAE replaces the Gaussian latent space with a codebook: a table of K learned vectors, each of dimension D. In code, the codebook is just an nn.Embedding(K, D) — a matrix of shape [K, D]. Typical values: K = 8192 entries, D = 256 dimensions.

The encoder outputs a continuous feature map [batch, D, H, W] — say [1, 256, 32, 32]. For each of the 32×32 = 1,024 spatial locations, we find the nearest codebook entry and replace it. That's 1,024 independent nearest-neighbor lookups. The decoder only ever sees codebook entries, not the raw encoder output.

Encoder
x → ze (continuous)
↓ nearest-neighbor lookup
Codebook
zq = argmin ||ze − ek||²
Decoder
zq → x̂
Vector Quantization

Blue dots are codebook entries. Drag the orange point (encoder output) and watch it snap to the nearest codebook entry. The green line shows the assignment.

L = ||x − x̂||² + ||sg[ze] − e||² + β ||ze − sg[e]||²

The loss has three parts, and each applies to different parameters:

Loss termWhat it doesUpdates what
||x − x̂||²Reconstruction — decoded image matches inputEncoder + Decoder weights
||sg[ze] − e||²Codebook — move entries toward encoder outputsCodebook entries only (sg freezes encoder)
β||ze − sg[e]||²Commitment — keep encoder near chosen entriesEncoder only (sg freezes codebook)
Straight-through estimator: Nearest-neighbor lookup is not differentiable — argmin has zero gradient almost everywhere. The solution: on the forward pass, use the discrete codebook vector. On the backward pass, pretend the quantization didn't happen and copy gradients from decoder input straight to encoder output. In code: z_q = z_e + (z_q - z_e).detach(). This one line makes VQ-VAE trainable.
python
# VQ-VAE quantization (per spatial location)
z_e = encoder(x)                    # [B, D, 32, 32]
z_flat = z_e.permute(0,2,3,1).reshape(-1, D)  # [B*1024, D]
dists = torch.cdist(z_flat, codebook.weight)    # [B*1024, K]
indices = dists.argmin(dim=1)                   # [B*1024]
z_q = codebook(indices)                         # [B*1024, D]

# Straight-through: forward=discrete, backward=continuous
z_q_st = z_e + (z_q.reshape_as(z_e) - z_e).detach()
Check: In VQ-VAE, what replaces the Gaussian latent space?
🔨 Derivation Why VQ-VAE needs THREE loss terms (and what sg[] does) ✓ ATTEMPTED

The VQ-VAE loss is: L = ||x − x̂||² + ||sg[ze] − e||² + β||ze − sg[e]||². The sg[] (stop-gradient) operator is critical — it determines WHICH parameters each term updates.

Your task: Explain why we can't just use L = ||x − x̂||² + ||ze − e||². What goes wrong without the stop-gradients? Why do we need the commitment loss as a separate term?

Without sg, this single term pushes BOTH the encoder output ze toward e AND the codebook entry e toward ze. Both move simultaneously. The problem: they can drift together in any direction, wandering away from useful representations. There's no anchor.
||sg[ze] − e||² only has gradients w.r.t. e (codebook). It says: "move codebook entries toward where the encoder is pointing." β||ze − sg[e]||² only has gradients w.r.t. ze (encoder). It says: "keep encoder outputs close to their assigned codebook entries." Each has a clear job.
The encoder already gets gradients from the reconstruction loss (via straight-through). The commitment loss is auxiliary — it just prevents the encoder from moving too fast for the codebook to follow. A small β (0.25) is enough gentle pressure without conflicting with reconstruction gradients.

The three-term decomposition:

Term 1: ||x − x̂||² — Reconstruction. Gradients flow through straight-through estimator to encoder, and directly to decoder. Updates: encoder + decoder.

Term 2: ||sg[ze] − e||² — Codebook learning. sg[ze] means "treat ze as a constant." Only e receives gradients. This moves codebook entries toward the encoder outputs they're matched to. In practice, replaced by EMA update (more stable). Updates: codebook only.

Term 3: β||ze − sg[e]||² — Commitment. sg[e] means "treat codebook entries as constants." Only ze receives gradients. This prevents the encoder from producing outputs that jump erratically between codebook entries. Updates: encoder only.

Without separation: a naive ||ze − e||² causes co-adaptation — encoder and codebook chase each other with no stable equilibrium. The stop-gradients create an alternating optimization: codebook adapts to encoder (term 2), encoder stays near codebook (term 3), and reconstruction drives the overall representation quality (term 1).

The key insight: Stop-gradient is a coordination mechanism. It prevents the "two magnets pulling each other" instability by saying: "only ONE thing moves at a time in each loss term." This is why VQ-VAE training is stable despite the fundamentally non-differentiable quantization step.

💻 Build It Implement the VQ Straight-Through Estimator ✓ ATTEMPTED
The encoder outputs continuous vectors. The codebook quantizes them. But argmin has zero gradient. Implement the straight-through trick that makes this trainable: forward pass uses the discrete codebook vector, backward pass copies gradients directly to the encoder output as if quantization didn't happen.
signature def vq_straight_through(z_e: Tensor, codebook: Tensor) -> Tuple[Tensor, Tensor]: """ Args: z_e: encoder output [B, D] (continuous) codebook: codebook embeddings [K, D] Returns: z_q_st: quantized with straight-through gradient [B, D] indices: codebook indices chosen [B] """
Test case
z_e = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
codebook = torch.tensor([[0.9, 0.1], [0.1, 0.9], [-0.5, -0.5]])
z_q_st, indices = vq_straight_through(z_e, codebook)
# indices should be [0, 1] (nearest entries)
# z_q_st should be [[0.9, 0.1], [0.1, 0.9]] (codebook values)
# BUT: z_q_st.requires_grad should be True (gradients flow to z_e)
.detach() creates a tensor with the same values but cut off from the computation graph. So (z_q - z_e).detach() has the value of the quantization error but zero gradient. Adding it to z_e means: forward value = z_e + (z_q - z_e) = z_q, but backward gradient flows only through z_e (since the detached part contributes zero gradient).
python
def vq_straight_through(z_e, codebook):
    # Step 1: compute distances [B, K]
    dists = torch.cdist(z_e, codebook)  # Euclidean distance

    # Step 2: find nearest codebook entry
    indices = dists.argmin(dim=1)  # [B]

    # Step 3: look up codebook vectors
    z_q = codebook[indices]  # [B, D]

    # Step 4: straight-through estimator
    # Forward: value is z_q (discrete codebook vector)
    # Backward: gradient goes to z_e (as if no quantization)
    z_q_st = z_e + (z_q - z_e).detach()

    return z_q_st, indices
Bonus challenge: Extend this to handle spatial feature maps [B, D, H, W] where quantization happens independently at each spatial location. You'll need to reshape to [B*H*W, D], quantize, then reshape back.

Chapter 6: Codebook Learning

The codebook is only useful if all entries are active. A common failure: codebook collapse — the encoder only uses a handful of entries while the rest go "dead." This wastes representational capacity.

Exponential Moving Average (EMA) updates are a popular fix. Instead of updating codebook entries with gradient descent, track the running average of all encoder outputs assigned to each entry. This is faster and more stable.

ek ← γ · ek + (1 − γ) · mean(assigned ze)

In practice, γ = 0.99 works well. For a codebook of K = 8192, D = 256, the EMA update processes all assigned vectors per batch in one vectorized operation. If entry k was assigned to 50 encoder outputs this batch, its new value is 0.99 × old + 0.01 × mean(those 50 vectors). Entries assigned to zero vectors don't update at all — this is exactly how they die.

python
# EMA codebook update
for k in range(K):
    mask = (indices == k)           # which encoder outputs chose this entry
    if mask.sum() == 0:             # dead code! no one chose it
        continue
    avg = z_flat[mask].mean(dim=0) # mean of assigned vectors
    codebook[k] = 0.99 * codebook[k] + 0.01 * avg
Codebook Utilization

Each bar is a codebook entry. Height = usage count. Red entries are dead (unused). Watch how dead code revival redistributes them.

StrategyHow it works
EMA updatesRunning average of assigned vectors; no gradient needed for codebook
Dead code revivalReplace unused entries with randomly sampled encoder outputs
Codebook resetPeriodically re-initialize low-usage entries from data (k-means style)
Larger codebookMore entries = finer granularity, but harder to keep all alive
Rule of thumb: Codebook utilization above 90% is healthy. Below 50% means half your representational capacity is wasted. Monitor this during training.

Dead code revival works by tracking usage counts per codebook entry. When an entry's count drops below a threshold (say, 1 assignment in the last 100 batches), it's declared dead. The revival strategy: replace the dead entry with the encoder output of a randomly chosen training example. This teleports the dead entry into a region of latent space that the encoder actually uses, giving it a fresh chance to attract assignments. After revival, the EMA update takes over and refines its position.

Check: What is "codebook collapse"?
⚔ Adversarial: Your VQ-VAE uses only 12 of 8192 codebook entries after training. Reconstruction is decent. What happened?
You trained a VQ-VAE with K=8192, D=256 on ImageNet. After 50 epochs, the reconstruction loss is reasonable (PSNR ~28 dB), but monitoring reveals only 12 codebook entries ever get selected. The other 8180 are never assigned to any encoder output.

Chapter 7: FSQ — Finite Scalar Quantization

VQ-VAE's codebook is elegant but fragile: you need to manage dead codes, tune commitment loss, and balance EMA rates. FSQ takes a radically simpler approach: instead of learning a codebook, just round each scalar to a small set of levels.

If each of d dimensions has L levels, you get Ld possible codes — an implicit codebook. With d=6 and L=5, that's 56=15,625 codes. No codebook parameters. No collapse. No commitment loss. Just rounding.

i = round(Li · tanh(zi))

The implementation is almost embarrassingly simple. The encoder outputs a d-dimensional vector. Apply tanh to bound it to [−1, 1]. Scale by L. Round to the nearest integer. That's quantization. The straight-through estimator handles gradients: forward uses the rounded value, backward pretends the rounding didn't happen.

python
# FSQ: 4 lines of code for quantization
z = encoder(x)                       # [B, d] continuous
z_bounded = torch.tanh(z)            # [B, d] in [-1, 1]
z_scaled = z_bounded * (L - 1) / 2  # scale to [-L/2, L/2]
z_q = z_scaled.round()               # snap to integers
z_q = z_scaled + (z_q - z_scaled).detach()  # straight-through
FSQ: Scalar Rounding

The continuous encoder output (left axis) is rounded to discrete levels (right axis). Adjust levels per dimension to see how granularity changes. More levels = finer representation.

Levels L5
Dimensions d4
Implicit codebook size: 625
PropertyVQ-VAEFSQ
Codebook typeExplicit: nn.Embedding(K, D)Implicit: Ld codes from rounding
Extra parametersK × D floats (e.g. 8192 × 256 = 2M)Zero
Collapse riskHigh — needs EMA + revivalNone — all codes are equally accessible
Gradient trickStraight-through + commitment lossStraight-through on rounding
Reconstruction qualitySlightly better (learned codes adapt)Comparable in practice
Code~50 lines for quantizer~5 lines
VQ-VAE: Learned codebook. Flexible but fragile. Needs EMA, dead code revival, commitment loss.
FSQ: Implicit codebook via rounding. Simple, stable, no collapse. Slightly less flexible.
Check: How does FSQ avoid codebook collapse?

Chapter 8: Image / Video Tokenizers

Modern generative models don't operate on raw pixels. They first tokenize images into compact representations using a VAE or VQ-VAE, then model the distribution of those representations using a transformer or diffusion model. This is the architecture behind Stable Diffusion, DALL-E, and MAGVIT.

The tokenizer compresses a 256×256 image to, say, a 32×32 grid of codebook indices. That's 1,024 tokens instead of 196,608 pixel values — a 192× compression. A transformer can then model these tokens autoregressively, just like words.

Each system uses a slightly different flavor. Stable Diffusion uses a continuous latent space (KL-regularized, not VQ), so the "tokens" are 4-channel feature maps at 64×64. DALL-E 1 used a proper discrete VQ-VAE codebook — each of its 32×32 spatial positions becomes a single integer index, giving 1,024 tokens that a GPT-style transformer can generate one by one. MAGVIT-2 pushes codebook size to 218 = 262,144 entries using lookup-free quantization, achieving near-perfect reconstruction.

Image 256×256
196,608 pixel values
↓ VQ-VAE Encoder
Token Grid 32×32
1,024 codebook indices
↓ Transformer / Diffusion
Generate New Tokens
Model the token distribution
↓ VQ-VAE Decoder
New Image
Decode tokens back to pixels
SystemTokenizerGeneratorYear
DALL-E 1dVAE (discrete VAE)Autoregressive transformer2021
Stable DiffusionKL-regularized AELatent diffusion model2022
MAGVIT-2LFQ (lookup-free quantization)Masked transformer2023
CosmosCausal VQ-VAEAutoregressive + diffusion2024
Key insight: Tokenizer quality is the ceiling for generation quality. If the tokenizer can't reconstruct fine details, no amount of transformer magic can bring them back. This is why teams invest heavily in tokenizer design.

Think of the tokenizer as a lossy camera: everything downstream sees the world through this lens. If the tokenizer drops fine text, subtle textures, or thin structures, no generative model can hallucinate them back faithfully. This is why MAGVIT-2 pushed codebook size to 262K entries — more codes = finer-grained visual vocabulary = less information lost at the tokenization boundary.

For video, the challenge compounds: you need temporal consistency between frames. A video tokenizer like Cosmos's causal VQ-VAE processes frames in sequence, ensuring that the latent representation of frame t depends on frames 0...t-1. Without this, independently tokenized frames produce flickering artifacts when decoded. The spatial compression might be 8× per axis (256→32), and the temporal compression 4× (32 frames → 8 latent frames), giving a combined compression of 8×8×4 = 256×.

Check: Why do generative models use tokenizers instead of raw pixels?
🏗 Design Challenge You're the Architect: Design a VQ-VAE Image Tokenizer for Latent Diffusion ✓ ATTEMPTED
You're building the tokenizer for a Stable-Diffusion-class system. The diffusion model will operate in YOUR latent space. Your choices directly determine the quality ceiling, inference speed, and GPU memory cost of the entire generation pipeline. The target: 256×256 input images, consumer GPU inference (8GB VRAM), high-fidelity reconstruction (PSNR > 30 dB).
Input resolution
256×256×3 RGB
Target VRAM (inference)
≤ 8 GB total (tokenizer + diffusion model)
Diffusion model
~400M param U-Net, runs 50 steps in latent space
Reconstruction target
PSNR > 30 dB, no visible block artifacts
Tokenizer budget
≤ 100M parameters for encoder + decoder
1. Spatial downsampling factor: 4×, 8×, or 16×? (256→64, 256→32, or 256→16?) Each 2× reduction cuts diffusion FLOPs by 4× but loses spatial detail.
2. Latent channels: 4, 8, or 16? More channels = more information per spatial location, but the diffusion model sees a larger "image" (channels act like extra pixels).
3. Continuous (KL-regularized) or discrete (VQ) latent space? The diffusion model needs to generate latents — which is easier to diffuse into?
4. If discrete: codebook size K? If continuous: how strong should KL regularization be?
5. Loss function: MSE only, or add perceptual + adversarial? What's the tradeoff in training cost vs. quality?

Stable Diffusion 1.x/2.x tokenizer choices:

1. Spatial: 8× downsampling (256→32 or 512→64). Why: 4× keeps too much spatial redundancy (diffusion is 4× slower). 16× loses fine detail (faces become blobs). 8× is the sweet spot — enough detail for sharp generation without bankrupting the diffusion budget.

2. Channels: 4. This gives [B, 4, 32, 32] = 4,096 latent values from a 196,608-pixel input. The 4 channels are much less than the 3 RGB channels × 64 spatial positions you might expect — aggressive compression.

3. Continuous (KL-regularized). Diffusion naturally handles continuous Gaussians. A discrete codebook would require dequantization noise or a separate discrete diffusion formulation. Continuous is simpler and works better for the Gaussian diffusion framework.

4. KL weight: ~1e-6 (nearly zero). Just enough to keep latents vaguely Gaussian-distributed so the diffusion model's N(0,1) noise schedule makes sense. The focus is reconstruction quality, not latent regularity.

5. L1 + perceptual (VGG) + adversarial (PatchGAN). MSE alone = blurry. Perceptual loss preserves structure. Adversarial loss adds sharpness. Training is 3× more expensive (discriminator + VGG forward passes) but the quality gain is dramatic.

Design lesson: The tokenizer is optimized for the DOWNSTREAM task. Stable Diffusion's VAE is barely a VAE — the KL is negligible. It's really a perceptual autoencoder with just enough Gaussian structure for diffusion to work.

🔗 Pattern Recognition
Latent Space as Shared Infrastructure
This Lesson (VAE/VQ-VAE)
Encoder compresses pixels → latent z
Decoder reconstructs pixels from z
Train once, share with downstream models
Diffusion Models
Diffusion operates IN the VAE's latent space
Noise → denoise in z-space, not pixel-space
50× cheaper per step → Diffusion lesson

The VAE creates a "compression layer" that multiple generative models can share. Stable Diffusion, SDXL, and Flux all use the SAME VAE encoder/decoder — only the generation backbone (U-Net vs DiT) changes. This is amortized infrastructure: train the VAE once (~84M params, cheap), then reuse it across model generations. The VAE latent space is the "pixel format" of modern generation.

Can you think of other systems where an expensive compression step is done once and then multiple downstream tasks operate in the compressed space? (Hint: text embeddings from a frozen encoder, used by retrieval + generation + classification.)

Chapter 9: VAEs in the Wild

VAEs and VQ-VAEs aren't just academic exercises — they're load-bearing infrastructure in the biggest generative AI systems. Stable Diffusion's latent space? A VAE. DALL-E's image tokens? A VQ-VAE. Every video model you've seen? Some flavor of temporal VQ-VAE.

Diffusion models operate in VAE latent space. The VAE compresses 512×512×3 to 64×64×4, making diffusion computationally tractable.
Flow matching models (Stable Diffusion 3, Flux) also use the VAE latent space. Same tokenizer, different generative backbone.

Why does this 48× compression matter so much? Diffusion models need to run the U-Net (or transformer) hundreds of times per image. On raw pixels at 512×512×3 = 786K dimensions, each U-Net pass would cost ~100× more FLOPs than at 64×64×4 = 16K dimensions. The VAE encoder runs once, diffusion runs ~50 steps in the tiny latent space, then the VAE decoder runs once. It's the VAE that makes latent diffusion feasible on consumer GPUs.

Stable Diffusion's VAE is trained separately from the diffusion model. It uses a KL penalty (like a standard VAE) but tuned very lightly — the KL weight is small (around 1e-6), so the latent space is barely regularized. The focus is reconstruction quality: the VAE must encode and decode images with minimal perceptual loss, because every artifact in the VAE gets inherited by every generated image.

The training recipe for Stable Diffusion's VAE includes three loss components: (1) L1 pixel reconstruction, (2) perceptual loss using VGG features (compares high-level structure, not pixel values), and (3) adversarial loss from a PatchGAN discriminator (penalizes blurriness locally). The KL term is almost an afterthought — just enough to keep the latent distribution vaguely Gaussian so the diffusion model has a reasonable starting distribution.

python
# Stable Diffusion VAE loss (simplified)
loss = (1.0  * l1_loss(x, x_hat)        # pixel reconstruction
      + 1.0  * perceptual_loss(x, x_hat)  # VGG feature matching
      + 0.5  * adversarial_loss(x_hat)    # PatchGAN discriminator
      + 1e-6 * kl_loss(mu, log_var))     # tiny KL regularization
VAE Family Tree

How the ideas connect. The original autoencoder begat a family of models that now power every major generative AI system.

ApplicationVAE VariantRole
Stable DiffusionKL-VAECompress images to/from latent space
DALL-E 1dVAEConvert images to discrete tokens
SoraSpatial-temporal VAETokenize video frames + motion
AudioLMSoundStream (VQ-VAE)Tokenize audio waveforms
Drug discoveryMolecular VAESmooth latent space for molecule optimization
"The art of compression is the art of understanding."
— paraphrase of Kolmogorov

You now understand latent spaces, variational inference, vector quantization, and how they power modern AI. Every generated image you see started as a latent code.