Esser, Kulal, Blattmann, Entezari, et al. (Stability AI) — 2024

Stable Diffusion 3

Scaling rectified flow transformers for high-resolution image synthesis — straight-line trajectories from noise to data, a multimodal DiT with separate text/image streams joined by attention, and predictable scaling laws.

Prerequisites: Diffusion models + Transformers + Latent representations
10
Chapters
5+
Simulations

Chapter 0: The Problem

By early 2024, diffusion models dominated image generation. SDXL, DALL-E 3, and Midjourney could produce stunning images from text. But three stubborn problems remained:

  1. Slow sampling. Standard diffusion models like DDPM define curved paths from noise to data. These curved trajectories need many neural network evaluations (often 50+) to follow accurately. Each evaluation is a full forward pass through a billion-parameter model — expensive.
  2. Weak text understanding. U-Net based models feed text through cross-attention, treating it as a static conditioning signal. The text representation never gets to "talk back" to the image representation. Result: poor typography, missed objects, wrong spatial arrangements.
  3. Unclear scaling. Unlike language models, where doubling parameters predictably improves loss, diffusion models with U-Net backbones had no clean scaling laws. You couldn't predict whether making the model bigger would actually help.
The goal: Build a text-to-image system with (1) straight-line trajectories that need fewer sampling steps, (2) an architecture where text and image representations deeply interact, and (3) predictable scaling behavior. SD3 delivers all three by combining rectified flow with a novel Multimodal DiT (MMDiT) architecture.
Why do standard DDPM-style diffusion models require many sampling steps?

Chapter 1: Rectified Flow

All diffusion-style models share the same core idea: define a forward process that turns data into noise, then learn to reverse it. The question is what path the forward process takes.

The forward process

Given a data sample x0 and noise ε ~ N(0, I), define the noisy sample at time t as:

zt = (1 − t) x0 + t ε

At t = 0, z0 = x0 (pure data). At t = 1, z1 = ε (pure noise). For any t in between, zt is a linear interpolation between data and noise. This is the defining property of rectified flow — the path from data to noise is a straight line.

Why straight lines matter

Compare this to DDPM, which uses a variance-preserving schedule: zt = αt x0 + σt ε, where αt and σt follow specific curved schedules. The DDPM path from data to noise is curved — the signal and noise don't mix linearly.

Why does this matter? Because at inference time, you need to follow these paths backwards. A curved path requires many small steps to trace accurately (like driving on a winding mountain road). A straight path can theoretically be traversed in a single step (like a highway).

The key insight: Rectified flows connect data and noise along straight lines. The velocity along these lines is constant: v = ε − x0. A neural network learns to predict this velocity, and because the trajectory is straight, even a crude Euler solver with few steps stays close to the true path. Fewer steps = faster sampling.

The training objective

The network vΘ learns the velocity field by minimizing the conditional flow matching loss:

LCFM = Et, x0, ε || vΘ(zt, t) − (ε − x0) ||²

The target velocity ε − x0 is constant for a given (x0, ε) pair — it doesn't depend on t. This makes the learning problem simpler than DDPM's time-dependent noise prediction.

Rectified Flow vs DDPM Trajectories

Rectified flow (teal) takes a straight path from data to noise. DDPM (orange) follows a curved variance-preserving path. Toggle between them to see how the straight path needs fewer solver steps.

Steps5
What is the velocity target that the rectified flow network learns to predict?

Chapter 2: The MMDiT Architecture

Previous text-to-image models (Stable Diffusion 1/2, SDXL) use a U-Net backbone with cross-attention: the image features attend to text features, but text features are frozen — they never get to see what the image is doing. This is a one-way street.

SD3 replaces the U-Net with a Multimodal Diffusion Transformer (MMDiT), which treats text and image as two parallel streams that interact through joint attention.

Two streams, one attention

The MMDiT block has two independent pathways:

  1. Image stream: Processes patch embeddings of the noisy latent zt
  2. Text stream: Processes token embeddings from the text encoders

