《Understanding Diffusion Models》阅读笔记2

《Understanding Diffusion Models: A Unified Perspective》阅读笔记2

原文链接

第二种等价解释:学习扩散噪声参数

笔记1中学习从xtx_t预测x0x_0,是通过将μθ\mu_{\theta}凑成μq\mu_{q}的形式推导出来的

在这里,也是一样,设置μθ\mu_{\theta}μq\mu_{q}拥有一样的形式:

具体地,在这里,去除μq\mu_{q}中的x0x_0,使其变成关于xtx_tϵ\epsilon的函数。并且,凑出μθ\mu_{\theta}与前者有相同的形式,进一步学习预测噪声。

基本流程

  1. 将加噪去噪时间表变成通过参数变成可以学习的(步数没法变,但是每步之间的距离可以)
  2. 通过重参数化,将均值参数均表示成xtx_t和噪声ϵ\epsilon的函数

首先,解决加噪时间表的学习问题,如果直接学习α^η(t)\hat{\alpha}_{\boldsymbol{\eta}}(t),会在推理时对一个样本多次使用该NN,效率低。

将KL散度进行变形得到不同的优化形式:

DKL(q(xt1xt,x0)pθ(xt1xt))=12σq2(t)αˉt1(1αt)2(1αˉt)2[x^θ(xt,t)x022]=12(1αt)(1αˉt1)1αˉtαˉt1(1αt)2(1αˉt)2[x^θ(xt,t)x022]替换真实的σ=121αˉt(1αt)(1αˉt1)αˉt1(1αt)2(1αˉt)2[x^θ(xt,t)x022]=12αˉt1(1αt)(1αˉt1)(1αˉt)[x^θ(xt,t)x022]=12αˉt1αˉt(1αˉt1)(1αˉt)[x^θ(xt,t)x022]=12αˉt1αˉt1αˉt+αˉt1αˉtαˉt(1αˉt1)(1αˉt)[x^θ(xt,t)x022]加一项减一项=12αˉt1(1αˉt)αˉt(1αˉt1)(1αˉt1)(1αˉt)[x^θ(xt,t)x022]=12(αˉt1(1αˉt)(1αˉt1)(1αˉt)αˉt(1αˉt1)(1αˉt1)(1αˉt))[x^θ(xt,t)x022]=12(αˉt11αˉt1αˉt1αˉt)[x^θ(xt,t)x022]\begin{aligned} D_{\mathrm{KL}}(q(\boldsymbol{x}_{t\boldsymbol{-}1}|\boldsymbol{x}_t,\boldsymbol{x}_0)\parallel p_{\boldsymbol{\theta}}(\boldsymbol{x}_{t\boldsymbol{-}1}|\boldsymbol{x}_t)) \\ & = \frac{1}{2\sigma_{q}^{2}(t)}\frac{\bar{\alpha}_{t-1}(1-\alpha_{t})^{2}}{(1-\bar{\alpha}_{t})^{2}}\left[\left\|\hat{\boldsymbol{x}}_{\boldsymbol{\theta}}(\boldsymbol{x}_{t},t)-\boldsymbol{x}_{0}\right\|_{2}^{2}\right] \\ & =\frac{1}{2\frac{(1-\alpha_{t})(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_{t}}}\frac{\bar{\alpha}_{t-1}(1-\alpha_{t})^{2}}{(1-\bar{\alpha}_{t})^{2}}\left[\left\|\hat{\boldsymbol x}_{\boldsymbol\theta}(\boldsymbol x_{t},t)-\boldsymbol x_{0}\right\|_{2}^{2}\right] && \mathrm{替换真实的\sigma}\\ &=\frac12\frac{1-\bar{\alpha}_t}{(1-\alpha_t)(1-\bar{\alpha}_{t-1})}\frac{\bar{\alpha}_{t-1}(1-\alpha_t)^2}{(1-\bar{\alpha}_t)^2}\left[\left\|\hat{\boldsymbol{x}}_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t)-\boldsymbol{x}_0\right\|_2^2\right] \\ &=\frac12\frac{\bar{\alpha}_{t-1}(1-\alpha_{t})}{(1-\bar{\alpha}_{t-1})(1-\bar{\alpha}_{t})}\left[\left\|\hat{\boldsymbol{x}}_{\boldsymbol{\theta}}(\boldsymbol{x}_{t},t)-\boldsymbol{x}_{0}\right\|_{2}^{2}\right] \\ &=\frac12\frac{\bar{\alpha}_{t-1}-\bar{\alpha}_{t}}{(1-\bar{\alpha}_{t-1})(1-\bar{\alpha}_{t})}\left[\left\|\hat{\boldsymbol{x}}_{\boldsymbol{\theta}}(\boldsymbol{x}_{t},t)-\boldsymbol{x}_{0}\right\|_{2}^{2}\right] \\ &=\frac{1}{2}\frac{\bar{\alpha}_{t-1}-\bar{\alpha}_{t-1}\bar{\alpha}_{t}+\bar{\alpha}_{t-1}\bar{\alpha}_{t}-\bar{\alpha}_{t}}{(1-\bar{\alpha}_{t-1})(1-\bar{\alpha}_{t})}\left[\left\|\hat{\boldsymbol{x}}_{\boldsymbol{\theta}}(\boldsymbol{x}_{t},t)-\boldsymbol{x}_{0}\right\|_{2}^{2}\right] && \mathrm{加一项减一项}\\ &=\frac{1}{2}\frac{\bar{\alpha}_{t-1}(1-\bar{\alpha}_{t})-\bar{\alpha}_{t}(1-\bar{\alpha}_{t-1})}{(1-\bar{\alpha}_{t-1})(1-\bar{\alpha}_{t})}\left[\left\|\hat{\boldsymbol x}_{\boldsymbol\theta}(\boldsymbol x_{t},t)-\boldsymbol x_{0}\right\|_{2}^{2}\right] \\ &=\frac{1}{2}\left(\frac{\bar{\alpha}_{t-1}(1-\bar{\alpha}_t)}{(1-\bar{\alpha}_{t-1})(1-\bar{\alpha}_t)}-\frac{\bar{\alpha}_t(1-\bar{\alpha}_{t-1})}{(1-\bar{\alpha}_{t-1})(1-\bar{\alpha}_t)}\right)\left[\left\|\hat{\boldsymbol{x}}_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t)-\boldsymbol{x}_0\right\|_2^2\right] \\ &=\frac{1}{2}\left(\frac{\bar{\alpha}_{t-1}}{1-\bar{\alpha}_{t-1}}-\frac{\bar{\alpha}_{t}}{1-\bar{\alpha}_{t}}\right)\left[\left\|\hat{\boldsymbol{x}}_{\boldsymbol{\theta}}(\boldsymbol{x}_{t},t)-\boldsymbol{x}_{0}\right\|_{2}^{2}\right] \end{aligned}

