通过迁移训练来定制 TensorFlow 模型

在我上一篇 构建一个基于 TensorFlow 的 Android 应用 的文章最后提到:我们可以通过对现有模型进行迁移训练(retrain)来定制我们自己的模型。

下面就通过对现有的 Google Inception-V3 模型进行 retrain ,对 5 种花朵样本数据的进行训练,来完成一个可以识别五种花朵的模型,并将新训练的模型迁移到 Android 端平台。

相关代码可查看:GitHub 项目地址

安装 TensorFlow (Mac 为例)

其他平台可以直接参考官网说明:Installing TensorFlow

首先检查系统是否安装了 Python

要安装 TensorFlow ,你的系统必须依据安装了以下任一 Python 版本:

  • Python 2.7
  • Python 3.3+

查看 Python 版本的命令:

# Python 2
$ python --version
# Python 3
$ python3 --version

如果你的系统还没有安装符合以上版本的 Python,现在安装。

检查 Pip 是否安装

Pip 是 Python 的安装和包管理工具,要使用本地 pip 安装 TensorFlow,系统上必须安装下面的任一版本的 pip :

  • pip for Python 2.7
  • pip3 for Python 3.n.

pip 或者 pip3 可能在你安装 Python 的时候已经安装了,执行以下任一命令确认系统上是否安装了 pip 或 pip3

$ pip -V  # for Python 2.7
$ pip3 -V # for Python 3.n 

建议使用 pip 或者 pip3 为 8.1 或者更新的版本安装 TensorFlow,如果没有安装,执行以下任一命令安装或更新:

$ sudo easy_install --upgrade pip
$ sudo easy_install --upgrade six

通过 pip 安装 TensorFlow

# Python 2
$ pip install tensorflow
# Python 3
$ pip3 install tensorflow 

通过官方样例测试 TensorFlow 是否正常安装

进入 Python 环境后输入以下代码,当出现 “Hello, TensorFlow!” 时表明已经安装成功,可正常使用 TensorFlow 了。

$ python
...
>>> import tensorflow as tf
>>> hello = tf.constant('Hello, TensorFlow!')
>>> sess = tf.Session()
>>> print(sess.run(hello))
Hello, TensorFlow!
image.png

准备训练样本

前面说到我们要训练花朵的识别,这里我们直接找 Google 提供的一个训练样本。我们为为模型迁移训练专门新建一个文件夹用于存放。

下载并解压得到训练样本

$ cd TensorFlowRetrainInceptionV3
$ curl -O http://download.tensorflow.org/example_images/flower_photos.tgz
$ tar xzf flower_photos.tgz

打开训练样本文件夹 flower_photos ,里面有 5 种类别的花:daisy(雏菊), dandelion(蒲公英), roses(玫瑰), sunflowers(向日葵) , tulips(郁金香),每个类别的大概有 600-700 张训练样本图片。

可以根据自身情况,减少训练样本数量,减少训练时间。

开始训练

下载 retrain 脚本
该脚本会自动下载 google Inception v3 模型相关文件,retrain.py 是 Google 提供的迁移训练脚本。

$ cd TensorFlowRetrainInceptionV3
$ curl -O https://raw.githubusercontent.com/tensorflow/tensorflow/r1.1/tensorflow/examples/image_retraining/retrain.py

启动 TensorBoard
TensorBoard 是为 TensorFlow 训练效果提供可视化的工具,具体效果如下图所示:

$ cd TensorFlowRetrainInceptionV3
$ tensorboard --logdir training_summaries &

启动 TensorBoard 会占用系统 6006 端口 ,再启动一个新的 TensorBoard 之前,必须要 kill 已在运行的 TensorBoard 任务。

 $ pkill -f "tensorboard

启动训练脚本

在运行 retrain.py 脚本时,需要配置一些运行命令参数,比如指定模型输入输出相关名称和其他训练要求的配置。

$ cd TensorFlowRetrainInceptionV3
$ python3 retrain.py \
  --bottleneck_dir=bottlenecks \
  --how_many_training_steps=500 \
  --model_dir=inception \
  --summaries_dir=training_summaries/basic \
  --output_graph=retrained_graph.pb \
  --output_labels=retrained_labels.txt \
  --image_dir=flower_photos

如果不添加--how_many_training_steps=500配置,默认值为4000,会相当耗时,建议测试阶段可以减少这个值。

启动浏览器查看 TensorBoard

等待当前目录下的 bottlenecks 文件夹中的文件生成结束后,可以启动浏览器,在地址栏中输入 localhost:6006 来查看训练进度。

等到训练完成后,我们将得到新生成的 retrained_labels.txtretrained_graph.pb 这两个模型相关文件。

测试重新训练后的模型

同样的,我们先下载测试模型的脚本 label_image.py,测试重新训练后的模型的识别准确率。

$ cd TensorFlowRetrainInceptionV3
$ curl -L https://goo.gl/3lTKZs > label_image.py
$ python3 label_image.py flower_photos/daisy/488202750_c420cbce61.jpg

经过简单的实际测试,对已有样本数据的识别准确率基本在 90% 以上,可以知道重新训练后模型满足使用要求,下面就按照前面的 Android 应用集成 TensorFlow 教程,将新的模型导入到上面的项目中。

将新训练的 TensorFlow 模型移植到 Android 中

下图是完成迁移训练后的新模型文件,新打包出来的 GraphDef 文件(PB文件)达到了87.5 MB 。考虑到我们要将这个模型移植到 Android 端去加载,这不仅会对应用的运行内存造成巨大压力,而且会导致安装包增大很多,对于一个简单的花朵识别应用来说,现在模型文件有些大了。因此,我们要考虑对模型文件进行优化,压缩它的体积。

优化模型文件

