字节跳动AI Lab最新的工作,研究了如何把ViT做的更深的问题。之前的ViT采用的都是固定的12层,所不同的是head的数目和embedding的通道数,由此构成small、medium以及large版本。但是关于ViT层数的问题,则很少有探讨。
本文首先将ViT由12层加深到32层,发现随着网络层数的加深,性能并没有获得相应幅度的上升,甚至还有下降:
类似的网络退化问题在cnn中也有出现,cnn中通过加入residual connection来解决网络退化。本文中则观察发现ViT的退化是由于attention collapse,即随着层数的加深,attention map的变化越来越小,相应的feature map也开始停止学习,因而随着层数加深性能增长逐渐停滞。
作者的解决方案其实很简单,观察发现虽然同一个token在不同层的attention map差别小,但是同一层不同head之间差距还是很大,因此作者额外加入一个HxH大小的矩阵,来对attention map进行re-attention,利用其他head的attention map来对当前attention map进行增强。
反应在公式上的变化就是在与V相乘之前先与theta矩阵相乘,看起来很简洁了。消融实验结果显示,在32层ViT上,本文所提的解决方案(DeepViT)可以比vanilla ViT有1.6的提升,在深层ViT上的提升还是比较显著的。
作者解释re-attention起作用的原因在于在不同head之间进行了interaction,而不是单纯的解决over smoothing,因而可以编码更丰富的信息,当然通过可视化结果也可以看到,加入re-attention之后,block的相似性也显著降低了。