这样KL散度里就只有αˉ\bar{\alpha}形式的存在,可以直接用NN建模该部分,减少计算量。

参考信噪比的定义有:

SNR(t)=αˉt1αˉt\mathrm{SNR}(t)=\frac{\bar{\alpha}_t}{1-\bar{\alpha}_t}

这样当逐渐到TT时间步时,αtˉ\bar{\alpha_{t}}逐渐到0,1αtˉ1 - \bar{\alpha_{t}},数据完全成为0均值1方差的标准高斯分布。

因此,上述KL散度可以写成:

12(SNR(t1)SNR(t))[x^θ(xt,t)x022]\frac12\left(\mathrm{SNR}(t-1)-\mathrm{SNR}(t)\right)\left[\left\|\hat{\boldsymbol{x}}_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t)-\boldsymbol{x}_0\right\|_2^2\right]

具体来说,SNR应是一个关于t单调递减的范围从0到1的函数,其可以通过建模一个关于t单调递增的函数ω\omega来实现(ω\omega神经元强制非负模拟单调函数,输入增大,输出一定是不会变小的):

SNR(t)=exp(ωη(t))\mathrm{SNR}(t)=\exp(-\omega_{\boldsymbol{\eta}}(t))

这样,可以得到:

αˉt1αˉt=exp(ωη(t))αˉt=sigmoid(ωη(t))1αˉt=sigmoid(ωη(t))\begin{aligned}&\frac{\bar{\alpha}_{t}}{1-\bar{\alpha}_{t}}=\exp(-\omega_{\boldsymbol{\eta}}(t))\\&\therefore\bar{\alpha}_{t}=\mathrm{sigmoid}(-\omega_{\boldsymbol{\eta}}(t))\\&\therefore1-\bar{\alpha}_{t}=\mathrm{sigmoid}(\omega_{\boldsymbol{\eta}}(t))\end{aligned}

以上这些,可以用于从x0x_0xtx_t的重参数化过程。

学习噪声扩散的流程

