TensorFlow有哪些有意思的接口设计

简介

TensorFlow是目前无可争议的、最流行的深度学习框架,其实有很多优良的接口设计,例如让用户无须感知底层实现的Swig Python API,还有可以用Python或者C++动态拓展的custom op(很方便集成到Serving),以及集成Model Signature的SavedModel模型格式。

当然TensorFlow也不乏让萌新或编程老人都感到疑惑的设计,例如多GPU训练需要用户实现tower结构和指定device,分布式训练需要自己写PS虽然几乎只有join一行代码(PS默认训练完不退出需要自己用TensorFlow Queue来解决),还有推荐使用的Dataset接口必须加上while(True) + except(OutOfRangeError)来配合使用,同时还有Keras、Estimator、Learn、Slim等各种类同的High-level API…… 当然这些“不合理”的接口设计是有存在的理由,作为最底层的建模语言暴露with device和PS/worker接口让开发者有可能实现更复杂的模型并行、数据并行、in-graph、between-graph结构。

最近一个月我们发现几个同样“有意思”的接口设计,在Github issue上有更多讨论,供大家学(tu)习(cao)。

1. FLAGS使用时解析只打印默认值

写过TensorFlow脚本的一般都用过tf.app.flags,通过TensorFlow的接口来定义脚本的命令行参数以及默认值,用法类似Python官方的argparse和configparser等。这几乎是所有官方TensorFlow代码的“最佳实践”,示例代码也很简单。

import tensorflow as tf

flags = tf.app.flags
flags.DEFINE_integer("image_width", 224, "Width of the image")
flags.DEFINE_integer("image_height", 224, "Height of the image")
flags.DEFINE_integer("channels", 3, "Channel of the image")
FLAGS = flags.FLAGS

# FLAGS.image_width

一般这样使用FLAGS是没有问题,也不必关心底层是基于argparse还是其他库实现,而当我们尝试运行时打印一下参数值时触发了一个“Bug”,就是一开始打印FLAGS内容只会取得默认值但使用一次参数后再次打印就会出现覆盖后的值,详见Issue The default values of tf.app.flags are printed event though passed parameters at the first time · Issue #20680 · tensorflow/tensorflow

《TensorFlow有哪些有意思的接口设计》
《TensorFlow有哪些有意思的接口设计》

触发这个问题有两个条件,一个是代码严谨的我会在脚本启动时打印覆盖或者默认的参数,第二个是tf.app.flags基于absl.flags实现的约定有点搓。首先介绍前者,为了避免命令行调用时传入错误的参数我们建议运行时检查,不同模型超参不同一般可以通过FLAGS对象获得所有key-value对,这里获取获取value的唯一方法就是FLAGS.__flags[key].value,代码如下。

FLAGS = flags.FLAGS
parameter_value_map = {}
for key in FLAGS.__flags.keys():
  parameter_value_map[key] = FLAGS.__flags[key].value
print("Parameters: {}".format(parameter_value_map))
# Parameters: {'channels': 3, 'image_height': 224, 'image_width': 224}

这里获得的value都是在Python代码定义的默认值,不管在命令行是否传入参数,显然达不到检查的效果,而我们只要随意调用FLAGS.channels、FLAGS.image_heightFLAGS.image_width任意一行代码,value就会更新为覆盖值,这就是“使用时解析”。这时我们去看FLAGS的代码实现,可以发现TensorFlow使用Google工程师的另一个开源项目abseil-py的参数解析代码 https://github.com/abseil/abseil-py/blob/master/absl/flags/_flag.py ,在主动调用这个类对象时(也就是__call__()函数)可以解析sys.argv的参数。因此无论是TensorFlow还是abseil都是通过读取Python的sys.argv来获得命令行参数的,而且有一个显式的parse过程,TensorFlow对abseil还有个wrapper封装,原生的在不parse前获取value会抛异常,而封装后在获取value时如果没有parse就做一次parse,代码就是最好的解析 https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/platform/flags.py#L79

《TensorFlow有哪些有意思的接口设计》
《TensorFlow有哪些有意思的接口设计》

绝大部分人在使用tf.app.flags时都不会去看源代码,也不知道TensorFlow每次在读参数值时都会检查参数是否被parse,这样的接口设计和实现让普通用户无须感知什么时候parse参数直接用就可以拿到最新值,却在性能上有所损失,可能留了一个坑让你在未调用任何值之前通过FLAGS.__flags[key].value会拿到默认值。目前Issue在讨论Pull-request未有实现思路。

2. 多HashTable从Checkpoint恢复被覆盖

