本例主要讲解使用Tensorflow 2.16.1, CPU版本。还有Pretrained model做对象检测的一般步骤,因为YoloV8模型比较大,而且做Object Detection的模型训练需要自己做图像标注,这里直接使用预定义的模型来做简单的Object Detection。
详细的步骤可以参考keras.io的Object Detection Guide。
并非只有本例的模型可用,Tensorflow2.0所有的Keras模型都可以以本例的步骤来使用,三步曲,
- load pretrained model
- model.predict
- visualization
步骤如下,
# 下载tensorflow CPU 2.16.1版本镜像,注意只有这个版本docs才work,
docker pull tensorflow/tensorflow:2.16.1-jupyter
# 在镜像中安装opencv-python
pip install opencv-python
# 在镜像中安装GL相关库,opencv-python需要使用
apt-get install libgl1-mesa-glx
完整的jupyter notebook 链接地址,
https://gitlab.com/zhuge20100104/cpp_practice/-/blob/master/simple_learn/deep_learning/14_tensorflow_object_detection/14.%20Tensorflow%20Object%20Detection.ipynb?ref_type=heads
notebook中的代码如下,
# Object detection demo
# Welcome to the object detection inference walkthrough! This notebook will walk you step by step through
# the process of using pre-trained model to detect objects in an image
# Install necessary dependencies
!python3 -m pip install --upgrade pip
!pip install --upgrade keras-cv
!pip install --upgrade keras # Upgrade to Keras 3.
import os
os.environ['KERAS_BACKEND'] = 'jax'
import tensorflow as tf
from tensorflow import data as tf_data
import tensorflow_datasets as tfds
import tensorflow.keras
import keras_cv
import keras
import numpy as np
from keras_cv import bounding_box
import os
from keras_cv import visualization
import tqdm
# 详细细节可参考: https://keras.io/guides/keras_cv/object_detection_keras_cv/
# Let's get started by constructing a YOLOV8Detector pretrained on the pascalvoc dataset.
pretrained_model = keras_cv.models.YOLOV8Detector.from_preset(
"yolo_v8_m_pascalvoc", bounding_box_format="xywh"
)
file_path = './images/dog_pic.jpeg'
image = keras.utils.load_img(file_path)
image = np.array(image)
visualization.plot_image_gallery(
np.array([image]),
value_range=(0, 255),
rows=1,
cols=1,
scale=5,);
# Resize the image to the model compat input size
inference_resizing = keras_cv.layers.Resizing(
640, 640, pad_to_aspect_ratio=True, bounding_box_format='xywh'
)
# This can be used as our inference preprocessing pipeline:
image_batch = inference_resizing([image])
# keras_cv.visualization.plot_bounding_box_gallery() supports a class_mapping parameter to
# highlight what class each box was assigned to. Let's assemble a class mapping now.
class_ids = [
"Aeroplane",
"Bicycle",
"Bird",
"Boat",
"Bottle",
"Bus",
"Car",
"Cat",
"Chair",
"Cow",
"Dining Table",
"Dog",
"Horse",
"Motorbike",
"Person",
"Potted Plant",
"Sheep",
"Sofa",
"Train",
"Tvmonitor",
"Total",
]
class_mapping = dict(zip(range(len(class_ids)), class_ids))
# Just like any other keras.Model you can predict bounding boxes using the model.predict() API.
y_pred = pretrained_model.predict(image_batch)
visualization.plot_bounding_box_gallery(image_batch,
value_range=(0, 255),
rows=1,
cols=1,
y_pred=y_pred,
scale=5,
font_scale=0.7,
bounding_box_format='xywh',
class_mapping=class_mapping
);
# 扩展: 上面已经完成prediction了,play with NonMaxSupression参数
# The following NonMaxSuppression layer is equivalent to disabling the operation
# 为啥是disable? 因为iou设置为1,表示只有当两者完全重合才去抑制他
# confidence设置成0,表示不管置信度多低,都不抑制他,都相信他
# 这样一弄就会出来一大坨
prediction_decoder = keras_cv.layers.NonMaxSuppression(
bounding_box_format='xywh',
from_logits=True,
iou_threshold=1,
confidence_threshold=0.0,
)
pretrained_model = keras_cv.models.YOLOV8Detector.from_preset(
"yolo_v8_m_pascalvoc",
bounding_box_format='xywh',
prediction_decoder=prediction_decoder,
)
y_pred = pretrained_model.predict(image_batch)
visualization.plot_bounding_box_gallery(
image_batch,
value_range=(0, 255),
rows=1,
cols=1,
y_pred=y_pred,
scale=5,
font_scale=0.7,
bounding_box_format="xywh",
class_mapping=class_mapping,
);
# Next, let's re-configure keras_cv.layers.NonMaxSuppression for our use case!
# In this case, we will tune the iou_threshold to 0.2, and the confidence_threshold to 0.7.
prediction_decoder = keras_cv.layers.NonMaxSuppression(
bounding_box_format='xywh',
from_logits=True,
iou_threshold=0.2,
confidence_threshold=0.7,
)
pretrained_model = keras_cv.models.YOLOV8Detector.from_preset(
"yolo_v8_m_pascalvoc",
bounding_box_format='xywh',
prediction_decoder=prediction_decoder,
)
y_pred = pretrained_model.predict(image_batch)
visualization.plot_bounding_box_gallery(
image_batch,
value_range=(0, 255),
rows=1,
cols=1,
y_pred=y_pred,
scale=5,
font_scale=0.7,
bounding_box_format="xywh",
class_mapping=class_mapping,
);
程序输出的Object Detection的结果如下