tensorflow,存储读取数据结构剖析与合并多个graph,看不懂你掐死我

我又一次开始了“看不懂你掐死我系列”。标题名称是仿照知乎的一篇介绍傅里叶变换的文章起的。当时看完了觉得还真看懂了。可是关上网页再自己想的时候,就有想掐死博主的冲动~~ 为了致敬,这里贴出原文章,大家共勉。

《tensorflow,存储读取数据结构剖析与合并多个graph,看不懂你掐死我》 网上下载的图片不认识,如有冒犯请联系我换图片

抄袭标题:看不懂傅里叶变换就掐死他

这段时间做训练的时候需要分步训练不同的网络结构,最后把所有训练好的graph合并成一个大graph,前后接起来并且重新定义输入和输出再继续训练,这样分步先训练小网络再合成大网络的话效果会好一点,收敛的也会快一些。那么有个问题,怎么把训练好的好几个graph恢复训练参数再合并到一起呢?Tensorflow到底能不能这么做?如果能,那应该怎么做?在读了这篇这篇知乎,和搜了无数个stackoverflow上的例子之后,终于有了答案。

要知道我们需要把每个pretrain的网络的结构和参数全都读进去,再把它们合并在一起。先不说合并的事,读取参数和结构就是个问题。比如下边这几个stackoverflow的帖子。12345。他们都用了不同的读取方法。但是到底读取的是什么?有没有达到我们预期的目的却不清楚。所以我意识到先要把tensorflow的内部结构搞清楚,看看存有什么东西,再看看存储和读取的方式。先来看结构。

Tensorflow的内部结构:

上面的这篇知乎都说的挺清楚的,我就捡这最重要的总结一下。

我们都知道tensorflow里有graph,graph的节点就是运算operation。这个用tensorboard可视化可以看到。比如下面这就是个简单的graph。

《tensorflow,存储读取数据结构剖析与合并多个graph,看不懂你掐死我》 graph示例

这个graph在tensorflow里实际的存储方式是被序列化以后,以Protocol Buffer的形式存储的。这里有中文的对protobuf的介绍,是google开发的。

graph序列化的protobuf叫做graphDef,就是define graph的意思,一个graph的定义。这个graphDef可以用tf.train.write_graph()/tf.Import_graph_def()来写入和导出。上面stackoverflow里就有人用这个方法。然而graphDef里面其实是没有存储变量的,但是可以存常量,就是constant。可以用一种叫freeze_graph的工具把变量替换成常量,这里有官方的介绍。一般来说没有必要这么做,因为既然存了网络,肯定有变量的信息,虽然不在graphDef里面,但是肯定在别的地方。其实它存在collectionDef里。还有一些其他的Def,所以干脆归纳一下:

MetaGraphMetaInfoDef 这个是存metadata的,像版本信息啊,用户信息啥的

                    – GraphDef 上面说的就是这个GraphDef

                    – SaverDef 这个就是tf.train.Saver的saver

                    – CollectionDef

这些Def的数据都存在一个叫MetaGraph的文件里。这个MetaGraph有官方介绍

最后面的collectionDef就是各种集合。每个集合里都是1对多的key/value pairs。你也可以把你想要的变量存进某个即合理,用tf.add_to_collection(collection_name,变量)就行。然后再用tf.get_collection()取出来。比如我有loss和train_op,就可以:

tf.add_to_collection(“training_collection”,loss)

tf.add_to_collection(“training_collection”,train_op)

然后再用

Train_collect = tf.get_collection(“training_collection”)  #得到一个python list

list里面就是你之前存的东西。所以collection我的理解就是为了方便管理变量用的。

metagraph可以用export_meta_graph/Import_meta_graph来导入导出。

这里注意了,如果你用tf.Import_graph_def()导入graphDef的话,导入的东西一般是不能训练的。但是用Import_meta_graph来导入metagraph之后,就是导入了一个完整的结构,这时候是可以训练的

虽然能训练,metagraph里也有变量,但是都是起始值。也就是说我们之前训练的参数是没有导入的。这里训练等于是从头训练。实际的训练参数没有存在metagraph里,而是在data文件里。这个下面会提到。

说完了tensorflow的结构,再说说存储的方式。看完这节,你应该完全知道什么api是用来读什么的了。

存储与读取:

上面那篇中文知乎恰好总结了这些。一般存读有3个API:

tf.train.Saver()/saver.restore()

export_meta_graph/Import_meta_graph

tf.train.write_graph()/tf.Import_graph_def()

后两个上一节都见过了。现在说说第一个。

我平时常用的只有第一个tf.train.Saver()和saver.restore()。我也看到很多代码里这么写。但是有一点很坑爹的是tf.train.saver.save() 什么都保存。但是在恢复图时,tf.train.saver.restore() 只恢复 Variable,如果要从MetaGraph恢复图,需要使用 import_meta_graph。看明白了吗?saver.save()和saver.restore()保存和读取的东西不!一!样!也就是说如果我想重组graph,要么用Import_meta_graph来导入graph,之后再saver.restore();要么就从新建立graph,把tensor传入结构的过程再写一遍,然后再saver.restore()。不然连变量名都找不到肯定会报错。

