Peebles & Xie — UC Berkeley / NYU, 2023

Scalable Diffusion Models with Transformers

Replace the U-Net backbone in diffusion models with a Vision Transformer. Patchify latent representations, condition with adaptive layer norm, and watch FID drop as you scale compute — just like language models.

Prerequisites: Diffusion models (DDPM) + Vision Transformers + Latent Diffusion
10
Chapters
5+
Simulations

Chapter 0: The Problem

By 2022, diffusion models had become the dominant paradigm for image generation. DALL-E 2, Imagen, Stable Diffusion — all of them produced stunning images. And all of them used the same backbone: a U-Net.

The U-Net was inherited from early pixel-level models (PixelCNN++) and adapted by Ho et al. for DDPM. It's convolutional, with ResNet blocks at multiple resolutions and self-attention sprinkled in at lower resolutions. Dhariwal and Nichol (ADM) ablated some architectural choices — channel counts, normalization layers — but the high-level design remained essentially unchanged from 2020 to 2022.

Meanwhile, a different story was playing out everywhere else in deep learning. Transformers had taken over NLP, vision (ViT), reinforcement learning, and even protein folding. The reason? Transformers scale. Double the compute, get predictably better performance. Clean scaling laws. No architectural ceilings.

But diffusion models were stuck with U-Nets. Nobody had seriously asked: what happens if you replace the U-Net with a transformer?

The core tension: U-Nets work well for diffusion, but they are a bespoke convolutional architecture with domain-specific inductive biases (skip connections across resolutions, multi-scale processing). Transformers have proven more scalable in every other domain. Can we bring the scaling properties of transformers to diffusion models?

There were practical concerns too. U-Nets are awkward to scale — you can add channels or attention heads, but there's no clean "just make it bigger" knob like there is for transformers (where you simply increase depth and width). The architecture is also hard to unify across modalities, making it difficult to share training recipes or transfer insights from language modeling.

Why is the U-Net backbone potentially limiting for diffusion model scaling?

Chapter 1: The Key Insight

Peebles and Xie's insight is beautifully simple: use a standard Vision Transformer as the denoising backbone in a latent diffusion model.

The recipe:

  1. Take a noised latent representation (from a pretrained VAE)
  2. Patchify it into a sequence of tokens, just like ViT does with images
  3. Add positional embeddings
  4. Process through a stack of transformer blocks
  5. Decode back to predict noise and covariance

That's it. No multi-scale processing. No skip connections across resolutions. No convolutional layers (except inside the frozen VAE). Just a plain transformer operating on a flat sequence of patches.

Parameter budget breakdown (DiT-XL/2)

ComponentParametersNotes
Patch embedding~74KLinear: (p×p×4) → 1152 = 16 × 1152 + bias
Positional embedding~295K256 positions × 1152 dims (sine-cosine, not learned)
28 DiT blocks~668MPer block: MHSA (4 × 1152²) + FFN (2 × 1152 × 4608) + adaLN MLP
adaLN conditioning MLP~5M sharedMaps t+c embedding to 6 scalars per block (shared MLP, block-specific outputs)
Final linear decode~37K1152 → p×p×2C = 1152 → 32
Label embedding~1.2M1000 classes × 1152 dim lookup table
Timestep MLP~2.7MSinusoidal embed → 2-layer MLP → 1152 dim
Total DiT-XL/2675MAll trained from scratch

The overwhelming majority (99%+) of parameters sit in the 28 transformer blocks. The conditioning, embedding, and decode layers are negligible. This is what makes DiT so clean: scale the blocks and everything else is noise.

Why this is surprising: The U-Net's multi-resolution structure — downsampling, processing, upsampling with skip connections — seemed essential. It lets the network reason at multiple scales simultaneously. Removing all of that and replacing it with a flat sequence of patches processed at a single resolution shouldn't work as well. But it does. The self-attention mechanism in the transformer implicitly learns to handle multi-scale reasoning without the architectural scaffolding.

The paper calls this architecture DiT: Diffusion Transformer. The key finding is that DiTs exhibit the same clean scaling behavior that makes transformers so powerful in language: more Gflops = lower FID, with strong correlation (-0.93). This means you can improve image quality simply by making the model bigger, with no architectural changes needed.

The best model, DiT-XL/2, achieves a state-of-the-art FID of 2.27 on class-conditional ImageNet 256x256, outperforming all prior diffusion models while using fewer Gflops than pixel-space U-Net models like ADM.

