生成模型大道至简|Rectified Flow基础概念|代码
最近看了下SD3的论文,里面用到了Rectified Flow,之前没有接触过,了解了下是项有意思的技术。值得推荐,这里记录下Rectified Flow的基础概念和代码实现。详细的原理和深入理解建议查阅论文。作者应该也在知乎,可以通过参考资料的链接直达。
定义:一个分布\pi_0转化成另一个分布\pi_1的问题,具体\pi_0和\pi_1选择不同就可以式不同的问题,比如从噪声分布到真实图片分布,或者从某类图像分布到另一类图像的分布。
比如这里我们用两个不同的高斯混合模型来表示两个分布,并分别采样10000个样本点
import torch
import numpy as np
import torch.nn as nn
from torch.distributions import Normal, Categorical
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions.mixture_same_family import MixtureSameFamily
import matplotlib.pyplot as plt
import torch.nn.functional as F
D = 10.
M = D+5
VAR = 0.3
DOT_SIZE = 4
COMP = 3
initial_mix = Categorical(torch.tensor([1/COMP for i in range(COMP)]))
initial_comp = MultivariateNormal(torch.tensor([[D * np.sqrt(3) / 2., D / 2.], [-D * np.sqrt(3) / 2., D / 2.], [0.0, - D * np.sqrt(3) / 2.]]).float(), VAR * torch.stack([torch.eye(2) for i in range(COMP)]))
initial_model = MixtureSameFamily(initial_mix, initial_comp)
samples_0 = initial_model.sample([10000])
target_mix = Categorical(torch.tensor([1/COMP for i in range(COMP)]))
target_comp = MultivariateNormal(torch.tensor([[D * np.sqrt(3) / 2., - D / 2.], [-D * np.sqrt(3) / 2., - D / 2.], [0.0, D * np.sqrt(3) / 2.]]).float(), VAR * torch.stack([torch.eye(2) for i in range(COMP)]))
target_model = MixtureSameFamily(target_mix, target_comp)
samples_1 = target_model.sample([10000])
print('Shape of the samples:', samples_0.shape, samples_1.shape)
plt.figure(figsize=(4,4))
plt.xlim(-M,M)
plt.ylim(-M,M)
plt.title(r'Samples from $\pi_0$ and $\pi_1$')
plt.scatter(samples_0[:, 0].cpu().numpy(), samples_0[:, 1].cpu().numpy(), alpha=0.1, label=r'$\pi_0$')
plt.scatter(samples_1[:, 0].cpu().numpy(), samples_1[:, 1].cpu().numpy(), alpha=0.1, label=r'$\pi_1$')
plt.legend()
plt.tight_layout()
可以得到
基于上面两个分布的样本点,我们想要学习一个模型(映射关系),能够将\pi_0上的点映射(运动)到\pi_1,用常微分方程表示就是流模型(flow):
\frac d{dt}Z_t=v(Z_t,t)\text{,}Z_0\sim\pi_0\text{,}\forall t\in[0,1]. \\
两点之间直线最短,如果分布\pi_0上的点通过两点之间的连线运动到\pi_1,那不就很完美。可以定义两个点之间的插值如下
X_t = t X_1 + (1-t)X_0 \\
前面提到的flow就等价于“拉直的流模型(Rectified Flow)”
\frac d{dt}X_t=X_1-X_0,\quad\forall t\in[0,1]. \\
也就是学习一个v,使得它尽可能接近X_1-X_0
\min_v\int_0^1\mathbb{E}_{X_0\sim\pi_0,X_1\sim\pi_1}\left[||(X_1-X_0)-v(X_t,t)||^2\right]dt,\quad\mathrm{where}\quad X_t=tX_1+(1-t)X_0. \\
我们可以假设这个v是一个三层的神经网络
class MLP(nn.Module):
def __init__(self, input_dim=2, hidden_num=100):
super().__init__()
self.fc1 = nn.Linear(input_dim+1, hidden_num, bias=True)
self.fc2 = nn.Linear(hidden_num, hidden_num, bias=True)
self.fc3 = nn.Linear(hidden_num, input_dim, bias=True)
self.act = lambda x: torch.tanh(x)
def forward(self, x_input, t):
inputs = torch.cat([x_input, t], dim=1)
x = self.fc1(inputs)
x = self.act(x)
x = self.fc2(x)
x = self.act(x)
x = self.fc3(x)
return x
定义RectifiedFlow类,包含获得训练样本和从常微分方程采样
class RectifiedFlow():
def __init__(self, model=None, num_steps=1000):
self.model = model
self.N = num_steps
def get_train_tuple(self, z0=None, z1=None):
t = torch.rand((z1.shape[0], 1))
z_t = t * z1 + (1.-t) * z0
target = z1 - z0
return z_t, t, target
@torch.no_grad()
def sample_ode(self, z0=None, N=None):
### NOTE: Use Euler method to sample from the learned flow
if N is None:
N = self.N
dt = 1./N
traj = [] # to store the trajectory
z = z0.detach().clone()
batchsize = z.shape[0]
traj.append(z.detach().clone())
for i in range(N):
t = torch.ones((batchsize,1)) * i / N
pred = self.model(z, t)
z = z.detach().clone() + pred * dt
traj.append(z.detach().clone())
return traj
损失函数的定义
def train_rectified_flow(rectified_flow, optimizer, pairs, batchsize, inner_iters):
loss_curve = []
for i in range(inner_iters+1):
optimizer.zero_grad()
indices = torch.randperm(len(pairs))[:batchsize]
batch = pairs[indices]
z0 = batch[:, 0].detach().clone()
z1 = batch[:, 1].detach().clone()
z_t, t, target = rectified_flow.get_train_tuple(z0=z0, z1=z1)
pred = rectified_flow.model(z_t, t)
loss = (target - pred).view(pred.shape[0], -1).abs().pow(2).sum(dim=1)
loss = loss.mean()
loss.backward()
optimizer.step()
loss_curve.append(np.log(loss.item())) ## to store the loss curve
return rectified_flow, loss_curve
1-Rectified Flow的训练样本对是随机从两个分布进行选取的
x_0 = samples_0.detach().clone()[torch.randperm(len(samples_0))]
x_1 = samples_1.detach().clone()[torch.randperm(len(samples_1))]
x_pairs = torch.stack([x_0, x_1], dim=1)
print(x_pairs.shape)
训练后
iterations = 10000
batchsize = 2048
input_dim = 2
rectified_flow_1 = RectifiedFlow(model=MLP(input_dim, hidden_num=100), num_steps=100)
optimizer = torch.optim.Adam(rectified_flow_1.model.parameters(), lr=5e-3)
rectified_flow_1, loss_curve = train_rectified_flow(rectified_flow_1, optimizer, x_pairs, batchsize, iterations)
plt.plot(np.linspace(0, iterations, iterations+1), loss_curve[:(iterations+1)])
plt.title('Training Loss Curve')
可以得到
可以发现由于两个样本的点是随机选择来组成训练样本对的,可能导致从一个分布的点走到另一分布有点“走弯路”。
因此,2-Rectified Flow则是用1-Rectified Flow的模型的采样点来组成样本对
z10 = samples_0.detach().clone()
traj = rectified_flow_1.sample_ode(z0=z10.detach().clone(), N=100)
z11 = traj[-1].detach().clone()
z_pairs = torch.stack([z10, z11], dim=1)
print(z_pairs.shape)
同样再训练一下
reflow_iterations = 50000
rectified_flow_2 = RectifiedFlow(model=MLP(input_dim, hidden_num=100), num_steps=100)
import copy
rectified_flow_2.net = copy.deepcopy(rectified_flow_1) # we fine-tune the model from 1-Rectified Flow for faster training.
optimizer = torch.optim.Adam(rectified_flow_2.model.parameters(), lr=5e-3)
rectified_flow_2, loss_curve = train_rectified_flow(rectified_flow_2, optimizer, z_pairs, batchsize, reflow_iterations)
plt.plot(np.linspace(0, reflow_iterations, reflow_iterations+1), loss_curve[:(reflow_iterations+1)])
可以得到
可以看到现在一个分布的点可以几乎以“直线的方式”走到另一个分布。
参考资料
赞!非常清晰和凝练学习了
作者好像没有放路径图可视化的代码,我画蛇添足补充一个
plt.figure()
mycolor = ['b', 'r']
color_candidate = ['g', 'c', 'm', 'y', 'k']
sample_num = 20
this_step = 5
for i in range(sample_num): #选择了20个样本,作者总共设置了1万个样本
pointindex = list(range(0, (len(traj) - 1), this_step)) + [len(traj) - 1] #作者设置的每条线101个点,但画多了看不清,少画些吧
thiscolor = random.sample(color_candidate, 1)[0]
for k in range(len(pointindex) - 1):
plt.plot([traj[pointindex[k]][i, 0], traj[pointindex[k + 1]][i, 0]], [traj[pointindex[k]][i, 1], traj[pointindex[k + 1]][i, 1]], c=thiscolor) #画这一小段的线,使用本次随机到的颜色
if(k == 0): #如果是起点,点就画大些,使用指定颜色蓝色
plt.scatter([traj[pointindex[k]][i, 0]], [traj[pointindex[k]][i, 1]], s = 30, c=mycolor[0])
elif(k + 1 == len(pointindex) - 1): #如果是终点,点就画大些,使用指定颜色红色
plt.scatter([traj[pointindex[k + 1]][i, 0]], [traj[pointindex[k + 1]][i, 1]], s = 30, c=mycolor[1])
else: #那就剩中间点了,用小尺寸,使用本次随机到的颜色
plt.scatter([traj[pointindex[k]][i, 0]], [traj[pointindex[k]][i, 1]], s = 10, c=thiscolor)
plt.scatter([traj[pointindex[k + 1]][i, 0]], [traj[pointindex[k + 1]][i, 1]], s = 10, c=thiscolor)
plt.show()
代码tab的参考截图:
再补充一句,这个代码是在得到traj变量之后就可以画了
为什么要基于第一次rectflow生成的数据组成pair,再进行二次rectflow训练呢?
请问1-Rectified Flow为什么会经过中心”走弯路“呀?
非常清晰的讲解和代码!(建议把plotting补上)