这是用户在 2024-9-25 13:21 为 https://zhuanlan.zhihu.com/p/687740527 保存的双语快照页面,由 沉浸式翻译 提供双语支持。了解如何保存?
生成模型大道至简|Rectified Flow基础概念|代码

生成模型大道至简|Rectified Flow基础概念|代码

99 人赞同了该文章
发布于 2024-03-18 22:47・IP 属地广东

最近看了下SD3的论文,里面用到了Rectified Flow,之前没有接触过,了解了下是项有意思的技术。值得推荐,这里记录下Rectified Flow的基础概念和代码实现。详细的原理和深入​理解建议查阅论文。作者应该也在知乎,可以通过参考资料的链接直达。

定义:一个分布π0\pi_0转化成另一个分布π1\pi_1的问题,具体π0\pi_0π1\pi_1选择不同就可以式不同的问题,比如从噪声分布到真实图片分布,或者从某类图像分布到另一类图像的分布。

比如这里我们用两个不同的高斯混合模型来表示两个分布,并分别采样10000个样本点

import torch
import numpy as np
import torch.nn as nn
from torch.distributions import Normal, Categorical
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions.mixture_same_family import MixtureSameFamily
import matplotlib.pyplot as plt
import torch.nn.functional as F

D = 10.
M = D+5
VAR = 0.3
DOT_SIZE = 4
COMP = 3

initial_mix = Categorical(torch.tensor([1/COMP for i in range(COMP)]))
initial_comp = MultivariateNormal(torch.tensor([[D * np.sqrt(3) / 2., D / 2.], [-D * np.sqrt(3) / 2., D / 2.], [0.0, - D * np.sqrt(3) / 2.]]).float(), VAR * torch.stack([torch.eye(2) for i in range(COMP)]))
initial_model = MixtureSameFamily(initial_mix, initial_comp)
samples_0 = initial_model.sample([10000])

target_mix = Categorical(torch.tensor([1/COMP for i in range(COMP)]))
target_comp = MultivariateNormal(torch.tensor([[D * np.sqrt(3) / 2., - D / 2.], [-D * np.sqrt(3) / 2., - D / 2.], [0.0, D * np.sqrt(3) / 2.]]).float(), VAR * torch.stack([torch.eye(2) for i in range(COMP)]))
target_model = MixtureSameFamily(target_mix, target_comp)
samples_1 = target_model.sample([10000])
print('Shape of the samples:', samples_0.shape, samples_1.shape)

plt.figure(figsize=(4,4))
plt.xlim(-M,M)
plt.ylim(-M,M)
plt.title(r'Samples from $\pi_0$ and $\pi_1$')
plt.scatter(samples_0[:, 0].cpu().numpy(), samples_0[:, 1].cpu().numpy(), alpha=0.1, label=r'$\pi_0$')
plt.scatter(samples_1[:, 0].cpu().numpy(), samples_1[:, 1].cpu().numpy(), alpha=0.1, label=r'$\pi_1$')
plt.legend()

plt.tight_layout()

可以得到

基于上面两个分布的样本点,我们想要学习一个模型(映射关系),能够将π0\pi_0上的点映射(运动)到π1\pi_1,用常微分方程表示就是流模型(flow):

ddtZt=v(Zt,t),Z0π0,t[0,1].\frac d{dt}Z_t=v(Z_t,t)\text{,}Z_0\sim\pi_0\text{,}\forall t\in[0,1]. \\

两点之间直线最短,如果分布π0\pi_0上的点通过两点之间的连线运动到π1\pi_1,那不就很完美。可以定义两个点之间的插值如下

X_t = t X_1 + (1-t)X_0 \\

前面提到的flow就等价于“拉直的流模型(Rectified Flow)”

\frac d{dt}X_t=X_1-X_0,\quad\forall t\in[0,1]. \\

也就是学习一个v,使得它尽可能接近X_1-X_0