What is the central finding of the DiT paper?

Chapter 2: Latent Diffusion Setup

DiT doesn't operate on raw pixels. It uses the Latent Diffusion Model (LDM) framework from Rombach et al. — the same framework behind Stable Diffusion. This is a two-stage approach:

Stage 1: Learn a VAE

A variational autoencoder is trained to compress images into a compact latent space. For a 256x256x3 RGB image, the VAE encoder E produces a latent z = E(x) with shape 32x32x4 — an 8x spatial downsampling with 4 channels. DiT uses the off-the-shelf VAE from Stable Diffusion, frozen during DiT training.

Stage 2: Diffuse in latent space

The forward diffusion process adds Gaussian noise to z:

zt = √(ᾱt) · z0 + √(1 − ᾱt) · ε,   ε ~ N(0, I)

The DiT model learns to reverse this process — given a noisy latent zt and timestep t, predict the noise ε. After denoising, the clean latent z0 is decoded back to an image via the VAE decoder: x = D(z).

Why latent space? Training diffusion directly on 256x256x3 pixels is expensive. ADM, a pixel-space U-Net model, requires 1120 Gflops per forward pass. By compressing to 32x32x4 latents first, DiT-XL/2 achieves better results with only 118.6 Gflops — a 9.4x reduction in compute. The VAE does the heavy lifting of learning pixel-level details; the diffusion model only needs to learn the high-level structure.

The complete data flow with exact shapes

Let's trace a single image through the full DiT pipeline, tracking every tensor shape:

Input image
256 × 256 × 3 (RGB pixels)
VAE Encoder E(x)
256 × 256 × 3 → 32 × 32 × 4 latent (8× spatial downsample, 4 channels)
Forward diffusion
32 × 32 × 4 → 32 × 32 × 4 noised latent zt
Patchify (p=2)
32 × 32 × 4 → 256 tokens of dim d=1152 (each token = 2×2×4 = 16 values, linearly projected to d)
+ Positional embedding
256 × 1152 (sine-cosine embeddings added)
28 × DiT Block (adaLN-Zero)
256 × 1152 → 256 × 1152 (self-attention over all 256 tokens per block)
Final layer norm + Linear decode
256 × 1152 → 256 tokens × (2×2×8) = 32 × 32 × 8 (noise + diagonal covariance)
Split output
32 × 32 × 4 predicted noise ε + 32 × 32 × 4 predicted variance Σ
VAE Decoder D(z0)
32 × 32 × 4 clean latent → 256 × 256 × 3 output image
Frozen vs. trained: The VAE (both encoder and decoder) is pretrained from Stable Diffusion and completely frozen during DiT training. Only the DiT transformer weights are trained. This separation is critical — it means DiT never needs to learn pixel-level details, only the high-level latent structure. The VAE has ~80M parameters; DiT-XL has 675M parameters, all trained from scratch on ImageNet.
What is the shape of the latent representation z for a 256x256x3 input image in DiT?

Chapter 3: DiT Architecture

Here's the complete forward pass of DiT, step by step. This is the core of the paper.

Step 1: Patchify

The noised latent zt (shape 32x32x4) is divided into non-overlapping patches of size p x p. Each patch is linearly embedded into a d-dimensional token vector. For patch size p, the number of tokens is:

T = (I / p)2

Where I = 32 (the spatial dimension of the latent). With p=2, you get T = 256 tokens. With p=8, just T = 16 tokens. Halving p quadruples T and thus quadruples the Gflops (since self-attention is quadratic in sequence length).

Step 2: Positional embeddings

Standard sine-cosine frequency-based positional embeddings are added to all tokens, exactly as in the original ViT. Nothing fancy.

Step 3: N transformer blocks

The token sequence passes through N DiT blocks. Each block contains multi-head self-attention and a pointwise feedforward network (MLP), with layer normalization. The key modification is how conditioning information (timestep t and class label c) enters the block — via adaptive layer normalization with zero initialization (adaLN-Zero). We'll detail this in the next chapter.

Step 4: Decode

After the final block, a final adaptive layer norm is applied, followed by a linear projection that maps each d-dimensional token to a p x p x 2C tensor (predicting both noise and diagonal covariance). The decoded tokens are rearranged back into spatial layout to produce the final noise prediction εθ and covariance Σθ.

