问题
实际的分类任务中大多存在样本不平衡(长尾数据,long-tailed distribution)问题,因为一些类别的数据本身就具有稀缺性。尤其是细粒度分类数据集,基本上都是不平衡的,而原生的(vanilla)深度学习模型和损失函数都是假设所有类别的样本数量是平衡的,这样才能使得模型在每个类别上的表现相同,否则多样本类别(head-class/many-shot)在学习中会占据主导地位(对loss的贡献更多),从而使得模型在多样本类别上表现很好,但是在稀缺样本类别(tail-class/few-shot)的表现很差。
之前已经有学者给出了多种缓解该问题的方法,例如样本上采样、设计适应非平衡数据的loss函数、设计更加复杂的训练流程等等。虽然这些方法都能带来一些效果的提升,但是他们有一个共性,即模型的特征抽取权重(backbone部分)和分类器权重是同时联合(joint)学习的。然而,这样的联合方式并不能告诉我们,使用的方法到底提高的是模型的特征表达能力,还是模型的分类器性能,或者二者兼有。Facebook这篇论文将分类任务学习过程解耦为特征学习(representation learning)和分类器学习(classifier learning)两部分,然后得到三个重要的发现:
- 解耦合训练方式比联合训练方式在非平衡数据集上的表现更好;
- 使用不平衡的数据集(原数据集)也可以学习到更好的特征提取器;
- 使用合适的策略重新调整(或重新训练)分类器,结合1中学习的特征提取器可以有效缓解数据不平衡带来的稀缺样本类别识别错误问题。
方法
作者将传统的联合训练方式(单阶段学习)拆分成两阶段学习方式,特征学习阶段和分类器学习阶段。对于每个阶段都使用多种策略进行训练或调整模型权重,作者将这些前后阶段的不同策略进行两两结合得到多种两阶段训练策略组合,在多个数据集上比较这些策略组合的表现,并且与联合方式训练的模型进行比较,且经过实验得出分别在特征学习阶段和分类器学习阶段表现比较好的策略。
特征学习
作者使用了四种策略进行特征学习,都是对训练样本进行重采样,实例平衡采样(instance-balanced sampling)、类别平衡采样(class-balanced sampling)、平方根采样(square-root sampling)和逐步平衡采样(progressively-balanced sampling)。其中前三种采样方式都可以用以下公式概括:
公式中,为类别数量,为训练集中第个类别的样本数量,或称为该类的基数(cardinality),指从第个类别的数据中采样数据的概率。
- 实例平衡采样(instance-balanced sampling,IB),就是最传统的不使用任何上采样或下采样方法的方法,也就是公式中的情况,每个类别的数据被采样到的概率(即训练中一个batch随机抽取的概率)与该类别的数据量有关,;
- 类别平衡采样(class-balanced sampling,CB),就是上面公式中的情况,这时所有类别的数据被采样到的概率相同,实现该采样方式的一般方法是直接上采样稀缺类别样本直到其数量与多样本类别样本数量相同,;
- 平方根采样(square-root sampling),公式中的情况,它是类别平衡采样的变种;
- 逐步平衡采样(Progressively-balanced sampling,PB),这是作者在前人的一些工作基础上提出的新式采样方法,灵感来自Cui等人和Cao等人提出的混合采样方法,例如将实例平衡采样方法和类别平衡采样方法结合,在训练的前期使用实例平衡采样方法,一些周期后期使用类别平衡采样方法,但是这种方法会引入采样方法切换的周期超参。基于此,作者提出逐步平衡采样,随着训练迭代周期数的增加,采访方法会逐渐切换到类别平衡采样方法,公式描述如下:
论文中用于特征学习训练的策略都是采样方法,并不涉及loss-reweighting等其他策略。
分类器学习
将特征学习从分类学习中分离出来后,需要修正分类器的决策边界,使得模型在稀缺类别上的表现更好。作者采用了多种前人提出的方法,包括重训练方法和其他无参数且无需重训练方法,作者也提出了两种分类权重修正的方法。具体方法如下:
- 分类器重训练(Classifier Re-training,cRT),对分类器权重进行重新随机初始化并使用类别平衡采样方法进行训练,训练时保持特征抽取部分的权重固定。
- 最邻近类别平均分类器(Nearest Class Mean classifier,NCM),先计算训练集上每个类的平均特征表示,然后利用标准化该平均特征,在此基础上每个类别特征的余弦相似度或欧氏距离以进行最邻近搜索(Snell等2017;Guerriero等2018;Rebuffi等2017)。
- 标准化分类器(),该方法的提出基于作者的实验观察:经过使用实例平衡采样方式联合训练的分类器权重的范数(为负责计算与类别相关的输出,后用于计算类别的概率)与其类别的基数相关,然而,经过使用类别平衡采样方式微调训练后,所有类别的权重范数趋于类似。受到以上观察的启发,作者提出了方法,直接对权重的范数进行标准化修改,从而修正分类器的决策边界。调整公式为,其中,为用来控制标准化"温度"(力度)的超参,表示标准化,当时公式退化为标准化,时不进行缩放。作者经验上选择的位于之间,这样权重就可以被平滑地修正。最后,预测时使用代替参与推理计算。
- 可学习权重缩放(Learnable weight scaling,LWS),虽然方法的超参可以通过交叉验证获得,但总归是麻烦的,作者提出的LWS方法就是尝试将该超参调整为可学习参数。其实,对方法的另一种解释是其实质上是对权重向量进行方向不变的尺度缩放,因此可将的计算公式改写为,其中,这样便可以将缩放因子视为可学习参数。类似cRT方法,学习参数时将特征权重和分类器权重进行固定,并使用类别平衡采样方法进行训练。
实验结果
作者在实验中使用Places-LT、ImageNet-LT和iNaturalist 2018三个不平衡数据集,并使用传统accuracy作为评价指标,为了能够明显体现不同方法在稀缺类别上的效果提升,作者使用的测试集是样本平衡的,而且将测试集上的accuracy指标按训练数据集上的类别样本数量等级划分为三类,Many、Medium和Few三类,分别表示样本数量占比较多的类别集合(样本数大于100)、样本数量占比中等的类别数量(样本数介于20到100)和样本数量占比较少的类别集合(样本数小于20)。
联合训练时采样策略同样重要
从下面的试验结果可以看出,对于联合训练方式,合理的采样策略也是很重要的。
解耦合训练的必要性
Table 1实验数据说明了对于不平衡数据集,使用实例平衡采样+cRT组合或实例平衡采样+组合的解耦合训练最终的表现要优于联合训练方式:
不同方式组合的两阶段训练方法效果
Figure 1给出了不同组合的两阶段训练方法以及联合训练方法在IamgeNet-LT数据集上的表现对比,从图中效果看出合理选择策略的两阶段方法最终表现要明显优于联合训练方式的。
从图中还可以看出使用实例平衡采样方式可以学习到更好的特征。另外,对于分类器学习来说,决策边界调整策略的使用对于最终的模型表现也起到了决定性的作用,其中cRT和方法的表现更加突出。
和LWS方法的有效性
Figure 2实验说明经过和LWS方法修正的权重具有与cRT方法类似的权重范数,图右部分给出了分类器的accuracy指标与超参的关系。
与其他SOTA方法的比较
下面给出了作者提出的解耦合方法与其他SOTA方法在三个数据集上的表现对比,可以看出解耦合方法都是最优的。
从该实验结果仍然可以看到,使用实例平衡采样+cRT组合或实例平衡采样+组合或实例平衡采样+LWS组合策略的两阶段训练方法都能获得很好的效果。
评价
作者是在样本平衡的测试集上测试效果的,所以效果很明显,但是在实际中我们的测试集和训练集是同分布的,即对于数量同样不平衡的测试集效果提升可能并不明显。另外,对比联合训练方式,几乎每个解耦合方法在Many类别上的accuracy表现都会掉点,如果不能从Medium和Few类别的表现提升上补偿Many的掉点的话,则最终的总体accuracy表现是比联合训练方式差的。
参考
原论文链接