TuriCreate + CoreML实践

机器学习?深度学习?作为一个已经毕业几年,已经把大多数相关知识都还给老师的人来说,要实际应用这些知识到工作中,还是相当有难度的。但是!我们毕竟是iOS开发啊,有Apple罩着我们,Apple在2017年的WWDC上就给了我们一个机器学习框架Core ML,且还在不断更新完善中,今年已经升级到Core ML 2了。

coreML

本文就将详细介绍如何训练集成从而识别物品,且实际落地到产品之中。也就是是主要演示图像分类的问题。

[toc]

第一部分:训练模型

选什么框架?

首先是选什么框架来来训练模型呢? 通用主流的有Keras、Caffe、TensorFlow等。既然是要集成到iOS App中,显然首选的是Apple自家的框架,Turi Create

XCode 10上,有新的Create ML框架,可在playground上可视化训练结果。由于本人主力机器还是XCode 9, <del>对swift不熟悉</del> , 因此不采用,且Turi Create更灵活。大家也可以试试这个新框架。

如果是考虑多个平台,或者对其他训练模型更有经验,可以选择其他的,Apple也提供了工具来转换成iOS所需的mlmodel, 也就是coremltools,可转的为:

  • Keras (1.2.2, 2.0.4+)
  • Xgboost (0.7+)
  • scikit-learn (0.17+)
  • libSVM

如果是TensorFlow则需要使用 tfcoreml 来转换。

Apple对于开发者来说,还是很周到,友好的。

选什么素材?

训练素材方面,Turi Create的要求并不严格,.JPEG和PNG都支持,不过为了识别的更准确些,建议素材的选择上要多一些跟实际应用场景接近的图片。本文的例子将是识别手边的实物,分别是路飞手办、索隆手办、WallE、严选运动水杯以及普通的杯子。所以在训练素材方面,walleE、杯子会全部使用Google图片,路飞、索隆会使用Google的图片以及部分iPhone实拍图片(为了试验贴近实际使用场景),严选运动水杯则全部采用线上的商品评价图来进行训练。

这里是作为例子来演示各种素材来源的效果,杯子和运动水杯之间也有交集并不是严格的不同分类,仅作效果展示,如果实际应用,建议还是更好的分类,使用更合理的训练集。

以下是部分WallE的训练图:

wallE训练图
wallE训练图
wallE训练图
wallE训练图
wallE训练图
wallE训练图

本文中使用的训练集数量为5个分类各20张,数量上是偏少的。训练素材上,需放在各自文件夹内,并将整个文件夹命名为mydataset,如图放置即可:

mydataset示意图

这里大家会发现还有一个other的文件夹,具体原因后文中会再介绍。

怎么训练?

环境搭建

因为TuriCreate是个python库,所以首先需要安装的是python, 这个在Mac OS上已经自带了,以下的安装和各类操作,都将以Mac上使用为例,Windows稍微修改下即可。然后需要安装turicreate

pip install -U turicreate

即可。TuriCreate依赖很多其他库,依赖库又依赖其他的库,所以直接安装可能会与当前的各种已安装包出现冲突,错误。这里使用viertualenv来安装会方便很多,创建一个虚拟环境,可以直接安装,详细操作可见TuriCreate文档里的Installation部分。
安装后如图:

viertualenv安装示意图

然后在训练集目录下新建一个turi.py 的文件,用IDE打开,然后就可以开动了!

数据导入

turi.py

#0 导入需要的库
import turicreate as tc
import os

#1. 导入之前放的图片文件夹mydataset
data = tc.image_analysis.load_images('mydataset', with_path=True)

#2. 用python的lambda来提取标签(文件夹名字)
data['name'] = data['path'].apply(lambda path: os.path.basename(os.path.dirname(path)))

#3. 保存下sframe
data.save('mydataset.sframe')

#4. turicreate的可视化
data.explore()
data.show()

这里展示的为图像分类问题,因此#1使用的是image_analysisSFrame是一种可伸缩表格数据,详细定义可可参考github - SFrame

运行后可能会报warning ,这报错也就是之前提到的,TuriCreate支持的是JPEG和PNG格式,其他不支持。

Unsupported image format. Supported formats are JPEG and PNG     file: /Users/Chen/Downloads/testml/mydataset/.DS_Store

然后由#4步骤中的操作,会将数据进行可视化.

有时候explore会卡在loading过程中,可参见此issue, show()可正常使用。

模型训练

