TensorFlow 使用预训练模型 ResNet-50(续)

        上一篇文章 TensorFlow 使用预训练模型 ResNet-50 介绍了使用 tf.contrib.slim 模块来简单导入 TensorFlow 预训练模型参数,进而使用 slim.learning.train 函数来 fine tuning 模型。这一篇文章,在预告的多任务多标签之前,再插入一篇简单的文章,延续 TensorFlow 导入预训练模型并精调神经网络参数的这个主题。这篇文章使用的方法和上一篇的明显不同,不过方法依旧非常简单,只需要使用类 tf.train.Saver 及其成员函数 .restore 即可。

模型定义、训练及预训练参数导入

        首先,我们来梳理一下要使用预训练模型来微调神经网络需要做的事情:1.定义神经网络结构;2.导入预训练模型参数;3.读取数据进行训练;4.使用 Tensorboard 可视化训练过程(此处略,留到以后单独讲)。清楚了以上步骤之后,我们来看如下全部代码:

# -*- coding: utf-8 -*-
"""
Created on Tue May  8 13:58:54 2018

@author: shirhe-lyh
"""

import numpy as np
import os
import tensorflow as tf

from tensorflow.contrib.slim import nets

slim = tf.contrib.slim


def get_next_batch(batch_size=64, ...):
       """Get a batch set of training data.
       
       Args:
           batch_size: An integer representing the batch size.
           ...: Additional arguments.

       Returns:
           images: A 4-D numpy array with shape [batch_size, height, width, 
               num_channels] representing a batch of images.
           labels: A 1-D numpy array with shape [batch_size] representing
               the groundtruth labels of the corresponding images.
       """
       ...  # Get images and the corresponding groundtruth labels.
       return images, labels


if __name__ == '__main__':
    # Specify which gpu to be used
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    
    batch_size = 64
    num_classes = 5
    num_steps = 10000
    resnet_model_path = '···/resnet_v1_50.ckpt'  # Path to the pretrained model
    model_save_path = '···/model'  # Path to the model.ckpt-(num_steps) will be saved
    ...  # Any other constant variables

    inputs = tf.placeholder(tf.float32, shape=[None, 224, 224, 3], name='inputs')
    labels = tf.placeholder(tf.int32, shape=[None], name='labels')
    is_training = tf.placeholder(tf.bool, name='is_training')
    
    with slim.arg_scope(nets.resnet_v1.resnet_arg_scope()):
        net, endpoints = nets.resnet_v1.resnet_v1_50(inputs, num_classes=None,
                                                     is_training=is_training)
        
    with tf.variable_scope('Logits'):
        net = tf.squeeze(net, axis=[1, 2])
        net = slim.dropout(net, keep_prob=0.5, scope='scope')
        logits = slim.fully_connected(net, num_outputs=num_classes,
                                      activation_fn=None, scope='fc')
        
    checkpoint_exclude_scopes = 'Logits'
    exclusions = None
    if checkpoint_exclude_scopes:
        exclusions = [
            scope.strip() for scope in checkpoint_exclude_scopes.split(',')]
    variables_to_restore = []
    for var in slim.get_model_variables():
        excluded = False
        for exclusion in exclusions:
            if var.op.name.startswith(exclusion):
                excluded = True
        if not excluded:
            variables_to_restore.append(var)
            
    losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=labels, logits=logits)
    loss = tf.reduce_mean(losses)
    
    logits = tf.nn.softmax(logits)
    classes = tf.argmax(logits, axis=1, name='classes')
    accuracy = tf.reduce_mean(tf.cast(
        tf.equal(tf.cast(classes, dtype=tf.int32), labels), dtype=tf.float32))
    
    optimizer = tf.train.AdamOptimizer(learning_rate=0.0001)
    train_step = optimizer.minimize(loss)
    
    init = tf.global_variables_initializer()
    
    saver_restore = tf.train.Saver(var_list=variables_to_restore)
    saver = tf.train.Saver(tf.global_variables())
    
    config = tf.ConfigProto(allow_soft_placement = True) 
    config.gpu_options.per_process_gpu_memory_fraction = 0.95
    with tf.Session(config=config) as sess:
        sess.run(init)
        
        # Load the pretrained checkpoint file xxx.ckpt
        saver_restore.restore(sess, resnet_model_path)
        
        for i in range(num_steps):
            images, groundtruth_lists = get_next_batch(batch_size, ...)
            train_dict = {inputs: images, 
                          labels: groundtruth_lists,
                          is_training: True}

            sess.run(train_step, feed_dict=train_dict)
            
            loss_, acc_ = sess.run([loss, accuracy], feed_dict=train_dict)
            
            train_text = 'Step: {}, Loss: {:.4f}, Accuracy: {:.4f}'.format(
                i+1, loss_, acc_)
            print(train_text)
                
            if (i+1) % 1000 == 0:
                saver.save(sess, model_save_path, global_step=i+1)
                print('save mode to {}'.format(model_save_path))

        main 函数的第一行 os.environ["CUDA_VISIBLE_DEVICES"] = "0" 指定系统(服务器或电脑等)中哪些 GPU 是对 TensorFlow 可见的,这里说 0 号 GPU 是可见的。如果你的系统只有一个 GPU,那么这行是多余的(请注释掉);如果你的系统有多个 GPU,那么你可以通过 , 分隔的方式指定同时使用多个 GPU,如 os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" 表示使用 0 号和 1 号 GPU。接下来,定义了几个变量,顾名思义,分别指定了批量大小、要分类的类别数目、迭代的总次数、预训练模型 ResNet-50 的存储路径和训练后模型要保存到的路径(这里最后的 model 不是指文件夹,而是说最后保存的模型要命名为 model-xxx.ckpt,xxx 表示模型保存时对应的训练次数。如果你想保存为其它格式的名字,请任取)。

        再接着往下看,定义了 3 个占位符,分别是:一个批量的图像数据入口、对应的类标号入口和是否是将模型用于训练还是推断的界定布尔值。再接着便到了模型定义阶段。第一个 with 语句,直接使用 slim 模块封装好的 ResNet-50 模型,注意要将最后的输出层禁用,即要将参数 num_classes 设置为 None。第二个 with 语句,是自定义的输出层,指定模型最后输出 num_classes 个值。这里,为了与前一个 with 语句相区别,以便以后导入预训练模型参数,一定要指定变量的命名空间 tf.variable_scope(),其中的名字可自取,只需要保证不与 ResNet-50 定义中的变量命名空间重名即可。这个 with 语句有三句话,第一句是去掉多余的维度:假设输入到 ResNet-50 的数据形状为 [None, 224, 224, 3],则上一个 with 语句的输出形状为 [None, 1, 1, 2048],这句话的作用就是将数据的形状变为 [None, 2048],去掉其中形状为 1 的第 1,2 个索引维度;第二句话 net = slim.dropout() 使用了 CNN 常用的正则化手段 dropout;第三句话使用一个全连接层指定输出大小。

        到此,模型定义就完成了。接下来,我们导入预训练参数。我们前面定义的神经网络包含两个阶段:1.使用 ResNet-50 来提取图像特征阶段;2.使用自定义的输出层来得到分类结果,它们分别对应前面的两个 with 语句。我们要导入的 ResNet-50 预训练参数只对应第一个阶段的模型参数,因此导入时就需要把第二阶段的参数排除掉。所以,接下来的一串代码

