本文分为如下四个部分:
Ⅰ.Tensorflow模型结构
Ⅱ.如何保存Tenflow模型
Ⅲ.如何读取模型并训练数据
Ⅳ.如何拓展模型与修正
一、Tensorflow模型结构
如果我们要存储一个完整的神经网络模型,我们需要存储graph和模型参数两个部分。
依此,我们存储的模型结构也主要分为两类文件。
1.图文件(Meta graph)
Google研发的二进制存储文件(protocol buffer),用于存储完整的图。
model.ckpt.meta
2.数据文件(存储latest Checkpoint)
在0.11版本之前,以.ckpt后缀存储模型的值。之后的版本中,以如下文件形式存储。
checkpoint #c最近存储记录
mymodel.data-00000-of-00001
mymodel.index
二、如何存储tensorflow模型
存储模型主要用到tf.train.Saver()类,其可以:
1全部参数或部分参数存储
2按照迭代次数或指定时间存储
3指定存储最近文件的数量
1.1最基础的存储代码如下:
import tensorflow as tf
#代码
saver=tf.train.Saver()
with Session() as sess:
saver.tf.train.Saver()
sess.run(tf.global_variables_initializer())
saver.save(sess,savepath+"my_model")
1.2如果只存储一部分参数:
saver=tf.train.Saver([name1,name2])
2.1如果按照迭代次数存储:
saver.save(sess,savepath+"my_model",global_step=1000)
#对应存储文件名字会自动加上-1000后缀如:
#my_model-1000.meta
#my_model-1000.index
#my_model-1000.data-00000-of-00001
#第一次存储是从第1000次开始的
2.2如果按照现实时间存储:
saver=tf.train.Saver(keep_checkpoint_every_n_hours=2)
#每两小时存储一次
3.1如果存储最近n次的模型:
saver=tf.train.Saver(max_to_keep=4)
#记录最近4次的模型
三、如何读取模型并训练数据
从第一部分得知,tensorflow模型分为graph和参数值两部分。对应存储文件也分为两类。依此模型的恢复也对应分为两个部分。
1.载入图
new_saver=tf.train.import_meta_graph("my_model.meta")
2.载入参数值
#载入最近一次存储模型
new_saver.restore(sess,tf.train.latest_checkpoint('./'))
#载入指定模型,注意这里没有后缀00000-of-00001
new_saver.restore(sess,"my_model")
3.获取placeholder和operations
其中值为定义图中的name,需要预先设置好并加:0
#将要学习的新数据
x = graph.get_tensor_by_name("x:0")
y_ = graph.get_tensor_by_name("y_:0")
keep_prob=graph.get_tensor_by_name("keepprob:0")
lout = graph.get_tensor_by_name("lout:0")
feed_dict ={x:data,y_:label}
4.根据需要进行继续投喂或验证
#loss = tf.sqrt(tf.reduce_mean(tf.square((lout-y_))))
#train=tf.train.GradientDescentOptimizer(lr).minimize(loss)
#sess.run(train,feed_dict={x:data[0],y_:data[1],keep_prob:0.8})
四、如何拓展与使用模型
1拓展图
lout = graph.get_tensor_by_name("lout:0")
nlout = tf.add(lout,2)
......
......
2使用部分图
......
......
fc3= graph.get_tensor_by_name('fc3:0')
fc3 = tf.stop_gradient(fc3) # 阻挡节点的梯度
fc3_shape= fc3.get_shape().as_list()
......
......
自此,对于tensorflow模型的读取和恢复就到这里了。
这个文章是基于某英文教程的精简和索引,旨在快速找到所需代码。
参考链接:
http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/