Replace the U-Net in diffusion models with a Vision Transformer. The result: clean scaling laws, state-of-the-art image quality, and the backbone of Sora, SD3, and FLUX.
By 2022, diffusion models had become the gold standard for image generation. DALL-E 2, Imagen, Stable Diffusion — all of them produced breathtaking images. And all of them shared one thing in common beneath the surface: a U-Net backbone.
The U-Net was borrowed from medical image segmentation, adapted for diffusion by Ho et al. in DDPM (2020), and refined by Dhariwal & Nichol (ADM, 2021). For two years, every major diffusion model was essentially "same U-Net, more tricks." Nobody seriously questioned the backbone.
Meanwhile, in every other corner of deep learning, transformers were taking over. NLP. Vision. Protein folding. Reinforcement learning. The reason? Transformers scale predictably. Double the parameters, get a quantifiable performance boost. The relationship between compute and quality follows clean power laws you can plot on a graph and extrapolate from.
There was good reason to be skeptical. The U-Net's multi-resolution design — downsampling, processing at multiple scales, upsampling with skip connections — seemed essential to generating coherent images. Removing all of that and replacing it with a flat sequence of patches processed at one resolution felt like it shouldn't work.
But it did. Spectacularly. And the paper's most important contribution wasn't even the final FID score — it was showing that diffusion models could exhibit the same clean scaling laws that had driven transformers to dominate every other domain.
The U-Net has many interconnected, bespoke components at different spatial resolutions. The DiT is a flat, repeating structure — clean and scalable.
To understand why DiT is a breakthrough, you first need to understand exactly how the U-Net worked in diffusion models — and precisely where it falls short.
The U-Net is an encoder-decoder architecture. The encoder repeatedly downsamples the feature map (halving spatial dimensions each level) while increasing channels. The decoder does the reverse — upsampling back to the original resolution. At each level, skip connections pass the encoder's features directly to the corresponding decoder layer, preserving fine-grained spatial detail.
In diffusion, the U-Net takes a noisy image [B, C, H, W] and a timestep embedding, and outputs a predicted noise map of the same shape. The timestep is injected via addition into each ResNet block. Cross-attention layers at low-resolution levels handle text conditioning in models like Stable Diffusion.
Here's where the U-Net runs into trouble. To make a U-Net bigger, you have several options: increase channel counts, add more resolution levels, add more ResNet blocks per level, or add attention at more resolutions. But none of these is a clean, predictable scaling axis.
Increase channel counts? Compute grows as the square of channels. Add resolution levels? The model topology changes, requiring architectural redesign. Add attention everywhere? It becomes a different architecture. The ADM paper (Dhariwal & Nichol, 2021) found that more channels usually helps, more attention sometimes helps — but there was no clean formula predicting exactly how much.
There's also a less-obvious problem. When the entire community uses one backbone, all collective knowledge is tied to that architecture. Training tricks discovered for U-Nets don't transfer to language models. Insights from ViT pretraining don't apply. The community has to rediscover everything from scratch within the diffusion U-Net world.
The encoder-decoder structure with skip connections at multiple resolutions. Each level has a different spatial resolution and channel count.
DiT doesn't operate on raw pixels. It builds on the Latent Diffusion framework — the same one behind Stable Diffusion. First, a pretrained VAE compresses the image into a compact latent space. Then diffusion happens entirely in that smaller space.
A Variational Autoencoder (frozen during DiT training) maps a 256×256×3 RGB image to a 32×32×4 latent. That's an 8× spatial downsampling. The forward diffusion process adds Gaussian noise to this latent:
DiT's job is to predict the noise ε given zt and the timestep t. After denoising, the clean latent z0 is decoded back to pixels by the VAE decoder.
Now comes the key idea borrowed from Vision Transformers (ViT). The 32×32×4 latent is divided into non-overlapping patches of size p×p. Each patch is flattened and linearly projected to a d-dimensional embedding. This gives a sequence of T tokens that the transformer can process.
Let's work through the exact numbers. With patch size p=2:
| Patch size p | Tokens T | Values per patch | FID-50K (DiT-XL) |
|---|---|---|---|
| 2 | 256 | 16 | 9.62 (best) |
| 4 | 64 | 64 | 27.8 |
| 8 | 16 | 256 | 55.3 |
Move the slider to change patch size. See how it affects token count and the granularity of patch boundaries.
The core building block of DiT is a modified transformer layer. It looks almost identical to a standard ViT block — multi-head self-attention followed by a feed-forward network, with layer normalization — except for one critical change: how conditioning information enters.
In a standard transformer, the layer norm has fixed, learned scale (γ) and shift (β) parameters. In DiT, these parameters are computed dynamically from the conditioning vector (timestep + class label) at every forward pass. This is called adaptive layer normalization (adaLN).
Given input x (the token sequence) and conditioning vector c:
Where adaLN(γ, β, x) = γ · LayerNorm(x) + β. The parameters γ1, β1, α1, γ2, β2, α2 are all regressed from the conditioning vector c via a small MLP. Each block gets its own output projection from the shared MLP.
DiT is intentionally minimal. Compared to a standard ViT block, it removes nothing structural — it only modifies the layer norm to be adaptive. There are no:
All 28 blocks (in DiT-XL) process the same 256-token sequence at the same 1152-dimensional hidden size. The same-resolution, same-width design is what gives DiT its clean scaling behavior.
The data flow through a single DiT block. γ, β, α are all computed from the conditioning vector — not learned as fixed parameters.
The paper ablates four different ways to inject conditioning into transformer blocks. The winner — adaLN-Zero — achieves roughly half the FID of the simplest approach while adding negligible compute. Understanding why it wins requires understanding the zero initialization trick.
1. In-Context: Append the timestep and class embeddings as two extra tokens. Process them alongside image tokens. Overhead: ~0%. Simple, but the conditioning has limited influence — it's just two tokens in a sea of 256.
2. Cross-Attention: Add a cross-attention layer after each self-attention. Image tokens attend to conditioning tokens. Overhead: ~15% extra Gflops. Borrows from language model conditioning (like Stable Diffusion's text conditioning).
3. adaLN: Replace fixed layer norm parameters with ones computed from the conditioning vector. Overhead: ~0% (the MLP is tiny). Better than in-context and cross-attention, because conditioning influences every single token's normalization.
4. adaLN-Zero: Same as adaLN, plus: regress additional scaling gates α₁, α₂ that multiply the residual branches. Initialize these gates to zero. Overhead: ~0%. Best by a large margin.
Think about what happens at the start of training. Random weights produce random outputs. If the MHSA output is random noise, adding it back via the residual connection corrupts x. This makes the first few training steps chaotic — gradients are large and unstable.
With α initialized to zero: output = x + 0 × MHSA(x) = x. The block is exactly the identity function. Gradients flow cleanly through the residual path from the very first step. The network has a stable foundation to learn from, and α grows gradually as training teaches the block what to contribute.
Let's trace an example token x = [1.2, -0.3, 0.8] through a DiT block at initialization:
| Step | Operation | Value |
|---|---|---|
| 1 | LayerNorm(x) | [1.15, -0.89, 0.65] (normalized) |
| 2 | γ₁ · LN(x) + β₁ (γ=1, β=0 at init) | [1.15, -0.89, 0.65] |
| 3 | MHSA output (random at init) | [0.47, -1.23, 0.91] |
| 4 | α₁ = 0 (zero init) → 0 × MHSA | [0, 0, 0] |
| 5 | Residual: x + 0 = x | [1.2, -0.3, 0.8] (unchanged!) |
| 6 | Same for FFN branch with α₂ = 0 | [1.2, -0.3, 0.8] (still unchanged) |
After a few training steps, α₁ and α₂ become small non-zero values (say 0.01, 0.02). The block begins to learn. By 100K steps, α values reach around 0.1-0.3, and the blocks are contributing meaningfully.
FID-50K at 400K training steps. Lower is better. All models are DiT-XL/2.
Every DiT block needs to know two things: how noisy is this latent? (the timestep t) and what class should this image be? (the class label c). These aren't fed as pixel values — they go through a carefully designed embedding pipeline before influencing each block's normalization parameters.
The timestep t is a single integer (say, 523 out of 1000). To make it useful for a neural network, DiT embeds it using sinusoidal frequency embeddings — the same technique from the original Transformer paper. Different frequencies capture different timescales:
This 256-dimensional sinusoidal embedding is then passed through a 2-layer MLP (with SiLU activation) to produce a d-dimensional vector: embed_t: int → Rd.
The class label c (an integer 0–999 for ImageNet) is embedded via a standard learned embedding table: a matrix of shape [1000, d]. Each class gets its own d-dimensional vector, learned from scratch during training. embed_c: int → Rd.
For unconditional generation (needed for classifier-free guidance), there's one extra learned embedding: the null embedding ∅, used when the class label is dropped during training (10% of the time).
The two embeddings are simply added:
This combined conditioning vector c (shape [d]) is shared across all 28 DiT blocks. Each block has its own linear projection layer that maps c to 6 scalars: γ₁, β₁, α₁, γ₂, β₂, α₂. This is a tiny linear layer — 6 outputs for each block — and it's all the architectural machinery needed to make conditioning work.
| Component | Shape | Parameters |
|---|---|---|
| Sinusoidal embed (fixed) | — | 0 (no learning) |
| Timestep MLP (2-layer) | 256→d→d | ~2.7M |
| Class embed table | 1000×d | ~1.2M |
| Per-block projection (×28) | d→6 | ~0.2M total |
| Total conditioning | ~4.1M / 675M total |
Less than 1% of DiT-XL's parameters are used for conditioning. The overwhelming majority (99%+) live in the 28 transformer blocks. This is the point: conditioning is an afterthought in terms of compute. The transformer does the heavy lifting.
Trace the timestep and class label from raw integers to the γ, β, α parameters that modulate each DiT block.
Now we can trace a complete forward pass through DiT — from a noisy latent zt all the way to a predicted noise map. Let's walk through every operation with exact tensor shapes for DiT-XL/2 (patch size p=2, hidden dim d=1152, 28 blocks).
The final linear layer maps each 1152-dim token to a 32-dim vector. Why 32? Each token corresponds to a 2×2 patch with 4 latent channels. The output predicts both noise and variance for that patch: 2×2×4 channels for ε + 2×2×4 channels for Σ = 2×2×8 = 32 values per token.
Unpatchifying is the exact reverse of patchifying: rearrange the [256, 32] output back to [32, 32, 8] spatial format, then split the last dimension into ε̂ [32, 32, 4] and Σ̂ [32, 32, 4].
python — complete DiT forward pass (simplified) def forward(self, z_t, t, c): # z_t: [B, 4, 32, 32] t: [B] c: [B] (class labels) B = z_t.shape[0] # 1. Patchify: [B, 4, 32, 32] → [B, 256, 1152] x = self.patch_embed(z_t) # linear: (p*p*4) → d x = x + self.pos_embed # add fixed sine-cosine embed # 2. Conditioning vector: [B, d] cond = self.t_embed(t) + self.c_embed(c) # 3. 28 DiT blocks for block in self.blocks: x = block(x, cond) # adaLN-Zero inside # 4. Final norm + linear decode x = self.final_norm(x, cond) # one more adaLN x = self.final_linear(x) # [B, 256, 32] # 5. Unpatchify: [B, 256, 32] → [B, 8, 32, 32] x = self.unpatchify(x) # 6. Split noise + variance eps_hat, sigma_hat = x.chunk(2, dim=1) return eps_hat, sigma_hat # each [B, 4, 32, 32]
Every stage of the forward pass, with exact shapes for DiT-XL/2 (p=2, d=1152, N=28).
The most important result in the DiT paper isn't the final FID score. It's what happens when you plot FID against Gflops across 12 different model variants.
Peebles and Xie train every combination of 4 model sizes (S, B, L, XL) × 3 patch sizes (2, 4, 8), all for 400K training steps on ImageNet 256×256. Then they ask: what predicts quality?
Across all 12 models, there is a -0.93 correlation between model Gflops and FID-50K. On a log-log plot, the 12 points fall almost on a straight line. This is remarkable. It means you can predict a DiT model's image quality from its compute cost alone, with no other information needed.
This is the transformer's greatest strength: predictable scaling. For U-Nets, no equivalent relationship exists. For transformers, it's practically a law.
Here's a subtle but crucial point. When you decrease patch size from p=4 to p=2, you quadruple T (tokens) and thus quadruple the Gflops. But parameter count barely changes — the transformer weights are identical, you're just processing more tokens per forward pass.
Yet FID improves substantially. This means compute (Gflops) per image, not total parameter count, is the true driver of quality. DiT-S/2 (33M params, 5.8 Gflops) achieves better FID than DiT-L/8 (458M params, 5.3 Gflops) — despite having 14× fewer parameters — because it uses more compute per forward pass.
When you plot FID against total training compute (Gflops × steps), larger models reach any FID threshold faster. A large model trained for fewer steps beats a small model trained longer. This matters for production: if you have a fixed compute budget, train the largest model you can afford, not a small model for longer.
| Model | Layers | Hidden d | Heads | Params | Gflops (p=2) | FID (400K steps) |
|---|---|---|---|---|---|---|
| DiT-S/2 | 12 | 384 | 6 | 33M | 5.8 | ~68 |
| DiT-B/2 | 12 | 768 | 12 | 130M | 22.6 | ~43 |
| DiT-L/2 | 24 | 1024 | 16 | 458M | 80.0 | ~24 |
| DiT-XL/2 | 28 | 1152 | 16 | 675M | 118.6 | ~19 → 2.27* |
*DiT-XL/2 achieves FID 9.62 without CFG, FID 2.27 with CFG (guidance scale 1.50) after 7M training steps.
All 12 model variants at 400K steps. Each point is a (size, patch) combination. The trend is nearly linear on this log-log plot — correlation -0.93.
Here's the payoff. Watch DiT denoise a latent from pure noise to structured signal, step by step. Choose a model size, a class label, and a guidance scale — then step through the denoising process and see how the image emerges.
The visualization shows the 16×16 patch grid (for p=2, a 32×32 latent becomes 256 patches arranged in a 16×16 grid). Color encodes signal-to-noise ratio: cool blue = mostly noise, warm orange = mostly signal. As t decreases from 250→0, patches progressively "lock in" their values.
Step through the reverse diffusion process. Watch signal emerge from noise as t decreases.
DiT (2023) didn't just introduce a new architecture. It changed the entire trajectory of visual generation. Here's how the ideas evolved — and where they landed.
| Model | Kept from DiT | Changed | Impact |
|---|---|---|---|
| SD3/MMDiT | adaLN-Zero blocks, patchification, latent space | Separate text stream, bi-directional cross-attention between text and image streams | Better text rendering, stronger prompt following |
| FLUX | MMDiT architecture, latent space | Flow matching loss (replaces DDPM), distillation for fast sampling | 4-step inference without quality loss |
| Sora | DiT blocks, patchification | 3D spatiotemporal patches, variable resolution & duration, massive scale | Coherent multi-second video generation |
| SiT | DiT blocks | Stochastic interpolants instead of DDPM | Better theoretical understanding of the diffusion objective |
The U-Net's core inductive bias was: you need to process images at multiple spatial scales simultaneously, with skip connections to preserve detail. DiT proved this is false. Self-attention over a flat patch sequence is sufficient — it implicitly learns multi-scale reasoning without the architectural scaffolding.
This matters because it means the transformer's core strength — attending to arbitrary pairs of tokens — is all you need. There's no special sauce in the U-Net that couldn't be learned from data.
"The transformer didn't just replace the U-Net — it opened diffusion models to the same scaling playbook that transformed language modeling. We no longer have to ask 'will making it bigger help?' The answer is yes, quantifiably, predictably, always."
— Inspired by Peebles & Xie, 2023