参考:《深度学习图像识别技术--基于TensorFlow Object Detection API 和 OpenVINO》
上节,详细分析了在自定义模型函数中,创建神经网络的三个步骤:创建输入层、创建隐藏层、创建输出层。本节主要介绍自定义模型函数的最后一步:编写实现预测、评估和训练的分支代码
回忆一下:《TensorFlow入门16: 创建自定义的Estimator 2》
1,Model_fn的返回值是: tf.estimator.EstimatorSpec。
2,Estimator对象的三个方法train、evaluate、predict都会调用model_fn给Estimator传参数。
3,当Estimator对象调用 train、evaluate 或 predict 方法时,Estimator 对象会在调用模型函数前,将 mode 参数设置为对应的值:ModeKeys.TRAIN、ModeKeys.EVAL、ModeKeys.PREDICT。
由此,model_fn函数创建好神经网络后,检测mode值,根据不同的mode,实现对应的代码,并返回: tf.estimator.EstimatorSpec,具体的实现,参考下图:
完成model_fn函数编写后,回到main函数,可以发现,只有创建classifier对象的代码,略有不同,其余代码一模一样,如下图所示