如何快速保存和读取Tensorflow模型

本文分为如下四个部分:

Ⅰ.Tensorflow模型结构

Ⅱ.如何保存Tenflow模型

Ⅲ.如何读取模型并训练数据

Ⅳ.如何拓展模型与修正

一、Tensorflow模型结构

如果我们要存储一个完整的神经网络模型,我们需要存储graph和模型参数两个部分。

依此,我们存储的模型结构也主要分为两类文件。

1.图文件(Meta graph)

Google研发的二进制存储文件(protocol buffer),用于存储完整的图。

model.ckpt.meta

2.数据文件(存储latest Checkpoint)

在0.11版本之前,以.ckpt后缀存储模型的值。之后的版本中,以如下文件形式存储。

checkpoint				#c最近存储记录
mymodel.data-00000-of-00001		
mymodel.index

二、如何存储tensorflow模型

存储模型主要用到tf.train.Saver()类,其可以:

1全部参数或部分参数存储

2按照迭代次数或指定时间存储

3指定存储最近文件的数量

1.1最基础的存储代码如下:

import tensorflow as tf
#代码
saver=tf.train.Saver()
with Session() as sess:
    saver.tf.train.Saver()
    sess.run(tf.global_variables_initializer())
    saver.save(sess,savepath+"my_model")

1.2如果只存储一部分参数:

saver=tf.train.Saver([name1,name2])

2.1如果按照迭代次数存储:

saver.save(sess,savepath+"my_model",global_step=1000)
#对应存储文件名字会自动加上-1000后缀如:
#my_model-1000.meta
#my_model-1000.index
#my_model-1000.data-00000-of-00001
#第一次存储是从第1000次开始的

2.2如果按照现实时间存储:

saver=tf.train.Saver(keep_checkpoint_every_n_hours=2)  
#每两小时存储一次

3.1如果存储最近n次的模型:

saver=tf.train.Saver(max_to_keep=4)  
#记录最近4次的模型

三、如何读取模型并训练数据

从第一部分得知,tensorflow模型分为graph和参数值两部分。对应存储文件也分为两类。依此模型的恢复也对应分为两个部分。

1.载入图

new_saver=tf.train.import_meta_graph("my_model.meta")

2.载入参数值

#载入最近一次存储模型
new_saver.restore(sess,tf.train.latest_checkpoint('./'))
#载入指定模型,注意这里没有后缀00000-of-00001
new_saver.restore(sess,"my_model")

3.获取placeholder和operations

其中值为定义图中的name,需要预先设置好并加:0

#将要学习的新数据
x = graph.get_tensor_by_name("x:0")
y_ = graph.get_tensor_by_name("y_:0")
keep_prob=graph.get_tensor_by_name("keepprob:0")
lout = graph.get_tensor_by_name("lout:0")
feed_dict ={x:data,y_:label}

4.根据需要进行继续投喂或验证

#loss = tf.sqrt(tf.reduce_mean(tf.square((lout-y_))))
#train=tf.train.GradientDescentOptimizer(lr).minimize(loss)
#sess.run(train,feed_dict={x:data[0],y_:data[1],keep_prob:0.8})  

四、如何拓展与使用模型

1拓展图

lout = graph.get_tensor_by_name("lout:0")
nlout = tf.add(lout,2)
......
......

2使用部分图

......
......
fc3= graph.get_tensor_by_name('fc3:0')
fc3 = tf.stop_gradient(fc3) # 阻挡节点的梯度
fc3_shape= fc3.get_shape().as_list()   
......
......

自此,对于tensorflow模型的读取和恢复就到这里了。

这个文章是基于某英文教程的精简和索引,旨在快速找到所需代码。

参考链接:

http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/

    原文作者:瓦娜与欧洛因
    原文地址: https://zhuanlan.zhihu.com/p/36180893
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