From toy 2D distributions to Stable Diffusion 3 and Meta Movie Gen — the engineering that makes it real.
In previous chapters, we trained flow matching models on simple 2D distributions — rings, checkerboards, mixtures of Gaussians. The neural network was a small MLP with a few hundred parameters. Now we want to generate 1024×1024 photorealistic images and 30-second videos. This requires answering three engineering questions:
Question 1: Architecture. An image x ∈ R3×1024×1024 has ~3 million dimensions. A generic MLP cannot process this efficiently. We need architectures that respect the spatial structure of images — transformers and U-Nets.
Question 2: Conditioning. The network uθt(x|y) must digest three very different inputs: a noisy image x, a scalar time t, and a conditioning variable y that could be a class label, text string, or another image. Each needs a tailored embedding strategy.
Question 3: Resolution. Training directly in the 3-million-dimensional pixel space is prohibitively expensive. The key insight is that natural images lie on a much lower-dimensional manifold. We can compress images to a latent space using an autoencoder, train the flow model there, and decode back.
The scalar time t ∈ [0,1] seems trivial to handle — just concatenate it to the input. But in practice, this works poorly for large models. The model needs to behave very differently at t=0 (input is mostly noise) versus t=1 (input is mostly data). A single scalar doesn't give the network enough "room" to represent these different behaviors.
The solution is Fourier features: embed t into a high-dimensional vector using sinusoidal functions. This is the same idea as positional encoding in transformers.
where the frequencies wi are logarithmically spaced:
Let's see how this works in practice:
python import torch import numpy as np def time_embedding(t, d=256, w_min=1.0, w_max=1000.0): """Fourier features for scalar time t.""" half_d = d // 2 # Log-spaced frequencies i = torch.arange(half_d, dtype=torch.float32) freqs = w_min * (w_max / w_min) ** (i / (half_d - 1)) # [cos, sin] embedding args = 2 * np.pi * freqs * t.unsqueeze(-1) # (B, d/2) emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) return emb * ((2.0 / d) ** 0.5) # normalize # Shape: (B, d) — one d-dimensional vector per sample
Drag t to see how the embedding vector changes. Each bar is one dimension of the 32-D embedding. Low frequencies (left) vary slowly; high frequencies (right) oscillate rapidly.
Notice how as you drag t, the leftmost bars (low frequency) change slowly, while the rightmost bars (high frequency) oscillate rapidly. The network can read both coarse and fine timing information from this embedding.
Properties of the Fourier embedding. The embedding has several useful properties:
• Unit norm: ||TimeEmb(t)|| = 1 for all t, because ∑(cos2 + sin2) = d/2 and we scale by √(2/d).
• Smooth in t: Small changes in t produce small changes in the embedding (Lipschitz continuous).
• Distinguishable: Different values of t produce different embeddings — the mapping is injective for practical frequency ranges.
• Multi-scale: The log-spaced frequencies cover both coarse timing (is it early or late?) and fine timing (exactly which step?).
After the Fourier embedding, the time vector is typically passed through a small MLP to produce the final conditioning signal:
python # Time conditioning pipeline t_fourier = time_embedding(t, d=256) # (B, 256) Fourier features t_emb = mlp(t_fourier) # (B, d) learned transformation # t_emb is then used for AdaLN in each DiT block gamma, beta = split(linear(t_emb)) # (B, d), (B, d) scale and shift
The conditioning variable y can be a class label, a text prompt, or even another image. Each type requires a different embedding strategy to convert raw input into vectors the model can process.
When y is a discrete class label (e.g. "cat"=0, "dog"=1, ..., "car"=999), the simplest approach is a learned embedding table. Each of the N+1 classes gets its own learnable vector in Rd:
python # Class embedding: a lookup table class_embed = nn.Embedding(num_classes, d) # (N+1, d) y_emb = class_embed(y) # y: (B,) integer labels -> (B, d)
These embedding vectors are trained jointly with the rest of the model — they are part of the parameters θ.
When y is a text prompt like "a corgi wearing sunglasses on a beach," embedding is much harder. We rely on pretrained frozen models that have already learned to understand language.
The most common choice is CLIP (Contrastive Language-Image Pretraining). CLIP was trained on billions of text-image pairs to learn a shared embedding space where images and their descriptions are close together. For our purposes:
This gives a single vector summarizing the entire prompt. But sometimes we want more granularity — the model should attend to specific words. For this, we use a pretrained transformer (like T5) that produces a sequence of embeddings:
where S is the number of text tokens and k is the embedding dimension. This gives the model a per-word handle on the prompt content.
| Embedding type | Input | Output shape | Used by |
|---|---|---|---|
| Learned table | Class label (integer) | (B, d) | Class-conditional DiT |
| CLIP | Text string | (B, dCLIP) | Stable Diffusion, FLUX |
| T5 encoder | Text string | (B, S, k) | Stable Diffusion 3 |
| Multiple encoders | Text string | Concat of above | SD3 (3 encoders), Movie Gen (3) |
Transformers process sequences of tokens. But an image is a 2D grid of pixels. How do we bridge the gap? The answer, borrowed from Vision Transformers (ViT), is patchification: chop the image into non-overlapping patches and treat each patch as a token.
Given an image x ∈ RC×H×W and a patch size P:
Worked example. Consider a 256×256 RGB image with patch size P=16:
After patchification, the image is just a sequence of 256 tokens — exactly what a transformer expects. The attention mechanism then processes these tokens, allowing each patch to attend to all other patches.
Why patches and not individual pixels? A 256×256 image has 65,536 pixels. Attention has O(N2) complexity, so attending over all pixels would cost 65,5362 ≈ 4.3 billion operations per layer — prohibitively expensive. With P=16 patches, we get N=256 tokens, costing only 2562 = 65,536 operations. That's a 65,000× reduction.
The tradeoff. Smaller patches capture finer detail but create longer sequences (more expensive). Larger patches are more efficient but lose spatial resolution within each patch. Common choices are P=8 (high quality, expensive) to P=32 (efficient, less detail).
| Image size | Patch size P | Tokens N | Per-patch dim C·P2 |
|---|---|---|---|
| 256×256 | 8 | 1024 | 192 |
| 256×256 | 16 | 256 | 768 |
| 512×512 | 16 | 1024 | 768 |
| 1024×1024 | 16 | 4096 | 768 |
| 1024×1024 | 32 | 1024 | 3072 |
python import torch.nn as nn class Patchify(nn.Module): def __init__(self, in_channels=3, patch_size=16, hidden_dim=768): super().__init__() self.P = patch_size # Linear projection of flattened patches self.proj = nn.Linear(in_channels * patch_size**2, hidden_dim) def forward(self, x): B, C, H, W = x.shape P = self.P # Reshape: (B, C, H, W) -> (B, N, C*P*P) x = x.unfold(2, P, P).unfold(3, P, P) # (B,C,H/P,W/P,P,P) x = x.contiguous().view(B, -1, C*P*P) # (B, N, C*P^2) return self.proj(x) # (B, N, d)
See how an image is divided into non-overlapping patches. Each colored cell becomes one token in the sequence. Adjust patch size to see the tradeoff: smaller patches = more tokens = higher resolution but more computation.
At the end of the transformer, we need to convert the token sequence back to an image. This depatchification applies a linear projection from d to C·P2 per token, then reshapes back to C×H×W:
The Diffusion Transformer (DiT) processes the patch tokens using a stack of L transformer layers. Each layer, called a DiT Block, applies three operations: self-attention among patches, cross-attention to the prompt, and time conditioning via adaptive normalization.
Let's trace the full data flow through a DiT:
Inside each DiT Block, the three operations happen sequentially:
Standard multi-head attention where queries, keys, and values all come from the patch tokens x:
Each patch can "see" all other patches, allowing the model to reason about global image structure (e.g. "the sky is blue, so this patch is probably clouds").
Queries come from patches x, but keys and values come from the prompt embedding ỹ:
This is how the model "reads" the prompt. Each image patch can attend to specific words in the text, learning associations like "the word 'red' should increase red pixel values in this region."
Adaptive Layer Normalization (AdaLN) is how time information enters each layer. The time embedding t̃ produces per-channel scale and shift parameters through an MLP:
This modulates every layer's behavior based on the current timestep. At early times (noisy input), the model needs to denoise aggressively. At late times (nearly clean input), it needs to make fine corrections. AdaLN lets each layer adapt its behavior automatically.
Worked example: dimension tracking through a DiT Block. Consider a model with hidden dimension d=1024, 16 attention heads (dh=64), N=256 patch tokens, and S=77 text tokens:
| Operation | Input | Output | Computation |
|---|---|---|---|
| Self-attention Q,K,V projection | x ∈ R256×1024 | Q,K,V ∈ R16×256×64 | Linear projections per head |
| Self-attention scores | Q,K ∈ R256×64 | A ∈ R256×256 | QKT/8 per head |
| Self-attention output | A ∈ R256×256, V ∈ R256×64 | R256×64 per head → R256×1024 | AV, concat, project |
| Cross-attention Q | x ∈ R256×1024 | Q ∈ R256×64 | Queries from patches |
| Cross-attention K,V | ỹ ∈ R77×1024 | K,V ∈ R77×64 | Keys/values from text |
| Cross-attention output | R256×77 scores | R256×1024 | Attend to text per patch |
| MLP | R256×1024 | R256×1024 | Two linear layers with GeLU |
| AdaLN (applied at each sub-layer) | t̃ ∈ R1024 | (γ,β) ∈ R1024 each | MLP on time embedding |
Every tensor maintains the shape RN×d = R256×1024 throughout the block. The cross-attention is the only operation where the sequence length changes (queries are 256 patches, keys/values are 77 text tokens), but the output is always mapped back to R256×1024.
Watch data flow through a single DiT block. Hover over each component to highlight its connections.
Before DiTs, the dominant architecture for diffusion models was the U-Net. Originally designed for biomedical image segmentation (Ronneberger et al., 2015), the U-Net was adapted for diffusion models by Ho et al. (2020). Its key property is that both its input and output have the shape of images — exactly what we need for predicting a velocity field uθt(x|y) ∈ RC×H×W.
The name "U-Net" comes from the U-shaped architecture when drawn as a diagram: the encoder path goes down (reducing spatial resolution), the midcoder processes at the lowest resolution, and the decoder path goes back up (restoring spatial resolution). Lateral skip connections connect corresponding encoder and decoder levels.
A U-Net consists of three parts: encoders (downsampling), a midcoder (processing at the lowest resolution), and decoders (upsampling). The "U" shape comes from the data path:
U-Net vs. DiT:
| Feature | U-Net | DiT |
|---|---|---|
| Core operation | Convolutions (local) | Attention (global) |
| Multi-resolution | Built-in (encoder/decoder) | Single resolution + patches |
| Scaling | Harder to scale | Scales like language models |
| Text conditioning | Cross-attention at select layers | Cross-attention at every layer |
| Used by | DDPM, early SD | SD3, FLUX, Movie Gen, VEO-3 |
The field has largely shifted from U-Nets to DiTs because transformers scale more predictably — doubling parameters consistently improves quality, just as in language modeling.
Historical context. The original DDPM paper (Ho et al., 2020) and Stable Diffusion 1–2 used U-Nets. The DiT paper (Peebles & Xie, 2023) showed that transformers could match and then exceed U-Net performance when scaled up. Since then, the field has converged on DiTs: Stable Diffusion 3, FLUX, Sora, VEO-3, and Movie Gen all use transformer-based architectures.
Why U-Nets are still relevant. U-Nets remain valuable for smaller-scale applications, edge deployment, and any setting where the built-in multi-resolution processing (without quadratic attention cost) is advantageous. Many open-source models still use U-Net architectures effectively.
A 1024×1024 RGB image has d = 3×1024×1024 ≈ 3 million dimensions. For our flow model, both the input and output have this shape. Training directly in this space is prohibitively expensive. But do we need all those dimensions?
Natural images occupy a tiny fraction of all possible pixel arrangements. Most random pixel configurations look like static noise. The actual "manifold of real images" is much lower-dimensional. An autoencoder exploits this by learning a compression:
where k ≪ d. For images, a typical compression is from 3×1024×1024 to 4×128×128 — a 48× reduction in dimensionality.
The standard training objective minimizes reconstruction error:
In plain English: encode the image, decode it back, and minimize the squared difference. A good autoencoder produces reconstructions that are nearly indistinguishable from the originals.
Concrete example. A 1024×1024 RGB image has d = 3×1024×1024 = 3,145,728 dimensions. A typical latent has k = 4×128×128 = 65,536 dimensions. That's a 48× compression. The flow model now operates in 65K-dimensional space instead of 3M-dimensional space — requiring roughly 48× less memory and computation per step.
python # Standard autoencoder architecture (simplified) class Encoder(nn.Module): # Input: (B, 3, 1024, 1024) # Sequence of conv layers + downsampling # Output: (B, 4, 128, 128) -- latent # Compression ratio: 3*1024*1024 / (4*128*128) = 48x class Decoder(nn.Module): # Input: (B, 4, 128, 128) -- latent # Sequence of conv layers + upsampling # Output: (B, 3, 1024, 1024) -- reconstructed image
A 2D point cloud is encoded to a lower-dimensional latent space and decoded back. The latent dimension controls how much information is preserved. Too low = blurry reconstruction. Too high = no compression.
A Variational Autoencoder (VAE) solves the "bad latent distribution" problem by making the encoder and decoder stochastic and adding a regularization term that pushes the latent distribution toward a Gaussian.
Instead of deterministic mappings, the VAE defines:
The encoder doesn't output a single point — it outputs a distribution (mean and variance). To encode, we sample z from this distribution. This stochasticity is crucial: it forces the encoder to spread its representations smoothly, rather than cramming all information into isolated points.
Term 1: Reconstruction loss. Make sure encode-then-decode gives back the original image:
For Gaussian decoders with fixed variance, this simplifies to mean squared error (up to constants):
Term 2: KL regularization. Push the encoder distribution toward a standard Gaussian prior pprior = N(0, Ik):
For Gaussian encoder with diagonal covariance, the KL divergence has a closed form:
When p = N(0, I), this becomes:
There's a subtle problem: the loss involves sampling z ∼ qφ(·|x), which depends on φ. We can't backpropagate through a random sampling operation. The trick: instead of sampling z directly, sample noise ε ∼ N(0, I) and compute:
Now the randomness is in ε (which doesn't depend on φ), and z is a differentiable function of φ. Gradients flow through μ and σ as usual.
python # VAE Training Step def vae_step(x, encoder, decoder, beta=0.01): # Encode: get mean and log-variance mu, log_var = encoder(x) # (B,k), (B,k) # Reparameterize: sample z without blocking gradients std = torch.exp(0.5 * log_var) # sigma = exp(0.5 * log(sigma^2)) eps = torch.randn_like(std) # noise from N(0,I) z = mu + std * eps # z ~ q_phi(z|x) # Decode x_hat = decoder(z) # (B, C, H, W) # Reconstruction loss recon = ((x - x_hat) ** 2).sum(dim=[1,2,3]).mean() # KL loss: closed form for Gaussian kl = -0.5 * (log_var - mu**2 - log_var.exp() + 1).sum(dim=1).mean() return recon + beta * kl
1. KL warm-up. Starting with β=0 and gradually increasing it over the first few epochs prevents posterior collapse: a pathological state where the encoder ignores x and outputs qφ(z|x) ≈ N(0, I) for all x. When this happens, the latent space carries no information about the data and the decoder generates random outputs.
2. Fixed decoder variance. Learning σθ2(z) is numerically delicate. Most practical implementations fix it to a constant, making the reconstruction loss proportional to simple MSE.
3. Perceptual losses. Pixel-wise MSE produces overly smooth, blurry reconstructions because it penalizes all pixel errors equally. Modern VAEs add perceptual losses: compare features extracted by a pretrained network (like VGG) rather than raw pixels. This produces sharper, more visually appealing reconstructions.
4. Adversarial training. Some VAEs add a discriminator that tries to distinguish real images from reconstructions (VAE-GAN style). This further improves sharpness but introduces optimization instability.
| Loss component | Effect | Too much | Too little |
|---|---|---|---|
| Reconstruction | Faithful decoding | Blurry (MSE dominates) | Decoded images don't match originals |
| KL divergence | Gaussian latent space | Posterior collapse | Irregular latent space |
| Perceptual | Sharp details | Hallucinated textures | Blurry outputs |
| Adversarial | Photorealism | Training instability | Smooth, unrealistic outputs |
Compare a standard autoencoder (left: irregular latent) with a VAE (right: Gaussian latent). The VAE's KL term pushes latents toward a smooth Gaussian, making them easier for a generative model to learn.
Now we can put everything together. Latent diffusion (or latent flow matching) is the recipe used by every state-of-the-art image and video generator: train a VAE to compress images into a well-behaved latent space, then train a flow/diffusion model in that latent space.
Step-by-step: what happens at inference time.
python # Latent diffusion inference (complete) # 1. Start with noise in LATENT space z_0 = torch.randn(B, 4, 128, 128) # NOT pixel space! # 2. Simulate ODE in latent space with CFG z = z_0 for t in linspace(0, 1, n_steps): # CFG: evaluate twice u_uncond = dit_model(z, t, NULL_PROMPT) # (B,4,128,128) u_cond = dit_model(z, t, text_emb) # (B,4,128,128) u_cfg = (1 - w) * u_uncond + w * u_cond z = z + dt * u_cfg # Euler step in latent space # 3. Decode latent to pixel space with frozen VAE decoder x = vae_decoder(z) # (B, 3, 1024, 1024) -- final image! # Note: we use the decoder MEAN, not a random sample
Stable Diffusion 3 extends the basic DiT with a key innovation: the image patches and text tokens are processed jointly in the same attention layers. This is called the multi-modal DiT (MM-DiT).
In a standard DiT, cross-attention lets image patches attend to text tokens. In the MM-DiT, text tokens also attend to image patches (bidirectional cross-attention). This allows richer information flow between modalities.
The MM-DiT concatenates the image patch sequence (RN×d) and text token sequence (RS×d) into a single sequence of length N+S, applies standard self-attention over the combined sequence, then splits the output back into image and text parts. Each part has its own AdaLN parameters conditioned on t.
python # MM-DiT: joint image-text attention def mm_dit_block(x_img, x_txt, t_emb): # x_img: (B, N, d) image patches # x_txt: (B, S, d) text tokens # Separate AdaLN per modality x_img = adaln(x_img, t_emb, params_img) x_txt = adaln(x_txt, t_emb, params_txt) # Concatenate into one sequence x_joint = torch.cat([x_img, x_txt], dim=1) # (B, N+S, d) # Joint self-attention x_joint = self_attention(x_joint) # Split back x_img, x_txt = x_joint[:, :N], x_joint[:, N:] # Separate MLP per modality x_img = mlp_img(adaln(x_img, t_emb, params_img2)) x_txt = mlp_txt(adaln(x_txt, t_emb, params_txt2)) return x_img, x_txt
SD3 parameters breakdown:
| Component | Parameters | Notes |
|---|---|---|
| MM-DiT backbone | ~8B | 38 layers, d=4096, 64 heads |
| CLIP ViT-L text encoder | ~0.4B | Frozen, provides global embedding |
| OpenCLIP ViT-bigG | ~2.5B | Frozen, provides global embedding |
| T5-XXL text encoder | ~4.7B | Frozen, provides per-token embeddings |
| VAE (encoder+decoder) | ~0.1B | Frozen, 8× spatial compression |
Note that the total system has ~16B parameters, but only ~8B are trained (the DiT). The text encoders and VAE are pretrained and frozen.
Meta's Movie Gen Video extends image generation to video. The key challenge: a video is a 4D tensor T×C×H×W, and even in latent space, the sequence lengths are enormous. Here are the design decisions and why they were made:
1. Temporal Autoencoder (TAE). The spatial VAE compresses H×W by 8× each, and the temporal VAE compresses T by 8×. A 10-second, 24fps video has T=240 frames. After TAE: 240/8=30 temporal latent frames, each 128×128 spatial. The patchified sequence length is 30×(128/P)2 tokens — still potentially thousands of tokens.
2. Temporal tiling. To handle long videos without running out of memory, Movie Gen chops the video into overlapping temporal chunks, encodes each chunk separately with the TAE, and stitches the latents together. This allows processing videos of arbitrary length at a fixed memory cost.
3. Space-time patchification. Unlike image DiTs that patchify only in (H,W), Movie Gen patchifies in (T,H,W). Each patch covers multiple frames, allowing the model to capture temporal structure within patches and learn motion patterns through attention across patches.
4. Three text encoders. UL2 (general semantics), ByT5 (character-level detail for text rendering), and MetaCLIP (visual-semantic alignment). The diversity of encoders ensures that different aspects of the text prompt are captured faithfully.
python # Movie Gen: temporal tiling for long videos def encode_long_video(video, tae, chunk_size=16, overlap=4): # video: (T, C, H, W) — can be hundreds of frames chunks = [] for start in range(0, T, chunk_size - overlap): end = min(start + chunk_size, T) chunk = video[start:end] latent_chunk = tae.encode(chunk) # (T', C_lat, H', W') chunks.append(latent_chunk) # Stitch with cross-fade in overlap regions return stitch_latents(chunks, overlap)
A key finding from the DiT paper and subsequent work is that diffusion transformers follow scaling laws similar to those discovered in language modeling (Kaplan et al., 2020). Specifically:
• Loss scales as a power law with the number of parameters, training compute, and dataset size.
• Larger models are more sample-efficient: they achieve the same loss with fewer training steps.
• FID improves monotonically with model size, at least up to the sizes tested (8B+ parameters).
This is why the field has shifted from U-Nets to DiTs: transformers have a proven recipe for scaling (just make them bigger), while U-Nets hit diminishing returns more quickly.
| Model | Architecture | Parameters | Year |
|---|---|---|---|
| DDPM | U-Net | ~100M | 2020 |
| Guided Diffusion | U-Net | ~500M | 2021 |
| Stable Diffusion 1.5 | U-Net | ~860M | 2022 |
| DiT-XL | DiT | ~675M | 2023 |
| Stable Diffusion 3 | MM-DiT | ~8B | 2024 |
| FLUX | DiT variant | ~12B | 2024 |
| Movie Gen Video | DiT | ~30B | 2024 |
| VEO-3 | DiT variant | ~20B+ | 2025 |
The trend is clear: models are getting bigger, and bigger models produce better results. The transition from U-Nets (~1B parameters max) to DiTs (30B+ parameters) enabled a qualitative leap in generation quality.
Let's trace every step of generating an image with Stable Diffusion 3, from prompt to pixels:
Memory and compute breakdown (approximate, for a single 1024×1024 image):
| Stage | FLOPs | Memory | Wall time |
|---|---|---|---|
| Text encoding (3 models) | ~50 GFLOPs | ~10 GB | ~0.1s |
| DiT (50 steps × 2 for CFG) | ~50 TFLOPs | ~20 GB | ~5s |
| VAE decode | ~5 GFLOPs | ~2 GB | ~0.05s |
| Total | ~50 TFLOPs | ~32 GB peak | ~5s |
The DiT dominates: 100 forward passes (50 steps × 2 evaluations per step for CFG) through an 8B parameter transformer. This is why optimization techniques like flash attention, mixed precision, and model parallelism are essential for practical deployment.
Deploying these massive models at scale requires several critical optimizations:
1. Flash Attention. Standard attention requires O(N2) memory. Flash Attention computes attention with O(N) memory using tiling, essential for long sequences (e.g., video with thousands of tokens).
2. Mixed precision. BFloat16 halves memory and doubles throughput. DiT training uses BF16 for forward/backward with Float32 for weight updates.
3. Model parallelism. For 10B+ parameter models, tensor parallelism splits attention heads across GPUs, pipeline parallelism splits layers.
4. CFG batching. Stack the conditioned and unconditioned inputs into a single 2B-size batch for one forward pass (vs. two sequential passes). ~40% faster.
python # Batched CFG inference (efficient) def cfg_step(model, z, t, text_emb, null_emb, w): z_double = torch.cat([z, z], dim=0) # (2B, ...) t_double = torch.cat([t, t], dim=0) y_double = torch.cat([text_emb, null_emb]) u_double = model(z_double, t_double, y_double) # ONE pass u_cond, u_uncond = u_double.chunk(2) return (1 - w) * u_uncond + w * u_cond
5. Step distillation. Multi-step sampling (50 steps) can be distilled into 4-step or 1-step models via consistency distillation, trading quality for 10-50× speedup.
6. Quantization. INT8/INT4 reduces model size 2-4×. Text encoders and VAE are especially amenable since they are frozen.
Let's trace exact tensor shapes through the entire Stable Diffusion 3 pipeline for generating a 1024×1024 image:
| Stage | Operation | Shape |
|---|---|---|
| 1. Noise | z0 ∼ N(0, I) | (1, 4, 128, 128) |
| 2. Time embed | TimeEmb(t) | (1, 1024) |
| 3a. CLIP-L | CLIP(text) | (1, 768) → project to (1, 1024) |
| 3b. CLIP-bigG | CLIP(text) | (1, 1280) → project to (1, 1024) |
| 3c. T5-XXL | T5(text) | (1, 77, 4096) → project to (1, 77, 1024) |
| 4. Patchify latent | P=2, N=(128/2)2=4096 | (1, 4096, 1024) |
| 5. Joint sequence | Cat patches + text | (1, 4096+77, 1024) = (1, 4173, 1024) |
| 6. MM-DiT ×38 layers | Self-attention + AdaLN | (1, 4173, 1024) → (1, 4173, 1024) |
| 7. Split | Extract image tokens | (1, 4096, 1024) |
| 8. Depatchify | Linear + reshape | (1, 4, 128, 128) |
| 9. VAE decode | Decoder network | (1, 3, 1024, 1024) — final image! |
The bottleneck is step 6: self-attention over 4,173 tokens costs O(41732) ≈ 17.4M attention entries per head, per layer, per step. With 38 layers, 50 steps, 2× for CFG, and 64 attention heads, the total attention computations are staggering — this is why Flash Attention is essential.
An important and often overlooked point: the diffusion model can never generate details finer than what the VAE decoder can reconstruct. If the VAE introduces blurriness or artifacts in its reconstructions, those same artifacts will appear in all generated images.
This is why the first Stable Diffusion papers invested heavily in autoencoder quality. The VAE reconstruction must be near-perfect at the target resolution. Typical metrics:
| Metric | Good VAE | Poor VAE |
|---|---|---|
| PSNR (pixel quality) | >30 dB | <25 dB |
| SSIM (structural similarity) | >0.95 | <0.85 |
| LPIPS (perceptual distance) | <0.05 | >0.15 |
| Latent dim ratio | 48×–64× | >128× (too compressed) |
If LPIPS is high (perceptual distance is large), the diffusion model will produce images that look "off" regardless of how well the diffusion model is trained. The autoencoder quality sets an upper bound on overall generation quality.
Let's compute the KL divergence for a concrete example. Suppose for a single image x, the encoder outputs:
So σ2 = [e−0.5, e0.2, e−1.0] = [0.607, 1.221, 0.368].
The KL divergence to N(0, I) is:
Computing each term:
| j | σj2 | μj2 | log σj2 | Term |
|---|---|---|---|---|
| 1 | 0.607 | 0.250 | −0.500 | (0.607+0.250+0.500−1)/2 = 0.179 |
| 2 | 1.221 | 0.090 | 0.200 | (1.221+0.090−0.200−1)/2 = 0.056 |
| 3 | 0.368 | 0.640 | −1.000 | (0.368+0.640+1.000−1)/2 = 0.504 |
Position 3 contributes the most because its mean (0.8) is far from zero and its variance (0.368) is far from 1. The KL loss will push the encoder to bring this dimension closer to N(0,1).
Reparameterization in action. To sample z from this encoder distribution:
This z is then passed to the decoder for reconstruction. The gradient flows through the deterministic operations (μ + σ⊙ε) to update the encoder parameters, while the randomness comes only from ε (independent of parameters).
The complete training pipeline for a latent diffusion model has two distinct phases:
Phase 1: Train the VAE (weeks of compute).
• Dataset: millions of images at target resolution
• Objective: reconstruction + KL regularization + perceptual loss
• Architecture: convolutional encoder and decoder (ResNet blocks)
• Duration: typically 1-2 weeks on 64+ GPUs
• Result: frozen encoder and decoder that compress 1024×1024 images to 128×128 latents
• Validation: check PSNR/SSIM/LPIPS on held-out images; inspect reconstructions visually
Phase 2: Train the DiT (months of compute).
• Dataset: billions of text-image pairs (LAION, internal datasets)
• Preprocessing: encode all images to latents using frozen VAE; compute text embeddings using frozen encoders
• Objective: conditional flow matching loss on latents
• Architecture: DiT/MM-DiT (billions of parameters)
• CFG: 10% label dropping rate
• Duration: months on thousands of GPUs
• Result: a model that can generate latents from noise, conditioned on text
Key insight about the two-phase approach: The VAE and DiT are completely decoupled. The VAE defines the "language" of the latent space; the DiT learns to "speak" in that language. A better VAE (sharper reconstructions, smoother latent space) directly translates to better generation quality, independent of the DiT.
Why not train end-to-end? Training the VAE and DiT jointly (end-to-end) is possible in principle but impractical for several reasons:
• Different learning rates. The VAE converges in weeks; the DiT needs months. Joint training would require carefully balancing two very different optimization dynamics.
• Memory. Holding both the VAE and DiT in GPU memory simultaneously would be prohibitive for large models.
• Modularity. Decoupling allows upgrading the DiT without retraining the VAE, and vice versa. Many research groups share pretrained VAEs.
• Precomputation. With a frozen VAE, all training images can be pre-encoded to latents once and stored on disk, eliminating the VAE encoder from the training loop entirely.
Practical tip: latent caching. Encoding the entire training dataset (say, 5 billion images) through the VAE takes significant time but only needs to happen once. The latent tensors (4×128×128 = 65K float16 values = ~130KB per image) are much smaller than the original images (~500KB–5MB), making storage and loading faster during DiT training.
After patchification, the transformer receives a sequence of N tokens, but it has no information about where each patch came from in the original image. A patch from the top-left corner should be treated differently from one in the bottom-right. We add positional embeddings to each patch token.
The most common approach is learned absolute positional embeddings: a learnable matrix P ∈ RN×d added to the patch embeddings:
Alternative approaches include:
| Method | Description | Advantage |
|---|---|---|
| Learned absolute | One learnable vector per position | Simple, effective |
| Sinusoidal 2D | Sine/cosine at (row, col) frequencies | No learned params, generalizes to new resolutions |
| RoPE (Rotary) | Rotate query/key vectors by position | Better length generalization |
| ALiBi | Linear bias in attention scores | Simple, no extra parameters |
For video generation, 3D positional embeddings encode the (time, row, col) position of each space-time patch, allowing the model to understand both spatial layout and temporal ordering.
The choice of noise schedule (αt, βt) affects training stability and sample quality. For latent diffusion, the straight-line schedule used in flow matching is the simplest and most common:
This schedule has the elegant property that the target velocity is constant in z and ε — independent of t. This simplifies both implementation and analysis. Stable Diffusion 3 uses this exact schedule.
Some models use more complex schedules (e.g., cosine schedules from DDPM, shifted schedules) that spend more time at intermediate noise levels. The optimal schedule depends on the data domain and is often determined empirically.
The core operation in both DiTs and U-Nets (at higher layers) is attention. Let's trace the computation for a single attention head in a DiT block:
Given: Input sequence x ∈ RN×d, projection matrices WQ, WK, WV ∈ Rd×dh.
The attention matrix A is N×N, where Aij tells us how much token i should attend to token j. For self-attention, high Aij means "patch i looks at patch j for guidance on what velocity to predict."
For cross-attention (patches attending to text), the queries come from image patches but keys and values come from text tokens:
Here Aij tells us how much image patch i attends to text token j. High attention on the word "red" means that patch i should become more red. This cross-attention is the mechanism by which text prompts guide image generation.
Multi-head attention. In practice, h parallel attention heads are used, each with dimension dh = d/h. The outputs are concatenated and projected back to Rd:
Multiple heads allow the model to attend to different aspects simultaneously: one head might focus on texture, another on color, another on spatial relationships to the prompt words.
Computational cost of attention. For N patch tokens and S text tokens, the costs per layer are:
| Operation | FLOPs | Memory |
|---|---|---|
| Self-attention (patches) | O(N2 · dh) | O(N2) or O(N) with Flash |
| Cross-attention (patches ↔ text) | O(N · S · dh) | O(N · S) or O(N) with Flash |
| MLP (per token) | O(N · d · dff) | O(N · dff) |
| AdaLN projections | O(d2) | O(d) |
For SD3 with N=4096 and S=77, self-attention (40962 = 16.8M) dominates cross-attention (4096×77 = 315K) by ~50×. This is why reducing N (through larger patches or lower-resolution latents) has such a dramatic effect on inference speed.
The resolution-compute tradeoff. Generating a 2048×2048 image (4× the resolution of 1024×1024) requires 4× more latent tokens. Since attention is O(N2), the compute increases by 16×. This quadratic scaling is the primary reason why very high-resolution generation remains challenging.
Strategies to mitigate this include:
• Generate at lower resolution, then upscale with a super-resolution model
• Patch-based generation with overlapping patches stitched together
• Windowed attention that restricts attention to local neighborhoods (linear cost)
• Hierarchical generation that first generates a low-res image, then refines regions
Each approach has tradeoffs between quality, coherence, and speed. The optimal strategy depends on the target resolution and the available compute budget.
python # Phase 1: VAE training (simplified) vae = VAE(in_channels=3, latent_channels=4) for x in image_dataset: z, mu, logvar = vae.encode(x) x_hat = vae.decode(z) loss = mse(x, x_hat) + 0.01 * kl(mu, logvar) loss += perceptual_loss(x, x_hat) # VGG features loss.backward(); optimizer.step() torch.save(vae.state_dict(), "vae.pt") # Phase 2: DiT training (freeze VAE, train DiT) vae.eval() # frozen! dit = DiT(layers=38, hidden=4096, heads=64) for x, text in pair_dataset: with torch.no_grad(): z = vae.encode(x).mean # frozen VAE y = text_encoder(text) # frozen CLIP/T5 t = torch.rand(1) eps = torch.randn_like(z) z_t = t * z + (1-t) * eps u_target = z - eps # CFG label dropping if random.random() < 0.1: y = null_embed u_pred = dit(z_t, t, y) loss = mse(u_pred, u_target) loss.backward(); optimizer.step()
Stable Diffusion 3 is one of the most influential image generators. Here's how it instantiates the latent diffusion recipe:
| Component | SD3 Choice |
|---|---|
| Probability path | Conditional flow matching (straight line) |
| Architecture | MM-DiT (multi-modal DiT) |
| Text encoders | 3 total: 2× CLIP + T5-XXL (frozen) |
| Autoencoder | Pretrained VAE (latent dim 4×128×128 for 1024×1024 images) |
| Guidance | Classifier-free guidance, w ∈ [2.0, 5.0] |
| Sampling | 50 Euler steps |
| Parameters | 8 billion |
The MM-DiT (multi-modal DiT) extends the standard DiT by attending not just to image patches, but also to text tokens from all three encoders simultaneously. This allows the model to use coarse global summaries (CLIP) alongside fine-grained per-word detail (T5).
Video generation adds a temporal dimension: data lives in RT×C×H×W where T is the number of frames. Movie Gen adapts the image pipeline:
| Component | Movie Gen Choice |
|---|---|
| Probability path | Straight-line flow matching (αt=t, βt=1−t) |
| Autoencoder | Temporal AE: T×3×H×W → T'×C×H'×W' (8× downsample per axis) |
| Architecture | DiT with space+time patchification |
| Text encoders | 3: UL2 + ByT5 + MetaCLIP |
| Parameters | 30 billion |
The key challenge for video is the temporal autoencoder (TAE). It must compress both spatially and temporally, reducing memory so the DiT can process reasonable-length sequences. A temporal tiling procedure chops long videos into overlapping chunks, encodes each separately, then stitches the latents together.
Why video is so much harder. A single 1024×1024 image has ~3M pixels. A 10-second video at 24fps has 240 frames × 3M = 720M values. Even in latent space with 8× temporal compression and 8×8 spatial compression, the latent has 30×128×128 ≈ 500K values — still enormous. The DiT must process patch sequences of length in the thousands, making efficient attention mechanisms critical.
Text encoders in Movie Gen. The three encoders capture different linguistic aspects:
| Encoder | Strength | Example use |
|---|---|---|
| UL2 | Semantic reasoning, logical structure | "A person juggling while riding a bicycle" |
| ByT5 | Character-level detail | "A sign that says HELLO" (renders text correctly) |
| MetaCLIP | Visual-semantic alignment | Overall style and scene composition |
Using multiple encoders simultaneously allows the model to handle both high-level semantics and low-level detail in text prompts — a major improvement over single-encoder systems.
python # Step 1: Train VAE (separately, before diffusion training) for x in image_dataset: mu, logvar = vae_encoder(x) # (B, 4, 128, 128) z = mu + exp(0.5*logvar) * eps # reparameterize x_hat = vae_decoder(z) # (B, 3, 1024, 1024) loss = recon(x, x_hat) + beta * kl(mu, logvar) # ... optimize, then freeze VAE # Step 2: Precompute latents latent_dataset = [vae_encoder(x).mean for x in image_dataset] # Step 3: Train DiT on latents for z, y in latent_dataloader: # z: (B,4,128,128), y: text t = torch.rand(B) eps = torch.randn_like(z) z_t = t * z + (1-t) * eps # noisy latent u_target = z - eps # target velocity # CFG label dropping mask = torch.rand(B) < 0.1 y_embed = text_encoder(y) # frozen CLIP/T5 y_embed[mask] = null_embed # DiT forward pass u_pred = dit(z_t, t, y_embed) # (B,4,128,128) loss = ((u_pred - u_target)**2).mean()
The full pipeline: encode a 2D data point into latent space, add noise, denoise with a learned vector field, decode back. Press Play to watch the process animate.
This chapter covered the engineering stack that transforms the elegant theory of flow matching into real-world image and video generators. Let's take stock.
| Component | Purpose | Key Idea |
|---|---|---|
| Fourier time embedding | Give t to the network | Sinusoidal features at log-spaced frequencies |
| Prompt embedding | Give y to the network | Frozen pretrained encoders (CLIP, T5) |
| Patchification | Convert images to token sequences | Non-overlapping P×P patches + linear projection |
| DiT blocks | Process patch tokens | Self-attention + cross-attention + AdaLN |
| U-Net | Alternative to DiT | Convolutional encoder-decoder with skip connections |
| VAE | Compress to latent space | Stochastic encoder + KL regularization toward N(0,I) |
| Latent diffusion | Train flow model in latent space | 48× dimensionality reduction, semantic focus |
What lies ahead. Chapter 7 extends flow matching beyond continuous data entirely. Instead of images (vectors in Rd), we'll model discrete sequences like text. The principles are the same — interpolate between noise and data, train on conditional targets — but the mathematical machinery changes from ODEs to continuous-time Markov chains.
Practical takeaways:
• DiTs scale better than U-Nets — doubling parameters consistently improves quality, mirroring the scaling laws of language models.
• Text encoding matters enormously — the jump from single CLIP to multi-encoder (CLIP + T5) was one of the biggest quality improvements in SD3.
• The autoencoder quality is a ceiling — the diffusion model can never generate details finer than what the decoder can reconstruct.
• Videos are just images + time — the same tools (DiT, VAE, CFG) extend to video by adding temporal dimensions to patches and autoencoders.
For reference, the essential formulas from this chapter:
| Component | Formula | Purpose |
|---|---|---|
| Fourier embedding | TimeEmb(t) = √(2/d)[cos(2πwit), sin(2πwit)] | Embed scalar t into Rd |
| Patchification | PatchEmb(x) = Patchify(x) · W ∈ RN×d | Convert image to token sequence |
| Depatchification | u = Depatchify(x̃L · W̃) ∈ RC×H×W | Convert tokens back to image |
| AdaLN | AdaNorm(x) = (1+γ)⊙Norm(x) + β | Time conditioning |
| AE reconstruction | LRecon = E||μθ(μφ(x)) − x||2 | Train autoencoder |
| VAE KL | DKL = (1/2)∑[σ2 + μ2 − logσ2 − 1] | Regularize latent → N(0,I) |
| VAE total | LVAE = LRecon + β LKL | Joint reconstruction + regularization |
| Reparameterization | z = μφ(x) + σφ(x) ⊙ ε, ε∼N(0,I) | Enable gradients through sampling |
The complete stack, from theory to production:
"In latent space, the generative model can focus on what matters — objects, colors, composition — rather than reproducing every imperceptible pixel detail." — Rombach et al., 2022