VAE教程

好久没看的VAE又不太记得了,重新梳理一下思路,在S同学的指导下,又有了一些新的理解。之前写过一篇关于VAE的入门教程,但是感觉还不够简练,删掉重新写一个哈哈哈。这篇文章主要是从一个熟悉machine learning但是对于VAE一点都不懂的视角,进行写作的,并不涉及很多复杂的理论知识,辅助理解为主。另外,上次那篇文章是先给结论,在慢慢讲细节。这篇文章将会循序渐进逐渐推导出VAE的各种Trick存在的必要性。

VAE 入门

我们首先要明确一点,VAE是一个生成式的模型,什么是生成式的模型?简单来说,就是可以用来生成数据的模型。怎么样才能生成数据呢?就是我们是知道数据的分布P(X)的?有了这个分布之后,我们就可以从这个分布中采样,获得新的数据。

这个思路好像很简单啊,但是问题是这个P(X)是怎么得到的。有很多方法啊,其中包含这样两大类:1. 基于概率的,如MCMC, Variational Inference等。以及2. 基于机器学习的。前面提到过,本文主要面向的是对于机器学习比较熟悉的人,所以这里就对概率方法不多说了,主要讲一下机器学习的思路。

其实机器学习来解决这种问题的思路是很清晰的,大多数的机器学习问题都有这样一个思路。我们想要优化某个目标O,我们先对这个问题建个模型,模型可以表示为某个数学表达式f(x;\theta) ,其中\theta是参数。我们用数据去训练这个模型,然后根据目标O,去调整我们的参数\theta,我们希望训练结束的时候,能够找到一组最优的\theta。对应到我们这个生成式的问题,我们希望能够生成一个新的数据X,那么我们构造一个模型f(x;\theta) ,我们的目标呢就是这个生成的数据越真越好,就是在众多的\theta中,我们希望能够找到一个最好的\theta,能够让这个数据存在的概率P(X;\theta)越大越好。有了目标,我们就能计算出损失函数L,然后就是利用梯度下降,逐步调整参数,最终找到最优。

前文这个过程好像很熟悉,但是存在几个问题:

  1. 我们建模f(x;\theta)其实是根据我们的assumption来的,我们的模型结构,初始参数设置都是根据我们的assumption来的, 但是我们的assumption有可能是错的,而且很有可能是错的。因此引入过多或者过强的assumption都会导致我们的模型效果很差。
  2. 因为我们的assumption和真实数据分布存在偏差,相应的,我们在优化的过程中,很容易陷入到局部最优中。
  3. 如果我们直接采用建模的方式来解决生成式问题,那么我们通常需要构造一个相对复杂的模型,或者说参数很多的模型,来获得较大的Capacity。这样就导致我们优化的过程非常的耗时(Computationally Expensive)。

上述三个问题的存在,让我们对直接建模这个思路产生了动摇,至少直接建模并不适用于所有的场景。所以我们在直接建模的基础上做出修改,引入了隐变量的概念。我们认为数据的生成是受到隐变量z的影响的。比如手写数字生成的任务,我们在生成数字的时候,会首先考虑,我们要生成的是数字几啊,因为我们只有10种数字可以生成,这个数字几就是我们的隐变量z。有了这个隐变量,我们就不再是漫天生成数字了,我们只有10个方向去生成,这大大的缩小了我们的生成空间,降低了计算量。

Pasted image 20230127200838.png

