VAE 模型
《Auto-Encoding Variational Bayes》
整体思路:
- 建模生成模型p(x)和p(x∣z)用于可控的生成
- 最大似然估计:
- 找到最优参数(模型)使得p(x∣θ)最大
- 全概率公式拆分边缘概率,引入隐变量z
- 以上积分变换为期望形式
- 不能使用采样的方式估计以上期望:大部分从p(z)中采样出的z对生成x无帮助
引入后验分布,转换为后验分布上的期望
- 琴生不等式,将对数操作提出来
- 转换为重建损失和真实后验分布与参数化的后验分布的KL散度的形式
- KL散度的解析解(p(z)的先验假设为0,1正态分布)
- 输出假设为相互独立的高斯分布,使用神经网络进行建模
- 重参数化技巧,使得z对编码器的输出μ和σ2可导
过程推导
优化参数或者模型,使得生成训练集的概率最大
θ∗=argθmaxp(x(1);θ)⋯p(x(N);θ)=argθmaxlog(p(x(1);θ)⋯p(x(N);θ))=argθmaxi=1∑nlogp(x(i);θ).
引入隐变量z,用于可控生成
p(x;θ)=∫zp(x,z;θ)dz=∫zp(z)p(x∣z;θ)dz.
转换为期望形式,但是由于直接从假设的p(z)(也就是正态分布)中采样并做蒙特卡洛估计会因为大部分z对于生成x毫无帮助而产生估计错误。
p(x;θ)=∫zp(z)p(x∣z;θ)dz=Ez∼p(z)[p(x∣z;θ)]≈m1i=1∑mp(x∣z(i);θ),z(i)∼p(z).
这里转而使用重要性采样技巧,引入后验分布,这样z会从对生成x更为重要的地方采样出来。
p(x;θ)=∫zp(z)p(x∣z;θ)dz=∫zp(z∣x)p(z∣x)p(z)p(x∣z;θ)dz=Ez∼p(z∣x)[p(z∣x)p(z)p(x∣z;θ)].
这里引入的后验分布也无法准确解析表示,需要使用神经网络ϕ去做从p(z∣x;ϕ)到p(z∣x)拟合。
这两个分布的差距足够小的时候才能用蒙特卡罗估计去计算期望,否则误差将会很大。
目前为止
主要目标:优化θ使得对数似然最大。
引申目标:优化ϕ使得变分后验和真实后验尽可能相似(也就是为什么要引入ELBO:最大化证据下界等价于最小化这两个分布的差距)。以此为基础,使用蒙特卡罗估计去优化θ。
使用琴生不等式 and log函数是凹函数 将对数函数提到期望内部。
logp(x;θ)=logEz∼q(z∣x;ϕ)[q(z∣x;ϕ)p(z)p(x∣z;θ)]≥Ez∼q(z∣x;ϕ)[logq(z∣x;ϕ)p(z)p(x∣z;θ)]=ELBO(x;θ,ϕ).
这里ELBO可以拆成两个部分,分别是
- 通过x确定的z的分布,再从该分布采样重建x的损失(使用在z∼p(z∣x)上的期望表示)
- 近似的后验分布与真实的p(z∣x)的KL散度(用下面分布近似上面的分布需要多少字节)。也就是对数似然减去近似后验的KL散度。
ELBO(x;θ,ϕ)=Ez∼q(z∣x;ϕ)[logq(z∣x;ϕ)p(x,z;θ)]=Ez∼q(z∣x;ϕ)[logq(z∣x;ϕ)p(x;θ)p(z∣x)]=Ez∼q(z∣x;ϕ)[logp(x;θ)]+Ez∼q(z∣x;ϕ)[logq(z∣x;ϕ)p(z∣x)]=logp(x;θ)−Ez∼q(z∣x;ϕ)[logp(z∣x)q(z∣x;ϕ)]=logp(x;θ)−KL[q(z∣x;ϕ)∥p(z∣x)],
通过拆联合概率的不同方法,ELBO拆成两个可计算的损失函数。
ELBO(x;θ,ϕ)=Ez∼q(z∣x;ϕ)[logq(z∣x;ϕ)p(z)p(x∣z;θ)]=Ez∼q(z∣x;ϕ)[p(x∣z;θ)]+Ez∼q(z∣x;ϕ)[logq(z∣x;ϕ)p(z)]=Ez∼q(z∣x;ϕ)[p(x∣z;θ)]−Ez∼q(z∣x;ϕ)[logp(z)q(z∣x;ϕ)]=Ez∼q(z∣x;ϕ)[p(x∣z;θ)]−KL[q(z∣x;ϕ)∥p(z)].
这里引入两个假设:
- p(z)的先验为0,1的正态分布,可以得到后者的解析解(同时作为正则化项,避免拟合的后验分布方差接近0,而退化成自编码器)。
KL[N(z;μ,σ2)∥N(z;0,1)]=∫zN(z;μ,σ2)logN(z;0,1)N(z;μ,σ2)dz=∫zN(z;μ,σ2)log2π1exp(−2z2)2πσ1exp(−2σ2(z−μ)2)dz=∫zN(z;μ,σ2)(−logσ−2σ2(z−μ)2+2z2)dz=21∫zN(z;μ,σ2)(−2logσ−σ2(z−μ)2+z2)dz=21(−logσ2∫zN(z;μ,σ2)dz−σ21∫zN(z;μ,σ2)(z−μ)2dz+∫zN(z;μ,σ2)z2dz)
其中 对1 , (z−μ)2,z2在正态上积分为1,方差σ2和二阶矩σ2+μ2
- p(x∣z;θ)针对于不同的任务,建模成不同的分布,再构建重建损失的具体形式。
例如,当p(x∣z;θ)建模为方差都相同的多元独立高斯分布时。
p(x∣z;θ):=N(x;μ,σ2I)=i=1∏DN(xi;μi,σi2)=(i=1∏D2πσi1)exp(i=1∑D−2σi2(xi−μi)2),
这样,对于最大化p(x∣z;θ)可以使用损失函数∣∣x−x^(z;θ)∣∣2来优化。
如果,建模成多维度二元分类问题,类似上式的方法,可以推导出二元交叉熵损失。
最后,整体的损失函数可以表达为:
Loss=n1i=1∑n[lnp(xi∣zi)+21i=1∑d(−lnσ2+σ2+μ2−1)]
重要性采样保证了对于每一个样本x都只采样一个z也能很好地估计期望。
网络结构设计


重参数化技巧
从编码器预测出的μ和σ2到采样出的z的具体值之间。因为使用了采样,导致过程并不可微分。
也就是∂μ∂z 和 ∂σ2∂z无法求出。
这里从N(ϵ;0,I)采样ϵ,并使得
z=μ+σ⊙ϵ
这样,在求∂μ∂z和∂σ2∂z,就可以分别得到1和ϵ,避免了出现由于采样而出现的不可微分的问题。
应用
当把隐变量建模为-1~1 * -1~1的网格时,并且学习到隐变量带来的知识时,可以可视化并且可控采样不同类型的数据
