这是用户在 2024-6-11 22:51 为 https://zhuanlan.zhihu.com/p/563661713 保存的双语快照页面,由 沉浸式翻译 提供双语支持。了解如何保存?

扩散模型之DDPM

霍华德、Jason Ma 等 1413 人赞同了该文章
发布于 2022-09-12 18:28・IP 属地广东 ,编辑于 2023-07-04 09:33・IP 属地广东
目录
收起
扩散模型原理
扩散过程
反向过程
优化目标
模型设计
代码实现
小结
参考
What I cannot create, I do not understand.” -- Richard Feynman

近段时间最火的方向无疑是基于文本用AI生成图像,继OpenAI在2021提出的文本转图像模型DALLE之后,越来越多的大公司卷入这个方向,如谷歌在今年相继推出了ImagenParti。一些主流的文本转图像模型如DALL·E 2stable-diffusionImagen采用了扩散模型Diffusion Model)作为图像生成模型,这也引发了对扩散模型的研究热潮。相比GAN来说,扩散模型训练更稳定,而且能够生成更多样的样本,OpenAI的论文Diffusion Models Beat GANs on Image Synthesis也证明了扩散模型能够超越GAN。简单来说,扩散模型包含两个过程:前向扩散过程反向生成过程,前向扩散过程是对一张图像逐渐添加高斯噪音直至变成随机噪音,而反向生成过程是去噪音过程,我们将从一个随机噪音开始逐渐去噪音直至生成一张图像,这也是我们要求解或者训练的部分。扩散模型与其它主流生成模型的对比如下所示:

目前所采用的扩散模型大都是来自于2020年的工作DDPM: Denoising Diffusion Probabilistic ModelsDDPM对之前的扩散模型(具体见Deep Unsupervised Learning using Nonequilibrium Thermodynamics)进行了简化,并通过变分推断(variational inference)来进行建模,这主要是因为扩散模型也是一个隐变量模型(latent variable model),相比VAE这样的隐变量模型,扩散模型的隐变量是和原始数据是同维度的,而且推理过程(即扩散过程)往往是固定的。这篇文章将基于DDPM详细介绍扩散模型的原理,并给出具体的代码实现和分析。

扩散模型原理

扩散模型包括两个过程:前向过程(forward process)反向过程(reverse process),其中前向过程又称为扩散过程(diffusion process),如下图所示。无论是前向过程还是反向过程都是一个参数化的马尔可夫链(Markov chain),其中反向过程可以用来生成数据,这里我们将通过变分推断来进行建模和求解。

扩散过程

扩散过程是指的对数据逐渐增加高斯噪音直至数据变成随机噪音的过程。对于原始数据x0q(x0)\mathbf{x}_0 \sim q(\mathbf{x}_0),总共包含TT步的扩散过程的每一步都是对上一步得到的数据xt1\mathbf{x}_{t-1}按如下方式增加高斯噪音:

q(xt|xt1)=N(xt;1βtxt1,βtI)q(\mathbf{x}_t \vert \mathbf{x}_{t-1}) = \mathcal{N}(\mathbf{x}_t; \sqrt{1 - \beta_t} \mathbf{x}_{t-1}, \beta_t\mathbf{I}) \\

这里{βt}t=1T\{\beta_t\}^T_{t=1}为每一步所采用的方差,它介于0~1之间。对于扩散模型,我们往往称不同step的方差设定为variance schedule或者noise schedule,通常情况下,越后面的step会采用更大的方差,即满足β1<β2<<βT\beta_1 < \beta_2 < \dots < \beta_T。在一个设计好的variance schedule下,的如果扩散步数TT足够大,那么最终得到的xT\mathbf{x}_{T}就完全丢失了原始数据而变成了一个随机噪音。 扩散过程的每一步都生成一个带噪音的数据xt\mathbf{x}_{t},整个扩散过程也就是一个马尔卡夫链

q(x1:T|x0)=t=1Tq(xt|xt1)q(\mathbf{x}_{1:T} \vert \mathbf{x}_0) = \prod^T_{t=1} q(\mathbf{x}_t \vert \mathbf{x}_{t-1}) \\

另外要指出的是,扩散过程往往是固定的,即采用一个预先定义好的variance schedule,比如DDPM就采用一个线性的variance schedule

扩散过程的一个重要特性是我们可以直接基于原始数据x0\mathbf{x}_{0}来对任意tt步的xt\mathbf{x}_{t}进行采样:xtq(xt|x0)\mathbf{x}_{t}\sim q(\mathbf{x}_t \vert \mathbf{x}_0)。这里定义αt=1βt\alpha_t = 1 - \beta_tα¯t=i=1tαi\bar{\alpha}_t = \prod_{i=1}^t \alpha_i,通过重参数技巧(和VAE类似),那么有:

xt=αtxt1+1αtϵt1 ;where ϵt1,ϵt2,N(0,I)=αt(αt1xt2+1αt1ϵt2)+1αtϵt1=αtαt1xt2+αtαtαt12+1αt2ϵ¯t2 ;where ϵ¯t2 merges two Gaussians (*).=αtαt1xt2+1αtαt1ϵ¯t2==α¯tx0+1α¯tϵ\begin{aligned} \mathbf{x}_t &= \sqrt{\alpha_t}\mathbf{x}_{t-1} + \sqrt{1 - \alpha_t}\mathbf{\epsilon}_{t-1} & \text{ ;where } \mathbf{\epsilon}_{t-1}, \mathbf{\epsilon}_{t-2}, \dots \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) \\ &= \sqrt{\alpha_t}(\sqrt{\alpha_{t-1}}\mathbf{x}_{t-2} + \sqrt{1 - \alpha_{t-1}}\mathbf{\epsilon}_{t-2}) + \sqrt{1 - \alpha_t}\mathbf{\epsilon}_{t-1} \\ &= \sqrt{\alpha_t \alpha_{t-1}} \mathbf{x}_{t-2} + \sqrt{\sqrt {\alpha_t-\alpha_t \alpha_{t-1}}^2+\sqrt{1-\alpha_t }^2} \bar{\mathbf{\epsilon}}_{t-2} & \text{ ;where } \bar{\mathbf{\epsilon}}_{t-2} \text{ merges two Gaussians (*).} \\ &= \sqrt{\alpha_t \alpha_{t-1}} \mathbf{x}_{t-2} + \sqrt{1 - \alpha_t \alpha_{t-1}} \bar{\mathbf{\epsilon}}_{t-2} \\ &= \dots \\ &= \sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t}\mathbf{\epsilon} \end{aligned}\\

上述推到过程利用了两个方差不同的高斯分布N(0,σ12I)\mathcal{N}(\mathbf{0}, \sigma_1^2\mathbf{I})N(0,σ22I)\mathcal{N}(\mathbf{0}, \sigma_2^2\mathbf{I})相加等于一个新的高斯分布N(0,(σ12+σ22)I)\mathcal{N}(\mathbf{0}, (\sigma_1^2 + \sigma_2^2)\mathbf{I})。反重参数化后,我们得到:

q(xt|x0)=N(xt;α¯tx0,(1α¯t)I)q(\mathbf{x}_t \vert \mathbf{x}_0) = \mathcal{N}(\mathbf{x}_t; \sqrt{\bar{\alpha}_t} \mathbf{x}_0, (1 - \bar{\alpha}_t)\mathbf{I}) \\

扩散过程的这个特性很重要。首先,我们可以看到xt\mathbf{x}_{t}其实可以看成是原始数据x0\mathbf{x}_{0}和随机噪音ϵ\mathbf{\epsilon}的线性组合,其中α¯t\sqrt{\bar \alpha_t}1α¯t\sqrt{1 - \bar{\alpha}_t}为组合系数,它们的平方和等于1,我们也可以称两者分别为signal_ratenoise_rate(见keras.io/examples/generVariational Diffusion Models)。更近一步地,我们可以基于α¯t\bar \alpha_t而不是βt\beta_t来定义noise schedule(见Improved Denoising Diffusion Probabilistic Models所设计的cosine schedule),因为这样处理更直接,比如我们直接将α¯T\bar \alpha_T设定为一个接近0的值,那么就可以保证最终得到的xT\mathbf{x}_{T}近似为一个随机噪音。其次,后面的建模和分析过程将使用这个特性。

反向过程

