The simple, scalable training algorithm behind Stable Diffusion, FLUX, and Meta's Movie Gen.
In the last chapter, we learned how to sample from a flow model: initialize with noise X0 ~ N(0, I), simulate the ODE dX/dt = utθ(Xt), and use the endpoint X1 as our sample. But we skipped the most important question: how do we train the neural network utθ so that X1 actually looks like data?
The naive answer — "simulate the ODE during training and backpropagate through it" — is expensive and unstable. Each training step would require simulating an entire ODE trajectory (50+ neural network evaluations), then backpropagating through all those steps. This is prohibitively costly for large models.
Flow matching is the elegant alternative. It turns training into a simple regression problem — no ODE simulation needed during training at all. The key insight: we can construct a target vector field that we know analytically, and train the neural network to match it. The training loop is just: sample a data point, sample some noise, compute a mean-squared error, and update weights. That is it.
The roadmap for this chapter:
The first step of flow matching is to define a probability path — a gradual interpolation between noise and data. Think of it this way: if you had a movie that showed an image forming from static, each frame would show the distribution at a different stage of "formation." A probability path formalizes this idea.
At time t = 0, we have pure noise pinit = N(0, I). At time t = 1, we want data pdata. A probability path pt smoothly transitions between the two:
Think of it as a movie that starts with static (noise) and ends with a clear image (data). Each frame t shows the distribution at an intermediate stage.
We start with a simpler object: a conditional probability path pt(x|z). For a single data point z, this interpolates between noise and z:
Here δz is the Dirac delta at z — the simplest possible distribution that always returns z. So pt(x|z) starts as a Gaussian (noise) and ends concentrated at the single data point z.
The marginal probability path pt(x) is defined by the two-step sampling procedure:
Formally, the marginal density is the integral:
We can sample from pt (draw z from data, then draw x from the conditional path). But we generally cannot evaluate pt(x) because the integral is intractable. This distinction will be crucial later.
Worked example — sampling from the marginal path in 1D. Suppose pdata puts equal mass at z = −3 and z = +3, and we use a Gaussian conditional path with αt = t, βt = 1−t.
To sample from pt at t = 0.5:
1. Sample z ~ pdata: get z = 3 (with probability 0.5).
2. Sample ε ~ N(0,1): get ε = 0.4.
3. Compute x = 0.5 · 3 + 0.5 · 0.4 = 1.5 + 0.2 = 1.7.
If we repeat this many times, we get a mixture of two Gaussians: N(0.5·(−3), 0.25) and N(0.5·3, 0.25) = N(−1.5, 0.25) and N(1.5, 0.25). The two modes are emerging from the noise but still blurred.
The most important probability path is the Gaussian probability path. It is used by Stable Diffusion, FLUX, and virtually every state-of-the-art model. Given a data point z, the conditional path is:
where αt and βt are noise schedulers: smooth monotonic functions satisfying:
Verification: At t = 0: p0(·|z) = N(0 · z, 12 I) = N(0, I) = pinit. Check. At t = 1: p1(·|z) = N(1 · z, 02 I) = δz. Check. (A Gaussian with zero variance is a Dirac delta.)
To sample from this path at time t:
Why is this a Gaussian? The sample x = αtz + βtε is a deterministic function of the data z and the noise ε ~ N(0, I). For fixed z, x is a shifted and scaled Gaussian: x ~ N(αtz, βt2I). The mean is αtz (a scaled version of the data) and the variance is βt2 (decreasing with time).
Visualization in 2D. At t = 0: pt(·|z) = N(0, I) is a unit Gaussian centered at the origin. At t = 0.5 (for CondOT): pt(·|z) = N(0.5z, 0.25I) — the Gaussian has shrunk (variance 0.25) and moved halfway toward z. At t = 0.9: pt(·|z) = N(0.9z, 0.01I) — a very tight Gaussian nearly on top of z. At t = 1: it collapses to δz.
python # Sampling from the Gaussian probability path def sample_path(z, t, alpha_fn, beta_fn): """Sample x ~ p_t(·|z) for the Gaussian path.""" alpha = alpha_fn(t) beta = beta_fn(t) eps = torch.randn_like(z) x = alpha * z + beta * eps return x, eps # CondOT: alpha(t) = t, beta(t) = 1-t z = torch.tensor([3.0, -1.0]) x, eps = sample_path(z, t=0.6, alpha_fn=lambda t: t, beta_fn=lambda t: 1-t) # x = 0.6 * [3, -1] + 0.4 * eps
Worked example. Let z = 3.0 (a 1D data point), αt = t, βt = 1 − t, and ε = 0.7 (a noise sample):
| t | αt = t | βt = 1−t | x = αtz + βtε = t·3 + (1−t)·0.7 |
|---|---|---|---|
| 0.0 | 0.0 | 1.0 | 0 + 0.7 = 0.700 |
| 0.25 | 0.25 | 0.75 | 0.75 + 0.525 = 1.275 |
| 0.50 | 0.50 | 0.50 | 1.5 + 0.35 = 1.850 |
| 0.75 | 0.75 | 0.25 | 2.25 + 0.175 = 2.425 |
| 1.0 | 1.0 | 0.0 | 3.0 + 0 = 3.000 |
The sample smoothly transitions from ε = 0.7 (noise) to z = 3.0 (data).
Drag the t slider to see how the distribution transitions from noise (t=0) to data (t=1). The data distribution is a mixture of 3 Gaussians (shown as target dots at t=1).
A probability path tells us what distributions we want at each time t. But we need a vector field whose ODE actually produces those distributions. Given a conditional probability path pt(·|z) = N(αtz, βt2I), we need to find uttarget(x|z) such that:
This is one of the most important derivations in the course. Let's do it step by step, showing every algebraic move.
Step 1: Define the conditional flow. We want a flow ψt(x|z) such that if X0 ~ N(0, I), then Xt = ψt(X0|z) ~ N(αtz, βt2I). The simplest choice is an affine map:
Step 2: Verify it produces the right distributions. If X0 ~ N(0, I), then Xt = ψt(X0|z) = αtz + βtX0. Since X0 is Gaussian and we are applying an affine transformation, Xt is also Gaussian:
So Xt ~ N(αtz, βt2I) = pt(·|z). Check. The flow produces the correct conditional probability path.
Step 3: Extract the vector field. By the definition of a flow, dψ/dt = u(ψ, z). We differentiate ψt with respect to t:
Step 4: Re-express in terms of the output. The vector field u must be a function of ψt (the current position), not x (the initial position). Since ψt = αtz + βtx, we solve for x = (ψt − αtz)/βt and substitute:
Renaming ψt back to x (since the formula should express the vector field as a function of the current position x), the conditional Gaussian vector field is:
Understanding each term. The vector field has two parts:
1. The z-dependent term (α̇t − αtβ̇t/βt) z — this pulls the particle toward the data point z. The coefficient depends on the time derivatives of the schedulers.
2. The x-dependent term (β̇t/βt) x — this adjusts the velocity based on the current position. For CondOT, β̇t/βt = −1/(1−t), which is always negative, meaning it pushes toward the origin.
Combined, the two terms produce a velocity that moves x along a trajectory from the initial noise toward the target data point z.
Step 5: Specialize to CondOT (αt = t, βt = 1−t). Let's compute each derivative:
α̇t = d(t)/dt = 1. β̇t = d(1−t)/dt = −1.
Now substitute into the general formula:
Step 6: The elegant shortcut. There is a much simpler way to see the answer. Since x = αtz + βtε, the trajectory of a specific sample is x(t) = tz + (1−t)ε for CondOT. Its time derivative is:
But dx/dt is the velocity of the trajectory, which is exactly the conditional vector field evaluated at x(t). Therefore:
Numerical example. z = 3.0, ε = 0.7, so the target velocity is z − ε = 3.0 − 0.7 = 2.3 at every time t. The velocity is constant along the straight line from ε to z.
Verification by trajectory. The trajectory under this velocity is x(t) = x(0) + t · (z − ε). Since x(0) = ε = 0.7:
| t | x(t) = 0.7 + t · 2.3 | Should equal tz + (1−t)ε | Match? |
|---|---|---|---|
| 0.0 | 0.7 + 0 = 0.700 | 0 + 0.7 = 0.700 | Yes |
| 0.25 | 0.7 + 0.575 = 1.275 | 0.75 + 0.525 = 1.275 | Yes |
| 0.50 | 0.7 + 1.15 = 1.850 | 1.5 + 0.35 = 1.850 | Yes |
| 1.0 | 0.7 + 2.3 = 3.000 | 3.0 + 0 = 3.000 | Yes |
The constant velocity z − ε traces a straight line from ε to z. This is the optimal transport solution — the shortest path between the two points.
2D numerical example. z = (3.0, −1.0), ε = (0.5, 0.8). Target velocity: z − ε = (2.5, −1.8). At t = 0.4: x = 0.4(3.0, −1.0) + 0.6(0.5, 0.8) = (1.2, −0.4) + (0.3, 0.48) = (1.5, 0.08). The trajectory is a straight line in 2D from (0.5, 0.8) to (3.0, −1.0).
python # Computing the conditional vector field for CondOT def condot_target(z, eps): """Target velocity for CondOT path.""" return z - eps # That's it. The target is z - eps. # Example z = torch.tensor([3.0, -1.0]) eps = torch.tensor([0.5, 0.8]) target = condot_target(z, eps) # tensor([2.5, -1.8])
The conditional vector field uttarget(x|z) moves everything toward a single data point z. That is useless by itself — we do not want every sample to collapse to the same z. We need a vector field that generates all of pdata.
The marginal vector field is the average of conditional vector fields, weighted by how likely each data point z is to have produced the current noisy sample x:
The weight pt(x|z)pdata(z)/pt(x) is a posterior: given the noisy observation x at time t, how likely is it that the clean data was z? This is Bayes' rule in action.
Worked example — marginal field with 2 data points. Suppose pdata puts equal mass on z1 = −3 and z2 = +3 (1D). At t = 0.5, consider a noisy observation x = 1.0. Which data point produced it?
Under the conditional path (CondOT), x = tz + (1−t)ε, so at t = 0.5:
If z = z1 = −3: x = 0.5(−3) + 0.5ε, so ε = (x − 0.5(−3))/0.5 = (1 + 1.5)/0.5 = 5.0.
If z = z2 = +3: x = 0.5(3) + 0.5ε, so ε = (x − 1.5)/0.5 = −1.0.
Since ε ~ N(0, 1), ε = −1.0 is much more likely than ε = 5.0. So the posterior strongly favors z2 = +3. The marginal velocity at x = 1.0 is dominated by the conditional velocity pointing toward z2.
Conditional velocities: u(x|z1) = z1 − ε1 = −3 − 5 = −8, u(x|z2) = z2 − ε2 = 3 − (−1) = 4.
Posterior weights: w1 ∝ N(5; 0, 1) ≈ 10−6, w2 ∝ N(−1; 0, 1) ≈ 0.24.
Marginal velocity: u(x) ≈ w1(−8) + w2(4) / (w1 + w2) ≈ 4.0. The field points toward +3.
Orange arrows show the conditional field for ONE data point z. Teal arrows show the marginal field (averaged over all data points). Toggle between them to see the difference.
Theorem 9 (Marginalization Trick): The marginal vector field uttarget(x) generates the marginal probability path pt:
This is the key theorem. If we could compute uttarget(x), we would be done — just simulate this ODE and get perfect samples. The problem: we cannot compute uttarget(x) because the integral involves pt(x) in the denominator, which is intractable.
The continuity equation. The mathematical tool that proves Theorem 9 is the continuity equation. For an ODE dX/dt = ut(Xt), the density pt(x) of Xt evolves according to:
This equation from physics says: the rate of change of probability at x equals the net inflow of probability mass. The divergence measures outflow; the negative divergence measures inflow. Multiplying by pt accounts for the total mass being transported.
Proof idea for Theorem 9. We need to show that uttarget(x) = ∫ u(x|z) · posterior(z|x) dz satisfies the continuity equation for pt. The proof directly computes:
1. Start with ∂t pt(x) = ∫ ∂t pt(x|z) pdata(z) dz (definition of marginal).
2. Each conditional term satisfies its own continuity equation: ∂t pt(x|z) = −div(pt(·|z) ut(·|z))(x).
3. Integrate: ∂t pt(x) = −div(∫ pt(x|z) ut(x|z) pdata(z) dz).
4. Multiply and divide by pt(x): = −div(pt(x) · ∫ ut(x|z) posterior(z|x) dz) = −div(pt uttarget).
This matches the continuity equation for uttarget, proving the theorem.
We want utθ ≈ uttarget (marginal vector field). The natural loss would be:
But we cannot compute uttarget(x)! Instead, define the conditional flow matching loss:
The difference: we regress against the conditional vector field uttarget(x|z) instead of the marginal uttarget(x). We can compute this because we know the conditional vector field analytically.
Proof sketch. Expand ||a − b||2 = ||a||2 − 2aTb + ||b||2 for both losses. The ||a||2 terms are the same (they both sample x ~ pt). The cross terms are equal because:
This is the crucial step: the marginal vector field in the cross-term can be replaced by the conditional one by expanding the definition and using linearity of integration. The ||b||2 terms differ but are constants in θ.
Why this is profound: Training is simulation-free. We never simulate any ODE during training. Each training step is O(1) neural network evaluations (not O(n) like ODE simulation). This makes flow matching scalable to billion-parameter models.
Full proof of Theorem 12. We need to show LFM(θ) = LCFM(θ) + C. Expand the squared norm:
Apply to LFM:
The last term is constant in θ (call it C1). The first term E[||uθ(x)||2] where x ~ pt is the same as E[||uθ(x)||2] where z ~ pdata, x ~ pt(·|z), because sampling x ~ pt is equivalent to the two-step procedure.
The crucial step is the cross-term. Using the definition of uttarget(x) as a posterior-weighted integral:
This is the key identity. It says: averaging uθ against the marginal velocity is the same as averaging against the conditional velocity, because the marginal velocity is a posterior-weighted average of conditional velocities. When we take the expectation over x ~ pt, the posterior weights integrate out.
Substituting back and adding/subtracting ||utarget(x|z)||2 inside the expectation, we get LFM = LCFM + C1 + C2, where C2 is also independent of θ. QED.
Worked example — verifying the theorem for d = 1. Suppose pdata puts mass 0.5 on z = −2 and 0.5 on z = +2. CondOT path. Consider fixed t = 0.5, x = 0.
The marginal velocity utarget(x = 0) at t = 0.5:
For z = −2: ε = (x − tz)/(1−t) = (0 − (−1))/0.5 = 2.0. Conditional velocity: z − ε = −2 − 2 = −4. Weight: N(ε=2; 0, 1) = 0.054.
For z = +2: ε = (0 − 1)/0.5 = −2.0. Conditional velocity: z − ε = 2 − (−2) = 4. Weight: N(ε=−2; 0, 1) = 0.054.
Both weights are equal! So utarget(0) = 0.5(−4) + 0.5(4) = 0. The velocity at the midpoint is zero because both data points pull equally.
If we slightly perturb x = 0.5: For z = +2, ε = (0.5 − 1)/0.5 = −1.0, weight N(−1; 0, 1) = 0.242. For z = −2, ε = (0.5 + 1)/0.5 = 3.0, weight N(3; 0, 1) = 0.004. Now z = +2 dominates, and utarget(0.5) ≈ 4. The field points toward +2. This is how the marginal field "decides" which mode to send each particle to.
An analogy to make this concrete. Imagine training a weather prediction model. At location x = "40°N, 74°W" at time t = "January," the target wind velocity varies across years. Some years it blows east, some west, some north. If you train the model on individual year samples, the model learns the average wind at that location and time. This average is the climatological wind — the marginal velocity in our analogy. Individual years are like individual data points z. The model cannot memorize a specific year; it learns the average over all years. This is exactly how flow matching works: individual training samples give conditional targets, but the model learns the marginal (average) velocity.
Convergence of SGD for flow matching. Each mini-batch provides a noisy estimate of ∇θ LCFM. By Theorem 12, this is an unbiased estimate of ∇θ LFM. Standard SGD convergence results guarantee that the parameters converge to a minimizer of LFM (given sufficient model capacity and decreasing learning rate). In practice, Adam with a fixed learning rate and cosine schedule works well.
Training vs. sampling complexity. A comparison that highlights why flow matching is practical:
| Operation | NFE per step | Backprop? | Typical iterations |
|---|---|---|---|
| Flow matching training step | 1 | Yes | 100K–1M |
| Sampling (Euler, 50 steps) | 50 | No | 1 per image |
| Naive "simulate + backprop" training | 50 | Yes (through all 50) | Prohibitive |
Flow matching training is 50x cheaper per step than the naive approach, with no memory overhead from storing intermediate ODE states for backpropagation.
Memory usage comparison. The naive "simulate + backprop" approach stores all intermediate states for backpropagation. For 50 Euler steps with a DiT-XL model: memory = 50 × (activation memory per step) ≈ 50 × 2GB = 100GB. This exceeds the memory of any single GPU. Flow matching: memory = 1 × 2GB = 2GB. A 50x memory reduction.
Techniques like checkpointing (recomputing activations instead of storing them) can reduce the naive approach's memory to O(1), but at the cost of 2-3x more computation. Flow matching avoids this tradeoff entirely.
EMA (Exponential Moving Average) of weights. One important practical detail: during training, we maintain an EMA of the model weights with decay rate 0.9999. At inference, we use the EMA weights instead of the raw weights. This smooths out training noise and consistently improves sample quality. All production models use this technique.
python # EMA of weights (essential for quality) class EMA: def __init__(self, model, decay=0.9999): self.ema_params = {n: p.clone() for n, p in model.named_parameters()} self.decay = decay def update(self, model): for n, p in model.named_parameters(): self.ema_params[n].mul_(self.decay).add_(p, alpha=1-self.decay) # Usage: after each optimizer step, call ema.update(model) # At inference: load ema.ema_params into the model
python # Conditional Flow Matching Loss (one training step) def cfm_loss(model, z_batch): """z_batch: [B, d] — batch of data samples""" B, d = z_batch.shape t = torch.rand(B, 1) # t ~ Unif[0,1] eps = torch.randn_like(z_batch) # eps ~ N(0, I) x = t * z_batch + (1 - t) * eps # x ~ p_t(·|z) target = z_batch - eps # u_target = z - eps pred = model(x, t) # u_theta(x, t) return (pred - target).pow(2).mean() # MSE loss
Let us instantiate everything for the most popular choice: the Conditional Optimal Transport (CondOT) path with αt = t and βt = 1 − t. This gives:
| Quantity | General Gaussian | CondOT (α=t, β=1−t) |
|---|---|---|
| Conditional path | pt(·|z) = N(αtz, βt2I) | N(tz, (1−t)2I) |
| Sample from path | x = αtz + βtε | x = tz + (1−t)ε |
| Derivatives | α̇t, β̇t | α̇t = 1, β̇t = −1 |
| Conditional velocity | (α̇t − αtβ̇t/βt)z + (β̇t/βt)x | z − ε |
| CFM loss | ||utθ(αtz+βtε) − (α̇tz+β̇tε)||2 | ||utθ(tz+(1−t)ε) − (z−ε)||2 |
Worked example — one training step by hand. Suppose our dataset is {2.0, 5.0, −1.0}. One training step:
1. Sample data: z = 5.0 (randomly chosen from dataset).
2. Sample time: t = 0.3 (from Unif[0,1]).
3. Sample noise: ε = −0.8 (from N(0,1)).
4. Compute noisy input: x = 0.3 · 5.0 + 0.7 · (−0.8) = 1.5 − 0.56 = 0.94.
5. Compute target velocity: z − ε = 5.0 − (−0.8) = 5.8.
6. Neural network predicts: u0.3θ(0.94) = 4.2 (hypothetical).
7. Loss: (4.2 − 5.8)2 = (−1.6)2 = 2.56.
8. Backpropagate and update θ to reduce this loss.
The general Gaussian CFM loss with arbitrary schedulers is:
Alternative noise schedules. While CondOT (αt = t, βt = 1 − t) creates straight-line paths, other schedules are possible:
| Name | αt | βt | Path shape |
|---|---|---|---|
| CondOT (linear) | t | 1 − t | Straight lines from ε to z |
| Cosine schedule | sin(πt/2) | cos(πt/2) | Curved paths, slow start & end |
| VP (DDPM-style) | √(ᾱt) | √(1 − ᾱt) | Preserves norm |
In practice, the CondOT schedule works extremely well. Straight-line paths mean the vector field is nearly constant over time, which is easier for a neural network to learn. The cosine schedule was popular before flow matching (in DDPM-era models) because it avoids spending too much training time at very high or very low noise levels.
Signal-to-Noise Ratio (SNR). The noise schedule determines how much "signal" (data) vs. "noise" is present at each time t. The SNR is defined as:
| t (CondOT) | αt = t | βt = 1−t | SNR = t2/(1−t)2 | Interpretation |
|---|---|---|---|---|
| 0.1 | 0.1 | 0.9 | 0.012 | Almost pure noise |
| 0.3 | 0.3 | 0.7 | 0.184 | Noise dominates |
| 0.5 | 0.5 | 0.5 | 1.000 | Equal signal and noise |
| 0.7 | 0.7 | 0.3 | 5.444 | Signal dominates |
| 0.9 | 0.9 | 0.1 | 81.0 | Almost pure data |
The network faces very different tasks at different t values. At low t (high noise), it must infer the general class or mode of the data. At high t (low noise), it must predict fine details like texture and edges. Modern training strategies sometimes sample t non-uniformly to give more weight to informative noise levels (like the logit-normal sampling in Stable Diffusion 3).
Time sampling strategies. Uniform sampling t ~ Unif[0,1] weights all noise levels equally. But not all noise levels are equally important for image quality:
| Strategy | Distribution of t | Effect |
|---|---|---|
| Uniform | Unif[0, 1] | Equal weight to all noise levels |
| Logit-normal | t = sigmoid(Normal(0, 1)) | More weight near t = 0.5 (medium noise) |
| Cosine | t = cos−1(Unif) | More weight near t = 0 and t = 1 |
| Min-SNR weighting | Weight loss by min(SNR, γ) | Caps contribution of high-SNR timesteps |
Stable Diffusion 3 uses logit-normal time sampling, which concentrates training on medium noise levels where the neural network has the most to learn. Very high noise (t near 0) is easy (just predict the data mean) and very low noise (t near 1) is easy (just predict the identity). The interesting work happens in between.
python # Time sampling strategies import torch # Uniform t_uniform = torch.rand(B) # Logit-normal (Stable Diffusion 3) t_logit_normal = torch.sigmoid(torch.randn(B)) # CDF-based cosine u = torch.rand(B) t_cosine = 1 - torch.acos(u) / torch.pi * 2 # All three are valid; logit-normal is most popular now
python # Different noise schedules import torch import math def condot_schedule(t): return t, 1 - t # alpha, beta def cosine_schedule(t): alpha = torch.sin(math.pi * t / 2) beta = torch.cos(math.pi * t / 2) return alpha, beta # CondOT: x = t*z + (1-t)*eps, target = z - eps # Cosine: x = sin(πt/2)*z + cos(πt/2)*eps # target = (π/2)*cos(πt/2)*z - (π/2)*sin(πt/2)*eps
Let us write out the complete training procedure and work through it in detail.
Click "Training Step" to see one iteration of Algorithm 3. Watch the loss decrease as the model learns to predict (z − ε) from noisy inputs.
python # Algorithm 3: Complete Flow Matching Training Loop import torch import torch.nn as nn class VelocityNet(nn.Module): def __init__(self, d=2, hidden=256): super().__init__() self.net = nn.Sequential( nn.Linear(d + 1, hidden), # input: [x, t] nn.SiLU(), nn.Linear(hidden, hidden), nn.SiLU(), nn.Linear(hidden, d), # output: velocity ) def forward(self, x, t): return self.net(torch.cat([x, t], dim=-1)) # Training loop model = VelocityNet(d=2) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) for step in range(10000): z = sample_data(batch_size=256) # [256, 2] t = torch.rand(256, 1) # [256, 1] eps = torch.randn_like(z) # [256, 2] x = t * z + (1 - t) * eps # noisy sample target = z - eps # target velocity loss = (model(x, t) - target).pow(2).mean() optimizer.zero_grad() loss.backward() optimizer.step()
What does the neural network see during training? Let's trace through 5 training steps in detail:
| # | z (data) | t | ε (noise) | x = tz+(1−t)ε | Target = z−ε |
|---|---|---|---|---|---|
| 1 | 2.0 | 0.71 | −0.3 | 1.33 | 2.3 |
| 2 | −1.0 | 0.12 | 1.5 | 1.20 | −2.5 |
| 3 | 5.0 | 0.85 | 0.2 | 4.28 | 4.8 |
| 4 | 2.0 | 0.33 | −1.1 | −0.08 | 3.1 |
| 5 | −1.0 | 0.55 | 0.8 | −0.19 | −1.8 |
Notice: the network receives input (x, t) and must predict the target. The same x value can have different targets depending on what z and ε were used — but on average, the network learns the correct marginal velocity. This is the magic of Theorem 12: training against individual conditional targets implicitly trains against the (intractable) average.
What does the loss curve look like? A typical flow matching training run on CIFAR-10 (32×32 images):
| Training steps | MSE Loss | FID Score | Quality |
|---|---|---|---|
| 1K | ~2.0 | >200 | Random noise |
| 10K | ~0.8 | ~100 | Blurry blobs with color |
| 50K | ~0.4 | ~30 | Recognizable objects |
| 200K | ~0.25 | ~10 | Good quality |
| 500K | ~0.20 | ~5 | State of the art |
The FID (Frechet Inception Distance) measures how similar the generated distribution is to the real data distribution. Lower is better. The loss decreases monotonically, but FID can plateau or even increase slightly due to overfitting.
Debugging flow matching training. Common issues and diagnostics:
Mini-batch size matters. Larger batches give lower-variance gradient estimates, leading to faster convergence. For 2D toy experiments, B = 256 is plenty. For image generation, B = 128-512 per GPU with gradient accumulation. Stable Diffusion 3 was trained with an effective batch size of 2048.
Learning rate schedule. Most flow matching models use a linear warmup (1K-10K steps) followed by cosine decay. The peak learning rate is typically 1e-4 to 3e-4 for Adam. Using AdamW with weight decay 0.01 helps regularization.
python # Complete training loop with all practical details import torch from torch.optim.lr_scheduler import CosineAnnealingLR model = DiT(in_channels=4, hidden=1024, depth=24) optimizer = torch.optim.AdamW( model.parameters(), lr=3e-4, weight_decay=0.01) scheduler = CosineAnnealingLR(optimizer, T_max=500000) for step in range(500000): z = next(dataloader) # [B, 4, 64, 64] latent images B = z.shape[0] # Logit-normal time sampling (better than uniform) u = torch.randn(B) t = torch.sigmoid(u).view(B, 1, 1, 1) # [B, 1, 1, 1] eps = torch.randn_like(z) x = t * z + (1 - t) * eps target = z - eps loss = (model(x, t.squeeze()) - target).pow(2).mean() optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step()
Tensor shapes in practice. For a batch of B = 256 images of size 64 × 64 × 3:
| Tensor | Shape | Description |
|---|---|---|
| z | [256, 3, 64, 64] | Batch of data samples |
| t | [256, 1, 1, 1] | Random times (broadcast over spatial dims) |
| ε | [256, 3, 64, 64] | Gaussian noise |
| x = tz + (1−t)ε | [256, 3, 64, 64] | Noisy images |
| target = z − ε | [256, 3, 64, 64] | Velocity targets |
| pred = uθ(x, t) | [256, 3, 64, 64] | Network prediction |
| loss | scalar | MSE over all elements |
After training, we sample by simulating the ODE with the learned vector field. This is the full pipeline:
Below, we simulate a pre-trained flow matching model on a 2D distribution. Particles start as Gaussian noise and evolve to the target distribution using Euler steps. You can see the straight-line trajectories characteristic of the CondOT path.
Watch 100 particles flow from Gaussian noise (t=0) to the target distribution (t=1). The target is a checkerboard pattern. Particles follow approximately straight paths — the hallmark of CondOT flow matching.
python # Complete Flow Matching Pipeline: Train + Sample # After training (Algorithm 3), sample via Algorithm 1: def generate_samples(model, n_samples=1000, n_steps=50, d=2): x = torch.randn(n_samples, d) # pure noise h = 1.0 / n_steps with torch.no_grad(): for i in range(n_steps): t = torch.full((n_samples, 1), i * h) x = x + h * model(x, t) # Euler step return x # generated samples!
What happens at each sampling step — in detail. Suppose we have trained a model on 2D "two moons" data. During sampling with 20 Euler steps:
| Step | t | Xt | uθ(Xt, t) | What the network "thinks" |
|---|---|---|---|---|
| 0 | 0.00 | (0.3, −0.5) | (0.8, 1.2) | Noise — uncertain, average direction |
| 5 | 0.25 | (0.5, 0.1) | (1.1, 0.9) | Starting to "choose" the upper moon |
| 10 | 0.50 | (0.8, 0.5) | (0.7, 0.3) | Committed to upper moon, refining |
| 15 | 0.75 | (1.0, 0.7) | (0.2, 0.1) | Almost at target, slowing down |
| 20 | 1.00 | (1.05, 0.72) | — | Final sample on the upper moon |
Notice how the velocity starts large (the particle needs to move far) and decreases as the particle approaches the data manifold. This is a consequence of the CondOT path: the velocity is constant in the conditional case (z − ε), but in the marginal case it varies as the posterior concentrates.
Why CondOT paths produce straight trajectories. For the CondOT path, the conditional trajectory is x(t) = tz + (1−t)ε — a straight line from ε to z. The marginal trajectories are not exactly straight (because the posterior-weighted average of straight lines is generally curved), but they are approximately straight, especially when the modes are well-separated. Straight trajectories mean the velocity field changes slowly over time, which has two benefits: (1) the neural network finds it easier to learn a slowly-varying function, and (2) the Euler method is more accurate for slowly-varying ODEs, so fewer steps are needed.
This is why flow matching with CondOT typically needs only 20-50 Euler steps, while DDPM-style models (which use curved VP noise schedules) historically needed 1000 steps. The improvement is purely due to straighter trajectories — the underlying mathematics is the same.
Trajectory straightness metric. Lipman et al. (2022) defined a straightness metric for flow models:
For perfectly straight trajectories, S = 1 (the path length equals the displacement). For curved trajectories, S < 1 (the path is longer than needed). CondOT paths achieve S ≈ 0.99, while VP (DDPM-style) paths achieve S ≈ 0.7. The higher the straightness, the fewer Euler steps needed for accurate sampling.
What the trained model looks like inside. After training, the neural network has learned a complex, time-dependent vector field. At t = 0 (pure noise), the field is approximately constant — pushing all particles in roughly the same direction (toward the center of the data distribution). As t increases, the field becomes more complex and spatially varying, with different regions pointing toward different modes. By t = 0.9, the field has fine-grained structure that places each particle precisely on the data manifold.
This progression from simple to complex is another reason why flow matching works so well: the neural network can dedicate most of its capacity to the hard part (fine-grained details at high t) rather than the easy part (rough direction at low t).
Extensions and variations of flow matching. The basic algorithm we described can be extended in several important ways:
1. Stochastic interpolants (Albergo & Vanden-Eijnden, 2023). A generalization where the interpolation between noise and data can have additional stochastic terms. This includes flow matching and score matching as special cases.
2. Rectified flows (Liu et al., 2022). After training, the trajectories may not be perfectly straight. Rectified flows "straighten" the trajectories by retraining: use the trained model to generate (noise, data) pairs, then retrain on these pairs. This iterative process produces straighter paths, enabling accurate generation in fewer steps.
3. Discrete flow matching. Flow matching extended to discrete data (like text tokens). Instead of continuous interpolation, discrete flow matching uses Markov chains that gradually convert a uniform distribution into the data distribution. This is covered in Chapter 7 of the book.
4. Equivariant flow matching. For data with symmetries (e.g., molecules that should be invariant to rotation), the vector field can be designed to respect these symmetries. This dramatically improves sample efficiency for scientific applications.
5. Consistency models (Song et al., 2023). Train a model that directly maps any point on the ODE trajectory to the endpoint (clean data). This enables one-step generation at the cost of additional training complexity. Consistency models can be trained from scratch or distilled from a pre-trained flow matching model.
6. Flow matching for video. The extension to video is straightforward: z ∈ RT×H×W×3 is a sequence of frames. The neural network is a 3D U-Net or video DiT that processes both spatial and temporal dimensions. The training loss and sampling procedure are identical — just with higher-dimensional z. Meta's Movie Gen and OpenAI's Sora both use this approach.
7. Conditional optimal transport. The "OT" in CondOT refers to optimal transport theory. The connection: the CondOT path creates straight-line trajectories from noise to data, which correspond to the optimal transport plan between the noise and data distributions (under the Wasserstein-2 metric). This is not just a mathematical curiosity — optimal transport paths are the most efficient possible, minimizing the total "work" of transporting noise to data. This efficiency is why CondOT paths require fewer simulation steps.
The scaling behavior of flow matching. How does flow matching scale with model size and data?
| Model Size | Dataset | FID (ImageNet 256) | Training Cost |
|---|---|---|---|
| DiT-S/2 (33M) | ImageNet-1K | ~68 | ~400 A100 hours |
| DiT-B/2 (130M) | ImageNet-1K | ~43 | ~1,600 A100 hours |
| DiT-L/2 (458M) | ImageNet-1K | ~10 | ~5,000 A100 hours |
| DiT-XL/2 (675M) | ImageNet-1K | ~2.3 | ~13,000 A100 hours |
| SD3-Medium (~2B) | ~1B images | <2 | ~500,000 A100 hours |
The scaling is remarkably clean: bigger models + more data = better quality, with no signs of saturation. This is analogous to the scaling laws observed in large language models, and suggests that simply scaling up flow matching models will continue to improve quality.
DiT (Diffusion Transformer) architecture. The dominant neural network architecture for flow matching is the DiT, which replaces the traditional U-Net with a Vision Transformer. Key design choices:
1. Patchification: The input latent (e.g., 64×64×4) is split into 2×2 patches, producing a sequence of N = 1024 tokens, each of dimension D.
2. Time conditioning: The time t is embedded as a sinusoidal encoding and added to the token embeddings via adaptive layer norm (adaLN).
3. Text conditioning: Text prompt embeddings from CLIP/T5 are cross-attended in each transformer block.
4. Output: The transformer outputs velocity predictions for each patch, which are unpatchified to the original spatial resolution.
python # Simplified DiT architecture class DiTBlock(nn.Module): def forward(self, x, t_emb, text_emb): # x: [B, N, D] (sequence of patch tokens) # t_emb: [B, D] (time embedding) # text_emb: [B, L, D] (text token embeddings) # AdaLN: modulate layer norm with time scale, shift = self.adaLN(t_emb).chunk(2, dim=-1) x = self.norm(x) * (1 + scale) + shift # Self-attention + cross-attention with text x = x + self.self_attn(x) x = x + self.cross_attn(x, text_emb) x = x + self.ffn(x) return x
The DiT architecture scales much better than U-Nets because Transformers have well-understood scaling laws and benefit from hardware optimizations (FlashAttention, etc.). Stable Diffusion 3 and FLUX both use DiT-style architectures.
Multi-resolution training. Modern models are trained at multiple resolutions to improve efficiency and quality. The approach: start training at low resolution (e.g., 256×256), then fine-tune at higher resolution (512×512, 1024×1024). This works because the vector field learned at low resolution transfers well to higher resolution — the overall structure of the flow is similar, and only fine details change.
python # Multi-resolution training schedule (simplified) # Phase 1: Low resolution (fast, learns global structure) for step in range(200000): z = load_batch(resolution=256) # latent: [B, 4, 32, 32] loss = cfm_loss(model, z) optimizer.step() # Phase 2: High resolution (slower, learns fine details) for step in range(100000): z = load_batch(resolution=1024) # latent: [B, 4, 128, 128] loss = cfm_loss(model, z) # same loss function! optimizer.step()
Noise-conditioned augmentation. During training, many models apply random augmentations (crop, flip, color jitter) to the data. These augmentations increase the effective dataset size and improve generalization. However, care must be taken: augmentations that change the semantics (e.g., flipping text) should not be used. Common augmentations for image models: random horizontal flip, random crop, and multi-aspect-ratio training.
The flow matching loss landscape. Unlike GAN training, which has notoriously difficult optimization landscapes (saddle points, cycling), the flow matching loss is a simple convex-like objective for a fixed architecture. The loss decreases monotonically during training, gradients are well-behaved, and training rarely diverges. This stability is one of the main practical advantages of flow matching over GANs.
Distributed training. Large flow matching models are trained on clusters of hundreds to thousands of GPUs using data parallelism. Each GPU processes a different mini-batch, and gradients are averaged across GPUs. The training is embarrassingly parallel — there are no sequential dependencies between mini-batches. This is unlike autoregressive models, which require sequential token generation even during training.
Open research questions as of 2026:
1. Can we achieve one-step generation at the quality of 50-step models? Consistency distillation is getting close.
2. How to best handle conditional generation for complex, multi-modal conditions (text + image + layout)?
3. Can flow matching scale to extremely long videos (minutes, not seconds)?
4. What are the fundamental limits of flow matching? Is there a quality ceiling, or will it keep improving with scale?
5. How to best combine flow matching with discrete diffusion for joint text-image models?
6. Can we develop better noise schedules that are provably optimal for a given dataset?
7. How to handle variable-length data (different image resolutions, different video lengths) efficiently?
Recommended reading. The original flow matching papers are accessible and well-written. In order of importance:
| Paper | Authors | Key contribution |
|---|---|---|
| Flow Matching for Generative Modeling | Lipman et al., 2022 | The original flow matching paper |
| Rectified Flow | Liu et al., 2022 | Straightening trajectories by retraining |
| Building Normalizing Flows with Stochastic Interpolants | Albergo & Vanden-Eijnden, 2023 | Generalized framework (stochastic interpolants) |
| Scalable Diffusion Models with Transformers (DiT) | Peebles & Xie, 2023 | Transformer architecture for diffusion |
| Scaling Rectified Flow Transformers | Esser et al., 2024 | SD3 architecture and training details |
Implementing flow matching from scratch. The complete implementation of a flow matching model on a 2D toy dataset requires only about 50 lines of PyTorch code (as shown in Chapter 8). For images, the main additions are: (1) a DiT or U-Net architecture instead of an MLP, (2) a VAE for latent encoding, and (3) text conditioning. The core training loop — sample z, sample t, sample ε, compute loss, update weights — remains identical regardless of the data type or model size.
Key equations to remember from this chapter:
These seven equations are the complete mathematical foundation of flow matching. Everything else — noise schedules, architectures, guidance, distillation — is built on top of this foundation. If you understand why each equation has the form it does, you understand flow matching.
The flow matching research ecosystem. Flow matching has spawned a vibrant research community. Key research directions include:
1. Optimal transport connections: Making the link between CondOT paths and Wasserstein distances rigorous and exploiting it for better schedules.
2. Faster samplers: Reducing the number of NFE needed from 50 to 1-4 via distillation, consistency models, and analytic solvers.
3. Better architectures: Moving from U-Nets to DiTs to MMDiTs (SD3's multi-modal DiT) to even more efficient designs.
4. Controllable generation: Going beyond text conditioning to spatial control (ControlNet), style transfer, and compositional generation.
5. Scientific applications: Protein design, drug discovery, weather prediction, materials science — anywhere that requires sampling from complex distributions over structured data.
6. Video and 3D: Scaling flow matching to higher-dimensional data types with temporal and spatial consistency constraints.
The field moves fast. By the time you read this, some of the "open questions" listed above may have been answered. But the mathematical foundations — probability paths, conditional vector fields, the marginalization trick, simulation-free training — are timeless. They will remain the bedrock of generative modeling regardless of which specific architectures or applications dominate in the future.
python # Rectified flow: straightening trajectories # Step 1: Train initial model (standard flow matching) model_v1 = train_flow_matching(dataset) # Step 2: Generate (noise, data) pairs using trained model x0_samples = torch.randn(10000, d) # noise x1_samples = generate(model_v1, x0_samples) # generated data # Step 3: Retrain on the paired data (noise, generated_data) # The trajectories are now straighter because we use # matched pairs instead of random (z, eps) samples model_v2 = train_flow_matching_paired(x0_samples, x1_samples) # model_v2 needs fewer steps for same quality
Number of function evaluations (NFE). Each Euler step requires one neural network forward pass. For 20 steps on a DiT-XL/2 model (675M parameters, generating 512×512 images), each forward pass takes about 40ms on an A100 GPU. Total generation time: 20 × 40ms = 0.8 seconds per image. Heun's method with 10 steps: 20 NFE, same cost, often better quality.
python # End-to-end flow matching: train and sample on Two Moons import torch import torch.nn as nn # 1. Dataset def sample_moons(n): t = torch.rand(n) * torch.pi x = torch.where(torch.rand(n) < 0.5, torch.stack([torch.cos(t), torch.sin(t)], 1), torch.stack([1-torch.cos(t), 1-torch.sin(t)-.5], 1)) return x + 0.05 * torch.randn_like(x) # 2. Model model = nn.Sequential( nn.Linear(3, 128), nn.SiLU(), nn.Linear(128, 128), nn.SiLU(), nn.Linear(128, 2)) # 3. Train (Algorithm 3) opt = torch.optim.Adam(model.parameters(), 1e-3) for step in range(5000): z = sample_moons(256) t = torch.rand(256, 1) eps = torch.randn_like(z) x = t * z + (1-t) * eps loss = (model(torch.cat([x, t], 1)) - (z - eps)).pow(2).mean() opt.zero_grad(); loss.backward(); opt.step() # 4. Sample (Algorithm 1) x = torch.randn(500, 2) for i in range(50): t = torch.full((500, 1), i/50) x = x + (1/50) * model(torch.cat([x, t], 1)) # x now contains 500 samples from Two Moons!
Flow matching is the foundation of modern generative AI. Let's summarize the complete picture.
| Component | Formula | Status |
|---|---|---|
| Conditional path | pt(·|z) = N(αtz, βt2I) | Chosen by us |
| Marginal path | pt(x) = ∫ pt(x|z)pdata(z)dz | Exists but intractable |
| Conditional vector field | uttarget(x|z) = α̇tz + β̇tε | Known analytically |
| Marginal vector field | uttarget(x) = ∫ u(x|z) · posterior(z|x) dz | Exists but intractable |
| CFM loss | E[||uθ(x,t) − uttarget(x|z)||2] | Tractable & equivalent to LFM |
| Sampling | Euler ODE simulation | n neural net evaluations |
Common pitfalls and practical tips:
| Pitfall | Solution |
|---|---|
| t = 0 or t = 1 causes division by zero | Clamp t to [ε, 1−ε] with ε = 10−5 |
| Loss spikes for extreme t values | Use logit-normal time sampling instead of Unif[0,1] |
| Generated samples are blurry | Increase Euler steps, or use Heun's method |
| Model memorizes training data | Dataset too small; increase diversity |
| Straight trajectories not achieved | Verify CondOT path implementation (alpha=t, beta=1-t) |
Models trained with flow matching include:
| Model | Domain | Organization |
|---|---|---|
| Stable Diffusion 3 | Text-to-image | Stability AI |
| FLUX | Text-to-image | Black Forest Labs |
| Movie Gen Video | Text-to-video | Meta |
| AlphaFold3 | Protein structure | Google DeepMind |