VAE学习笔记

VAE 模型

《Auto-Encoding Variational Bayes》

整体思路:

  • 建模生成模型p(x)p(x)p(xz)p(x|z)用于可控的生成
  • 最大似然估计:
  1. 找到最优参数(模型)使得p(xθ)p(\mathbf{x}| \theta)最大
  2. 全概率公式拆分边缘概率,引入隐变量zz
  3. 以上积分变换为期望形式
  4. 不能使用采样的方式估计以上期望:大部分从p(z)p(z)中采样出的zz对生成xx无帮助

引入后验分布,转换为后验分布上的期望

  • 参数化后验分布,转换为期望形式
  • ELBO引入
  • 琴生不等式,将对数操作提出来
  • 转换为重建损失和真实后验分布与参数化的后验分布的KL散度的形式
  • KL散度的解析解(p(z)p(z)的先验假设为0,1正态分布)
  • 输出假设为相互独立的高斯分布,使用神经网络进行建模
  • 重参数化技巧,使得zz对编码器的输出μ\muσ2{\sigma}^2可导

过程推导

优化参数或者模型,使得生成训练集的概率最大

θ=argmaxθp(x(1);θ)p(x(N);θ)=argmaxθlog(p(x(1);θ)p(x(N);θ))=argmaxθi=1nlogp(x(i);θ).\begin{aligned} \theta^{*}& =\arg\max_{\theta}p(\boldsymbol{x}^{(1)};\boldsymbol{\theta})\cdots p(\boldsymbol{x}^{(N)};\boldsymbol{\theta}) \\ &=\arg\max_{\theta}\log\left(p(\boldsymbol{x}^{(1)};\boldsymbol{\theta})\cdots p(\boldsymbol{x}^{(N)};\boldsymbol{\theta})\right) \\ &=\arg\max_{\boldsymbol{\theta}}\sum_{i=1}^n\log p(\boldsymbol{x}^{(i)};\boldsymbol{\theta}). \end{aligned}

引入隐变量zz,用于可控生成

p(x;θ)=zp(x,z;θ)dz=zp(z)p(xz;θ)dz.\begin{aligned} p(\boldsymbol x;\boldsymbol\theta)& =\int_zp(\boldsymbol{x},\boldsymbol{z};\boldsymbol{\theta})\mathrm{d}\boldsymbol{z} \\ &=\int_zp(\boldsymbol{z})p(\boldsymbol{x}|\boldsymbol{z};\boldsymbol{\theta})\mathrm{d}\boldsymbol{z}. \end{aligned}

转换为期望形式,但是由于直接从假设的p(z)p(z)(也就是正态分布)中采样并做蒙特卡洛估计会因为大部分zz对于生成xx毫无帮助而产生估计错误。

p(x;θ)=zp(z)p(xz;θ)dz=Ezp(z)[p(xz;θ)]1mi=1mp(xz(i);θ),z(i)p(z).\begin{aligned} p(\boldsymbol{x};\boldsymbol\theta)& =\int_zp(\boldsymbol{z})p(\boldsymbol{x}|\boldsymbol{z};\boldsymbol{\theta})\mathrm{d}\boldsymbol{z} \\ &=\mathbb{E}_{z\sim p(z)}[p(\boldsymbol{x}|\boldsymbol{z};\boldsymbol{\theta})] \\ &\approx\frac1m\sum_{i=1}^mp(\boldsymbol{x}|\boldsymbol{z}^{(i)};\boldsymbol{\theta}),\quad\boldsymbol{z}^{(i)}\sim p(\boldsymbol{z}). \end{aligned}

这里转而使用重要性采样技巧,引入后验分布,这样zz会从对生成xx更为重要的地方采样出来。

p(x;θ)=zp(z)p(xz;θ)dz=zp(zx)p(zx)p(z)p(xz;θ)dz=Ezp(zx)[p(z)p(xz;θ)p(zx)].\begin{aligned} p(\boldsymbol{x};\boldsymbol\theta)& =\int_zp(\boldsymbol{z})p(\boldsymbol{x}|\boldsymbol{z};\boldsymbol{\theta})\mathrm{d}\boldsymbol{z} \\ &=\int_z\frac{p(\boldsymbol{z}|\boldsymbol{x})}{p(\boldsymbol{z}|\boldsymbol{x})}p(\boldsymbol{z})p(\boldsymbol{x}|\boldsymbol{z};\boldsymbol{\theta})\mathrm{d}\boldsymbol{z} \\ &=\mathbb{E}_{\boldsymbol{z}\sim p(\boldsymbol{z}|\boldsymbol{x})}\left[\frac{p(\boldsymbol{z})p(\boldsymbol{x}|\boldsymbol{z};\boldsymbol{\theta})}{p(\boldsymbol{z}|\boldsymbol{x})}\right]. \end{aligned}