如前面所说,重新训练后的模型移植到 Android 平台前需要对模型文件进行优化才行,下面我们就来看看官方推荐的几种方法。

Optimize for inference

通过调用 optimize_for_inference 脚本,会自动删除模型中输入层和输出层之间所有不需要的节点。
同时该脚本还做了一些其他优化以提高运行速度。例如它把显式批处理标准化运算跟卷积权重进行了合并,从而降低了计算量。

1. 用 bazel 工具构建 optimize_for_inference 脚本文件

# 在 tensorflow 项目的根目录(WORKSPACE 文件所在)执行下面的 build 命令
bazel build tensorflow/python/tools:optimize_for_inference

build 完成后脚本文件路径:tensorflow/python/tools/optimize_for_inference.py

如果还没安装 bazel ,建议先看看前一篇文章

2. 调用 optimize_for_inference.py 脚本进行优化

调用脚本时,我们需要提供几个命令参数,比如输入的 PB 文件路径,输出的 PB 文件路径,输入节点名以及输出节点名等。

python3 -m tensorflow.python.tools.optimize_for_inference \
  --input=retrained_graph.pb \
  --output=optimized_graph.pb \
  --input_names="Cast" \
  --output_names="final_result"

查看脚本执行完成后输出的 optimized_graph.pb 文件

可以看到,经过 optimize_for_inference 优化过后的模型依然是非常的大的。经过这一次的优化,文件只是变小了一些,但还不足以我们放到手机端去运行,所以我们要进一步的压缩模型,同时还要保证准确率。

Quantize the network weights

Android 项目中,通常我们把模型 PB 文件放在 assets 文件夹中加载,其实不管是直接打包进 APP 还是进入 APP 后再进行下载,模型文件占用太大的问题还是没得到解决。我们知道 Android 的 APK 文件在构建过程中会进行 zip 压缩。那有没有一种行之有效的方法在不过多的降低精确度的情况下压缩更大的空间呢?

Google 就提供了这么一个脚本,经过这个脚本优化后的模型 PB 文件大小不会改变,但会有更多的可利用的重复性,所以在打包构建APK 包时对 PB 文件进行 zip 压缩后,最终按照中的 PB 文件会缩小大约 3~4 倍的大小。

1. 用 bazel 工具构建 quantize_graph 脚本

# 在 tensorflow 项目的根目录(WORKSPACE 文件所在)执行下面的 build 命令
$ bazel build tensorflow/tools/quantization:quantize_graph.py

build 完成后脚本文件路径:tensorflow/tools/quantization/quantize_graph.py

2. 调用 quantize_graph 脚本进行优化
将生成的 quantize_graph.py 文件拷贝到 retrain 文件夹下,在目录下执行脚本。

输入的参数依然是:输入的 PB 文件路径,输出的 PB 文件路径,输出节点名,这里还有个特别的参数 mode ,这个参数是告诉脚本我们选择哪种压缩方式,这里我们选择了对权重进行四舍五入。

python3 -m quantize_graph \
  --input=optimized_graph.pb \
  --output=rounded_graph.pb \
  --output_node_names=final_result \
  --mode=weights_rounded

可以看到最终的输出文件 rounded_graph.pb 大小并没有改变,下面我们就将优化后的迁移训练模型文件重新导入到我们原来的 Android 项目中。

把新训练的模型导入到 Android 中

同样的,我们把新训练的模型 pb 文件和 labels 文件复制到 assets 文件夹下。

因为新训练的模型,输入和输出层名称也发生的改变,这里要修改之前 TensorFlowImageClassifier.create 方法传入的参数。

   /**
     * retrained inception-v3 model, flower classifier
     */
    private static final int INPUT_SIZE = 299;
    private static final int IMAGE_MEAN = 128;
    private static final float IMAGE_STD = 1f;
    private static final String INPUT_NAME = "Mul";
    private static final String OUTPUT_NAME = "final_result";
    private static final String MODEL_FILE = "file:///android_asset/model/rounded_graph.pb";
    private static final String LABEL_FILE = "file:///android_asset/model/retrained_labels.txt";

最终打包出来的 APK 文件,可以看到压缩后的 pb 文件只有 22 MB

参考

教程:在 Mac OS X 上安装 TensorFlow
当Android开发者遇见TensorFlow
Retrain a tensorflow model based on Inception v3
TensorFlow Mobile模型压缩

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 199,711评论 5 468
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 83,932评论 2 376
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 146,770评论 0 330
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 53,799评论 1 271
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 62,697评论 5 359
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,069评论 1 276
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,535评论 3 390
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,200评论 0 254
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,353评论 1 294
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,290评论 2 317
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,331评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,020评论 3 315
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,610评论 3 303
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,694评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 30,927评论 1 255
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 42,330评论 2 346
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 41,904评论 2 341

推荐阅读更多精彩内容

  • 1. 介绍 首先让我们来看看TensorFlow! 但是在我们开始之前,我们先来看看Python API中的Ten...
    JasonJe阅读 11,715评论 1 32
  • Android 自定义View的各种姿势1 Activity的显示之ViewRootImpl详解 Activity...
    passiontim阅读 171,270评论 25 707
  • 2016.12.12开始尝试古筝,长这么大第一次系统地学习一门乐器。2017年,你好!我会好好成长的!
    四只阅读 205评论 1 0
  • 第一次在那家小店就餐后,我们打了85分。 那天的服务生是一个腼腆的男孩。“您点两菜一汤就可以了,再点吃不完,浪费了...
    老丁子阅读 251评论 2 1
  • 离北京最近的沙漠是库布齐,也是中国第七大沙漠,“库布齐”是蒙古语,意思是弓上的弦,这个弓是黄河,库布齐像弦一样就挂...
    兰苑1972阅读 372评论 1 1