扩散过程是将数据噪音化,那么反向过程就是一个去噪的过程,如果我们知道反向过程的每一步的真实分布q(xt1|xt)q(\mathbf{x}_{t-1} \vert \mathbf{x}_t),那么从一个随机噪音xTN(0,I)\mathbf{x}_T \sim \mathcal{N}(\mathbf{0}, \mathbf{I})开始,逐渐去噪就能生成一个真实的样本,所以反向过程也就是生成数据的过程

估计分布q(xt1|xt)q(\mathbf{x}_{t-1} \vert \mathbf{x}_t)需要用到整个训练样本,我们可以用神经网络来估计这些分布。这里,我们将反向过程也定义为一个马尔卡夫链,只不过它是由一系列用神经网络参数化的高斯分布来组成:

pθ(x0:T)=p(xT)t=1Tpθ(xt1|xt)pθ(xt1|xt)=N(xt1;μθ(xt,t),Σθ(xt,t))p_\theta(\mathbf{x}_{0:T}) = p(\mathbf{x}_T) \prod^T_{t=1} p_\theta(\mathbf{x}_{t-1} \vert \mathbf{x}_t) \quad p_\theta(\mathbf{x}_{t-1} \vert \mathbf{x}_t) = \mathcal{N}(\mathbf{x}_{t-1}; \boldsymbol{\mu}_\theta(\mathbf{x}_t, t), \boldsymbol{\Sigma}_\theta(\mathbf{x}_t, t))\\

这里p(xT)=N(xT;0,I)p(\mathbf{x}_T)= \mathcal{N}(\mathbf{x}_T;\mathbf{0}, \mathbf{I}),而pθ(xt1|xt)p_\theta(\mathbf{x}_{t-1} \vert \mathbf{x}_t)为参数化的高斯分布,它们的均值和方差由训练的网络μθ(xt,t)\boldsymbol{\mu}_\theta(\mathbf{x}_t, t)Σθ(xt,t)\boldsymbol{\Sigma}_\theta(\mathbf{x}_t, t)给出。实际上,扩散模型就是要得到这些训练好的网络,因为它们构成了最终的生成模型

虽然分布q(xt1|xt)q(\mathbf{x}_{t-1} \vert \mathbf{x}_t)是不可直接处理的,但是加上条件x0\mathbf{x}_0的后验分布q(xt1|xt,x0)q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0)却是可处理的,这里有:

q(xt1|xt,x0)=N(xt1;μ~(xt,x0),β~tI)q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0) = \mathcal{N}(\mathbf{x}_{t-1}; \color{blue}{\tilde{\boldsymbol{\mu}}}(\mathbf{x}_t, \mathbf{x}_0), \color{red}{\tilde{\beta}_t} \mathbf{I}) \\

下面我们来具体推导这个分布,首先根据贝叶斯公式,我们有:

q(xt1|xt,x0)=q(xt|xt1,x0)q(xt1|x0)q(xt|x0)q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0) = q(\mathbf{x}_t \vert \mathbf{x}_{t-1}, \mathbf{x}_0) \frac{ q(\mathbf{x}_{t-1} \vert \mathbf{x}_0) }{ q(\mathbf{x}_t \vert \mathbf{x}_0) }\\

由于扩散过程的马尔卡夫链特性,我们知道分布q(xt|xt1,x0)=q(xt|xt1)=N(xt;1βtxt1,βtI)q(\mathbf{x}_t \vert \mathbf{x}_{t-1}, \mathbf{x}_0)=q(\mathbf{x}_t \vert \mathbf{x}_{t-1})=\mathcal{N}(\mathbf{x}_t; \sqrt{1 - \beta_t} \mathbf{x}_{t-1}, \beta_t\mathbf{I})(这里条件x0\mathbf{x}_0是多余的),而由前面得到的扩散过程特性可知:

q(xt1|x0)=N(xt1;α¯t1x0,(1α¯t1)I)q(xt|x0)=N(xt;α¯tx0,(1α¯t)I)q(\mathbf{x}_{t-1} \vert \mathbf{x}_0)=\mathcal{N}(\mathbf{x}_{t-1}; \sqrt{\bar{\alpha}_{t-1}} \mathbf{x}_0, (1 - \bar{\alpha}_{t-1})\mathbf{I}),q(\mathbf{x}_t \vert \mathbf{x}_0) = \mathcal{N}(\mathbf{x}_t; \sqrt{\bar{\alpha}_t} \mathbf{x}_0, (1 - \bar{\alpha}_t)\mathbf{I}) \\

所以,我们有:

q(xt1|xt,x0)=q(xt|xt1,x0)q(xt1|x0)q(xt|x0)exp(12((xtαtxt1)2βt+(xt1α¯t1x0)21α¯t1(xtα¯tx0)21α¯t))=exp(12(xt22αtxtxt1+αtxt12βt+xt122α¯t1x0xt1+α¯t1x021α¯t1(xtα¯tx0)21α¯t))=exp(12((αtβt+11α¯t1)xt12(2αtβtxt+2α¯t11α¯t1x0)xt1+C(xt,x0)))\begin{aligned} q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0) &= q(\mathbf{x}_t \vert \mathbf{x}_{t-1}, \mathbf{x}_0) \frac{ q(\mathbf{x}_{t-1} \vert \mathbf{x}_0) }{ q(\mathbf{x}_t \vert \mathbf{x}_0) } \\ &\propto \exp \Big(-\frac{1}{2} \big(\frac{(\mathbf{x}_t - \sqrt{\alpha_t} \mathbf{x}_{t-1})^2}{\beta_t} + \frac{(\mathbf{x}_{t-1} - \sqrt{\bar{\alpha}_{t-1}} \mathbf{x}_0)^2}{1-\bar{\alpha}_{t-1}} - \frac{(\mathbf{x}_t - \sqrt{\bar{\alpha}_t} \mathbf{x}_0)^2}{1-\bar{\alpha}_t} \big) \Big) \\ &= \exp \Big(-\frac{1}{2} \big(\frac{\mathbf{x}_t^2 - 2\sqrt{\alpha_t} \mathbf{x}_t \color{blue}{\mathbf{x}_{t-1}} \color{black}{+ \alpha_t} \color{red}{\mathbf{x}_{t-1}^2} }{\beta_t} + \frac{ \color{red}{\mathbf{x}_{t-1}^2} \color{black}{- 2 \sqrt{\bar{\alpha}_{t-1}} \mathbf{x}_0} \color{blue}{\mathbf{x}_{t-1}} \color{black}{+ \bar{\alpha}_{t-1} \mathbf{x}_0^2} }{1-\bar{\alpha}_{t-1}} - \frac{(\mathbf{x}_t - \sqrt{\bar{\alpha}_t} \mathbf{x}_0)^2}{1-\bar{\alpha}_t} \big) \Big) \\ &= \exp\Big( -\frac{1}{2} \big( \color{red}{(\frac{\alpha_t}{\beta_t} + \frac{1}{1 - \bar{\alpha}_{t-1}})} \mathbf{x}_{t-1}^2 - \color{blue}{(\frac{2\sqrt{\alpha_t}}{\beta_t} \mathbf{x}_t + \frac{2\sqrt{\bar{\alpha}_{t-1}}}{1 - \bar{\alpha}_{t-1}} \mathbf{x}_0)} \mathbf{x}_{t-1} \color{black}{ + C(\mathbf{x}_t, \mathbf{x}_0) \big) \Big)} \end{aligned}\\

这里的C(xt,x0)C(\mathbf{x}_t, \mathbf{x}_0)是一个和xt1\mathbf{x}_{t-1}无关的部分,所以省略。根据高斯分布的概率密度函数定义和上述结果(配平方),我们可以得到后验分布q(xt1|xt,x0)q(\mathbf{x}_{t-1} \vert \mathbf{x}_{t}, \mathbf{x}_0)的均值和方差:

β~t=1/(αtβt+11α¯t1)=1/(αtα¯t+βtβt(1α¯t1))=1α¯t11α¯tβtμ~t(xt,x0)=(αtβtxt+α¯t11α¯t1x0)/(αtβt+11α¯t1)=(αtβtxt+α¯t11α¯t1x0)1α¯t11α¯tβt=αt(1α¯t1)1α¯txt+α¯t1βt1α¯tx0\begin{aligned} \tilde{\beta}_t &= 1/(\frac{\alpha_t}{\beta_t} + \frac{1}{1 - \bar{\alpha}_{t-1}}) = 1/(\frac{\alpha_t - \bar{\alpha}_t + \beta_t}{\beta_t(1 - \bar{\alpha}_{t-1})}) = \color{green}{\frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t} \cdot \beta_t} \\ \tilde{\boldsymbol{\mu}}_t (\mathbf{x}_t, \mathbf{x}_0) &= (\frac{\sqrt{\alpha_t}}{\beta_t} \mathbf{x}_t + \frac{\sqrt{\bar{\alpha}_{t-1} }}{1 - \bar{\alpha}_{t-1}} \mathbf{x}_0)/(\frac{\alpha_t}{\beta_t} + \frac{1}{1 - \bar{\alpha}_{t-1}}) \\ &= (\frac{\sqrt{\alpha_t}}{\beta_t} \mathbf{x}_t + \frac{\sqrt{\bar{\alpha}_{t-1} }}{1 - \bar{\alpha}_{t-1}} \mathbf{x}_0) \color{green}{\frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t} \cdot \beta_t} \\ &= \frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} \mathbf{x}_t + \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1 - \bar{\alpha}_t} \mathbf{x}_0\\ \end{aligned}\\

可以看到方差是一个定量(扩散过程参数固定),而均值是一个依赖x0\mathbf{x}_0xt\mathbf{x}_t的函数。这个分布将会被用于推导扩散模型的优化目标。

优化目标

上面介绍了扩散模型的扩散过程和反向过程,现在我们来从另外一个角度来看扩散模型:如果我们把中间产生的变量看成隐变量的话,那么扩散模型其实是包含TT个隐变量的隐变量模型(latent variable model),它可以看成是一个特殊的Hierarchical VAEs(见Understanding Diffusion Models: A Unified Perspective):

相比VAE来说,扩散模型的隐变量是和原始数据同维度的,而且encoder(即扩散过程)是固定的。既然扩散模型是隐变量模型,那么我们可以就可以基于变分推断来得到variational lower boundVLB,又称ELBO)作为最大化优化目标,这里有:

logpθ(x0)=logpθ(x0:T)dx1:T=logpθ(x0:T)q(x1:T|x0)q(x1:T|x0)dx1:TEq(x1:T|x0)[logpθ(x0:T)q(x1:T|x0)]\begin{aligned} \log p_\theta(\mathbf{x}_0) &=\log\int p_\theta(\mathbf{x}_{0:T}) d\mathbf{x}_{1:T}\\ &=\log\int \frac{p_\theta(\mathbf{x}_{0:T}) q(\mathbf{x}_{1:T}\vert \mathbf{x}_{0})}{q(\mathbf{x}_{1:T}\vert \mathbf{x}_{0})} d\mathbf{x}_{1:T}\\ &\geq \mathbb{E}_{q(\mathbf{x}_{1:T}\vert \mathbf{x}_{0})}[\log \frac{p_\theta(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T}\vert \mathbf{x}_{0})}]\\ \end{aligned}\\

这里最后一步是利用了Jensen's inequality(不采用这个不等式的推导见博客What are Diffusion Models?),对于网络训练来说,其训练目标为VLB取负

L=LVLB=Eq(x1:T|x0)[logpθ(x0:T)q(x1:T|x0)]=Eq(x1:T|x0)[logq(x1:T|x0)pθ(x0:T)]L=-L_{\text{VLB}}=\mathbb{E}_{q(\mathbf{x}_{1:T}\vert \mathbf{x}_{0})}[-\log \frac{p_\theta(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T}\vert \mathbf{x}_{0})}]=\mathbb{E}_{q(\mathbf{x}_{1:T}\vert \mathbf{x}_{0})}[\log \frac{q(\mathbf{x}_{1:T}\vert \mathbf{x}_{0})}{p_\theta(\mathbf{x}_{0:T})}] \\

我们近一步对训练目标进行分解可得:

L=Eq(x1:T|x0)[logq(x1:T|x0)pθ(x0:T)]=Eq(x1:T|x0)[logt=1Tq(xt|xt1)pθ(xT)t=1Tpθ(xt1|xt)]=Eq(x1:T|x0)[logpθ(xT)+t=1Tlogq(xt|xt1)pθ(xt1|xt)]=Eq(x1:T|x0)[logpθ(xT)+t=2Tlogq(xt|xt1)pθ(xt1|xt)+logq(x1|x0)pθ(x0|x1)]=Eq(x1:T|x0)[logpθ(xT)+t=2Tlogq(xt|xt1,x0)pθ(xt1|xt)+logq(x1|x0)pθ(x0|x1)] ;use q(xt|xt1,x0)=q(xt|xt1)=Eq(x1:T|x0)[logpθ(xT)+t=2Tlog(q(xt1|xt,x0)pθ(xt1|xt)q(xt|x0)q(xt1|x0))+logq(x1|x0)pθ(x0|x1)] ;use Bayes' Rule =Eq(x1:T|x0)[logpθ(xT)+t=2Tlogq(xt1|xt,x0)pθ(xt1|xt)+t=2Tlogq(xt|x0)q(xt1|x0)+logq(x1|x0)pθ(x0|x1)]=Eq(x1:T|x0)[logpθ(xT)+t=2Tlogq(xt1|xt,x0)pθ(xt1|xt)+logq(xT|x0)q(x1|x0)+logq(x1|x0)pθ(x0|x1)]=Eq(x1:T|x0)[logq(xT|x0)pθ(xT)+t=2Tlogq(xt1|xt,x0)pθ(xt1|xt)logpθ(x0|x1)]=Eq(xT|x0)[logq(xT|x0)pθ(xT)]+t=2TEq(xt,xt1|x0)[logq(xt1|xt,x0)pθ(xt1|xt)]Eq(x1|x0)[logpθ(x0|x1)]=Eq(xT|x0)[logq(xT|x0)pθ(xT)]+t=2TEq(xt|x0)[q(xt1|xt,x0)logq(xt1|xt,x0)pθ(xt1|xt)]Eq(x1|x0)[logpθ(x0|x1)]=DKL(q(xT|x0)pθ(xT))LT+t=2TEq(xt|x0)[DKL(q(xt1|xt,x0)pθ(xt1|xt))]Lt1Eq(x1|x0)logpθ(x0|x1)L0\begin{aligned} L &= \mathbb{E}_{q(\mathbf{x}_{1:T}\vert \mathbf{x}_{0})} \Big[ \log\frac{q(\mathbf{x}_{1:T}\vert\mathbf{x}_0)}{p_\theta(\mathbf{x}_{0:T})} \Big] \\ &= \mathbb{E}_{q(\mathbf{x}_{1:T}\vert \mathbf{x}_{0})} \Big[ \log\frac{\prod_{t=1}^T q(\mathbf{x}_t\vert\mathbf{x}_{t-1})}{ p_\theta(\mathbf{x}_T) \prod_{t=1}^T p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t) } \Big] \\ &= \mathbb{E}_{q(\mathbf{x}_{1:T}\vert \mathbf{x}_{0})} \Big[ -\log p_\theta(\mathbf{x}_T) + \sum_{t=1}^T \log \frac{q(\mathbf{x}_t\vert\mathbf{x}_{t-1})}{p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t)} \Big] \\ &= \mathbb{E}_{q(\mathbf{x}_{1:T}\vert \mathbf{x}_{0})} \Big[ -\log p_\theta(\mathbf{x}_T) + \sum_{t=2}^T \log \frac{q(\mathbf{x}_t\vert\mathbf{x}_{t-1})}{p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t)} + \log\frac{q(\mathbf{x}_1 \vert \mathbf{x}_0)}{p_\theta(\mathbf{x}_0 \vert \mathbf{x}_1)} \Big] \\ &= \mathbb{E}_{q(\mathbf{x}_{1:T}\vert \mathbf{x}_{0})} \Big[ -\log p_\theta(\mathbf{x}_T) + \sum_{t=2}^T \log \frac{q(\mathbf{x}_t\vert\mathbf{x}_{t-1}, \mathbf{x}_{0})}{p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t)} + \log\frac{q(\mathbf{x}_1 \vert \mathbf{x}_0)}{p_\theta(\mathbf{x}_0 \vert \mathbf{x}_1)} \Big] & \text{ ;use } q(\mathbf{x}_t \vert \mathbf{x}_{t-1}, \mathbf{x}_0)=q(\mathbf{x}_t \vert \mathbf{x}_{t-1})\\ &= \mathbb{E}_{q(\mathbf{x}_{1:T}\vert \mathbf{x}_{0})} \Big[ -\log p_\theta(\mathbf{x}_T) + \sum_{t=2}^T \log \Big( \frac{q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0)}{p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t)}\cdot \frac{q(\mathbf{x}_t \vert \mathbf{x}_0)}{q(\mathbf{x}_{t-1}\vert\mathbf{x}_0)} \Big) + \log \frac{q(\mathbf{x}_1 \vert \mathbf{x}_0)}{p_\theta(\mathbf{x}_0 \vert \mathbf{x}_1)} \Big] & \text{ ;use Bayes' Rule }\\ &= \mathbb{E}_{q(\mathbf{x}_{1:T}\vert \mathbf{x}_{0})} \Big[ -\log p_\theta(\mathbf{x}_T) + \sum_{t=2}^T \log \frac{q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0)}{p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t)} + \sum_{t=2}^T \log \frac{q(\mathbf{x}_t \vert \mathbf{x}_0)}{q(\mathbf{x}_{t-1} \vert \mathbf{x}_0)} + \log\frac{q(\mathbf{x}_1 \vert \mathbf{x}_0)}{p_\theta(\mathbf{x}_0 \vert \mathbf{x}_1)} \Big] \\ &= \mathbb{E}_{q(\mathbf{x}_{1:T}\vert \mathbf{x}_{0})} \Big[ -\log p_\theta(\mathbf{x}_T) + \sum_{t=2}^T \log \frac{q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0)}{p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t)} + \log\frac{q(\mathbf{x}_T \vert \mathbf{x}_0)}{q(\mathbf{x}_1 \vert \mathbf{x}_0)} + \log \frac{q(\mathbf{x}_1 \vert \mathbf{x}_0)}{p_\theta(\mathbf{x}_0 \vert \mathbf{x}_1)} \Big]\\ &= \mathbb{E}_{q(\mathbf{x}_{1:T}\vert \mathbf{x}_{0})} \Big[ \log\frac{q(\mathbf{x}_T \vert \mathbf{x}_0)}{p_\theta(\mathbf{x}_T)} + \sum_{t=2}^T \log \frac{q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0)}{p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t)} - \log p_\theta(\mathbf{x}_0 \vert \mathbf{x}_1) \Big] \\ &= \mathbb{E}_{q(\mathbf{x}_{T}\vert \mathbf{x}_{0})}\Big[\log\frac{q(\mathbf{x}_T \vert \mathbf{x}_0)}{p_\theta(\mathbf{x}_T)}\Big]+\sum_{t=2}^T \mathbb{E}_{q(\mathbf{x}_{t}, \mathbf{x}_{t-1}\vert \mathbf{x}_{0})}\Big[\log \frac{q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0)}{p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t)}\Big] - \mathbb{E}_{q(\mathbf{x}_{1}\vert \mathbf{x}_{0})}\Big[\log p_\theta(\mathbf{x}_0 \vert \mathbf{x}_1)\Big] \\ &= \mathbb{E}_{q(\mathbf{x}_{T}\vert \mathbf{x}_{0})}\Big[\log\frac{q(\mathbf{x}_T \vert \mathbf{x}_0)}{p_\theta(\mathbf{x}_T)}\Big]+\sum_{t=2}^T \mathbb{E}_{q(\mathbf{x}_{t}\vert \mathbf{x}_{0})}\Big[q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0)\log \frac{q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0)}{p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t)}\Big] - \mathbb{E}_{q(\mathbf{x}_{1}\vert \mathbf{x}_{0})}\Big[\log p_\theta(\mathbf{x}_0 \vert \mathbf{x}_1)\Big] \\ &= \underbrace{D_\text{KL}(q(\mathbf{x}_T \vert \mathbf{x}_0) \parallel p_\theta(\mathbf{x}_T))}_{L_T} + \sum_{t=2}^T \underbrace{\mathbb{E}_{q(\mathbf{x}_{t}\vert \mathbf{x}_{0})}\Big[D_\text{KL}(q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0) \parallel p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t))\Big]}_{L_{t-1}} -\underbrace{\mathbb{E}_{q(\mathbf{x}_{1}\vert \mathbf{x}_{0})}\log p_\theta(\mathbf{x}_0 \vert \mathbf{x}_1)}_{L_0} \end{aligned}