说道存储,我们必须得看看存储文件的格式。如果你用saver.save()保存的话(好像也只有这一种方法),打开你的保存文件夹,你会看到4种后缀名的文件(events开头的不算,那是tf.summary生成给tensorboard用的),分别是:

checkpoint – 就是一个账本文件,可以使用高级帮助程序来加载不同的时间保存的chkp文件。没什么用

.meta – 保存压缩后的Metagraph的protobufs,其实就是Metagraph。

.index – 包含一个不可变的键值表,用于链接序列化的张量名称以及在chkp.data文件中查找其数据的位置,也没存什么实际东西

.data – 这个里面才是存了训练后的参数。通常比.meta要大。有的时候有多个data文件用于共享或创建多个训练的时间戳。

其中.data文件的名字一般都是这种格式的:

<prefix>-<global_step>.data-<shard_index>-of-<number_of_shards>.

比如:

《tensorflow,存储读取数据结构剖析与合并多个graph,看不懂你掐死我》 存储名的例子

所以saver.restore()的时候其实是restore的.data文件。当然在restore之前可以用tf.train.latest_checkpoint()来得到最后一次存储点。还有一点是在saver.save()和restore的时候,那个文件对象是xxx.ckpt。但实际上在存储文件夹里你找不到xxx.ckpt文件。这个也是正常的。官方文档有说.ckpt文件其实是隐性的的。所以除非你文件名字输入错了,不然不用担心读错文件。

下面结合我的实例再看看怎么合并graph。

实例:

先稍微介绍一下网络的结构。我有四个网络结构。其中3个网络是平行的,这里就叫p1,p2和p3吧。最后一个网络是微调用的,就叫m吧。这个m会得到3个网络的输出,合并在一起作为m的输入,输入到m,最后得到最终结果。为了方便理解我画了个图。

《tensorflow,存储读取数据结构剖析与合并多个graph,看不懂你掐死我》 总体网络结构示意图

如果直接训练这么大的网络,收敛起来一定很费劲,有可能某一个网络落到一个local minimum就出不去了。所以我们把p1,p2,p3拿出来单独训练,每次只训练一个。

我分别用数据训练这3个网络。这个训练阶段算是pretrain。待到三个网络都稳定的时候,我把它们的输出结果加在一起,输入到第四个网络里训练整个网络。

官方文件称feed_dicts是效率最低的方法,所以我们改用的tfrecord和dataset api来读取文件。如果你不清楚这是啥,可以参看我们办公室博导的简书,这家伙可厉害了~

现在有两个问题,1是用Import_meta_graph导入metagraph的方法没法合并graph,因为我写的数据导入之后拿不出来(或者说我不知道怎么拿出来,可能有api可以取出来)。p1,p2,p3的输出数据是要手动连接的。import_graph_def()也可以设置input,output mapping,但是我这里没有tf.placeholder。我必须拿到一个从p1,p2,p3合成出来的tensor,再塞到m里去。所以我选择了用重建graph的方法。用

traindata, label = data_iterator(tfrecord_path).get_next() 

得到数据,再把traindata分别放入p1,p2,p3的架构中:

out_p1 = networkp1(trandata_p1)

网络结构有了,再restore参数:

full_path = tf.train.latest_checkpoint(model_ckp)

saver.restore(sess, full_path)

p2和p3也这么做。

三个全恢复了会得到三个output,再合并

m_data = out_p1 + out_p2 + out_p3

再输入m中:

output_m = networkm(m_data)

之后再做loss,bp,summary啥的,就可以训练了。

《tensorflow,存储读取数据结构剖析与合并多个graph,看不懂你掐死我》

需要注意的是,别恢复错了graph。不要建3个session下分别用3个graph恢复,因为那样到

m_data = out_p1 + out_p2 + out_p3 #如果三个out是不同的graph,这里会报错

这一步会报错。说不同的graph出来的结果是不能相互运算的。大家必须是在同一个graph里才行。所以要建一个session,在这个session下挨个恢复:

with tf.session as sess: # 下面每个restore里不要单建 with tf.graph():… 

    # restore p1

    # restore p2

    # retore p3

    # ….

等于是把大家依次放进default graph里。再填上最后的m就ok了。

references:

https://blog.metaflow.fr/tensorflow-saving-restoring-and-mixing-multiple-models-c4c94d5d7125

https://zhuanlan.zhihu.com/p/31308381

https://www.tensorflow.org/api_guides/python/meta_graph#What_s_in_a_MetaGraph

https://www.jianshu.com/p/0f9f2bb962f4

stackoverflow:

https://stackoverflow.com/questions/41990014/load-multiple-models-in-tensorflow

https://stackoverflow.com/questions/45093688/how-to-understand-sess-as-default-and-sess-graph-as-default

https://stackoverflow.com/questions/49864234/tensorflow-restoring-variables-from-two-checkpoints-after-combining-two-graphs

https://stackoverflow.com/questions/49490262/combining-graphs-is-there-a-tensorflow-import-graph-def-equivalent-for-c

https://stackoverflow.com/questions/41607144/loading-two-models-from-saver-in-the-same-tensorflow-session

    原文作者:木木爱吃糖醋鱼
    原文地址: https://www.jianshu.com/p/ca637520002f
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