1、保存模型
Tensorflow 1.7中,保存训练模型,
import tensorflow as tf from tensorflow.python.saved_model import tag_constants with tf.Graph().as_default(): with tf.Sessionas sess: #省略其它逻辑代码 # Saving inputs = { "batch_size_placeholder": batch_size_placeholder, "features_placeholder": features_placeholder, "labels_placeholder": labels_placeholder, } outputs = {"prediction": model_output} tf.saved_model.simple_save( sess, 'path/to/your/location/', inputs, outputs )
path/to/your/location/ :保存模型的路径
2、恢复模型
Tensorflow 1.7中,还原训练模型,
import tensorflow as tf from tensorflow.python.saved_model import tag_constants restored_graph = tf.Graph() with restored_graph.as_default(): with tf.Sessionas sess: tf.saved_model.loader.load( sess, [tag_constants.SERVING], 'path/to/your/location/', ) batch_size_placeholder = graph.get_tensor_by_name('batch_size_placeholder:0') features_placeholder = graph.get_tensor_by_name('features_placeholder:0') labels_placeholder = graph.get_tensor_by_name('labels_placeholder:0') prediction = restored_graph.get_tensor_by_name('dense/BiasAdd:0') sess.run(prediction, feed_dict={ batch_size_placeholder: some_value, features_placeholder: some_other_value, labels_placeholder: another_value })
path/to/your/location/ :要还原模型的路径
文档地址:https://www.tensorflow.org/programmers_guide/saved_model