简介
在前两篇文章手写一个全连接神经网络用于MNIST数据集 和 全连接神经网络之反向传播算法原理推导中,我们讨论了全连接神经网络是如何应用梯度下降算法来学习权重和偏置,以及反向传播算法的工作原理。在些例子中,我们都使用了均方差损失函数,因为它的形式非常直观,直接刻画了期望输出与真实值之间的差值。但是它也有一些弊端,本文将重点讨论均方差损失函数存在的问题,并引出交叉熵损失函数。
均方差损失函数
先回忆一下均方差损失函数的定义:
这里的代表的是网络中所有权重的集合,是所有的偏置,是训练输入数据的个数,是输入对应的标签,是表示当输入为时神经网络输出的向量,求和则是在总的训练输入上进行的。
再回忆一下,随机梯度下降算法中权重和偏置更新的方程,如下:
通过上式可以知道,权重的更新跟损失函数的偏导数 和 有关系。在学习率一定的情况下,偏导数越大,权重和偏置更新的越快。反之,越慢。
再根据全连接神经网络之反向传播算法原理推导提及到的反向传播的四个方程,如下:
我们可以看到偏导数 和 又和误差有关,而误差又与激活函数的导数有关系。费了这么多口舌,其实是想说明权重和偏置的更新最终会和激活函数的导数有关系。而这会导致一个什么问题呢?
先看下激活函数的图像:
对Sigmoid函数求导得:
的图像如下:
可以看到,当输入较大或者较小的时候,的值趋向于1或者0,此时的值就非常小,接近于0。这会导致梯度消失的问题,下面以一个小例子来说明。
令,以下图的神经网络结构为例,它每一层只包含一个神经元:
我们可以求得对偏置的偏导数如下:
根据Sigmoid导数的图像可知,的最大值为0.25,而一般随机初始化的权重和偏置一般都小于1,随着层数的增加,上式中的链式求导连乘也会越来越深,则的值将会越来越小,导致越靠近输入层的神经元的权重和偏置越难得到更新,即便网络训练完了,这些神经元的权重和偏置还是跟初始值差不多,那么此时的深度神经网络就只相当于后几层的浅层学习网络结构了。
当然解决这个问题的办法有很多,常见的就是换成Relu激活函数,但这不是本文讨论的重点,本文将从损失函数的角度来解决这个问题。
交叉熵损失函数
如何来解决上一节提出问题呢?先给出结论,我们可以使用交叉熵损失函数来代替均方差损失函数,定义如下:其中是训练样本的总数,求和是在所有的训练输入上进行的,是对应的输出,是神经元的输出,,其中。
不同于均方差损失函数的定义,我们可以一眼看出来它就是衡量了期望输出与实际输出之间的误差,交叉熵损失函数貌似不太能看出来它是如何衡量期望输出和实际输出之间的误差的。下面详细解释一下,主要有2点:
- 它是非负的,可以看到求和符号中的每一项都是负数,累加起来之后仍然是负数,最后再添加一个负号,所以最后的结果是非负的。
- 如果或,那么趋近于0,如果或,那么会变得很大。即,神经元输出与真实值相近,交叉熵趋于0,反之,趋于无穷大,这很符合直觉。
那它是如何解决因梯度消失而导致的参数学习缓慢的问题呢?我们来求一下关于权重的偏导数:
注意上式证明的过程中,用到了(1)式的结论。
可以看到,在使用交叉熵损失函数之后,计算对权重的偏导数的时候,最终的式子里面没有再出现,并且可以看出,偏导数受到的影响,即真实输出与实际输出的差值。当差别越大时,偏导数越大,权重学习速度就会越快。避免了均方差损失函数中因为导致的学习缓慢。
同理,关于偏置的偏导数如下:
交叉熵损失函数与均方差损失函数之间关系
上一节直接给出了交叉熵的定义,而并没有描述它是如何得出的,它究竟表示什么,最开始的研究者是如何想到这个概念的呢?
同样采用均方差损失函数,定义如下:
其中是神经元的输出,是目标输出。分别对和求偏导得:
既然我们发现了权重学习缓慢的原因是因为的值过小,那么我们何不直接去掉上面两个偏导数中包含的呢,那么可以得到:
又, 由(1)式可知,将其带入(2)式得:
又我们令,故结合起来可得:
其中是一个常量。上式是一个单独的训练样本对损失函数的贡献,为了得到整个损失函数,需要对全部训练样本进行平均,得到;
这个形式就跟我们上一节给出的交叉熵损失函数定义很相似了。可以看到,交叉熵损失函数不是凭空得来的,而是以一种自然而然的方式计算出来的。
交叉熵函数的正向推导
上一节给出了交叉熵损失函数与均方差损失函数的关系。通过改进均方差损失函数的偏导数,然后通过不定积分反向推导出新的损失函数,即交叉熵损失函数。
下面从概率的角度来推导一下交叉熵损失函数。
我们的识别手写数字的神经网络的输出层会通过激活函数,函数函数的输出在0~1之间,这个输出可以认为是一个概率值,即代表了预测当前神经元为真的可能性。值越大,代表当前神经元输出为真的可能性越大。比如第1个神经元输出的值最大,则代表神经网络认为这个数字是“0”。
即对于输出层的某个神经元,我们用来代表该神经元预测为真的概率:
式(11)只不过是聚合了式(9)和(10),本质上是一样的。我们希望的概率越大越好,为了计算方便,我们对式(11)引入函数,这并不会改变其单调性,有:
我们的目的是希望越大越好,先在上式前添加一个负号,变成,则现在的目的就变成了希望它越小越好,这么做更加符合直觉。接着我们可以定义损失函数:
由于这只是针对单个训练样本的损失函数,对于全体个训练样本而言,我们可以累加起来求平均,得到如下形式:
这与上面提及的交叉熵损失函数是一致的。