tensorflow – tf-slim批量规范:训练/推理模式之间的不同行为

我正在尝试基于流行的
slim implementation mobilenet_v2训练一个张量流模型,并观察行为我无法解释相关(我认为)批量规范化.

问题摘要

推理模式中的模型性能最初得到改善,但在很长一段时间后开始产生微不足道的推论(所有近零).在训练模式下运行时,即使在评估数据集上也能保持良好的性能.评估性能受到批量标准化衰减/动量率的影响……不知何故.

下面有更广泛的实现细节,但我可能会失去你们大多数人的文字墙,所以这里有一些图片让你感兴趣.

下面的曲线来自一个模型,我在训练时调整了bn_decay参数.

0-370k:bn_decay = 0.997(默认)

370k-670k:bn_decay = 0.9

670k:bn_decay = 0.5

《tensorflow – tf-slim批量规范:训练/推理模式之间的不同行为》
(橙色)训练(训练模式)和(蓝色)评估(推理模式)的损失.低是好的.

《tensorflow – tf-slim批量规范:训练/推理模式之间的不同行为》
推理模式下评估数据集模型的评价指标.高是好的.

我试图制作一个最小的例子来证明这个问题 – 在MNIST上进行分类 – 但是失败了(即分类效果很好,我遇到的问题没有表现出来).我为无法进一步减少事情而道歉.

实施细节

我的问题是2D姿态估计,目标是以联合位置为中心的高斯.它基本上与语义分段相同,除了使用softmax_cross_entropy_with_logits(标签,logits),我使用tf.losses.l2_loss(sigmoid(logits) – gaussian(label_2d_points))(我使用术语“logits”来描述未激活的输出我的学习模型,虽然这可能不是最好的术语).

推理模型

在对输入进行预处理之后,我的logits函数是对基本mobilenet_v2的作用域调用,后跟一个未激活的卷积层,以使过滤器的数量合适.

from slim.nets.mobilenet import mobilenet_v2

def get_logtis(image):
    with mobilenet_v2.training_scope(
            is_training=is_training, bn_decay=bn_decay):
        base, _ = mobilenet_v2.mobilenet(image, base_only=True)
    logits = tf.layers.conv2d(base, n_joints, 1, 1)
    return logits

训练操作

我已经尝试过tf.contrib.slim.learning.create_train_op以及自定义培训操作:

def get_train_op(optimizer, loss):
    global_step = tf.train.get_or_create_global_step()
    opt_op = optimizer.minimize(loss, global_step)
    update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
    update_ops.add(opt_op)
    return tf.group(*update_ops)

我正在使用学习率= 1e-3的tf.train.AdamOptimizer.

训练循环

我正在使用tf.estimator.Estimator API进行培训/评估.

行为

培训最初进展顺利,预计业绩将大幅增加.这符合我的期望,因为最后一层经过快速培训,可以解释预训练基础模型输出的高级特征.

然而,经过很长一段时间(使用batch_size 8 60k步,GTX-1070上约8小时),我的模型在推理模式下运行时开始输出接近零的值(~1e-11),即is_training = False.当在*训练模式中运行时,完全相同的模型继续改进,ieis_training = True`,即使在估值集上也是如此.我已经在视觉上证实了这一点.

经过一些实验后,我将bn_decay(批量归一化衰减/动量率)从默认的0.997改为0.9,步长为〜370k(也尝试了0.99,但这并没有太大差别)并且观察到准确性的不确定性提高.在推理模式中对推理的目视检查显示在预期位置的推断值~1e-1中的清晰峰值,与来自训练模式的峰值的位置一致(尽管值低得多).这就是为什么准确度会显着提高,但损失 – 虽然更具波动性 – 并没有太大改善.

经过更多的训练后,这些影响逐渐消失,并恢复到所有零推理.

我在步骤~670k时进一步将bn_decay降为0.5.这导致了损失和准确性的改善.我可能要等到明天才能看到长期影响.

下面给出了损失和评估指标图.请注意,评估指标基于logits的argmax,而high是好的.损失基于实际值,低值是好的. Orange在训练集上使用is_training = True,而蓝色在评估集上使用is_training = False.大约8的损失与所有零输出一致.

其他说明

>我还试验了关闭辍学(即总是用is_training = False运行辍学层),并且没有观察到差异.
>我已经尝试了从1.7到1.10的所有版本的tensorflow.没有不同.
>我从一开始就使用bn_decay = 0.99训练了来自预训练检查点的模型.与使用默认bn_decay相同的行为.
>批量大小为16的其他实验导致定性相同的行为(尽管由于内存限制我无法同时评估和训练,因此定量分析批量大小为8).
>我使用相同的损失训练了不同的模型并使用tf.layers API并从头开始训练.他们工作得很好.
>从头开始训练(而不是使用预训练检查点)会导致类似的行为,但需要更长的时间.

总结/我的想法:

>我相信这不是过度拟合/数据集问题.当使用is_training = True运行时,该模型对评估集进行合理的推断,无论是峰值位置还是幅度.
>我相信这不是运行update ops的问题.我之前没有使用过slim,但除了使用arg_scope之外,它与我广泛使用的tf.layers API看起来并没有太大区别.我还可以检查移动平均值,并观察它们随着训练的进展而变化.
> Chaning bn_decay值暂时显着影响结果.我接受0.5的值是荒谬的低,但我的想法已经不多了.
>我尝试将动态= 0.997的tf.layers.conv2d的slim.layers.conv2d图层换掉(即动量与默认衰减值一致)并且行为相同.
>使用预训练权重和Estimator框架的最小示例用于MNIST的分类而无需修改bn_decay参数.

我已经查看了tensorflow和模型github存储库的问题,但是除了this之外没有找到太多.我正在尝试更低的学习率和更简单的优化器(MomentumOptimizer),但更多的是因为我正在运行出于想法而不是因为我认为这就是问题所在.

可能的解释

>我的最佳解释是我的模型参数快速循环,使得移动统计数据无法跟上批次统计信息.我从来没有听说过这种行为,也没有解释为什么模型会在更长的时间后恢复到不良行为,但这是我的最佳解释.
>移动平均代码中可能存在错误,但在其他所有情况下它都完美适用,包括简单的分类任务.在我能够制作一个更简单的例子之前,我不想提出问题.

无论如何,我的想法已经用完了,调试周期很长,而且我已经花了太多时间在这上面.很高兴提供更多细节或按需进行实验.也很高兴发布更多代码,虽然我担心会吓跑更多的人.

提前致谢.

最佳答案 使用Adam将学习率降低到1e-4并使用Momentum优化器(使用learning_rate = 1e-3和动量= 0.9)解决了这个问题.我还发现
this post表明该问题跨越多个框架,并且由于优化器和批处理规范化之间的交互,因此是某些网络的未记录病理.由于学习速度太高,我不认为优化器未能找到合适的最小值的简单情况(否则训练模式下的性能会很差).

我希望能够帮助其他人遇到同样的问题,但我还有很长的路要走.我很高兴听到其他解释.

点赞