Tensorflow MobileNet移植到Android

最近看到一个巨牛的人工智能教程,分享一下给大家。教程不仅是零基础,通俗易懂,而且非常风趣幽默,像看小说一样!觉得太牛了,所以分享给大家。平时碎片时间可以当小说看,【点这里可以去膜拜一下大神的“小说”】

1 CKPT模型转换pb文件

使用上一篇博客《MobileNet V1官方预训练模型的使用》中下载的MobileNet V1官方预训练的模型《MobileNet_v1_1.0_192》。虽然打包下载的文件中包含已经转换过的pb文件,但是官方提供的pb模型输出是1001类别对应的概率,我们需要的是概率最大的3类。可在原始网络中使用函数tf.nn.top_k获取概率最大的3类,将函数tf.nn.top_k作为网络中的一个计算节点。模型转换代码如下所示。

import tensorflow as tf
from mobilenet_v1 import mobilenet_v1,mobilenet_v1_arg_scope 
import numpy as np
slim = tf.contrib.slim
CKPT = 'mobilenet_v1_1.0_192.ckpt'  

def build_model(inputs):   
    with slim.arg_scope(mobilenet_v1_arg_scope(is_training=False)):
        logits, end_points = mobilenet_v1(inputs, is_training=False, depth_multiplier=1.0, num_classes=1001)
    scores = end_points['Predictions']
    print(scores)
    #取概率最大的5个类别及其对应概率
    output = tf.nn.top_k(scores, k=3, sorted=True)
    #indices为类别索引,values为概率值
    return output.indices,output.values

def load_model(sess):
    loader = tf.train.Saver()
    loader.restore(sess,CKPT)
   
inputs=tf.placeholder(dtype=tf.float32,shape=(1,192,192,3),name='input')
classes_tf,scores_tf = build_model(inputs)  
classes = tf.identity(classes_tf, name='classes')
scores = tf.identity(scores_tf, name='scores') 
with tf.Session() as sess:
    load_model(sess)
    graph = tf.get_default_graph()
    output_graph_def = tf.graph_util.convert_variables_to_constants(
        sess, graph.as_graph_def(), [classes.op.name,scores.op.name]) 
    tf.train.write_graph(output_graph_def, 'model', 'mobilenet_v1_1.0_192.pb', as_text=False)

上面代码中,单一的所有类别概率经过计算节点tf.nn.top_k后分为两个输出:概率最大的3个类别classes,概率最大的3个类别的概率scores。执行上面代码后,在目录“model”中得到文件mobilenet_v1_1.0_192.pb

2 移植到Android中

2.1 AndroidStudio中使用Tensorflow Mobile

首先,AndroidStudio版本必须是3.0及以上。创建Android Project后,在Module:appbuild.gradle文件中的dependencies中加入如下:

 compile 'org.tensorflow:tensorflow-android:+'

2.2 Tensorflow Mobile接口

使用Tensorflow Mobile库中模型调用封装类org.tensorflow.contrib.android.TensorFlowInferenceInterface完成模型的调用,主要使用的如下函数。

 public TensorFlowInferenceInterface(AssetManager assetManager, String model){...}
 public void feed(String inputName, float[] src, long... dims) {...}
 public void run(String[] outputNames) {...}
 public void fetch(String outputName, int[] dst) {...}

其中,构造函数中的参数model表示目录“assets”中模型名称。feed函数中参数inputName表示输入节点的名称,即对应模型转换时指定输入节点的名称“input”,参数src表示输入数据数组,变长参数dims表示输入的维度,如传入1,192,192,3则表示输入数据的Shape=[1,192,192,3]。函数run的参数outputNames表示执行从输入节点到outputNames中节点的所有路径。函数fetch中参数outputName表示输出节点的名称,将指定的输出节点的数据拷贝到dst中。

2.3 Bitmap对象转float[]

注意到,在2.1小节中函数feed传入到输入节点的数据对象是float[]。因此有必要将Bitmap转为float[]对象,示例代码如下所示。

    //读取Bitmap像素值,并放入到浮点数数组中。归一化到[-1,1]
    private float[] getFloatImage(Bitmap bitmap){
        Bitmap bm = getResizedBitmap(bitmap,inputWH,inputWH);
        bm.getPixels(inputIntData, 0, bm.getWidth(), 0, 0, bm.getWidth(), bm.getHeight());
        for (int i = 0; i < inputIntData.length; ++i) {
            final int val = inputIntData[i];
            inputFloatData[i * 3 + 0] =(float) (((val >> 16) & 0xFF)/255.0-0.5)*2;
            inputFloatData[i * 3 + 1] = (float)(((val >> 8) & 0xFF)/255.0-0.5)*2;
            inputFloatData[i * 3 + 2] = (float)(( val & 0xFF)/255.0-0.5)*2 ;
        }
        return inputFloatData;
    }

