好久没看的VAE又不太记得了,重新梳理一下思路,在S同学的指导下,又有了一些新的理解。之前写过一篇关于VAE的入门教程,但是感觉还不够简练,删掉重新写一个哈哈哈。这篇文章主要是从一个熟悉machine learning但是对于VAE一点都不懂的视角,进行写作的,并不涉及很多复杂的理论知识,辅助理解为主。另外,上次那篇文章是先给结论,在慢慢讲细节。这篇文章将会循序渐进逐渐推导出VAE的各种Trick存在的必要性。
VAE 入门
我们首先要明确一点,VAE是一个生成式的模型,什么是生成式的模型?简单来说,就是可以用来生成数据的模型。怎么样才能生成数据呢?就是我们是知道数据的分布的?有了这个分布之后,我们就可以从这个分布中采样,获得新的数据。
这个思路好像很简单啊,但是问题是这个是怎么得到的。有很多方法啊,其中包含这样两大类:1. 基于概率的,如MCMC, Variational Inference等。以及2. 基于机器学习的。前面提到过,本文主要面向的是对于机器学习比较熟悉的人,所以这里就对概率方法不多说了,主要讲一下机器学习的思路。
其实机器学习来解决这种问题的思路是很清晰的,大多数的机器学习问题都有这样一个思路。我们想要优化某个目标O,我们先对这个问题建个模型,模型可以表示为某个数学表达式 ,其中是参数。我们用数据去训练这个模型,然后根据目标O,去调整我们的参数,我们希望训练结束的时候,能够找到一组最优的。对应到我们这个生成式的问题,我们希望能够生成一个新的数据,那么我们构造一个模型 ,我们的目标呢就是这个生成的数据越真越好,就是在众多的中,我们希望能够找到一个最好的,能够让这个数据存在的概率越大越好。有了目标,我们就能计算出损失函数,然后就是利用梯度下降,逐步调整参数,最终找到最优。
前文这个过程好像很熟悉,但是存在几个问题:
- 我们建模其实是根据我们的assumption来的,我们的模型结构,初始参数设置都是根据我们的assumption来的, 但是我们的assumption有可能是错的,而且很有可能是错的。因此引入过多或者过强的assumption都会导致我们的模型效果很差。
- 因为我们的assumption和真实数据分布存在偏差,相应的,我们在优化的过程中,很容易陷入到局部最优中。
- 如果我们直接采用建模的方式来解决生成式问题,那么我们通常需要构造一个相对复杂的模型,或者说参数很多的模型,来获得较大的Capacity。这样就导致我们优化的过程非常的耗时(Computationally Expensive)。
上述三个问题的存在,让我们对直接建模这个思路产生了动摇,至少直接建模并不适用于所有的场景。所以我们在直接建模的基础上做出修改,引入了隐变量的概念。我们认为数据的生成是受到隐变量的影响的。比如手写数字生成的任务,我们在生成数字的时候,会首先考虑,我们要生成的是数字几啊,因为我们只有10种数字可以生成,这个数字几就是我们的隐变量。有了这个隐变量,我们就不再是漫天生成数字了,我们只有10个方向去生成,这大大的缩小了我们的生成空间,降低了计算量。
引入隐变量,用数学公式可以表示为:
模型的学习过程也因为隐变量的引入发生了改变。我们最终的目标还是要计算,我们对隐变量的概率建个模,参数是。我们要找到能让值最大的参数。但是现在是要把隐变量的所有可能取值都找到,然后求一个上面这样的积分,来确定的值,从而通过比较不同的对应的的值来确定哪个才是最合适的。是不是觉得天衣无缝?对于手写数字来说,我们的隐变量的取值只有10个,所以这个积分就退化成了只有10项的求和。但是对于很多其他问题,这个隐变量的取值就有可能变得非常多,又不太好做了。所以我们的做法是不去计算积分了,我们做了一步近似。我们就采样一个隐变量,我们希望挑出来的参数能够在这一个隐变量上表现好就行了。What?这样的近似是不是差的太多?这样真的呆胶布吗?答案是肯定的。我们虽然用单个变量代替了积分,或者说,代替了期望值,但是我们机器学习的过程是在不断的迭代的随机过程(Stochastic process)。简单来说,就是我们会找一个又一个的sample,重复的进行优化,理论上讲依然能够得到最优解(可以参考机器学习的学习理论)。总而言之,就是这样用单样本代替期望是可行的。
这个过程是不是听起来又是很合理?但是存在一个问题,我们说想要去采样隐变量,但是从什么分布里采样?从?可以吗?可以。但是我们再思考一个问题,真的所有的隐变量都是平等的吗?回到手写数字的例子,如果我们现在要生成的是数字7,那么隐变量如果是0,8这种带圆圈的概率是不是不大。假如我们采样到了一个隐变量代表的是数字0,那么这一次采样是不是相当于浪费了?你本来就不怎么能指导我做这一次生成呀。所以,为了减少这样的无效采样,从而进一步的降低计算量,我们并不是从中采样的,而是从中采样的。
好了,现在又有了一个新问题,这个我们知道嘛?答案是不知道,不知道怎么办?不知道那就去求?用什么样的方法去求?用机器学习的去求,和前面对建模一样,我们在这里对建模为,然后求个,再梯度下降去优化它。常见的做法是把建模为一个正态分布:
讲到这里VAE的主体框架已经出来了:
我们用Q采样出来一个隐变量,然后我们根据这个隐变量,利用生成新的图片。我们从这个过程中,计算损失值,通过梯度下降的方式,不断优化和的参数,从而我们能够生成越来越好的图片
这个框架到目前为止已经可以说是相对很完整了,里面呢有两个函数需要去优化:和。这两个函数我们都用神经网络去建模,但是我们依然需要做一件事,就是去定义一个损失函数。我们这样思考一下,我们定义损失函数,是为了能够让和更好。我们首先来考虑,如何让变得更好?我们回忆一下,是我们定义出来用来估计分布的,最好的当然就是能够跟一毛一样啦。那么我们很自然的就想要把目标函数,或者说损失值定义成这两个分布之间的差距了。而计算分布差距,最常用的metric之一就是KL 散度。所以,我们想要让这个公式最小化:
有人可能要说啦:这,怎么最小化?我们不知道这个是啥我们才估计的呀。没错,不过我们可以试着把这个公式变个形式,试试看:
$$
\begin{aligned}
KL(Q(z|X);||;P(z|X)) & =E(logQ(z|X)-logP(z|X))\
&=E(logQ(z|X))-E(logP(z|X))\
&=E(logQ(z|X))-E(log(\frac{P(X|z)P(z)}{P(X)}))\
&=E(logQ(z|X))-E(logP(X|z))-E(logP(z))+E(logP(X))\
&=E(logQ(z|X))-E(logP(X|z))-E(logP(z))+logP(X)\
&=KL(logQ(z|X)||P(z))-E(logP(X|z))+logP(X)
\end{aligned}
logP(X)-KL(Q(z|X);||;P(z|X))=E(logP(X|z))-KL(logQ(z|X)||P(z))
$$
上面这个公式的转换过程中,没有用到什么很特别的技巧,主要就是贝叶斯公式套进去了一下,我就不多说了。重点来看一下最终的公式形式,我们发现公式左边,恰好是我们想要优化的目标,当我们让左边最大化的时候,不仅和越来越接近,我们的也越来越大,也就是我们的函数的参数也越来越好。一石二鸟!本来我们只是在考虑的问题,现在连带着把的问题也解决了。我们看一下右边是我们能计算的东西吗?答案是肯定的,第一项是我们建模的函数,直接可以得到结果。第二项中的是我们建模的函数,呢是隐变量的分布,我们可以把这个分布定义为一个标准正态分布,因为从这个标准正态分布,理论上讲我们可以映射到任意的输出空间上(当然标准正态分布也只是一个选项,很多别的分布都是可以的)。所以公式右边的两项都是可以求的,我们就可以把这个作为我们的优化目标。至此,我们的计算过程算是完善了,可以用下面这样一张图表示
我们首先有一个输入,对应到我们的例子里就是一张手写数字图片。我们将这个输入到自己定义的函数中,因为是个正态分布函数,所以我们的做法是用两个神经网络(管他们叫encoder)分别去计算期望和方差 。有了这个分布之后,我们就可以采样出来一个隐变量的样本,然后用这个样本在通过神经网络 (管他叫decoder)去生成新的数据样本。在这个计算过程中,我们在encoder里计算了一个损失函数,在decoder里计算了一个损失函数(在数据为正态分布的时候等价于)。
上面描述的过程更完整了,不过还有一个问题,就是中间采样的那一步。我们知道,神经网络的学习依赖于梯度下降,梯度下降就要求整个损失函数的梯度链条是存在的,或者说参数是可导的。我们看decoder的这个损失函数,在计算,除了涉及分布以外,还依赖于隐变量,隐变量又是从分布中采样出来的, 所以我们在对decoder的损失函数进行梯度下降的时候,是要对的参数也梯度下降的。但是因为中间这一步采样,我们的梯度断掉了。采样还怎么知道是什么梯度?所以这里用到了一个小trick: Reparamterization。简单来说就是我们不再是采样了,而是看做按照分布的期望和方差,加上一些小噪音,生成出来的样本。也就是说:
其中这个是一个随机噪声,我们可以从一个标准正态分布中采样得到。
这种做法非常 好理解对吧,我们的每个样本都可以看作是这个样本服从分布的期望,根据方差进行波动的结果。通过这种变换,原来断掉的梯度链恢复啦,我们的梯度下降终于能够进行下去了。修正后的计算过程如下图。
以上就是VAE的主要内容,这里额外说明一点:从decoder的损失函数看,我们希望Q在估计隐变量的分布,而隐变量的分布就是一个标准正态分布,所以我们在实际生成的过程中,不需要用到encoder,只需要从标准正态分布里随便采样一个隐变量就能进行生成了。
与自编码模型Autoencoder比较
很多人可能很熟悉自编码模型,自编码模型英文叫Autoencoder。而我们这个VAE呢,叫Variational Autoencoder,听起来好像关系很大,但是其实关系真的不是很大。只不过我们这个VAE呢也像Autoencoder一样有一个encoder,一个decoder。但是Autoencoder对于隐变量没有什么限制,它的过程就很简单,就是从输入计算一个隐变量,然后再把映射到一个新的,损失函数只有一个,就是比较和计算一个重构损失。这样做呢并没有很好的利用隐变量。但是它最大的缺点是生成出来的东西是和输入的高度相关的,并不能生成出来什么很新奇的玩意,所以Autoencoder一般只能用来做做降噪什么的,并不能真正用来做生成。但是VAE就不一样了,前面我们讲过,VAE在生成阶段是完全抛开了encoder的,隐变量是从标准正态分布里随随便便采样出来的,这就摆脱了对输入的依赖,想怎么生成就怎么生成。
代码实现
"""
@Time : 26/01/2023 @Software: PyCharm @File : model.py
"""
import torchvision
import matplotlib.pyplot as plt
import torch.nn as nn
import torch
import torch.nn.functional as F
class Encoder(nn.Module):
def __init__(self, hidden_dim=2):
super(Encoder, self).__init__()
self.linear1 = nn.Linear(28 * 28, 512)
self.linear2 = nn.Linear(512, hidden_dim)
def forward(self, x):
"""x:[N,1,28,28]"""
x = torch.flatten(x, start_dim=1) # [N,764]
x = self.linear1(x) # [N, 512]
x = F.relu(x)
return self.linear2(x) # [N,2]
class Decoder(nn.Module):
def __init__(self, hidden_dim):
super(Decoder, self).__init__()
self.linear1 = nn.Linear(hidden_dim, 512)
self.linear2 = nn.Linear(512, 28 * 28)
def forward(self, x):
"""x:[N,2]"""
hidden = self.linear1(x) # [N, 512]
hidden = torch.relu(hidden)
hidden = self.linear2(hidden) # [N,764]
hidden = torch.sigmoid(hidden)
return torch.reshape(hidden, (-1, 1, 28, 28))
class AutoEncoder(nn.Module):
def __init__(self, hidden_dim=2):
super(AutoEncoder, self).__init__()
self.name = "ae"
self.encoder = Encoder(hidden_dim)
self.decoder = Decoder(hidden_dim)
def forward(self, x):
return self.decoder(self.encoder(x))
class VAEEncoder(nn.Module):
def __init__(self, hidden_dim=2):
super(VAEEncoder, self).__init__()
self.linear1 = nn.Linear(28 * 28, 512)
self.linear2 = nn.Linear(512, hidden_dim)
self.linear3 = nn.Linear(512, hidden_dim)
self.noise_dist = torch.distributions.Normal(0, 1)
self.kl = 0
def forward(self, x):
x = torch.flatten(x, start_dim=1)
x = self.linear1(x)
x = torch.relu(x)
mu = self.linear2(x)
sigma = torch.exp(self.linear3(x))
hidden = mu + self.noise_dist.sample(mu.shape) * sigma
self.kl = (sigma ** 2 + mu ** 2 - torch.log(sigma) - 1 / 2).sum()
return hidden
class VAE(nn.Module):
def __init__(self, hidden_dim=2):
super(VAE, self).__init__()
self.name = "vae"
self.encoder = VAEEncoder(hidden_dim=hidden_dim)
self.decoder = Decoder(hidden_dim)
self.kl = 0
def forward(self, x):
hidden = self.encoder(x)
self.kl = self.encoder.kl
return self.decoder(hidden)
if __name__ == '__main__':
dataset = torchvision.datasets.MNIST("data", transform=torchvision.transforms.ToTensor(), download=True)
print(dataset[0][0].shape)
核心的VAE 模型代码实现我贴在这里,其余代码已经上传到Github。