写在前面:
本文不涉及原理知识,将了解到GAN的基本结构,GAN的训练过程,Pytorch的使用,初学望指正
1、应用
GAN本身是一种生成式模型,所以在数据生成上用的是最普遍的,最常见的是图片生成,常用的有DCGAN WGAN,BEGAN。目前比较有意思的应用就是GAN用在图像风格迁移,图像降噪修复,图像超分辨率了,都有比较好的结果。目前也有研究者将GAN用在对抗性攻击上,具体就是训练GAN生成对抗文本,有针对或者无针对的欺骗分类器或者检测系统等等,但是目前没有见到很典范的文章。好吧,笔者有一个项目和对抗性攻击有关,所以要学习一下GAN。
2、简介
GANs组成:生成器和判别器。结构如图1所示
针对问题: 给定一批样本,训练一个系统能够生成类似的新样本
核心思想:博弈论中的纳什均衡,
判别器D 的目的是判断数据来自生成器还是训练集,
生成器G 的目的是学习真实数据的分布,使得生成的数据更接近真实数据,
两者不断学习优化最后得到纳什平衡点。
D( x) 表示真实数据的概率分布,
G( z) 表示输入噪声z 产生的生成数据的概率分布
训练目标:G( Z)在判别器上的分布D( G( Z) ) 更接近真实数据在判别器上的分布D( X)
3、例子、拟合高斯分布数据
接下来就来实现我们的例子把,目标是把标准正态分布的数据,通过训练的GAN网络之后,得到的数据x_fake能尽量拟合均值为3方差为1的高斯分布N(3,1)的数据。
3.1首先定义GAN的生成器G
可以看出生成器其实就是简单的全连接网络,当然CNN,RNN等网络都是适合GAN的,根据需要选择。
# Generator 【 N(0,1)->N(3,1)】
class Generator(nn.Module):
def __init__(self):
super(Generator,self).__init__()
self.net = nn.Sequential(
#输入 z:[batch,1]
nn.Linear(data_dim,h_dim),
nn.LeakyReLU(True),
nn.Linear(h_dim,data_dim)
)
def forward(self,z):
output = self.net(z)
return output
3.2定义GAN的判别器D
可以看出判别器其实也是简单的全连接网络,当然CNN,RNN等网络都是适合GAN的,根据需要选择。
# Generator 【 N(0,1)->N(3,1)】
# Discriminator 【data -> pred】
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator,self).__init__()
self.net=nn.Sequential(
nn.Linear(data_dim, h_dim),
nn.LeakyReLU(True),
nn.Linear(h_dim, 1),
nn.Sigmoid()
)
def forward(self,x):
output = self.net(x)
return output
3.3定义真实数据来源函数
def real_data_generator():
while True:
data = np.random.randn(batchsz,1)
data = data + 3
yield data
3.4 定义主函数,包括了训练过程
在这里想说的是对于判别器和生成器的训练是分开的,训练判别器的时候固定生成器,训练生成器的时候固定判别器,如此循环。本例子中先训练三次判别器,接着训练一次生成器。
def main():
torch.manual_seed(66)
np.random.seed(66)
data_iter = real_data_generator()
G = Generator().to(device)
D = Discriminator().to(device)
optim_G = optim.Adam(G.parameters(), lr=5e-4, betas=(0.5, 0.9))
optim_D = optim.Adam(D.parameters(), lr=5e-4, betas=(0.5, 0.9))
for epoch in range(epoches):
# 1、train Discrimator firstly
for _ in range(3):
# 1.1 获取真实数据
x = next(data_iter)
x_real = torch.from_numpy(x).float().to(device) # numpy -> tensor
# 对真实数据判别
pred_real = D(x_real)
# max pred_real
loss_real = -pred_real.mean()
# 1.2 获取随机数据
z = torch.randn(batchsz, data_dim).to(device)
# 生成假数据
x_fake = G(z).detach() # tf.stop_gradient()
# 对假数据判别
pred_fake = D(x_fake)
loss_fake = pred_fake.mean()
# 计算损失函数
loss_D = loss_real + loss_fake
# optimize
optim_D.zero_grad()
loss_D.backward()
optim_D.step()
# 2、train Generator
# 2.1 获取随机数据
z = torch.randn(batchsz, data_dim).to(device)
# 2.2 Generator生成假数据
xf = G(z)
# 2.3 Discrimator对假数据进行判别
predf = D(xf)
# 2.4 得到损失函数
loss_G = -predf.mean()
# optimize
optim_G.zero_grad()
loss_G.backward()
optim_G.step()
if epoch % 10 == 0:
print("轮:",epoch," ",loss_D.item(), loss_G.item())
drow_image(G,epoch)
为了便于理解具体训练过程,图2 、图3展示了判别器和生成器训练时的数据流向,具体就不展开了,参考注释。
3.5 画图函数
画图函数敬上
# normfun正态分布,概率密度函数
def normfun(x,mu, sigma):
pdf = np.exp(-((x - mu)**2) / (2* sigma**2)) / (sigma * np.sqrt(2*np.pi))
return pdf
def drow_image(G,epoch):
# 1 画出想要拟合的分布
data=np.random.randn(batchsz, 1)+3
x = np.arange(0,6, 0.2)
y = normfun(x, 3, 1)
plt.plot(x, y, 'r', linewidth=3)
plt.hist( data, bins=50, color='grey', alpha=0.5, rwidth=0.9, density=True)
# 2 画出目前生成器生成的分布
x = torch.randn(batchsz, data_dim).to(device)
data=G(x).cpu().detach().numpy()
mean = data.mean()
std = data.std()
print(mean, std)
x = np.arange(np.floor(data.min())-5, np.ceil(data.max())+5, 0.2)
y = normfun(x, mean, std)
plt.plot(x, y, 'g', linewidth=3)
plt.hist(data, bins=50, color='b', alpha=0.5, rwidth=0.9, density=True)
# plt 的设置
title = 'epoch' + epoch.__str__()
plt.title(title)
plt.xlabel('value')
plt.ylabel('Probability')
title ="./res/"+title+ ".png"
plt.savefig(title)
plt.show()
然后调用main()函数就好了
if __name__ == '__main__':
main()
3.6 效果
红色是目标分布,蓝色是生成分布,还是有一定效果的额。
4 、小结
感受到是在调参了,请教我如何学习生成(xie)对抗(lun)网络(wen)。