Deep Learning Architecture

Diffusion Transformer
(DiT)

Replace the U-Net in diffusion models with a Vision Transformer. The result: clean scaling laws, state-of-the-art image quality, and the backbone of Sora, SD3, and FLUX.

Prerequisites: Diffusion models (see our Diffusion lesson) + basic transformer intuition. No measure theory required.
10
Chapters
7+
Simulations
1
Big Idea

Chapter 0: Why Replace the U-Net?

By 2022, diffusion models had become the gold standard for image generation. DALL-E 2, Imagen, Stable Diffusion — all of them produced breathtaking images. And all of them shared one thing in common beneath the surface: a U-Net backbone.

The U-Net was borrowed from medical image segmentation, adapted for diffusion by Ho et al. in DDPM (2020), and refined by Dhariwal & Nichol (ADM, 2021). For two years, every major diffusion model was essentially "same U-Net, more tricks." Nobody seriously questioned the backbone.

Meanwhile, in every other corner of deep learning, transformers were taking over. NLP. Vision. Protein folding. Reinforcement learning. The reason? Transformers scale predictably. Double the parameters, get a quantifiable performance boost. The relationship between compute and quality follows clean power laws you can plot on a graph and extrapolate from.

The question nobody was asking: What if we replaced the U-Net in a diffusion model with a standard Vision Transformer? Would it work? Would it scale? Peebles and Xie (2023) asked this question — and the answer changed the entire field.

There was good reason to be skeptical. The U-Net's multi-resolution design — downsampling, processing at multiple scales, upsampling with skip connections — seemed essential to generating coherent images. Removing all of that and replacing it with a flat sequence of patches processed at one resolution felt like it shouldn't work.

But it did. Spectacularly. And the paper's most important contribution wasn't even the final FID score — it was showing that diffusion models could exhibit the same clean scaling laws that had driven transformers to dominate every other domain.

Architecture Complexity: U-Net vs Transformer

The U-Net has many interconnected, bespoke components at different spatial resolutions. The DiT is a flat, repeating structure — clean and scalable.

Why was the U-Net considered a bottleneck for scaling diffusion models?

Chapter 1: The U-Net Problem

To understand why DiT is a breakthrough, you first need to understand exactly how the U-Net worked in diffusion models — and precisely where it falls short.

The U-Net is an encoder-decoder architecture. The encoder repeatedly downsamples the feature map (halving spatial dimensions each level) while increasing channels. The decoder does the reverse — upsampling back to the original resolution. At each level, skip connections pass the encoder's features directly to the corresponding decoder layer, preserving fine-grained spatial detail.

In diffusion, the U-Net takes a noisy image [B, C, H, W] and a timestep embedding, and outputs a predicted noise map of the same shape. The timestep is injected via addition into each ResNet block. Cross-attention layers at low-resolution levels handle text conditioning in models like Stable Diffusion.

Why skip connections? After heavy downsampling, fine spatial detail is lost in the bottleneck. Skip connections act as information highways — the decoder gets access to both the low-level features (from encoder) and the high-level compressed features (from the bottleneck). This is why U-Nets produce crisp, spatially coherent outputs.

The Scaling Problem

Here's where the U-Net runs into trouble. To make a U-Net bigger, you have several options: increase channel counts, add more resolution levels, add more ResNet blocks per level, or add attention at more resolutions. But none of these is a clean, predictable scaling axis.

Increase channel counts? Compute grows as the square of channels. Add resolution levels? The model topology changes, requiring architectural redesign. Add attention everywhere? It becomes a different architecture. The ADM paper (Dhariwal & Nichol, 2021) found that more channels usually helps, more attention sometimes helps — but there was no clean formula predicting exactly how much.

The transformer scaling formula: For a transformer with N layers and hidden dimension d, roughly: Params ≈ 12·N·d², Gflops ≈ 6·N·d²·T (where T is sequence length). Want 4× more compute? Double d. This simple math doesn't exist for U-Nets. Their performance doesn't correlate neatly with any single parameter.

The Architecture Monoculture Risk