引入隐变量,用数学公式可以表示为:
P(X)=\int P(X|z)P(z)dz=E_z(X|z)
模型的学习过程也因为隐变量的引入发生了改变。我们最终的目标还是要计算P(X),我们对隐变量的概率建个模P(X|z)=f(X,z;\theta),参数是\theta。我们要找到能让P(X)值最大的参数\theta。但是现在是要把隐变量的所有可能取值都找到,然后求一个上面这样的积分,来确定P(X)的值,从而通过比较不同的\theta对应的P(X)的值来确定哪个\theta才是最合适的。是不是觉得天衣无缝?对于手写数字来说,我们的隐变量的取值只有10个,所以这个积分就退化成了只有10项的求和。但是对于很多其他问题,这个隐变量的取值就有可能变得非常多,又不太好做了。所以我们的做法是不去计算积分了,我们做了一步近似。我们就采样一个隐变量,我们希望挑出来的参数\theta能够在这一个隐变量上表现好就行了。What?这样的近似是不是差的太多?这样真的呆胶布吗?答案是肯定的。我们虽然用单个变量代替了积分,或者说,代替了期望值,但是我们机器学习的过程是在不断的迭代的随机过程(Stochastic process)。简单来说,就是我们会找一个又一个的sample,重复的进行优化,理论上讲依然能够得到最优解(可以参考机器学习的学习理论)。总而言之,就是这样用单样本代替期望是可行的。

这个过程是不是听起来又是很合理?但是存在一个问题,我们说想要去采样隐变量z,但是从什么分布里采样?从P(z)?可以吗?可以。但是我们再思考一个问题,真的所有的隐变量都是平等的吗?回到手写数字的例子,如果我们现在要生成的是数字7,那么隐变量如果是0,8这种带圆圈的概率是不是不大。假如我们采样到了一个隐变量代表的是数字0,那么这一次采样是不是相当于浪费了?你本来就不怎么能指导我做这一次生成呀。所以,为了减少这样的无效采样,从而进一步的降低计算量,我们并不是从P(X)中采样的,而是从P(z|X)中采样的。
z=sample(P(z|X))
好了,现在又有了一个新问题,这个P(z|X)我们知道嘛?答案是不知道,不知道怎么办?不知道那就去求?用什么样的方法去求?用机器学习的去求,和前面对P(X|z)建模一样,我们在这里对P(z|X)建模为Q(z|X),然后求个Loss,再梯度下降去优化它。常见的做法是把Q建模为一个正态分布:
Q(z|X)=N(\mu,\sigma^2I)
讲到这里VAE的主体框架已经出来了:

我们用Q采样出来一个隐变量z,然后我们根据这个隐变量,利用P(X|z)生成新的图片\hat{x}。我们从这个过程中,计算损失值L,通过梯度下降的方式,不断优化QP(X|z)的参数,从而我们能够生成越来越好的图片

这个框架到目前为止已经可以说是相对很完整了,里面呢有两个函数需要去优化:QP(X|z)。这两个函数我们都用神经网络去建模,但是我们依然需要做一件事,就是去定义一个损失函数L。我们这样思考一下,我们定义损失函数,是为了能够让QP(X|z)更好。我们首先来考虑Q,如何让Q变得更好?我们回忆一下,Q是我们定义出来用来估计分布P(z|X)的,最好的Q当然就是能够跟P(z|X)一毛一样啦。那么我们很自然的就想要把目标函数,或者说损失值定义成这两个分布之间的差距了。而计算分布差距,最常用的metric之一就是KL 散度。所以,我们想要让这个公式最小化:
KL(Q(z|X)\;||\;P(z|X))
有人可能要说啦:这,怎么最小化?我们不知道这个P(z|X)是啥我们才估计的呀。没错,不过我们可以试着把这个公式变个形式,试试看:
$$
\begin{aligned}
KL(Q(z|X);||;P(z|X)) & =E(logQ(z|X)-logP(z|X))\
&=E(logQ(z|X))-E(logP(z|X))\
&=E(logQ(z|X))-E(log(\frac{P(X|z)P(z)}{P(X)}))\
&=E(logQ(z|X))-E(logP(X|z))-E(logP(z))+E(logP(X))\
&=E(logQ(z|X))-E(logP(X|z))-E(logP(z))+logP(X)\
&=KL(logQ(z|X)||P(z))-E(logP(X|z))+logP(X)