在同一个目录中,可新建一个turi_train.py,并上一个过程中保存的mydataset.sframe来进行训练。具体代码如下:
turi_train.py

import turicreate as tc

# 1. 导入之前生成的sframe文件
data = tc.SFrame('mydataset.sframe')

# 2. 分0.8的数据作为训练样本,0.2的作为测试样本
train_data, test_data = data.random_split(0.8)

# 3. 生成模型
model = tc.image_classifier.create(train_data,  target='name')

# 4. 测试样本数据
predictions = model.predict(test_data)

# 5. 计算样本数据的精确性
metrics = model.evaluate(test_data)
print(metrics['accuracy'])

# 6. 保存model
model.save('mydataset.model')

# 7. 导出成mlmodel,以便coreML使用
model.export_coreml('mydataset.mlmodel')

其中#2的操作是为了验证我们训练出来的模型的可信度,以80%的数据来训练,剩下的作为验证样本来看看训练的数据质量怎么样。

会输出类似这样的log:

训练结果的log

需要注意的是,每次训练的结果并不相同,尤其是验证的数据,可能为100%,也可能为80%,跟随机到的结果有关,仅供参考。

最后一行的0.95为样本测试正确率。可以通过打印predictionstest_data['name']来进行对比。

['cup', 'cup', 'cup', 'cup', 'cup', 'luffy', 'luffy', 'other', 'other', 'other', 'other', 'sportcup', 'sportcup', 'sportcup', 'sportcup', 'sportcup', 'sportcup', 'walle', 'walle', 'walle', 'walle', 'walle', 'walle', 'walle', 'walle', 'zoro', 'zoro', 'zoro', 'zoro', 'zoro', 'zoro', 'zoro', 'zoro']

['cup', 'cup', 'cup', 'cup', 'cup', 'luffy', 'luffy', 'other', 'other', 'other', 'other', 'sportcup', 'sportcup', 'sportcup', 'sportcup', 'sportcup', 'sportcup', 'walle', 'walle', 'walle', 'walle', 'walle', 'walle', 'luffy', 'luffy', 'zoro', 'luffy', 'luffy', 'luffy', 'zoro', 'luffy', 'zoro', 'zoro']

这个跟上一个0.95的结果非同一次训练

两者一对比,即可找到错误的判断为哪些,为了更加直观的看到具体是哪些图片出现了误识别,可打印:
print test_data[test_data['name'] != predictions]['path']
然后再使用turicreate.img来预览图片就可以自动打开错误的图片,来看是哪几种图片有问题,以及猜测可能原因。代码如下:

    print test_data[test_data['name'] != predictions]['path']
    wrongResults = test_data[test_data['name'] != predictions]['path']
    for index in range(len(wrongResults)):
        img = tc.Image(test_data[test_data['name'] != predictions]['path'][index])
        img.show()

查询后发现本应该属于WallE的图:


WallE测试图

被误识别成了luffy-路飞。可能是跟这张训练样本混淆了(以及另外3张路飞的图,也都有大面积蓝色背景):

路飞训练图

可以看出这张图的元素是比较复杂的,样本可能还需要更多些。

第二部分:模型集成(iOS)

模型导入

真正到XCode是集成使用的,只有turi_train.py中导出的mydataset.mlmodel(94.2MB)。拖动文件到项目中即可:

XCode导入

这一步会自动生成OC的Mydataset.hMydataset.m文件。

如果需要生成swift格式的文件,则可在 项目 - BuildSetting - 点击All - 搜索 coreml,会看到选项 CoreML Model Class Generation Language,修改即可。

模型使用

在你想要使用的源文件中,先引入库:

#import <CoreML/CoreML.h>
#import <Vision/Vision.h>
#import "Mydataset.h"

关键类为VNCoreMLModelVNCoreMLRequestVNImageRequestHandler。核心在于用VNCoreMLModel来构建VNCoreMLRequest,方法为:
- (instancetype)initWithModel:(VNCoreMLModel *)model completionHandler:(VNRequestCompletionHandler)completionHandler
然后在completeHandler里获取request分类的结果数组,根据结果里匹配率最高的来当作识别结果。具体代码如下:

Mydataset *resnetModel = [[Mydataset alloc] init];
VNCoreMLModel *vnCoreModel = [VNCoreMLModel modelForMLModel:resnetModel.model error:nil];