The design philosophy: DiT is intentionally as close to a standard ViT as possible. No multi-scale processing, no skip connections between blocks at different depths, no convolutional layers. This faithfulness to the vanilla transformer architecture is what gives DiT its scaling properties — you can simply increase N (depth) and d (width) following standard ViT configs (S, B, L, XL).

Engineering decisions: why these choices?

Why patchify instead of pixel-level tokens? A 32×32 latent has 1024 spatial positions. Self-attention is O(N²), so 1024 tokens would cost 16× the compute of 256 tokens (p=2). Patchification is a compute-quality tradeoff. With p=2, you get 256 tokens — manageable for attention — while still preserving fine-grained spatial information within each patch via the linear projection.

Why predict both noise AND variance? Standard DDPM predicts only noise ε and uses a fixed variance schedule. But Nichol & Dhariwal (2021) showed that learning the variance improves log-likelihood and sample quality, especially with fewer sampling steps. DiT outputs 2C = 8 channels: 4 for noise prediction, 4 for an interpolation parameter v that blends between the fixed DDPM upper and lower variance bounds. The variance prediction adds zero parameters (just doubles the output projection) but meaningfully helps quality.

Why adaLN-Zero instead of cross-attention for conditioning? Cross-attention adds 15% Gflops overhead (extra QKV projections + attention computation) for a length-2 conditioning sequence. adaLN-Zero adds negligible overhead (a single MLP that regresses 6 scalars per block: γ1, β1, α1, γ2, β2, α2) yet achieves lower FID. The insight: conditioning on timestep and class doesn't need per-token flexibility (every patch should denoise the same amount), so a global modulation suffices.

Model configurations

ModelLayers NHidden dim dHeadsGflops (p=4)
DiT-S1238461.4
DiT-B12768125.6
DiT-L2410241619.7
DiT-XL2811521629.1

With patch size p=2, DiT-XL/2 reaches 118.6 Gflops. The naming convention is DiT-{size}/{patch_size}.

What does decreasing the patch size p do to the number of tokens and compute?

Chapter 4: Conditioning Mechanisms

A diffusion model needs to know two things beyond the noised input: the timestep t (how noisy is this?) and the class label c (what should this image be?). The question is: how do you inject this conditioning information into each transformer block?

DiT explores four approaches. All of them first embed t and c into vector representations using learned MLPs, then combine them (usually by summing). The difference is how this combined conditioning vector enters each transformer block.

1. In-Context Conditioning

Simply append the conditioning embeddings as two extra tokens in the input sequence. The transformer processes them alongside image tokens with no architectural changes. After the final block, remove the extra tokens. Gflops overhead: negligible.

2. Cross-Attention

Add a cross-attention layer after self-attention in each block. The image tokens attend to the conditioning tokens (a length-2 sequence of t and c embeddings). This is similar to how the original transformer decoder attends to encoder outputs, and how LDM conditions on text. Gflops overhead: ~15%.

3. Adaptive Layer Norm (adaLN)

Replace the standard learnable scale (γ) and shift (β) parameters in layer norm with ones regressed from the conditioning vector. Instead of learning fixed normalization parameters, the model computes them as a function of t + c. This applies the same transformation to all tokens (unlike cross-attention, which can apply different weights per token). Gflops overhead: minimal.

4. adaLN-Zero (the winner)

Same as adaLN, but with a critical addition: regress additional scaling parameters α that are applied immediately before each residual connection. The α parameters are initialized to zero, which means each DiT block starts as the identity function — it passes the input straight through. The network gradually learns to "turn on" each block during training.

h = x + α1 · MHSA(adaLN(γ1, β1, x))
out = h + α2 · FFN(adaLN(γ2, β2, h))

Where γ, β, α are all regressed from the conditioning embedding via a shared MLP.

Why zero initialization matters: This is inspired by a trick from ResNet training — Goyal et al. found that zero-initializing the final batch norm in each residual block accelerates training. U-Net diffusion models use a similar trick (zero-init final conv). For DiT, zero-initializing α means each block is the identity at initialization, so the full DiT block is a no-op. Gradients flow cleanly through the residual connections from the start, giving the network a stable foundation to learn from.

The verdict

At 400K training steps, adaLN-Zero achieves roughly half the FID of in-context conditioning. Cross-attention is better than in-context but worse than adaLN-Zero, despite costing 15% more Gflops. Vanilla adaLN (without zero init) is also worse than adaLN-Zero, confirming that the zero initialization matters. All subsequent DiT models use adaLN-Zero.