对于TensorFlow的Checkpoint可以保存模型参数的作用很多人都知道,除了可以保存神经网络的矩阵权重外,还可以保存key-value对的HashTable,TensorFlow提供了tf.contrib.lookup大量op来实现。例如我们尝试将一个训练样本的可读label(字符串型)与训练label(整数型)以MutableHashTable的方式保存到Checkpoint和模型中,为了实现双向转化一般会有string-to-int和int-to-string两个哈希表,我们组同事就发现了多HashTable恢复被覆盖的Issue https://github.com/tensorflow/tensorflow/issues/19528

《TensorFlow有哪些有意思的接口设计》
《TensorFlow有哪些有意思的接口设计》

用户在定义两个MutableHashTable后,如果不指定name,可以直接把变量导出到Checkpoint中,然后写代码从Checkpoint中恢复,示例如下。

keys = tf.placeholder(dtype=tf.string, shape=[None])
values = tf.placeholder(dtype=tf.int64, shape=[None])
table1 = tf.contrib.lookup.MutableHashTable(tf.string, tf.int64, -1)
table2 = tf.contrib.lookup.MutableHashTable(tf.string, tf.int64, -1)
insert_table1 = table1.insert(keys, values)
insert_table2 = table2.insert(keys, values)
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(insert_table1, feed_dict={keys: ["a"], values: [1]})
    sess.run(insert_table2, feed_dict={keys: ["b"], values: [2]})
    print "table1:", sess.run(table1.export())
    print "table2:", sess.run(table2.export())
    saver.save(sess, "checkpoint/test")

整个过程不会报任何错误,但从Checkpoint恢复的结果看到,tabel1的值是空的,而table2可以正常恢复。这又是一次入坑TensorFlow源代码的好机会,我们去看tf.contrib.lookup.MutableHashTable的Python代码实现,这个类里面是有name属性的,而且name会提供一个默认值“MutableHashTable”,如果后续导出Checkpoint会基于传入或者默认的这个name来标识 https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lookup/lookup_ops.py#L289

《TensorFlow有哪些有意思的接口设计》
《TensorFlow有哪些有意思的接口设计》

因此前面其实强调过了,如果定义了多个HashTable而不指定name的话就会出现被覆盖的非预期行为,如果用户习惯好为每个op定义unique name就不会触发这个异常了。这个场景我是觉得非常熟悉的,一般TensorFlow用户写脚本通常不会为每个tf.add()、tf.muliple()定义op name的,尤其我们用了TensorFlow给我们写好的Python操作符重载,但TensorFlow的Graph和Checkpoint是非常依赖op name的,因此TensorFlow的Python API会为用户的每一个op设置一个unique name,即使你定义的时候传入一个有冲突的name也会通过后缀的方式改为unique name。这个逻辑是在TensorFlow Python端实现的,也就是说如果直接调C++ API或者其他语言API需要注意下。

同样的逻辑在 https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/control_flow_ops.py#L1680 也有体现,但所有的Table目前都只提供一个固定的默认值,如果多个Table同时导出而不指定name会因为被覆盖问题引发更多逻辑的错误。目前Issue在讨论Pull-request已发(通过类似Variable的分配unique op name方案能否merge需要讨论) https://github.com/tensorflow/tensorflow/pull/20657

3. Dataset保存Epoch进度与Shuffle冲突

Dataset是目前主推的数据读取接口,支持丰富的功能,包括指定epoch number、batch size、是否shufle、是否cache等诸多高级功能,而且可以配合自定义map函数实现TFRecords、CSV等不同数据格式文件的解析读取,为了避免filename过多存入Graph过大还提供了placeholer动态加载文件名等功能。说了这么多,有一个问题是用户会在TensorFlow脚本中加入“断点延跑”的功能,一般会把训练进度以及参数存储Checkpoint中,而Dataset是知道用户训练的epoch index以及batch index的,在Dataset设计之初也考虑到这点实现SaveInternal接口 https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc#L136

问题是我们在使用时触发了这个Issue,会直接导致模型尝试加入Dataset的参数后保存Checkpoint失败,详见Issue Fail to save the checkpoint with dataset Iterator saveable when using shuffle · Issue #18583 · tensorflow/tensorflow

《TensorFlow有哪些有意思的接口设计》
《TensorFlow有哪些有意思的接口设计》

在我们排查后发现,基于placeholer的TensorSliceDataset确实可以把进度保存到Checkpoint中,但如果在dataset中加入shuffle op,导出失败。发现导因就很容易理解了,用户如果加入了shuffle,训练的数据在一个buffer size中打乱,Dataset可以保证一个epoch内所有数据都被使用到了,但如果只是把进度保存到Checkpoint里,下次运行重新shuffle可能会遗漏一部分数据。