这里引入的后验分布也无法准确解析表示,需要使用神经网络ϕ\phi去做从p(zx;ϕ)p(z|x;\phi)p(zx)p(z|x)拟合。

这两个分布的差距足够小的时候才能用蒙特卡罗估计去计算期望,否则误差将会很大。

目前为止

主要目标:优化θ\theta使得对数似然最大。
引申目标:优化ϕ\phi使得变分后验和真实后验尽可能相似(也就是为什么要引入ELBO:最大化证据下界等价于最小化这两个分布的差距)。以此为基础,使用蒙特卡罗估计去优化θ\theta

使用琴生不等式 and log函数是凹函数 将对数函数提到期望内部。

logp(x;θ)=logEzq(zx;ϕ)[p(z)p(xz;θ)q(zx;ϕ)]Ezq(zx;ϕ)[logp(z)p(xz;θ)q(zx;ϕ)]=ELBO(x;θ,ϕ).\begin{aligned} \log p(\boldsymbol{x};\boldsymbol{\theta})& =\log\mathbb{E}_{\boldsymbol{z}\sim q(\boldsymbol{z}|\boldsymbol{x};\boldsymbol{\phi})}\left[\frac{p(\boldsymbol{z})p(\boldsymbol{x}|\boldsymbol{z};\boldsymbol{\theta})}{q(\boldsymbol{z}|\boldsymbol{x};\boldsymbol{\phi})}\right] \\ &\geq\mathbb{E}_{\boldsymbol{z}\sim q(\boldsymbol{z}|\boldsymbol{x};\boldsymbol{\phi})}\left[\log\frac{p(\boldsymbol{z})p(\boldsymbol{x}|\boldsymbol{z};\boldsymbol{\theta})}{q(\boldsymbol{z}|\boldsymbol{x};\boldsymbol{\phi})}\right] \\ &=\mathrm{ELBO}(\boldsymbol{x};\boldsymbol{\theta},\boldsymbol{\phi}). \end{aligned}

这里ELBO可以拆成两个部分,分别是

  1. 通过xx确定的zz的分布,再从该分布采样重建xx的损失(使用在zp(zx)z \sim p(z|x)上的期望表示)
  2. 近似的后验分布与真实的p(zx)p(z|x)KL散度(用下面分布近似上面的分布需要多少字节)。也就是对数似然减去近似后验的KL散度。

ELBO(x;θ,ϕ)=Ezq(zx;ϕ)[logp(x,z;θ)q(zx;ϕ)]=Ezq(zx;ϕ)[logp(x;θ)p(zx)q(zx;ϕ)]=Ezq(zx;ϕ)[logp(x;θ)]+Ezq(zx;ϕ)[logp(zx)q(zx;ϕ)]=logp(x;θ)Ezq(zx;ϕ)[logq(zx;ϕ)p(zx)]=logp(x;θ)KL[q(zx;ϕ)p(zx)],\begin{aligned} \mathrm{ELBO}(\boldsymbol{x};\boldsymbol{\theta},\boldsymbol{\phi})& =\mathbb{E}_{\boldsymbol{z}\sim q(\boldsymbol{z}|\boldsymbol{x};\boldsymbol{\phi})}\left[\log\frac{p(\boldsymbol{x},\boldsymbol{z};\boldsymbol{\theta})}{q(\boldsymbol{z}|\boldsymbol{x};\boldsymbol{\phi})}\right] \\ &=\mathbb{E}_{\boldsymbol{z}\sim q(\boldsymbol{z}|\boldsymbol{x};\boldsymbol{\phi})}\left[\log\frac{p(\boldsymbol{x};\boldsymbol{\theta})p(\boldsymbol{z}|\boldsymbol{x})}{q(\boldsymbol{z}|\boldsymbol{x};\boldsymbol{\phi})}\right] \\ &=\mathbb{E}_{\boldsymbol{z}\sim q(\boldsymbol{z}|\boldsymbol{x};\boldsymbol{\phi})}[\log p(\boldsymbol{x};\boldsymbol{\theta})]+\mathbb{E}_{\boldsymbol{z}\sim q(\boldsymbol{z}|\boldsymbol{x};\boldsymbol{\phi})}\left[\log\frac{p(\boldsymbol{z}|\boldsymbol{x})}{q(\boldsymbol{z}|\boldsymbol{x};\boldsymbol{\phi})}\right] \\ &=\log p(\boldsymbol{x};\boldsymbol{\theta})-\mathbb{E}_{\boldsymbol{z}\sim q(\boldsymbol{z}|\boldsymbol{x};\boldsymbol{\phi})}\left[\log\frac{q(\boldsymbol{z}|\boldsymbol{x};\boldsymbol{\phi})}{p(\boldsymbol{z}|\boldsymbol{x})}\right] \\ &=\log p(\boldsymbol{x};\boldsymbol{\theta})-\mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x};\boldsymbol{\phi})\parallel p(\boldsymbol{z}|\boldsymbol{x})], \end{aligned}