现在开始,分别将μq\mu_qμθ\mu_{\theta}表示为xt+ϵx_t+\epsilon的形式:

μq(xt,x0)=αt(1αˉt1)xt+αˉt1(1αt)x01αˉt=αt(1αˉt1)xt+αˉt1(1αt)xt1αˉtϵ0αˉt1αˉt重参数化的逆过程=αt(1αˉt1)xt+(1αt)xt1αˉtϵ0αt1αˉt=αt(1αˉt1)xt1αˉt+(1αt)xt(1αˉt)αt(1αt)1αˉtϵ0(1αˉt)αt=(αt(1αˉt1)1αˉt+1αt(1αˉt)αt)xt(1αt)1αˉt(1αˉt)αtϵ0=(αt(1αˉt1)(1αˉt)αt+1αt(1αˉt)αt)xt1αt1αˉtαtϵ0=αtαˉt+1αt(1αˉt)αtxt1αt1αˉtαtϵ0=1αˉt(1αˉt)αtxt1αt1αˉtαtϵ0=1αtxt1αt1αˉtαtϵ0\begin{aligned} \boldsymbol{\mu}_q(\boldsymbol{x}_t,\boldsymbol{x}_0)& =\frac{\sqrt{\alpha_{t}}(1-\bar{\alpha}_{t-1})\boldsymbol{x}_{t}+\sqrt{\bar{\alpha}_{t-1}}(1-\alpha_{t})\boldsymbol{x}_{0}}{1-\bar{\alpha}_{t}} \\ &=\frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})\boldsymbol{x}_t+\sqrt{\bar{\alpha}_{t-1}}(1-\alpha_t)\frac{\boldsymbol{x}_t-\sqrt{1-\bar{\alpha}_t}\boldsymbol{\epsilon}_0}{\sqrt{\bar{\alpha}_t}}}{1-\bar{\alpha}_t} && \mathrm{重参数化的逆过程}\\ &=\frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})\boldsymbol{x}_t+(1-\alpha_t)\frac{\boldsymbol{x}_t-\sqrt{1-\bar{\alpha}_t}\boldsymbol{\epsilon}_0}{\sqrt{\alpha_t}}}{1-\bar{\alpha}_t} \\ &=\frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})\boldsymbol{x}_t}{1-\bar{\alpha}_t}+\frac{(1-\alpha_t)\boldsymbol{x}_t}{(1-\bar{\alpha}_t)\sqrt{\alpha_t}}-\frac{(1-\alpha_t)\sqrt{1-\bar{\alpha}_t}\boldsymbol{\epsilon}_0}{(1-\bar{\alpha}_t)\sqrt{\alpha_t}} \\ &=\left(\frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}+\frac{1-\alpha_t}{(1-\bar{\alpha}_t)\sqrt{\alpha_t}}\right)\boldsymbol{x}_t-\frac{(1-\alpha_t)\sqrt{1-\bar{\alpha}_t}}{(1-\bar{\alpha}_t)\sqrt{\alpha_t}}\boldsymbol{\epsilon}_0 \\ &=\left(\frac{\alpha_t(1-\bar\alpha_{t-1})}{(1-\bar\alpha_t)\sqrt{\alpha_t}}+\frac{1-\alpha_t}{(1-\bar\alpha_t)\sqrt{\alpha_t}}\right)\boldsymbol{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar\alpha_t}\sqrt{\alpha_t}}\boldsymbol{\epsilon}_0 \\ &=\frac{\alpha_t-\bar{\alpha}_t+1-\alpha_t}{(1-\bar{\alpha}_t)\sqrt{\alpha_t}}\boldsymbol{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}\sqrt{\alpha_t}}\boldsymbol{\epsilon}_0 \\ &=\frac{1-\bar{\alpha}_t}{(1-\bar{\alpha}_t)\sqrt{\alpha_t}}\boldsymbol{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}\sqrt{\alpha_t}}\boldsymbol{\epsilon}_0 \\ &=\frac1{\sqrt{\alpha_t}}\boldsymbol{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}\sqrt{\alpha_t}}\boldsymbol{\epsilon}_0 \end{aligned}

μθ\mu_{\theta}pθ(xt1xt)p_{\theta}(x_{t-1}|x_t)的均值,即xtx_ttt的函数。可以模仿前者,类似的设置为关于xtx_tϵ\epsilon的函数。由于xtx_t已知,ϵ\epsilon为止,可以将ϵ\epsilon用神经网络建模,得到:

