Tensorflow实战google深度学习框架代码学习八(模型的保存和使用)

#保存模型一共会有3个文件,ckpt保存变量的值,meta保存的是图的结构,checkpoint保存此文件中所有模型的列表
import tensorflow as tf 
a1 = tf.Variable(tf.truncated_normal(shape=[2],seed=2),name='a1')
a2 = tf.Variable(tf.truncated_normal(shape=[2],seed=2),name='a2')
result = a1 +a2

#保存模型
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess,'saveckpt/model.ckpt')
import tensorflow as tf
a1 = tf.Variable(tf.truncated_normal(shape=[2]),name='a1')
a2 = tf.Variable(tf.truncated_normal(shape=[2]),name='a2')
result = a1 +a2

#加载模型
saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess,'saveckpt/model.ckpt')
    print(sess.run(result))

结果:

INFO:tensorflow:Restoring parameters from saveckpt/model.ckpt
[-1.71622169 -0.39324597]
#直接加载持久化的图
import tensorflow as tf
saver = tf.train.import_meta_graph('saveckpt/model.ckpt.meta')
with tf.Session() as sess:
    saver.restore(sess,'saveckpt/model.ckpt')
    print(sess.run(a1))

结果:

INFO:tensorflow:Restoring parameters from saveckpt/model.ckpt
[-0.85811085 -0.19662298]
#变量重命名
import tensorflow as tf
v1 = tf.Variable([1.0,2.1],name='v1')
v2 = tf.Variable([2.0,3.0],name='v2')
saver = tf.train.Saver(var_list={'a1':v1,'a2':v2})
with tf.Session() as sess:
    saver.restore(sess,'saveckpt/model.ckpt')
    print(sess.run(v1))

结果:

INFO:tensorflow:Restoring parameters from saveckpt/model.ckpt
[-0.85811085 -0.19662298]

    原文作者:Carson
    原文地址: https://zhuanlan.zhihu.com/p/37473083
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