由于MobileNet V1预训练的模型输入数据归一化到[-1,1],因此在函数getFloatImage中转换数据的同时将数据归一化到[-1,1]

2.4 封装模型调用

为了便于调用,将与模型相关的调用函数封装到类TFModelUtils中,通过TFModelUtilsrun函数完成模型的调用,示例代码如下所示。

package com.huachao.mn_v1_192; 
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.Matrix;
import android.util.Log; 
import org.tensorflow.contrib.android.TensorFlowInferenceInterface; 
import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.HashMap;
import java.util.Map; 
public class TFModelUtils {
    private TensorFlowInferenceInterface inferenceInterface;
    private int[] inputIntData ;
    private float[] inputFloatData ;
    private int inputWH;
    private String inputName;
    private String[] outputNames; 
    private Map<Integer,String> dict;
    public TFModelUtils(AssetManager assetMngr,int inputWH,String inputName,String[]outputNames,String modelName){
        this.inputWH=inputWH;
        this.inputName=inputName;
        this.outputNames=outputNames;
        this.inputIntData=new int[inputWH*inputWH];
        this.inputFloatData = new float[inputWH*inputWH*3];
        //从assets目录加载模型
        inferenceInterface= new TensorFlowInferenceInterface(assetMngr, modelName);
        this.loadLabel(assetMngr);
    }
    public Map<String,Object> run(Bitmap bitmap){ 
        float[] inputData = getFloatImage(bitmap);
        //将输入数据复制到TensorFlow中,指定输入Shape=[1,INPUT_WH,INPUT_WH,3]
        inferenceInterface.feed(inputName, inputData, 1, inputWH, inputWH, 3); 
        // 执行模型
        inferenceInterface.run( outputNames ); 
        //将输出Tensor对象复制到指定数组中
        int[] classes=new int[3];
        float[] scores=new float[3];
        inferenceInterface.fetch(outputNames[0], classes);
        inferenceInterface.fetch(outputNames[1], scores);
        Map<String,Object> results=new HashMap<>();
        results.put("scores",scores);
        String[] classesLabel = new String[3];
        for(int i =0;i<3;i++){
            int idx=classes[i];
            classesLabel[i]=dict.get(idx);
//            System.out.printf("classes:"+dict.get(idx)+",scores:"+scores[i]+"\n");
        }
        results.put("classes",classesLabel);
        return results;


    }
    //读取Bitmap像素值,并放入到浮点数数组中。归一化到[-1,1]
    private float[] getFloatImage(Bitmap bitmap){
        Bitmap bm = getResizedBitmap(bitmap,inputWH,inputWH);
        bm.getPixels(inputIntData, 0, bm.getWidth(), 0, 0, bm.getWidth(), bm.getHeight());
        for (int i = 0; i < inputIntData.length; ++i) {
            final int val = inputIntData[i];
            inputFloatData[i * 3 + 0] =(float) (((val >> 16) & 0xFF)/255.0-0.5)*2;
            inputFloatData[i * 3 + 1] = (float)(((val >> 8) & 0xFF)/255.0-0.5)*2;
            inputFloatData[i * 3 + 2] = (float)(( val & 0xFF)/255.0-0.5)*2 ;
        }
        return inputFloatData;
    }
    //对图像做Resize
    public Bitmap getResizedBitmap(Bitmap bm, int newWidth, int newHeight) {
        int width = bm.getWidth();
        int height = bm.getHeight();
        float scaleWidth = ((float) newWidth) / width;
        float scaleHeight = ((float) newHeight) / height;
        Matrix matrix = new Matrix();
        matrix.postScale(scaleWidth, scaleHeight);
        Bitmap resizedBitmap = Bitmap.createBitmap( bm, 0, 0, width, height, matrix, false);
        bm.recycle();
        return resizedBitmap;
    }

    private void loadLabel( AssetManager assetManager ) {
        dict=new HashMap<>();
        try {
            InputStream stream = assetManager.open("label.txt");
            InputStreamReader isr=new InputStreamReader(stream);
            BufferedReader br=new BufferedReader(isr);
            String line;
            while((line=br.readLine())!=null){
                line=line.trim();
                String[] arr = line.split(",");
                if(arr.length!=2)
                    continue;
                int key=Integer.parseInt(arr[0]);
                String value = arr[1];
                dict.put(key,value);
            }
        }catch (Exception e){
            e.printStackTrace();
            Log.e("ERROR",e.getMessage());
        }
    }
}

3 模型测试

测试模型
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 203,271评论 5 476
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,275评论 2 380
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 150,151评论 0 336
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,550评论 1 273
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,553评论 5 365
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,559评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,924评论 3 395
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,580评论 0 257
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,826评论 1 297
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,578评论 2 320
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,661评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,363评论 4 318
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,940评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,926评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,156评论 1 259
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 42,872评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,391评论 2 342

推荐阅读更多精彩内容