『TensorFlow』slim模块常用API

辅助函数

slim.arg_scope()

slim.arg_scope可以定义一些函数的默认参数值,在scope内,我们重复用到这些函数时可以不用把所有参数都写一遍,注意它没有tf.variable_scope()划分图结构的功能,

with slim.arg_scope([slim.conv2d, slim.fully_connected], 
                    trainable=True,
                    activation_fn=tf.nn.relu, 
                    weights_initializer=tf.truncated_normal_initializer(stddev=0.01), 
                    weights_regularizer=slim.l2_regularizer(0.0001)):
    with slim.arg_scope([slim.conv2d], 
                        kernel_size=[3, 3], 
                        padding='SAME',
                        normalizer_fn=slim.batch_norm):
        net = slim.conv2d(net, 64, scope='conv1'))
        net = slim.conv2d(net, 128, scope='conv2'))
        net = slim.conv2d(net, 256, [5, 5], scope='conv3'))

slim.arg_scope的用法基本都体现在上面了。一个slim.arg_scope内可以用list来同时定义多个函数的默认参数(前提是这些函数都有这些参数),另外,slim.arg_scope也允许相互嵌套。在其中调用的函数,可以不用重复写一些参数(例如kernel_size=[3, 3]),但也允许覆盖(例如最后一行,卷积核大小为[5,5])。
另外,还可以把这么多scope封装成函数:

def new_arg_sc():
    with slim.arg_scope([slim.conv2d, slim.fully_connected], 
                        trainable=True,
                        activation_fn=tf.nn.relu, 
                        weights_initializer=tf.truncated_normal_initializer(stddev=0.01), 
                        weights_regularizer=slim.l2_regularizer(0.0001)):
        with slim.arg_scope([slim.conv2d], 
                            kernel_size=[3, 3], 
                            padding='SAME',
                            normalizer_fn=slim.batch_norm) as sc:
            return sc

def main():
    ......
    with slim.arg_scope(new_arg_sc()):
        ......

 

slim.utils.collect_named_outputs()

将变量取个别名,并收集到collection中

net = slim.utils.collect_named_outputs(outputs_collections,sc.name,net)

 参数意义如下,

return:这个方法会返回本次添加的tensor对象,
参数二:意义是为tensor添加一个别名,并收集进collections中
查看源码可见实现如下

if collections:
append_tensor_alias(outputs,alias)
ops.add_to_collections(collections,outputs)
return outputs

据说本方法位置已经被转移到这里了,
from tensorflow.contrib.layers.python.layers import utils
utils.collect_named_outputs()

slim.utils.convert_collection_to_dict()

#集合转换为字典,{节点名:输出张量值}
end_points = slim.utils.convert_collection_to_dict(end_points_collection)
 
# 收集 & 释放 集合值
tf.add_to_collection("loss",mse_loss)
tf.add_n(tf.get_collection("loss"))

 

层函数

batch_norm处理

slim.batch_norm()函数,以及slim的各个层函数的normalizer_fn=slim.batch_norm调用都会用到,

其参数很多,需要以字典的形式传入,

