一、python存储模型的方法
好了,进入正题,在python中如何存储tensorflow模型。
tf.saved_model.builder(推荐)
tf.saved_model是tensorflow官网推荐的一个保存模型的方法,只要你输入保存模型的路径,就可以使用。基本使用方式如下:
import tensorflow as tf
input=…
export_dir=…
…
build net…
…
#指定存储路径
builder =tf.saved_model.builder.SavedModelBuilder(export_dir)
with tf.Session() as sess:
#下段话只能调用一次
builder.add_meta_graph_and_variables(sess,[‘custom’])
builder.save()
其中,export_dir必须指定为一个不存在的路径,否则会报错。上面一段代码中,我们建立了一个名叫’custom’的网络,并将其保存在export_dir中,文件结构如下:
|-saved_model.pb-|
|-variables-|
|-variables.data-00000-of-00001-|
|-variables.index-|
其实和下一个方法tf.save存出来的文件差不多。pb文件中是网络结构信息,index文件中是参数值。
tf.train.saver
tf.train.saver是1.3版本之前主要的模型存储方式,在新版本中也兼容,但已经不是最推荐的方式了。它的使用方式也很简单:
import tensorflow as tf
input=…
model_path=…
…
build net
…
saver = tf.train.Saver()
with tf.Session() as sess:
saver.save(sess,model_path)
tf.save.saver产生的文件结构如下:
|-saved_model.meta-|
|-saved_model.data-00000-of-00001-|
|-saved_model.index-|
1
2
3
meta文件中存储网络结构,index文件中存储参数信息。
tf.saved_model.builder和tf.train.saver方法比较
tf.saved_model.builder方法:
优点:
1.只需要指定一个存储路径。存储、读取都很方便。
2.可以存多段网络,参数可以复用。比如现在有一个GAN网络模型,用tf.saved_model.builder指定相应tag以后,可以同时存生成网络、鉴别网络和整个网络。之后读取时,只要读需要的那一部分即可,大大加快读取速度。提升内存利用率。
3.在tensorflow推荐的estimate(一种更高级的机器学习API,以后填坑)流程中,扮演主要的模型存储方法。
4.便于分布式读取及使用。
缺点:
1.只能保存一次参数
2.对于一个目录,只能导出一个模型。(但可以改变目录名)
3.不灵活。
4.速度慢。
tf.train.saver方法:
优点:
1.灵活。可以指定保存模型的名称、后缀、多长时间保存一次、最多保存多少个模型等等。
2.应用范围广。如果你使用tf.contrib.Slim库(类似tensorlayer的一种高级库)训练模型,那么只能用此方法保存模型。
3.速度快。
缺点:
1.保存多个模型比较复杂。
二、python读取模型的方法
tensorflow读取模型的方法也很简单。我对应的介绍一下。
tf.saved_model.loader
如果你使用tf.saved_model.builder存储模型的话,那么可以使用tf.saved_model.loader读取模型。只输入一个模型存储的路径即可。简单的例子:
export_dir = …
…
build net…
…
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(sess, [‘custom’], export_dir)
…
可以看到,该方式读取模型非常简单,只需要模型路径和网络标签即可,函数内部会自动加载网络模型和恢复参数。
tf.train.saver.restore
该方法需要先恢复网络结构(如果你有了定义网络的py文件,可以跳过此步,等价的),再读取参数。简单的例子:
model_path=…
#恢复网络结构
saver = tf.train.import_meta_graph(model_path + ‘.meta’)
with tf.Session() as sess:
#读取参数
saver.restore(sess, model_path)
graph = sess.graph
input = graph.get_tensor_by_name(‘input:0’)
…
prediction…
…
python的模型存取方式就介绍到这里,更多有关tf.train.save和tf.saved_model的区别请点这里。
- 什么是TensorFlow模型?
训练了一个神经网络之后,我们希望保存它以便将来使用。那么什么是TensorFlow模型?Tensorflow模型主要包含我们所培训的网络参数的网络设计或图形和值。因此,Tensorflow模型有两个主要的文件:
a) Meta graph:
这是一个协议缓冲区,它保存了完整的Tensorflow图形;即所有变量、操作、集合等。该文件以.meta作为扩展名。
b) Checkpoint file:
这是一个二进制文件,它包含了所有的权重、偏差、梯度和其他所有变量的值。这个文件有一个扩展名.ckpt。然而,Tensorflow从0.11版本中改变了这一点。现在,我们有两个文件,而不是单个.ckpt文件:
mymodel.data-00000-of-00001
mymodel.index
.data文件是包含我们训练变量的文件,我们待会将会使用它。
与此同时,Tensorflow也有一个名为checkpoint的文件,它只保存的最新保存的checkpoint文件的记录。
因此,为了总结,对于大于0.10的版本,Tensorflow模型如下:
在0.11之前的Tensorflow模型仅包含三个文件:
inception_v1.meta
inception_v1.ckpt
checkpoint
现在我们已经知道了Tensorflow模型的样子.