\end{aligned}
交换一下公式的项,我们得到:
logP(X)-KL(Q(z|X);||;P(z|X))=E(logP(X|z))-KL(logQ(z|X)||P(z))
$$
上面这个公式的转换过程中,没有用到什么很特别的技巧,主要就是贝叶斯公式套进去了一下,我就不多说了。重点来看一下最终的公式形式,我们发现公式左边,恰好是我们想要优化的目标,当我们让左边最大化的时候,不仅QP(z|X)越来越接近,我们的P(X)也越来越大,也就是我们的P(X|z)函数的参数也越来越好。一石二鸟!本来我们只是在考虑Q的问题,现在连带着把P(X|z)的问题也解决了。我们看一下右边是我们能计算的东西吗?答案是肯定的,第一项E(logP(X|z))是我们建模的函数,直接可以得到结果。第二项KL(logQ(z|X)||P(z))中的Q是我们建模的函数,P(z)呢是隐变量的分布,我们可以把这个分布定义为一个标准正态分布N(0,I),因为从这个标准正态分布,理论上讲我们可以映射到任意的输出空间上(当然标准正态分布也只是一个选项,很多别的分布都是可以的)。所以公式右边的两项都是可以求的,我们就可以把这个作为我们的优化目标。至此,我们的计算过程算是完善了,可以用下面这样一张图表示

Pasted image 20230127204653.png

我们首先有一个输入X,对应到我们的例子里就是一张手写数字图片。我们将这个X输入到自己定义的函数Q中,因为Q是个正态分布函数,所以我们的做法是用两个神经网络(管他们叫encoder)分别去计算期望\mu和方差\sigma^2 。有了这个分布之后,我们就可以采样出来一个隐变量的样本z,然后用这个样本在通过神经网络P(X|z) (管他叫decoder)去生成新的数据样本f(z)。在这个计算过程中,我们在encoder里计算了一个损失函数KL(Q||P(z)),在decoder里计算了一个损失函数||x-f(z)||^2(在数据为正态分布的时候等价于-logP(X|z))。

上面描述的过程更完整了,不过还有一个问题,就是中间采样的那一步。我们知道,神经网络的学习依赖于梯度下降,梯度下降就要求整个损失函数的梯度链条是存在的,或者说参数是可导的。我们看decoder的这个损失函数,在计算||x-f(z)||^2,除了涉及分布P(X|z)以外,还依赖于隐变量z,隐变量又是从分布Q中采样出来的, 所以我们在对decoder的损失函数进行梯度下降的时候,是要对Q的参数也梯度下降的。但是因为中间这一步采样,我们的梯度断掉了。采样还怎么知道是什么梯度?所以这里用到了一个小trick: Reparamterization。简单来说就是我们不再是采样了,而是看做按照分布的期望和方差,加上一些小噪音,生成出来的样本。也就是说:
z=\mu+e\sigma
其中这个e是一个随机噪声,我们可以从一个标准正态分布中采样得到。
e\sim N(0,1)
这种做法非常 好理解对吧,我们的每个样本都可以看作是这个样本服从分布的期望,根据方差进行波动的结果。通过这种变换,原来断掉的梯度链恢复啦,我们的梯度下降终于能够进行下去了。修正后的计算过程如下图。

Pasted image 20230128180833.png

以上就是VAE的主要内容,这里额外说明一点:从decoder的损失函数看,我们希望Q在估计隐变量的分布,而隐变量的分布就是一个标准正态分布,所以我们在实际生成的过程中,不需要用到encoder,只需要从标准正态分布里随便采样一个隐变量z就能进行生成了。

与自编码模型Autoencoder比较

很多人可能很熟悉自编码模型,自编码模型英文叫Autoencoder。而我们这个VAE呢,叫Variational Autoencoder,听起来好像关系很大,但是其实关系真的不是很大。只不过我们这个VAE呢也像Autoencoder一样有一个encoder,一个decoder。但是Autoencoder对于隐变量没有什么限制,它的过程就很简单,就是从输入x计算一个隐变量z,然后再把z映射到一个新的\hat{x},损失函数只有一个,就是比较x\hat{x}计算一个重构损失。这样做呢并没有很好的利用隐变量。但是它最大的缺点是生成出来的东西是和输入的x高度相关的,并不能生成出来什么很新奇的玩意,所以Autoencoder一般只能用来做做降噪什么的,并不能真正用来做生成。但是VAE就不一样了,前面我们讲过,VAE在生成阶段是完全抛开了encoder的,隐变量是从标准正态分布里随随便便采样出来的,这就摆脱了对输入的依赖,想怎么生成就怎么生成。

