这篇paper主要介绍了GAN在文本生成上的应用。GAN在2014年被提出之后,在图像生成领域取得了广泛的研究应用。然后在文本领域却一直没有很惊艳的效果。主要在于文本数据是离散数据,而GAN在应用于离散数据时存在以下几个问题:
- GAN的生成器梯度来源于判别器对于正负样本的判别。然而,对于文本生成问题,RNN输出的是一个概率序列,然后取argmax。这会导致生成器Loss不可导。还可以站在另一个角度理解,由于是argmax,所以参数更新一点点并不会改变argmax的结果,这也使得GAN不适合离散数据。
- GAN只能评估整个序列的loss,但是无法评估半句话,或者是当前生成单词对后续结果好坏的影响。
- 如果不加argmax,那么由于生成器生成的都是浮点数值,而ground truth都是one-hot encoding,那么判别器只要判别生成的结果是不是0/1序列组成的就可以了。这容易导致训练崩溃。
背景知识
什么是策略梯度下降?
对于监督学习算法,我们通常会使用梯度下降来进行优化。梯度下降算法计算损失函数的梯度需要首先计算标签与预测结果的loss。可是对于强化学习问题,根本不存在标签。我们只有通过不断的试错来实现优化。在强化学习中,每一个action都会有一个reward。整个算法就是需要最大化期望reward。对于最大化reward推导出来的公式,就是策略梯度下降算法。
什么是蒙特卡洛采样?
可以根据已知概率分布的函数产生服从此分布的样本X。具体原来可以查看30分钟了解蒙特卡洛采样。
SeqGAN原理
整个算法在GAN的框架下,结合强化学习来做文本生成。 模型示意图如下:
在文本生成任务,seqGAN相比较于普通GAN区别在以下几点:
- 生成器不取argmax。
- 每生成一个单词,则根据当前的词语序列进行蒙特卡洛采样生成完成的句子。然后将句子送入判别器计算reward。
- 根据得到的reward进行策略梯度下降优化模型。