论文名称:Cross-Image Pixel Contrasting for Semantic Segmentation
这是一种将对比学习运用到全监督语义分割里的方法,主要起到辅助训练的作用,在实际推理部署的时候,原本用来对比学习的分支是去除的。
主要解决的问题:
- 模型训练的时候只考虑当前一张图像的内容,无法站在整个数据集的内容上考虑问题。
创新点:
- 基于像素到像素的对比(pixel to pixel)+像素到区域的对比(pixel to region),设计了更加高效的memory bank。
- 设计了一种更加合理的难样本采样策略,Segmentation-Aware Hard Anchor Sampling。
整体结构
通过上图可以看到,相比常规的全监督语义分割结构,本文的方法只是额外增加了一条用于对比计算的辅助分支,该分支在实际推理部署的时候是去除的,所以对于语义分割模型本身是不增加推理负担的。
Memory Bank
这个东西就是一个数据池,保存了历史数据用于对比计算,这里面保存的都是经过了模型特征提取后的D维特征,本文D=256。
Pixel to Pixel
这个是针对整个数据集的图像来操作的,就是对所有类设置一个专属的队列(整个数据集有多少个类就有多少个队列,比如COCO数据集有80类,那么就有80个队列),训练的时候从每个mini-batch中的每个类选取V个D维像素加入到对应类的队列T里,T是远大于V的。一旦T被装满了,那么就去旧留新。通过这种方式Memory Bank能动态存储绝大部分图像的内容特征。
num_pixel = idxs.shape[0]
perm = torch.randperm(num_pixel) #随机选择一定像素
K = min(num_pixel, self.pixel_update_freq) #跟预设值比较,减少代码出错的操作
feat = this_feat[:, perm[:K]]
feat = torch.transpose(feat, 0, 1)
ptr = int(pixel_queue_ptr[lb])
if ptr + K >= self.memory_size: #队列满了则去旧留新
pixel_queue[lb, -K:, :] = nn.functional.normalize(feat, p=2, dim=1)
pixel_queue_ptr[lb] = 0
else:
pixel_queue[lb, ptr:ptr + K, :] = nn.functional.normalize(feat, p=2, dim=1)
pixel_queue_ptr[lb] = (pixel_queue_ptr[lb] + 1) % self.memory_size
上面的源码可以大致看到,首先会随机选择一定数量的像素加入到队列,如果队列满了则旧的数据会被新的数据代替。
Pixel to Region
就是将区域的一大块特征用一个像素点的特征去表示,主要是用来弥补pixel to pixel 采样不充分的问题,这个方法是针对一张图像的操作,将多个一张图像的特征拼接到一起训练就能获取全局信息。怎么操作的呢?比如一张图像上有3个地方是猫的区域,首先对这个3块区域在XY坐标上进行求平均,最后变为3个D维的像素特征(D,1,1),然后再对这个3个像素点在对应通道维度上求平均,最后当前图像的3只猫就被一个D维的像素点表示。
# segment enqueue and dequeue
feat = torch.mean(this_feat[:, idxs], dim=1).squeeze(1)
ptr = int(segment_queue_ptr[lb])
segment_queue[lb, ptr, :] = nn.functional.normalize(feat.view(-1), p=2, dim=0)
segment_queue_ptr[lb] = (segment_queue_ptr[lb] + 1) % self.memory_size
这个方法是训练前期开辟一大块内存,然后每一个mini-batch都会加入一定的特征进去,越到训练后期特征越多,等一轮训练完成后就清空,再重新开始。
困难样本采样策略Segmentation-Aware Hard Anchor Sampling
这个采样策略其实很简单,相比现有的难负挖掘采样策略,它是随机采一半困难样本,剩下的一半就随机采样,这剩下的一本里面应该既有困难样本也有简单样本,这样做的目的是防止全部使用困难样本训练导致过拟合。举个例子,当前是一个类别为猫的像素特征,首先会从memory bank中选择512个D维的像素点,这些像素点属于狗、羊等其他跟猫特征接近的动物,或者是猫的但经常分类分错的像素点,再随机从memory bank中选择一些像素点放一起,源码中是总共选择1024个点用于跟当前的限度点进行损失计算。怎么确定是困难样本还是简单样本呢?就是通过模型的mask图的像素值跟label值对不对的上,mask值跟label值匹配就是困难样本。