针对文本数据增强的方法【有代码】

论文标题:SSMix: Saliency-Based Span Mixup for Text Classification

论文链接:https://arxiv.org/pdf/2106.08062.pdf

不需要翻墙论文链接:文献全文 - 学术范 (xueshufan.com)

论文代码:https://github.com/clovaai/ssmix

论文作者:{soyoungyoon etc.}

论文摘要

数据增强已证明对各种计算机视觉任务是有效的。尽管文本取得了巨大的成功,但由于文本由可变长度的离散标记组成,因此将混合应用于NLP任务一直存在障碍。在这项工作中,我们提出了SSMix,一种新的混合方法,其中操作是对输入文本执行的,而不是像以前的方法那样对隐藏向量执行的。SSMix通过基于跨度的混合,综合一个句子,同时保留两个原始文本的位置,并依赖于显著性信息保留更多与预测相关的标记。通过大量的实验,我们实证验证了我们的方法在广泛的文本分类基准上优于隐藏级混合方法,包括文本隐含、情感分类和问题类型分类。

数据增强的效果已经在各种计算机视觉任务中被证实是有效的。尽管数据增强非常有效,由于文本是由变长的离散字符组成的,所以将mixup应用与NLP任务一直存在障碍。在本篇论文,作者提出了SSMix算法,一种针对输入文本增强的mixup算法,而非之前针对隐藏向量的方法。SSMix通过跨度混合( span-based mixing)在保留原始两个文本的条件下合成一个句子,同时保留两个原始文本的位置,并依赖于显著性信息保留更多与预测相关的标记。通过大量的实验,论文验证了该算法在广泛的文本分类基准上优于隐藏级混合方法,包括文本推断、情感分类和问题类型分类任务。

算法简介

由于数据收集与标志的昂贵成本,数据增强在自然语言处理(NLP)中越来越重要。其中一些已往研究包括基于简单的规则和模型来生成类似的文本。比如通过标准方法或先进的训练方法与原始样本联合进行训练,也有基于混淆(mixup)插值文本和标签进行增强。

Mixup及其变体训练算法成为计算机视觉中常用的正则化方法,用来提高神经网络的泛化能力。混合方法分为输入级混合和隐藏级混合( hidden-level mixup),两者取决于混合操作的位置。输入级混合是一种比隐藏级混合更普遍的方法,因为它的简单性和能够捕获局部性,从而具有更好的准确性。

由于文本数据的离散性和可变的序列长度,在NLP中应用mixup比在计算机视觉中更具有挑战性和难度。因此,之前大多数关于文本混合的尝试将mixup应用于嵌入向量,如嵌入或中间表示。然而根据计算机视觉的增强直观感受,输入级混合一般比隐藏级混合有优势。这一动机鼓励作者对探究文本数据的输入级混淆方法。

在这项工作中,作者提出了SSMix(图1),一种新的输入级结合跨度(Span)的显著性混合数据增强法算法。首先,作者通过用另一个文本中的跨度替换连续的标记来进行混淆,这一灵感来自CutMixarXiv,在混合文本中保留两个源文本的位置。其次,选择一个要替换的跨度,并基于显著性信息进行替换,以使混合文本包含与输出预测更相关的标记,这在语义上可能很重要。文本的输入级方法不同于隐级混合方法,当当前的隐级混合方法线性插值原始隐向量,我们的方法在输入级上混合文本字符,产生非线性输出。同时,利用显著性值从每个句子中选择跨度,并离散地定义跨度的长度和混合比,这是与隐藏级别混合增强区别的地方。


SSMix已经通过大量的文本分类基准实验被证明是有效的。特别强调的是,论文证明了输入级混合方法一般要优于隐层混合方法。论文还展示了在进行文本混合增强的同时,在跨度水平上使用显著性信息和限制标记选择的重要性。

SSMix算法

SSMix基本原理为:给定两个文本和,通过将文本的片段替换为来自另一文本的显著信息片段生成得到新的文本。同时,对于新文本,基于两个文本标签和重新为新文本设置一个新的标签。最后可以使用这个生成的增强虚拟样本(,)来进行训练模型。

Saliency:显著性信息

Saliency衡量了文本数据的每个字符对最终结果预测的影响。在以往研究中基于梯度的方法被广泛用于显著性计算,文本同样计算了分类损失相对于输入嵌入的梯度,并使用其大小作为显著性:。文中应用l2范数来获得一个梯度向量的大小,代表着每个字符的类似于PuzzleMix的显著性。

Mixing Text:文本合成

