1、保存模型
Tensorflow 1.7中,保存训练模型,
import tensorflow as tf
from tensorflow.python.saved_model import tag_constants
with tf.Graph().as_default():
with tf.Sessionas sess:
#省略其它逻辑代码
# Saving
inputs = {
"batch_size_placeholder": batch_size_placeholder,
"features_placeholder": features_placeholder,
"labels_placeholder": labels_placeholder,
}
outputs = {"prediction": model_output}
tf.saved_model.simple_save(
sess, 'path/to/your/location/', inputs, outputs
)
path/to/your/location/ :保存模型的路径
2、恢复模型
Tensorflow 1.7中,还原训练模型,
import tensorflow as tf
from tensorflow.python.saved_model import tag_constants
restored_graph = tf.Graph()
with restored_graph.as_default():
with tf.Sessionas sess:
tf.saved_model.loader.load(
sess,
[tag_constants.SERVING],
'path/to/your/location/',
)
batch_size_placeholder = graph.get_tensor_by_name('batch_size_placeholder:0')
features_placeholder = graph.get_tensor_by_name('features_placeholder:0')
labels_placeholder = graph.get_tensor_by_name('labels_placeholder:0')
prediction = restored_graph.get_tensor_by_name('dense/BiasAdd:0')
sess.run(prediction, feed_dict={
batch_size_placeholder: some_value,
features_placeholder: some_other_value,
labels_placeholder: another_value
})
path/to/your/location/ :要还原模型的路径
文档地址:https://www.tensorflow.org/programmers_guide/saved_model