利用GAN生成CIFAR10图片

要点

  • GAN (Generative Adversarial Networks,生成对抗网络)是一种可以训练生成模型(generative model)的架构
  • GAN由判别器(discriminator)和生成器(generator)两部分组成,判别器用于识别生成器结果的真实性,即生成的图片是真实的还是计算机生成的,而生成器则根据判别器的识别结果努力生成虚假(fake)但看似真实(plausible)的图片以欺骗判别器。在训练过程中,生成器会根据判别器的性能不断更新自己的模型权重,判别器和生成器不断地进行对抗,犹如一个动态的博弈过程,并最终趋于平衡,即判别器无法正确识别图片的真假,生成器生成的图片接近于真实的图片
  • 模型架构上,GAN是生成器和判别器的堆叠,生成器以隐空间(latent space)中的随机点为输入,输出结果为图片样本,而判别器以生成器输出的图片样本为输入,其输出结果是图片的真假,判别器的输出结果将用于更新生成器的模型权重
  • 当判别器被看作一个独立模型时,可单独对其进行训练,因为判别器只关心图片样本的真假。当判别器和生成器被看作一个整体时,判别器的各层在训练过程中需保持冻结,以避免被虚假样本过度训练。以上两点看似矛盾,实际上可通过tf.keras API巧妙地实现:一个模型可被训练还是已被冻结,这个属性只有模型被编译后才能影响模型,具体细节见代码部分
  • 当判别器和生成器被看作一个整体时,生成器输出的图片样本都要标记为“真”即class = 1。这样做的原因是,当判别器认为图片样本为“假”(即class = 0)或者为“真”的概率较低时,后向传播过程会视其为巨大误差,并据此更新生成器的模型权重以纠正这一误差,也就是让生成器更好地生成虚假样本
  • 判别器没有pooling layer,而是采用2x2的stride,其效果和pooling layer类似
  • 隐空间是一个向量空间(呈高斯分布),其本身没有意义,但是生成器可以赋予隐空间意义。经过训练后,隐向量空间可看作是生成图片的压缩表示
  • deconvolution有两种实现形式,第一种是先上采样再卷积(UpSampling2D→Conv2D),第二种是直接采用Conv2DTranspose,本文代码部分选择第二种形式
  • 将Conv2DTranspose层中的stride配置为2x2,可使输入的feature map的面积增大4倍,同时,将kernel size的大小设置为stride的倍数(比如4x4)还可避免出现checkerboard pattern
  • LeakyReLU的slope建议设为0.2
  • CIFAR-10数据集有60000张32x32彩色图片,包含10个分类,如青蛙、鸟、猫、船、飞机等等,由CIFAR(Canadian Institute For Advanced Research)开发,图片尺寸较小,主要用于计算机视觉研究

代码部分

# load required libraries

import tensorflow as tf
import matplotlib.pyplot as plt
from keras.utils.vis_utils import plot_model
import numpy as np
tf.__version__

'2.0.0'

# load CIFAR10 datasets

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
# print the shapes of training and test data

x_train.shape, y_train.shape, x_test.shape, y_test.shape

((50000, 32, 32, 3), (50000, 1), (10000, 32, 32, 3), (10000, 1))

# plot training data

fig = plt.gcf()
fig.set_size_inches(10,10)
for i in range(49):
    plt.subplot(7,7,1+i)
    plt.imshow(x_train[i])
image.png
# define the standalone discriminator model

def discriminator_model():
    model = tf.keras.models.Sequential([
        tf.keras.layers.Conv2D(filters = 64, kernel_size = (3,3), padding = 'same', input_shape = (32,32,3)),
        tf.keras.layers.LeakyReLU(alpha = 0.2),
        tf.keras.layers.Conv2D(filters = 128, kernel_size = (3,3), padding = 'same', strides = (2,2)),
        tf.keras.layers.LeakyReLU(alpha = 0.2),
        tf.keras.layers.Conv2D(filters = 128, kernel_size = (3,3), padding = 'same', strides = (2,2)),
        tf.keras.layers.LeakyReLU(alpha = 0.2),
        tf.keras.layers.Conv2D(filters = 256, kernel_size = (3,3), padding = 'same', strides = (2,2)),
        tf.keras.layers.LeakyReLU(alpha = 0.2),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dropout(0.4),
        tf.keras.layers.Dense(units = 1, activation = 'sigmoid')
    ])
    model.compile(loss = 'binary_crossentropy', 
              optimizer = tf.keras.optimizers.Adam(lr = 0.0002, beta_1 = 0.5), metrics = ['accuracy'])
    
    return model