Each stream has its own LayerNorm, linear projections for Q/K/V, and MLP. But for the attention operation, the two sequences are concatenated and attention is computed jointly. This means every image patch can attend to every text token, and every text token can attend to every image patch.

Why this matters: In cross-attention (SDXL), text tokens are read-only — the image attends to text, but text never updates based on the image. In MMDiT's joint attention, both modalities evolve together. The text representation can adapt based on what the image looks like at each denoising step. This bidirectional flow is why SD3 excels at typography, spatial reasoning, and compositional prompts.

Modulation via timestep

Both streams are modulated by the diffusion timestep t and a pooled text embedding y (from CLIP). These modulation parameters (α, β, γ) scale and shift the hidden states — identical to DiT's adaptive layer norm (adaLN). Each stream gets its own set of modulation parameters, so the same timestep signal can affect text and image processing differently.

Scaling the model

The model size is controlled by a single parameter: depth d (number of MMDiT blocks). The hidden dimension is 64d, the MLP expands to 4 × 64d, and there are d attention heads. This gives a clean scaling axis: d=15 is ~500M params, d=24 is ~2B, d=38 is ~8B.

MMDiT Block Architecture

Two parallel streams (image in teal, text in orange) with separate weights for LayerNorm, Q/K/V projections, and MLPs. They merge only for the joint attention computation, then split back into their own streams.

The full forward pass with shapes (1024×1024 generation, d=38)

