前言:经过前面两篇文章,我们定义了flow matchin model和diffusion model的前向过程,并且构建了训练目标,从本文开始,我们将从高斯路径出发,推导实际的loss函数,然后着手分别训练unconditional flow model和unconditional diffusion model

训练一个生成式模型(上)

flow model

回顾flow matching model的定义:
$$
\begin{align}
X_0 \sim p_{init} , \quad dX_t = u^{\theta}_t(X_t)dt
\end{align}
$$
直觉上说,我们应该定义loss函数:

$$ \begin{align} \mathcal{L}_{FM}(\theta) = \mathbb{E}_{t\sim Unif, x\sim p_t}\left[ \| u^{\theta}_t(x) - u^{target}_t(x)\|^2\right]\\ = \mathbb{E}_{t\sim Unif,z \sim p_{data}, x\sim p_t(\cdot | z)}\left[ \| u^{\theta}_t(x) - u^{target}_t(x)\|^2\right] \end{align} $$

式中$p_t(x) = \int p_t(x|z)p_{data}(z)dz$。这个loss做了几件事:

  • 1、时间步t满足[0,1]的均匀分布
  • 2、我们从$p_{data}$中采样z,添加噪声,计算$u^{\theta}_t(x)$
  • 3、计算$u^{\theta}_t(x)$和$u^{target}_t(x)$
    如前文所说我们没法直接计算$u^{target}_t(x)$,因为:
$$ \begin{align} u^{target}_t(x) = \int u^{target}_t(x|z) \frac{p_t(x|z)p_{data}(z)}{p_t(x)} \mathrm{d}z \end{align} $$

而我们没法得到$p_t(x)$,于是我们尝试将目光转向条件概率路径,我们定义conditional flow matching loss:

$$ \begin{align} \mathcal{L}_{CFM}(\theta) = \mathbb{E}_{t\sim Unif,z \sim p_{data}, x\sim p_t(\cdot | z)}\left[ \| u^{\theta}_t(x) - u^{target}_t(x|z)\|^2\right] \end{align} $$

虽然$\mathcal{L}_{CFM}(\theta)$是可以计算的,但是计算出来是否有用,毕竟我们的最终目的是要最小化$\mathcal{L}_{FM}(\theta)$。答案是等价的,我们先放出结论:

$$ \begin{align} \mathcal{L}_{FM}(\theta) = \mathcal{L}_{CFM}(\theta) + C \end{align} $$

下面给出证明:


$$ \begin{align} \mathcal{L}_{FM}(\theta) &= \mathbb{E}_{t\sim Unif,x \sim p_t}\left[\|u^{\theta}_t(x) - u^{target}_t(x)\|^2 \right]\\ &= \mathbb{E}_{t\sim Unif,x \sim p_t}\left[\| u^{\theta}_t(x)\|^2 - 2u^{\theta}_t(x)^Tu^{target}_t(x) + \|u^{target}_t(x)\|^2\right]\\ &=\mathbb{E}_{t\sim Unif,x \sim p_t}\left[\| u^{\theta}_t(x)\|^2\right] - 2\mathbb{E}_{t\sim Unif,x \sim p_t}\left[u^{\theta}_t(x)^Tu^{target}_t(x)\right] + \underbrace{\mathbb{E}_{t\sim Unif,x \sim p_t}\left[\|u^{target}_t(x)\|^2\right]}_\text{C1}\\ &= \mathbb{E}_{t\sim Unif,z\sim p_{data},x \sim p_t(\cdot|z)}\left[\| u^{\theta}_t(x)\|^2\right] - 2\mathbb{E}_{t\sim Unif,x \sim p_t}\left[u^{\theta}_t(x)^Tu^{target}_t(x)\right] + C1 \end{align} $$

