环境:
Python 3.5.2
tensorflow : 1.11.0
ubuntu : 16.04
保存模型,github代码
saved_model_dir='./model'
builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
# input_x, keep_prob
inputs = {'input_x': tf.saved_model.utils.build_tensor_info(xs),
'input_y': tf.saved_model.utils.build_tensor_info(ys),
'keep_prob': tf.saved_model.utils.build_tensor_info(keep_prob)}
# prediction 为预测函数,恢复的时候要通过该函数来预测
outputs = {'prediction' : tf.saved_model.utils.build_tensor_info(prediction)}
signature = tf.saved_model.signature_def_utils.build_signature_def(inputs, outputs, 'test_sig_name')
with tf.Session() as sess:
sess.run(init)
for i in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={xs: batch_xs, ys: batch_ys, keep_prob: 0.5})
if i % 50 == 0:
print(compute_accuracy(sess, prediction,
mnist.test.images[:1000], mnist.test.labels[:1000]))
builder.add_meta_graph_and_variables(sess, ['model_final'], {'test_signature':signature})
builder.save()
恢复模型github代码
saved_model_dir='./model'
signature_key = 'test_signature'
input_key_x = 'input_x'
input_key_y = 'input_y'
input_key_keep_prob = 'keep_prob'
output_key_prediction = 'prediction'
with tf.Session() as sess:
meta_graph_def = tf.saved_model.loader.load(sess, ['model_final'], saved_model_dir)
# 从meta_graph_def中取出SignatureDef对象
signature = meta_graph_def.signature_def
# 从signature中找出具体输入输出的tensor name.
x_tensor_name = signature[signature_key].inputs[input_key_x].name
y_tensor_name = signature[signature_key].inputs[input_key_y].name
keep_prob_tensor_name = signature[signature_key].inputs[input_key_keep_prob].name
prediction_tensor_name = signature[signature_key].outputs[output_key_prediction].name
# 获取tensor 并inference
input_x = sess.graph.get_tensor_by_name(x_tensor_name)
input_y = sess.graph.get_tensor_by_name(y_tensor_name)
keep_prob = sess.graph.get_tensor_by_name(keep_prob_tensor_name)
prediction = sess.graph.get_tensor_by_name(prediction_tensor_name)
通过恢复的模型,来预测结果
# 测试单个数据
x = mnist.test.images[index].reshape(1, 784)
y = mnist.test.labels[index].reshape(1, 10) # 转为one-hot形式
print (y)
pred_y = sess.run(prediction, feed_dict={input_x: x, keep_prob : 1 })
print (pred_y)
print("Actual class: ", str(sess.run(tf.argmax(y, 1))), \
", predict class ",str(sess.run(tf.argmax(pred_y, 1))), \
", predict ", str(sess.run(tf.equal(tf.argmax(y, 1), tf.argmax(pred_y, 1))))
)
# 测试数据集
print(compute_accuracy(sess, prediction, input_x, keep_prob,
mnist.test.images[:1000], mnist.test.labels[:1000]))