Unet图像分割

Unet网络是一种图像语义分割网络,图像语义分割网络让计算机根据图像的语义来进行分割,例如让计算机在输入下面下图,能够输出指定分割的图片。
基本图片分割

原图中,物体被分为三类,1.背景, 2.人, 3.自行车

地理信息

语义分割的用处很多,比如说上图中分割卫星图,通过多伦迭代,Prediction逐渐与Grond Truth一致。

Unet网络结构

Unet网络结构如下,整个网络形如字母U。简单的来说,整个网络分为两个部分,左边部分负责特征提取,随着网络层加深,网络的channel逐渐变大,"图片"逐渐变小。右边的网络负责特征的还原,整个网络实际上就是一个编码-解码器。需要注意的是,整个网络最出彩的地方是灰色箭头的部分。在编码的过程中,部分信息丢失了(Maxpooling和Conv2D)。在解码时,加入与之对应的编码层信息。从图上来看的话就是右边每一层网络都加入了一部分"白"色的"图片"(特征)。

全连接语义分割

那么这里就有个问题,为什么要这么复杂的做一个编码-解码器?上图的一个简单的多层卷积就可以完成图像语义分割。


编码解码器

原因就在于随着卷积核的越大,伴随着参数就会成倍增长,一是运算效率会大大下降,其次不利于收敛。这里强烈推荐看一篇文章“看懂”卷积神经网(Visualizing and Understanding Convolutional Networks)

工作原理1

这里讲一下,Unet工作原理,假设我们有一张图片,如左图所示,我们会根据实际需要将需要识别的区域转化为特定的"编码"作为类标签。


工作原理2

工作原理3

实际上每个需要识别的物体需要一个channel,有多少个需要识别的物体,就有多少个输出channel,最后再做一个叠加就是最终我们想分割的结果。

下面哪一个简单的实例代码来说明Unet的工作原理,源代码Github在这里,下面我做一些解释性说明

1.首先引入必要包
%matplotlib inline
%load_ext autoreload
%autoreload 2
import os, sys
import random
import copy
import itertools
import time
from functools import reduce
from collections import defaultdict
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models
from torchsummary import summary
2.生成模拟数据,这一部分不用太纠结代码,复制粘贴就可以
def generate_random_data(height, width, count):
    x, y = zip(*[generate_img_and_mask(height, width) for i in range(0, count)])
    X = np.asarray(x) * 255
    X = X.repeat(3, axis=1).transpose([0, 2, 3, 1]).astype(np.uint8)
    Y = np.asarray(y)
    return X, Y

def generate_img_and_mask(height, width):
    shape = (height, width)
    triangle_location = get_random_location(*shape)
    circle_location1 = get_random_location(*shape, zoom=0.7)
    circle_location2 = get_random_location(*shape, zoom=0.5)
    mesh_location = get_random_location(*shape)
    square_location = get_random_location(*shape, zoom=0.8)
    plus_location = get_random_location(*shape, zoom=1.2)

    # Create input image
    arr = np.zeros(shape, dtype=bool)
    arr = add_triangle(arr, *triangle_location)
    arr = add_circle(arr, *circle_location1)
    arr = add_circle(arr, *circle_location2, fill=True)
    arr = add_mesh_square(arr, *mesh_location)
    arr = add_filled_square(arr, *square_location)
    arr = add_plus(arr, *plus_location)
    arr = np.reshape(arr, (1, height, width)).astype(np.float32)

    # Create target masks
    masks = np.asarray([
        add_filled_square(np.zeros(shape, dtype=bool), *square_location),
        add_circle(np.zeros(shape, dtype=bool), *circle_location2, fill=True),
        add_triangle(np.zeros(shape, dtype=bool), *triangle_location),
        add_circle(np.zeros(shape, dtype=bool), *circle_location1),
         add_filled_square(np.zeros(shape, dtype=bool), *mesh_location),
        # add_mesh_square(np.zeros(shape, dtype=bool), *mesh_location),
        add_plus(np.zeros(shape, dtype=bool), *plus_location)
    ]).astype(np.float32)
    return arr, masks