checkpoint_exclude_scopes = 'Logits'
exclusions = None
if checkpoint_exclude_scopes:
    exclusions = [
        scope.strip() for scope in checkpoint_exclude_scopes.split(',')]
variables_to_restore = []
for var in slim.get_model_variables():
    excluded = False
    for exclusion in exclusions:
        if var.op.name.startswith(exclusion):
            excluded = True
    if not excluded:
        variables_to_restore.append(var)

的作用就是将所有需要恢复的变量取出来,所有预训练模型中没有的变量排除掉。代码首先指定要排除的变量的命名空间为 Logits,对应上面的第二个 with 语句下的所有变量。然后通过 slim.get_model_variables 函数得到所有的模型变量,对这些变量进行遍历,如果变量不在 Logits 这个命名空间中就记录到要从预训练模型恢复的变量列表 variables_to_restore 中,否则就排除掉。

        当我们记录下所有要恢复的变量之后,就可以做其它任何事情了,如接下来的代码定义了损失函数 loss、准确率 accuracy、优化器 optimizer 等。然后,定义了两个 tf.train.Saver 类的对象:

saver_restore = tf.train.Saver(var_list=variables_to_restore)
saver = tf.train.Saver(tf.global_variables())

其中,第一个对象 saver_restore 用于恢复预训练模型中的参数,var_list 指定要从预训练模型中恢复的变量列表;第二个对象 saver 用来保存我们微调过后的训练模型,因此要保留所有的模型变量,从而 var_list=tf.global_variables()。接下来的 config 语句是说只使用每个 GPU 的 95% 的显存(config.gpu_options.per_process_gpu_memory_fraction = 0.95
),还允许 TensorFlow 将计算放到 CPU 上(allow_soft_placement=True),如果你不想这么设置直接将有关 config 的所有语句删除即可。

        再然后,通过 with tf.Session() as sess 生成一个会话,然后将所有变量作为结点 sess.run(init) 写入 TensorFlow graph。之后,又到了一个重要阶段,即导入预训练模型参数阶段。导入预训练模型参数只需要一句话就可以完成:

saver_restore.restore(sess, resnet_model_path)

当然,我们前面还是做了一些必要的准确: saver_restore = tf.train.Saver(var_list=variables_to_restore)。最后,就是对模型使用自己的数据进行微调以及模型保存等,略过。关于模型保存,请访问 TensorFlow 模型保存与恢复

说明】:使用这篇文章的 train.py 训练的模型,在进行模型预测时,仍然要将 is_training 设置为 True,否则测试准确率会差很多。这是 TensorFlow batch_norm 层的问题导致的。建议使用 《TensorFlow 使用预训练模型 ResNet-50》 这篇文章的代码,那里没有这个问题。

再次预告:下一篇文章将要介绍如何用 TensorFlow 来训练多任务多标签模型,敬请期待!

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 202,980评论 5 476
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,178评论 2 380
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 149,868评论 0 336
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,498评论 1 273
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,492评论 5 364
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,521评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,910评论 3 395
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,569评论 0 256
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,793评论 1 296
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,559评论 2 319
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,639评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,342评论 4 318
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,931评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,904评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,144评论 1 259
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 42,833评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,350评论 2 342

推荐阅读更多精彩内容