TensorFlow在save/restore模型时有官方的函数,具体可参考tensorflow保存和恢复模型的两种方法介绍,但是我最近碰到的问题显然超出了这个官方函数的范畴,先描述下这个问题吧:
模型可以简单描述为Input->A->B->Output 1,这是个基本模型。某一天我心血来潮,想加一个分支,另外一个分支是这样的:Input->A->C->Output 2,因为C路径非常短,而且根据大多数人的通常做法,A里面的参数都是可以共用的吗,所以我想我可不可以这样搞呢:把Input->A->B->Output 1只有1个分支的模型参数restore到{Input->A->B->Output 1, Input->A->C->Output 2}有2个分支的模型里。
结果是Error:
NotFoundError (see above for traceback): Key */* not found in checkpoint
一种解决方法可以参考tensorflow 加载部分变量,得写变量名,鉴于我的变量名实在是太多了,而且每次都要人工去写显得好low啊,我觉得抛弃这种方法了。
我想这种应用太常见了,TensorFlow官方应该会考虑到这点吧,所以就找吧。
首先找了到官方文档python/tf/train/Saver#__init__,找了一圈是没有的,直接找restore的函数吧,python/tf/train/Saver#restore,里面的参数少的有种想打人的冲动:
restore(
sess,
save_path
)
这不符合我的习惯,我想我一定是哪个参数没看到,还是直接进源代码找吧,里面肯定有,源代码在python/training/saver.py,restore()函数从1243行开始,函数很短,恢复参数的代码为第1271~1276行:
try:
if context.executing_eagerly():
self._build_eager(save_path, build_save=False, build_restore=True)
else:
sess.run(self.saver_def.restore_op_name, {self.saver_def.filename_tensor_name: save_path})
然后发现官方真是想的很周全,有三种NotFoundError的情况:
except errors.NotFoundError as err:
# There are three common conditions that might cause this error:
# 0. The file is missing. We ignore here, as this is checked above.
# 1. This is an object-based checkpoint trying name-based loading.
# 2. The graph has been altered and a variable or other name is missing.
第2种情况就是我所面临的情况了:
except errors.NotFoundError:
# 2. This is not an object-based checkpoint, which likely means there
# is a graph mismatch. Re-raise the original error with
# a helpful message (b/110263146)
raise _wrap_restore_error_with_msg(
err, "a Variable name or other graph key that is missing")
结论就是依然会报错,所以我要吐槽TensorFlow的一个点就是知道了还不给我解决。。。
所以官方的这个就得抛弃了,这时候我想起了我以前写的那篇文章[代码分享]将ckpt模型文件转成npy模型文件,思路就是:先读取ckpt文件至一个map中,然后逐渐将map中的节点的值assign给模型,代码写好了是这样:
# restore the network using another method
# first change the checkpoint to variable map, then assign single variable to the graph one by one
ckpt = tf.train.get_checkpoint_state(os.path.join(self.cfg.OUTPUT_DIR, self.element))
from tensorflow.python import pywrap_tensorflow
reader = pywrap_tensorflow.NewCheckpointReader(ckpt.model_checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
try:
with tf.variable_scope(key.split("/")[0], reuse=tf.AUTO_REUSE):
var = tf.get_variable(key.split("/")[1])
self.sess.run(var.assign(reader.get_tensor(key)))
_print ('assign pretrain model to ' + key)
except ValueError as e:
_print (e)
_print ('ignore ' + key)
最后能够完美执行了,结果和官方函数tf.saver.restore()差不多,差的一点在于这样的报错:
Shape of a new variable (*/biaes) must be fully defined, but instead was <unknown>.
原因在于我写这个偏置的时候维度并没有给定,官方函数tf.saver.restore()恢复的时候会把计算图也恢复进来(tf保存的几个文件的意思可以参考TensorFlow学习笔记(8)–网络模型的保存和读取 – 对角巷 – CSDN博客),而我只是把参数assign了一遍,一个小小的遗憾吧,在我的代码里下降了0.2个百分点。
【已完结】