在安装完TF之后关注物体检测API有一段时间了,但是因为官方的例子一直没有跑通,所以就搁置了。但是苦于目前找不到使用TF训练个人数据集的成熟实例,所以在转了一圈之后发现这个API其实是绕不开的。
写这篇文章的目的不是逐步记录从数据标定到最后导出模型并测试检测结果的每一个步骤。而是简单记录一下使用这个API的直观感受,来说说这个API究竟适合什么样的项目,工作量有多大,有没有必要入坑。如果想上手这个API,谷歌上有几篇不错的教程,之后会推荐。
Object Detection API 的个人印象
当你从GITHUB下载API的文件时会发现,它是TF research文件夹下的一个项目,和正式发布的TF有一点点区别。官方其实不会特别积极的维护这个API,解决你的配置问题,因为这只是他们的研究项目。
总之,给人一种你爱用不用的感觉。
几个月前Object Detection 是直接放在model下面的,现在路径改变后,示例没法下载所需的模型(提供链接是不能直接访问的,所以我也找不到他们把文件放哪了)。之前把这个issue报上去也没能解决,只能摊手。所以说,示例要配合网络链接跑,这其实是挺坑的一件事。链接一变动,以后会用的人就更少了。
但是,明明那么不好用,为何还要用?
因为里面的模型比较牛。
使用API可以直接训练model zoo里的5个模型。如果你不打算发明新的深度神经网络构架,使用之前比较经典的模型就可以解决大部分的物体识别问题。那么使用这些模型文件配合相应的配置文件,就可以准备开工了。
安装API
如果你已经安装了GPU模式的TF,安装物体检测API应该不是难事。但比较讨厌的一点是,安装并没有放到根目录上,这意味着你每次开机甚至是跑完一个脚本,又要重新添加路径,需要退到object detection文件夹的路径并从新把它添加到PYTHONPATH里。
什么样的项目适合使用Object Detection API
首先,你要保证你的数据集包含上千甚至上万的图片。如果只是几百张的量,可能深度学习并不适合你。
此外,你的数据集不是直接把图片丢进去就可以开始训练,需要使用专门软件做标定。之前也见过用图片名称做标注的例子,但是那样的数据格式并不适合model zoo 里面的模型。他们只能喂进去 PASCAL VOC标准格式的文件。
如果你在训练时只给图片加物体列表,而没有在图片中标定区域,不可能达到上图的训练效果。机器学习很流行的一句话是“garbage in,garbage out”,放在我们这里只谈训练文件格式的选择,通常训练文件放进去是什么格式,训练完成后输出也是什么格式。
所以说,用之前,先准备好至少上千张图片,外加PASCAL VOC格式的标定。至于标签包含多少类别,你可以自由决定,但我个人感觉多多益善,不然深度网络就大材小用了。推荐常用图片标注软件labelImage。标注完成后每一个图片都会对应一个同名xml,xml内包含物体标签名称和所在区域方框的对角线两个顶点坐标。
此外,图片大小也需要注意。谷歌搜图时选择中等大小(300-500像素)的图片比较保险。不然训练时遇到大图,显卡会崩溃。
VOC PASCAL 转TFRecord
TFRecord是TF专用的数据集格式,但是目前官方没有做标注软件,所以还要很麻烦地把VOC格式从XML通过CSV转成TFRecord。谷歌上的教程已经介绍了一些方法,这里就不多说了。教程链接
配置网络并开始训练
到这一步,你不光要准备好模型文件,还要准备配置文件,并在其中设置训练数据的路径,learning rate, batch size等参数。如果还没到做优化,仅仅是想拿程序跑跑看,需要关注一下batch size。虽然这个参数不会特别影响最终的训练效果,但是不同的硬件有必要设置一下每批次训练图片的数量,不然很容易出现显存不足的状况,在训练过程中崩溃 (ResourceExhaustedError ...OOM一类的错误)。如果是个人电脑,训练十来个小时比较常见。开始的前一个小时一定要看看有没有问题,如果没问题才能安心做别的事。
我现在的显卡有3G,batch size 设 25 和 15 时都崩溃过,设成5之后才平安训练到结束(虽说这个设置有点保守)。
训练结束后导出模型并测试
训练结束后,你可以在tensorboard和控制台看到你的网络已经收敛(loss数值缓慢下降),证明网络学到了东西。训练好的模型暂时保存在你指定的文件夹里,但是这些文件不能直接拿来用,需要API里面的export_inference_graph.py导出特定的文件,默认的名字是frozen_inference_graph.pb
现在,你在本地拥有了可以用于测试的模型,可以修改API内的object_detection_tutorial,调用导出的模型做测试。把你想要测试的图片放入test_images看测试的结果。
TF版本问题
如果你还没有安装但是想试试看API,那么请直接安装TF1.4或更新的版本,官方说API只支持1.4之后的版本。
之前的版本也不是不能用,要看运气和你选择的模型。我使用ssd_mobilenet_v1_pets在导出训练好的模型时遇到了小问题,但是通过注释掉出问题的三行解决了问题。 (见issue) 不过用模型embedded_ssd_mobilenet_v1这样做就不行了。
推荐教程
本着不推荐做基准数据(例如MNIST,CIFAR..)的原则,这里只推荐那些用原创数据集做分类的教程。
识别小浣熊
比较早的一篇物体检测API教程,VOC格式转TFRecord是从这篇文章开始的。
识别芝士通心粉
从安装API讲到训练与测试,有视频,是手把手教的风格,对新手帮助很大。但是他做训练数据集只提供一个类别,不是很推荐。