VNCoreMLRequest *vnCoreMlRequest = [[VNCoreMLRequest alloc] initWithModel:vnCoreModel completionHandler:^(VNRequest * _Nonnull request, NSError * _Nullable error) {
    CGFloat confidence = 0.0f;
    VNClassificationObservation *tempClassification = nil;
    for (VNClassificationObservation *classification in request.results) {
        if (classification.confidence > confidence) {
        //取最大概率的模型
            confidence = classification.confidence;
            tempClassification = classification;
        }
    }
    NSLog(@"识别结果:%@,匹配率:%@",tempClassification.identifier,@(tempClassification.confidence));
  }];

//这里的image为输入图片
VNImageRequestHandler *vnImageRequestHandler = [[VNImageRequestHandler alloc] initWithCGImage:image.CGImage options:nil];

NSError *error = nil;
[vnImageRequestHandler performRequests:@[vnCoreMlRequest] error:&error];

if (error) {
    NSLog(@"%@",error.localizedDescription);
}

这部分代码对于iOS开发来说,还是比较简单易懂的。基本的识别到这一步也就可以了。很神奇,只要这么一点点的代码和一个导入的模型,我们就完成了机器学习的工作,Apple对于开发者实在是友好。

使用效果

实时性:

到上一步为止,基本的步骤已经完成了 ,但既然开头说了,要达到能实际使用的程度,那就需要再进行一些完善了。先得试试实时性怎么样,因此采用摄像头实时获取图片,实时辨别来试试效果。这部分和机器学习关系不大,因此略过详细过程,核心方法为- (void)captureOutput:(AVCaptureOutput *)output didOutputSampleBuffer:(CMSampleBufferRef)sampleBuffer fromConnection:(AVCaptureConnection *)connection

使用的queue要注意下,更新的时候也要注意切回主线程

在这个回调里,先把图片摆正,这样能提高准确率。[connection setVideoOrientation:AVCaptureVideoOrientationPortrait];然后根据sampleBuffer获取图片即可。

实际效果下图:

效果演示

图不动的话,戳这里

这里在末尾也演示了前面提到的other这个分类,其实这个分类存在的目的,就是为了增加分类器的健壮性,可参加这个issue。如果只训练两个分类A和B,那么分类器A和B的概率相加为1,假设新物体非常不像A,那么有可能显示的B的概率为1,造成误判,other这个分类的意义就在于摊平这里的概率,当然对于other里训练的图片选择,感觉是个大学问,目前我只是随意的放了些非目标分类的图片。

准确性

做了那么多的步骤,直接看图就知道效果了:

效果演示1
效果演示2

图不动的话,戳这里

总体来说,目标物对的稍微准点,95+%的识别率还是有的,超过了我的预期,可应用到实际中。

包的大小

前文提到过,导出的包为94.2MB,这对于一个iOS App来说,实在是有点太大了。贴心的Apple当然也给了解决方法,那就是替换卷积神经网络CNN,CNN的主要目的是 提取图片的特征值。替换的地方在turi_train.py的第三步:

# 3. 生成模型
model = tc.image_classifier.create(train_data,  target='name')

这里还有一个参数,model,改成

model = tc.image_classifier.create(train_data, target='name',model='squeezenet_v1.1')

也就是把model的CNN指定为squeezenet_v1.1(默认的为resnet-50)。当然这里还可以设置其他的参数,比如最大迭代次数等。这样导出的mlmodel一下子就变成了5MB左右,小了非常的多!当然,这也牺牲了一定的精度。具体对比,Apple已经列了对比:

CNN模型对比

而Apple官网提供的“从1000种类别的对象中检测出图像中的主体”训练集当中,从大到小依次为

  • VGG16: 553.5MB
  • ResNet50: 102.6MB
  • Inception v3: 94.7MB
  • MobileNet: 17.1MB
  • SqueezeNet: 5MB

至于如何在精度和包大小取舍就看自己的选择了。

在线下载、更新包

一个包离线打在项目里,既更新不了,又导致每个用户的包都变大了,这显然不是一个好的实践。Apple提供了一个新的API,+ (NSURL *)compileModelAtURL:(NSURL *)modelURL error:(NSError * _Nullable *)error;,使用方法也很简单,下载数据,放到沙盒里,然后compile即可。需要注意的是,这个方法较为耗时,不要放在主线程。

这样包大小的问题也算一定程度上解决了。

存在的问题