提示:判别器没有pooling layer,而是采用2*2的stride,其效果和pooling layer类似

# show the summary and graph of the discriminator model 

model = discriminator_model()

model.summary()
image.png
# convert unsigned int to float32

x_train = x_train.astype('float32')
x_train = (x_train - 127.5)/127.5

提示:生成器以tanh为激活函数,其生成的像素值范围为[-1,1],因此,真实图片的像素值范围也应从[0,255]标准化为[-1,1]

# generate points in latent space as the inputs of the generator

def generate_latent_points(latent_dim,n_samples):
    x_input = np.random.randn(latent_dim * n_samples)
    x_input = x_input.reshape(n_samples,latent_dim)
    return x_input    
# randomly select n real samples

def generate_real_samples(dataset, n_samples):
    # define random instances
    ix = np.random.randint(0, dataset.shape[0], n_samples)
    # retrieve selected images
    x = dataset[ix]
    # generate class label (label = 1)
    y = np.ones((n_samples,1))
    return x,y

# generate n fake samples with class label

def generate_fake_samples(g_model, latent_dim, n_samples):
    # generate points in latent space
    x_input = generate_latent_points(latent_dim, n_samples)
    # predict outputs
    x = g_model.predict(x_input)
    # generate class label (label = 0)
    y = np.zeros((n_samples,1))
    return x,y
# define the standalone generator model

def generator_model(latent_dim):
    n_nodes = 256*4*4
    model = tf.keras.models.Sequential([
        tf.keras.layers.Dense(units = n_nodes, input_dim = latent_dim),
        tf.keras.layers.LeakyReLU(alpha = 0.2),
        tf.keras.layers.Reshape((4,4,256)),
        # upsample to 8*8
        tf.keras.layers.Conv2DTranspose(filters = 128, kernel_size = (4,4), padding = 'same', strides = (2,2)),
        tf.keras.layers.LeakyReLU(alpha = 0.2),
        # upsample to 16*16
        tf.keras.layers.Conv2DTranspose(filters = 128, kernel_size = (4,4), padding = 'same', strides = (2,2)),
        tf.keras.layers.LeakyReLU(alpha = 0.2),
        # upsample to 32*32
        tf.keras.layers.Conv2DTranspose(filters = 128, kernel_size = (4,4), padding = 'same', strides = (2,2)),
        tf.keras.layers.LeakyReLU(alpha = 0.2),
        # output layer
        tf.keras.layers.Conv2D(filters = 3, kernel_size = (3,3), activation = 'tanh', padding = 'same')      
    ])
    return model
# show the summary and graph of the generator model

model = generator_model(100)

model.summary()
image.png
# define gan model (only generator model can be updated)

def gan_model(g_model, d_model):
    # freeze discriminator model
    d_model.trainable = False
    
    model = tf.keras.models.Sequential([
        g_model,
        d_model
    ])
    
    model.compile(loss = 'binary_crossentropy', 
              optimizer = tf.keras.optimizers.Adam(lr = 0.0002, beta_1 = 0.5))
    
    return model
# show the summary and graph of the gan model

latent_dim = 100

g_model = generator_model(latent_dim)

d_model = discriminator_model()

gan_model = gan_model(g_model,d_model)

gan_model.summary()
image.png
# show and save the plots of generated images

def save_plot(examples, epoch, n = 7):
    # scale from [-1,1] to [0,1]
    examples = (examples + 1)/2.0
    # make plot
    for i in range(n*n):
        plt.subplot(n,n,i+1)
        plt.imshow(examples[i])
    
    # save plots
    filename = 'generated_plot_e%03d.png' % (epoch+1)
    plt.savefig(filename)

# evaluate discriminator model performance, display generated images, save generator model

def summarize_performance(epoch, g_model, d_model, dataset, latent_dim, n_samples = 150):
    # prepare real samples
    x_real, y_real = generate_real_samples(dataset, n_samples)
    # evaluate discriminator on real samples
    _, acc_real = d_model.evaluate(x_real, y_real, verbose = 0)
    # prepare fake samples
    x_fake, y_fake = generate_fake_samples(g_model, latent_dim, n_samples)
    # evaluate discriminator on fake samples
    _, acc_fake = d_model.evaluate(x_fake, y_fake, verbose = 0)
    # display discriminator performance
    print('>Accuracy real: %.0f%%, fake: %.0f%%' % (acc_real*100, acc_fake*100))
    
    # show and save the plots of generated images
    save_plot(x_fake, epoch)
    
    # save generator model tile file
    filename = 'generator_model_%03d.h5' % (epoch+1)
    g_model.save(filename)