μθ(xt,t)=1αtxt1αt1αˉtαtϵ^θ(xt,t)\boldsymbol{\mu}_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t)=\frac1{\sqrt{\alpha_t}}\boldsymbol{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}\sqrt{\alpha_t}}\boldsymbol{\hat{\epsilon}}_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t)

这里KL散度可以写为:

argminθDKL(q(xt1xt,x0)pθ(xt1xt))=argminθDKL(N(xt1;μq,Σq(t))N(xt1;μθ,Σq(t)))=argminθ12σq2(t)[μθμq22]=argminθ12σq2(t)[1αtxt1αt1αˉtαtϵ^θ(xt,t)1αtxt+1αt1αˉtαtϵ022]=argminθ12σq2(t)[1αt1αˉtαtϵ01αt1αˉtαtϵ^θ(xt,t)22]=argminθ12σq2(t)[1αt1αˉtαt(ϵ0ϵ^θ(xt,t))22]=argminθ12σq2(t)(1αt)2(1αˉt)αt[ϵ0ϵ^θ(xt,t)22]\begin{aligned} &\arg\min_{\boldsymbol{\theta}}D_{\mathrm{KL}}(q(\boldsymbol{x}_{t-1}|\boldsymbol{x}_{t},\boldsymbol{x}_{0})\parallel p_{\boldsymbol{\theta}}(\boldsymbol{x}_{t-1}|\boldsymbol{x}_{t})) \\ &=\arg\min_{\boldsymbol{\theta}}D_{\mathrm{KL}}(\mathcal{N}\left(\boldsymbol{x}_{t-1};\boldsymbol{\mu}_{q},\boldsymbol{\Sigma}_{q}\left(t\right)\right)\parallel\mathcal{N}\left(\boldsymbol{x}_{t-1};\boldsymbol{\mu}_{\boldsymbol{\theta}},\boldsymbol{\Sigma}_{q}\left(t\right)\right)) \\ & =\arg\min_{\boldsymbol{\theta}}\frac{1}{2\sigma_q^2(t)}\left[\left\|\boldsymbol{\mu}_{\boldsymbol{\theta}}-\boldsymbol{\mu}_q\right\|_2^2\right]\\ &=\arg\min_{\boldsymbol{\theta}}\frac{1}{2\sigma_{q}^{2}(t)}\left[\left\|\frac{1}{\sqrt{\alpha_{t}}}\boldsymbol{x}_{t}-\frac{1-\alpha_{t}}{\sqrt{1-\bar{\alpha}_{t}}\sqrt{\alpha_{t}}}\boldsymbol{\hat{\epsilon}}_{\boldsymbol{\theta}}(\boldsymbol{x}_{t},t)-\frac{1}{\sqrt{\alpha_{t}}}\boldsymbol{x}_{t}+\frac{1-\alpha_{t}}{\sqrt{1-\bar{\alpha}_{t}}\sqrt{\alpha_{t}}}\boldsymbol{\epsilon}_{0}\right\|_{2}^{2}\right] \\ &=\arg\min_{\boldsymbol{\theta}}\frac{1}{2\sigma_{q}^{2}(t)}\left[\left\|\frac{1-\alpha_{t}}{\sqrt{1-\bar{\alpha}_{t}}\sqrt{\alpha_{t}}}\boldsymbol{\epsilon}_{0}-\frac{1-\alpha_{t}}{\sqrt{1-\bar{\alpha}_{t}}\sqrt{\alpha_{t}}}\boldsymbol{\hat{\epsilon}}_{\boldsymbol{\theta}}(\boldsymbol{x}_{t},t)\right\|_{2}^{2}\right] \\ &=\arg\min_{\boldsymbol{\theta}}\frac{1}{2\sigma_{q}^{2}(t)}\left[\left\|\frac{1-\alpha_{t}}{\sqrt{1-\bar{\alpha}_{t}}\sqrt{\alpha_{t}}}(\boldsymbol{\epsilon}_{0}-\boldsymbol{\hat{\epsilon}}_{\boldsymbol{\theta}}(\boldsymbol{x}_{t},t))\right\|_{2}^{2}\right] \\ &=\arg\min_{\boldsymbol{\theta}}\frac1{2\sigma_q^2(t)}\frac{(1-\alpha_t)^2}{(1-\bar{\alpha}_t)\alpha_t}\left[\left\|\boldsymbol{\epsilon}_0-\boldsymbol{\hat{\epsilon}}_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t)\right\|_2^2\right] \end{aligned}

