tensorflow batch_normalization的正确使用姿势

BN在如今的CNN结果中已经普遍应用,在tensorflow中可以通过tf.layers.batch_normalization()这个op来使用BN。该op隐藏了对BN的mean var alpha beta参数的显示申明,因此在训练和部署测试中需要特征注意正确使用BN的姿势。

正确使用BN训练

注意把tf.layers.batch_normalization(x, training=is_training,name=scope)输入参数的training=True。另外需要在来训练中添加update_ops以便在每一次训练完后及时更新BN的参数。

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)  
with tf.control_dependencies(update_ops): #保证train_op在update_ops执行之后再执行。 
   train_op = optimizer.minimize(loss) 

正确保存带BN的模型

保存模型的时候不能只保存trainable_variables,因为BN的参数不属于trainable_variables。为了方便,可以用tf.global_variables()。使用姿势如下

saver = tf.train.Saver(var_list=tf.global_variables())
savepath = saver.save(sess, 'here_is_your_personal_model_path’)

正确读取带BN的模型

与保存类似,读的时候变量也需要为global_variables。如下:

saver = tf.train.Saver()
or saver = tf.train.Saver(tf.global_variables())
saver.restore(sess, 'here_is_your_personal_model_path')

PS:测试的时候还需要把tf.layers.batch_normalization(x, training=is_training,name=scope) 这里的training设为False

Reference:
https://stackoverflow.com/questions/48260394/whats-the-differences-between-tf-graphkeys-trainable-variables-and-tf-graphkeys

点赞