Why does adaLN-Zero outperform vanilla adaLN?

Chapter 5: Scaling Laws

The most important result in the paper isn't the final FID number — it's the scaling behavior. Peebles and Xie train 12 DiT models spanning 4 model sizes (S, B, L, XL) and 3 patch sizes (8, 4, 2), then plot FID against Gflops.

Finding 1: More Gflops = Lower FID

Across all 12 models, there is a -0.93 correlation between model Gflops and FID-50K at 400K training steps. This is remarkably clean — almost a straight line on a log-log plot. You can predict a DiT model's quality from its Gflops alone.

Finding 2: Gflops matter more than parameters

Here's a subtle but crucial point. When you decrease the patch size (say from p=4 to p=2), you quadruple the number of tokens and thus the Gflops. But the parameter count barely changes — the transformer weights are the same, you're just processing more tokens. Yet FID improves substantially. This means compute (Gflops), not parameter count, is the true driver of quality.

Models with similar Gflops achieve similar FID regardless of how they get there (bigger model vs. smaller patches). For example, DiT-S/2 and DiT-B/4 have similar Gflops and similar FID.

Finding 3: Larger models are more compute-efficient

When you plot FID against total training compute (Gflops x batch size x steps x 3), larger models reach any given FID threshold with less total compute. A small model trained for a long time is eventually overtaken by a large model trained for fewer steps. This mirrors the compute-optimal scaling behavior seen in language models (Chinchilla).

Training efficiency across model sizes

A key subtlety: larger DiT models are more compute-efficient at reaching a given quality level. Consider what it takes to reach FID 50:

DiT-XL/2 uses more compute per step but reaches the target in fewer steps, with similar total compute. And unlike DiT-S, the XL model keeps improving beyond FID 50 — it hasn't saturated. This mirrors the "Chinchilla" finding in language models: it's better to train a large model for fewer steps than a small model for many steps.

This is the paper's legacy: Before DiT, it wasn't clear that diffusion models could exhibit clean scaling laws. U-Nets don't have an obvious "make it bigger" axis, and their performance doesn't correlate as cleanly with compute. DiT shows that once you adopt the transformer architecture, diffusion models inherit the same predictable scaling that has driven progress in language modeling. This is what made DiT so influential — not just the FID number, but the promise of a clear path to ever-better generative models.
What is the correlation between DiT model Gflops and FID-50K?

Chapter 6: Results

After the scaling analysis, Peebles and Xie train their best model — DiT-XL/2 — for 7 million steps (up from the 400K used in ablations). The results speak for themselves.

256x256 ImageNet (class-conditional)

With classifier-free guidance (cfg scale = 1.50):

ModelFID ↓sFID ↓IS ↑PrecisionRecall
ADM-G4.595.25186.70.820.52
ADM-G + ADM-U3.946.14215.80.830.53
LDM-4-G (cfg=1.50)3.60247.70.870.48
DiT-XL/2-G (cfg=1.50)2.274.60278.20.830.57

DiT-XL/2 achieves 2.27 FID, beating the previous best of 3.60 from LDM-4 by a large margin. It also achieves the highest Inception Score (278.2) and a strong balance between Precision (0.83) and Recall (0.57).

512x512 ImageNet

DiT-XL/2 also sets a new state-of-the-art at 512x512 resolution with an FID of 3.04, outperforming ADM-G + ADM-U (3.85 FID) while being substantially more compute-efficient.

Compute efficiency

DiT-XL/2 uses 118.6 Gflops per forward pass. Compare this to:

DiT achieves better FID than all of these despite comparable or lower compute cost.

What degrades and when

The scaling analysis reveals clear degradation patterns:

Concrete training numbers

Training budget: DiT-XL/2 was trained for 7M steps at batch size 256 on ImageNet (1.28M images, 1000 classes). That's ~1.79 billion images seen. At 118.6 Gflops per forward pass × 3 (forward + backward + EMA), total training compute is approximately 1.27 × 1021 FLOPs (or ~1.27 ZettaFLOPs). For reference, training GPT-3 took ~3.1 × 1023 FLOPs — DiT-XL/2 is ~250x cheaper. Inference: 250 sampling steps × 118.6 Gflops × 2 (CFG) = ~59,300 Gflops per image, or ~30 seconds on a single A100 GPU.
What FID does DiT-XL/2 achieve on class-conditional ImageNet 256x256 with classifier-free guidance?

