Tensorflow学习(一):模型保存与恢复

最近在学习Tensorflow,想将所学通过知乎文章的形式进行总结,以飨读者。

首先明确一点,tensorflow保存的是什么?

模型保存后产生四个文件,分别是:

|--models
|    |--checkpoint
|    |--.meta
|    |--.data
|    |--.index

其中.meta保存的是图的结构,checkpoint文件是个文本文件,里面记录了保存的最新的checkpoint文件以及其它checkpoint文件列表,.data和.index保存的是变量值。

即模型保存的是图的结构和变量值。

一 实例

以下是使用tensorflow实现简单的线性模型:

#生成样本数据
x = np.random.randn(10000,1)
y = 0.03*x+0.8

#定义模型参数
Weights = tf.Variable(tf.random_normal([1],seed=1,stddev=1),name='Weights')
bias = tf.Variable(tf.random_normal([1],seed=1,stddev=1),name='bias')


xx = tf.placeholder(tf.float32,shape=(None,1),name='xx')
yy = tf.placeholder(tf.float32,shape=(None,1),name='yy')

#线性模型
y_predict = tf.add(Weights*xx,bias,name='preds')

#损失函数
loss = tf.reduce_mean(tf.square(yy-y_predict))

#优化方法
optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(loss)

#批训练模型
batchsize = 20
samplesize = 100
start = 0
end = 0
with tf.Session() as sess:
    init_var = tf.global_variables_initializer()
    sess.run(init_var)
    print('before training, variable is %s,%s'%(sess.run(Weights),sess.run(bias)))
    for i in (range(5000)):
# start = (i*batchsize)%100
        if end == samplesize:
            start = 0
        end = np.minimum(start+batchsize,samplesize)
# try:
# end = np.min(start+batchsize,samplesize)
# except:
# print(end)
        sess.run(optimizer,feed_dict={xx:x[start:end],yy:y[start:end]})
        if (i+1)%1000 == 0:
            print(sess.run(loss,feed_dict={xx:x[start:end],yy:y[start:end]}))
        start += batchsize
    print('after training, variable is %s,%s'%(sess.run(Weights),sess.run(bias)))

二 模型保存

通过以下程序可实现保存:

saver = tf.train.Saver()
saver.save(session,dir[,global_step])

save中第一个参数是session,第二个参数是模型保存的位置,第三个参数申明模型每迭代多少步保存一次。

保存一中的模型,并设置每1000步保存一次:

#生成样本数据
x = np.random.randn(10000,1)
y = 0.03*x+0.8

#定义模型参数
Weights = tf.Variable(tf.random_normal([1],seed=1,stddev=1),name='Weights')
bias = tf.Variable(tf.random_normal([1],seed=1,stddev=1),name='bias')


xx = tf.placeholder(tf.float32,shape=(None,1),name='xx')
yy = tf.placeholder(tf.float32,shape=(None,1),name='yy')

#线性模型
y_predict = tf.add(Weights*xx,bias,name='preds')

#损失函数
loss = tf.reduce_mean(tf.square(yy-y_predict))

#优化方法
optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(loss)

#批训练模型
batchsize = 20
samplesize = 100
start = 0
end = 0
with tf.Session() as sess:
    init_var = tf.global_variables_initializer()
    sess.run(init_var)
    saver = tf.train.Saver()

    print('before training, variable is %s,%s'%(sess.run(Weights),sess.run(bias)))
    for i in (range(5000)):
# start = (i*batchsize)%100
        if end == samplesize:
            start = 0
        end = np.minimum(start+batchsize,samplesize)
# try:
# end = np.min(start+batchsize,samplesize)
# except:
# print(end)
        sess.run(optimizer,feed_dict={xx:x[start:end],yy:y[start:end]})

        #实现每1000步保存一次模型
        if (i+1)%1000 == 0:
            saver.save(sess,'models\ckp',1000)
            print(sess.run(loss,feed_dict={xx:x[start:end],yy:y[start:end]}))
        start += batchsize
    print('after training, variable is %s,%s'%(sess.run(Weights),sess.run(bias)))

以下代码实现了每1000步保存一次模型

if (i+1)%1000 == 0:
    saver.save(sess,'models\ckp',1000)

之所以这样做,是为了防止意外情况下(比如训练时突然断电)下次训练需要从头开始训练。

保存的目录结构如下

|--models
|    |--checkpoint
|    |--ckp-1000.meta
|    |--ckp-1000.data-00000-of-00001
|    |--ckp-1000.index

三 模型恢复

首先加载保存的meta文件

saver = tf.train.import_meta_graph(file_name)

恢复参数,依赖于session,dir表示模型保存的目录路径,此时所有张量的值都在session中

saver.restore(session,tf.train.latest_checkpoint(dir))

获取恢复的参数,varname表示恢复的参数名,因此建议所有的参数都加上name属性

graph = sess.graph #sess所打开的图,所有的结构都在这个图上
graph.get_tensor_by_name(varname)

以下给出回归模型的恢复,并利用训练好的模型进行预测:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('models\ckp-1000.meta')
    saver.restore(sess,tf.train.latest_checkpoint('models'))
    graph = tf.get_default_graph()

    #恢复传入值
    xx = graph.get_tensor_by_name('xx:0')


    #计算利用训练好的模型参数计算预测值
    preds = graph.get_tensor_by_name('preds:0')
    print('predict values:%s' % sess.run(preds,feed_dict={xx:x}))

    原文作者:coder Amazing
    原文地址: https://zhuanlan.zhihu.com/p/33108465
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