在情感分类任务中,数据集的标签分布往往是极度不平衡的。以我目前手上的这个二分类任务来说,正例样本14.4万个:负例样本166.1万 = 1 :11.5。很显然这是一个极度不平衡的数据集,假设我把样本全部预测为负,那准确率也高达92%,但这么做没有意义。
那么我们如何处理这个不平衡数据集呢?
因为我用的是神经网络,我不希望减少训练样本,因此我不会采用下采样的方式。有三个方向可以尝试:
- 使用自定义的loss函数
- 设置class weight
- 设置sample weight
这里,我将尝试三种不同的loss函数并进行对比。
(一)三种损失函数
令表示样本的真实标签,则。令表示sigmoid输出的预测类别为1的概率,显然。下面我们给出三种损失函数的定义。
1. Binary crossentropy
2. 修正的Binary crossentropy
这个损失函数来自苏神的两篇文章:
【1】文本情感分类(四):更好的损失函数
【2】何恺明大神的「Focal Loss」,如何更好地理解?
引入单位跃阶函数
取定阈值(为可调超参数,原则上大于0.5均可),则:
这里我稍微修改了一点点,以使得损失函数更加对称。
这个损失函数跟Binary Crossentropy比起来,就是多了这个调节因子,我们来分析一下这个公式:
- 当正样本的预测概率大于m时,根据的定义,这一项的损失就变为了0;当正样本的预测概率小于m时,保持这一项损失不变;
- 当负样本的预测概率小于1-m时,这一项的损失也变为了0;当负样本的预测概率大于1-m时,保持这一项损失不变。
也就是说,这个损失函数将焦点放在了分类错误的样本上面,希望能够把更多的样本正确分类。
3. Focal Loss
来自论文Focal Loss for Dense Object Detection
其中为权重因子,为调节参数。
我们来分析一下这个函数,考虑的情形:
- 当正样本的预测概率接近1时(我们希望的),则接近0,则就变得很小很小。也就是说,当某个样本分类合理时,函数会对其损失进行打折(down weighting),则打折的幅度依赖于参数。
- 当正样本的预测概率接近0时(我们不希望的),则接近1,因此会变小一点点,但跟上面的情况比起来,其实相当于是放大了,因为大小是相对的。
对于负样本,同理可分析,此处略过。
再来考虑这个参数,它其实是一个权重的调节因子,用于平衡正负样本的损失贡献。但由于的存在,我们很难从实际数据中得到指导来设置这个参数,更多可能要去尝试和调参。一般情况下,先令。
Focal Loss函数对容易分类的样本进行down weighting,聚焦于难分类的样本上。跟苏神的那个思路类似,却更加高明。
那么在类别不均衡的分类任务中,这个损失函数到底怎么起作用呢?
我们知道, 而我的任务中负样本占了绝大多数,对模型来说,它们绝大部分是很好分类的样本,因此它们的损失贡献会大打折扣,从而使模型聚焦在难分类的样本上面,包括绝大部分的正样本。同时,这个参数也能起到平衡正负样本损失的作用。
(二) 损失函数代码
1. 修正的Binary crossentropy(keras版本)
import keras.backend as K
margin = 0.8
theta = lambda t: (K.sign(t)+1.)/2.
def variant_crossentropy_loss(y_true, y_pred):
return - theta(margin - y_pred) * y_true * K.log(y_pred + 1e-9)
- theta(y_pred - 1 + m) * (1 - y_true) * K.log(1 - y_pred + 1e-9))
2. Focal Loss(tensorflow版本)
由于网上的代码都是多分类的(基于softmax输出的),这里我写了一个二分类的(基于sigmoid输出),同时我还加了一个rescale的flag来控制损失函数的量级,单任务学习中,这个flag按照默认的False即可。
import tensorflow as tf
def variant_focal_loss(gamma=2., alpha=0.5, rescale = False):
gamma = float(gamma)
alpha = float(alpha)
def focal_loss_fixed(y_true, y_pred):
"""
Focal loss for bianry-classification
FL(p_t)=-rescaled_factor*alpha_t*(1-p_t)^{gamma}log(p_t)
Notice:
y_pred is probability after sigmoid
Arguments:
y_true {tensor} -- groud truth label, shape of [batch_size, 1]
y_pred {tensor} -- predicted label, shape of [batch_size, 1]
Keyword Arguments:
gamma {float} -- (default: {2.0})
alpha {float} -- (default: {0.5})
Returns:
[tensor] -- loss.
"""
epsilon = 1.e-9
y_true = tf.convert_to_tensor(y_true, tf.float32)
y_pred = tf.convert_to_tensor(y_pred, tf.float32)
model_out = tf.clip_by_value(y_pred, epsilon, 1.-epsilon) # to advoid numeric underflow
# compute cross entropy ce = ce_0 + ce_1 = - (1-y)*log(1-y_hat) - y*log(y_hat)
ce_0 = tf.multiply(tf.subtract(1., y_true), -tf.log(tf.subtract(1., model_out)))
ce_1 = tf.multiply(y_true, -tf.log(model_out))
# compute focal loss fl = fl_0 + fl_1
# obviously fl < ce because of the down-weighting, we can fix it by rescaling
# fl_0 = -(1-y_true)*(1-alpha)*((y_hat)^gamma)*log(1-y_hat) = (1-alpha)*((y_hat)^gamma)*ce_0
fl_0 = tf.multiply(tf.pow(model_out, gamma), ce_0)
fl_0 = tf.multiply(1.-alpha, fl_0)
# fl_1= -y_true*alpha*((1-y_hat)^gamma)*log(y_hat) = alpha*((1-y_hat)^gamma*ce_1
fl_1 = tf.multiply(tf.pow(tf.subtract(1., model_out), gamma), ce_1)
fl_1 = tf.multiply(alpha, fl_1)
fl = tf.add(fl_0, fl_1)
f1_avg = tf.reduce_mean(fl)
if rescale:
# rescale f1 to keep the quantity as ce
ce = tf.add(ce_0, ce_1)
ce_avg = tf.reduce_mean(ce)
rescaled_factor = tf.divide(ce_avg, f1_avg + epsilon)
f1_avg = tf.multiply(rescaled_factor, f1_avg)
return f1_avg
return focal_loss_fixed
(三)结果对比
我采用了双向GRU模型,在保持模型以及数据不变的情况下,仅改变损失函数以对比不同损失函数在我任务上的表现,由于Local CV是我关注的最终metric,我使用Local CV作为early stopping的依据,若两个epoch后Local CV没有得到提升,则模型停止训练,并取Local CV最高的模型作为预测模型。
以下是各个损失函数在验证集上表现,我取了三个维度:
- Accuracy(阈值为0.5)
- AUC(Area Under the Curve)
- Local CV(本质是多个AUC的加权平均)
根据以上数据,我们可以得到如下结论:
- Focal Loss在我的任务上获得了最大的Local CV,比带class weight的Binary Crossentropy损失高出3个千分点。Focal Loss真是一个优秀的损失函数!
- Variant Crossentropy这个损失函数获得了最高的Accuracy,但是AUC和Local CV都很低,显然不适合我手中的任务。根据Variant Crossentropy的公式,其实也可推断,这个损失函数是在优化Accuracy。对于关注正确率的任务,这个损失函数应该是不错的选择。
参考资料:
【1】 非平衡数据集 focal loss 多类分类
【2】Focal Loss for Dense Object Detection
【3】文本情感分类(四):更好的损失函数
【4】何恺明大神的「Focal Loss」,如何更好地理解?