关于TensorFlow动态设置trainable的问题

这个问题是在项目中遇到的一个问题,即“如何在训练过程中动态的控制哪些参数更新,哪些不更新”。我们知道tensorflow的计算图是静态的,tensorflow中定义的tf.Variable时,可以通过trainable属性控制这个变量是否可以被优化器更新。但是,tf.Variable的trainable属性是只读的,我们无法动态更改这个只读属性。在定义tf.Variable时,如果在定义变量时制定了trainable=True,那么只要这个变量被初始化后,这个trainable就没法更改了,即使使用tf.placeholder(tf.bool)在训练时给这个变量传递一个参数试图改变该变量的trainable属性也是不可以的,会报错。

那么如何在训练时动态的选择需要更新的和不需要更新的参数呢?我在这里提供一个思路,这个思路也是在stackoverflow上看到的(链接);另外我还看到一个办法,但是我没有尝试,如果有人有兴趣可以试一下是否可行(链接)。

tensorflow将可以更新的参数存在TRAINABLE_VARIABLES中,所以我们只要定义两个不同的优化器就可以了。每个优化器指定当前更新哪些参数,这样我们就可以交替更新参数了。

P.S. 这篇博客也提供了部分思路,但是我觉得他的实现方式较为复杂。

下面我给出实现代码(核心代码):

x = tf.placeholder(shape=[None,5],dtype=tf.float32,name='x')
y = tf.placeholder(shape=[None,1],dtype=tf.float32,name='y')
with tf.variable_scope('z'):
    z = tf.Variable(tf.zeros(shape=[3,1],dtype=tf.float32),name='z',trainable=True) # 定义中间隐参数z

def dnn(x,z):
    with tf.variable_scope('parameter'):
        w1 = tf.Variable(tf.truncated_normal([5, 3]), dtype=tf.float32, trainable=True, name='w1')
        b1 = tf.Variable(tf.constant(0.001,shape=[3],dtype=tf.float32),trainable=True,name='b1')
        w2 = tf.Variable(tf.truncated_normal([3, 1]), dtype=tf.float32, trainable=True, name='w2')
        b2 = tf.Variable(tf.constant(0.001, shape=[1], dtype=tf.float32), trainable=True, name='b2')
    d1 = tf.add(tf.matmul(x,w1),b1)
    d2 = tf.add(tf.matmul(d1,w2),b2)
    output = tf.add(d2,z)
    return output

y_ = dnn(x,z)
loss = tf.abs(y-y_)
optimzer = tf.train.AdamOptimizer(0.001)
trainable_var_p = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,'parameter')
trainable_var_z = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,'z')
train_op_p = optimzer.minimize(loss,var_list=trainable_var_p)
train_op_z = optimzer.minimize(loss,var_list=trainable_var_z)

with tf.Session() as sess:
    sess.run(init)
    for i in range(100):
        print('第',i,'轮')
        if i%10<5 :
            print("z不更新")
            sess.run(train_op_p,feed_dict={x:batch_x,y:batch_y})
        else:
            print("z更新")
            sess.run(train_op_z, feed_dict={x: batch_x, y: batch_y})
           

    原文作者:镜镜詅痴
    原文地址: https://zhuanlan.zhihu.com/p/54547244
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