TensorFlow 将Keras和Checkpoint格式转换为SavedModel格式

滴滴云技术支持发表于:2019年06月19日 16:10:59

滴滴云弹性推理服务支持TensorFlow SavedModel格式的模型部署成在线服务,本文介绍如何将Keras模型格式和Checkpoint模型格式导出为SavedModel格式。

SavedModel格式简介

SavedModel格式的模型是Tensorflow官方推荐的导出模型格式,模型目录结构如下所示:

assets/
variables/
    variables.data-00000-of-00001
    variables.index
saved_model.pb|saved_model.pbtxt

其中:

  • assets是一个可选目录,用于存放预测时的辅助文档信息;

  • variables存放tf.train.Saver时保存的变量信息;

  • saved_model.pb或saved_model.pbtxt存放MetaGraphDef,存储训练预测模型的程序逻辑和SignatureDef用于标记预测时的输入和输出。

滴滴云弹性推理服务建议将这些文件存放在一个按照数字命名的目录下,以一个导出的SavedModel模型Inceptionv3为例,目录结构如下:

inception/
 1/
   saved_model.pb
   variables/
      variables.data-00000-of-00001
     variables.index

Keras模型转换成Savedmodel

对于使用keras的model.save()方法来将keras模型导出成h5格式的情况,将h5格式的模型转换成Savedmodel调用load_model()方法将h5模型加载,再导出成Savedmodel格式,代码片段示例如下所示:

import tensorflow as tf
with tf.device("/cpu:0"):
    model = tf.keras.models.load_model('./mnist.h5')
    tf.saved_model.simple_save(
      tf.keras.backend.get_session(),
      "./h5_savedmodel/",
      inputs={"image": model.input},
      outputs={"scores": model.output}
)

Checkpoint转换成Savedmodel

训练过程中使用tf.train.Saver()保存的模型格式为checkpoint格式,同样需要转换成Savedmodel才可进行在线推理,可以saver.restore()方法将checkpoint加载成tf session,再用上述方法转换成saved_model即可,示例如下所示:


import tensorflow as tf
# variable define ...
saver = tf.train.Saver()
with tf.Session() as sess:
  # Initialize v1 since the saver will not.
    saver.restore(sess, "./lr_model/model.ckpt")
    tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
    tensor_info_y = tf.saved_model.utils.build_tensor_info(y)
    tf.saved_model.simple_save(
      sess,
      "./savedmodel/",
      inputs={"image": tensor_info_x},
      outputs={"scores": tensor_info_y}
)