其中$\mathbb{E}_{t\sim Unif,x \sim p_t}\left[\|u^{target}_t(x)\|^2\right]$是整个目标分布的向量场,与我们训练的神经网络无关,所以可以将其化简为一个常数C1,现在还有一项$\mathbb{E}_{t\sim Unif,x \sim p_t}\left[u^{\theta}_t(x)^Tu^{target}_t(x)\right]$需要我们解决:

$$ \begin{align} \mathbb{E}_{t\sim Unif,x \sim p_t}\left[u^{\theta}_t(x)^Tu^{target}_t(x)\right] &= \int^1_0\int p_t(x)u^{\theta}_t(x)^Tu^{target}_t(x)\mathrm{d}x\mathrm{d}t\\ &= \int^1_0\int p_t(x)u^{\theta}_t(x)^T\left[\int u^{target}_t(x|z) \frac{p_t(x|z)p_{data}(z)}{p_t(x)} \mathrm{d}z\right]\mathrm{d}x\mathrm{d}t\\ &= \int^1_0\int\int u^{\theta}_t(x)^Tu^{target}_t(x|z)p_t(x|z)p_{data}(z)\mathrm{d}z\mathrm{d}x\mathrm{d}t\\ &= \mathbb{E}_{t\sim Unif,z\sim p_{data},x \sim p_t(\cdot|z)}\left[u^{\theta}_t(x)^Tu^{target}_t(x|z)\right] \end{align} $$

我们将这一项放回到$\mathcal{L}_{FM}(\theta)$中再做一个整理:

$$ \begin{align} \mathcal{L}_{FM}(\theta) &= \mathbb{E}_{t\sim Unif,z\sim p_{data},x \sim p_t(\cdot|z)}\left[\| u^{\theta}_t(x)\|^2\right] - 2\mathbb{E}_{t\sim Unif,z\sim p_{data},x \sim p_t(\cdot|z)}\left[u^{\theta}_t(x)^Tu^{target}_t(x|z)\right] + C1\\ &= \mathbb{E}_{t\sim Unif,z\sim p_{data},x \sim p_t(\cdot|z)}\left[\| u^{\theta}_t(x)\|^2 -2u^{\theta}_t(x)^Tu^{target}_t(x|z)\right] + C1\\ &= \mathbb{E}_{t\sim Unif,z\sim p_{data},x \sim p_t(\cdot|z)}\left[\| u^{\theta}_t(x)\|^2 -2u^{\theta}_t(x)^Tu^{target}_t(x|z) - \|u^{target}_t(x|z)\|^2 + \|u^{target}_t(x|z)\|^2\right] + C1\\ &= \mathbb{E}_{t\sim Unif,z\sim p_{data},x \sim p_t(\cdot|z)}\left[\| u^{\theta}_t(x) - u^{target}_t(x|z)\|^2\right] - \underbrace{\mathbb{E}_{t\sim Unif,z\sim p_{data},x \sim p_t(\cdot|z)}\left[ \|u^{target}_t(x|z)\|^2\right]}_\text{C2} + C1\\ &= \mathcal{L}_{CFM}(\theta) + C1 + C2 \end{align} $$

同样的$\mathbb{E}_{t\sim Unif,z\sim p_{data},x \sim p_t(\cdot|z)}\left[ \|u^{target}_t(x|z)\|^2\right]$也和我们的神经网络无关,对于整个目标分布来说,其条件向量场也是不变的,所以我们将其化为常数C2


ok我们成功的获得了一个可以实际计算的loss函数,现在,我们把它使用在高斯路径下,进一步梳理一下式子的形式:
记$\epsilon \sim \mathcal{N}(0,I_d)$,我们有$x_t = \alpha_t z + \beta_t \epsilon \sim \mathcal{N}(\alpha_t z , \beta^2_t I_d) = p_t(\cdot|z)$,根据我们构建的训练目标$u^{target}_t(x|z) = (\dot{\alpha_t} - \frac{\dot{\beta}_t}{\beta_t})z + \frac{\dot{\beta}_t}{\beta_t}x$,代入loss函数:

$$ \begin{align} \mathcal{L}_{CFM}(\theta) &= \mathbb{E}_{t\sim Unif,z\sim p_{data},x \sim \mathcal{N}(\alpha_t z ,\beta^2_t I_d)}\left[\|u^{\theta}_t(x) - (\dot{\alpha_t} - \frac{\dot{\beta}_t}{\beta_t})z - \frac{\dot{\beta}_t}{\beta_t}x\|^2\right]\\ &= \mathbb{E}_{t\sim Unif,z\sim p_{data},x \sim \mathcal{N}(\alpha_t z ,\beta^2_t I_d)}\left[\|u^{\theta}_t(\alpha_t z , + \beta_t \epsilon) - (\dot{\alpha}_t z + \dot{\beta}_t \epsilon)\|^2\right] \end{align} $$

对于我们的超参数$\alpha$和$\beta$来说,我们只需要满足:

  • $\alpha_0 = 0$ , $\alpha_1 = 1$
  • $\beta = 1$ , $\beta_1 = 0$
    我们不妨令:
  • $\alpha_t = t$
  • $\beta_t = 1 - t$
    代入loss函数:
$$ \begin{align} \mathcal{L}_{CFM}(\theta) &= \mathbb{E}_{t\sim Unif,z\sim p_{data},x \sim \mathcal{N}(\alpha_t z ,\beta^2_t I_d)}\left[\|u^{target}_t(tz + (1-t)\epsilon)-(z-\epsilon)\|^2\right] \end{align} $$

这是一个非常简洁的形式了,我们解读一下:对于我们训练的神经网络来说,采样一个目标分布的样本z,为其添加噪声输入神经网络中,要求其输出样本与”所添加噪声“的差距


ok,数学太多了,在马不停蹄进入sde之前,我们似乎先训练一个简单的flow model会更加振奋人心,现在我们来做一个flow model吧

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
from tqdm import tqdm
import os
import math
from PIL import Image
from abc import ABC, abstractmethod


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Time Embedding 模块 ---
class SinusoidalPositionEmbeddings(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
self.half_dim = dim // 2
self.emb_scale = np.log(10000) / (self.half_dim - 1)

def forward(self, time):
device = time.device
embeddings = torch.exp(torch.arange(self.half_dim, device=device) * -self.emb_scale)
embeddings = time * embeddings.unsqueeze(0)
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
return embeddings

# --- 卷积块 ---
class Block(nn.Module):
def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
super().__init__()
self.time_mlp = nn.Linear(time_emb_dim, out_ch)
if up:
self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
else:
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
self.bnorm1 = nn.BatchNorm2d(out_ch)
self.bnorm2 = nn.BatchNorm2d(out_ch)
self.relu = nn.SiLU()

def forward(self, x, t, ):
h = self.bnorm1(self.relu(self.conv1(x)))
time_emb = self.relu(self.time_mlp(t))
time_emb = time_emb[(..., ) + (None, ) * 2]
h = h + time_emb
h = self.bnorm2(self.relu(self.conv2(h)))
return self.transform(h)

# --- U-Net ---
class SimpleUNet(nn.Module):
def __init__(self):
super().__init__()
image_channels = 1
down_channels = (32, 64, 128)
up_channels = (128, 64, 32)
out_dim = 1
time_emb_dim = 32

self.time_mlp = nn.Sequential(
SinusoidalPositionEmbeddings(time_emb_dim),
nn.Linear(time_emb_dim, time_emb_dim),
nn.SiLU()
)

self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)

self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1], \
time_emb_dim) for i in range(len(down_channels)-1)])

self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], \
time_emb_dim, up=True) for i in range(len(up_channels)-1)])

self.output = nn.Conv2d(up_channels[-1], out_dim, 1)

