这篇文章的目的是为了解决transformer 处理长序列任务遇到的计算复杂度较高的问题。为了解决这个问题,许多工作聚焦于探索更有效的注意力机制,比如linear attention,但这类方法往往存在着以下三个缺陷:
- inferior quality. linear attention 相对于vanilla attention 往往会带来明显的指标掉点。
- overhead in practice. efficient attention往往是进行了复杂的layout 变换,这种操作计算复杂度体现不出来,但在实际应用中往往带来较大的时间开销。
- inefficient auto-regressive training. 这一块主要是针对时序上使用的efficient attention,比如RNN-style 的序列状态更新,因为时序的存在导致并行度降低,训练上速度较慢。
所以本文不是直接去拟合MHSA,而是设计了一种新的quad attention结构GAU,这种结构对于attention不是特别敏感,从而能够进一步使用线性时间去近似拟合。最终这种线性拟合的结构被命名为 FLASH (Fast Linear Attention With a Single Head.)
下面来分别看一下GAU和FLASH的设计。
1. GAU (Gated Attention Unit)
GAU 是从GLU(Gated Linear Unit) 发展而来的,GLU的定义如下
这个式子可以理解成对relu的输出变量加了一个平方的非线性变换,当然这里只是类比,因为 是放在外面的,如果是 就更明显了,只是relu放在里面和外面还是不完全等价的,相当于更复杂的非线性变换。GLU的解释是v对u做了一次gate。
GAU将上式中的V替换成了quad attention的输出,见下图
这里需要注意的几点:
- GAU中的attention部分 使用的, 而不是softmax, 这个操作和我们上面解释GLU差不多; 作者验证在自然语言上同样有效,但CV上还不确定。
- GAU中的QKV都是single head的, Q、K的生成方式不同,采用类似于BN中可学习参数的那种形式的scale_offset 算子
-
GAU的参数量相对于MHSA要小近似一半,因此同样的参数量和速度下,可以多堆叠一倍的GAU。
下面两张表分别对比了结构中一些操作的影响,gating这个文章中没细说,加上gating操作,参数量反而降低了,推测应该是一种过滤的操作?如果是GLU中的gate操作,参数量应该增大吧。
2. Fast Linear Attention with GAU
对GAU进行线性逼近,其出发点基于两点观测:
- GAU的门限机制允许一个更弱的,比如单头,无softmax (见Table1和Table2的对比)注意力机制实现相似的性能;
- 相同的计算代价下,MHSA+MLP 差不多等于2个GAU,而注意力逼近一般需要更多的layer去捕获更全面的依赖,所以GAU是一个好的选择。
目前存在的线性复杂度的attention可以大致划分为两类: Partial Attention和Linear Attention。
Partial Attention,包括划分window,local+sparse, axial, hash, clustering等,这个方法虽然表现不如full attention,但理论上确实能够对long sequence有较好的速度表现。但问题关键在于实际应用中往往需要gather,scatter,slice和cat等layout操作,这类操作并行度较差,对硬件不友好,导致实际场景中速度要慢很多。
Linear Attention,主要是去掉Softmax操作,从而能够利用矩阵乘法结合律先计算K^TV, 将计算复杂度降下来,同时对于NLP中的时序任务,随着序列逐步增长 是个累加的过程,每次只需要计算当前时刻的加上历史累积的。相比于quad attention每次都需要计算全部 计算量显然小很多。但是,对于长序列而言,虽然每次计算量小很多,但是这种序列计算就带来RNN类型存在的先天时序性,没法并行操作,只有执行完t-1步才能执行第t步,所以计算复杂度小但是计算时间反而有可能更长。对于移动芯片来说,存储器较小会使这种情况加剧。
3.1 Mixed Chunk Attention
本文提出的线性逼近结构。首先将序列划分为G个不重叠的chunk,每个chunk生成对应的, 由通过per-dim的scaling和offset 生成 , 每个chunk会同时参与local attention 和 global attention, local attention的话
而global attention则划分为non-causal和 causal,即是否时序上算, causal 会带来训练时间的大幅增加。
那么最终在当前帧的输出是把local 和 global的attention 叠加在一起放到GAU中,
伪代码如下:
这种操作其实是quad attention和linear attention的这种方案。
一些讨论:
- chunk的划分能加快auto-regressive training 的过程
- overlapping local attention能够改善质量,但是会存在memory re-formatting operations导致实际运行速度差很多。另外作者认为optimal partial attention 是任务相关的,但non-overlapping 是通用的。
- 和combiner相比,combiner也划分chunk,但每个chunk内使用的quad local attention,这样FLASH在chunk内就能够允许更长的chunk。
3. 实验
实验在NLP上做的,这里只看下结果。
时间对比:
4. 结论
这篇文章先是提出一种GAU,然后GAU对attention的依赖较小,进而可以把GAU中的attention替换成linear attention。这个倒是可以在CV上尝试,GAU可以多堆一倍。另一点mixed-chunk 在CV上感觉用处不大,倒是可以在track任务上使用,track的memory bank上更新query。