GAN 拟合高斯分布数据Pytorch实现

写在前面:

本文不涉及原理知识,将了解到GAN的基本结构,GAN的训练过程,Pytorch的使用,初学望指正

1、应用

GAN本身是一种生成式模型,所以在数据生成上用的是最普遍的,最常见的是图片生成,常用的有DCGAN WGAN,BEGAN。目前比较有意思的应用就是GAN用在图像风格迁移,图像降噪修复,图像超分辨率了,都有比较好的结果。目前也有研究者将GAN用在对抗性攻击上,具体就是训练GAN生成对抗文本,有针对或者无针对的欺骗分类器或者检测系统等等,但是目前没有见到很典范的文章。好吧,笔者有一个项目和对抗性攻击有关,所以要学习一下GAN。

2、简介

GANs组成:生成器和判别器。结构如图1所示
针对问题: 给定一批样本,训练一个系统能够生成类似的新样本
核心思想:博弈论中的纳什均衡,
判别器D 的目的是判断数据来自生成器还是训练集,
生成器G 的目的是学习真实数据的分布,使得生成的数据更接近真实数据,
两者不断学习优化最后得到纳什平衡点。

D( x) 表示真实数据的概率分布,
G( z) 表示输入噪声z 产生的生成数据的概率分布
训练目标:G( Z)在判别器上的分布D( G( Z) ) 更接近真实数据在判别器上的分布D( X)

图1、GAN的结构

3、例子、拟合高斯分布数据

接下来就来实现我们的例子把,目标是把标准正态分布的数据,通过训练的GAN网络之后,得到的数据x_fake能尽量拟合均值为3方差为1的高斯分布N(3,1)的数据。

3.1首先定义GAN的生成器G

可以看出生成器其实就是简单的全连接网络,当然CNN,RNN等网络都是适合GAN的,根据需要选择。

# Generator  【 N(0,1)->N(3,1)】
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        self.net = nn.Sequential(
            #输入 z:[batch,1]
            nn.Linear(data_dim,h_dim),
            nn.LeakyReLU(True),
            nn.Linear(h_dim,data_dim)
        )
    def forward(self,z):
        output = self.net(z)
        return output

3.2定义GAN的判别器D

可以看出判别器其实也是简单的全连接网络,当然CNN,RNN等网络都是适合GAN的,根据需要选择。

# Generator  【 N(0,1)->N(3,1)】
#  Discriminator  【data -> pred】
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.net=nn.Sequential(
            nn.Linear(data_dim, h_dim),
            nn.LeakyReLU(True),
            nn.Linear(h_dim, 1),
            nn.Sigmoid()
        )
    def forward(self,x):
        output = self.net(x)
        return output

3.3定义真实数据来源函数

def real_data_generator():
    while True:
        data = np.random.randn(batchsz,1)
        data = data + 3
        yield data

3.4 定义主函数,包括了训练过程

在这里想说的是对于判别器和生成器的训练是分开的,训练判别器的时候固定生成器,训练生成器的时候固定判别器,如此循环。本例子中先训练三次判别器,接着训练一次生成器。

def main():
    torch.manual_seed(66)
    np.random.seed(66)
    data_iter = real_data_generator()
    G = Generator().to(device)
    D = Discriminator().to(device)
    optim_G = optim.Adam(G.parameters(), lr=5e-4, betas=(0.5, 0.9))
    optim_D = optim.Adam(D.parameters(), lr=5e-4, betas=(0.5, 0.9))

    for epoch in range(epoches):
        # 1、train Discrimator firstly
        for _ in range(3):
            # 1.1 获取真实数据
            x = next(data_iter)
            x_real = torch.from_numpy(x).float().to(device)  # numpy -> tensor
            # 对真实数据判别
            pred_real = D(x_real)
            # max pred_real
            loss_real = -pred_real.mean()
            # 1.2 获取随机数据
            z = torch.randn(batchsz, data_dim).to(device)
            #  生成假数据
            x_fake = G(z).detach()  # tf.stop_gradient()
            # 对假数据判别
            pred_fake = D(x_fake)
            loss_fake = pred_fake.mean()
            # 计算损失函数
            loss_D = loss_real + loss_fake
            # optimize
            optim_D.zero_grad()
            loss_D.backward()
            optim_D.step()

        # 2、train Generator
        # 2.1 获取随机数据
        z = torch.randn(batchsz, data_dim).to(device)
        # 2.2 Generator生成假数据
        xf = G(z)
        # 2.3 Discrimator对假数据进行判别
        predf = D(xf)
        # 2.4 得到损失函数
        loss_G = -predf.mean()
        # optimize
        optim_G.zero_grad()
        loss_G.backward()
        optim_G.step()

        if epoch % 10 == 0:
            print("轮:",epoch,"   ",loss_D.item(), loss_G.item())
            drow_image(G,epoch)

