The compression engines behind image generation. Learn how neural nets learn to encode, quantize, and reconstruct the visual world.
A 256×256 colour image has 196,608 numbers. But most of those numbers are redundant: smooth patches, repeated textures, predictable edges. The real information in an image — its content, composition, style — can be captured by far fewer numbers. A latent space is where those compressed essentials live.
The encoder-decoder idea is simple: an encoder squeezes the input into a small latent vector z, and a decoder tries to reconstruct the original from z alone. If the reconstruction is good, z must contain the essence of the data.
What does this look like in a real system? Stable Diffusion's VAE takes a 512×512 RGB image — a tensor of shape [1, 3, 512, 512] — and compresses it to a latent of shape [1, 4, 64, 64]. That's 786,432 input values squeezed into 16,384 latent values: a 48× compression. Every image you've seen generated by Stable Diffusion started life as a tiny 64×64×4 grid.
Drag the latent dimension to see how much compression we achieve. The bar shows what fraction of the original information we keep.
A plain autoencoder trains two neural networks end-to-end: the encoder f and decoder g. The loss is simply the reconstruction error: how different is g(f(x)) from x? The bottleneck — a narrow latent layer — forces compression.
It works. But there's a problem: the latent space is messy. Points are scattered unpredictably. If you pick a random point in latent space and decode it, you get garbage. The autoencoder memorizes efficient codes but doesn't organize them.
Concretely, the encoder is a stack of convolutional layers that progressively downsample. An image [batch, 3, 256, 256] passes through conv → downsample → conv → downsample until the spatial dimensions shrink and the channel count grows: [batch, 256, 32, 32]. A final linear layer (or global average pooling) flattens this to the latent vector [batch, latent_dim]. The decoder mirrors this process: upsample → conv → upsample until you're back to the original resolution.
python # Minimal encoder architecture class Encoder(nn.Module): def forward(self, x): # x: [B, 3, 256, 256] x = self.conv1(x) # [B, 64, 128, 128] x = self.conv2(x) # [B, 128, 64, 64] x = self.conv3(x) # [B, 256, 32, 32] x = x.flatten(1) # [B, 256*32*32] return self.fc(x) # [B, latent_dim]
Left: messy autoencoder latent space — clusters with gaps. Right: organized VAE latent space — smooth and continuous. Click to regenerate.
To make this concrete: train an autoencoder on face images. Encode two faces and get z1 and z2. Decode the midpoint (z1 + z2) / 2 — you might get a coherent blend, or you might get static noise. The autoencoder never promised that midpoints would be meaningful. It only learned to reconstruct, not to organize. The latent space is full of "dead zones" where decoded output is garbage.
The VAE's fix is elegant: force every encoded point to look like it came from a standard Gaussian. Then sampling is trivial — just draw z ~ N(0,1) and decode. But this constraint comes at a cost: the encoder can't use arbitrary regions of latent space anymore, which limits reconstruction fidelity. This tension defines the entire field.
The key insight of the VAE: instead of encoding x to a single point z, encode it to a distribution — specifically, a Gaussian with mean μ and variance σ². Then sample z from that distribution. This forces the latent space to be smooth.
In code, the encoder doesn't output one vector. It outputs two: μ (the mean) and log σ² (the log-variance). Why log-variance instead of variance? Because variance must be positive, and log-variance is unconstrained — the network can output any real number. To get the standard deviation: std = exp(0.5 * log_var).
python # Encoder outputs TWO heads, not one h = encoder_backbone(x) # [B, 256, 32, 32] → flatten → [B, hidden] mu = self.fc_mu(h) # [B, latent_dim] log_var = self.fc_logvar(h) # [B, latent_dim] # Reparameterization trick std = torch.exp(0.5 * log_var) # [B, latent_dim] eps = torch.randn_like(std) # [B, latent_dim] ~ N(0,1) z = mu + std * eps # [B, latent_dim] — differentiable!
But wait: sampling is not differentiable. How do we backpropagate through a random number? The reparameterization trick: instead of z ~ N(μ, σ²), write z = μ + σ · ε where ε ~ N(0,1). Now the randomness (ε) is external, and gradients flow through μ and σ.
Why does this matter so much? If we sampled z directly from N(μ, σ²), the sampling operation has no gradient with respect to μ or σ. Backprop hits a wall. But z = μ + σ · ε is a deterministic function of μ, σ, and the random noise ε. Gradients flow through μ and σ just fine — ε is treated as a constant input. The stochasticity is still there (each forward pass samples a different ε), but the computation graph is smooth.
Adjust μ and σ. Each frame samples a new ε and shows the resulting z. The orange curve is the distribution; teal dots are samples.
Let's trace the gradient path to make this visceral. During backprop, we need ∂L/∂μ and ∂L/∂log_var. With reparameterization: ∂z/∂μ = 1 and ∂z/∂σ = ε. Both are simple, well-defined numbers. Without reparameterization, z is drawn from N(μ, σ²) and ∂z/∂μ doesn't exist — sampling is a black box. The trick turns a stochastic bottleneck into a deterministic computation with a random input.
We need ∂L/∂μ and ∂L/∂σ to train the encoder. With z sampled from N(μ, σ²), the computation graph has a stochastic node. With z = μ + σ·ε, it doesn't.
Your task: Show that ∂z/∂μ and ∂z/∂σ are well-defined under the reparameterization, and explain why the "naive" sampling has no gradient.
Full derivation:
Let L be the loss computed from the decoder output. We need ∂L/∂μ via the chain rule:
∂L/∂μ = (∂L/∂z) · (∂z/∂μ)
With reparameterization z = μ + σ·ε:
∂z/∂μ = 1 ∂z/∂σ = ε
So: ∂L/∂μ = ∂L/∂z and ∂L/∂σ = ε · ∂L/∂z
Without reparameterization, z = sample(N(μ, σ²)). The "sample" operation is discontinuous — tiny changes in μ can produce arbitrarily different z values depending on internal RNG state. There is no function z(μ) to differentiate. The gradient ∂z/∂μ is undefined.
The key insight: The trick moves the source of randomness outside the computation graph. ε enters as a "data input," not as an operation. The graph becomes: ε (input) → multiply by σ → add μ → z. Every arrow is differentiable.
We want to maximize the likelihood of our data: p(x). But computing p(x) directly requires integrating over all possible latent codes z, which is intractable. The integral p(x) = ∫ p(x|z) p(z) dz sums over every possible latent code. With a 64-dimensional z, that's integration over a 64-dimensional space — computationally hopeless.
Instead, we optimize a lower bound on log p(x): the Evidence Lower Bound (ELBO). The gap between log p(x) and the ELBO is exactly KL(q(z|x) || p(z|x)) — the KL divergence between our approximate posterior and the true posterior. Since KL is always ≥ 0, the ELBO is always ≤ log p(x). Maximizing the ELBO simultaneously tightens the bound and improves the model.
The ELBO has two terms: a reconstruction term (how well can we decode z back to x?) and a KL divergence term (how close is our learned distribution q(z|x) to the prior p(z) = N(0,1)?). The reconstruction term wants good decoding; the KL term wants an organized latent space.
Let's make this concrete with numbers. Suppose for one image, the encoder outputs μ = 3.0, σ = 0.2 for some latent dimension. The KL term for that dimension is ½(3.0² + 0.04 − log 0.04 − 1) = ½(9 + 0.04 + 3.22 − 1) = 5.63. That's a big penalty — the KL is screaming "your mean is way off from zero!" If instead μ = 0.1, σ = 0.9, the KL is ½(0.01 + 0.81 + 0.21 − 1) = 0.015. Almost free. The reconstruction term pushes back: "but I need μ = 3.0 to encode this cat!" The balance between these two forces IS the VAE.
Adjust the balance between reconstruction quality and KL penalty to see the tradeoff. Low KL = organized latent space but blurrier. High reconstruction = sharp but messy latent space.
We want to maximize log p(x) but can't compute it directly. We introduce an approximate posterior q(z|x) and derive a tractable lower bound.
Your task: Starting from log p(x) = log ∫ p(x,z) dz, introduce q(z|x) and derive: log p(x) ≥ Eq[log p(x|z)] − KL(q(z|x) || p(z)).
Full derivation:
Step 1: log p(x) = log ∫ p(x,z) dz = log ∫ p(x,z) · q(z|x)/q(z|x) dz
Step 2: = log Eq(z|x)[p(x,z) / q(z|x)]
Step 3 (Jensen's): ≥ Eq(z|x)[log(p(x,z) / q(z|x))]
Step 4: = Eq[log p(x|z) + log p(z) − log q(z|x)]
Step 5: = Eq[log p(x|z)] − Eq[log q(z|x) − log p(z)]
Step 6: = Eq[log p(x|z)] − KL(q(z|x) || p(z))
The gap between log p(x) and the ELBO is exactly KL(q(z|x) || p(z|x)) — the divergence between our approximate posterior and the TRUE posterior. When q perfectly approximates p(z|x), the bound is tight.
The key insight: We can't compute p(z|x) (that's what we're trying to learn!), so we optimize a lower bound instead. Maximizing the ELBO simultaneously improves the model (makes the bound higher) and tightens the approximation (makes q closer to the true posterior).
The KL term IS the rate: it measures how many bits the encoder uses to communicate z. The reconstruction term IS the distortion: how much quality we lose. The β in β-VAE literally traces the rate-distortion curve — higher β means lower rate (more compression) at the cost of higher distortion (blurrier images). This is exactly the same tradeoff that JPEG, MP3, and every lossy codec faces.
Where else have you seen "compress information through a bottleneck then reconstruct"? (Hint: attention mechanisms compress T tokens through a T×T matrix, forcing the model to select which information survives.)
Computing log p(x) requires integrating p(x|z)p(z) over ALL possible z values — a high-dimensional integral with no closed form. We can't even evaluate p(x), let alone optimize it.
The ELBO sidesteps this by introducing q(z|x): a learned approximation to the true posterior p(z|x). Instead of integrating over all z, we sample z from q(z|x) and evaluate log p(x|z). This is tractable — it's just running the decoder on one sampled z and comparing to x.
The bound is useful because: (1) maximizing a lower bound on log p(x) also increases log p(x) itself, and (2) the gap between them shrinks as q gets better. So one loss function simultaneously trains the encoder (make q better), the decoder (make p(x|z) better), and regularizes the latent space (keep q close to the prior).
In practice, the KL term has a closed-form solution for Gaussians. For each latent dimension j, the KL divergence is: ½(μj² + σj² − log σj² − 1). This is cheap to compute and differentiable. The total KL sums over all latent dimensions.
The big practical knob is β: a weight on the KL term. With β=1, you get the standard VAE (ELBO). With β>1, you get β-VAE, which forces more disentanglement at the cost of blurrier reconstructions. With β<1, you get sharper images but messier latent space.
Why do VAEs produce blurry images? The reconstruction loss is usually MSE: ||x − x̂||². MSE penalizes pixel-level differences, so the decoder hedges its bets — when uncertain, it outputs the average, which is blurry. Modern VAEs fix this with perceptual loss (compare features from a pre-trained VGG network instead of raw pixels) or adversarial loss (add a discriminator that penalizes blurriness). Stable Diffusion's VAE uses both, which is why its reconstructions are sharp.
Let's trace through a concrete example. Suppose latent_dim = 4 and the encoder outputs:
| Dim j | μj | σj² | KLj | Interpretation |
|---|---|---|---|---|
| 0 | 0.1 | 0.9 | 0.008 | Near standard normal — almost free |
| 1 | 2.5 | 0.1 | 4.0 | Mean far from zero — big penalty |
| 2 | 0.0 | 0.01 | 1.8 | Variance too small — "cheating" by being too certain |
| 3 | 0.3 | 1.2 | 0.06 | Slightly off — small penalty |
Total KL = 5.87. With β = 1 and a reconstruction loss of, say, 150, the KL is only ~4% of the total — hence the model mostly focuses on reconstruction. With β = 10, KL becomes 58.7, nearly half the loss — now the model is strongly pushed toward N(0,1).
Low β: sharp but unstructured. High β: blurry but well-organized latent space. Watch how the "reconstruction" and "organization" bars respond.
| β value | Reconstruction | Latent structure | Use case |
|---|---|---|---|
| β < 1 | Sharp | Messy | When quality matters most |
| β = 1 | Balanced | Good | Standard VAE (ELBO) |
| β > 1 | Blurry | Disentangled | β-VAE for interpretable factors |
Here's the full training loop in code. Notice how each loss component maps to a specific part of the model:
python # Forward pass mu, log_var = encoder(x) # two heads std = torch.exp(0.5 * log_var) z = mu + std * torch.randn_like(std) # reparam trick x_hat = decoder(z) # reconstruction # Reconstruction loss (applies to decoder) recon_loss = F.mse_loss(x_hat, x, reduction='sum') # KL loss (applies to encoder's mu and log_var heads) kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) # Total loss = recon_loss + beta * kl_loss
What if latent codes were discrete instead of continuous? VQ-VAE replaces the Gaussian latent space with a codebook: a table of K learned vectors, each of dimension D. In code, the codebook is just an nn.Embedding(K, D) — a matrix of shape [K, D]. Typical values: K = 8192 entries, D = 256 dimensions.
The encoder outputs a continuous feature map [batch, D, H, W] — say [1, 256, 32, 32]. For each of the 32×32 = 1,024 spatial locations, we find the nearest codebook entry and replace it. That's 1,024 independent nearest-neighbor lookups. The decoder only ever sees codebook entries, not the raw encoder output.
Blue dots are codebook entries. Drag the orange point (encoder output) and watch it snap to the nearest codebook entry. The green line shows the assignment.
The loss has three parts, and each applies to different parameters:
| Loss term | What it does | Updates what |
|---|---|---|
| ||x − x̂||² | Reconstruction — decoded image matches input | Encoder + Decoder weights |
| ||sg[ze] − e||² | Codebook — move entries toward encoder outputs | Codebook entries only (sg freezes encoder) |
| β||ze − sg[e]||² | Commitment — keep encoder near chosen entries | Encoder only (sg freezes codebook) |
z_q = z_e + (z_q - z_e).detach(). This one line makes VQ-VAE trainable.python # VQ-VAE quantization (per spatial location) z_e = encoder(x) # [B, D, 32, 32] z_flat = z_e.permute(0,2,3,1).reshape(-1, D) # [B*1024, D] dists = torch.cdist(z_flat, codebook.weight) # [B*1024, K] indices = dists.argmin(dim=1) # [B*1024] z_q = codebook(indices) # [B*1024, D] # Straight-through: forward=discrete, backward=continuous z_q_st = z_e + (z_q.reshape_as(z_e) - z_e).detach()
The VQ-VAE loss is: L = ||x − x̂||² + ||sg[ze] − e||² + β||ze − sg[e]||². The sg[] (stop-gradient) operator is critical — it determines WHICH parameters each term updates.
Your task: Explain why we can't just use L = ||x − x̂||² + ||ze − e||². What goes wrong without the stop-gradients? Why do we need the commitment loss as a separate term?
The three-term decomposition:
Term 1: ||x − x̂||² — Reconstruction. Gradients flow through straight-through estimator to encoder, and directly to decoder. Updates: encoder + decoder.
Term 2: ||sg[ze] − e||² — Codebook learning. sg[ze] means "treat ze as a constant." Only e receives gradients. This moves codebook entries toward the encoder outputs they're matched to. In practice, replaced by EMA update (more stable). Updates: codebook only.
Term 3: β||ze − sg[e]||² — Commitment. sg[e] means "treat codebook entries as constants." Only ze receives gradients. This prevents the encoder from producing outputs that jump erratically between codebook entries. Updates: encoder only.
Without separation: a naive ||ze − e||² causes co-adaptation — encoder and codebook chase each other with no stable equilibrium. The stop-gradients create an alternating optimization: codebook adapts to encoder (term 2), encoder stays near codebook (term 3), and reconstruction drives the overall representation quality (term 1).
The key insight: Stop-gradient is a coordination mechanism. It prevents the "two magnets pulling each other" instability by saying: "only ONE thing moves at a time in each loss term." This is why VQ-VAE training is stable despite the fundamentally non-differentiable quantization step.
python def vq_straight_through(z_e, codebook): # Step 1: compute distances [B, K] dists = torch.cdist(z_e, codebook) # Euclidean distance # Step 2: find nearest codebook entry indices = dists.argmin(dim=1) # [B] # Step 3: look up codebook vectors z_q = codebook[indices] # [B, D] # Step 4: straight-through estimator # Forward: value is z_q (discrete codebook vector) # Backward: gradient goes to z_e (as if no quantization) z_q_st = z_e + (z_q - z_e).detach() return z_q_st, indices
The codebook is only useful if all entries are active. A common failure: codebook collapse — the encoder only uses a handful of entries while the rest go "dead." This wastes representational capacity.
Exponential Moving Average (EMA) updates are a popular fix. Instead of updating codebook entries with gradient descent, track the running average of all encoder outputs assigned to each entry. This is faster and more stable.
In practice, γ = 0.99 works well. For a codebook of K = 8192, D = 256, the EMA update processes all assigned vectors per batch in one vectorized operation. If entry k was assigned to 50 encoder outputs this batch, its new value is 0.99 × old + 0.01 × mean(those 50 vectors). Entries assigned to zero vectors don't update at all — this is exactly how they die.
python # EMA codebook update for k in range(K): mask = (indices == k) # which encoder outputs chose this entry if mask.sum() == 0: # dead code! no one chose it continue avg = z_flat[mask].mean(dim=0) # mean of assigned vectors codebook[k] = 0.99 * codebook[k] + 0.01 * avg
Each bar is a codebook entry. Height = usage count. Red entries are dead (unused). Watch how dead code revival redistributes them.
| Strategy | How it works |
|---|---|
| EMA updates | Running average of assigned vectors; no gradient needed for codebook |
| Dead code revival | Replace unused entries with randomly sampled encoder outputs |
| Codebook reset | Periodically re-initialize low-usage entries from data (k-means style) |
| Larger codebook | More entries = finer granularity, but harder to keep all alive |
Dead code revival works by tracking usage counts per codebook entry. When an entry's count drops below a threshold (say, 1 assignment in the last 100 batches), it's declared dead. The revival strategy: replace the dead entry with the encoder output of a randomly chosen training example. This teleports the dead entry into a region of latent space that the encoder actually uses, giving it a fresh chance to attract assignments. After revival, the EMA update takes over and refines its position.
VQ-VAE's codebook is elegant but fragile: you need to manage dead codes, tune commitment loss, and balance EMA rates. FSQ takes a radically simpler approach: instead of learning a codebook, just round each scalar to a small set of levels.
If each of d dimensions has L levels, you get Ld possible codes — an implicit codebook. With d=6 and L=5, that's 56=15,625 codes. No codebook parameters. No collapse. No commitment loss. Just rounding.
The implementation is almost embarrassingly simple. The encoder outputs a d-dimensional vector. Apply tanh to bound it to [−1, 1]. Scale by L. Round to the nearest integer. That's quantization. The straight-through estimator handles gradients: forward uses the rounded value, backward pretends the rounding didn't happen.
python # FSQ: 4 lines of code for quantization z = encoder(x) # [B, d] continuous z_bounded = torch.tanh(z) # [B, d] in [-1, 1] z_scaled = z_bounded * (L - 1) / 2 # scale to [-L/2, L/2] z_q = z_scaled.round() # snap to integers z_q = z_scaled + (z_q - z_scaled).detach() # straight-through
The continuous encoder output (left axis) is rounded to discrete levels (right axis). Adjust levels per dimension to see how granularity changes. More levels = finer representation.
| Property | VQ-VAE | FSQ |
|---|---|---|
| Codebook type | Explicit: nn.Embedding(K, D) | Implicit: Ld codes from rounding |
| Extra parameters | K × D floats (e.g. 8192 × 256 = 2M) | Zero |
| Collapse risk | High — needs EMA + revival | None — all codes are equally accessible |
| Gradient trick | Straight-through + commitment loss | Straight-through on rounding |
| Reconstruction quality | Slightly better (learned codes adapt) | Comparable in practice |
| Code | ~50 lines for quantizer | ~5 lines |
Modern generative models don't operate on raw pixels. They first tokenize images into compact representations using a VAE or VQ-VAE, then model the distribution of those representations using a transformer or diffusion model. This is the architecture behind Stable Diffusion, DALL-E, and MAGVIT.
The tokenizer compresses a 256×256 image to, say, a 32×32 grid of codebook indices. That's 1,024 tokens instead of 196,608 pixel values — a 192× compression. A transformer can then model these tokens autoregressively, just like words.
Each system uses a slightly different flavor. Stable Diffusion uses a continuous latent space (KL-regularized, not VQ), so the "tokens" are 4-channel feature maps at 64×64. DALL-E 1 used a proper discrete VQ-VAE codebook — each of its 32×32 spatial positions becomes a single integer index, giving 1,024 tokens that a GPT-style transformer can generate one by one. MAGVIT-2 pushes codebook size to 218 = 262,144 entries using lookup-free quantization, achieving near-perfect reconstruction.
| System | Tokenizer | Generator | Year |
|---|---|---|---|
| DALL-E 1 | dVAE (discrete VAE) | Autoregressive transformer | 2021 |
| Stable Diffusion | KL-regularized AE | Latent diffusion model | 2022 |
| MAGVIT-2 | LFQ (lookup-free quantization) | Masked transformer | 2023 |
| Cosmos | Causal VQ-VAE | Autoregressive + diffusion | 2024 |
Think of the tokenizer as a lossy camera: everything downstream sees the world through this lens. If the tokenizer drops fine text, subtle textures, or thin structures, no generative model can hallucinate them back faithfully. This is why MAGVIT-2 pushed codebook size to 262K entries — more codes = finer-grained visual vocabulary = less information lost at the tokenization boundary.
For video, the challenge compounds: you need temporal consistency between frames. A video tokenizer like Cosmos's causal VQ-VAE processes frames in sequence, ensuring that the latent representation of frame t depends on frames 0...t-1. Without this, independently tokenized frames produce flickering artifacts when decoded. The spatial compression might be 8× per axis (256→32), and the temporal compression 4× (32 frames → 8 latent frames), giving a combined compression of 8×8×4 = 256×.
Stable Diffusion 1.x/2.x tokenizer choices:
1. Spatial: 8× downsampling (256→32 or 512→64). Why: 4× keeps too much spatial redundancy (diffusion is 4× slower). 16× loses fine detail (faces become blobs). 8× is the sweet spot — enough detail for sharp generation without bankrupting the diffusion budget.
2. Channels: 4. This gives [B, 4, 32, 32] = 4,096 latent values from a 196,608-pixel input. The 4 channels are much less than the 3 RGB channels × 64 spatial positions you might expect — aggressive compression.
3. Continuous (KL-regularized). Diffusion naturally handles continuous Gaussians. A discrete codebook would require dequantization noise or a separate discrete diffusion formulation. Continuous is simpler and works better for the Gaussian diffusion framework.
4. KL weight: ~1e-6 (nearly zero). Just enough to keep latents vaguely Gaussian-distributed so the diffusion model's N(0,1) noise schedule makes sense. The focus is reconstruction quality, not latent regularity.
5. L1 + perceptual (VGG) + adversarial (PatchGAN). MSE alone = blurry. Perceptual loss preserves structure. Adversarial loss adds sharpness. Training is 3× more expensive (discriminator + VGG forward passes) but the quality gain is dramatic.
Design lesson: The tokenizer is optimized for the DOWNSTREAM task. Stable Diffusion's VAE is barely a VAE — the KL is negligible. It's really a perceptual autoencoder with just enough Gaussian structure for diffusion to work.
The VAE creates a "compression layer" that multiple generative models can share. Stable Diffusion, SDXL, and Flux all use the SAME VAE encoder/decoder — only the generation backbone (U-Net vs DiT) changes. This is amortized infrastructure: train the VAE once (~84M params, cheap), then reuse it across model generations. The VAE latent space is the "pixel format" of modern generation.
Can you think of other systems where an expensive compression step is done once and then multiple downstream tasks operate in the compressed space? (Hint: text embeddings from a frozen encoder, used by retrieval + generation + classification.)
VAEs and VQ-VAEs aren't just academic exercises — they're load-bearing infrastructure in the biggest generative AI systems. Stable Diffusion's latent space? A VAE. DALL-E's image tokens? A VQ-VAE. Every video model you've seen? Some flavor of temporal VQ-VAE.
Why does this 48× compression matter so much? Diffusion models need to run the U-Net (or transformer) hundreds of times per image. On raw pixels at 512×512×3 = 786K dimensions, each U-Net pass would cost ~100× more FLOPs than at 64×64×4 = 16K dimensions. The VAE encoder runs once, diffusion runs ~50 steps in the tiny latent space, then the VAE decoder runs once. It's the VAE that makes latent diffusion feasible on consumer GPUs.
Stable Diffusion's VAE is trained separately from the diffusion model. It uses a KL penalty (like a standard VAE) but tuned very lightly — the KL weight is small (around 1e-6), so the latent space is barely regularized. The focus is reconstruction quality: the VAE must encode and decode images with minimal perceptual loss, because every artifact in the VAE gets inherited by every generated image.
The training recipe for Stable Diffusion's VAE includes three loss components: (1) L1 pixel reconstruction, (2) perceptual loss using VGG features (compares high-level structure, not pixel values), and (3) adversarial loss from a PatchGAN discriminator (penalizes blurriness locally). The KL term is almost an afterthought — just enough to keep the latent distribution vaguely Gaussian so the diffusion model has a reasonable starting distribution.
python # Stable Diffusion VAE loss (simplified) loss = (1.0 * l1_loss(x, x_hat) # pixel reconstruction + 1.0 * perceptual_loss(x, x_hat) # VGG feature matching + 0.5 * adversarial_loss(x_hat) # PatchGAN discriminator + 1e-6 * kl_loss(mu, log_var)) # tiny KL regularization
How the ideas connect. The original autoencoder begat a family of models that now power every major generative AI system.
| Application | VAE Variant | Role |
|---|---|---|
| Stable Diffusion | KL-VAE | Compress images to/from latent space |
| DALL-E 1 | dVAE | Convert images to discrete tokens |
| Sora | Spatial-temporal VAE | Tokenize video frames + motion |
| AudioLM | SoundStream (VQ-VAE) | Tokenize audio waveforms |
| Drug discovery | Molecular VAE | Smooth latent space for molecule optimization |
You now understand latent spaces, variational inference, vector quantization, and how they power modern AI. Every generated image you see started as a latent code.