Noise zt
128 × 128 × 16 latent (from 16-channel VAE)
Patchify (2×2)
128×128×16 → 64×64 = 4096 image tokens, each dim 16×4 = 64 → projected to hidden dim 64×38 = 2432
Text context
~410 text tokens, each projected to dim 2432 → text stream
38 × MMDiT Block
Joint attention: concat [4096 image; ~410 text] = ~4506 total tokens for attention. Each block: LN + QKV (separate weights per stream) → joint attention → split → residual → LN + MLP (4×2432 = 9728 hidden) → residual. 38 attention heads.
Unpatchify
4096 tokens × 2432 → linear to 64 → reshape to 128×128×16 → velocity prediction vθ
Euler step
zt-Δt = zt − Δt · vθ(zt, t) → repeat for N steps
VAE Decoder
128×128×16 → 1024×1024×3 RGB image
Engineering decisions: Why MM-DiT joint attention instead of cross-attention? Cross-attention keeps text frozen — text tokens never update based on what the image looks like. In compositional prompts ("red ball LEFT of blue cube"), the model needs to correlate spatial image regions with specific text tokens, which requires bidirectional flow. Joint attention enables this at the cost of longer sequence length (~4506 vs 4096 tokens), adding ~10% overhead to the attention computation. The payoff: dramatically better compositional accuracy (GenEval 0.74 vs SDXL's 0.55).

QK-Normalization

When finetuning on high resolutions, the attention logits can grow uncontrollably — the largest attention values explode, causing entropy collapse (all attention weight concentrates on one token). SD3 applies RMSNorm to Q and K before computing attention, which stabilizes training and allows efficient bf16 mixed-precision even at 1024×1024.

What is the key difference between MMDiT's joint attention and the cross-attention used in SDXL?

Chapter 3: Text Conditioning

SD3 doesn't rely on a single text encoder. It uses three pretrained text models, each capturing different aspects of the prompt:

CLIP-L/14

OpenAI's CLIP with a ViT-L/14 image encoder. Produces a 77-token sequence embedding and a pooled 768-dim vector. Good at visual-semantic alignment — it knows what words "look like" because it was trained on image-text pairs.

CLIP-G/14

A larger CLIP variant (ViT-bigG/14). Also produces 77 tokens + a pooled vector. The larger model captures more nuanced visual concepts.

T5-XXL

Google's 4.7B parameter text-to-text transformer. Unlike CLIP (which is trained on image-text pairs), T5 is a pure language model. It understands complex sentence structure, negation, spatial relationships, and counting — precisely the things CLIP struggles with.

How they combine

The two CLIP models provide the pooled text embedding y (concatenated, then projected), which drives the timestep modulation in every MMDiT block. The sequence outputs from all three encoders are concatenated into the context sequence c (77 + 77 tokens from CLIP, plus T5 tokens), which forms the text stream in MMDiT's joint attention.

Exact data flow with shapes

Text prompt
"A golden retriever on a red couch" → tokenize
CLIP-L/14
77 tokens × 768 dim + pooled vector 768-dim. ~124M params, frozen.
CLIP-G/14
77 tokens × 1280 dim + pooled vector 1280-dim. ~1.8B params, frozen.
T5-XXL
Up to 256 tokens × 4096 dim. 4.7B params, frozen.
Pooled embedding y
Concat CLIP-L pool (768) + CLIP-G pool (1280) → project to 2048 dim → timestep modulation for all MMDiT blocks
Context sequence c
Concat: CLIP-L (77×768) + CLIP-G (77×1280) + T5 (up to 256×4096) → each padded/projected to match MMDiT hidden dim → ~410 text tokens form the text stream
Why three encoders? CLIP encoders are excellent at visual-semantic grounding ("what does a cat look like?") but weak at language understanding ("the red ball is to the left of the blue cube"). T5 is the opposite — strong language understanding but no visual grounding. Together, they cover both bases. The paper shows that removing T5 significantly degrades performance on compositional prompts, while removing CLIP hurts visual quality.
What degrades when you drop encoders: Removing T5-XXL drops GenEval from 0.74 to ~0.62 — the model loses counting, spatial reasoning, and complex attribute binding. Removing both CLIPs while keeping T5 hurts visual quality (blurry textures, poor color fidelity). You need the full trio for best results. At inference, the text encoders add ~6.6B frozen parameters to memory, but only run once per prompt (not per denoising step).
Improved captions: The training data uses a 50/50 mix of original human-written captions and synthetic captions generated by CogVLM. Human captions tend to be sparse ("a dog"), while synthetic captions describe the full scene ("a golden retriever sitting on a red couch in a sunlit living room"). This mix dramatically improves the model's ability to follow detailed prompts on GenEval benchmarks.
Why does SD3 use T5 in addition to CLIP for text encoding?

Chapter 4: Training

Two key training innovations separate SD3 from vanilla rectified flow: logit-normal timestep sampling and resolution-dependent shifting.

Logit-normal timestep sampling

In standard rectified flow, timesteps t are sampled uniformly from [0, 1]. But not all timesteps are equally informative. At t ≈ 0, zt is almost pure data — easy to predict. At t ≈ 1, zt is almost pure noise — the prediction target is roughly the dataset mean. The hardest (most informative) predictions are at intermediate timesteps.

SD3 samples timesteps from a logit-normal distribution:

πln(t; m, s) = (1 / s√(2π) t(1−t)) exp(−(logit(t) − m)² / 2s²)

where logit(t) = log(t / (1−t)). With m = 0 and s = 1, this puts most sampling weight on intermediate timesteps while still covering the endpoints with some probability.

Timestep Sampling Distributions

Compare uniform sampling (flat) vs logit-normal (peaked at middle). The logit-normal biases training toward the most informative intermediate timesteps. Adjust m and s to see how the distribution changes.

m0.00
s1.00

Resolution-dependent timestep shifting

Higher resolution images have more pixels, so they need more noise to destroy the signal. Think of it this way: a 1024×1024 image has 16× more pixels than 256×256. If you average the noisy pixels (as a rough estimate of the clean image), the law of large numbers gives you a much better estimate with more pixels. So the same amount of noise is "less noisy" at higher resolutions.

SD3 addresses this with a resolution-dependent shift. Given a model pretrained at resolution n and finetuned at resolution m, the timestep mapping is:

tm = √(m/n) · tn / (1 + (√(m/n) − 1) tn)

For 1024×1024 finetuning from 256×256, a shift value of α = √(m/n) = 3.0 is used. This effectively pushes the noise schedule to add more noise at each timestep, compensating for the higher resolution.

Empirical validation: The paper runs a human preference study over different shift values for 1024² sampling. Shifts above 1.5 are strongly preferred, with diminishing differences among higher values. α = 3.0 is used for all final models.

Frozen vs. trained components

ComponentParametersStatus
CLIP-L/14~124MFrozen (pretrained by OpenAI)
CLIP-G/14~1.8BFrozen (pretrained, OpenCLIP)
T5-XXL encoder4.7BFrozen (pretrained by Google)
VAE (16-ch)~80MPretrained separately, frozen during MMDiT training
MMDiT (d=38)~8BTrained from scratch

Total system: ~14.7B parameters. Of those, only the 8B MMDiT is trained for the diffusion task. The 6.7B in text encoders and VAE are pretrained and frozen — they provide the "language" and "pixel" expertise, while MMDiT learns to compose them.

Why rectified flow instead of DDPM?

Three concrete reasons:

Why does the logit-normal distribution improve training over uniform timestep sampling?

Chapter 5: Scaling Laws

One of SD3's most important contributions is demonstrating that MMDiT follows predictable scaling trends, just like language models.

Validation loss scales with compute

The paper trains MMDiT models at depths d = 15, 18, 21, 24, 30, and 38 (roughly 500M to 8B parameters) for 500k steps each on 256² images. The validation loss decreases smoothly and predictably as model size increases. Crucially, there are no diminishing returns — the curves don't flatten out, suggesting even larger models would continue to improve.

Lower loss = better images

This sounds obvious but isn't guaranteed. Many generative models show poor correlation between training loss and sample quality. SD3 demonstrates a strong, monotonic relationship: lower validation loss consistently correlates with better CLIP scores, better FID, and better human preference ratings.

Why scaling laws matter: If you can predict performance from compute budget, you can make rational decisions about model size before spending millions on training. This is the same insight that powered GPT-3/4's development — now it applies to image generation. The clean scaling axis of MMDiT (just change depth d) is what makes this possible, unlike U-Nets where scaling is ad hoc.

Larger models are more sample-efficient

The 8B model (d=38) loses only 2.71% CLIP score when reducing from 50 to 5 sampling steps. The 500M model (d=15) loses 4.30%. Larger models learn straighter trajectories, so they tolerate aggressive step reduction better. This means bigger models are not only better but also faster per quality level.

Scaling: Model Size vs Performance

Validation loss decreases smoothly with model depth. Larger models also maintain more performance when using fewer sampling steps. Drag the step count to see how step-efficiency improves with scale.

Steps50
What property of larger MMDiT models makes them more sample-efficient at inference time?

Chapter 6: Results

SD3's 8B model achieves state-of-the-art performance across multiple benchmarks, outperforming both open-source and closed-source alternatives.

GenEval benchmark

GenEval tests compositional text-to-image generation: can the model render the right number of objects, in the right colors, in the right positions? SD3 (d=38, 1024², DPO-aligned) scores 0.74 overall, compared to DALL-E 3's 0.67 and SDXL's 0.55.

The breakdown is striking:

Human preference evaluation

In head-to-head human evaluations on PartiPrompts, SD3 wins against SDXL Turbo, SDXL, Pixart-α, and DALL-E 3 across visual quality, prompt following, and typography. The typography advantage is particularly large — a direct benefit of MMDiT's joint attention allowing the model to "see" text tokens while generating image patches.

GenEval Benchmark Comparison

Overall GenEval scores measuring compositional text-to-image generation. SD3 (8B + DPO) outperforms all competitors including DALL-E 3.

What degrades and when

Concrete inference numbers (8B model, 1024×1024): Text encoding runs once: CLIP-L (~0.1s) + CLIP-G (~0.3s) + T5-XXL (~0.8s) = ~1.2s. Then 50 Euler steps, each requiring one MMDiT forward pass through 4506 tokens with 38 layers = ~28 Gflops per step × 50 = ~1400 Gflops. Plus CFG doubles this to ~2800 Gflops. Total inference: ~8-12 seconds on a single A100 GPU. VAE decode: ~0.2s. Total wall-clock: ~10-14 seconds per image.
Typography breakthrough: SD3 is the first open-source model to render text reliably in images. Previous models struggled because text tokens and image patches lived in separate worlds (cross-attention). MMDiT's joint attention lets the model correlate character-level text information with specific spatial patches, enabling coherent letter rendering.
On the GenEval benchmark, which capability shows the largest improvement of SD3 over DALL-E 3?

Chapter 7: Comparison with SDXL & DALL-E 3

Understanding what changed from SDXL to SD3 reveals why each design choice matters.

SDXL vs SD3: Architecture

SDXL
U-Net backbone • Cross-attention to text (one-way) • CLIP-L + CLIP-G • 4-channel latent • ε-prediction • DDPM schedule
SD3
MMDiT backbone • Joint attention (two-way) • CLIP-L + CLIP-G + T5-XXL • 16-channel latent • v-prediction (velocity) • Rectified flow

DALL-E 3 vs SD3

DALL-E 3 (OpenAI) uses a U-Net with cross-attention and a two-stage approach: a prior model generates CLIP embeddings from text, then a diffusion model generates images from those embeddings. SD3 eliminates this two-stage pipeline entirely — text encoders feed directly into MMDiT.

DALL-E 3's key innovation was training on highly descriptive synthetic captions. SD3 adopts this insight (50/50 original + CogVLM synthetic captions) while also advancing the architecture and noise formulation.

Step efficiency comparison

This is where rectified flow shines. SD3's performance degrades gracefully with fewer steps, while DDPM-based models collapse. At 5 steps, rectified flow formulations still produce coherent images, whereas traditional formulations produce blurry messes.

Path length comparison: The paper measures the total path length of learned trajectories. SD3's 8B model achieves a path length of 185.96, compared to 191.13 for the 500M model. Lower path length = straighter trajectory = fewer steps needed. The path length decreases with model size, confirming that larger models learn more efficient transport paths.

Complete training and inference budget

MetricSD3-2B (d=24)SD3-8B (d=38)
MMDiT params~2B~8B
Total system params~8.7B~14.7B
Training resolution256², then 1024² finetune256², then 1024² finetune
Training steps500K+500K+
VRAM for inference (bf16)~18 GB~30 GB
Steps for good quality28-50 Euler20-50 Euler
Inference time (A100)~5-7s~10-14s

The VRAM requirement is dominated by the text encoders at inference: T5-XXL alone needs ~9.4 GB in bf16. For memory-constrained deployment, T5 can be dropped (with the compositional quality degradation noted above) to save nearly 10 GB, or quantized to int8 (~4.7 GB).

What is the fundamental architectural difference between SDXL and SD3?

Chapter 8: The Autoencoder

Like its predecessors, SD3 operates in a compressed latent space, not in pixel space. A pretrained autoencoder maps images to and from this latent representation.

Why latent diffusion?

A 1024×1024 RGB image has ~3.1M values. Running a transformer with attention over 3.1M tokens is computationally infeasible. The autoencoder compresses the image by a factor of 8× in each spatial dimension: 1024×1024×3 → 128×128×16 — about 97% fewer values. The diffusion model works in this compact space, then the decoder reconstructs the final image.

16-channel latent space

This is a key upgrade from SDXL (4 channels) and previous Stable Diffusion versions. More channels = richer latent representation = higher reconstruction quality. The paper shows this systematically:

4 channels
FID 2.41 • SSIM 0.75 • PSNR 25.12
8 channels
FID 1.56 • SSIM 0.79 • PSNR 26.40
16 channels
FID 1.06 • SSIM 0.86 • PSNR 28.62

The 16-channel autoencoder reduces reconstruction FID by 56% compared to 4-channel. This is crucial because the autoencoder's reconstruction quality is an upper bound on the final image quality — the diffusion model can never generate images better than the autoencoder can reconstruct.

Patching

The 128×128×16 latent is further divided into 2×2 patches, yielding 64×64 = 4096 tokens, each of dimension 16×4 = 64. These tokens form the image sequence in MMDiT. This patch size matches DiT's design and keeps the sequence length manageable for attention.

The latent resolution stack: 1024×1024 pixels → 128×128×16 latent (autoencoder, 8× downsample) → 64×64 patches of dim 64 (2×2 patching) → 4096 tokens in MMDiT. The model never "sees" pixels — it only works with these 4096 abstract tokens.

VAE training (separate stage)

The 16-channel autoencoder is trained independently before any diffusion training begins. It uses a combination of:

Once trained, the VAE is frozen for all subsequent diffusion training. This two-stage approach is inherited from the original Latent Diffusion paper (Rombach et al., 2022). The VAE has ~80M parameters — small compared to the 8B MMDiT it supports.

Why the VAE is a bottleneck: The autoencoder's reconstruction quality is an absolute ceiling on final image quality. No matter how perfect the MMDiT becomes, it cannot generate details the VAE cannot reconstruct. This is why upgrading from 4 to 16 channels matters so much — it lifts the ceiling. Fine details like thin lines, small text, and fabric textures that were permanently lost in a 4-channel latent are now preserved.

Why does SD3 use 16 latent channels instead of the 4 channels used by SDXL?

Chapter 9: Connections

SD3 sits at the intersection of several research threads. Understanding these connections clarifies both where SD3 came from and where the field is going.

DiT (Peebles & Xie, 2023)

MMDiT is a direct descendant of DiT, which first showed that replacing U-Net with a transformer backbone for diffusion models works and scales. DiT used class-conditional generation with adaLN modulation. MMDiT extends this to text-to-image by adding the text stream and joint attention. In fact, DiT with concatenated text+image tokens is a special case of MMDiT with shared weights.

DDPM & Diffusion Models (Ho et al., 2020)

DDPM established the modern diffusion framework: define a forward noising process, train a network to reverse it. SD3's rectified flow is a different choice of forward process (— straight lines instead of curved variance-preserving paths —) but the same fundamental idea: learn to undo the noise.

Flow Matching (Lipman et al., 2023)

Rectified flow is a special case of the flow matching framework. Flow matching provides the theoretical foundation: you can train velocity fields by matching conditional vector fields along specified probability paths. Rectified flow simply chooses the straight-line path zt = (1−t)x0 + tε. SD3's contribution is showing this works at scale with the right timestep sampling.

Flux (Black Forest Labs, 2024)

Several SD3 authors later founded Black Forest Labs and created Flux, which builds on the same MMDiT foundation but introduces a hybrid architecture: early layers use MMDiT-style joint attention, while later layers use single-stream attention (text and image tokens are fully merged). Flux also drops the T5 encoder's contribution to the pooled embedding and simplifies the modulation mechanism.

Latent Diffusion (Rombach et al., 2022)

SD3 inherits the core LDM insight: train the diffusion model in a compressed latent space, not in pixel space. The key upgrade is moving from 4 to 16 latent channels and using a transformer backbone instead of a U-Net.

What each upgrade contributed

TransitionWhat changedWhat it improved
DDPM → LDMPixel space → latent space9x compute reduction, same quality
LDM → DiTU-Net → transformer backboneClean scaling laws, SOTA FID
DiT → SD3Class-cond → text-cond (MMDiT), DDPM → rectified flow, 4ch → 16ch VAEText generation, 4x fewer steps, sharper images
SD3 → FluxDual-stream → hybrid single/dual, simplified modulationFurther efficiency gains, production deployment
The evolution: DDPM (2020) → LDM/SD1 (2022, latent space + U-Net) → DiT (2023, transformer backbone) → SD3 (2024, MMDiT + rectified flow + 16ch latent) → Flux (2024, hybrid single/multi-stream). Each step replaced one component while keeping the others, and each step brought clear, measurable improvements.
How does MMDiT relate to the original DiT architecture?