\min_v\int_0^1\mathbb{E}_{X_0\sim\pi_0,X_1\sim\pi_1}\left[||(X_1-X_0)-v(X_t,t)||^2\right]dt,\quad\mathrm{where}\quad X_t=tX_1+(1-t)X_0. \\

我们可以假设这个v是一个三层的神经网络

class MLP(nn.Module):
    def __init__(self, input_dim=2, hidden_num=100):
        super().__init__()
        self.fc1 = nn.Linear(input_dim+1, hidden_num, bias=True)
        self.fc2 = nn.Linear(hidden_num, hidden_num, bias=True)
        self.fc3 = nn.Linear(hidden_num, input_dim, bias=True)
        self.act = lambda x: torch.tanh(x)
    
    def forward(self, x_input, t):
        inputs = torch.cat([x_input, t], dim=1)
        x = self.fc1(inputs)
        x = self.act(x)
        x = self.fc2(x)
        x = self.act(x)
        x = self.fc3(x)

        return x

定义RectifiedFlow类,包含获得训练样本和从常微分方程采样

class RectifiedFlow():
  def __init__(self, model=None, num_steps=1000):
    self.model = model
    self.N = num_steps
  
  def get_train_tuple(self, z0=None, z1=None):
    t = torch.rand((z1.shape[0], 1))
    z_t =  t * z1 + (1.-t) * z0
    target = z1 - z0 
        
    return z_t, t, target

  @torch.no_grad()
  def sample_ode(self, z0=None, N=None):
    ### NOTE: Use Euler method to sample from the learned flow
    if N is None:
      N = self.N    
    dt = 1./N
    traj = [] # to store the trajectory
    z = z0.detach().clone()
    batchsize = z.shape[0]
    
    traj.append(z.detach().clone())
    for i in range(N):
      t = torch.ones((batchsize,1)) * i / N
      pred = self.model(z, t)
      z = z.detach().clone() + pred * dt
      
      traj.append(z.detach().clone())

    return traj

损失函数的定义

def train_rectified_flow(rectified_flow, optimizer, pairs, batchsize, inner_iters):
  loss_curve = []
  for i in range(inner_iters+1):
    optimizer.zero_grad()
    indices = torch.randperm(len(pairs))[:batchsize]
    batch = pairs[indices]
    z0 = batch[:, 0].detach().clone()
    z1 = batch[:, 1].detach().clone()
    z_t, t, target = rectified_flow.get_train_tuple(z0=z0, z1=z1)

    pred = rectified_flow.model(z_t, t)
    loss = (target - pred).view(pred.shape[0], -1).abs().pow(2).sum(dim=1)
    loss = loss.mean()
    loss.backward()
    
    optimizer.step()
    loss_curve.append(np.log(loss.item())) ## to store the loss curve

  return rectified_flow, loss_curve

1-Rectified Flow的训练样本对是随机从两个分布进行选取的

x_0 = samples_0.detach().clone()[torch.randperm(len(samples_0))]
x_1 = samples_1.detach().clone()[torch.randperm(len(samples_1))]
x_pairs = torch.stack([x_0, x_1], dim=1)
print(x_pairs.shape)

训练后

iterations = 10000
batchsize = 2048
input_dim = 2

rectified_flow_1 = RectifiedFlow(model=MLP(input_dim, hidden_num=100), num_steps=100)
optimizer = torch.optim.Adam(rectified_flow_1.model.parameters(), lr=5e-3)

rectified_flow_1, loss_curve = train_rectified_flow(rectified_flow_1, optimizer, x_pairs, batchsize, iterations)
plt.plot(np.linspace(0, iterations, iterations+1), loss_curve[:(iterations+1)])
plt.title('Training Loss Curve')

可以得到

可以发现由于两个样本的点是随机选择来组成训练样本对的,可能导致从一个分布的点走到另一分布有点“走弯路”。

因此,2-Rectified Flow则是用1-Rectified Flow的模型的采样点来组成样本对