def add_square(arr, x, y, size):
    s = int(size / 2)
    arr[x-s,y-s:y+s] = True
    arr[x+s,y-s:y+s] = True
    arr[x-s:x+s,y-s] = True
    arr[x-s:x+s,y+s] = True
    return arr

def add_filled_square(arr, x, y, size):
    s = int(size / 2)
    xx, yy = np.mgrid[:arr.shape[0], :arr.shape[1]]
    return np.logical_or(arr, logical_and([xx > x - s, xx < x + s, yy > y - s, yy < y + s]))

def logical_and(arrays):
    new_array = np.ones(arrays[0].shape, dtype=bool)
    for a in arrays:
        new_array = np.logical_and(new_array, a)
    return new_array

def add_mesh_square(arr, x, y, size):
    s = int(size / 2)
    xx, yy = np.mgrid[:arr.shape[0], :arr.shape[1]]
    return np.logical_or(arr, logical_and([xx > x - s, xx < x + s, xx % 2 == 1, yy > y - s, yy < y + s, yy % 2 == 1]))

def add_triangle(arr, x, y, size):
    s = int(size / 2)
    triangle = np.tril(np.ones((size, size), dtype=bool))
    arr[x-s:x-s+triangle.shape[0],y-s:y-s+triangle.shape[1]] = triangle
    return arr

def add_circle(arr, x, y, size, fill=False):
    xx, yy = np.mgrid[:arr.shape[0], :arr.shape[1]]
    circle = np.sqrt((xx - x) ** 2 + (yy - y) ** 2)
    new_arr = np.logical_or(arr, np.logical_and(circle < size, circle >= size * 0.7 if not fill else True))
    return new_arr

def add_plus(arr, x, y, size):
    s = int(size / 2)
    arr[x-1:x+1,y-s:y+s] = True
    arr[x-s:x+s,y-1:y+1] = True
    return arr

def get_random_location(width, height, zoom=1.0):
    x = int(width * random.uniform(0.1, 0.9))
    y = int(height * random.uniform(0.1, 0.9))
    size = int(min(width, height) * random.uniform(0.06, 0.12) * zoom)
    return (x, y, size)

def plot_img_array(img_array, ncol=3):
    nrow = len(img_array) // ncol
    f, plots = plt.subplots(nrow, ncol, sharex='all', sharey='all', figsize=(ncol * 4, nrow * 4))
    for i in range(len(img_array)):
        plots[i // ncol, i % ncol]
        plots[i // ncol, i % ncol].imshow(img_array[i])

def plot_side_by_side(img_arrays):
    flatten_list = reduce(lambda x,y: x+y, zip(*img_arrays))
    plot_img_array(np.array(flatten_list), ncol=len(img_arrays))

def plot_errors(results_dict, title):
    markers = itertools.cycle(('+', 'x', 'o'))
    plt.title('{}'.format(title))
    for label, result in sorted(results_dict.items()):
        plt.plot(result, marker=next(markers), label=label)
        plt.ylabel('dice_coef')
        plt.xlabel('epoch')
        plt.legend(loc=3, bbox_to_anchor=(1, 0))
    plt.show()

def masks_to_colorimg(masks):
    colors = np.asarray([(201, 58, 64), (242, 207, 1), (0, 152, 75), (101, 172, 228),(56, 34, 132), (160, 194, 56)])
    colorimg = np.ones((masks.shape[1], masks.shape[2], 3), dtype=np.float32) * 255
    channels, height, width = masks.shape
    for y in range(height):
        for x in range(width):
            selected_colors = colors[masks[:,y,x] > 0.5]
            if len(selected_colors) > 0:
                colorimg[y,x,:] = np.mean(selected_colors, axis=0)
    return colorimg.astype(np.uint8)
3.看一下输入数据和类标签数据
# 生成图片与类标签(192*192, 3张)
input_images, target_masks = generate_random_data(192, 192, count=1)
print(f'输入数据维度:{input_images.shape}')
print(f'输出数据维度:{target_masks.shape}')
# 修改数据类型,方便画图
input_images_rgb = [x.astype(np.uint8) for x in input_images]
# 将灰度图片(channel=1)变为RGB图片(channel=3)
target_masks_rgb = [masks_to_colorimg(x) for x in target_masks]
# 显示模拟图片
plot_side_by_side([input_images_rgb, target_masks_rgb])

['out']:输入数据维度:(1, 192, 192, 3)
['out']:输出数据维度:(1, 6, 192, 192)

训练数据一个(192,192,3(RGB通道))的RGB图片, 类标签数据是一组灰度图片(6,192,192),每个需要识别的图形是一个灰度图片一共6个图形。
模拟数据

左图为输入数据,右图中将类标签灰度图片加了RBG通道,然后6张图叠加的效果图(我们只需预测6张灰度图即可)。

4.数据生成器
# 一个简单的pytorch数据迭代器
class SimDataset(Dataset):
    def __init__(self, count, transform=None):
        # count:每次需要生成的数据量
        # transform指定数据转化器
        self.input_images, self.target_masks = generate_random_data(192, 192, count=count)        
        self.transform = transform

    def __len__(self):
        return len(self.input_images)
    
    def __getitem__(self, idx):
        image = self.input_images[idx]
        mask = self.target_masks[idx]
        if self.transform:
            image = self.transform(image)
        return [image, mask]
# use same transform for train/val for this example
trans = transforms.Compose([
    transforms.ToTensor(),
])
# 这里生成2000组模拟数据作为训练集, 200组模拟数据作为测试集
train_set = SimDataset(2000, transform = trans)
val_set = SimDataset(200, transform = trans)
batch_size = 25
dataloaders = {
    'train': DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0),
    'val': DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=0)
}

