如何处理不平衡数据集的分类任务

在情感分类任务中,数据集的标签分布往往是极度不平衡的。以我目前手上的这个二分类任务来说,正例样本14.4万个:负例样本166.1万 = 1 :11.5。很显然这是一个极度不平衡的数据集,假设我把样本全部预测为负,那准确率也高达92%,但这么做没有意义。

那么我们如何处理这个不平衡数据集呢?

因为我用的是神经网络,我不希望减少训练样本,因此我不会采用下采样的方式。有三个方向可以尝试:

  • 使用自定义的loss函数
  • 设置class weight
  • 设置sample weight

这里,我将尝试三种不同的loss函数并进行对比。

(一)三种损失函数

y表示样本的真实标签,则y\in \{0,1\}。令\hat y表示sigmoid输出的预测类别为1的概率,显然\hat y \in [0,1]。下面我们给出三种损失函数的定义。

1. Binary crossentropy

L_{ce} = - ylog(\hat y) - (1-y)log(1-\hat y)

2. 修正的Binary crossentropy

这个损失函数来自苏神的两篇文章:
【1】文本情感分类(四):更好的损失函数
【2】何恺明大神的「Focal Loss」,如何更好地理解?

引入单位跃阶函数
\theta(x) = \begin{cases} 1, &x>0 \\ \frac 12, &x=0 \\ 0, &x<0 \end{cases}
取定阈值m=0.6(m为可调超参数,原则上大于0.5均可),则:
L_{vce} = -\theta(m-\hat y)ylog(\hat y) - \theta(\hat y -1+m)(1-y)log(1-\hat y)

这里我稍微修改了一点点,以使得损失函数更加对称。

这个损失函数跟Binary Crossentropy比起来,就是多了\theta这个调节因子,我们来分析一下这个公式:

  • 当正样本的预测概率大于m时,根据\theta的定义,这一项的损失就变为了0;当正样本的预测概率小于m时,保持这一项损失不变;
  • 当负样本的预测概率小于1-m时,这一项的损失也变为了0;当负样本的预测概率大于1-m时,保持这一项损失不变。

也就是说,这个损失函数将焦点放在了分类错误的样本上面,希望能够把更多的样本正确分类。

3. Focal Loss

来自论文Focal Loss for Dense Object Detection

L_{fl} = - \alpha(1-\hat y)^\gamma ylog(\hat y) - (1-\alpha){\hat y}^\gamma(1-y)log(1-\hat y)
其中\alpha为权重因子,\gamma为调节参数。

我们来分析一下这个函数,考虑\alpha=0.5, \gamma=2的情形:

  • 当正样本的预测概率\hat y接近1时(我们希望的),则1-\hat y接近0,则(1-\hat y)^\gamma就变得很小很小。也就是说,当某个样本分类合理时,函数会对其损失进行打折(down weighting),则打折的幅度依赖于参数\gamma
  • 当正样本的预测概率\hat y接近0时(我们不希望的),则1-\hat y接近1,因此(1-\hat y)^\gamma会变小一点点,但跟上面的情况比起来,其实相当于是放大了,因为大小是相对的。

对于负样本,同理可分析,此处略过。

再来考虑\alpha这个参数,它其实是一个权重的调节因子,用于平衡正负样本的损失贡献。但由于\gamma的存在,我们很难从实际数据中得到指导来设置这个参数,更多可能要去尝试和调参。一般情况下,先令\alpha = 0.5

Focal Loss函数对容易分类的样本进行down weighting,聚焦于难分类的样本上。跟苏神的那个思路类似,却更加高明。

那么在类别不均衡的分类任务中,这个损失函数到底怎么起作用呢?

我们知道Loss = \frac 1n \sum_{y} L_{fl}, 而我的任务中负样本占了绝大多数,对模型来说,它们绝大部分是很好分类的样本,因此它们的损失贡献会大打折扣,从而使模型聚焦在难分类的样本上面,包括绝大部分的正样本。同时,\alpha这个参数也能起到平衡正负样本损失的作用。

(二) 损失函数代码

1. 修正的Binary crossentropy(keras版本)

import keras.backend as K

margin = 0.8
theta = lambda t: (K.sign(t)+1.)/2.

def variant_crossentropy_loss(y_true, y_pred):
    return  - theta(margin - y_pred) * y_true * K.log(y_pred + 1e-9)
            - theta(y_pred - 1 + m) * (1 - y_true) * K.log(1 - y_pred + 1e-9))

2. Focal Loss(tensorflow版本)