注意,在用神经网络建模ϵ\epsilon时大小应控制在(0,1)的范围内。据原文,尽管预测x0x_0ϵ\epsilon是等价的。但是,预测噪声在实际操作上效果会更好。

第三种等价解释:学习分数函数模型

特威迪公式做了什么(当zN(z;μz,Σz)\boldsymbol{z}\sim\mathcal{N}(\boldsymbol{z};\boldsymbol{\mu}_z,\boldsymbol{\Sigma}_z))?

E[μzz]=z+Σzzlogp(z)\mathbb{E}\left[\boldsymbol{\mu}_z|\boldsymbol{z}\right]=\boldsymbol{z}+\boldsymbol{\Sigma}_z\nabla_{\boldsymbol{z}}\log p(\boldsymbol{z})

关于μq\mu_q的表达式中有x0x_0,在概率分布q(xtx0)q(x_t|x_0)的均值μxt\mu_{x_t}可以用关于x0x_0来表达:αˉx0\sqrt{\bar{\alpha}}x_0

q(xtx0)=N(xt;αˉtx0,(1αˉt)I)q(\boldsymbol{x}_t|\boldsymbol{x}_0)=\mathcal{N}(\boldsymbol{x}_t;\sqrt{\bar{\alpha}_t}\boldsymbol{x}_0,(1-\bar{\alpha}_t) \mathbf{I})

特威迪公式里需要有一个观测变量xtx_t没关系,因为在q(xt1xt,x0)q(x_{t-1}|x_t,x_0)xtx_t是已知的,总结起来:相当于是在xtx_t有一个已观测样本时,对μq\mu_q表达式中的x0x_0做了一个估计和替换。其中,由于假设,μxt=αˉx0\mu_{x_t} = \sqrt{\bar{\alpha}}x_0是完全准确的,没有估计过程。=z+Σzzlogp(z)=\boldsymbol{z}+\boldsymbol{\Sigma}_z\nabla_{\boldsymbol{z}}\log p(\boldsymbol{z})过程是存在估计的。

E[μxtxt]=αˉx0=xt+(1αˉt)xtlogp(xt)\mathbb{E}\left[\boldsymbol{\mu}_{x_t}|\boldsymbol{x}_t\right]= \sqrt{\bar{\alpha}}x_0 = \boldsymbol{x}_t+(1-\bar{\alpha}_t)\nabla_{\boldsymbol{x}_t}\log p(\boldsymbol{x}_t)

x0=xt+(1αˉt)logp(xt)αˉt\therefore\boldsymbol{x}_0=\frac{\boldsymbol{x}_t+(1-\bar{\alpha}_t)\nabla\log p(\boldsymbol{x}_t)}{\sqrt{\bar{\alpha}_t}}

我们通过特威迪公式得到均值μxt\mu_{x_t}的估计用于替换掉μq\mu_q中的x0x_0