读文章最怕介绍的都是各种优点的文章,显然,作为这么个工具,还是需要提出我在这整个过程中遇到的问题:

  • 训练模型需较多,识别出的内容仅在训练分类中,如果不是,会出现误识别,比如训练中有“杯子”这个种类,如果有个电器长得跟“杯子”很像,那这个电器就会被识别为“杯子”,属于误识别。 这个通过增加训练种类能一定程度上解决。
  • 由于是摄像头实时取,实时识别,识别结果会存在一定程度的抖动。 这个通过设置阈值等可以解决
  • 实时获取判断,机器发热较为严重,没有做过具体的性能检测。 这个可以定时获取或者继续优化代码来调优。
  • 试验中涉及到的种类较少,实际应用到需求里所需的种类后,训练情况和效果未知。

第三部分:What's More?

目标跟踪

目标跟踪

Turi Create这个工具能做到的远不止图像分类,还有目标追踪,推荐系统,相似图片,文字识别等等。其中目标跟踪跟本实践较为接近,这个可以继续叠加训练数据的维度来实现。需要增加的工作为,需要标记每一张训练图的目标物方框坐标,数据格式为:

[{'coordinates': {'height': 104, 'width': 110, 'x': 115, 'y': 216}, 'label': 'ball'}, {'coordinates': {'height': 106, 'width': 110, 'x': 188, 'y': 254}, 'label': 'ball'}, {'coordinates': {'height': 164, 'width': 131, 'x': 374, 'y': 169}, 'label': 'cup'}]
其他步骤跟前文提到的基本一致。具体可以大家自己尝试。介绍在官网github上。

Android使用

本文中通过TuriCreate生成的数据为mlmodel,仅供iOS使用,可通过开源工具MMdnn来转换为Caffe, Keras, MXNet, Tensorflow, CNTK, PyTorch Onnx这些模型,从而供其他方来使用。

@张云龙:

也可以只用训练素材图片,然后用 tensorflow-for-poets-2 来训练,得到 retrained_graph.pb 和 retrained_labels.txt 集成到Android中。

执行脚本

export ARCHITECTURE="Mobilenet_0.75_224"
export CUSTOM_TRAIN_PICS="trainpics"
python -m scripts.retrain \
      --bottleneck_dir=tf_files/bottlenecks \
      --how_many_training_steps=500 \
      --model_dir=tf_files/models/ \
      --summaries_dir=tf_files/training_summaries/"${ARCHITECTURE}" \
      --output_graph=tf_files/retrained_graph.pb \
      --output_labels=tf_files/retrained_labels.txt \
      --architecture="${ARCHITECTURE}" \
      --image_dir=tf_files/"${CUSTOM_TRAIN_PICS}"

算法学习

本文是一篇应用型的文章,基本没有介绍真正的机器学习的知识。这部分还是很有必要深入了解下的,这两个感觉介绍的不错,可推荐:

谷歌 - 机器学习速成课程

知乎专栏 - 卷积神经网络(CNN)入门讲解

第四部分:参考文章

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 201,924评论 5 474
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 84,781评论 2 378
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 148,813评论 0 335
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,264评论 1 272
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,273评论 5 363
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,383评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,800评论 3 393
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,482评论 0 256
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,673评论 1 295
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,497评论 2 318
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,545评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,240评论 4 318
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,802评论 3 304
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,866评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,101评论 1 258
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 42,673评论 2 348
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,245评论 2 341

推荐阅读更多精彩内容

  • 1、通过CocoaPods安装项目名称项目信息 AFNetworking网络请求组件 FMDB本地数据库组件 SD...
    X先生_未知数的X阅读 15,967评论 3 119
  • 发现 关注 消息 iOS 第三方库、插件、知名博客总结 作者大灰狼的小绵羊哥哥关注 2017.06.26 09:4...
    肇东周阅读 11,994评论 4 61
  • 6月2号,周五晚上,加班后回到小区已经8点半了,感觉很累。 车子还没有停好,接到儿子的电话,问我他的玩具拿到了没有...
    monicaqiqi阅读 369评论 0 0
  • 你接触什么样的世界,决定着你成为什么样的人。 我用了七年的时间尝试跨过这道坎,始终还是不能逾越,唯有考研是我一雪前...
    哈米先生阅读 110评论 0 0
  • 如果*爱 (2009年情人节) 如果你是小溪 那么我就是溪里的鱼儿 热情地为你的快乐而快乐 如果你是大山 那么我...
    沉睡的古莲阅读 420评论 0 0