TensorFlow 的 Graph 计算流程控制

前两节我们了解到了 Graph 用来构建计算图,那么如果我们要构建比较复杂的计算图,需要使用一些什么操作呢?

TensorFlow 提供了几种操作和类用来控制计算图的执行流程计算条件依赖,我们来具体看看都有哪些方法以及它们的具体含义。

主要有如下计算图流控制方法,我们一个个讲解:

tf.identity

它返回一个和输入的 tensor 大小和数值都一样的 tensor ,类似于 y=x 操作,我们通常可以查到以下使用示例:

import tensorflow as tf

g = tf.Graph()
with g.as_default():
    x = tf.Variable(1.0, name='x')
    x_plus_1 = tf.assign_add(x, 1, name='x_plus')

    with tf.control_dependencies([x_plus_1]):
        y = x
        z=tf.identity(x,name='z_added')

    init = tf.global_variables_initializer()

    with tf.Session() as sess:
        sess.run(init)
        for i in range(5):
            print(sess.run(z))
            # 输出 2,3,4,5,6

        # 如果改为输出 print(sess.run(y)) ,则结果为 1,1,1,1,1

但是所有的博客资料都没有详细说明为什么会这样,以及在 Graph 的构建中,这个方法应该怎么使用。

起初一看,这个方法好像并没有什么用,只是把输入原样复制了一遍,但是实际上,tf.identity在计算图内部创建了两个节点,send / recv节点,用来发送和接受两个变量,如果两个变量在不同的设备上,比如 CPU 和 GPU,那么将会复制变量,如果在一个设备上,将会只是一个引用。

之所以会出现上面代码的情况,就是因为 y = x 并没有在计算图中占有一席之地,所以每次sess.run(y) 的时候都没有进行它的上一步 tf.control_dependencies 的操作,而 z 的计算则不同,它是计算图内部的节点,所以每次sess.run(z) 的时候都会进行 tf.control_dependencies 的操作,所以它输出的值每次都会更新; 并且,我们可以想见,如果我们同时输出 y 和 z,那么 y 的值也会同步更新为2,3,4,5,6。

我们打印该计算图内部所有的 OP,可以看到也是没有 y 这样一个计算节点存在的:

[<tf.Operation 'x/initial_value' type=Const>, <tf.Operation 'x' type=VariableV2>, <tf.Operation 'x/Assign' type=Assign>, <tf.Operation 'x/read' type=Identity>, <tf.Operation 'x_plus/value' type=Const>, <tf.Operation 'x_plus' type=AssignAdd>, <tf.Operation 'z_added' type=Identity>, <tf.Operation 'init' type=NoOp>]

那么在什么时候使用这个方法呢?

它是通过在计算图内部创建 send / recv节点来引用或复制变量的,最主要的用途就是更好的控制在不同设备间传递变量的值

另外,它还有一种常见的用途,就是用来作为一个虚拟节点来控制流程操作,比如我们希望强制先执行loss_averages_op或updata_op,然后更新相关变量。这可以实现为:

with tf.control_dependencies[loss_averages_op]):
  total_loss = tf.identitytotal_loss

或者:

with tf.control_dependencies[updata_op]):
  train_tensor = tf.identitytotal_loss,name='train_op'

在这里,tf.identity除了在执行 loss_averages_op之后标记total_loss张量被执行之外没有做任何有用的事情。

tf.tuple

这个方法创建一个 tensors 的元祖,但是只有在所有的 tensor 的值都已经计算出来的情况下才会返回该元祖。

这个方法创建一个 tensor 的元祖,但是只有在所有的 tensor 的值都已经计算出来的情况下才会返回该元祖。

它实际上相当于一个并行计算的 Join 流程控制机制, 每个 tensor 都可以并行的独立计算,但是返回的 tuple 只有在所有的并行计算都完成的情况下才会返回。

它最常用的用法是在某个函数中需要返回多个 tensor 变量值,就可以使用如下类似语句:

return tf.tuple([x, y, z]

tf.group

创建一个聚合多个 OP 的 OP ; 当这个 OP 完成后,它输入的所有 OP 也都完成了,相当于一个控制多个 OP 计算进度的功能,这个函数没有返回值

mul = tf.multiply(w, 2)  
add = tf.add(w, 2)  
group = tf.group(mul, add)  
print sess.run(group)  

或者:

update_ops = []
update_ops.append(grad_updates)
update_op = tf.group(*update_ops)

简而言之,这是一个聚合多个操作的操作,从示例中可以看出,它的主要用途是更加简便的完成多个操作。

tf.no_op

什么也不做,只是作为一个控制边界的占位符

这个有点难以理解了,为什么要有这样第一个东西,我们通过该方法的常用示例来看:

with tf.control_dependencies([a, b]):
    c = tf.no_op() 

可以看到,它通常是控制依赖一起使用,表示在 a,b 执行完之后什么都不做,起到一个边界的作用,其实,上面说的 group 函数和它作用一样,这个示例可以用以下代码替换:

c=tf.group(a, b)

tf.count_up_to

一个计数器,根据计数条件是否满足控制流程;它有两个主要参数,ref,limit,表示每次都在 ref 的基础上递增,直到等于 limit。

import tensorflow as tf

g = tf.Graph()
with g.as_default():
    x = tf.Variable(1, name='x',dtype=tf.int32)
    li = tf.count_up_to(x,4)
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        for i in range(5):
            print(sess.run(x))
            print(sess.run(li))
#1
#1
#2
#2
#3
#3
#4
# 然后就报 OutOfRangeError 错误

需要注意的是,x 的值也是同时在递增的,并且 li 的值是 x 的前一次的值。

tf.cond

这是一个根据条件进行流程控制的函数,它有三个参数,pred, true_fn,false_fn ,它的主要作用是在 pred 为真的时候返回 true_fn 函数的结果,为假的时候返回 false_fn 函数的结果,我们通过一个示例来看:

scale = tf.cond(tf.greater(height, width),
                lambda: x / y,
                lambda: y / x )

如果 height 更大,就执行 x / y,反之则是 y / x .

tf.case

这是一个多分枝流程控制函数,它有两个主要参数,pred_fn_pairs, 表示一个由布尔值和可调用的函数组成的 pair 构建出来的列表,default 是表示布尔值为false 时执行的函数。

def f(k): 
    return tf.constant(100)*k

def g(): 
    return tf.constant(2333)

min_index = tf.constant(2)

Case_0 = (tf.equal(min_index,0), lambda: f(0))
Case_1 = (tf.equal(min_index,1), lambda: f(1))
Case_2 = (tf.equal(min_index,2), lambda: f(2))
Case_3 = (tf.equal(min_index,3), lambda: f(3))

Case_List = [Case_0, Case_1, Case_2, Case_3]
result = tf.case(pred_fn_pairs=Case_List, default=g)

with tf.Session() as sess:
    print sess.run(result)

输出 200,如果把 min_index 改为 0-3 以外的数值,输出2333

pred_fn_pairs 也可以写成字典的形式:

pred_fn_pairs={ tf.less(x, y):f(1), tf.less(z, y):f(2) } 

tf.while_loop

这是一个类似于 while 循环的函数,它有三个主要参数,cond, 循环条件,body,循环语句,loop_vars,循环控制变量。

i = tf.constant(0)
c = lambda i: tf.less(i, 10)
b = lambda i: tf.add(i, 1)
r = tf.while_loop(cond=c, body=b, loop_vars=[i])

with tf.Session() as sess:
    print sess.run(r)
# 输出 10

import collections
Pair = collections.namedtuple('Pair', 'j, k')
print Pair.k
ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2)))
#print ijk_0
c = lambda i, p: i < 2
b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k)))
ijk_final = tf.while_loop(cond=c, body=b, loop_vars=ijk_0)

with tf.Session() as sess:
    print sess.run(ijk_final)
# 输出 (2, Pair(j=2, k=4))

该函数会重复循环体 body 知道条件 cond 为真,最后输出的是循环控制变量 loop_vars 最后的值,如果是多个值输出为列表。

TIP:任意以 tensor 作为输入的函数都可以使用 tf.convert_to_tensor 转换不是 tensor 的输入作为输入参数; 它可以使得 numpy arrays, Python lists, scalars 都可以作为输入。

参考资料:Control Flow | TensorFlow

    原文作者:lonlon ago
    原文地址: https://zhuanlan.zhihu.com/p/32540546
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