z10 = samples_0.detach().clone()
traj = rectified_flow_1.sample_ode(z0=z10.detach().clone(), N=100)
z11 = traj[-1].detach().clone()
z_pairs = torch.stack([z10, z11], dim=1)
print(z_pairs.shape)

同样再训练一下

reflow_iterations = 50000

rectified_flow_2 = RectifiedFlow(model=MLP(input_dim, hidden_num=100), num_steps=100)
import copy 
rectified_flow_2.net = copy.deepcopy(rectified_flow_1) # we fine-tune the model from 1-Rectified Flow for faster training.
optimizer = torch.optim.Adam(rectified_flow_2.model.parameters(), lr=5e-3)

rectified_flow_2, loss_curve = train_rectified_flow(rectified_flow_2, optimizer, z_pairs, batchsize, reflow_iterations)
plt.plot(np.linspace(0, reflow_iterations, reflow_iterations+1), loss_curve[:(reflow_iterations+1)])

可以得到

可以看到现在一个分布的点可以几乎以“直线的方式”走到另一个分布。

参考资料

发布于 2024-03-18 22:47・IP 属地广东
理性发言,友善互动

14 条评论
默认
最新
栈满阈越

赞!非常清晰和凝练[赞同]学习了

08-09 · 北京
王康康

作者好像没有放路径图可视化的代码,我画蛇添足补充一个[酷][酷]
plt.figure()
mycolor = ['b', 'r']
color_candidate = ['g', 'c', 'm', 'y', 'k']

sample_num = 20
this_step = 5
for i in range(sample_num): #选择了20个样本,作者总共设置了1万个样本
pointindex = list(range(0, (len(traj) - 1), this_step)) + [len(traj) - 1] #作者设置的每条线101个点,但画多了看不清,少画些吧
thiscolor = random.sample(color_candidate, 1)[0]
for k in range(len(pointindex) - 1):
plt.plot([traj[pointindex[k]][i, 0], traj[pointindex[k + 1]][i, 0]], [traj[pointindex[k]][i, 1], traj[pointindex[k + 1]][i, 1]], c=thiscolor) #画这一小段的线,使用本次随机到的颜色
if(k == 0): #如果是起点,点就画大些,使用指定颜色蓝色
plt.scatter([traj[pointindex[k]][i, 0]], [traj[pointindex[k]][i, 1]], s = 30, c=mycolor[0])
elif(k + 1 == len(pointindex) - 1): #如果是终点,点就画大些,使用指定颜色红色
plt.scatter([traj[pointindex[k + 1]][i, 0]], [traj[pointindex[k + 1]][i, 1]], s = 30, c=mycolor[1])
else: #那就剩中间点了,用小尺寸,使用本次随机到的颜色
plt.scatter([traj[pointindex[k]][i, 0]], [traj[pointindex[k]][i, 1]], s = 10, c=thiscolor)
plt.scatter([traj[pointindex[k + 1]][i, 0]], [traj[pointindex[k + 1]][i, 1]], s = 10, c=thiscolor)
plt.show()

09-09 · 中国香港
王康康

代码tab的参考截图:

09-09 · 中国香港
王康康

再补充一句,这个代码是在得到traj变量之后就可以画了

09-09 · 中国香港
Beyond

为什么要基于第一次rectflow生成的数据组成pair,再进行二次rectflow训练呢?

07-01 · 北京
Beyond
文生图训练的时候,假设图文本身就是对齐的pair,这样的话应该不需要二次rectflow了吧?
07-01 · 北京
养生的控制人
可以不用的
07-01 · 广东
Justin Jiang

请问1-Rectified Flow为什么会经过中心”走弯路“呀?

05-21 · 美国
养生的控制人
因为训练pair没有进行匹配,随便选点的
05-21 · 广东
Justin Jiang

非常清晰的讲解和代码!(建议把plotting补上)

05-21 · 美国
理性发言,友善互动

文章被以下专栏收录

想来知乎工作?请发送邮件到 jobs@zhihu.com