批归一化是现在越来越多的神经网络采用的方法,其具有加快训练速度、防止过拟合等优点,尤其在深度神经网络中效果非常好。现将BN的学习整理一篇文章备忘。
1.为什么要采用BN?
随着神经网络的层数加深,研究者发现神经网络训练起来越困难,收敛越慢。BN就是为解决这一问题提出的。
首先明确神经网络之所以可以先训练,再预测并取得较好效果的前提假设:
我们假设神经网络的训练数据和测试数据是独立同分布的
在神经网络的训练过程中,如果输入数据的分布不断变化,神经网络将很难稳定的学习规律,这也是one example SGD训练收敛慢的原因(随机得到的数据前后之间差距可能会很大)。而网络的每一层通过对输入数据进行线性和非线性变换都会改变数据的分布,随着网络层数的加深,每层接收到的数据分布都不一样,这还怎么学习规律呀,这就使得深层网络训练困难。
BN的启发来源是:之前的研究表明如果在图像处理中对输入图像进行白化操作的话(所谓白化,就是对输入数据分布变换到0均值,单位方差的正态分布)那么神经网络会较快收敛。神经网络有很多隐藏层,图像只是第一层的输入数据,对于每一个隐藏层来说,都有一个输入数据,即前一层的输出。BN将每一层的输入都进行了类似于图像白化的操作,将每层的数据都控制在稳定的分布内,并取得了很好的效果。
2.怎么进行BN?
BN算法是专门针对mini-batch SGD进行优化的,mini-batch SGD一次性输入batchsize个数据进行训练,相比one example SGD,mini-batch SGD梯度更新方向更准确,毕竟多个数据的分布和规律更接近整体数据的分布和规律,类似于多次测量取平均值减小误差的思想,所以收敛速度更快。
3.BN的本质思想
BN究竟对数据的分布做了什么处理,我们来看下面的示意图:
在概率论中我们都学过,数据减去均值除以方差后,将变成均值为0,方差为1的标准正态分布。如果数据分布在激活函数(图中假设为sigmoid)梯度比较小的范围,在深层神经网络训练中将很容易出现梯度消失的现象,这也是深度网络难训练的原因。通过规范化处理后的数据分布在0附近,图中为激活函数梯度最大值附近,较大的梯度在训练中收敛速度自然快。
但是,关键问题出现了,分布在0附近的数据,sigmoid近似线性函数,这就失去了非线性激活操作的意义,这种情况下,神经网络将降低拟合性能,如何解决这一问题呢?作者对规范化后的(0,1)正态分布数据x又进行了scale和shift操作:y = scale * x + shift,即对(0,1)正态分布的数据进行了均值平移和方差变换,使数据从线性区域向非线性区域移动一定的范围,使数据在较大梯度和非线性变换之间找到一个平衡点,在保持较大梯度加快训练速度的同时又不失线性变换提高表征能力。这两个参数需要在训练中由神经网络自己学习,即公式中的γ和β。如果原始数据的分布就很合适,那么即使经过BN,数据也可以回到原始分布状态,这种情况下就相当于恒等变换了,当然这是特殊情况。
4.预测时的BN实现:
在训练时,BN的操作步骤如第一张图所示,那么在预测时,每次只输入一张图的情况下,无法进行均值和方差的计算,此时该怎么实现BN呢?
正是因为训练数据和测试数据是独立同分布的,所以我们可用训练时的所有均值和方差来对测试数据的均值和方差进行无偏估计 。本来mini-batch SGD就是在整体数据量大无法一次性操作的情况下,把数据切割成几部分,用部分近似整体的解决方案。在训练时,将每一个mini-batch的均值和方差记录下,估计出整体的均值和方差如下:
最后一步的计算就是:
这样,就完成了预测时的BN计算。
5.BN计算细节:
首先要明确一点:BN是沿着batch方向计算的,每个神经元都会有一组()在非线性激活前面约束数据。
(1)全连接层的BN计算方法:
假设batch_size=m, 输入的每一个样本有d维,记为
下标表示batch,上标表示一个样本中第几个维度,即第几个神经元。
那么BN计算如下:
中间的过程省略了,其核心思想就是BN是对每一个batch的某一固定维度规范化的,一个样本中有d维,就会求出d组(),即每一个神经元都有一组()。
(2)卷积层的BN计算方法:
在卷积层中,数据在某个卷积层中的维度是[batch, w, h, c],其中batch表示batch_size,w是feature map的宽,h是feature map的高,c表示channels。在沿着batch的方向,每个channel的feature map就相当于一个神经元,经过BN后会得到c组()。此时的BN算法可表示如下:
图中的, 表示一个batch中每个样本相同位置的feature map中的元素,共有batchwh个元素。计算结束后会有c对()。
6.BN的反向传播
7.BN究竟用在什么位置?
原论文中BN操作是放在线性计算后,非线性激活前,即:
其中g()表示激活函数。
这里建议参考一下ResNet_v1和ResNet_v2的用法:
最后一点还需要注意的是,在使用BN后,神经网络的线性计算(WX + b)中的偏置b将不起作用,因为在(WX + b)求均值后b作为常数均值还是b,在规范化的过程中原数据要减去均值,所以b在这两步计算中完全抵消了。但由于BN的算法中有一个偏置项β,它完全可以代替b的作用,所以有BN的计算中可不用b。