可以看到最终的优化目标共包含T+1T+1项,其中L0L_0可以看成是原始数据重建,优化的是负对数似然,L0L_0可以用估计的N(x0;μθ(x1,1),Σθ(x1,1))\mathcal{N}(\mathbf{x}_0; \boldsymbol{\mu}_\theta(\mathbf{x}_1, 1), \boldsymbol{\Sigma}_\theta(\mathbf{x}_1, 1))来构建一个离散化的decoder来计算(见DDPM论文3.3部分):

pθ(x0|x1)=i=1Dδ(x0i)δ+(x0i)N(x0;μθi(x1,1),Σθi(x1,1))dxδ+(x)={ if x=1x+1255 if x<1δ+(x)={ if x=1x1255 if x>1p_{\theta}(\mathbf{x}_0\vert\mathbf{x}_1)=\prod^D_{i=1}\int ^{\delta_+(x_0^i)}_{\delta_-(x_0^i)}\mathcal{N}(x_0; \mu^i_\theta(x_1, 1), \Sigma^i_\theta(x_1, 1))dx\\ \delta_+(x)= \begin{cases} \infty& \text{ if } x=1 \\ x+\frac{1}{255}& \text{ if } x <1 \end{cases} \\ \delta_+(x)= \begin{cases} -\infty& \text{ if } x=-1 \\ x-\frac{1}{255}& \text{ if } x >-1 \end{cases}

在DDPM中,会将原始图像的像素值从[0, 255]范围归一化到[-1, 1],像素值属于离散化值,这样不同的像素值之间的间隔其实就是2/255,我们可以计算高斯分布落在以ground truth为中心且范围大小为2/255时的概率积分即CDF,具体实现见github.com/hojonathanho(不过后面我们的简化版优化目标并不会计算这个对数似然)。

LTL_T计算的是最后得到的噪音的分布和先验分布的KL散度,这个KL散度没有训练参数,近似为0,因为先验p(xT)=N(0,I)p(\mathbf {x}_T)=\mathcal{N}(\mathbf{0}, \mathbf{I})而扩散过程最后得到的随机噪音q(xT|x0)q(\mathbf{x}_{T}\vert \mathbf{x}_{0})也近似为N(0,I)\mathcal{N}(\mathbf{0}, \mathbf{I});而Lt1L_{t-1}则是计算的是估计分布pθ(xt1|xt)p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t)和真实后验分布q(xt1|xt,x0)q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0)的KL散度,这里希望我们估计的去噪过程和依赖真实数据的去噪过程近似一致:

之所以前面我们将pθ(xt1|xt)p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t)定义为一个用网络参数化的高斯分布N(xt1;μθ(xt,t),Σθ(xt,t))\mathcal{N}(\mathbf{x}_{t-1}; \boldsymbol{\mu}_\theta(\mathbf{x}_t, t), \boldsymbol{\Sigma}_\theta(\mathbf{x}_t, t)),是因为要匹配的后验分布q(xt1|xt,x0)q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0)也是一个高斯分布。对于训练目标L0L_0Lt1L_{t-1}来说,都是希望得到训练好的网络μθ(xt,t)\boldsymbol{\mu}_\theta(\mathbf{x}_t, t)Σθ(xt,t)\boldsymbol{\Sigma}_\theta(\mathbf{x}_t, t)(对于L0L_0t=1t=1)。DDPM对pθ(xt1|xt)p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t)做了近一步简化,采用固定的方差Σθ(xt,t)=\boldsymbol{\Sigma}_\theta(\mathbf{x}_t, t)= σt2I\sigma_t^2\mathbf{I},这里的σt2\sigma_t^2可以设定为βt\beta _t或者β~t\tilde{\beta}_t(这其实是两个极端,分别是上限和下限,也可以采用可训练的方差,见论文Improved Denoising Diffusion Probabilistic ModelsAnalytic-DPM: an Analytic Estimate of the Optimal Reverse Variance in Diffusion Probabilistic Models)。这里假定σt2=β~t\sigma_t^2=\tilde{\beta}_t,那么:

q(xt1|xt,x0)=N(xt1;μ~(xt,x0),σt2I)q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0)=\mathcal{N}(\mathbf{x}_{t-1}; {\tilde{\boldsymbol{\mu}}} (\mathbf{x}_t, \mathbf{x}_0), {\sigma_t^2} \mathbf{I}) pθ(xt1|xt)=N(xt1;μθ(xt,t),σt2I)p_\theta(\mathbf{x}_{t-1} \vert \mathbf{x}_t) = \mathcal{N}(\mathbf{x}_{t-1}; \boldsymbol{\mu}_\theta(\mathbf{x}_t, t), {\sigma_t^2} \mathbf{I})

对于两个高斯分布的KL散度,其计算公式为(具体推导见生成模型之VAE):

KL(p1||p2)=12(tr(Σ21Σ1)+(μ2μ1)Σ21(μ2μ1)n+logdet(Σ2)det(Σ1))\text{KL}(p_1||p_2) = \frac{1}{2}(\text{tr}(\boldsymbol{\Sigma}_2^{-1}\boldsymbol{\Sigma}_1)+(\boldsymbol{\mu_2}-\boldsymbol{\mu_1})^{\top}\boldsymbol{\Sigma}_2^{-1}(\boldsymbol{\mu_2}-\boldsymbol{\mu_1})-n+\log\frac{\det(\boldsymbol{\Sigma_2})}{\det(\boldsymbol{\Sigma_1})}) \\

那么就有:

DKL(q(xt1|xt,x0)pθ(xt1|xt))=DKL(N(xt1;μ~(xt,x0),σt2I)N(xt1;μθ(xt,t),σt2I))=12(n+1σt2μ~t(xt,x0)μθ(xt,t)2n+log1)=12σt2μ~t(xt,x0)μθ(xt,t)2\begin{aligned} D_\text{KL}(q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0)\parallel p_\theta(\mathbf{x}_{t-1} \vert \mathbf{x}_t)) &=D_\text{KL}(\mathcal{N}(\mathbf{x}_{t-1}; {\tilde{\boldsymbol{\mu}}} (\mathbf{x}_t, \mathbf{x}_0), {\sigma_t^2} \mathbf{I}) \parallel \mathcal{N}(\mathbf{x}_{t-1}; \boldsymbol{\mu}_\theta(\mathbf{x}_t, t), {\sigma_t^2} \mathbf{I})) \\ &=\frac{1}{2}(n+\frac{1}{{\sigma_t^2}}\|\tilde{\boldsymbol{\mu}}_t(\mathbf{x}_t, \mathbf{x}_0) - {\boldsymbol{\mu}_\theta(\mathbf{x}_t, t)} \|^2 -n+\log1) \\ &=\frac{1}{2{\sigma_t^2}}\|\tilde{\boldsymbol{\mu}}_t(\mathbf{x}_t, \mathbf{x}_0) - {\boldsymbol{\mu}_\theta(\mathbf{x}_t, t)} \|^2 \end{aligned}\\

那么优化目标Lt1L_{t-1}即为:

Lt1=Eq(xt|x0)[12σt2μ~t(xt,x0)μθ(xt,t)2]L_{t-1}=\mathbb{E}_{q(\mathbf{x}_{t}\vert \mathbf{x}_{0})}\Big[ \frac{1}{2{\sigma_t^2}}\|\tilde{\boldsymbol{\mu}}_t(\mathbf{x}_t, \mathbf{x}_0) - {\boldsymbol{\mu}_\theta(\mathbf{x}_t, t)} \|^2\Big] \\

从上述公式来看,我们是希望网络学习到的均值μθ(xt,t)\boldsymbol{\mu}_\theta(\mathbf{x}_t, t)和后验分布的均值μ~(xt,x0){\tilde{\boldsymbol{\mu}}} (\mathbf{x}_t, \mathbf{x}_0)一致。不过DDPM发现预测均值并不是最好的选择。根据前面得到的扩散过程的特性,我们有:

xt(x0,ϵ)=α¯tx0+1α¯tϵ where ϵN(0,I)\mathbf{x_t}(\mathbf{x_0},\mathbf{\epsilon})=\sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t}\mathbf{\epsilon} \quad \text{ where } \mathbf{\epsilon}\sim \mathcal{N}(\mathbf{0}, \mathbf{I}) \\

将这个公式带入上述优化目标(注意这里的损失我们加上了对 x0\mathbf{x}_0 的数学期望),可以得到:

Lt1=Ex0(Eq(xt|x0)[12σt2μ~t(xt,x0)μθ(xt,t)2])=Ex0,ϵN(0,I)[12σt2μ~t(xt(x0,ϵ),1α¯t(xt(x0,ϵ)1α¯tϵ))μθ(xt(x0,ϵ),t)2]=Ex0,ϵN(0,I)[12σt2(αt(1α¯t1)1α¯txt(x0,ϵ)+α¯t1βt1α¯t1α¯t(xt(x0,ϵ)1α¯tϵ))μθ(xt(x0,ϵ),t)2]=Ex0,ϵN(0,I)[12σt21αt(xt(x0,ϵ)βt1α¯tϵ)μθ(xt(x0,ϵ),t)2]\begin{aligned} L_{t-1}&=\mathbb{E}_{\mathbf{x}_{0}}\Big(\mathbb{E}_{q(\mathbf{x}_{t}\vert \mathbf{x}_{0})}\Big[ \frac{1}{2{\sigma_t^2}}\|\tilde{\boldsymbol{\mu}}_t(\mathbf{x}_t, \mathbf{x}_0) - {\boldsymbol{\mu}_\theta(\mathbf{x}_t, t)} \|^2\Big]\Big) \\ &=\mathbb{E}_{\mathbf{x}_{0},\mathbf{\epsilon}\sim \mathcal{N}(\mathbf{0}, \mathbf{I})}\Big[ \frac{1}{2{\sigma_t^2}}\|\tilde{\boldsymbol{\mu}}_t\Big(\mathbf{x_t}(\mathbf{x_0},\mathbf{\epsilon}), \frac{1}{\sqrt{\bar \alpha_t}} \big(\mathbf{x_t}(\mathbf{x_0},\mathbf{\epsilon}) - \sqrt{1 - \bar{\alpha}_t} \mathbf{\epsilon} \big)\Big ) - {\boldsymbol{\mu}_\theta(\mathbf{x_t}(\mathbf{x_0},\mathbf{\epsilon}), t)} \|^2\Big] \\ &=\mathbb{E}_{\mathbf{x}_{0},\mathbf{\epsilon}\sim \mathcal{N}(\mathbf{0}, \mathbf{I})}\Big[ \frac{1}{2{\sigma_t^2}}\|\Big (\frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} \mathbf{x}_t(\mathbf{x_0},\mathbf{\epsilon}) + \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1 - \bar{\alpha}_t} \frac{1}{\sqrt{\bar \alpha_t}} \big(\mathbf{x_t}(\mathbf{x_0},\mathbf{\epsilon}) - \sqrt{1 - \bar{\alpha}_t} \mathbf{\epsilon} \big) \Big) - {\boldsymbol{\mu}_\theta(\mathbf{x_t}(\mathbf{x_0},\mathbf{\epsilon}), t)} \|^2\Big] \\ &=\mathbb{E}_{\mathbf{x}_{0},\mathbf{\epsilon}\sim \mathcal{N}(\mathbf{0}, \mathbf{I})}\Big[ \frac{1}{2{\sigma_t^2}}\|\frac{1}{\sqrt{\alpha_t}}\Big( \mathbf{x}_t(\mathbf{x_0},\mathbf{\epsilon}) - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}}\mathbf{\epsilon}\Big) - {\boldsymbol{\mu}_\theta(\mathbf{x_t}(\mathbf{x_0},\mathbf{\epsilon}), t)} \|^2\Big] \end{aligned}\\

近一步地,我们对μθ(xt(x0,ϵ),t)\boldsymbol{\mu}_\theta(\mathbf{x_t}(\mathbf{x_0},\mathbf{\epsilon}), t)也进行重参数化,变成:

μθ(xt(x0,ϵ),t)=1αt(xt(x0,ϵ)βt1α¯tϵθ(xt(x0,ϵ),t))\boldsymbol{\mu}_\theta(\mathbf{x_t}(\mathbf{x_0},\mathbf{\epsilon}), t)=\frac{1}{\sqrt{\alpha_t}}\Big( \mathbf{x}_t(\mathbf{x_0},\mathbf{\epsilon}) - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}}\mathbf{\epsilon}_\theta\big(\mathbf{x}_t(\mathbf{x_0},\mathbf{\epsilon}), t\big)\Big) \\

这里的ϵθ\mathbf{\epsilon}_\theta是一个基于神经网络的拟合函数,这意味着我们由原来的预测均值而换成预测噪音ϵ\mathbf{\epsilon}。我们将上述等式带入优化目标,可以得到:

Lt1=Ex0,ϵN(0,I)[12σt21αt(xt(x0,ϵ)βt1α¯tϵ)μθ(xt(x0,ϵ),t)2]=Ex0,ϵN(0,I)[βt22σt2αt(1α¯t)ϵϵθ(xt(x0,ϵ),t)2]=Ex0,ϵN(0,I)[βt22σt2αt(1α¯t)ϵϵθ(α¯tx0+1α¯tϵ,t)2]\begin{aligned} L_{t-1}&=\mathbb{E}_{\mathbf{x}_{0},\mathbf{\epsilon}\sim \mathcal{N}(\mathbf{0}, \mathbf{I})}\Big[ \frac{1}{2{\sigma_t^2}}\|\frac{1}{\sqrt{\alpha_t}}\Big( \mathbf{x}_t(\mathbf{x_0},\mathbf{\epsilon}) - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}}\mathbf{\epsilon}\Big) - {\boldsymbol{\mu}_\theta(\mathbf{x_t}(\mathbf{x_0},\mathbf{\epsilon}), t)} \|^2\Big] \\ &= \mathbb{E}_{\mathbf{x}_{0},\mathbf{\epsilon}\sim \mathcal{N}(\mathbf{0}, \mathbf{I})}\Big[ \frac{\beta_t^2}{2{\sigma_t^2}\alpha_t(1-\bar{\alpha}_t)}\| \mathbf{\epsilon}- \mathbf{\epsilon}_\theta\big(\mathbf{x}_t(\mathbf{x_0},\mathbf{\epsilon}), t\big)\|^2\Big]\\ &=\mathbb{E}_{\mathbf{x}_{0},\mathbf{\epsilon}\sim \mathcal{N}(\mathbf{0}, \mathbf{I})}\Big[ \frac{\beta_t^2}{2{\sigma_t^2}\alpha_t(1-\bar{\alpha}_t)}\| \mathbf{\epsilon}- \mathbf{\epsilon}_\theta\big(\sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t}\mathbf{\epsilon}, t\big)\|^2\Big] \end{aligned}\\

DDPM近一步对上述目标进行了简化,即去掉了权重系数,变成了: Lt1simple=Ex0,ϵN(0,I)[ϵϵθ(α¯tx0+1α¯tϵ,t)2]L_{t-1}^{\text{simple}}=\mathbb{E}_{\mathbf{x}_{0},\mathbf{\epsilon}\sim \mathcal{N}(\mathbf{0}, \mathbf{I})}\Big[ \| \mathbf{\epsilon}- \mathbf{\epsilon}_\theta\big(\sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t}\mathbf{\epsilon}, t\big)\|^2\Big]

这里的tt在[1, T]范围内取值(如前所述,其中取1时对应L0L_0)。由于去掉了不同tt的权重系数,所以这个简化的目标其实是VLB优化目标进行了reweight。从DDPM的对比实验结果来看,预测噪音比预测均值效果要好,采用简化版本的优化目标比VLB目标效果要好:

虽然扩散模型背后的推导比较复杂,但是我们最终得到的优化目标非常简单,就是让网络预测的噪音和真实的噪音一致。DDPM的训练过程也非常简单,如下图所示:随机选择一个训练样本->从1-T随机抽样一个t->随机产生噪音-计算当前所产生的带噪音数据(红色框所示)->输入网络预测噪音->计算产生的噪音和预测的噪音的L2损失->计算梯度并更新网络。

