#保存模型一共会有3个文件,ckpt保存变量的值,meta保存的是图的结构,checkpoint保存此文件中所有模型的列表
import tensorflow as tf
a1 = tf.Variable(tf.truncated_normal(shape=[2],seed=2),name='a1')
a2 = tf.Variable(tf.truncated_normal(shape=[2],seed=2),name='a2')
result = a1 +a2
#保存模型
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.save(sess,'saveckpt/model.ckpt')
import tensorflow as tf
a1 = tf.Variable(tf.truncated_normal(shape=[2]),name='a1')
a2 = tf.Variable(tf.truncated_normal(shape=[2]),name='a2')
result = a1 +a2
#加载模型
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess,'saveckpt/model.ckpt')
print(sess.run(result))
结果:
INFO:tensorflow:Restoring parameters from saveckpt/model.ckpt
[-1.71622169 -0.39324597]
#直接加载持久化的图
import tensorflow as tf
saver = tf.train.import_meta_graph('saveckpt/model.ckpt.meta')
with tf.Session() as sess:
saver.restore(sess,'saveckpt/model.ckpt')
print(sess.run(a1))
结果:
INFO:tensorflow:Restoring parameters from saveckpt/model.ckpt
[-0.85811085 -0.19662298]
#变量重命名
import tensorflow as tf
v1 = tf.Variable([1.0,2.1],name='v1')
v2 = tf.Variable([2.0,3.0],name='v2')
saver = tf.train.Saver(var_list={'a1':v1,'a2':v2})
with tf.Session() as sess:
saver.restore(sess,'saveckpt/model.ckpt')
print(sess.run(v1))
结果:
INFO:tensorflow:Restoring parameters from saveckpt/model.ckpt
[-0.85811085 -0.19662298]