Java调用TensorFlow库预测图片质量

概述

本文参考官方的java使用TensorFlow库的例子,将深度学习模型保存成pb文件,在java环境中加载模型并做预测。

环境安装

安装pip
yum -y install epel-release
yum install python-pip
pip install --upgrade pip
安装TensorFlow、Keras、numpy
pip install tensorflow  //安装的是最新的tensorflow2.1版本
pip install keras
pip install numpy

Maven配置

在pom.xml中增加如下配置,加载java的tensorflow库

<dependency>
    <groupId>org.tensorflow</groupId>
    <artifactId>tensorflow</artifactId>
    <version>1.15.0</version>
</dependency>

加载模型

InputStream inputStream = ImageRecognize.class.getResourceAsStream(MODEL_PATH);
Graph graph = new Graph();
graph.importGraphDef(IOUtils.toByteArray(inputStream));
Session session = new Session(graph);

其中graph和session都是线程安全的,可以使用单例,不需要每次请求都重新加载模型和Session。

图片预处理

我们使用的是Xception模型,它要求输入的图片大小是480 * 480 * 3,并且需要对图片做预处理,每个rgb值归一化到[-1,1]区间。下面我们介绍两种预处理的方式:

BufferedImage预处理
BufferedImage bufferedImage = new BufferedImage(480, 480, BufferedImage.TYPE_INT_RGB);
Graphics graphics = bufferedImage.getGraphics();

InputStream in = new ByteArrayInputStream(imageData);
Image srcImage = ImageIO.read(in);
graphics.drawImage(srcImage, 0, 0, 480, 480, null); //将图片大小转换为480*480

int w = bufferedImage.getWidth();
int h = bufferedImage.getHeight();
float[][][][] imgTensor = new float[1][h][w][3];
for (int i = 0; i < h; i++) {
     for (int j = 0; j < w; j++) {
              int pixel = bufferedImage.getRGB(j, i); // 下面三行代码将一个数字转换为RGB数字,同时归一化到[-1,1]区间
              imgTensor[0][i][j][0] = (float) ((pixel & 0xff0000) >> 16) / 127.5f - 1;
              imgTensor[0][i][j][1] = (float) ((pixel & 0xff00) >> 8) / 127.5f - 1;
              imgTensor[0][i][j][2] = (float) ((pixel & 0xff)) / 127.5f - 1;
      }
 }
return Tensors.create(imgTensor);
TensorFlow预处理

TensorFlow的预处理参考了LabelImage.java调用方式,它是使用TensorFlow Graph的一些预定义好的Operator来对图片做预处理。

  private Tensor<Float> getImageTensor(byte[] imageBytes){
      Graph g = new Graph();
      GraphBuilder b = new GraphBuilder(g);

      final int H = IMAGE_HEIGTH;
      final int W = IMAGE_WIDTH;
      final float mean = 1f;
      final float scale = 127.5f;

      final Output<String> input = b.constant("input", imageBytes);
      final Output<Float> output =
              b.sub(
                      b.div(
                              b.resizeBilinear(
                                      b.expandDims(
                                              b.cast(b.decodeJpeg(input, 3), Float.class), //解析jpeg文件
                                              b.constant("make_batch", 0) //扩展成4维Tensor
                                      ),
                                      b.constant("size", new int[]{H, W}) //resize图片成[H,W]大小
                              ),
                              b.constant("scale", scale) //每个值除以127.5f
                      ),
                      b.constant("mean", mean) //归一化到[-1,1]区间
              );
      try (Session s = new Session(g)) {
        // Generally, there may be multiple output tensors, all of them must be closed to prevent resource leaks.
        return s.runner().fetch(output.op().name()).run().get(0).expect(Float.class);
      }
  }

  static class GraphBuilder {
    GraphBuilder(Graph g) {
      this.g = g;
    }

    Output<Float> div(Output<Float> x, Output<Float> y) {
      return binaryOp("Div", x, y);
    }

    <T> Output<T> sub(Output<T> x, Output<T> y) {
      return binaryOp("Sub", x, y);
    }

    <T> Output<Float> resizeBilinear(Output<T> images, Output<Integer> size) {
      return binaryOp3("ResizeBilinear", images, size);
    }

    <T> Output<T> expandDims(Output<T> input, Output<Integer> dim) {
      return binaryOp3("ExpandDims", input, dim);
    }

    <T, U> Output<U> cast(Output<T> value, Class<U> type) {
      DataType dtype = DataType.fromClass(type);
      return g.opBuilder("Cast", "Cast")
              .addInput(value)
              .setAttr("DstT", dtype)
              .build()
              .<U>output(0);
    }

    Output<UInt8> decodeJpeg(Output<String> contents, long channels) {
      return g.opBuilder("DecodeJpeg", "DecodeJpeg")
              .addInput(contents)
              .setAttr("channels", channels)
              .build()
              .<UInt8>output(0);
    }

    <T> Output<T> constant(String name, Object value, Class<T> type) {
      try (Tensor<T> t = Tensor.<T>create(value, type)) {
        return g.opBuilder("Const", name)
                .setAttr("dtype", DataType.fromClass(type))
                .setAttr("value", t)
                .build()
                .<T>output(0);
      }
    }
    Output<String> constant(String name, byte[] value) {
      return this.constant(name, value, String.class);
    }

    Output<Integer> constant(String name, int value) {
      return this.constant(name, value, Integer.class);
    }

    Output<Integer> constant(String name, int[] value) {
      return this.constant(name, value, Integer.class);
    }

    Output<Float> constant(String name, float value) {
      return this.constant(name, value, Float.class);
    }

    private <T> Output<T> binaryOp(String type, Output<T> in1, Output<T> in2) {
      return g.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0);
    }

    private <T, U, V> Output<T> binaryOp3(String type, Output<U> in1, Output<V> in2) {
      return g.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0);
    }
    private Graph g;
  }

上面两种方式都做了尝试,我们是原始图片大小为640 * 640,机器是2.4GHz的CPU机器(没有用GPU),第一种预处理方法平均耗时在200ms左右,第二种预处理方法耗时为30ms左右,主要原因为TensorFlow内部对矩阵运算会做优化,而第一种方法手写的循环效率不高。后续会再测一下在GPU环境下的耗时。

模型预测

我们的xception模型中,输入节点的名字为input_1,输出节点的名字为output,对应着代码里的名字,需要完全一致。

    float result = -1;
    input = getImageTensor1(imageData);
    if ( input == null ) {
      return result;
    }

    List<Tensor<?>> results = session.runner().feed("input_1", input).fetch("output").run();
    if (results.size() > 0 && results.get(0).shape().length == 2) {
      long[] rshape = results.get(0).shape();
      int rs = (int) rshape[0];
      int rt = (int) rshape[1];
      float realResult[][] = new float[rs][rt];

      results.get(0).copyTo(realResult);
      for (int i = 0; i < rs; i++) {
        for (int j = 0; j < rt; j++) {
          result = realResult[i][j];
          break;
        }
      }
    }

线上部署

线上使用时候,有一个线程不断的从HDFS中检查并读取最新的模型。一旦模型有更新,则加载新模型替换旧模型。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 194,088评论 5 459
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 81,715评论 2 371
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 141,361评论 0 319
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 52,099评论 1 263
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 60,987评论 4 355
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 46,063评论 1 272
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 36,486评论 3 381
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 35,175评论 0 253
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 39,440评论 1 290
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 34,518评论 2 309
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 36,305评论 1 326
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 32,190评论 3 312
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 37,550评论 3 298
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 28,880评论 0 17
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 30,152评论 1 250
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 41,451评论 2 341
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 40,637评论 2 335