μq(xt,x0)=αt(1αˉt1)xt+αˉt1(1αt)x01αˉt=αt(1αˉt1)xt+αˉt1(1αt)xt+(1αˉt)logp(xt)αˉt1αˉt=αt(1αˉt1)xt+(1αt)xt+(1αˉt)logp(xt)αt1αˉt右上角一项的分母被消掉一部分=αt(1αˉt1)xt1αˉt+(1αt)xt(1αˉt)αt+(1αt)(1αˉt)logp(xt)(1αˉt)αt=(αt(1αˉt1)1αˉt+1αt(1αˉt)αt)xt+1αtαtlogp(xt)=(αt(1αˉt1)(1αˉt)αt+1αt(1αˉt)αt)xt+1αtαtlogp(xt)=αtαˉt+1αt(1αˉt)αtxt+1αtαtlogp(xt)=1αˉt(1αˉt)αtxt+1αtαtlogp(xt)=1αtxt+1αtαtlogp(xt)\begin{aligned} \boldsymbol{\mu}_q(\boldsymbol{x}_t,\boldsymbol{x}_0)& \begin{aligned}=\frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})\boldsymbol{x}_t+\sqrt{\bar{\alpha}_{t-1}}(1-\alpha_t)\boldsymbol{x}_0}{1-\bar{\alpha}_t}\end{aligned} \\ &=\frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})\boldsymbol{x}_t+\sqrt{\bar{\alpha}_{t-1}}(1-\alpha_t)\frac{\boldsymbol{x}_t+(1-\bar{\alpha}_t)\nabla\log p(\boldsymbol{x}_t)}{\sqrt{\bar{\alpha}_t}}}{1-\bar{\alpha}_t} \\ &=\frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})\boldsymbol{x}_t+(1-\alpha_t)\frac{\boldsymbol{x}_t+(1-\bar{\alpha}_t)\nabla\log p(\boldsymbol{x}_t)}{\sqrt{\alpha_t}}}{1-\bar{\alpha}_t} && \mathrm{右上角一项的分母被消掉一部分}\\ &\begin{aligned}&=\frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})\boldsymbol{x}_t}{1-\bar{\alpha}_t}+\frac{(1-\alpha_t)\boldsymbol{x}_t}{(1-\bar{\alpha}_t)\sqrt{\alpha_t}}+\frac{(1-\alpha_t)(1-\bar{\alpha}_t)\nabla\log p(\boldsymbol{x}_t)}{(1-\bar{\alpha}_t)\sqrt{\alpha_t}}\end{aligned} \\ &\begin{aligned}=\left(\frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}+\frac{1-\alpha_t}{(1-\bar{\alpha}_t)\sqrt{\alpha_t}}\right)\boldsymbol{x}_t+\frac{1-\alpha_t}{\sqrt{\alpha_t}}\nabla\log p(\boldsymbol{x}_t)\end{aligned} \\ &\begin{aligned}&=\left(\frac{\alpha_t(1-\bar{\alpha}_{t-1})}{(1-\bar{\alpha}_t)\sqrt{\alpha_t}}+\frac{1-\alpha_t}{(1-\bar{\alpha}_t)\sqrt{\alpha_t}}\right)\boldsymbol{x}_t+\frac{1-\alpha_t}{\sqrt{\alpha_t}}\nabla\log p(\boldsymbol{x}_t)\end{aligned} \\ &\begin{aligned}=\frac{\alpha_t-\bar{\alpha}_t+1-\alpha_t}{(1-\bar{\alpha}_t)\sqrt{\alpha_t}}\boldsymbol{x}_t+\frac{1-\alpha_t}{\sqrt{\alpha_t}}\nabla\log p(\boldsymbol{x}_t)\end{aligned} \\ &\begin{aligned}&=\frac{1-\bar{\alpha}_t}{(1-\bar{\alpha}_t)\sqrt{\alpha_t}}\boldsymbol{x}_t+\frac{1-\alpha_t}{\sqrt{\alpha_t}}\nabla\log p(\boldsymbol{x}_t)\end{aligned} \\ &=\frac{1}{\sqrt{\alpha_t}}\boldsymbol{x}_t+\frac{1-\alpha_t}{\sqrt{\alpha_t}}\nabla\log p(\boldsymbol{x}_t) \\ \end{aligned}

同样地,我们可以把μθ\mu_{\theta}设置为(用神经网络建模分数函数sθ(xt,t)s_\theta(x_t,t)):

μθ(xt,t)=1αtxt+1αtαtsθ(xt,t)\boldsymbol{\mu}_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t)=\frac{1}{\sqrt{\alpha_t}}\boldsymbol{x}_t+\frac{1-\alpha_t}{\sqrt{\alpha_t}}\boldsymbol{s}_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t)

因此,KL散度的优化过程可以表示为:

argminθDKL(q(xt1xt,x0)pθ(xt1xt))=argminθDKL(N(xt1;μq,Σq(t))N(xt1;μθ,Σq(t)))=argminθ12σq2(t)[μθμq22]=argminθ12σq2(t)[1αtxt+1αtαtsθ(xt,t)1αtxt1αtαtlogp(xt)22]=argminθ12σq2(t)[1αtαtsθ(xt,t)1αtαtlogp(xt)22]=argminθ12σq2(t)[1αtαt(sθ(xt,t)logp(xt))22]=argminθ12σq2(t)(1αt)2αt[sθ(xt,t)logp(xt)22]\begin{aligned} &\underset{\boldsymbol{\theta}}{\operatorname*{\arg\min}}D_{\mathrm{KL}}(q(\boldsymbol{x}_{t-1}|\boldsymbol{x}_{t},\boldsymbol{x}_{0})\parallel p_{\boldsymbol{\theta}}(\boldsymbol{x}_{t-1}|\boldsymbol{x}_{t})) \\ &=\arg\min_{\boldsymbol{\theta}}D_{\mathrm{KL}}(\mathcal{N}\left(\boldsymbol{x}_{t-1};\boldsymbol{\mu}_q,\boldsymbol{\Sigma}_q\left(t\right)\right) \parallel\mathcal{N}\left(\boldsymbol{x}_{t-1};\boldsymbol{\mu}_{\boldsymbol{\theta}},\boldsymbol{\Sigma}_q\left(t\right)\right)) \\ & =\arg\min_{\boldsymbol{\theta}}\frac{1}{2\sigma_q^2(t)}\left[\left\|\boldsymbol{\mu}_{\boldsymbol{\theta}}-\boldsymbol{\mu}_q\right\|_2^2\right]\\ &=\arg\min_{\boldsymbol{\theta}}\frac{1}{2\sigma_q^2(t)}\left[\left\|\frac{1}{\sqrt{\alpha_t}}\boldsymbol{x}_t+\frac{1-\alpha_t}{\sqrt{\alpha_t}}\boldsymbol{s}_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t)-\frac{1}{\sqrt{\alpha_t}}\boldsymbol{x}_t-\frac{1-\alpha_t}{\sqrt{\alpha_t}}\nabla\log p(\boldsymbol{x}_t)\right\|_2^2\right] \\ &\begin{aligned}=\arg\min_{\boldsymbol{\theta}}\frac{1}{2\sigma_q^2(t)}\left[\left\|\frac{1-\alpha_t}{\sqrt{\alpha_t}}s_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t)-\frac{1-\alpha_t}{\sqrt{\alpha_t}}\nabla\log p(\boldsymbol{x}_t)\right\|_2^2\right]\end{aligned} \\ &=\arg\min_{\boldsymbol{\theta}}\frac{1}{2\sigma_q^2(t)}\left[\left\|\frac{1-\alpha_t}{\sqrt{\alpha_t}}(\boldsymbol{s_\theta}(\boldsymbol{x}_t,t)-\nabla\log p(\boldsymbol{x}_t))\right\|_2^2\right] \\ &\begin{aligned}&=\underset{\boldsymbol{\theta}}{\arg\min}\frac{1}{2\sigma_q^2(t)}\frac{(1-\alpha_t)^2}{\alpha_t}\left[\left\|\boldsymbol{s}_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t)-\nabla\log p(\boldsymbol{x}_t)\right\|_2^2\right]\end{aligned} \end{aligned}

并且,参看Post1,预测分数函数和预测噪声实际上是可以相互替换的。

条件控制生成

基本方法

对于有条件控制的生成任务,我们需要建模的对象为p(xy)p(x|y)。首先,可以将条件信息yy加在每一个时间步上,即:

p(x0:Ty)=p(xT)t=1Tpθ(xt1xt,y)p(\boldsymbol{x}_{0:T}|y)=p(\boldsymbol{x}_T)\prod_{t=1}^Tp_{\boldsymbol{\theta}}(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t,y)

其中,yy可以是文本的嵌入编码或者等等。最终,对于三种等价的解释,分别可以用神经网络建模x^θ(xt,t,y)x0,ϵ^θ(xt,t,y)ϵ0,或者sθ(xt,t,y)logp(xty)\hat{\boldsymbol{x}}_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t,y) \approx \boldsymbol{x}_0, \boldsymbol{\hat{\epsilon}}_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t,y) \approx \boldsymbol{\epsilon}_0, \mathrm{或者} \boldsymbol{s}_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t,y) \approx \nabla\log p(\boldsymbol{x}_t|y)

但是,这种方法的结果是模型往往会忽略掉提供的条件信息。原文又介绍了基于Classifier和Classifier free的模型。

Classifier Guidance

在基于分数函数的解释中,q(xt1xt,x0)q(x_{t-1}|x_t,x_0)的均值μq\mu_q中包含logp(xt)\nabla\log p(\boldsymbol{x}_t)项。现在,在有条件信息的情况下,变成logp(xty)\nabla\log p(\boldsymbol{x}_t|y)

logp(xty)=log(p(xt)p(yxt)p(y))=logp(xt)+logp(yxt)logp(y)=logp(xt)unconditional score+logp(yxt)adversarial gradient\begin{aligned} \begin{aligned}\nabla\log p(\boldsymbol{x}_t|y)\end{aligned}& \begin{aligned}=\nabla\log\left(\frac{p(\boldsymbol{x}_t)p(y|\boldsymbol{x}_t)}{p(y)}\right)\end{aligned} \\ &=\nabla\log p(\boldsymbol{x}_t)+\nabla\log p(y|\boldsymbol{x}_t)-\nabla\log p(y) \\ &=\underbrace{\nabla\log p(\boldsymbol{x}_t)}_{\text{unconditional score}}+\underbrace{\nabla\log p(y|\boldsymbol{x}_t)}_{\text{adversarial gradient}} \end{aligned}

