TensorFlow导出SavedModel格式模型

滴滴云技术支持发表于:2019年06月19日 16:05:46更新于:2019年06月19日 16:35:34

滴滴云弹性推理服务支持TensorFlow SavedModel格式的模型部署成在线推理服务,本文介绍如何将TensorFlow模型导出为SavedModel格式。

SavedModel格式简介

SavedModel格式的模型是Tensorflow官方推荐的导出模型格式,模型目录结构如下所示:

assets/
variables/
    variables.data-00000-of-00001
    variables.index
saved_model.pb|saved_model.pbtxt

其中:

  • assets是一个可选目录,用于存放预测时的辅助文档信息;

  • variables存放tf.train.Saver时保存的变量信息;

  • saved_model.pb或saved_model.pbtxt存放MetaGraphDef,存储训练预测模型的程序逻辑和SignatureDef用于标记预测时的输入和输出。

滴滴云弹性推理服务建议将这些文件存放在一个按照数字命名的目录下,以一个导出的SavedModel模型Inceptionv3为例,目录结构如下:

inception/
 1/
   saved_model.pb
   variables/
      variables.data-00000-of-00001
     variables.index

导出SavedModel格式模型

使用Tensorflow导出SavedModel格式请参考Saving and Restoring,Tensorflow导出SavedModel格式模型的完整代码下载saved_model.tar.gz。对于简单模型,用户可快速导出savedmodel,如下所示:

tf.saved_model.simple_save(
  session,
  "./savedmodel/",
  inputs={"image": x},   ## x是模型的输入变量
  outputs={"scores": y}  ## y是模型的输出
)

在请求推理服务时,请求中需指定模型的signature_name,使用simple_save()方法导出的模型中,signature_name默认为:serving_default

若模型较复杂,可使用手工的方式来导出saved_model,示例如下所示:

print 'Exporting trained model to', export_path
  builder = tf.saved_model.builder.SavedModelBuilder(export_path)
  tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
  tensor_info_y = tf.saved_model.utils.build_tensor_info(y)
  prediction_signature = (
      tf.saved_model.signature_def_utils.build_signature_def(
          inputs={'images': tensor_info_x},
          outputs={'scores': tensor_info_y},
          method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
  legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
  builder.add_meta_graph_and_variables(
      sess, [tf.saved_model.tag_constants.SERVING],
      signature_def_map={
          'predict_images':
              prediction_signature,
      },
      legacy_init_op=legacy_init_op)
  builder.save()
  print 'Done exporting!'

其中:

  • export_path为导出模型的路径;

  • prediction_signature是模型为输入和输出构建出的SignatureDef,具体可参考SignatureDef,在上例中signature_name为predict_images

  • builder.add_meta_graph_and_variables方法描述了导出模型的参数,特别注意tf.saved_model.tag_constants.SERVING**。

  • 更多的信息请参见:TensorFlow SavedModel