《Understanding Diffusion Models: A Unified Perspective》阅读笔记2
原文链接
第二种等价解释:学习扩散噪声参数
笔记1中学习从xt预测x0,是通过将μθ凑成μq的形式推导出来的
在这里,也是一样,设置μθ和μq拥有一样的形式:
具体地,在这里,去除μq中的x0,使其变成关于xt和ϵ的函数。并且,凑出μθ与前者有相同的形式,进一步学习预测噪声。
基本流程
- 将加噪去噪时间表变成通过参数变成可以学习的(步数没法变,但是每步之间的距离可以)
- 通过重参数化,将均值参数均表示成xt和噪声ϵ的函数
首先,解决加噪时间表的学习问题,如果直接学习α^η(t),会在推理时对一个样本多次使用该NN,效率低。
将KL散度进行变形得到不同的优化形式:
DKL(q(xt−1∣xt,x0)∥pθ(xt−1∣xt))=2σq2(t)1(1−αˉt)2αˉt−1(1−αt)2[∥x^θ(xt,t)−x0∥22]=21−αˉt(1−αt)(1−αˉt−1)1(1−αˉt)2αˉt−1(1−αt)2[∥x^θ(xt,t)−x0∥22]=21(1−αt)(1−αˉt−1)1−αˉt(1−αˉt)2αˉt−1(1−αt)2[∥x^θ(xt,t)−x0∥22]=21(1−αˉt−1)(1−αˉt)αˉt−1(1−αt)[∥x^θ(xt,t)−x0∥22]=21(1−αˉt−1)(1−αˉt)αˉt−1−αˉt[∥x^θ(xt,t)−x0∥22]=21(1−αˉt−1)(1−αˉt)αˉt−1−αˉt−1αˉt+αˉt−1αˉt−αˉt[∥x^θ(xt,t)−x0∥22]=21(1−αˉt−1)(1−αˉt)αˉt−1(1−αˉt)−αˉt(1−αˉt−1)[∥x^θ(xt,t)−x0∥22]=21((1−αˉt−1)(1−αˉt)αˉt−1(1−αˉt)−(1−αˉt−1)(1−αˉt)αˉt(1−αˉt−1))[∥x^θ(xt,t)−x0∥22]=21(1−αˉt−1αˉt−1−1−αˉtαˉt)[∥x^θ(xt,t)−x0∥22]替换真实的σ加一项减一项
这样KL散度里就只有αˉ形式的存在,可以直接用NN建模该部分,减少计算量。
参考信噪比的定义有:
SNR(t)=1−αˉtαˉt
这样当逐渐到T时间步时,αtˉ逐渐到0,1−αtˉ,数据完全成为0均值1方差的标准高斯分布。
因此,上述KL散度可以写成:
21(SNR(t−1)−SNR(t))[∥x^θ(xt,t)−x0∥22]
具体来说,SNR应是一个关于t单调递减的范围从0到1的函数,其可以通过建模一个关于t单调递增的函数ω来实现(ω神经元强制非负模拟单调函数,输入增大,输出一定是不会变小的):
SNR(t)=exp(−ωη(t))
这样,可以得到:
1−αˉtαˉt=exp(−ωη(t))∴αˉt=sigmoid(−ωη(t))∴1−αˉt=sigmoid(ωη(t))
以上这些,可以用于从x0到xt的重参数化过程。
学习噪声扩散的流程
现在开始,分别将μq和μθ表示为xt+ϵ的形式:
μq(xt,x0)=1−αˉtαt(1−αˉt−1)xt+αˉt−1(1−αt)x0=1−αˉtαt(1−αˉt−1)xt+αˉt−1(1−αt)αˉtxt−1−αˉtϵ0=1−αˉtαt(1−αˉt−1)xt+(1−αt)αtxt−1−αˉtϵ0=1−αˉtαt(1−αˉt−1)xt+(1−αˉt)αt(1−αt)xt−(1−αˉt)αt(1−αt)1−αˉtϵ0=(1−αˉtαt(1−αˉt−1)+(1−αˉt)αt1−αt)xt−(1−αˉt)αt(1−αt)1−αˉtϵ0=((1−αˉt)αtαt(1−αˉt−1)+(1−αˉt)αt1−αt)xt−1−αˉtαt1−αtϵ0=(1−αˉt)αtαt−αˉt+1−αtxt−1−αˉtαt1−αtϵ0=(1−αˉt)αt1−αˉtxt−1−αˉtαt1−αtϵ0=αt1xt−1−αˉtαt1−αtϵ0重参数化的逆过程
而μθ是pθ(xt−1∣xt)的均值,即xt和t的函数。可以模仿前者,类似的设置为关于xt和ϵ的函数。由于xt已知,ϵ为止,可以将ϵ用神经网络建模,得到:
μθ(xt,t)=αt1xt−1−αˉtαt1−αtϵ^θ(xt,t)
这里KL散度可以写为:
argθminDKL(q(xt−1∣xt,x0)∥pθ(xt−1∣xt))=argθminDKL(N(xt−1;μq,Σq(t))∥N(xt−1;μθ,Σq(t)))=argθmin2σq2(t)1[μθ−μq22]=argθmin2σq2(t)1[αt1xt−1−αˉtαt1−αtϵ^θ(xt,t)−αt1xt+1−αˉtαt1−αtϵ022]=argθmin2σq2(t)1[1−αˉtαt1−αtϵ0−1−αˉtαt1−αtϵ^θ(xt,t)22]=argθmin2σq2(t)1[1−αˉtαt1−αt(ϵ0−ϵ^θ(xt,t))22]=argθmin2σq2(t)1(1−αˉt)αt(1−αt)2[∥ϵ0−ϵ^θ(xt,t)∥22]
注意,在用神经网络建模ϵ时大小应控制在(0,1)的范围内。据原文,尽管预测x0和ϵ是等价的。但是,预测噪声在实际操作上效果会更好。
第三种等价解释:学习分数函数模型
特威迪公式做了什么(当z∼N(z;μz,Σz))?
E[μz∣z]=z+Σz∇zlogp(z)
关于μq的表达式中有x0,在概率分布q(xt∣x0)的均值μxt可以用关于x0来表达:αˉx0。
q(xt∣x0)=N(xt;αˉtx0,(1−αˉt)I)
特威迪公式里需要有一个观测变量xt没关系,因为在q(xt−1∣xt,x0)里xt是已知的,总结起来:相当于是在xt有一个已观测样本时,对μq表达式中的x0做了一个估计和替换。其中,由于假设,μxt=αˉx0是完全准确的,没有估计过程。=z+Σz∇zlogp(z)过程是存在估计的。
E[μxt∣xt]=αˉx0=xt+(1−αˉt)∇xtlogp(xt)
∴x0=αˉtxt+(1−αˉt)∇logp(xt)
我们通过特威迪公式得到均值μxt的估计用于替换掉μq中的x0。
μq(xt,x0)=1−αˉtαt(1−αˉt−1)xt+αˉt−1(1−αt)x0=1−αˉtαt(1−αˉt−1)xt+αˉt−1(1−αt)αˉtxt+(1−αˉt)∇logp(xt)=1−αˉtαt(1−αˉt−1)xt+(1−αt)αtxt+(1−αˉt)∇logp(xt)=1−αˉtαt(1−αˉt−1)xt+(1−αˉt)αt(1−αt)xt+(1−αˉt)αt(1−αt)(1−αˉt)∇logp(xt)=(1−αˉtαt(1−αˉt−1)+(1−αˉt)αt1−αt)xt+αt1−αt∇logp(xt)=((1−αˉt)αtαt(1−αˉt−1)+(1−αˉt)αt1−αt)xt+αt1−αt∇logp(xt)=(1−αˉt)αtαt−αˉt+1−αtxt+αt1−αt∇logp(xt)=(1−αˉt)αt1−αˉtxt+αt1−αt∇logp(xt)=αt1xt+αt1−αt∇logp(xt)右上角一项的分母被消掉一部分
同样地,我们可以把μθ设置为(用神经网络建模分数函数sθ(xt,t)):
μθ(xt,t)=αt1xt+αt1−αtsθ(xt,t)
因此,KL散度的优化过程可以表示为:
θargminDKL(q(xt−1∣xt,x0)∥pθ(xt−1∣xt))=argθminDKL(N(xt−1;μq,Σq(t))∥N(xt−1;μθ,Σq(t)))=argθmin2σq2(t)1[μθ−μq22]=argθmin2σq2(t)1[αt1xt+αt1−αtsθ(xt,t)−αt1xt−αt1−αt∇logp(xt)22]=argθmin2σq2(t)1[αt1−αtsθ(xt,t)−αt1−αt∇logp(xt)22]=argθmin2σq2(t)1[αt1−αt(sθ(xt,t)−∇logp(xt))22]=θargmin2σq2(t)1αt(1−αt)2[∥sθ(xt,t)−∇logp(xt)∥22]
并且,参看Post1,预测分数函数和预测噪声实际上是可以相互替换的。
条件控制生成
基本方法
对于有条件控制的生成任务,我们需要建模的对象为p(x∣y)。首先,可以将条件信息y加在每一个时间步上,即:
p(x0:T∣y)=p(xT)t=1∏Tpθ(xt−1∣xt,y)
其中,y可以是文本的嵌入编码或者等等。最终,对于三种等价的解释,分别可以用神经网络建模x^θ(xt,t,y)≈x0,ϵ^θ(xt,t,y)≈ϵ0,或者sθ(xt,t,y)≈∇logp(xt∣y)。
但是,这种方法的结果是模型往往会忽略掉提供的条件信息。原文又介绍了基于Classifier和Classifier free的模型。
Classifier Guidance
在基于分数函数的解释中,q(xt−1∣xt,x0)的均值μq中包含∇logp(xt)项。现在,在有条件信息的情况下,变成∇logp(xt∣y)。
∇logp(xt∣y)=∇log(p(y)p(xt)p(y∣xt))=∇logp(xt)+∇logp(y∣xt)−∇logp(y)=unconditional score∇logp(xt)+adversarial gradient∇logp(y∣xt)
因为这里的∇都是对x求导,因此最后一项为0。对抗梯度项,在实现时使用一个“可以忍受”噪声的分类器(因为输入的都是加噪的隐变量,这也是导致其很难用已经训练好的分类器作为预训练模型的原因,需要与扩散模型一块训练)去计算对抗梯度。
∇logp(xt∣y)=∇logp(xt)+γ∇logp(y∣xt)
这里的参数γ是用于控制条件信息强度的超参数。当其为0时会完全忽略条件信息。当其逐渐变大,采样多样性会有损失,但是根据条件信息采样的数据会更容易重现(甚至对于隐变量)。
Classifier free Guidance
但是,为了避免训练一个分类器,可以把条件分数函数进行重组:
∇logp(y∣xt)=∇logp(xt∣y)−∇logp(xt)
带回到Classifier Guidance方法的表达式中:
∇logp(xt∣y)=∇logp(xt)+γ(∇logp(xt∣y)−∇logp(xt))=∇logp(xt)+γ∇logp(xt∣y)−γ∇logp(xt)=conditional scoreγ∇logp(xt∣y)+unconditional score(1−γ)∇logp(xt)
第一项相当于是训练基本控制生成方法中基于分数函数的解释的模型。第二项则为普通的非条件分数函数。类似地,当γ为0时,模型会完全忽略条件信息。等于1的时候相当于学习了一个普通的条件控制扩散模型。而当大于1时,不仅强化了条件控制分数函数的优先级,并且相当于还要避免学习到普通扩散模型的状态(也是会减少多样性,但是会匹配条件信息)。
但是,在具体实现的时候,同时训练两个扩散模型也是计算成本很高的。因此,对于非条件分数函数,会使用一个固定值代替条件信息(全0等)。这相当于是条件分数函数随机丢弃了条件信息。这种方法提高了对条件控制的程度,并且仅需要训练一个扩散模型。
Post1:
x0=αˉtxt+(1−αˉt)∇logp(xt)∴(1−αˉt)∇logp(xt)∇logp(xt)=αˉtxt−1−αˉtϵ0=−1−αˉtϵ0=−1−αˉt1ϵ0
后面可以参考的博客