这个问题也不是一定无解,一种方案是按现在的API实现,如果shuffle就无法把Dataset的iterator变量保存到Checkpoint,用户可以提前准备乱序的数据集不使用shuffle来规避。如果非要在API支持我也有两种思路,一是把Dataset iterator不保证每个epoch可以处理所有数据,这样恢复Checkpoint只需要重新做shuffle而不管数据是否被处理过,另一种是把shuffle结果也保存到Checkpoint中,恢复时可以感知上次shuffle结果就可以正常restore了,当然这两种方案即使实现出来也会有用户来写文章来喷的。目前Issue在讨论Pull-request未有思路。

4. 其他

除了最近我们提的这些Issue,还有一些接口设计也是让人觉得“有意思”。例如以前我们在社区提过,分布式TensorFlow的Parameter server能否在训练结束后自动退出,社区的回复是不能,但你可以实现成这样。怎么实现呢?一般我们实现分布式TensorFlow的PS,包括官方文档介绍也是会调用server.join(),一看名字就知道这个函数是阻塞的,因此即使所有worker训练完都结束后PS也会继续阻塞不会退出。那么社区有人思路新奇,阻塞的函数不只是join(),我们可以用TensorFlow API提供的queue,PS可以从queue中获取足够数据的东西(这个数量就是Worker的个数),然后worker在训练结束后在相同队列塞数据,这样PS会一直在队列阻塞,直到所有worker都结束了才从队列中拿到足够多的东西然后退出Python进程。代码是这样的感兴趣可以看看 https://github.com/tobegit3hub/distributed_tensorflow/blob/master/auto_stop_ps/task.py

# If is PS
queue = worker_done_queues[task_index]
dequeue_op = queue.dequeue()

for i in range(master_worker_number):
  sess.run(dequeue_op)
  logging.info("{} workers are already done".format(i + 1))


# If is Worker
enqueue_ops = []
for queue in worker_done_queues:
  enqueue_op = queue.enqueue(1)
  enqueue_ops.append(enqueue_op)

for enqueue_op in enqueue_ops:
  sess.run(enqueue_op)

分布式TensorFlow接口还有一个值得吐槽的,就是可以使用种类繁多的Session封装,例如Supervisor和MonitoredTrainingSession。使用Supervisor在网络抖动的情况下可能导致与PS或者Worker失败而无法使用Session,而使用MonitoredTrainingSession获得的Session对象,实际上因为类型不匹配不能直接用官方的saved_model_builder来导出SavedModel模型。脑洞新奇的网友想到了在MonitoredTrainingSession内把模型参数导出到本地Checkpoint,然后new tf.Session()来加载Checkpoint导出模型。当然我们用的是另一种思路解决这个问题,在MonitoredTrainingSession加入了一个SavedModelHook,注意是加在chief_only_hooks而不是普通hooks里(避免所有Worker都去Export一遍),然后继承了SessionRunHook API在end()的时候获取Session对象来保存模型,所以当你看到下面的代码也不用经验,这里面凝聚了TensorFlow API设计师的智慧。

class SavedModelHook(tf.train.SessionRunHook):
  def end(self, session):
    saved_model(session, model_path, FLAGS.model_version,
                model_signature, legacy_init_op)

chief_hooks = [SavedModelHook()]

with tf.train.MonitoredTrainingSession(
    master=server.target,
    is_chief=(task_type == "master"),
    chief_only_hooks=chief_hooks) as mon_sess:

总结

前面对TensorFlow的部分API设计有一些调侃,当然作为开发者,我们也深知里面涉及的天生缺陷以及实现难度。TensorFlow作为一个“全能”的深度学习框架,除了基本的autograd功能,还实现了一个DAG管理和control flow,里面还有通用的queue队列来实现filename管理和shuffle data管理(包括让用户基于queue实现auto ps exit),和通用的Filesystem API支持local、hdfs、s3等各种文件系统,还有能在CPU、GPU、TPU统一接口的大量operator、layer抽象,而且上述的这些都需要在op上实现,才能保证在一个C++ tensorflow::ClientSession中运行。

因此所有的接口设计很难在易用性和拓展性都做到完美,例如前面介绍的FLAGS在__getattr__都检查和parse一遍的约定就可能导致其他API无法感知,而很多接口在拓展性强的时候使用起来也稍微比较复杂,例如Distributed TensorFlow确实很复杂以后可以单独展开介绍。这时我们可以多给TensorFlow社区提Issue和Pull-request,学习TensorFlow源代码的过程就是最好的使用TensorFlow过程,期待更多TensorFlow Contributor加入(帮我们解决上面几个问题)。

    原文作者:tobe
    原文地址: https://zhuanlan.zhihu.com/p/39625368
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