目前transformer从语言到视觉任务的挑战主要是由于这两个领域间的差异:
- 1、尺度变化大
- 2、高分辨率的输入
为了解决以上两点,我们提出了层级Transformer,通过滑动窗口提取特征的方式将使得self.attention
的计算量降低为和图像尺寸的线性相关。
简介
我们观察到将语言领域迁移到视觉领域的主要问题可以被总结为两种:
- 1、不同于word token,它的尺度是固定的,但是视觉领域的尺度变化非常剧烈
- 2、相对于上下文中的words,图片有着更高分辨率的像素,计算量会随着图片的尺寸成平方倍的增长。
结构
以上是论文中结构图,每一个
stage feature map
的尺寸都会减半。易知主要分为四个模块:
- Patch Partition
- Linear Embedding
-
Swin Transformer Block(主要模块):
- W-MSA:
regular window partition
和mutil-head self attention
- SW-MSA:
shift window partition
和mutil-head self attention
- W-MSA:
- Patch Merging
1、Patch Partition 和 Linear Embedding
在源码实现中两个模块合二为一,称为PatchEmbedding
。输入图片尺寸为 的RGB图片,将4x4x3
视为一个patch,用一个linear embedding 层将patch转换为任意dimension(通道)的feature。源码中使用4x4的stride=4的conv实现。->
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
2、Swin Transformer Block
这是这篇论文的核心模块。
- 如何解决计算量随着输入尺寸的增大成平方倍的增长? 抛弃传统的transformer基于全局来计算注意力的方法,将输入划分为不同的窗口,分别对每个窗口(window)施加注意力。
- 仅仅对窗口(window)单独施加注意力,如何解决窗口(window)之间的信息流动?交替使用
W-MSA
和SW-MSA
模块,因此SwinTransformerBlock
必须是偶数。如下图所示:
整体流程如下:- 先对特征图进行LayerNorm
- 通过self.shift_size决定是否需要对特征图进行shift
- 然后将特征图切成一个个窗口
- 计算Attention,通过self.attn_mask来区分Window Attention还是Shift Window Attention
- 将各个窗口合并回来
- 如果之前有做shift操作,此时进行reverse shift,把之前的shift操作恢复
- 做dropout和残差连接
- 再通过一层LayerNorm+全连接层,以及dropout和残差连接
2.1、window partition
window partition
分为regular window partition
和shift window partition
,对应于W-MSA
和SW-MSA
。通过窗口划分,将输入的feature map
转换为num_windows*B, window_size, window_size, C
,其中 num_windows = H*W / window_size / window_size
。然后resize 到 num_windows*B, window_size*window_size, C
进行attention。源码如下:
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
-
Layer1
是regular window partition
,窗口的大小是4x4,将图片分成了4个窗口。 -
Layer2
是shift window partition
,为了保证不同窗口的信息流动,起始点从(windows_size//2, windows_size//2)开始进行划分,将图片分成了9个窗口。可以看到移位后的窗口包含了原本相邻窗口的元素。但是同时也引入了新的问题,窗口大小不一致的问题,有2x2、2x4、4x2、4x4,最简单的方法就是统一padding到4x4,但是窗口数量由4增加至9,计算量变大了2.25倍。因此作者提出了cycle shift
去解决这个问题。
以下的示例图片来自于:https://mp.weixin.qq.com/s/8x1pgRLWaMkFSjT7zjhTgQ
首先对窗口进行shift window partition
,得到左图部分。不进行padding
,而是采用滚动的方式调整窗口,源码中用torch.roll()
函数实现,得到了右图。这时候得到了和regular window partition
一样的4个2x2大小的window
,不同的是,在一个2x2的windows
区域内是不连续的(index
不一样)。
我们希望在计算Attention的时候,让具有相同index
进行计算,而忽略不同index QK计算结果。因此我们为其添加上mask。源码计算mask
实现如下:
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
2.2、W-MSA
由regular window partition
模块 和 mutil-head self attention
模块组成。
W-MSA相比于直接使用MSA主要是为了降低计算量。传统的transformer都是基于全局来计算注意力,因此计算复杂度非常高。但是swin transformer通过对每个窗口施加注意力,从而减少了计算量。attention的主要计算过程如下:
假设每一个window
的区块大小为,输入的尺寸为,以下为原始的和的计算复杂度:
- 对于:对输入的
feature map
做全局attention,、、的计算量分别是,和的计算量分别是,的计算量是。 - 对于:在
windows
内的大小的区域内做attention,feature map
会被划分为个windows
,每个windows
的尺寸为。、、的计算量分别是,和的计算量的分别是,的计算量是。因此和输入尺寸成线性关系。
2.3、SW-MSA
虽然降低了计算量,但是由于将attention限制在window
内,因此不重合的window
缺乏联系,限制了模型的性能。因此提出了模块。在MSA
前面加上一个cycle shift window partition
3、Patch Merging
swin transformer中没有使用pooling
进行下采样,而是使用了和yolov5中的focus
层进行feature map
的下采样。 -> ,在使用一个全连接层->,在一个stage中将feature map的高宽减半,通道数翻倍。
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
不同尺寸的网络结构
基准模型结构命名为Swin-B
,模型大小和计算复杂度和ViT-B
/DeiT-B
相近。同时我们也提出了Swin-T
,Swin-S
和 Swin-L
,分别对应0.25×
, 0.5×
和 2×
倍的模型尺寸和计算复杂度。Swin-T
和 Swin-S
的计算复杂度分别和ResNet-50
、ResNet-101
相近。默认设置为7。代表第一层隐藏层的数量。
- Swin-T: C = 96, layer numbers = {2, 2, 6, 2}
- Swin-S: C = 96, layer numbers ={2, 2, 18, 2}
- Swin-B: C = 128, layer numbers ={2, 2, 18, 2}
- Swin-L: C = 192, layer numbers ={2, 2, 18, 2}
不同数据集的实验结果
1、ImageNet
2、COCO Object Detection
- 在不同的模型上使用swin transformer 作为特征提取网络
- 在cascade mask rcnn上使用swin transformer 作为backbone
-
直接对比其他目标检测网络