GAN原理
生成对抗式网络(GAN)作为一种优秀的生成式模型,在图像生成方面有很多应用,主要有两大特点:
- 不依赖任何先验假设。传统的许多方法会假设数据服从某一分布,然后使用极大似然去估计数据分布。
-
生成real-like样本的方式非常简单。GAN生成real-like样本的方式通过生成器的前向传播。
GAN的主要灵感来自于博弈论中的零和博弈的思想,应用到深度学习的神经网络中来就是,通过生成网络G和判别网络D不断博弈,以便使生成器G学习到数据的分布。生成器G和判别器D的主要功能是: - G是生成式的网络,接收一个随机的噪声,生成图像;
- D是一个判别网络,判别一张图片是不是真实的。
训练过程中,生成器的目标是尽可能生成真实的图片使判别器认为是真的,而判别器的目标是尽可能的辨别出生成器生成的假图像和真实的图像。这样G和D就构成了一个博弈的过程,最终会达到一个平衡,这个平衡成为纳什平衡。
GAN的目标函数定义为:
这个目标函数可以分为两个部分来理解:
判别器的优化通过实现,为判别器的目标函数,其第一项表示对于从真实数据分布中采用的样本 ,其被判别器判定为真实样本概率的数学期望。对于真实数据分布中采样的样本,其预测为正样本的概率当然是越接近1越好。因此希望最大化这一项。第二项表示:对于从噪声分布当中采样得到的样本经过生成器生成之后得到的生成图片,然后送入判别器,其预测概率的负对数的期望,这个值自然是越大越好,这个值越大, 越接近0,也就代表判别器越好。
生成器的优化通过实现。注意,生成器的目标不是,即生成器不是最小化判别器的目标函数,生成器最小化的是判别器目标函数的最大值,判别器目标函数的最大值代表的是真实数据分布与生成数据分布的JS散度,JS散度可以度量分布的相似性,两个分布越接近,JS散度越小。
在实际训练中,生成器和判别器采取较训练,即先训练D,然后训练G,不断的往复。
GAN有两个特点:
- 相比较传统的模型,存在两个不同的网络,训练方式为对抗式训练。
- G的梯度信息来自于判别器D,而不是样本数据。
GAN的优点:
- 相比于其他的生成模型,GAN只利用了反向传播,不需要复杂的马尔科夫链。而且可以生成更清晰真实的样本。
- GAN可以用到很多场景上,比如图片风格迁移、超分辨率、图像补全等等。
GAN的缺点:
- 训练需要达到纳什平衡,现在还没有一个很好的达到纳什平衡的方法。
- GAN不适合处理离散形式的数据。
- GAN存在训练不稳定、梯度消失、模式崩溃等问题。原因:GAN是采用对抗的训练方式,G的损失来自于D,因此G训练的效果的好坏与否是由D决定的。如果某一次G生成的样本可能并不是真是的,但是D还是给出了正确的评价,或者说G生成的结果中的某些特征得到了D的认可,那么G就认为自己生成的结果是正确的,这样子自我欺骗下去就会导致最终生成的额结果会丢失一些信息,特征不全。
GAN的改进
DCGAN
DCGAN是GAN早期效果最好的GAN。
DCGAN主要贡献:引入了卷积并给了比较优雅的网络结构,DCGAN的生成器何判别器几乎是对称的。
DCGAN使用CNN结构来稳定GAN的训练,并使用了一些trick:
- Batch Normalization
- 使用转置卷积进行上采样
- 使用leaky relu作为激活函数。
WGAN
使用Wasserstein 距离代替KL散度,作为loss函数。
WGAN也可以用最优传书理论来解释,WGAN的生成器等价于求解最优传输映射,判别器等价于计算Wasserstein 距离,即最优的传输代价。代码实现部分具有如下:
- 判别器左右一层去掉sigmoid
- 生成器和判别器loss不取log
- 每个更新判别器的参数之后把他们的绝对值截断到不超过一个固定的常数c。
LSGAN
LSGAN的loss如下:
其中a=1, b=0.
LSGAN的有点:
- 稳定训练:解决了传统GAN训练过程中的梯度饱和问题
- 改善生成质量:通过惩罚原理判别器决策边界的生成样本来实现。
GAN训练过程中的问题
理论上存在的问题
经典的GAN判别器的两种loss,分别为:
第一个loss:当判别器达到最优的时候,等价于最小化生成分布于真实分布之间的JS散度,由于随机生成分布很难与真实分布有不可忽略的重叠以及JS散度的突变特性,使得生成器面临梯度消失的问题。
第二个loss:在最优判别器下,等价于既要最小化生成分布与真实分布之间的KL散度,又要最大化JS散度,相互矛盾,使得梯度不稳定,而且KL散度的不对成型使得生成器宁可丧失多样性也不远丧失准确性,导致模式崩溃的现象。
在实践中存在的问题
GAN在实践中存在两个问题:
- GAN在实际过程中,很难达到理论上的纳什平衡
- GAN的优化目标是一个极小极大问题,即也就是说,优化生成器时,最小化。由于学习过程是迭代优化的,要保证最大化,需要迭代非常多次,导致训练的时间很长。
极小极大值问题不等于极大极小值问题。生成器先生成一些样本,然后判别器给出错误的判断结果并惩罚生成器,于是生成器调整生成的样本的概率分布。这样往往会导致生成器变得很懒,只生成一些简单的重复的样本,缺乏多样性。
一些训练技巧
稳定训练的技巧
特征匹配:使用判别器某一层的特征作为输出,来计算真实图片和生成图片之间的距离,作为loss。
标签平滑:把标签1替换成0.8-1.0的随机数。
PatchGAN:PatchGAN相当于对图像的每一个小的patch就行判别,这样可以使生成器生成更加锐利清晰的边缘。
具体做法:输入256256的图像到判别器,输出44的置信图。置信图中的每一个像素代表当前patch是真实图像的置信度,当前图像patch的大小就是感受野的大小,最后将所有patch的loss平均作为最终的loss。
模式崩溃的解决方法
目标函数:更新k次生成器,每次的参考loss是迭代之后的判别器的loss(这个loss只更新生成器,不更新判别器)。这样使生成器考虑到了后面k次判别器的变化情况,避免在不同的mode之间切换。
网络结构:采用多个生成器,一个判别器以保障样本生成的多样性,结构如下:
GAN的应用
GAN及其改进算法大都是围绕GAN两大常见问题改进的:模式崩溃和训练不稳定。
由于GAN在生成样本过程中不需要显式建模任何数据分布就可以生成real-like的样本,所以GAN在图像、文本,语音等诸多领域都有广泛的应用。
未来的方向
GAN的训练崩溃和模式崩溃等问题有待进一步研究。
参考文献
[1]万字综述之生成对抗网络(GAN)
[2]WGAN介绍
[3]令人拍案叫绝的Wasserstein GAN