之前提到过,Mixing Text主要是是指两个文本序列和如何合成新的文本。大致思路是根据梯度显著性计算方法得到两个文本中每个字符的显著性分数,然后在文本中选取一个显著性最低的片段,长度为,在文本中选取一个显著性最低的片段,长度为。长度设置为==,其中为mixup比例参数。最后生成新文本w为,其中和为原始文本中替换片段的左右的两部分。

Sample span length:相等片段长度

本文将原始()的长度和替换()跨度设置为相同的,主要原因是使用不同长度的span(片段)将导致冗余和语义不明确的mixup 转换。另外,计算不同长度的span之间的mixup 比列也过于复杂。在以往研究中也采用了这种相同大小的替换策略。在替换span长度相同的情况下,论文的SSMix算法能够使显著性的效果最大化。由于SSMix不限制字符的位置,可以同时选择最显著的span和被替换的最不显著片段。如图片1中,in this在文本中是不显著的,transcedent love在文本中是最显著的,那么可以用transcedent love替换in this。

Mixing Text:标签合成

算法1 展示了如何利用原始样本对来计算增广样本的混合损失。公式中计算了增强输出logit相对于每个样本的原始目标标签的交叉熵损失,并通过加权和进行组合,因此SSMix算法与数据集标签个数是不相关的,在任何数据集上,输出标签比例是通过两个原始标签的线性组合来计算。

Paired sentence tasks:句子对任务

实验设置

实验数据集

论文实验数据集有文本分类和句子对分类任务:


对比实验

论文将SSMix与三个基线进行了比较:(1) standard training without mixup,(2)EmbedMixMix(3)TMix。


与基线和消融研究的实验结果进行了比较。所有的准确率值都是使用不同种子的5次运行的平均精度(%)。MNLI表示MNLI-不匹配的开发集的准确性。论文报告了GLUE的验证精度,TREC的测试精度,以及ANLI的有效(上)/测试(较低)精度,可以看出SSMix在大部分数据集效果要优于其他混合增强算法。


论文总结

与隐层混合方法相比,SSMix在具有足够数据量的数据集上充分证明了其有效性。由于SSMix是一个离散的组合,而不是两个数据样本的线性组合,它在一个合成空间上创建数据样本的范围大于隐藏级别的混合。论文假设,大量的数据有助于更好地在合成空间中进行表示。

SSMix对于多个类标签数据集(TREC、ANLI、MNLI、QNLI)尤其有效。因此,在没有混合的训练条件下,SSMix在TREC-fine(47个标签)上的精度增益远高于TRECcrare(6个标签),+分别为3.56和+为0.52。具有多个总类标签的数据集增加了在混合源的随机抽样中被选择交叉标签的可能性,所以可以认为在这些多标签分类数据集中的混合性能会显著提高

在成对句子任务上具有显著优势,如文本隐含或相似性分类。现有的方法(隐藏层混合)在隐藏层上应用混合,而不考虑特殊的标记,即[SEP]、[CLS]。这些方法可能会丢失关于句子开头的信息或句子对的适当分离。相比之下,SSMix在应用混合时可以考虑单个字符的特性。 -SSMix 及其变体的消融研究结果表明,随着对片段约束和显著性信息的增加,性能有所提高。在混合操作中添加片段约束受益于更好的可定位能力,并且大多数显著的片段与相应的标签有更多的关系,而丢弃最小显著的片段,这些片段相对于原始标签在语义上不重要。其中,引入显著性信息对精度的贡献相对高于片段约束。

代码实现

import copy

import random

import torch

import torch.nn.functional as F

from .saliency import get_saliency

