本文目录
0. 说明
- 由于是商业项目,一些数据不方便公开,但不影响文中列出的结果&结论,请谅解;
- 若发现问题,或者有任何问题想交流,欢迎留言哦~
1. 问题背景
最近在做一个文本分类任务,遇到一个现象:在扩大了数据集之后,模型的分类准确率反而下降了,比较严重一个模型是下降了7个百分点左右。
这很违反直觉,因为在保证数据质量的前提下,正常来说,数据集变大,模型应该能对噪声越不敏感,泛化性能更佳。
2. 原因分析
2.1 有没有可能是因为新数据的标签不够准确?
这是我第一个冒出来的念头。但是,其一,数据扩增的方法,是依赖于原始数据中存在的某中唯一性映射关系,对某些缺少标签的数据进行自动标注;其二,人工观察新增的case,与原始数据在标签准确度上并无明显区别。那可能是什么原因呢?
2.2 是不是因为数据分布发生了变化?
为了验证这个想法,我统计了原始数据A、新增数据B、以及合并后总体数据C的标签分布,然后用余弦相似度计算两两分布之间的相似程度。由于项目需要,同时存在多个数据集(基于此训练了多个分类器),因此我可以对比多个数据集之间的情况。
结果表明,分类效果下降越厉害的数据集,其数据扩增前后的分布的相似程度也越低(即原始数据A和合并后总体数据C间的余弦相似度越低),并且总体数据分布C总是比原始数据分布A更加均衡。
这似乎说明了数据分布的变化,是导致模型准确率下降的主要原因。但是,数据分布变化为什么会导致模型性能下降?数据量变大不能抵消这种下降吗?
3. 举个栗子
我们来玩一个游戏:假设有一个不透明的箱子,其中装有红绿蓝三种球,红 绿 蓝 球数分别为 98:1:1,这三种球唯一去区别是重量区间不一样。在开始游戏之前,给了几分钟时间,可以随意拿球掂量体验一下。现在游戏开始,让你从中摸一个球,并且猜是红绿蓝哪个颜色,你会怎么猜?我想,只要不傻,直接盲猜红色,至少也有98%的准确率。似乎,你已经摸透了三种球在手中的分量,已经找到了规律。
实际上,三种球的真正规律可能是,红色球重量克数在 [100, 105) 区间,绿色球和蓝色球的重量克数分别在[95, 100), [90, 95) 区间,但是这个规律你并不知道。
后来游戏升级了,我们往箱子中加入了 2 个红色球,以及绿色、蓝色球各 99 个,这样,箱子中就有了红绿蓝三种球各 100 个了。这时候,又给你几分钟时间,可以拿些新加入的球体验一下,然后继续游戏:让你从中摸一个球,并且猜是红绿蓝哪个颜色,你会怎么猜?
也许,你已经发现了规律并且总是能区分出红色球和蓝色球之间较大的重量差距,但是在红色球和绿色球之间(100克上下)、以及绿色球和蓝色球之间(95克上下),似乎有点难区分。因此,假设你摸到一个球之后,总是能排除掉它是红色球(或蓝色球)的可能,然后在剩下的两个颜色里随机选一个,这样,你的准确率大概就是 50%。
这个游戏只是一个理想化的假设,计算出来的准确率是在假设下得到的,但这不妨碍我们理解。正常情况下,红绿蓝三种球各 100 个时的游戏难度是更高的,没有经过足够训练的玩家,猜测准确率应该会下降。
4. (强行)理论解释
相信我,以下的解释可能都是错的。。。
有一天,我失眠了。然后无聊之中我就在想,数据分布变化如果会导致模型性能下降,那么背后的原理是什么?可否量化地进行解释?
突然间,脑海里蹦出一个念头,信息熵?!一切似乎都说得通了。
4.1 一种解释
最直接的解释是,模型进行预测,需要解决一些不确定性,而这个不确定性,很大一部分就体现在模型的输出空间 Y 的分布上(另一部分体现在输入空间 X)。这种不确定性,可以用信息熵来量化。
根据信息熵的含义,我们知道,输出空间 Y 的分布的信息熵越大,则为了消除这个不确定性,模型需要学到更多的信息/规律,也就是该分类任务的固有难度越大。
因此,问题变成了:如果能确认数据集扩大后,输出 Y 的分布的熵增加了,那么就可以说分类难度增加了,算是在某种程度上解释了为何模型性能会下降。
直接上结果:经统计发现,分类准确率下降较严重的数据集,在数据扩增后,输出 Y 的分布的熵确实增加幅度较大。
对上文游戏中的例子也进行一下计算吧!
初始游戏:红 绿 蓝 球数分别为 98:1:1,此时的熵为: ( − 0.98 l o g 2 0.98 ) + ( − 0.01 l o g 2 0.01 ) × 2 = 0.1614 (-0.98 \rm{log}_2 {0.98}) + (-0.01 \rm{log}_2 {0.01} ) \times 2 = 0.1614 (−0.98log20.98)+(−0.01log20.01)×2=0.1614 (对数底为2,单位为比特)
升级游戏:红 绿 蓝 球数比值为 1:1:1,此时的熵为: ( − 1 3 l o g 2 1 3 ) × 3 = − l o g 2 1 3 = 1.585 (- \frac13 \rm{log}_2 {\frac13}) \times 3 = – \rm{log}_2 {\frac13} = 1.585 (−31log231)×3=−log231=1.585 (对数底为2,单位为比特)
可以看出二者的熵具有较大的差距,后者为前者的近十倍。
4.2 另一种解释
以下尝试进行更深度的解释,但可能也都是错的。。。
机器学习模型学习的是什么?理论上,目标是让模型学到从输入X到输出Y的映射,可以是条件概率 P(Y|X)、联合概率 P(X, Y),或者 Y = f ( X ) Y=f(X) Y=f(X)。
然而,这或许只是我们的一厢情愿,模型或许主要学到了 P(Y),然后在数据不均衡的情况下,显示出了具有欺骗性的准确率。
为了便于公式拆解,我们假设模型学习的目标是联合概率分布 P(X, Y),假设使用的统计学习方法是朴素贝叶斯法。
那么,模型主要学习两点:①先验分布 P(Y);②条件概率 P(X|Y)。
然后,预测的方法是: y = a r g m a x c k P ( Y = c k ) P ( X ∣ Y = c k ) y = \rm{arg} \; \rm{max}_{c_k} P(Y=c_k) P(X|Y=c_k) y=argmaxckP(Y=ck)P(X∣Y=ck) (严格的写法中,这里的条件概率需要用条件独立性假设展开为连乘的形式,但这里为了便于理解,就不引入太多细节了)
当数据量比较少的时候,我们学到的 P(X|Y) 可能不太准确,而且,倘若这时数据分布较为不均衡,那么模型可能主要就学到了 P(Y),可能其中的某个类别概率很大。在这种情况下, P(X|Y) 学得差,而 P(Y) 学得好,P(Y) 将会主导模型最后的预测结果。就像我们一开始举例的游戏中那样,直接猜红色也能有 98% 准确率。但,这或许并不是我们想要的结果。
后续的数据量增加,数据分布更加平衡,让 P(Y) 不再“误导”分类结果,虽然模型准确率大概率会下降,但实际上学到了更多东西(P(X|Y))。
5. 结论
在数据量较少、且样本分布不均衡的情况下,数据集扩大使得样本标签更加均衡后(熵增),分类难度增加,可能会导致模型性能下降。
当然,如果数据量增大的特别多,比如扩大到原来的十倍、百倍,那么,大量的数据或许可以弥补这一下降。但一切都不能一概而论,毕竟,影响最终效果的因素很多,比如输入 X 本身的可区分性(问题本身的固有难度)、噪声/错误标签的多少、模型的学习能力、优化算法是否足够有效等等。