Unet网络是一种图像语义分割网络,图像语义分割网络让计算机根据图像的语义来进行分割,例如让计算机在输入下面下图,能够输出指定分割的图片。
原图中,物体被分为三类,1.背景, 2.人, 3.自行车
语义分割的用处很多,比如说上图中分割卫星图,通过多伦迭代,Prediction逐渐与Grond Truth一致。
Unet网络结构如下,整个网络形如字母U。简单的来说,整个网络分为两个部分,左边部分负责特征提取,随着网络层加深,网络的channel逐渐变大,"图片"逐渐变小。右边的网络负责特征的还原,整个网络实际上就是一个编码-解码器。需要注意的是,整个网络最出彩的地方是灰色箭头的部分。在编码的过程中,部分信息丢失了(Maxpooling和Conv2D)。在解码时,加入与之对应的编码层信息。从图上来看的话就是右边每一层网络都加入了一部分"白"色的"图片"(特征)。
那么这里就有个问题,为什么要这么复杂的做一个编码-解码器?上图的一个简单的多层卷积就可以完成图像语义分割。
原因就在于随着卷积核的越大,伴随着参数就会成倍增长,一是运算效率会大大下降,其次不利于收敛。这里强烈推荐看一篇文章“看懂”卷积神经网(Visualizing and Understanding Convolutional Networks)
这里讲一下,Unet工作原理,假设我们有一张图片,如左图所示,我们会根据实际需要将需要识别的区域转化为特定的"编码"作为类标签。
实际上每个需要识别的物体需要一个channel,有多少个需要识别的物体,就有多少个输出channel,最后再做一个叠加就是最终我们想分割的结果。
下面哪一个简单的实例代码来说明Unet的工作原理,源代码Github在这里,下面我做一些解释性说明
1.首先引入必要包
%matplotlib inline
%load_ext autoreload
%autoreload 2
import os, sys
import random
import copy
import itertools
import time
from functools import reduce
from collections import defaultdict
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models
from torchsummary import summary
2.生成模拟数据,这一部分不用太纠结代码,复制粘贴就可以
def generate_random_data(height, width, count):
x, y = zip(*[generate_img_and_mask(height, width) for i in range(0, count)])
X = np.asarray(x) * 255
X = X.repeat(3, axis=1).transpose([0, 2, 3, 1]).astype(np.uint8)
Y = np.asarray(y)
return X, Y
def generate_img_and_mask(height, width):
shape = (height, width)
triangle_location = get_random_location(*shape)
circle_location1 = get_random_location(*shape, zoom=0.7)
circle_location2 = get_random_location(*shape, zoom=0.5)
mesh_location = get_random_location(*shape)
square_location = get_random_location(*shape, zoom=0.8)
plus_location = get_random_location(*shape, zoom=1.2)
# Create input image
arr = np.zeros(shape, dtype=bool)
arr = add_triangle(arr, *triangle_location)
arr = add_circle(arr, *circle_location1)
arr = add_circle(arr, *circle_location2, fill=True)
arr = add_mesh_square(arr, *mesh_location)
arr = add_filled_square(arr, *square_location)
arr = add_plus(arr, *plus_location)
arr = np.reshape(arr, (1, height, width)).astype(np.float32)
# Create target masks
masks = np.asarray([
add_filled_square(np.zeros(shape, dtype=bool), *square_location),
add_circle(np.zeros(shape, dtype=bool), *circle_location2, fill=True),
add_triangle(np.zeros(shape, dtype=bool), *triangle_location),
add_circle(np.zeros(shape, dtype=bool), *circle_location1),
add_filled_square(np.zeros(shape, dtype=bool), *mesh_location),
# add_mesh_square(np.zeros(shape, dtype=bool), *mesh_location),
add_plus(np.zeros(shape, dtype=bool), *plus_location)
]).astype(np.float32)
return arr, masks
def add_square(arr, x, y, size):
s = int(size / 2)
arr[x-s,y-s:y+s] = True
arr[x+s,y-s:y+s] = True
arr[x-s:x+s,y-s] = True
arr[x-s:x+s,y+s] = True
return arr
def add_filled_square(arr, x, y, size):
s = int(size / 2)
xx, yy = np.mgrid[:arr.shape[0], :arr.shape[1]]
return np.logical_or(arr, logical_and([xx > x - s, xx < x + s, yy > y - s, yy < y + s]))
def logical_and(arrays):
new_array = np.ones(arrays[0].shape, dtype=bool)
for a in arrays:
new_array = np.logical_and(new_array, a)
return new_array
def add_mesh_square(arr, x, y, size):
s = int(size / 2)
xx, yy = np.mgrid[:arr.shape[0], :arr.shape[1]]
return np.logical_or(arr, logical_and([xx > x - s, xx < x + s, xx % 2 == 1, yy > y - s, yy < y + s, yy % 2 == 1]))
def add_triangle(arr, x, y, size):
s = int(size / 2)
triangle = np.tril(np.ones((size, size), dtype=bool))
arr[x-s:x-s+triangle.shape[0],y-s:y-s+triangle.shape[1]] = triangle
return arr
def add_circle(arr, x, y, size, fill=False):
xx, yy = np.mgrid[:arr.shape[0], :arr.shape[1]]
circle = np.sqrt((xx - x) ** 2 + (yy - y) ** 2)
new_arr = np.logical_or(arr, np.logical_and(circle < size, circle >= size * 0.7 if not fill else True))
return new_arr
def add_plus(arr, x, y, size):
s = int(size / 2)
arr[x-1:x+1,y-s:y+s] = True
arr[x-s:x+s,y-1:y+1] = True
return arr
def get_random_location(width, height, zoom=1.0):
x = int(width * random.uniform(0.1, 0.9))
y = int(height * random.uniform(0.1, 0.9))
size = int(min(width, height) * random.uniform(0.06, 0.12) * zoom)
return (x, y, size)
def plot_img_array(img_array, ncol=3):
nrow = len(img_array) // ncol
f, plots = plt.subplots(nrow, ncol, sharex='all', sharey='all', figsize=(ncol * 4, nrow * 4))
for i in range(len(img_array)):
plots[i // ncol, i % ncol]
plots[i // ncol, i % ncol].imshow(img_array[i])
def plot_side_by_side(img_arrays):
flatten_list = reduce(lambda x,y: x+y, zip(*img_arrays))
plot_img_array(np.array(flatten_list), ncol=len(img_arrays))
def plot_errors(results_dict, title):
markers = itertools.cycle(('+', 'x', 'o'))
plt.title('{}'.format(title))
for label, result in sorted(results_dict.items()):
plt.plot(result, marker=next(markers), label=label)
plt.ylabel('dice_coef')
plt.xlabel('epoch')
plt.legend(loc=3, bbox_to_anchor=(1, 0))
plt.show()
def masks_to_colorimg(masks):
colors = np.asarray([(201, 58, 64), (242, 207, 1), (0, 152, 75), (101, 172, 228),(56, 34, 132), (160, 194, 56)])
colorimg = np.ones((masks.shape[1], masks.shape[2], 3), dtype=np.float32) * 255
channels, height, width = masks.shape
for y in range(height):
for x in range(width):
selected_colors = colors[masks[:,y,x] > 0.5]
if len(selected_colors) > 0:
colorimg[y,x,:] = np.mean(selected_colors, axis=0)
return colorimg.astype(np.uint8)
3.看一下输入数据和类标签数据
# 生成图片与类标签(192*192, 3张)
input_images, target_masks = generate_random_data(192, 192, count=1)
print(f'输入数据维度:{input_images.shape}')
print(f'输出数据维度:{target_masks.shape}')
# 修改数据类型,方便画图
input_images_rgb = [x.astype(np.uint8) for x in input_images]
# 将灰度图片(channel=1)变为RGB图片(channel=3)
target_masks_rgb = [masks_to_colorimg(x) for x in target_masks]
# 显示模拟图片
plot_side_by_side([input_images_rgb, target_masks_rgb])
['out']:输入数据维度:(1, 192, 192, 3)
['out']:输出数据维度:(1, 6, 192, 192)
训练数据一个(192,192,3(RGB通道))的RGB图片, 类标签数据是一组灰度图片(6,192,192),每个需要识别的图形是一个灰度图片一共6个图形。
左图为输入数据,右图中将类标签灰度图片加了RBG通道,然后6张图叠加的效果图(我们只需预测6张灰度图即可)。
4.数据生成器
# 一个简单的pytorch数据迭代器
class SimDataset(Dataset):
def __init__(self, count, transform=None):
# count:每次需要生成的数据量
# transform指定数据转化器
self.input_images, self.target_masks = generate_random_data(192, 192, count=count)
self.transform = transform
def __len__(self):
return len(self.input_images)
def __getitem__(self, idx):
image = self.input_images[idx]
mask = self.target_masks[idx]
if self.transform:
image = self.transform(image)
return [image, mask]
# use same transform for train/val for this example
trans = transforms.Compose([
transforms.ToTensor(),
])
# 这里生成2000组模拟数据作为训练集, 200组模拟数据作为测试集
train_set = SimDataset(2000, transform = trans)
val_set = SimDataset(200, transform = trans)
batch_size = 25
dataloaders = {
'train': DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0),
'val': DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=0)
}
Unet网络
# Unet编码层, 如上图所示,包含两个(卷积+Relu)
# 原始Unet网络中padding=0(填充),所以"图片"会变小
# 572*572--->570*570--->568*568
def double_conv(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, padding=1),
nn.ReLU(inplace=True)
)
5.定义网络
# Unet经过一次double_conv通道数加倍(变厚),然后使用Maxpool, "图片"维度/2(变小)
class Unet(nn.Module):
def __init__(self, n_class):
super().__init__()
self.dconv_down1 = double_conv(3, 64)
self.dconv_down2 = double_conv(64, 128)
self.dconv_down3 = double_conv(128, 256)
self.dconv_down4 = double_conv(256, 512)
self.maxpool = nn.MaxPool2d(2)
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # 这里使用双线性插值
self.dconv_up3 = double_conv(256 + 512, 256)
self.dconv_up2 = double_conv(128 + 256, 128)
self.dconv_up1 = double_conv(128 + 64, 64)
self.conv_last = nn.Conv2d(64, n_class, 1) # 最后一层, 需要识别多少种目标,则输出多少个channel(n_class)
def forward(self, x):
conv1 = self.dconv_down1(x)
x = self.maxpool(conv1) # 对应上图Unet编码层2
conv2 = self.dconv_down2(x)
x = self.maxpool(conv2)
conv3 = self.dconv_down3(x)
x = self.maxpool(conv3)
x = self.dconv_down4(x) #到底了
x = self.upsample(x) # 双线性插值,还原"图片"
# 解码数据与对应编码数据concat使channel数增加, 弥补了单纯上采样导致的信息还原不足
# 这一步很关键(也就是图Unet解码层1中数据变"厚")
x = torch.cat([x, conv3], dim=1)
x = self.dconv_up3(x)
x = self.upsample(x)
x = torch.cat([x, conv2], dim=1) # 256+128
x = self.dconv_up2(x)#
x = self.upsample(x)
x = torch.cat([x, conv1], dim=1)
x = self.dconv_up1(x)
out = self.conv_last(x)
return out
# 这里打印一下网络结构
model = Unet(6)
summary(model, input_size=(3, 224, 224))
6.损失函数
def dice_loss(pred, target, smooth = 1.):
pred = pred.contiguous()
target = target.contiguous()
intersection = (pred * target).sum(dim=2).sum(dim=2)
loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))
return loss.mean()
# 这里使用两种损失函数加权
def calc_loss(pred, target, metrics, bce_weight=0.5):
bce = F.binary_cross_entropy_with_logits(pred, target)
pred = F.sigmoid(pred)
dice = dice_loss(pred, target)
loss = bce * bce_weight + dice * (1 - bce_weight)
metrics['bce'] += bce.data.cpu().numpy() * target.size(0)
metrics['dice'] += dice.data.cpu().numpy() * target.size(0)
metrics['loss'] += loss.data.cpu().numpy() * target.size(0)
return loss
def print_metrics(metrics, epoch_samples, phase):
outputs = []
for k in metrics.keys():
outputs.append("{}: {:4f}".format(k, metrics[k] / epoch_samples))
print("{}: {}".format(phase, ", ".join(outputs)))
def train_model(model, optimizer, scheduler, num_epochs=25):
best_model_wts = copy.deepcopy(model.state_dict())
best_loss = 1e10
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-'*10)
since = time.time()
for phase in ['train', 'val']:
if phase == 'train':
scheduler.step()
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
metrics = defaultdict(float)
epoch_samples = 0
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
# zero the parameter gradients
optimizer.zero_grad()
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
loss = calc_loss(outputs, labels, metrics)
if phase == 'train':
loss.backward()
optimizer.step()
epoch_samples += inputs.size(0)
print_metrics(metrics, epoch_samples, phase)
epoch_loss = metrics['loss'] / epoch_samples
if phase == 'val' and epoch_loss < best_loss:
print("saving best model")
best_loss = epoch_loss
best_model_wts = copy.deepcopy(model.state_dict())
time_elapsed = time.time() - since
print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best val loss: {:4f}'.format(best_loss))
# load best model weights
model.load_state_dict(best_model_wts)
return model
7.训练模型
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_class = 6
model = Unet(num_class).to(device)
optimizer_ft = optim.Adam(model.parameters(), lr=1e-4)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=25, gamma=0.1)
model = train_model(model, optimizer_ft, exp_lr_scheduler, num_epochs=40)