代码实现

"""  
    @Time    : 26/01/2023    @Software: PyCharm    @File    : model.py  
"""  
import torchvision  
import matplotlib.pyplot as plt  
import torch.nn as nn  
import torch  
import torch.nn.functional as F  
  
  
class Encoder(nn.Module):  
    def __init__(self, hidden_dim=2):  
        super(Encoder, self).__init__()  
        self.linear1 = nn.Linear(28 * 28, 512)  
        self.linear2 = nn.Linear(512, hidden_dim)  
  
    def forward(self, x):  
        """x:[N,1,28,28]"""  
        x = torch.flatten(x, start_dim=1)  # [N,764]  
        x = self.linear1(x)  # [N, 512]  
        x = F.relu(x)  
        return self.linear2(x)  # [N,2]  
  
  
class Decoder(nn.Module):  
    def __init__(self, hidden_dim):  
        super(Decoder, self).__init__()  
        self.linear1 = nn.Linear(hidden_dim, 512)  
        self.linear2 = nn.Linear(512, 28 * 28)  
  
    def forward(self, x):  
        """x:[N,2]"""  
        hidden = self.linear1(x)  # [N, 512]  
        hidden = torch.relu(hidden)  
        hidden = self.linear2(hidden)  # [N,764]  
        hidden = torch.sigmoid(hidden)  
        return torch.reshape(hidden, (-1, 1, 28, 28))  
  
  
class AutoEncoder(nn.Module):  
    def __init__(self, hidden_dim=2):  
        super(AutoEncoder, self).__init__()  
        self.name = "ae"  
        self.encoder = Encoder(hidden_dim)  
        self.decoder = Decoder(hidden_dim)  
  
    def forward(self, x):  
        return self.decoder(self.encoder(x))  
  
  
class VAEEncoder(nn.Module):  
    def __init__(self, hidden_dim=2):  
        super(VAEEncoder, self).__init__()  
        self.linear1 = nn.Linear(28 * 28, 512)  
        self.linear2 = nn.Linear(512, hidden_dim)  
        self.linear3 = nn.Linear(512, hidden_dim)  
        self.noise_dist = torch.distributions.Normal(0, 1)  
        self.kl = 0  
  
    def forward(self, x):  
        x = torch.flatten(x, start_dim=1)  
        x = self.linear1(x)  
        x = torch.relu(x)  
        mu = self.linear2(x)  
        sigma = torch.exp(self.linear3(x))  
        hidden = mu + self.noise_dist.sample(mu.shape) * sigma  
        self.kl = (sigma ** 2 + mu ** 2 - torch.log(sigma) - 1 / 2).sum()  
        return hidden  
  
  
class VAE(nn.Module):  
    def __init__(self, hidden_dim=2):  
        super(VAE, self).__init__()  
        self.name = "vae"  
        self.encoder = VAEEncoder(hidden_dim=hidden_dim)  
        self.decoder = Decoder(hidden_dim)  
        self.kl = 0  
  
    def forward(self, x):  
        hidden = self.encoder(x)  
        self.kl = self.encoder.kl  
        return self.decoder(hidden)  
  
  
if __name__ == '__main__':  
    dataset = torchvision.datasets.MNIST("data", transform=torchvision.transforms.ToTensor(), download=True)  
    print(dataset[0][0].shape)

核心的VAE 模型代码实现我贴在这里,其余代码已经上传到Github

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 203,547评论 6 477
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,399评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 150,428评论 0 337
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,599评论 1 274
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,612评论 5 365
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,577评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,941评论 3 395
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,603评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,852评论 1 297
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,605评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,693评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,375评论 4 318
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,955评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,936评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,172评论 1 259
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 43,970评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,414评论 2 342

推荐阅读更多精彩内容