python – tensorflow在使用tf.cond()时要求输入不必要的占位符

请考虑以下包含tensorflow tf.cond()的代码段.

    import tensorflow as tf
    import numpy as np

    bb = tf.placeholder(tf.bool)
    xx = tf.placeholder(tf.float32, name='xx')
    yy = tf.placeholder(tf.float32, name='yy')

    zz = tf.cond(bb, lambda: xx + yy, lambda: 100 + yy)

    with tf.Session() as sess:
            dict1 = {bb:False, yy:np.array([1., 3, 4]), xx:np.array([5., 6, 7])}
            print(sess.run(zz, feed_dict=dict1)) # works fine without errors

            dict2 = {bb:False, yy:np.array([1., 3, 4])}
            print(sess.run(zz, feed_dict=dict2)) # get an InvalidArgumentError asking to
                                                 # provide an input for xx

在这两种情况下,bb都是False,理论上zz的评估不依赖于xx,但仍然是tensorflow需要xx的输入.尽管它可以作为虚拟阵列提供,但它必须与yy的形状匹配,并且不像dict2那样干净.

任何人都可以建议如何评估zz(使用tf.cond()或任何其他方法)而不提供xx的值?

最佳答案 你可以将xx定义为tf.Variable,给它一个默认值(只要xx没有用另一个值输入,就会使用它).有几点需要注意:

>虽然xx不是占位符 – 您仍然可以通过feed_dict将值添加到其中来对待它.
>使用validate_shape = False,以便您可以将任何形状提供给xx.
>使用trainable = False以使xx不被优化(否则,优化器可能会将其默认值更改为Nan,这可能会导致问题).
>不要忘记使用例如tf.global_variables_initializer()来初始化xx的值.

这是代码:

import tensorflow as tf
import numpy as np

bb = tf.placeholder(tf.bool)
xx = tf.Variable(initial_value=0.0,validate_shape=False,trainable=False,name='xx')
yy = tf.placeholder(tf.float32, name='yy')

zz = tf.cond(bb, lambda: xx + yy, lambda: 100 + yy)

with tf.Session() as sess:
   sess.run(tf.global_variables_initializer())
   dict1 = {bb:False, yy:np.array([1., 3, 4]), xx:np.array([5., 6, 7])}
   print(sess.run(zz, feed_dict=dict1))
   dict2 = {bb:False, yy:np.array([1., 3, 4])}
   print(sess.run(zz, feed_dict=dict2))
点赞