通过拆联合概率的不同方法,ELBO拆成两个可计算的损失函数。

ELBO(x;θ,ϕ)=Ezq(zx;ϕ)[logp(z)p(xz;θ)q(zx;ϕ)]=Ezq(zx;ϕ)[p(xz;θ)]+Ezq(zx;ϕ)[logp(z)q(zx;ϕ)]=Ezq(zx;ϕ)[p(xz;θ)]Ezq(zx;ϕ)[logq(zx;ϕ)p(z)]=Ezq(zx;ϕ)[p(xz;θ)]KL[q(zx;ϕ)p(z)].\begin{aligned} \mathrm{ELBO}(\boldsymbol{x};\boldsymbol{\theta},\boldsymbol{\phi})& =\mathbb{E}_{\boldsymbol{z}\sim q(\boldsymbol{z}|\boldsymbol{x};\boldsymbol{\phi})}\left[\log\frac{p(\boldsymbol{z})p(\boldsymbol{x}|\boldsymbol{z};\boldsymbol{\theta})}{q(\boldsymbol{z}|\boldsymbol{x};\boldsymbol{\phi})}\right] \\ &=\mathbb{E}_{\boldsymbol{z}\sim q(\boldsymbol{z}|\boldsymbol{x};\boldsymbol{\phi})}[p(\boldsymbol{x}|\boldsymbol{z};\boldsymbol{\theta})]+\mathbb{E}_{\boldsymbol{z}\sim q(\boldsymbol{z}|\boldsymbol{x};\boldsymbol{\phi})}\left[\log\frac{p(\boldsymbol{z})}{q(\boldsymbol{z}|\boldsymbol{x};\boldsymbol{\phi})}\right] \\ &=\mathbb{E}_{\boldsymbol{z}\sim q(\boldsymbol{z}|\boldsymbol{x};\boldsymbol{\phi})}[p(\boldsymbol{x}|\boldsymbol{z};\boldsymbol{\theta})]-\mathbb{E}_{\boldsymbol{z}\sim q(\boldsymbol{z}|\boldsymbol{x};\boldsymbol{\phi})}\left[\log\frac{q(\boldsymbol{z}|\boldsymbol{x};\boldsymbol{\phi})}{p(\boldsymbol{z})}\right] \\ &=\mathbb{E}_{\boldsymbol{z}\sim q(\boldsymbol{z}|\boldsymbol{x};\boldsymbol{\phi})}[p(\boldsymbol{x}|\boldsymbol{z};\boldsymbol{\theta})]-\mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x};\boldsymbol{\phi})\parallel p(\boldsymbol{z})]. \end{aligned}

这里引入两个假设:

  1. p(z)p(z)的先验为0,1的正态分布,可以得到后者的解析解(同时作为正则化项,避免拟合的后验分布方差接近0,而退化成自编码器)。