# train gan model

def train_gan(g_model, d_model, gan_model, dataset, latent_dim, n_epochs = 20, n_batch = 128):
    bat_per_epoch = int(dataset.shape[0] / n_batch)
    half_batch = int(n_batch / 2)
    # manually enumerate epochs
    for i in range(n_epochs):
        # enumerate batches
        for j in range(bat_per_epoch):
            # randomly select n real samples
            x_real, y_real = generate_real_samples(dataset, half_batch)
            # update standalone discriminator model
            d_loss1, _ = d_model.train_on_batch(x_real, y_real)
            # generate fake samples
            x_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
            # update standalone discriminator model again
            d_loss2, _ = d_model.train_on_batch(x_fake, y_fake)
            # generate points in latent space as the inputs of generator model
            x_gan = generate_latent_points(latent_dim, n_batch)
            # generate class label for fake samples (label = 1)
            y_gan = np.ones((n_batch,1))
            # update the generator model with discriminator model errors
            g_loss = gan_model.train_on_batch(x_gan, y_gan)
            # display the loss
            print('>%d, %d/%d, d1=%.3f, d2=%.3f g=%.3f' % (i+1, j+1, bat_per_epoch, d_loss1, d_loss2, g_loss))
        
        # evaluate model performance every 5 epochs  
        if (i + 1)%5 == 0:
            summarize_performance(i, g_model, d_model, dataset, latent_dim)

提示:GAN的目标是让生成器生成“看似真实”的图片,然而这些图片的质量高低无法通过客观的误差指标来体现,只能由程序员进行人工判读。换言之,即程序员不检查图片的质量,就不知道什么时候该停止训练。例如,某一个epoch结束后,生成器输出的图片质量很高,此时若不停止训练,之后生成的图片质量会发生波动(GAN的对抗性导致每一个batch后生成器都会发生变化),也可能提升,也可能降低。因此,在实际训练过程中,程序员要周期性地评估判别器分辨真假图片的能力(即分类精度),也要周期性地生成图片并进行人工判读,还要周期性地保存生成器模型

# train gan model

train_gan(g_model, d_model, gan_model, x_train, latent_dim)

1, 1/390, d1=0.376, d2=0.280 g=1.740
1, 2/390, d1=0.351, d2=0.322 g=1.679
1, 3/390, d1=0.274, d2=0.299 g=1.866
1, 4/390, d1=0.301, d2=0.272 g=2.027
1, 5/390, d1=0.257, d2=0.230 g=2.256
1, 6/390, d1=0.204, d2=0.186 g=2.558
……
……
20, 387/390, d1=0.724, d2=0.636 g=0.865
20, 388/390, d1=0.665, d2=0.623 g=0.837
20, 389/390, d1=0.678, d2=0.717 g=0.867
20, 390/390, d1=0.718, d2=0.606 g=0.960
Accuracy real: 51%, fake: 89%

image.png

提示:本代码以20个epochs为示例,每5个epochs评估一次模型性能,20个epochs共评估模型性能4次,生成图片4副,保存模型4个。接下来,就可以用性能最好的生成器生成图片了。

# generate images with final generator model

model = tf.keras.models.load_model('generator_model_020.h5') # load model saved after 20 epochs

latent_points = generate_latent_points(100,100) # generate points in latent space

X = model.predict(latent_points) # generate images

X = (X + 1)/2.0 # scale the range from [-1,1] to [0,1]
# plot the images

fig = plt.gcf()
fig.set_size_inches(20,20)
for i in range(100):
    plt.subplot(10,10,1+i)
    plt.imshow(X[i])
image.png
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 203,179评论 5 476
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,229评论 2 380
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 150,032评论 0 336
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,533评论 1 273
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,531评论 5 365
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,539评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,916评论 3 395
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,574评论 0 256
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,813评论 1 296
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,568评论 2 320
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,654评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,354评论 4 318
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,937评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,918评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,152评论 1 259
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 42,852评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,378评论 2 342

推荐阅读更多精彩内容