Chapter 7: CFG and Sampling

DiT's strong results depend on classifier-free guidance (CFG), a technique that dramatically improves sample quality at the cost of some diversity. Let's understand how it works.

The intuition

During sampling, you want images x where the class probability p(c|x) is high — if you asked for "golden retriever," the image should clearly look like a golden retriever, not an ambiguous blob. By Bayes' rule:

x log p(c|x) ∝ ∇x log p(x|c) − ∇x log p(x)

The gradient of the class-conditional score minus the unconditional score points toward images that strongly belong to class c.

The CFG formula

Classifier-free guidance modifies the noise prediction during sampling:

ε̂θ(zt, c) = εθ(zt, ∅) + s · (εθ(zt, c) − εθ(zt, ∅))

Where s > 1 is the guidance scale (s = 1 recovers standard sampling). Each sampling step requires two forward passes: one conditioned on c, one unconditional (with a learned null embedding ∅). The unconditional pass is enabled by randomly dropping the class label during training (replacing it with ∅).

The precision-recall tradeoff

Higher guidance scale s pushes the model toward higher-fidelity but lower-diversity samples. For DiT-XL/2:

CFG reduces FID by over 4x but slightly reduces diversity (recall drops from 0.67 to 0.57). This tradeoff is well-known and consistent across all diffusion model architectures.

Training for CFG: During training, the class label c is randomly replaced with a learned null embedding ∅ 10% of the time. This teaches the model to produce both conditional and unconditional noise predictions, enabling CFG at inference time with no architectural changes. DiT uses 250 DDPM sampling steps with a standard linear noise schedule.

Inference cost breakdown

With CFG, every sampling step requires two forward passes (one conditioned, one unconditional). The full cost:

Compare to ADM (pixel-space U-Net): 250 steps × 2 × 1120 = 560,000 Gflops — nearly 10x more expensive for worse results. This is the fundamental efficiency win of latent diffusion: the 9.4x compute reduction from working in latent space compounds across all 250 steps.

Why does classifier-free guidance require two forward passes per sampling step?

Chapter 8: Why Transformers Win

DiT's success isn't just about one paper's results. It's about why transformers are fundamentally better suited as a backbone for scaling generative models. Let's unpack the structural advantages.

1. Clean scaling axes

Transformers have two orthogonal axes for scaling: depth (number of layers N) and width (hidden dimension d). Doubling either increases Gflops predictably. U-Nets have channels, resolution levels, attention layers at specific resolutions — but these interact in complex, non-linear ways. There's no clean "make it 2x bigger" knob.

2. Hardware efficiency

Modern GPUs and TPUs are optimized for the dense matrix multiplications that dominate transformer computation (attention QKV projections, FFN layers). U-Nets mix convolutions at various spatial resolutions with attention at lower resolutions — this heterogeneous compute profile is harder to optimize and often leaves hardware underutilized.

3. Architecture unification

If your image generator, text model, and video model all use transformers, you can share training recipes, optimization tricks, and infrastructure. The community's collective knowledge about transformer training (learning rate schedules, initialization, regularization, mixed precision) transfers directly. With U-Nets, every insight had to be re-discovered within the diffusion community.

4. Flexibility for conditioning

Transformers naturally handle variable-length sequences and cross-attention over conditioning tokens. This makes it straightforward to condition on text embeddings, multiple images, or any other modality. U-Nets require bespoke injection points for each conditioning type.

The big picture: DiT showed that the U-Net's inductive biases (multi-scale processing, skip connections) are not necessary for high-quality diffusion. The transformer's ability to learn arbitrary interactions between patches via self-attention is sufficient. And unlike the U-Net, the transformer comes with a proven playbook for scaling that has been validated across language, vision, and every other domain.
What key inductive bias of U-Nets did DiT show is NOT necessary for high-quality diffusion?

Chapter 9: Connections

DiT sits at a pivotal point in the evolution of generative models. Here's how it connects to the broader landscape.

Predecessors

Contemporaries and successors

DiT's lasting impact: Nearly every major image and video generation system released after DiT has adopted a transformer-based backbone. The paper's contribution wasn't just a better FID number — it was a paradigm shift. It showed the diffusion community that the U-Net era was over, and that the path to better generative models runs through the same scaling playbook that transformed NLP. Today, "diffusion transformer" is the default architecture for visual generation, exactly as DiT predicted.
Which of these systems directly builds on DiT's transformer-based diffusion architecture?