There's also a less-obvious problem. When the entire community uses one backbone, all collective knowledge is tied to that architecture. Training tricks discovered for U-Nets don't transfer to language models. Insights from ViT pretraining don't apply. The community has to rediscover everything from scratch within the diffusion U-Net world.

U-Net Architecture in Diffusion

The encoder-decoder structure with skip connections at multiple resolutions. Each level has a different spatial resolution and channel count.

What do skip connections do in a U-Net?

Chapter 2: Patchify the Latent

DiT doesn't operate on raw pixels. It builds on the Latent Diffusion framework — the same one behind Stable Diffusion. First, a pretrained VAE compresses the image into a compact latent space. Then diffusion happens entirely in that smaller space.

Step 1: VAE Compression

A Variational Autoencoder (frozen during DiT training) maps a 256×256×3 RGB image to a 32×32×4 latent. That's an 8× spatial downsampling. The forward diffusion process adds Gaussian noise to this latent:

zt = √ᾱt · z0 + √(1−ᾱt) · ε,   ε ~ N(0, I)

DiT's job is to predict the noise ε given zt and the timestep t. After denoising, the clean latent z0 is decoded back to pixels by the VAE decoder.

Why latent space? Pixel-space diffusion on 256×256 images is massively expensive — ADM uses 1120 Gflops per forward pass. Operating on 32×32×4 latents instead cuts the spatial area by 64×. DiT-XL achieves better quality than ADM at only 118.6 Gflops. The VAE handles pixel-level detail; DiT only needs to learn high-level structure.

Step 2: Patchify

Now comes the key idea borrowed from Vision Transformers (ViT). The 32×32×4 latent is divided into non-overlapping patches of size p×p. Each patch is flattened and linearly projected to a d-dimensional embedding. This gives a sequence of T tokens that the transformer can process.

T = (I/p)2    where I = 32 (latent spatial size)

Let's work through the exact numbers. With patch size p=2:

Patch size is a compute-quality tradeoff. Halving p quadruples T (since T = (I/p)²). Self-attention costs O(T²), so halving p increases attention compute by 16×. DiT uses p=2 for the best results, giving 256 tokens — manageable but rich in spatial detail. With p=8 you only get 16 tokens: fast but blurry.
Patch size pTokens TValues per patchFID-50K (DiT-XL)
2256169.62 (best)
4646427.8
81625655.3
Interactive: Patchify a Latent

Move the slider to change patch size. See how it affects token count and the granularity of patch boundaries.

Patch size p2
A 32×32×4 latent is patchified with p=4. How many tokens result?

Chapter 3: The DiT Block

The core building block of DiT is a modified transformer layer. It looks almost identical to a standard ViT block — multi-head self-attention followed by a feed-forward network, with layer normalization — except for one critical change: how conditioning information enters.

In a standard transformer, the layer norm has fixed, learned scale (γ) and shift (β) parameters. In DiT, these parameters are computed dynamically from the conditioning vector (timestep + class label) at every forward pass. This is called adaptive layer normalization (adaLN).

The Block Equations

Given input x (the token sequence) and conditioning vector c:

h = x + α1 · MHSA( adaLN(γ1, β1, x) )
out = h + α2 · FFN( adaLN(γ2, β2, h) )

Where adaLN(γ, β, x) = γ · LayerNorm(x) + β. The parameters γ1, β1, α1, γ2, β2, α2 are all regressed from the conditioning vector c via a small MLP. Each block gets its own output projection from the shared MLP.

The α gates matter enormously. These scaling factors sit right before each residual connection. At initialization, they are set to zero. This means each DiT block starts as the identity function: output = x + 0·MHSA(x) = x. The network starts from a stable baseline and gradually "turns on" each block. We'll dig into why this is so powerful in Chapter 4.

What's Removed vs Standard ViT

DiT is intentionally minimal. Compared to a standard ViT block, it removes nothing structural — it only modifies the layer norm to be adaptive. There are no:

All 28 blocks (in DiT-XL) process the same 256-token sequence at the same 1152-dimensional hidden size. The same-resolution, same-width design is what gives DiT its clean scaling behavior.

DiT Block Internals

The data flow through a single DiT block. γ, β, α are all computed from the conditioning vector — not learned as fixed parameters.

How does the DiT block differ from a standard ViT block?

Chapter 4: adaLN-Zero Deep Dive

The paper ablates four different ways to inject conditioning into transformer blocks. The winner — adaLN-Zero — achieves roughly half the FID of the simplest approach while adding negligible compute. Understanding why it wins requires understanding the zero initialization trick.

The Four Conditioning Strategies

1. In-Context: Append the timestep and class embeddings as two extra tokens. Process them alongside image tokens. Overhead: ~0%. Simple, but the conditioning has limited influence — it's just two tokens in a sea of 256.

2. Cross-Attention: Add a cross-attention layer after each self-attention. Image tokens attend to conditioning tokens. Overhead: ~15% extra Gflops. Borrows from language model conditioning (like Stable Diffusion's text conditioning).

3. adaLN: Replace fixed layer norm parameters with ones computed from the conditioning vector. Overhead: ~0% (the MLP is tiny). Better than in-context and cross-attention, because conditioning influences every single token's normalization.

4. adaLN-Zero: Same as adaLN, plus: regress additional scaling gates α₁, α₂ that multiply the residual branches. Initialize these gates to zero. Overhead: ~0%. Best by a large margin.

Why Zero Initialization?

Think about what happens at the start of training. Random weights produce random outputs. If the MHSA output is random noise, adding it back via the residual connection corrupts x. This makes the first few training steps chaotic — gradients are large and unstable.

With α initialized to zero: output = x + 0 × MHSA(x) = x. The block is exactly the identity function. Gradients flow cleanly through the residual path from the very first step. The network has a stable foundation to learn from, and α grows gradually as training teaches the block what to contribute.

This is borrowed from ResNets. Goyal et al. (2017) found that zero-initializing the final BatchNorm in each residual block accelerates large-batch training. U-Net diffusion models use a similar trick: zero-initialize the final convolution in each ResNet block. DiT brings this same insight to transformers. Small initialization trick, large training stability gain.

Worked Numerical Example

Let's trace an example token x = [1.2, -0.3, 0.8] through a DiT block at initialization:

StepOperationValue
1LayerNorm(x)[1.15, -0.89, 0.65] (normalized)
2γ₁ · LN(x) + β₁ (γ=1, β=0 at init)[1.15, -0.89, 0.65]
3MHSA output (random at init)[0.47, -1.23, 0.91]
4α₁ = 0 (zero init) → 0 × MHSA[0, 0, 0]
5Residual: x + 0 = x[1.2, -0.3, 0.8] (unchanged!)
6Same for FFN branch with α₂ = 0[1.2, -0.3, 0.8] (still unchanged)

After a few training steps, α₁ and α₂ become small non-zero values (say 0.01, 0.02). The block begins to learn. By 100K steps, α values reach around 0.1-0.3, and the blocks are contributing meaningfully.

The conditioning comparison numbers (at 400K training steps on ImageNet 256×256): In-Context: FID 43.0 | Cross-Attention: FID 35.0 | adaLN: FID 30.0 | adaLN-Zero: FID 23.0. Same architecture, same training budget, same dataset — only the conditioning mechanism changed. adaLN-Zero wins by 7 FID points over adaLN and nearly 20 FID points over in-context.
Conditioning Strategy Comparison

FID-50K at 400K training steps. Lower is better. All models are DiT-XL/2.

Why does initializing α to zero improve DiT training?

Chapter 5: Conditioning

Every DiT block needs to know two things: how noisy is this latent? (the timestep t) and what class should this image be? (the class label c). These aren't fed as pixel values — they go through a carefully designed embedding pipeline before influencing each block's normalization parameters.

Timestep Embedding

The timestep t is a single integer (say, 523 out of 1000). To make it useful for a neural network, DiT embeds it using sinusoidal frequency embeddings — the same technique from the original Transformer paper. Different frequencies capture different timescales:

emb(t)2k = sin(t / 100002k/d)
emb(t)2k+1 = cos(t / 100002k/d)

This 256-dimensional sinusoidal embedding is then passed through a 2-layer MLP (with SiLU activation) to produce a d-dimensional vector: embed_t: int → Rd.

Class Label Embedding

The class label c (an integer 0–999 for ImageNet) is embedded via a standard learned embedding table: a matrix of shape [1000, d]. Each class gets its own d-dimensional vector, learned from scratch during training. embed_c: int → Rd.

For unconditional generation (needed for classifier-free guidance), there's one extra learned embedding: the null embedding ∅, used when the class label is dropped during training (10% of the time).

Combining Timestep and Class

The two embeddings are simply added:

c = embed_t(t) + embed_c(label)    [shape: d]

This combined conditioning vector c (shape [d]) is shared across all 28 DiT blocks. Each block has its own linear projection layer that maps c to 6 scalars: γ₁, β₁, α₁, γ₂, β₂, α₂. This is a tiny linear layer — 6 outputs for each block — and it's all the architectural machinery needed to make conditioning work.

Why addition, not concatenation? Concatenating would double the conditioning vector size, requiring larger projection layers. But addition works just as well, because the MLP after the embeddings can learn to disentangle the timestep and class information from the summed vector. The simplicity here is intentional — every DiT design decision leans toward "as vanilla as possible."

Parameter Count for Conditioning

ComponentShapeParameters
Sinusoidal embed (fixed)0 (no learning)
Timestep MLP (2-layer)256→d→d~2.7M
Class embed table1000×d~1.2M
Per-block projection (×28)d→6~0.2M total
Total conditioning~4.1M / 675M total

Less than 1% of DiT-XL's parameters are used for conditioning. The overwhelming majority (99%+) live in the 28 transformer blocks. This is the point: conditioning is an afterthought in terms of compute. The transformer does the heavy lifting.

Conditioning Flow: t + c → adaLN Parameters

Trace the timestep and class label from raw integers to the γ, β, α parameters that modulate each DiT block.

How many adaLN parameters does the conditioning vector produce per DiT block?

Chapter 6: The Full Pipeline

Now we can trace a complete forward pass through DiT — from a noisy latent zt all the way to a predicted noise map. Let's walk through every operation with exact tensor shapes for DiT-XL/2 (patch size p=2, hidden dim d=1152, 28 blocks).

Input: noised latent zt
Shape: [B, 4, 32, 32] — 4 channels, 32×32 spatial
Patchify + Linear Project
[B, 4, 32, 32] → [B, 256, 1152] — 256 tokens of dim 1152
+ Positional Embedding
[B, 256, 1152] + sine-cosine embed [256, 1152] → [B, 256, 1152]
Conditioning: embed_t(t) + embed_c(c)
[B] int + [B] int → [B, 1152] conditioning vector
↓ ×28
28 × DiT Block (adaLN-Zero)
[B, 256, 1152] → [B, 256, 1152] — self-attention over all 256 tokens
Final adaLN + LayerNorm
[B, 256, 1152] — one more adaLN conditioned on c
Linear Decode
[B, 256, 1152] → [B, 256, 32] — 32 = p×p×2C = 2×2×8
Unpatchify
[B, 256, 32] → [B, 8, 32, 32] — reshape tokens back to spatial
Split output
[B, 8, 32, 32] → ε̂ [B, 4, 32, 32] + Σ̂ [B, 4, 32, 32]
Why output 8 channels? Standard DDPM predicts only noise ε (4 channels). DiT also predicts the diagonal variance Σ (4 more channels). Nichol & Dhariwal (2021) showed that learning the variance improves log-likelihood and enables faster sampling with fewer steps. The variance is parameterized as an interpolation between the DDPM upper and lower bounds — it adds zero extra parameters (just doubles the output projection) but noticeably improves quality.

The Decode Step in Detail

The final linear layer maps each 1152-dim token to a 32-dim vector. Why 32? Each token corresponds to a 2×2 patch with 4 latent channels. The output predicts both noise and variance for that patch: 2×2×4 channels for ε + 2×2×4 channels for Σ = 2×2×8 = 32 values per token.

Unpatchifying is the exact reverse of patchifying: rearrange the [256, 32] output back to [32, 32, 8] spatial format, then split the last dimension into ε̂ [32, 32, 4] and Σ̂ [32, 32, 4].

python — complete DiT forward pass (simplified)
def forward(self, z_t, t, c):
    # z_t: [B, 4, 32, 32]  t: [B]  c: [B] (class labels)
    B = z_t.shape[0]

    # 1. Patchify: [B, 4, 32, 32] → [B, 256, 1152]
    x = self.patch_embed(z_t)          # linear: (p*p*4) → d
    x = x + self.pos_embed             # add fixed sine-cosine embed

    # 2. Conditioning vector: [B, d]
    cond = self.t_embed(t) + self.c_embed(c)

    # 3. 28 DiT blocks
    for block in self.blocks:
        x = block(x, cond)             # adaLN-Zero inside

    # 4. Final norm + linear decode
    x = self.final_norm(x, cond)       # one more adaLN
    x = self.final_linear(x)           # [B, 256, 32]

    # 5. Unpatchify: [B, 256, 32] → [B, 8, 32, 32]
    x = self.unpatchify(x)

    # 6. Split noise + variance
    eps_hat, sigma_hat = x.chunk(2, dim=1)
    return eps_hat, sigma_hat          # each [B, 4, 32, 32]
Full DiT Pipeline with Tensor Shapes

Every stage of the forward pass, with exact shapes for DiT-XL/2 (p=2, d=1152, N=28).

Why does DiT's output have 8 channels when the input latent only has 4?

Chapter 7: Scaling Laws

The most important result in the DiT paper isn't the final FID score. It's what happens when you plot FID against Gflops across 12 different model variants.

Peebles and Xie train every combination of 4 model sizes (S, B, L, XL) × 3 patch sizes (2, 4, 8), all for 400K training steps on ImageNet 256×256. Then they ask: what predicts quality?

Finding 1: Gflops → FID, with correlation -0.93

Across all 12 models, there is a -0.93 correlation between model Gflops and FID-50K. On a log-log plot, the 12 points fall almost on a straight line. This is remarkable. It means you can predict a DiT model's image quality from its compute cost alone, with no other information needed.

This is the transformer's greatest strength: predictable scaling. For U-Nets, no equivalent relationship exists. For transformers, it's practically a law.

Finding 2: Gflops > Parameters

Here's a subtle but crucial point. When you decrease patch size from p=4 to p=2, you quadruple T (tokens) and thus quadruple the Gflops. But parameter count barely changes — the transformer weights are identical, you're just processing more tokens per forward pass.

Yet FID improves substantially. This means compute (Gflops) per image, not total parameter count, is the true driver of quality. DiT-S/2 (33M params, 5.8 Gflops) achieves better FID than DiT-L/8 (458M params, 5.3 Gflops) — despite having 14× fewer parameters — because it uses more compute per forward pass.

This mirrors Chinchilla. The Chinchilla paper (Hoffmann et al., 2022) showed that for language models, it's better to train a smaller model on more data than a larger model on less data. For DiT, the analogous insight: within a training budget, spend compute on more tokens-per-image (smaller patches) rather than more parameters.

Finding 3: Larger Models are More Efficient

When you plot FID against total training compute (Gflops × steps), larger models reach any FID threshold faster. A large model trained for fewer steps beats a small model trained longer. This matters for production: if you have a fixed compute budget, train the largest model you can afford, not a small model for longer.

Model Configurations

ModelLayersHidden dHeadsParamsGflops (p=2)FID (400K steps)
DiT-S/212384633M5.8~68
DiT-B/21276812130M22.6~43
DiT-L/224102416458M80.0~24
DiT-XL/228115216675M118.6~19 → 2.27*

*DiT-XL/2 achieves FID 9.62 without CFG, FID 2.27 with CFG (guidance scale 1.50) after 7M training steps.

DiT Scaling: Gflops vs FID (log-log)

All 12 model variants at 400K steps. Each point is a (size, patch) combination. The trend is nearly linear on this log-log plot — correlation -0.93.

DiT-S/2 has 33M parameters. DiT-L/8 has 458M. Which achieves better FID, and why?

Chapter 8: Showcase — The DiT Pipeline in Motion

Here's the payoff. Watch DiT denoise a latent from pure noise to structured signal, step by step. Choose a model size, a class label, and a guidance scale — then step through the denoising process and see how the image emerges.

The visualization shows the 16×16 patch grid (for p=2, a 32×32 latent becomes 256 patches arranged in a 16×16 grid). Color encodes signal-to-noise ratio: cool blue = mostly noise, warm orange = mostly signal. As t decreases from 250→0, patches progressively "lock in" their values.

Interactive DiT Denoising Simulator

Step through the reverse diffusion process. Watch signal emerge from noise as t decreases.

t = 250 / 250
CFG scale s4
ESTIMATED FID (lower = better)
DENOISING PROGRESS
SIGNAL-TO-NOISE RATIO
0.000

Chapter 9: Connections

DiT (2023) didn't just introduce a new architecture. It changed the entire trajectory of visual generation. Here's how the ideas evolved — and where they landed.

The DiT Family Tree

DDPM (2020)
Ho et al. — U-Net denoiser, 1000-step DDPM, noise prediction. The foundation DiT replaces.
Latent Diffusion / SD (2022)
Rombach et al. — VAE compression, diffusion in latent space. DiT keeps this framework, swaps U-Net for transformer.
DiT (2023)
Peebles & Xie — Replace U-Net with ViT. adaLN-Zero conditioning. FID 2.27 on ImageNet 256×256. Clean scaling laws.
SD3 / MMDiT (2024)
Esser et al. — Separate transformer streams for text and image tokens that interact via bi-directional attention. Text-to-image at scale.
FLUX (2024)
Black Forest Labs — DiT backbone + flow matching (replaces DDPM). Fewer steps needed. State-of-the-art text-to-image.
Sora (2024)
OpenAI — DiT applied to video. Space-time patches (3D patchification). Variable-length sequences. Validates DiT at massive scale.

What Each Successor Changed

ModelKept from DiTChangedImpact
SD3/MMDiTadaLN-Zero blocks, patchification, latent spaceSeparate text stream, bi-directional cross-attention between text and image streamsBetter text rendering, stronger prompt following
FLUXMMDiT architecture, latent spaceFlow matching loss (replaces DDPM), distillation for fast sampling4-step inference without quality loss
SoraDiT blocks, patchification3D spatiotemporal patches, variable resolution & duration, massive scaleCoherent multi-second video generation
SiTDiT blocksStochastic interpolants instead of DDPMBetter theoretical understanding of the diffusion objective

Key Insight: What DiT Proved Was Unnecessary

The U-Net's core inductive bias was: you need to process images at multiple spatial scales simultaneously, with skip connections to preserve detail. DiT proved this is false. Self-attention over a flat patch sequence is sufficient — it implicitly learns multi-scale reasoning without the architectural scaffolding.

This matters because it means the transformer's core strength — attending to arbitrary pairs of tokens — is all you need. There's no special sauce in the U-Net that couldn't be learned from data.

DiT's lasting legacy: Nearly every major image and video generation system after 2023 uses a transformer backbone. The paper's true contribution wasn't a better FID number — it was bringing diffusion models into the transformer scaling playbook that has driven progress in every other domain. The next breakthrough in image generation will almost certainly be built on a DiT-family architecture, scaled to more compute.

Related Lessons

"The transformer didn't just replace the U-Net — it opened diffusion models to the same scaling playbook that transformed language modeling. We no longer have to ask 'will making it bigger help?' The answer is yes, quantifiably, predictably, always."
— Inspired by Peebles & Xie, 2023
What key architectural assumption of U-Net diffusion did DiT disprove?