batch_norm_params = {  # 定义batch normalization(标准化)的参数字典
        'is_training': is_training,
        # 是否是在训练模式,如果是在训练阶段,将会使用指数衰减函数(衰减系数为指定的decay),
        # 对moving_mean和moving_variance进行统计特性的动量更新,也就是进行使用指数衰减函数对均值和方
        # 差进行更新,而如果是在测试阶段,均值和方差就是固定不变的,是在训练阶段就求好的,在训练阶段,
        # 每个批的均值和方差的更新是加上了一个指数衰减函数,而最后求得的整个训练样本的均值和方差就是所
        # 有批的均值的均值,和所有批的方差的无偏估计

        'zero_debias_moving_mean': True,
        # 如果为True,将会创建一个新的变量对 'moving_mean/biased' and 'moving_mean/local_step',
        # 默认设置为False,将其设为True可以增加稳定性

        'decay': batch_norm_decay,             # Decay for the moving averages.
        # 该参数能够衡量使用指数衰减函数更新均值方差时,更新的速度,取值通常在0.999-0.99-0.9之间,值
        # 越小,代表更新速度越快,而值太大的话,有可能会导致均值方差更新太慢,而最后变成一个常量1,而
        # 这个值会导致模型性能较低很多.另外,如果出现过拟合时,也可以考虑增加均值和方差的更新速度,也
        # 就是减小decay

        'epsilon': batch_norm_epsilon,         # 就是在归一化时,除以方差时,防止方差为0而加上的一个数
        'scale': batch_norm_scale,
        'updates_collections': tf.GraphKeys.UPDATE_OPS,     
        # force in-place updates of mean and variance estimates
        # 该参数有一个默认值,ops.GraphKeys.UPDATE_OPS,当取默认值时,slim会在当前批训练完成后再更新均
        # 值和方差,这样会存在一个问题,就是当前批数据使用的均值和方差总是慢一拍,最后导致训练出来的模
        # 型性能较差。所以,一般需要将该值设为None,这样slim进行批处理时,会对均值和方差进行即时更新,
        # 批处理使用的就是最新的均值和方差。
        #
        # 另外,不论是即使更新还是一步训练后再对所有均值方差一起更新,对测试数据是没有影响的,即测试数
        # 据使用的都是保存的模型中的均值方差数据,但是如果你在训练中需要测试,而忘了将is_training这个值
        # 改成false,那么这批测试数据将会综合当前批数据的均值方差和训练数据的均值方差。而这样做应该是不
        # 正确的。
    }

在以其他层参数的形式调用时如下,

normalizer_fn=slim.batch_norm,                               # 标准化器设置为BN
normalizer_params=batch_norm_params

注意一但使用batch_norm层,在训练节点定义时需要添加一些语句,slim.batch_norm里有moving_mean和moving_variance两个量,分别表示每个批次的均值和方差。在训练时还好理解,但在测试时,moving_mean和moving_variance的含义变了,在训练时,

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)  
    with tf.control_dependencies(update_ops):  
        train_step = tf.train.GradientDescentOptimizer(0.01).minimize(total_loss)  
# 注意并tf本体的batch_normal操作也需要这步操作
# 其中,tf.control_dependencies(update_ops)表示with段中的操作是在update_ops操作执行之后 再执行的

tf.contrib.slim.conv2d()

convolution(inputs,
num_outputs,
kernel_size,
stride=1,
padding='SAME',
data_format=None,
rate=1,
activation_fn=nn.relu,
normalizer_fn=None,
normalizer_params=None,
weights_initializer=initializers.xavier_initializer(),
weights_regularizer=None,
biases_initializer=init_ops.zeros_initializer(),
biases_regularizer=None,
reuse=None,
variables_collections=None,
outputs_collections=None,
trainable=True,
scope=None)

inputs                        是指需要做卷积的输入图像
num_outputs             指定卷积核的个数(就是filter的个数)
kernel_size               用于指定卷积核的维度(卷积核的宽度,卷积核的高度)
stride                         为卷积时在图像每一维的步长
padding                     为padding的方式选择,VALID或者SAME
data_format              是用于指定输入的input的格式
rate                           这个参数不是太理解,而且tf.nn.conv2d中也没有,对于使用atrous convolution的膨胀率(不是太懂这个atrous convolution)
activation_fn             用于激活函数的指定,默认的为ReLU函数
normalizer_fn           用于指定正则化函数
normalizer_params  用于指定正则化函数的参数
weights_initializer     用于指定权重的初始化程序
weights_regularizer  为权重可选的正则化程序
biases_initializer       用于指定biase的初始化程序
biases_regularizer    biases可选的正则化程序
reuse                        指定是否共享层或者和变量
variable_collections  指定所有变量的集合列表或者字典
outputs_collections   指定输出被添加的集合
trainable                    卷积层的参数是否可被训练
scope                        共享变量所指的variable_scope

slim.conv2d是基于tf.conv2d的进一步封装,省去了很多参数,一般调用方法如下:

net = slim.conv2d(inputs, 256, [3, 3], stride=1, scope='conv1_1')

slim.max_pool2d

这个函数更简单了,用法如下:

net = slim.max_pool2d(net, [2, 2], scope='pool1')

slim.fully_connected

slim.fully_connected(x, 128, scope='fc1')

前两个参数分别为网络输入、输出的神经元数量。

    原文作者:tensorflow
    原文地址: https://www.cnblogs.com/hellcat/p/8058092.html
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