class SSMix:

    def __init__(self, args):

        self.args = args

    def __call__(self, input1, input2, target1, target2, length1, length2, max_len):

        batch_size = len(length1)

        if self.args.ss_no_saliency:

            if self.args.ss_no_span:

                inputs_aug, ratio = self.ssmix_nosal_nospan(input1, input2, length1, length2, max_len)

            else:

                inputs_aug, ratio = self.ssmix_nosal(input1, input2, length1, length2, max_len)

        else:

            assert not self.args.ss_no_span

            input2_saliency, input2_emb, _ = get_saliency(self.args, input2, target2)

            inputs_aug, ratio = self.ssmix(batch_size, input1, input2,

                                          length1, length2, input2_saliency, target1, max_len)

        return inputs_aug, ratio

    def ssmix(self, batch_size, input1, input2, length1, length2, saliency2, target1, max_len):

        inputs_aug = copy.deepcopy(input1)

        for i in range(batch_size):  # cut off length bigger than max_len ( nli task )

            if length1[i].item() > max_len:

                length1[i] = max_len

                for key in inputs_aug.keys():

                    inputs_aug[key][i][max_len:] = 0

                inputs_aug['input_ids'][i][max_len - 1] = 102

        saliency1, _, _ = get_saliency(self.args, inputs_aug, target1)

        ratio = torch.ones((batch_size,), device=self.args.device)

        for i in range(batch_size):

            l1, l2 = length1[i].item(), length2[i].item()

            limit_len = min(l1, max_len) - 2  # mixup except [CLS] and [SEP]

            mix_size = max(int(limit_len * (self.args.ss_winsize / 100.)), 1)

            if l2 < mix_size:

                ratio[i] = 1

                continue

            saliency1_nopad = saliency1[i, :l1].unsqueeze(0).unsqueeze(0)

            saliency2_nopad = saliency2[i, :l2].unsqueeze(0).unsqueeze(0)

            saliency1_pool = F.avg_pool1d(saliency1_nopad, mix_size, stride=1).squeeze(0).squeeze(0)

            saliency2_pool = F.avg_pool1d(saliency2_nopad, mix_size, stride=1).squeeze(0).squeeze(0)

            # should not select first and last

            saliency1_pool[0], saliency1_pool[-1] = 100, 100

            saliency2_pool[0], saliency2_pool[-1] = -100, -100

            input1_idx = torch.argmin(saliency1_pool)

            input2_idx = torch.argmax(saliency2_pool)

            inputs_aug['input_ids'][i, input1_idx:input1_idx + mix_size] = \

                input2['input_ids'][i, input2_idx:input2_idx + mix_size]

            ratio[i] = 1 - (mix_size / (l1 - 2))

        return inputs_aug, ratio

    def ssmix_nosal(self, input1, input2, length1, length2, max_len):

        inputs_aug = copy.deepcopy(input1)

        ratio = torch.ones((len(length1),), device=self.args.device)

        for idx in range(len(length1)):

            if length1[idx].item() > max_len:

                for key in inputs_aug.keys():

                    inputs_aug[key][idx][max_len:] = 0

                inputs_aug['input_ids'][idx][max_len - 1] = 102  # artificially add EOS token.

            l1, l2 = min(length1[idx].item(), max_len), length2[idx].item()

            if self.args.ss_winsize == -1:

                window_size = random.randrange(0, l1)  # random sampling of window_size

            else:

                # remove EOS & SOS when calculating ratio & window size.

                window_size = int((l1 - 2) *

                                  self.args.ss_winsize / 100.) or 1

            if l2 <= window_size:

                ratio[idx] = 1

                continue

            start_idx = random.randrange(0, l1 - window_size)  # random sampling of starting point

            if (l2 - window_size) < start_idx:  # not enough text for reference.

                ratio[idx] = 1

                continue

            else:

                ref_start_idx = start_idx

            mix_percent = float(window_size) / (l1 - 2)

            for key in input1.keys():

                inputs_aug[key][idx, start_idx:start_idx + window_size] = \

                    input2[key][idx, ref_start_idx:ref_start_idx + window_size]

            ratio[idx] = 1 - mix_percent

        return inputs_aug, ratio

    def ssmix_nosal_nospan(self, input1, input2, length1, length2, max_len):

        batch_size, n_token = input1['input_ids'].shape

        inputs_aug = copy.deepcopy(input1)

        len1 = length1.clone().detach()

        ratio = torch.ones((batch_size,), device=self.args.device)

        for i in range(batch_size): # force augmented output length to be no more than max_len

            if len1[i].item() > max_len:

                len1[i] = max_len

                for key in inputs_aug.keys():

                    inputs_aug[key][i][max_len:] = 0

                inputs_aug['input_ids'][i][max_len - 1] = 102

            mix_len = int((len1[i] - 2) * (self.args.ss_winsize / 100.)) or 1

            if (length2[i] - 2) < mix_len:

                mix_len = length2[i] - 2

            flip_idx = random.sample(range(1, min(len1[i] - 1, length2[i] - 1)), mix_len)

            inputs_aug['input_ids'][i][flip_idx] = input2['input_ids'][i][flip_idx]

            ratio[i] = 1 - (mix_len / (len1[i].item() - 2))

        return inputs_aug, ratio

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 203,098评论 5 476
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,213评论 2 380
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 149,960评论 0 336
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,519评论 1 273
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,512评论 5 364
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,533评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,914评论 3 395
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,574评论 0 256
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,804评论 1 296
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,563评论 2 319
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,644评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,350评论 4 318
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,933评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,908评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,146评论 1 259
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 42,847评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,361评论 2 342

推荐阅读更多精彩内容