Unet网络

Unet编码层
# Unet编码层, 如上图所示,包含两个(卷积+Relu)
# 原始Unet网络中padding=0(填充),所以"图片"会变小
# 572*572--->570*570--->568*568
def double_conv(in_channels, out_channels):
  return nn.Sequential(
      nn.Conv2d(in_channels, out_channels, 3, padding=1),
      nn.ReLU(inplace=True),
      nn.Conv2d(out_channels, out_channels, 3, padding=1),
      nn.ReLU(inplace=True)
  )
Unet编码层2

Unet解码层1
5.定义网络
# Unet经过一次double_conv通道数加倍(变厚),然后使用Maxpool, "图片"维度/2(变小)
class Unet(nn.Module):
  def __init__(self, n_class):
    super().__init__()
    self.dconv_down1 = double_conv(3, 64)
    self.dconv_down2 = double_conv(64, 128)
    self.dconv_down3 = double_conv(128, 256)
    self.dconv_down4 = double_conv(256, 512)
    self.maxpool = nn.MaxPool2d(2)
    self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # 这里使用双线性插值
    self.dconv_up3 = double_conv(256 + 512, 256)
    self.dconv_up2 = double_conv(128 + 256, 128)
    self.dconv_up1 = double_conv(128 + 64, 64)
    self.conv_last = nn.Conv2d(64, n_class, 1) # 最后一层, 需要识别多少种目标,则输出多少个channel(n_class)

  def forward(self, x):
    conv1 = self.dconv_down1(x)
    x = self.maxpool(conv1) # 对应上图Unet编码层2
    conv2 = self.dconv_down2(x)
    x = self.maxpool(conv2)
    conv3 = self.dconv_down3(x)
    x = self.maxpool(conv3)
    x = self.dconv_down4(x) #到底了
    x = self.upsample(x) # 双线性插值,还原"图片"
    # 解码数据与对应编码数据concat使channel数增加, 弥补了单纯上采样导致的信息还原不足
    # 这一步很关键(也就是图Unet解码层1中数据变"厚")
    x = torch.cat([x, conv3], dim=1) 
    x = self.dconv_up3(x)
    x = self.upsample(x)        
    x = torch.cat([x, conv2], dim=1) # 256+128
    x = self.dconv_up2(x)# 
    x = self.upsample(x)        
    x = torch.cat([x, conv1], dim=1)
    x = self.dconv_up1(x)
    out = self.conv_last(x)
    return out
# 这里打印一下网络结构
model = Unet(6)
summary(model, input_size=(3, 224, 224))
数值化网络结构
6.损失函数
def dice_loss(pred, target, smooth = 1.):
    pred = pred.contiguous()
    target = target.contiguous()    
    intersection = (pred * target).sum(dim=2).sum(dim=2)
    loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))
    return loss.mean()