def forward(self, x, t):
t = self.time_mlp(t)
x = self.conv0(x)
residuals = []
for down in self.downs:
x = down(x, t)
residuals.append(x)
for up in self.ups:
residual = residuals.pop()
if x.shape != residual.shape:
x = transforms.Resize(residual.shape[2:])(x)
x = torch.cat((x, residual), dim=1)
x = up(x, t)
return self.output(x)


def compute_loss(model, z):
batch_size = z.shape[0]
epsilon = torch.randn_like(z).to(device)
t = torch.rand(batch_size, 1).to(device)
t_img = t.view(batch_size, 1, 1, 1)
x_t = t_img * z + (1 - t_img) * epsilon # x_t = t * z + (1 - t) * epsilon
u_target = z - epsilon # loss = ||u - (z - epsilon)||^2
u_pred = model(x_t, t)
loss = nn.functional.mse_loss(u_pred, u_target)
return loss

@torch.no_grad()
def simulate(model, n_samples=16, n_steps=100):
model.eval()

x = torch.randn(n_samples, 1, 28, 28).to(device)

trajectory = []

dt = 1.0 / n_steps

for i in tqdm(range(n_steps), desc="Sampling"):
current_x = x.clone().detach().cpu()
current_x = current_x / 2 + 0.5
current_x = current_x.clamp(0, 1)
trajectory.append(current_x)

t_value = i / n_steps
t = torch.full((n_samples, 1), t_value).to(device) # t = i / n_steps

v = model(x, t)
x = x + v * dt # ode Euler 更新:x_{t+dt} = x_t + v * dt

final_x = x.clone().detach().cpu()
final_x = final_x / 2 + 0.5
final_x = final_x.clamp(0, 1)
trajectory.append(final_x)

return trajectory

# --- util function for gif ---
def save_trajectory_gif(trajectory, filename="flow_process.gif"):
print(f"Saving GIF to {filename}...")

frames = []
for x_batch in trajectory:
grid = make_grid(x_batch, nrow=4, padding=2, pad_value=1)
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
im = Image.fromarray(ndarr)
frames.append(im)

frames[0].save(
filename,
save_all=True,
append_images=frames[1:],
optimize=False,
duration=50,
loop=0
)
print("Done!")


def train_loop(model, train_loader, optimizer, num_epochs):
model.train()

for epoch in range(num_epochs):
total_loss = 0
pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

for batch_idx, (data, _) in enumerate(pbar):
z = data.to(device)
optimizer.zero_grad()
loss = compute_loss(model, z)
loss.backward()
optimizer.step()
total_loss += loss.item()
pbar.set_postfix({"Loss": f"{loss.item():.4f}"})

torch.save(model.state_dict(), "flow_unet_mnist.pth")

def prepare_data():
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_data = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True)
return train_loader

if __name__ == "__main__":
train_loader = prepare_data()
model = SimpleUNet().to(device)

if os.path.exists("flow_unet_mnist.pth"):
try:
model.load_state_dict(torch.load("flow_unet_mnist.pth", map_location=device))
print("Loaded existing model weights.")
except:
print("Weight shape mismatch, starting from scratch.")

optimizer = optim.Adam(model.parameters(), lr=1e-3)

train_loop(model, train_loader, optimizer, num_epochs=5)

trajectory = simulate(model, n_samples=16, n_steps=100)

save_trajectory_gif(trajectory, "mnist_generation.gif")

final_image = trajectory[-1]
save_image(final_image, "unet_final_result.png", nrow=4)

下图是我们简单训练了5个epoch得到的结果:
flow_model_result

值得注意的是:

  • 我们需要显式地将时间步t喂给model,我们的$u^{target}_t(x)$也体现了这一点,是一个关于x和t的函数,可以写成$u^{target}(x,t)$会比较直观
  • 使用mlp没法有效地表达图像数据(我们尝试的结果是,使用mlp不仅收敛速度慢,而且效果很差),我们使用一个简单的unet结构来表达。