为了便于理解具体训练过程,图2 、图3展示了判别器和生成器训练时的数据流向,具体就不展开了,参考注释。


图2、判别器训练数据流向
图3、生成器训练数据流向

3.5 画图函数

画图函数敬上

# normfun正态分布,概率密度函数
def normfun(x,mu, sigma):
    pdf = np.exp(-((x - mu)**2) / (2* sigma**2)) / (sigma * np.sqrt(2*np.pi))
    return pdf

def drow_image(G,epoch):
    # 1 画出想要拟合的分布
    data=np.random.randn(batchsz, 1)+3
    x = np.arange(0,6, 0.2)
    y = normfun(x, 3, 1)
    plt.plot(x, y, 'r', linewidth=3)
    plt.hist( data, bins=50, color='grey', alpha=0.5, rwidth=0.9, density=True)


    # 2 画出目前生成器生成的分布
    x = torch.randn(batchsz, data_dim).to(device)
    data=G(x).cpu().detach().numpy()
    mean = data.mean()
    std = data.std()
    print(mean, std)
    x = np.arange(np.floor(data.min())-5, np.ceil(data.max())+5, 0.2)
    y = normfun(x, mean, std)
    plt.plot(x, y, 'g', linewidth=3)
    plt.hist(data, bins=50, color='b', alpha=0.5, rwidth=0.9, density=True)

    # plt 的设置
    title = 'epoch' + epoch.__str__()
    plt.title(title)
    plt.xlabel('value')
    plt.ylabel('Probability')
    title ="./res/"+title+ ".png"
    plt.savefig(title)
    plt.show()

然后调用main()函数就好了

if __name__ == '__main__':
    main()

3.6 效果

红色是目标分布,蓝色是生成分布,还是有一定效果的额。


epoch0
epoch0

4 、小结

感受到是在调参了,请教我如何学习生成(xie)对抗(lun)网络(wen)。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 202,607评论 5 476
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,047评论 2 379
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 149,496评论 0 335
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,405评论 1 273
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,400评论 5 364
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,479评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,883评论 3 395
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,535评论 0 256
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,743评论 1 295
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,544评论 2 319
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,612评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,309评论 4 318
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,881评论 3 306
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,891评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,136评论 1 259
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 42,783评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,316评论 2 342

推荐阅读更多精彩内容

  • GAN,全称为Generative Adversarial Nets,直译为生成式对抗网络。它一方面将产生式模型拉...
    MiracleJQ阅读 3,353评论 0 14
  • 在GAN的相关研究如火如荼甚至可以说是泛滥的今天,一篇新鲜出炉的arXiv论文《Wasserstein GAN》却...
    MiracleJQ阅读 2,212评论 0 8
  • 生成对抗网络基本概念 要理解生成对抗模型(GAN),首先要了解生成对抗模型可以拆分为两个模块:一个是判别模型,另一...
    变身的大恶魔阅读 5,388评论 0 7
  • 摘要:在深度学习之前已经有很多生成模型,但苦于生成模型难以描述难以建模,科研人员遇到了很多挑战,而深度学习的出现帮...
    肆虐的悲傷阅读 11,242评论 1 21
  • GAN原理 生成对抗式网络(GAN)作为一种优秀的生成式模型,在图像生成方面有很多应用,主要有两大特点: 不依赖任...
    LuDon阅读 7,144评论 0 4