由于网上的代码都是多分类的(基于softmax输出的),这里我写了一个二分类的(基于sigmoid输出),同时我还加了一个rescale的flag来控制损失函数的量级,单任务学习中,这个flag按照默认的False即可。

import tensorflow as tf

def variant_focal_loss(gamma=2., alpha=0.5, rescale = False):

    gamma = float(gamma)
    alpha = float(alpha)

    def focal_loss_fixed(y_true, y_pred):
        """
        Focal loss for bianry-classification
        FL(p_t)=-rescaled_factor*alpha_t*(1-p_t)^{gamma}log(p_t)
        
        Notice: 
        y_pred is probability after sigmoid

        Arguments:
            y_true {tensor} -- groud truth label, shape of [batch_size, 1]
            y_pred {tensor} -- predicted label, shape of [batch_size, 1]

        Keyword Arguments:
            gamma {float} -- (default: {2.0})  
            alpha {float} -- (default: {0.5})

        Returns:
            [tensor] -- loss.
        """
        epsilon = 1.e-9  
        y_true = tf.convert_to_tensor(y_true, tf.float32)
        y_pred = tf.convert_to_tensor(y_pred, tf.float32)
        model_out = tf.clip_by_value(y_pred, epsilon, 1.-epsilon)  # to advoid numeric underflow
        
        # compute cross entropy ce = ce_0 + ce_1 = - (1-y)*log(1-y_hat) - y*log(y_hat)
        ce_0 = tf.multiply(tf.subtract(1., y_true), -tf.log(tf.subtract(1., model_out)))
        ce_1 = tf.multiply(y_true, -tf.log(model_out))

        # compute focal loss fl = fl_0 + fl_1
        # obviously fl < ce because of the down-weighting, we can fix it by rescaling
        # fl_0 = -(1-y_true)*(1-alpha)*((y_hat)^gamma)*log(1-y_hat) = (1-alpha)*((y_hat)^gamma)*ce_0
        fl_0 = tf.multiply(tf.pow(model_out, gamma), ce_0)
        fl_0 = tf.multiply(1.-alpha, fl_0)
        # fl_1= -y_true*alpha*((1-y_hat)^gamma)*log(y_hat) = alpha*((1-y_hat)^gamma*ce_1
        fl_1 = tf.multiply(tf.pow(tf.subtract(1., model_out), gamma), ce_1)
        fl_1 = tf.multiply(alpha, fl_1)
        fl = tf.add(fl_0, fl_1)
        f1_avg = tf.reduce_mean(fl)
        
        if rescale:
            # rescale f1 to keep the quantity as ce
            ce = tf.add(ce_0, ce_1)
            ce_avg = tf.reduce_mean(ce)
            rescaled_factor = tf.divide(ce_avg, f1_avg + epsilon)
            f1_avg = tf.multiply(rescaled_factor, f1_avg)
        
        return f1_avg
    
    return focal_loss_fixed

(三)结果对比

我采用了双向GRU模型,在保持模型以及数据不变的情况下,仅改变损失函数以对比不同损失函数在我任务上的表现,由于Local CV是我关注的最终metric,我使用Local CV作为early stopping的依据,若两个epoch后Local CV没有得到提升,则模型停止训练,并取Local CV最高的模型作为预测模型。

以下是各个损失函数在验证集上表现,我取了三个维度:

  • Accuracy(阈值为0.5)
  • AUC(Area Under the Curve)
  • Local CV(本质是多个AUC的加权平均)

根据以上数据,我们可以得到如下结论:

  1. Focal Loss在我的任务上获得了最大的Local CV,比带class weight的Binary Crossentropy损失高出3个千分点。Focal Loss真是一个优秀的损失函数!
  2. Variant Crossentropy这个损失函数获得了最高的Accuracy,但是AUC和Local CV都很低,显然不适合我手中的任务。根据Variant Crossentropy的公式,其实也可推断,这个损失函数是在优化Accuracy。对于关注正确率的任务,这个损失函数应该是不错的选择。

参考资料:

【1】 非平衡数据集 focal loss 多类分类
【2】Focal Loss for Dense Object Detection
【3】文本情感分类(四):更好的损失函数
【4】何恺明大神的「Focal Loss」,如何更好地理解?

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 203,456评论 5 477
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,370评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 150,337评论 0 337
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,583评论 1 273
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,596评论 5 365
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,572评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,936评论 3 395
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,595评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,850评论 1 297
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,601评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,685评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,371评论 4 318
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,951评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,934评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,167评论 1 259
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 43,636评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,411评论 2 342