python – 如何使用TensorFlow Graphkeys获取所有权重?

Tensorflow定义了预设的集合,如下所示:
https://www.tensorflow.org/versions/r0.12/api_docs/python/framework/graph_collections

我目前正在使用tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)来获取所有变量[*的名字;如果他们不是那么他们就不会被展示,即使他们存在].

类似地,我期望tf.get_collection(tf.GraphKeys.WEIGHTS)输出权重列表,但它是一个空数组.这也适用于GraphKeys.BIASES和.ACTIVATIONS.

这里发生了什么?

在我看来,这里似乎有两种可能性.首先,它们实际上从未自动定义,只是推荐的集合名称.其次,我的网络非常破碎,但似乎并非如此.

有人有这方面的经验吗?

最佳答案 默认情况下,所有变量都绑定到tf.GraphKeys.GLOBAL_VARIABLES集合.方便的方法是将每个权重设置为集合tf.GraphKeys.WEIGHTS,如下所示:

In [2]: w = tf.Variable([1,2,3], collections=[tf.GraphKeys.WEIGHTS], dtype=tf.float32)
In [3]: w2 = tf.Variable([11,22,32], collections=[tf.GraphKeys.WEIGHTS], dtype=tf.float32)

然后你可以通过以下方式获取它们:

tf.get_collection_ref(tf.GraphKeys.WEIGHTS)

这是权重:

[<tf.Variable 'Variable:0' shape=(3,) dtype=float32_ref>,
 <tf.Variable 'Variable_1:0' shape=(3,) dtype=float32_ref>]
点赞