本文最后更新于:June 11, 2025 am
The series of tutorial is based on Flow Matching Guide and Code
- arXiv: 2412.06264
- Thank you, META
Flow Matching Problem
Instead of learning the likelihood of the target like they did in flow models, we directly learn the ground truth velocity field ut, that is, find ut(θ) generating pt, with p0=p and p1=q. In Flow Matching loss, we minimize the difference of the learned velocity field and the GT velocity field directly
L(θ)=EXt∼ptD(ut(Xt),utθ(Xt))
where D is a dissimilarity measure between vectors, such as l2-norm D(u,v)=∥u−v∥2.
Data Dependencies
We know Flow Matching is actually learning the mapping from a source distribution to a target distribution, then the source and the target could be independent, or they originate from a joint distribution
(X0,X1)∼π0,1(X0,X1)
known as coupling. If they are independent, we would write π0,1(X0,X1)=p(X0)q(X1). For example, in diffusion models, we generate images from a standard Gaussian noise, hence they are independent. But if we want to generate high resolution images from their low resolution counterpart (not conditioning), then they are coupling.
Designing Probability Path and the GT Velocity Field
In FM problem, we want to learn the GT velocity field, but we don’t have access to a tractable velocity field. Here, we start from building a conditional probability path to show that the problem simplifies through conditional strategy.
Consider a probability path pt∣1 conditioned on a single target sample X1=x1, then the marginal probability path is
pt(x)=∫pt∣1(x∣x1)q(x1)dx1(1)
Remember that we must satisfy p0=p and p1=q, which means we start from our source distribution and at the target distribution. These boundary conditions can be enforced by requiring the conditional probability paths to satisfy:
p0∣1(x∣x1)=π0∣1(x∣x1)p1∣1(x∣x1)=δx1(x)
where π0∣1(x0∣x1)=π0,1(x0,x1)/q(x1) and δx1 is the delta measure centered at x1.
proof. From Equation (1), let t=0
p0(x)=∫p0∣1(x∣x1)q(x1)dx1=∫π0∣1(x∣x1)q(x1)dx1=π0(x)
which is just our source distribution, and similarly let t=1
p1(x)=∫p1∣1(x∣x1)q(x1)dx1=∫δx1(x)q(x1)dx1=q(x)
Q.E.D.
If the source and the target are independent, then p0∣1(x∣x1)=p(x). For the delta measure δx1(x), since it doesn’t have a density, we should read it as
∫pt∣1(x∣y)f(y)dy→f(x)t→1
for continuous functions f. An example of conditional path satisfying the boundary condition is
N(⋅∣tx1,(1−t)2I)→δx1(⋅)t→1
As you can see t=0 we have our standard Gaussian distribution and at t=1 we obtain our target samples.
We now find our GT velocity field according to the conditional probability path.
We know that the Continuity Equation tells us, if the conditional velocity field generates the conditional probability path, then
dtdpt∣1(x∣x1)=−∇⋅(pt∣1(x∣x1)ut(x∣x1))
then from Equation (1)
dtdpt(x)=∫q(x1)dtdpt∣1(x∣x1)dx1=−∇⋅∫q(x1)pt∣1(x∣x1)ut(x∣x1)dx1
Now, if the marginal velocity field generates the marginal probability path, the above equation must also be a Continuity Equation, compare the above equation to Continuity Equation, we can say
pt(x)ut(x)=∫q(x1)pt∣1(x∣x1)ut(x∣x1)dx1
devide both sides by pt(x) yields
ut(x)=∫ut(x∣x1)pt(x)pt∣1(x∣x1)q(x1)dx1(2)
Recall Bayes trick
p(x∣y)=p(y)p(x,y)=p(y)p(y∣x)p(x)
we then can transform Equation (2) into
ut(x)=∫ut(x∣x1)p1∣t(x1∣x)dx1(3)
So our GT velocity field is a weighted average of the conditional velocities ut(x∣x1), with weights p1∣t(x1∣x) representing the posterior probability of target samples x1 given the current sample x, we can also write Equation (3) in expectation form
ut(x)=E[ut(Xt∣X1)∣Xt=x]
General Conditioning
So far we are conditioning on X1 which is samples from target distribution, however, we can actually condition on any arbitrary Z∈Rm with PDF pZ, and the marginal probability path becomes
pt(x)=∫pt∣Z(x∣z)pZ(z)dz
which is generated by the following marginal velocity field
ut(x)=∫ut(x∣z)pZ∣t(z∣x)dz=E[ut(Xt∣Z)∣Xt=x](4)
We have an important theorem called marginalization trick, but before we introduce it, we have to make some assumptions
Assumption.
- pt∣Z(x∣z) is C1([0,1)×Rd) and ut(x∣z) is C1([0,1)×Rd,Rd) as a function of t,x
- pZ has bounded support, which means pZ(x)=0 outside some bounded set in Rm
- pt(x)>0 for all x∈Rd and t∈[0,1)
We have the following theorem
Theorem. Under the above assumption, if ut(x∣z) is conditionally integrable and generates the conditional probability path pt(⋅∣z), then the marginal velocity field ut generates the marginal probability path pt, for all t∈[0,1)
where the conditionally integrable means
∫01∫∫∥ut(x∣z)∥pt∣Z(x∣Z)pZ(x)dzdxdt<∞
proof. To prove that the marginal velocity field generates the marginal probability path, we can utilize the Continuity Equation
dtdpt(x)=dtd∫pt∣Z(x∣z)pZ(x)dz=∫dtdpt∣Z(x∣z)pZ(x)dz=∫−∇⋅ut(x∣z)pt∣Z(x∣z)pZ(x)dz=−∇⋅∫ut(x∣z)pt∣Z(x∣z)pZ(x)dz=−∇⋅∫ut(x∣z)pt(x)pt(x)pt∣Z(x∣z)pZ(x)dz=−∇⋅∫ut(x∣z)pZ∣t(z∣x)pt(x)dz=−∇⋅(ut(x)pt(x))
We can further prove that ut is integrable
∫01∫∥ut(x)∥pt(x)dxdt=∫01∫∥∥∥∥∥∫ut(x∣z)pZ∣t(z∣x)dz∥∥∥∥∥pt(x)dxdt≤∫01∫∫∥ut(x∣z)∥pZ∣t(z∣x)pt(x)dzdxdt=∫01∫∫∥ut(x∣z)∥pZ∣t(z∣x)pt(x)dzdxdt=∫01∫∫∥ut(x∣z)∥pt(x)pt∣Z(x∣z)pZ(z)pt(x)dzdxdt=∫01∫∫∥ut(x∣z)∥pt∣Z(x∣z)pZ(z)dzdxdt<∞
Q.E.D.
Here, we actually made a loop in the proof, we found ut by continuity equation but here we prove it satisfies continuity equation. Actually, the marginal velocity field ut is constructed to be like Equation (3), and we found that this construction satisfies Continuity Equation.
So far we have designed our probability path as conditional probability path and the velocity field as conditional velocity field, we also established the relation between conditional path/field and the marginal path/field.
However, we still don’t have a tractable GT velocity field to learn, let’s continue by taking a closer look to the loss function.
Flow Matching Loss
We want a tractable loss function, and now the GT velocity field ut from Equation (3) is still intractable because we have to marginalize over the entire training set. Now, we introduce a family of loss functions known as Bregman divergences, and we will show that using them will provide unbiased gradients for utθ(x) to learn ut(x) by learning the conditional counterpart utθ(x∣z), and hence we don’t need the marginal field ut(x) anymore.
The definition of Bregman divergence is as follows
D(u,v):=Φ(u)−[Φ(v)+⟨u−v,∇Φ(v)⟩]
Actually, the squared Euclidean distance D(u,v)=∥u−v∥2 is a Bregman divergence, let Φ(u)=∥u∥2, we have
D(u,v)=∥u∥2−[∥v∥2+⟨u−v,2v⟩]=∥u∥2−∥v∥2−2⟨u,v⟩+2∥v∥2=uTu−2uTv+vTv=(u−v)T(u−v)=∥u−v∥2
A key property of Bregman divergence is that their gradient with respect to the second argument is affine invariant
∇vD(au1+bu2,v)=a∇vD(u1,v)+b∇vD(u2,v)a+b=1
The above property allows us to swap the expected values with gradients as follows
∇vD(E[Y],v)=E[∇vD(Y,v)](5)
similar to the affine invariant, where we can take the linear operation out of the gradient.
We now restate our two objective, the first is to learn the marginal velocity field
LFM(θ)=Et,Xt∼ptD(ut(Xt),utθ(Xt))
the second is to learn the conditional velocity field
LCFM(θ)=Et,Z,Xt∼pt∣Z(⋅∣Z)D(ut(Xt∣Z),utθ(Xt))
We will show the following theorem
Theorem
∇θLFM(θ)=∇θLCFM(θ)
So the minimizer of the conditional flow matching loss is the marginal velocity field ut, because they have the same gradient.
proof.
∇θLFM(θ)=∇θEt,Xt∼ptD(ut(Xt),utθ(Xt))=Et,Xt∼pt∇θD(ut(Xt),utθ(Xt))=Et,Xt∼pt∇vD(ut(Xt),utθ(Xt))∇θutθ(Xt)=(4)Et,Xt∼pt∇vD(EZ∼pZ∣t(⋅∣Xt)[ut(Xt∣Z)],utθ(Xt))∇θutθ(Xt)=Et,Xt∼ptEZ∼pZ∣t(⋅∣Xt)∇vD(ut(Xt∣Z),utθ(Xt))∇θutθ(Xt)=Et,Xt∼ptEZ∼pZ∣t(⋅∣Xt)∇θD(ut(Xt∣Z),utθ(Xt))=∇θEt,Z∼q,Xt∼pt∣Z(⋅∣Z)D(ut(Xt∣Z),utθ(Xt))=∇θLCFM(θ)
Q.E.D.
The above theorem actually has a more general form
Theorem. Let X∈SX, Y∈SY be RVs over state spaces SX and SY, we have a function gθ(x):Rp×SX→Rn, where θ∈Rp is the learnable parameters.
Let Dx(u,v), x∈SX be a Bregman divergence over a convex set Ω⊂Rn that contains the image of gθ(x), then
∇θEX,YDX(Y,gθ(X))=∇θEXDX(E[Y∣X],gθ(X))
and the global minimum of the above gθ(x) satisfies
gθ(x)=E[Y∣X=x]
I don’t really understand the subscript x in Dx here, maybe they mean we can choose a divergence depends on x and varies according to x∼X? I don’t quite get it.
proof.
∇θEX,YDX(Y,gθ(X))=EX,Y∇vDX(Y,gθ(X))∇θgθ(X)=EX[E[∇vDX(Y,gθ(X))∇θgθ(X)∣X]]=(5)EX[∇vDX(E[Y∣X],gθ(X))∇θgθ(X)]=EX[∇θDX(E[Y∣X],gθ(X))]=∇θEXDX(E[Y∣X],gθ(X))
Therefore, we can choose gθ(x)=E[Y∣X=x] to obtain the global minimum
Q.E.D.
Solving Conditional Generation
We have now designed the conditional probability path, the conditional velocity field, the marginal probability path and the marginal velocity field, we also proved that we can learn conditional velocity field instead of the intractable marginal one. We now consider how to realize these things.
Here, we design conditional path and velocity fields via conditional flows, we define a flow model Xt∣1 satisfying the boundary conditions. Once we defined the flow, we can get the velocity field via flow ODE.
If we define
Xt∣1=ψt(X0∣x1)X0∼π0∣1(⋅∣x1)
where ψ:[0,1)×Rd×Rd→Rd is a conditional flow satisfying
ψt(x∣x1)={xt=0x1t=1(6)
We then can obtain the conditioned probability path via push-forward formula
pt∣1(x∣x1):=[ψt(⋅∣x1)♯π0∣1(⋅∣x1)](x)(7)
which means we use the conditional flow we defined to push-forward the source distribution. And you will find that the probability path satisfies boundary conditions. At t=0, the flow is identity map, so it is the source distribution directly. At t=1, the flow is a constant map, it only provides us the target samples.
Recall the flow ODE
dtdψt(x)=ut(ψt(x))
Let x=ψ−1(x′)
dtdψt(ψ−1(x′))=ut(x′)
which gives us a way to extract the velocity field. Let’s denote ψ˙t=dtdψt, then we can write the conditional version as
ut(x∣x1)=ψ˙t(ψt−1(x∣x1)∣x1)(8)
Now we have shown that the conditional flow can give us the conditional probability path and the conditional velocity field, so we only need to build the conditional flow satisfying Equation (6).
Recalibrate our Flow Matching Loss
We know that we have defined our loss as learning the conditional velocity field
LCFM(θ)=Et,X1,Xt∼pt(⋅∣X1)D(ut(Xt∣X1),utθ(Xt))
From Equation (7) we have
ut(Xt∣X1)=ψ˙t(ψt−1(Xt∣X1)∣X1)
We know that by definition of the flow function we have
Xt=ψt(X0∣X1)
then
ut(Xt∣X1)=ψ˙t(ψt−1(Xt∣X1)∣X1)=ψ˙t(ψt−1(ψt(X0∣X1)∣X1)∣X1)=ψ˙t(X0∣X1)(9)
Hence our loss function becomes
LCFM(θ)=Et,X1,Xt∼pt(⋅∣X1)D(ψ˙t(X0∣X1),utθ(Xt))
with the minimizer
utθ(x)=E[ψ˙t(X0∣X1)∣Xt=x](10)
Recalibrate Marginalization Trick
We must again state our marginalization trick for conditional flows. Assume ut(x∣x1) is conditionally integrable, which means
Et,(X0,X1)∼π0,1∥ψ˙t(X0∣X1)∥<∞
we give the following corollary to the marginalization trick
Corollary. Assume that q has bounded support, π0∣1(⋅∣x1) us C1(Rd) and strictly positive for some x1 with q(x1)>0, and ψt(x∣x1) is a conditional flow satisfying the boundary condition and the above integrable condition.
Then pt∣1(x∣x1) and ut(x∣x1) defined (7) and (8), define a marginal velocity field ut(x) generating the marginal probability path pt(x) interpolating p and q.
We know that from the marginalization trick before, if conditional field is conditionally integrable and generates the conditional probability path, then the marginal field generates the marginal path. We already know that the conditional filed generates the conditional path (they are linked by flow), then we only need to prove the conditional field is integrable
proof.
∫01∫∥ut(x∣x1)∥pt(x)dx=∫01∫∥ut(x∣x1)∥pt∣1(x∣x1)q(x1)dx1dxdt=Et,X1∼q,Xt∼pt∣1(⋅∣X1)∥ut(Xt∣X1)∥=(9)Et,(X0,X1)∼π0,1∥ψ˙t(X0∣X1)∥<∞
We then proved that the field is conditionally integrable, then the marginal field ut generates marginal path pt
Q.E.D.
Optimal Transport and Linear Conditional Flow
We have shown in the previous section that the conditional path and field can be designed via conditional flow, we have one last question left: How do we find a useful conditional flow?
Here, we introduce the flow derived from the dynamic Optimal Transport(OT) problem.
(pt∗,ut∗)=pt,utargmin∫01∫∥ut(x)∥2pt(x)dxdt
s.t.
p0=p,p1=qdtdpt+∇⋅(ptut)=0
Together with the flow ODE, solving the OT problem gives us the flow with the form
ψt∗(x)=tϕ(x)+(1−t)x
called the OT displacement interpolant.
The above form also solves the Flow Matching problem
Xt=ψt∗(X0)∼pt∗X0∼p
we can see that the OT formulation promotes straight sample trajectories
Xt=ψt∗(X0)=tϕ(X0)+(1−t)X0=X0+t(ϕ(X0)−X0)
so the distribution at t is just a displacement with the constant velocity ϕ(X0)−X0 times t. This form is also much more friendly to ODE solvers.
To find a specific ϕ, and hence a specific ψt(x∣x1), let’s solve the OT problem with the conditional velocity field in Equation (10)
∫01∫∥ut(x)∥2pt(x)dxdt=∫01EXt∼pt∥ut(Xt)∥2dt=(10)∫01EXt∼pt∥E[ψ˙t(X0∣X1)∣Xt]∥2dt≤∫01EXt∼ptE[∥ψ˙t(X0∣X1)∥2∣Xt]dt=E(X0,X1)∼π0,1∫01∥ψ˙t(X0∣X1)∥2dt
The above problem involves minimizing the expectation, but now we can minimize individually for each X0 and X1, let γt=ψt(x∣x1), we can focus on the integral itself
γmin∫01∥γ˙t∥2dts.t.γ0=x,γ1=x1
This is a variational problem and can be solved using Eular-Lagrange equations, we directly write the final answer, the conditional flow that comes from the OT problem is
ψt(x∣x1)=tx1+(1−t)x(11)
this is the minimizer of the above variational problem.
We can find some facts, first, the linear conditional flow in (11) minimizes a bound of the Kinetic Energy among all conditional flows.
Second, note that Equation (11) now gives us
Xt=ψt(X0∣X1)=tX1+(1−t)X0
and rearrange it yields
X0=1−tXt−tX1
if the target q consists of a single data point q(x)=δx1(⋅), then the above equation is
X0=1−tXt−tx1
which means if you know Xt, then X0 is deterministic because there is no randomness in x1 for Dirac measures. We will find that
E[ψ˙t(X0∣x1)∣Xt]=ψ˙t(X0∣x1)=x1−X0
since Xt is given and x1 is Dirac, and also X0 is deterministic when Xt is given, we can get rid of the expectation.
Theorem. If q=δx1, then the dynamic OT problem has an analytic solution given by the OT displacement interpolant in (11).
We can use this linear conditional flow to construct our conditional velocity field and the conditional probability path, then our Flow Matching Loss. Refer to the 01-Overview of Flow Matching
.