问题起源
深度学习普遍认为发端于2006年,根据Bengio的定义,深层网络由多层自适应非线性单元组成——即多层非线性模块的级联,所有层次上都包含可训练的参数,在工程实际操作中,深层神经网络通常是五层及以上,包含数百万个可学习的自由参数的庞然大物。理论上,网络模型无论深浅与否,都能通过函数逼近数据的内在关系和本质特征,但在解决真实世界的复杂问题时,需要指数增长的计算单元,浅层网络往往出现函数表达能力不足,而深层网络则可能仅仅需要较少的计算单元。
不过网络并不是像理论上那样越深越好,除了显而易见的因为层数过多而导致浪费性质的占用显存和“吃”计算力的问题,还会出现以下三种问题。
- 过拟合 (over fit)
- 梯度弥散 (vanishing gradient problem)
- 网络退化 (degenerate)
其中,问题一、二并不是本文所讲的残差学习主要要解决的问题,所以就不多赘述,只讲述网络退化的问题。其现象如下图所示,是随着网络层数的增多,整体模型的表达能力增强,但是训练精度反而变差,并且因为训练精度本身也下降的缘故,故而可以排除是过拟合的原因,而确定是网络退化。
When deeper networks are able to start converging, a degradation problem has been exposed: with the network depth increasing, accuracy gets saturated which might be unsurprising and then degrades rapidly. Unexpectedly, such degradation is not caused by overfitting, and adding more layers to a suitably deep model leads to higher training error,as reported in and thoroughly verified by our experiments.
但是,很可惜的是,业界对于网络退化的原因及其标准情况依然没有定论,甚至说出现了随着网络变深而效果变差的问题的时候,也有可能无法分辨出是梯度弥散还是网络退化的问题。读者如果有兴趣,可以自行去寻找网络退化方面的研究论文,各家的观点虽然都不尽相同,但我们还是可以发现不少有用的信息。
残差学习
而对于上述问题,Kaiming He大神提出了一种简洁而不失优雅的残差学习的方法。多的不谈,我们直接甩出模型结构来讲解残差学习的思想。
首先,只看图的左半边,也就是橘红色的部分。左侧与普通网络连接方式的区别一目了然——在顺次直连而下的基础上加入了每隔两层的跨接桥(其实官方的叫法并非如此,然而这么叫它显得更加直观)。不过纯凭看图的感觉毕竟流于表面,用数学说话才是严谨的态度。
对于一个神经网络而言,我们需要用反向传播来更新参数,就像这样:
此时,第二个式子所得的结果就是我们常说的梯度。
而当如下图网络越来越深的时候:
......
这时候再通过算偏导求梯度,就会是这样:
其实数列的每一项都很小,再依此相乘就会越来越小,最后趋近于0,举个简单的例子就是0.9虽然很接近于1,但当有n个0.9相乘时(n趋近于无限大),最后的结果就会无限趋近于0。
而当有了“跨接桥”之后,我们再算偏导的时候就会变成这样:
说白了就是1.01的n次方依然大于1。
最后,我们可以发现对于相同的数据集来讲,残差网络比同等深度的其他网络表现出了更好的性能。
不过,这是大神的测试结果,没有什么说服力,而我在自己的项目里做了一组关于有无残差学习的对比,下面是数据图(项目是和图像增强有关,所以用PSNR作为评判标准):
最后可见,Loss的下降趋势,残差学习的方法明显更加平稳,而最后结果Loss和PSNR虽然差距目测不大,但最后的图片视觉效果却千差万别。
下一节我们会讲模型结构图的右半边——同样是Kaiming He大神的Skip Connection策略。