# 这里使用两种损失函数加权
def calc_loss(pred, target, metrics, bce_weight=0.5):
    bce = F.binary_cross_entropy_with_logits(pred, target) 
    pred = F.sigmoid(pred)
    dice = dice_loss(pred, target)
    loss = bce * bce_weight + dice * (1 - bce_weight)
    metrics['bce'] += bce.data.cpu().numpy() * target.size(0)
    metrics['dice'] += dice.data.cpu().numpy() * target.size(0)
    metrics['loss'] += loss.data.cpu().numpy() * target.size(0)
    return loss

def print_metrics(metrics, epoch_samples, phase):    
    outputs = []
    for k in metrics.keys():
        outputs.append("{}: {:4f}".format(k, metrics[k] / epoch_samples))
    print("{}: {}".format(phase, ", ".join(outputs))) 

def train_model(model, optimizer, scheduler, num_epochs=25):
  best_model_wts = copy.deepcopy(model.state_dict())
  best_loss = 1e10
  for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('-'*10)
    since = time.time()
    for phase in ['train', 'val']:
      if phase == 'train':
        scheduler.step()
        model.train()  # Set model to training mode
      else:
        model.eval()   # Set model to evaluate mode
      metrics = defaultdict(float)
      epoch_samples = 0

      for inputs, labels in dataloaders[phase]:
        inputs = inputs.to(device)
        labels = labels.to(device)
        # zero the parameter gradients
        optimizer.zero_grad()
        with torch.set_grad_enabled(phase == 'train'):
          outputs = model(inputs)
          loss = calc_loss(outputs, labels, metrics)
          if phase == 'train':
            loss.backward()
            optimizer.step()
        epoch_samples += inputs.size(0)
      print_metrics(metrics, epoch_samples, phase)
      epoch_loss = metrics['loss'] / epoch_samples

      if phase == 'val' and epoch_loss < best_loss:
        print("saving best model")
        best_loss = epoch_loss
        best_model_wts = copy.deepcopy(model.state_dict())
    time_elapsed = time.time() - since
    print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
  print('Best val loss: {:4f}'.format(best_loss))
  # load best model weights
  model.load_state_dict(best_model_wts)
  return model
7.训练模型
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_class = 6
model = Unet(num_class).to(device)
optimizer_ft = optim.Adam(model.parameters(), lr=1e-4)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=25, gamma=0.1)
model = train_model(model, optimizer_ft, exp_lr_scheduler, num_epochs=40)
训练结果
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 202,802评论 5 476
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,109评论 2 379
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 149,683评论 0 335
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,458评论 1 273
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,452评论 5 364
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,505评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,901评论 3 395
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,550评论 0 256
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,763评论 1 296
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,556评论 2 319
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,629评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,330评论 4 318
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,898评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,897评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,140评论 1 259
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 42,807评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,339评论 2 342

推荐阅读更多精彩内容

  • 主要内容包括: 1、基于boosting级联学习的遥感目标检测 Adaboost算法、前向分步算法、提升树、梯度提...
    大概是只翻车鱼的莫方阅读 2,914评论 0 2
  • 这些年计算机视觉识别和搜索这个领域非常热闹,后期出现了很多的创业公司,大公司也在这方面也花了很多力气在做。做视觉搜...
    方弟阅读 6,436评论 6 24
  • 月轮国,地处西域,崇尚佛法,共辖七府四十二县一百六十八镇。府县镇各级都设有寺庙。佛在月轮国为尊,连带着和尚的地位也...
    三小山闲话阅读 627评论 0 2
  • 昨天晚上十点下班,挺累的,只因昨天周三,店里搞活动,人比较多。 天气逐渐转冷,骑车走在下班的路上,虽是十点,在上海...
    爱佳蓉阅读 118评论 0 0
  • 自己从建库开始完成按key1键实现红绿蓝灯转换,按key2蜂鸣器响,松手停,用中断形式,其中包括中断子函数的编写
    李欣l阅读 204评论 0 0