KL[N(z;μ,σ2)N(z;0,1)]=zN(z;μ,σ2)logN(z;μ,σ2)N(z;0,1)dz=zN(z;μ,σ2)log12πσexp((zμ)22σ2)12πexp(z22)dz=zN(z;μ,σ2)(logσ(zμ)22σ2+z22)dz=12zN(z;μ,σ2)(2logσ(zμ)2σ2+z2)dz=12(logσ2zN(z;μ,σ2)dz1σ2zN(z;μ,σ2)(zμ)2dz+zN(z;μ,σ2)z2dz)\begin{aligned} &\mathrm{KL}[N(z;\mu,\sigma^{2})\parallel N(z;0,1)] \\ &=\int_{z}N(z;\mu,\sigma^{2})\operatorname{log}\frac{N(z;\mu,\sigma^{2})}{N(z;0,1)}\mathrm{d}z \\ &=\int_{z}N(z;\mu,\sigma^{2})\log\frac{\frac{1}{\sqrt{2\pi}\sigma}\mathrm{exp}\Big(-\frac{(z-\mu)^{2}}{2\sigma^{2}}\Big)}{\frac{1}{\sqrt{2\pi}}\mathrm{exp}\Big(-\frac{z^{2}}{2}\Big)}\mathrm{d}z \\ &=\int_zN(z;\mu,\sigma^2)\left(-\log\sigma-\frac{(z-\mu)^2}{2\sigma^2}+\frac{z^2}2\right)\mathrm{d}z \\ &=\frac12\int_zN(z;\mu,\sigma^2)\left(-2\log\sigma-\frac{(z-\mu)^2}{\sigma^2}+z^2\right)\mathrm{d}z \\ &=\frac12\left(-\log\sigma^2\int_zN(z;\mu,\sigma^2)\mathrm{d}z-\frac1{\sigma^2}\int_zN(z;\mu,\sigma^2)(z-\mu)^2\mathrm{d}z+\int_zN(z;\mu,\sigma^2)z^2dz\right) \end{aligned}

其中 对1 , (zμ)2(z-\mu)^2,z2z^2在正态上积分为1,方差σ2\sigma^2和二阶矩σ2+μ2\sigma^2+\mu^2

  1. p(xz;θ)p(x|z;\theta)针对于不同的任务,建模成不同的分布,再构建重建损失的具体形式。

例如,当p(xz;θ)p(x|z;\theta)建模为方差都相同的多元独立高斯分布时。

p(xz;θ):=N(x;μ,σ2I)=i=1DN(xi;μi,σi2)=(i=1D12πσi)exp(i=1D(xiμi)22σi2),\begin{aligned} p(\boldsymbol{x}|\boldsymbol z;\boldsymbol\theta)& :=N(\boldsymbol{x};\boldsymbol{\mu},\boldsymbol{\sigma}^2\boldsymbol{I}) \\ &=\prod_{i=1}^DN(x_i;\mu_i,\sigma_i^2) \\ &=\left(\prod_{i=1}^D\frac1{\sqrt{2\pi}\sigma_i}\right)\exp\left(\sum_{i=1}^D-\frac{(x_i-\mu_i)^2}{2\sigma_i^2}\right), \end{aligned}

这样,对于最大化p(xz;θ)p(x|z;\theta)可以使用损失函数xx^(z;θ)2||x-\hat{x}(z;\theta)||^2来优化。
如果,建模成多维度二元分类问题,类似上式的方法,可以推导出二元交叉熵损失。

最后,整体的损失函数可以表达为:

Loss=1ni=1n[lnp(xizi)+12i=1d(lnσ2+σ2+μ21)]Loss=\frac1n\sum_{i=1}^n[\ln p(x_i|z_i)+\frac12\sum_{i=1}^d(-\ln\sigma^2+\sigma^2+\mu^2-1)]

重要性采样保证了对于每一个样本xx都只采样一个zz也能很好地估计期望。

网络结构设计

结构图1

结构图2

重参数化技巧

从编码器预测出的μ\muσ2\sigma^2到采样出的zz的具体值之间。因为使用了采样,导致过程并不可微分。
也就是zμ 和 zσ2\frac{\partial z}{\partial\mu}\text{ 和 }\frac{\partial z}{\partial\sigma^{2}}无法求出。

这里从N(ϵ;0,I)N(\boldsymbol{\epsilon};\boldsymbol{0},\boldsymbol{I})采样ϵ\boldsymbol{\epsilon},并使得

z=μ+σϵz = \boldsymbol{\mu}+\sigma\odot\boldsymbol{\epsilon}

这样,在求zμ\frac{\partial z}{\partial\mu}zσ2\frac{\partial z}{\partial\sigma^{2}},就可以分别得到1和ϵ\epsilon,避免了出现由于采样而出现的不可微分的问题。

应用

当把隐变量建模为-1~1 * -1~1的网格时,并且学习到隐变量带来的知识时,可以可视化并且可控采样不同类型的数据

隐变量可视化


VAE学习笔记
https://fengxiang777.github.io/2024/07/30/VAE学习笔记/
作者
FengXiang777
发布于
2024年7月30日
许可协议