因为这里的\nabla都是对xx求导,因此最后一项为0。对抗梯度项,在实现时使用一个“可以忍受”噪声的分类器(因为输入的都是加噪的隐变量,这也是导致其很难用已经训练好的分类器作为预训练模型的原因,需要与扩散模型一块训练)去计算对抗梯度。

logp(xty)=logp(xt)+γlogp(yxt)\nabla\log p(\boldsymbol{x}_t|y)=\nabla\log p(\boldsymbol{x}_t)+\gamma\nabla\log p(y|\boldsymbol{x}_t)

这里的参数γ\gamma是用于控制条件信息强度的超参数。当其为0时会完全忽略条件信息。当其逐渐变大,采样多样性会有损失,但是根据条件信息采样的数据会更容易重现(甚至对于隐变量)。

Classifier free Guidance

但是,为了避免训练一个分类器,可以把条件分数函数进行重组:

logp(yxt)=logp(xty)logp(xt)\nabla\log p(y|\boldsymbol{x}_t)=\nabla\log p(\boldsymbol{x}_t|y)-\nabla\log p(\boldsymbol{x}_t)

带回到Classifier Guidance方法的表达式中:

logp(xty)=logp(xt)+γ(logp(xty)logp(xt))=logp(xt)+γlogp(xty)γlogp(xt)=γlogp(xty)conditional score+(1γ)logp(xt)unconditional score\begin{aligned} \nabla\operatorname{log}p(\boldsymbol{x}_t|y)& \begin{aligned}=\nabla\log p(\boldsymbol{x}_t)+\gamma\left(\nabla\log p(\boldsymbol{x}_t|y)-\nabla\log p(\boldsymbol{x}_t)\right)\end{aligned} \\ &=\nabla\log p(\boldsymbol{x}_t)+\gamma\nabla\log p(\boldsymbol{x}_t|y)-\gamma\nabla\log p(\boldsymbol{x}_t) \\ &=\underbrace{\gamma\nabla\log p(\boldsymbol{x}_t|y)}_{\text{conditional score}}+\underbrace{(1-\gamma)\nabla\log p(\boldsymbol{x}_t)}_{\text{unconditional score}} \end{aligned}

第一项相当于是训练基本控制生成方法中基于分数函数的解释的模型。第二项则为普通的非条件分数函数。类似地,当γ\gamma为0时,模型会完全忽略条件信息。等于1的时候相当于学习了一个普通的条件控制扩散模型。而当大于1时,不仅强化了条件控制分数函数的优先级,并且相当于还要避免学习到普通扩散模型的状态(也是会减少多样性,但是会匹配条件信息)。

但是,在具体实现的时候,同时训练两个扩散模型也是计算成本很高的。因此,对于非条件分数函数,会使用一个固定值代替条件信息(全0等)。这相当于是条件分数函数随机丢弃了条件信息。这种方法提高了对条件控制的程度,并且仅需要训练一个扩散模型。

Post1:

x0=xt+(1αˉt)logp(xt)αˉt=xt1αˉtϵ0αˉt(1αˉt)logp(xt)=1αˉtϵ0logp(xt)=11αˉtϵ0\begin{aligned} \boldsymbol{x}_0=\frac{\boldsymbol{x}_t+(1-\bar{\alpha}_t)\nabla\operatorname{log}p(\boldsymbol{x}_t)}{\sqrt{\bar{\alpha}_t}}& \begin{aligned}=\frac{\boldsymbol{x}_t-\sqrt{1-\bar{\alpha}_t}\boldsymbol{\epsilon}_0}{\sqrt{\bar{\alpha}_t}}\end{aligned} \\ \therefore(1-\bar{\alpha}_{t})\nabla\operatorname{log}p(\boldsymbol{x}_{t})& =-\sqrt{1-\bar{\alpha}_t}\boldsymbol{\epsilon}_0 \\ \nabla\operatorname{log}p(\boldsymbol{x}_t)& =-\frac{1}{\sqrt{1-\bar{\alpha}_{t}}}\boldsymbol{\epsilon}_{0} \end{aligned}

后面可以参考的博客


《Understanding Diffusion Models》阅读笔记2
https://fengxiang777.github.io/2024/08/08/《Understanding-Diffusion-Models》阅读笔记2/
作者
FengXiang777
发布于
2024年8月8日
许可协议