一旦训练完成,其采样过程也非常简单,如上所示:我们从一个随机噪音开始,并用训练好的网络预测噪音,然后计算条件分布的均值(红色框部分),然后用均值加标准差乘以一个随机噪音,直至t=0完成新样本的生成(最后一步不加噪音)。不过实际的代码实现和上述过程略有区别(见github.com/hojonathanho:先基于预测的噪音生成x0\mathbf{x}_0,并进行了clip处理(范围[-1, 1],原始数据归一化到这个范围),然后再计算均值。我个人的理解这应该算是一种约束,既然模型预测的是噪音,那么我们也希望用预测噪音重构处理的原始数据也应该满足范围要求。

模型设计

前面我们介绍了扩散模型的原理以及优化目标,那么扩散模型的核心就在于训练噪音预测模型,由于噪音和原始数据是同维度的,所以我们可以选择采用AutoEncoder架构来作为噪音预测模型。DDPM所采用的模型是一个基于residual block和attention block的U-Net模型。如下所示:

U-Net属于encoder-decoder架构,其中encoder分成不同的stages,每个stage都包含下采样模块来降低特征的空间大小(H和W),然后decoder和encoder相反,是将encoder压缩的特征逐渐恢复。U-Net在decoder模块中还引入了skip connection,即concat了encoder中间得到的同维度特征,这有利于网络优化。DDPM所采用的U-Net每个stage包含2个residual block,而且部分stage还加入了self-attention模块增加网络的全局建模能力。 另外,扩散模型其实需要的是TT个噪音预测模型,实际处理时,我们可以增加一个time embedding(类似transformer中的position embedding)来将timestep编码到网络中,从而只需要训练一个共享的U-Net模型。具体地,DDPM在各个residual block都引入了time embedding,如上图所示。

代码实现

最后,我们基于PyTorch框架给出DDPM的具体实现,这里主要参考了三套代码实现:

首先,是time embeding,这里是采用Attention Is All You Need中所设计的sinusoidal position embedding,只不过是用来编码timestep:

# use sinusoidal position embedding to encode time step (https://arxiv.org/abs/1706.03762)   
def timestep_embedding(timesteps, dim, max_period=10000):
    """
    Create sinusoidal timestep embeddings.
    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    half = dim // 2
    freqs = torch.exp(
        -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
    ).to(device=timesteps.device)
    args = timesteps[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    return embedding

由于只有residual block才引入time embedding,所以可以定义一些辅助模块来自动处理,如下所示:

# define TimestepEmbedSequential to support `time_emb` as extra input
class TimestepBlock(nn.Module):
    """
    Any module where forward() takes timestep embeddings as a second argument.
    """

    @abstractmethod
    def forward(self, x, emb):
        """
        Apply the module to `x` given `emb` timestep embeddings.
        """


class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
    """
    A sequential module that passes timestep embeddings to the children that
    support it as an extra input.
    """

    def forward(self, x, emb):
        for layer in self:
            if isinstance(layer, TimestepBlock):
                x = layer(x, emb)
            else:
                x = layer(x)
        return x

这里所采用的U-Net采用GroupNorm进行归一化,所以这里也简单定义了一个norm layer以方便使用:

# use GN for norm layer
def norm_layer(channels):
    return nn.GroupNorm(32, channels)

U-Net的核心模块是residual block,它包含两个卷积层以及shortcut,同时也要引入time embedding,这里额外定义了一个linear层来将time embedding变换为和特征维度一致,第一conv之后通过加上time embedding来编码time:

# Residual block
class ResidualBlock(TimestepBlock):
    def __init__(self, in_channels, out_channels, time_channels, dropout):
        super().__init__()
        self.conv1 = nn.Sequential(
            norm_layer(in_channels),
            nn.SiLU(),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        )
        
        # pojection for time step embedding
        self.time_emb = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_channels, out_channels)
        )
        
        self.conv2 = nn.Sequential(
            norm_layer(out_channels),
            nn.SiLU(),
            nn.Dropout(p=dropout),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        )

        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        else:
            self.shortcut = nn.Identity()


    def forward(self, x, t):
        """
        `x` has shape `[batch_size, in_dim, height, width]`
        `t` has shape `[batch_size, time_dim]`
        """
        h = self.conv1(x)
        # Add time step embeddings
        h += self.time_emb(t)[:, :, None, None]
        h = self.conv2(h)
        return h + self.shortcut(x)

这里还在部分residual block引入了attention,这里的attention和transformer的self-attention是一致的:

# Attention block with shortcut
class AttentionBlock(nn.Module):
    def __init__(self, channels, num_heads=1):
        super().__init__()
        self.num_heads = num_heads
        assert channels % num_heads == 0
        
        self.norm = norm_layer(channels)
        self.qkv = nn.Conv2d(channels, channels * 3, kernel_size=1, bias=False)
        self.proj = nn.Conv2d(channels, channels, kernel_size=1)

    def forward(self, x):
        B, C, H, W = x.shape
        qkv = self.qkv(self.norm(x))
        q, k, v = qkv.reshape(B*self.num_heads, -1, H*W).chunk(3, dim=1)
        scale = 1. / math.sqrt(math.sqrt(C // self.num_heads))
        attn = torch.einsum("bct,bcs->bts", q * scale, k * scale)
        attn = attn.softmax(dim=-1)
        h = torch.einsum("bts,bcs->bct", attn, v)
        h = h.reshape(B, -1, H, W)
        h = self.proj(h)
        return h + x

对于上采样模块和下采样模块,其分别可以采用插值和stride=2的conv或者pooling来实现:

# upsample
class Upsample(nn.Module):
    def __init__(self, channels, use_conv):
        super().__init__()
        self.use_conv = use_conv
        if use_conv:
            self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)

    def forward(self, x):
        x = F.interpolate(x, scale_factor=2, mode="nearest")
        if self.use_conv:
            x = self.conv(x)
        return x

# downsample
class Downsample(nn.Module):
    def __init__(self, channels, use_conv):
        super().__init__()
        self.use_conv = use_conv
        if use_conv:
            self.op = nn.Conv2d(channels, channels, kernel_size=3, stride=2, padding=1)
        else:
            self.op = nn.AvgPool2d(stride=2)

    def forward(self, x):
        return self.op(x)

上面我们实现了U-Net的所有组件,就可以进行组合来实现U-Net了:

# The full UNet model with attention and timestep embedding
class UNetModel(nn.Module):
    def __init__(
        self,
        in_channels=3,
        model_channels=128,
        out_channels=3,
        num_res_blocks=2,
        attention_resolutions=(8, 16),
        dropout=0,
        channel_mult=(1, 2, 2, 2),
        conv_resample=True,
        num_heads=4
    ):
        super().__init__()

        self.in_channels = in_channels
        self.model_channels = model_channels
        self.out_channels = out_channels
        self.num_res_blocks = num_res_blocks
        self.attention_resolutions = attention_resolutions
        self.dropout = dropout
        self.channel_mult = channel_mult
        self.conv_resample = conv_resample
        self.num_heads = num_heads
        
        # time embedding
        time_embed_dim = model_channels * 4
        self.time_embed = nn.Sequential(
            nn.Linear(model_channels, time_embed_dim),
            nn.SiLU(),
            nn.Linear(time_embed_dim, time_embed_dim),
        )
        
        # down blocks
        self.down_blocks = nn.ModuleList([
            TimestepEmbedSequential(nn.Conv2d(in_channels, model_channels, kernel_size=3, padding=1))
        ])
        down_block_chans = [model_channels]
        ch = model_channels
        ds = 1
        for level, mult in enumerate(channel_mult):
            for _ in range(num_res_blocks):
                layers = [
                    ResidualBlock(ch, mult * model_channels, time_embed_dim, dropout)
                ]
                ch = mult * model_channels
                if ds in attention_resolutions:
                    layers.append(AttentionBlock(ch, num_heads=num_heads))
                self.down_blocks.append(TimestepEmbedSequential(*layers))
                down_block_chans.append(ch)
            if level != len(channel_mult) - 1: # don't use downsample for the last stage
                self.down_blocks.append(TimestepEmbedSequential(Downsample(ch, conv_resample)))
                down_block_chans.append(ch)
                ds *= 2
        
        # middle block
        self.middle_block = TimestepEmbedSequential(
            ResidualBlock(ch, ch, time_embed_dim, dropout),
            AttentionBlock(ch, num_heads=num_heads),
            ResidualBlock(ch, ch, time_embed_dim, dropout)
        )
        
        # up blocks
        self.up_blocks = nn.ModuleList([])
        for level, mult in list(enumerate(channel_mult))[::-1]:
            for i in range(num_res_blocks + 1):
                layers = [
                    ResidualBlock(
                        ch + down_block_chans.pop(),
                        model_channels * mult,
                        time_embed_dim,
                        dropout
                    )
                ]
                ch = model_channels * mult
                if ds in attention_resolutions:
                    layers.append(AttentionBlock(ch, num_heads=num_heads))
                if level and i == num_res_blocks:
                    layers.append(Upsample(ch, conv_resample))
                    ds //= 2
                self.up_blocks.append(TimestepEmbedSequential(*layers))

        self.out = nn.Sequential(
            norm_layer(ch),
            nn.SiLU(),
            nn.Conv2d(model_channels, out_channels, kernel_size=3, padding=1),
        )

    def forward(self, x, timesteps):
        """
        Apply the model to an input batch.
        :param x: an [N x C x H x W] Tensor of inputs.
        :param timesteps: a 1-D batch of timesteps.
        :return: an [N x C x ...] Tensor of outputs.
        """
        hs = []
        # time step embedding
        emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
        
        # down stage
        h = x
        for module in self.down_blocks:
            h = module(h, emb)
            hs.append(h)
        # middle stage
        h = self.middle_block(h, emb)
        # up stage
        for module in self.up_blocks:
            cat_in = torch.cat([h, hs.pop()], dim=1)
            h = module(cat_in, emb)
        return self.out(h)

对于扩散过程,其主要的参数就是timesteps和noise schedule,DDPM采用范围为[0.0001, 0.02]的线性noise schedule,其默认采用的总扩散步数为1000

# beta schedule
def linear_beta_schedule(timesteps):
    scale = 1000 / timesteps
    beta_start = scale * 0.0001
    beta_end = scale * 0.02
    return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)

我们定义个扩散模型,它主要要提前根据设计的noise schedule来计算一些系数,并实现一些扩散过程和生成过程:

class GaussianDiffusion:
    def __init__(
        self,
        timesteps=1000,
        beta_schedule='linear'
    ):
        self.timesteps = timesteps
        
        if beta_schedule == 'linear':
            betas = linear_beta_schedule(timesteps)
        elif beta_schedule == 'cosine':
            betas = cosine_beta_schedule(timesteps)
        else:
            raise ValueError(f'unknown beta schedule {beta_schedule}')
        self.betas = betas
            
        self.alphas = 1. - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.)
        
        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
        self.log_one_minus_alphas_cumprod = torch.log(1.0 - self.alphas_cumprod)
        self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
        self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1)
        
        # calculations for posterior q(x_{t-1} | x_t, x_0)
        self.posterior_variance = (
            self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
        # below: log calculation clipped because the posterior variance is 0 at the beginning
        # of the diffusion chain
        self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(min =1e-20))
        
        self.posterior_mean_coef1 = (
            self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
        self.posterior_mean_coef2 = (
            (1.0 - self.alphas_cumprod_prev)
            * torch.sqrt(self.alphas)
            / (1.0 - self.alphas_cumprod)
        )
    
    # get the param of given timestep t
    def _extract(self, a, t, x_shape):
        batch_size = t.shape[0]
        out = a.to(t.device).gather(0, t).float()
        out = out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))
        return out
    
    # forward diffusion (using the nice property): q(x_t | x_0)
    def q_sample(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)

        sqrt_alphas_cumprod_t = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape)
        sqrt_one_minus_alphas_cumprod_t = self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)

        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
    
    # Get the mean and variance of q(x_t | x_0).
    def q_mean_variance(self, x_start, t):
        mean = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
        variance = self._extract(1.0 - self.alphas_cumprod, t, x_start.shape)
        log_variance = self._extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
        return mean, variance, log_variance
    
    # Compute the mean and variance of the diffusion posterior: q(x_{t-1} | x_t, x_0)
    def q_posterior_mean_variance(self, x_start, x_t, t):
        posterior_mean = (
            self._extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
            + self._extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        posterior_variance = self._extract(self.posterior_variance, t, x_t.shape)
        posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped, t, x_t.shape)
        return posterior_mean, posterior_variance, posterior_log_variance_clipped
    
    # compute x_0 from x_t and pred noise: the reverse of `q_sample`
    def predict_start_from_noise(self, x_t, t, noise):
        return (
            self._extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
            self._extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
        )
    
    # compute predicted mean and variance of p(x_{t-1} | x_t)
    def p_mean_variance(self, model, x_t, t, clip_denoised=True):
        # predict noise using model
        pred_noise = model(x_t, t)
        # get the predicted x_0: different from the algorithm2 in the paper
        x_recon = self.predict_start_from_noise(x_t, t, pred_noise)
        if clip_denoised:
            x_recon = torch.clamp(x_recon, min=-1., max=1.)
        model_mean, posterior_variance, posterior_log_variance = \
                    self.q_posterior_mean_variance(x_recon, x_t, t)
        return model_mean, posterior_variance, posterior_log_variance
        
    # denoise_step: sample x_{t-1} from x_t and pred_noise
    @torch.no_grad()
    def p_sample(self, model, x_t, t, clip_denoised=True):
        # predict mean and variance
        model_mean, _, model_log_variance = self.p_mean_variance(model, x_t, t,
                                                    clip_denoised=clip_denoised)
        noise = torch.randn_like(x_t)
        # no noise when t == 0
        nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x_t.shape) - 1))))
        # compute x_{t-1}
        pred_img = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
        return pred_img
    
    # denoise: reverse diffusion
    @torch.no_grad()
    def p_sample_loop(self, model, shape):
        batch_size = shape[0]
        device = next(model.parameters()).device
        # start from pure noise (for each example in the batch)
        img = torch.randn(shape, device=device)
        imgs = []
        for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
            img = self.p_sample(model, img, torch.full((batch_size,), i, device=device, dtype=torch.long))
            imgs.append(img.cpu().numpy())
        return imgs
    
    # sample new images
    @torch.no_grad()
    def sample(self, model, image_size, batch_size=8, channels=3):
        return self.p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))
    
    # compute train losses
    def train_losses(self, model, x_start, t):
        # generate random noise
        noise = torch.randn_like(x_start)
        # get x_t
        x_noisy = self.q_sample(x_start, t, noise=noise)
        predicted_noise = model(x_noisy, t)
        loss = F.mse_loss(noise, predicted_noise)
        return loss

其中几个主要的函数总结如下:

  • q_sample:实现的从x0\mathbf{x}_0xt\mathbf{x}_t扩散过程;
  • q_posterior_mean_variance:实现的是后验分布的均值和方差的计算公式;
  • predict_start_from_noiseq_sample的逆过程,根据预测的噪音来生成x0\mathbf{x}_0
  • p_mean_variance:根据预测的噪音来计算pθ(xt1|xt)p_\theta(\mathbf{x}_{t-1} \vert \mathbf{x}_t)的均值和方差;
  • p_sample:单个去噪step;
  • p_sample_loop:整个去噪音过程,即生成过程。

扩散模型的训练过程非常简单,如下所示:

# train
epochs = 10

for epoch in range(epochs):
    for step, (images, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        
        batch_size = images.shape[0]
        images = images.to(device)
        
        # sample t uniformally for every example in the batch
        t = torch.randint(0, timesteps, (batch_size,), device=device).long()
        
        loss = gaussian_diffusion.train_losses(model, images, t)
        
        if step % 200 == 0:
            print("Loss:", loss.item())
            
        loss.backward()
        optimizer.step()

这里我们以mnist数据简单实现了一个mnist-demo,下面是一些生成的样本:

对生成过程进行采样,如下所示展示了如何从一个随机噪音生成一个手写字体图像:

另外这里也提供了CIFAR10数据集的demo:ddpm_cifar10,不过只训练了200epochs,生成的图像只是初见成效。


小结

相比VAE和GAN,扩散模型的理论更复杂一些,不过其优化目标和具体实现却并不复杂,这其实也让人感叹:一堆复杂的数据推导,最终却得到了一个简单的结论。要深入理解扩散模型,DDPM只是起点,后面还有比较多的改进工作,比如加速采样的DDIM以及DDPM的改进版本DDPM+ADM

这篇文章特别参考了OpenAI研究员的博客What are Diffusion Models?(部分公式在此基础上进行加工修改)以及谷歌研究员的论文Understanding Diffusion Models: A Unified Perspective

注:本人水平有限,如有谬误,欢迎讨论交流。

参考

发布于 2022-09-12 18:28・IP 属地广东 ,编辑于 2023-07-04 09:33・IP 属地广东
「真诚赞赏,手留余香」
还没有人赞赏,快来当第一个赞赏的人吧!
欢迎参与讨论

175 条评论
默认
最新
霍华德
我关注的人
这密密麻麻的公式 惊呆我了
2022-09-12 · IP 属地广东
咪咕班克斯

太强了,推不下去了哈哈啊哈

2022-09-13 · IP 属地北京
小小将
作者
哈哈,有些是来自其他的blog[大笑]
2022-09-12 · IP 属地广东
wdxs2025

我居然看懂了这么多公式,主要还是作者写得好,没什么错误

2022-11-24 · IP 属地贵州
FuturexGO
[思考]DDPM++这个名字在Yang Song的score-based SDE里面用过了,不应该把classifier guidance这篇文章称为DDPM++[思考]
2022-09-12 · IP 属地北京
小小将
作者
受教了[谢邀]
2022-09-12 · IP 属地广东
从心

看裂开了,ddpm绝对是我今年看过的最头疼的论文[捂脸]

2023-08-26 · IP 属地浙江
歸去去

题主,我想问一下关于ddpm从ELBO推导出3个KL散度的那个式子,你写出来的式子没有前面的E_q,但是我看ddpm的论文附录A的推导是有,我怎么也推不出来,只有log完全没法组成KL散度,所以确定是ddpm的文章写错了是吗?

2023-04-14 · IP 属地黑龙江
稳住自己的状态

我也没看懂这一步[发呆]

04-16 · IP 属地重庆
草堂浅饮

请问你现在懂了吗[捂脸][捂脸]

06-04 · IP 属地北京
打怪兽的小飞象

你好,我想问下,DDPM在测试的时候,输入的是噪声图像吗

2023-07-06 · IP 属地辽宁
秦波

我理解DDPM本身是没有办法 控制输出自己想要的数字的,需要添加一定的condition

02-24 · IP 属地北京
小小将
对啊
02-24 · IP 属地广东
徐策

为什么q(xt-1| xt, x0)也服从高斯分布呢

2023-11-22 · IP 属地中国香港
吞剑表演艺术家

它不服从高斯分布,只是当beta_t很小的时候近似服从。参考这个

03-19 · IP 属地中国香港
Jimmy
因为a_t趋向于0,而噪声本来就是高斯分布,所以q也服从高斯分布
03-04 · IP 属地广东
Zach

作者太强了。。这些公式看得我头皮发麻

2023-07-05 · IP 属地广东
猛士守四方

之所以我们将p_theta(x_{t-1}|x_t)定义为一个用网络参数化的高斯分布....是因为要匹配的后验分布xxxx也是一个高斯分布。针对这一句,

在论文Deep Unsupervised Learning using Nonequilibrium Thermodynamics里Section2.2提到,当beta_t比较小的时候,q(x_{t-1}|x_t)也会是个高斯分布。对应的,diffusion的step越多,beta越小。


因此,这里指出了p_theta(x_{t-1}|x_t)是高斯分布的原因。论文中指出这个结论来自于Feller, W. On the theory of stochastic processes, with particular reference to applications. In Proceedings of the
[First] Berkeley Symposium on Mathematical Statistics
and Probability. The Regents of the University of California, 1949

这篇参考文献。有大佬看这篇文献的,欢迎评论指出核心思想。


论文里指出的这一点,应该就是Stable Diffusion为什么需要这么多步的原因。


个人的浅见,欢迎大家指出问题。

2023-09-28 · IP 属地北京
疾风

可以请问下为什么这里q不直接从x0到xT而是以x0为条件从x1到xT吗?

2023-08-04 · IP 属地重庆
点击查看全部评论
欢迎参与讨论

文章被以下专栏收录

想来知乎工作?请发送邮件到 jobs@zhihu.com