DDPM (Denoising Diffusion Probabilistic Model)
Published:
Paper Reading: Denoising diffusion probabilistic models.
本文主要介绍 Denoising Diffusion Probabilistic Model (DDPM), 以图片生成为例。
给定一张图片 x0, 采用一个加噪的过程将其变为Guassian分布:
p(x1:T∣x0)=T∏t=1p(xt∣xt−1),xt∣xt−1∼N(√1−βtxt−1,βtI).其中 βt 为噪声的调度。定义 αt=1−βt 以及 ˉαt=∏ss=1αs.
经过如下推导
xT∣x0∼√αTxT−1+√1−αTϵT,ϵT∼N(0,I)∼√αT(√αT−1xT−2+√1−αT−1ϵT−1)+√1−αTϵT,ϵT−1,ϵT∼N(0,I)∼√αTαT−1xT−2+√1−αTαT−1ϵT−1,ϵT−1∼N(0,I)∼⋯∼√ˉαTx0+√1−ˉαTϵ0,ϵ0∼N(0,I)∼N(√ˉαTx0,(1−ˉαT)I).令 ˉαT→0 我们知道 xT∣x0∼N(0,I). 得到了一个纯噪声。
利用Bayes公式,可以求解逆向过程的条件分布
p(xt−1∣xt,x0)=p(xt∣xt−1,x0)p(xt∣x0)p(xt−1∣x0)=exp(−(xt−√αtxt−1)22(1−αt))exp(−(xt−√ˉαtx0)22(1−ˉαt))exp(−(xt−1−√ˉαt−1x0)22(1−ˉαt−1))∝exp(−12((αt1−αt+11−ˉαt−1)x2t−1−2(√αt1−αtxt+√ˉαt1−ˉαt)xt−1))∝exp(−(xt−1−μ(xt,x0,t))22σ2t).其中均值和方差项可以通过配方法得到,形如
μ(xt,x0,t)=√αt(1−ˉαt−1)1−ˉαtxt+√ˉαt−1(1−αt)1−ˉαtx0,σ2t=(1−αt)(1−ˉαt−1)1−ˉαt.我们希望训练一个参数为 θ 的模型预测逆向过程,如此便可以从噪声中恢复数据。
由于方差可以根据时间 t 计算,自然地可以令模型预测均值,也即给定时间步 t 求解
minEx0,xt‖μ(xt,x0,t)−μθ(xt,x0,t)‖2原文 [1] 采用最大化变分下界的方式推导上述式子,但可以从预测均值的角度直观理解。
但上面的问题需要模型输入图片 x0。采用重参数化,改为预测噪声,可以使得模型仅需要输入 xt,t. 注意到
xt∼√ˉαtx0+√1−ˉαtϵ,ϵ∼N(0,I).可以将损失函数等价地写成
minExt,ϵ‖√αt(1−ˉαt−1)1−ˉαtxt+√ˉαt−1(1−αt)(1−ˉαt)√ˉαt(xt−√1−ˉαtϵ)−μθ(xt,x0,t)‖2=minExt,ϵ‖1√αt(xt−1−αt√1−ˉαtϵ)−μθ(xt,x0,t)‖2.令模型的预测为
μθ(xt,x0,t)=1√αt(xt−1−αt√1−ˉαtϵθ(xt,t)).此时模型已经不需要 x0 作为输入值,忽略前面的系数,上面的式子等价于
minExt,ϵ‖ϵ−ϵθ(xt,t)‖2=minEx0,ϵ‖ϵ−ϵθ(√ˉαtx0+√1−ˉαtϵ,t)‖2.这就得到了DDPM的训练过程,也即采样一个图片数据 x0 和标准高斯随机变量 ϵ, 然后用梯度下降对上面的损失函数求解最优的 θ.
得到了训练完的模型后,需要生成数据仅需要首先采样 xT∼N(0,I) 然后使用网络进行如下的逆向过程即可
xt−1=1√αt(xt−1−αt√1−ˉαtϵθ(xt,t))+σtz,z∼N(0,I).迭代 T 步就可以得到最终的数据 x0.
Reference
[1] Ho, Jonathan, Ajay Jain, and Pieter Abbeel. Denoising diffusion probabilistic models. In NeurIPS, 2020.