手写数字的识别堪称学习神经网络的“Hello World!”,而LeNet-5网络是一种高效的CNN,CNN又是神经网络在图像处理领域的经典应用之一。对于初学者,hyperparameter的设定还是参考成熟的结构,查阅相关论文比较好。
概述
先把一张非常经典的网络图搬上来:
下图摘自吴恩达老师在deeplearning.ai 的讲义,之后代码中网络参数设定将以此为参考:
可见,网络主要由三大部分组成:
· Convolution(卷积)
· Pooling(池化)
· Fully connected(全连接)
整体架构为 :
conv - pool - conv - pool - FC - FC - softmax
参数解释
相关数学基础及推导不再赘述。
layer 1
首先是CNN的输入图像,假设为3通道,长宽均为32的彩色图像,里面包含数字信息“7”。
我们使用kernel_size为5,stride为1的6个filter进行第一次卷积操作,由于使用Valid convolutions,故padding为0。由数学关系易推导得,图像(nH,nW,nC)变为(28,28,6)。
随后进入池化,使用kernel_size为2,stride为2的filter进行Max-Pooling。由数学关系易推导得,图像(nH,nW,nC)变为(14,14,6)。
layer 1 结束。
layer 2
将layer 1输出作为layer 2的输入。
同样是一次卷积一次池化。首先,我们采用kernel_size为5,stride为1的16个filter进行Valid convolutions,然后使用kernel_size为5,stride为2的filter进行Max-Pooling。输出的图像为(5,5,16)。
Fully connected
将layer 2输出展平成dim = 1后作为FC的输入,在pytorch中可使用view()方法,然后经过两个FC后采用softmax函数进行激活获得输出。
神经网络部分代码实现(基于pytorch)
代码使用了MNIST数据集中的train_data进行训练,采用交叉熵(CrossEntropyLoss)计算loss,通过反向传递更新网络参数(使用Adam更新方法),训练方式为miniBatch,其中BATCH_SIZE = 64,LR = 0.001(Adam推荐值),网络搭建采用torch.nn.Sequential进行快速搭建。
由于MNIST中均为灰度图像,故将上述分析中input图像的channels修改为1,考虑到图像实际尺寸,我将参数进行了适当修改。
下图摘自CSDN。
class CNN(torch.nn.Module):
def __init__(self):
super(CNN,self).__init__()
self.conv1 = torch.nn.Sequential(
torch.nn.Conv2d(
in_channels = 1,
out_channels = 6,
kernel_size = 1,
stride = 1,
padding = 0,
),
torch.nn.MaxPool2d(
kernel_size = 2,
stride = 2)
)
self.conv2 = torch.nn.Sequential(
torch.nn.Conv2d(
in_channels = 6,
out_channels = 16,
kernel_size = 5,
stride = 1,
padding = 0,
),
torch.nn.MaxPool2d(
kernel_size = 2,
stride = 2)
)
self.fc1 = torch.nn.Linear(5*5*16,120)
self.fc2 = torch.nn.Linear(120,84)
self.fc3 = torch.nn.Linear(84,10)
def forward(self,x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0),-1)
x = self.fc1(x)
x = self.fc2(x)
output = self.fc3(x)
return F.softmax(output, dim=1)
cnn = CNN()
print(cnn)
optimizer = torch.optim.Adam(cnn.parameters(),lr = LR)
loss_func = torch.nn.CrossEntropyLoss()
运行结果(无CUDA加速)
手写数字识别是卷积神经网络最简单的应用,除此之外,CNN主要在Image Classification,Object detection,Neural Style Transfer等方面广有应用。下一步,我将结合OpenCV,详解CNN在图像分割中的应用。