TensorFlow加载部分模型的一个技巧

TensorFlow在save/restore模型时有官方的函数,具体可参考tensorflow保存和恢复模型的两种方法介绍,但是我最近碰到的问题显然超出了这个官方函数的范畴,先描述下这个问题吧:

模型可以简单描述为Input->A->B->Output 1,这是个基本模型。某一天我心血来潮,想加一个分支,另外一个分支是这样的:Input->A->C->Output 2,因为C路径非常短,而且根据大多数人的通常做法,A里面的参数都是可以共用的吗,所以我想我可不可以这样搞呢:把Input->A->B->Output 1只有1个分支的模型参数restore到{Input->A->B->Output 1, Input->A->C->Output 2}有2个分支的模型里。

结果是Error:

NotFoundError (see above for traceback): Key */* not found in checkpoint

一种解决方法可以参考tensorflow 加载部分变量,得写变量名,鉴于我的变量名实在是太多了,而且每次都要人工去写显得好low啊,我觉得抛弃这种方法了。

我想这种应用太常见了,TensorFlow官方应该会考虑到这点吧,所以就找吧。

首先找了到官方文档python/tf/train/Saver#__init__,找了一圈是没有的,直接找restore的函数吧,python/tf/train/Saver#restore,里面的参数少的有种想打人的冲动:

restore(
    sess,
    save_path
)

这不符合我的习惯,我想我一定是哪个参数没看到,还是直接进源代码找吧,里面肯定有,源代码在python/training/saver.py,restore()函数从1243行开始,函数很短,恢复参数的代码为第1271~1276行:

try:
      if context.executing_eagerly():
        self._build_eager(save_path, build_save=False, build_restore=True)
      else:
        sess.run(self.saver_def.restore_op_name, {self.saver_def.filename_tensor_name: save_path})

然后发现官方真是想的很周全,有三种NotFoundError的情况:

except errors.NotFoundError as err:
      # There are three common conditions that might cause this error:
      # 0. The file is missing. We ignore here, as this is checked above.
      # 1. This is an object-based checkpoint trying name-based loading.
      # 2. The graph has been altered and a variable or other name is missing.

第2种情况就是我所面临的情况了:

except errors.NotFoundError:
        # 2. This is not an object-based checkpoint, which likely means there
        # is a graph mismatch. Re-raise the original error with
        # a helpful message (b/110263146)
        raise _wrap_restore_error_with_msg(
            err, "a Variable name or other graph key that is missing")

结论就是依然会报错,所以我要吐槽TensorFlow的一个点就是知道了还不给我解决。。。

所以官方的这个就得抛弃了,这时候我想起了我以前写的那篇文章[代码分享]将ckpt模型文件转成npy模型文件,思路就是:先读取ckpt文件至一个map中,然后逐渐将map中的节点的值assign给模型,代码写好了是这样:

# restore the network using another method
# first change the checkpoint to variable map, then assign single variable to the graph one by one
ckpt = tf.train.get_checkpoint_state(os.path.join(self.cfg.OUTPUT_DIR, self.element))
from tensorflow.python import pywrap_tensorflow
reader = pywrap_tensorflow.NewCheckpointReader(ckpt.model_checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
	try:
		with tf.variable_scope(key.split("/")[0], reuse=tf.AUTO_REUSE):
			var = tf.get_variable(key.split("/")[1])
			self.sess.run(var.assign(reader.get_tensor(key)))
			_print ('assign pretrain model to ' + key)
	except ValueError as e:
		_print (e)
		_print ('ignore ' + key)

最后能够完美执行了,结果和官方函数tf.saver.restore()差不多,差的一点在于这样的报错:

Shape of a new variable (*/biaes) must be fully defined, but instead was <unknown>.

原因在于我写这个偏置的时候维度并没有给定,官方函数tf.saver.restore()恢复的时候会把计算图也恢复进来(tf保存的几个文件的意思可以参考TensorFlow学习笔记(8)–网络模型的保存和读取 – 对角巷 – CSDN博客),而我只是把参数assign了一遍,一个小小的遗憾吧,在我的代码里下降了0.2个百分点。

【已完结】

    原文作者:Michael
    原文地址: https://zhuanlan.zhihu.com/p/51906234
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