博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
深度学习原理之理解反向传播[1] --- 简单全连接网络(BP)
阅读量:6899 次
发布时间:2019-06-27

本文共 2442 字,大约阅读时间需要 8 分钟。

写在前面

一开始看的是cs231n的min-char-rnn.py,无奈怎么都看不懂反向传播那一块的内容。后也是不断搜索,发现那不是普通的bp,也不是正规的bptt,所以要理解karpathy的代码我还需要更深地理解反向传播,从头开始。

最早看了这么一篇 ,还颇以为详尽明白,但是套到min-char-rnn里一点也讲不通,后来明白了,这篇博文讲的是bptt的一部分,甚至还不全,其参考的是 但却缺斤少两,不过后一篇在推导上也是很有不足,不谈了,大家不必去看这两篇,甚至和也不必先看(这个都是乱花眼的公式,虽然其表达的意思很准确,但是过于抽象,不利于我们从实现代码的细节角度上理解),经过我的实践,我认为我们应该先从讲的最好最清楚的BP开始,从开始,如果你听从我的建议,你应该会少走很多的弯路。

so,let's begin.

正文

整体介绍一下

 是 Neural Networks and Deep Learning 这本书的第二章,大家要是不习惯英文也可以去找翻译版。下面我不完全按照书里来,但是我会解释的很清楚的。

首先要反向传播,我们需要有向前传播。

向前传播很简单,画个图,图上每个点都需要做个线性计算,一层一层传递,这就是前向传播。

图片描述(图一)

layer1当成输入层,layer2是hidden_layer(隐含层),layer3是输出层,这个要记住了,其实深度网络有一点很难理解的就是图的抽象思维。上面这张图如果是在rnn中,变成了下面这样。

图片描述(图二)

图中的s看上去是一个点,其实里面就是上上图拆解开来的全连接,因为我们是矩阵与向量操作!!所以这点先明白清楚。

那么回到图一。

w是一个权值矩阵,同样的还需要b偏置向量,再看下图。

图片描述(图三)

a是输出层的每个神经元的输出。

图片描述

这就是输出,l代表的是第几层,j,k都是来代表第几个神经元。这里用的是sigmoid函数,也就是括号外的sigma符号,现在我们可以暂时不管这个函数是什么,因为我们还可以用relus,tanh等等,其实sigmoid现在已经很少用了,但是为了读者不那么麻烦去查,我还是把这个公式贴出来。

图片描述

好了,这个就是一个简单的全连接的网络的公式解释。为了我们后面更方便地使用公式,我们在这里许澄清一点,以上的都是element-wise也就是单元素操作,如下。

图片描述

我们可以把公式美化成下面这个版本。

图片描述

在这里我需要提一点,书中讲到一个quirk就是w的角标j,k的设置为什么是这样的,因为这样我们不用把上面这个公式中的w转置了,但其实现实代码中还是转置了,所以当你们看到现实代码中有转置就不用那么疑惑了。

一般来讲,我们往下走之前还需定义几个公式。

$$z^l≡w^la^{l−1}+b^l$$

$$z^l_j=\sum{kw^l_{jk}a^{l−1}_k+b^l_j}$$
$$a^l=σ(z^l)$$

以上就是简单的前向传播。但是我们有个最重要的损失函数才是我们需要拟合的函数,这个函数就是利用输出层的输出作为损失函数的输入。下面我们介绍损失函数。

介绍下损失函数

损失函数是干什么用的?我们就是利用损失函数求∂C/∂w 和 ∂C/∂b进行反向传播来更新w矩阵和b向量。我们可以先定义一个最简单的损失函数如下(类似于方差)

图片描述

图片描述

好清楚明白的,我就不过多废话了。

四个最最关键的公式

现在开始是核心部分,BP的核心。

先介绍一个中间量,残差。残差不等于误差。

$$δ^l_j$$

代表着第l层第j个神经元的残差。

图片描述

如上图,一个形象可爱的说法是每个神经元的地方都有一个恶魔,这个恶魔会给输入带来一点偏差,

$$Δz^l_j$$
这样会造成输出
$$σ(z^l_j+Δz^l_j)$$
总的损失会改变
$$\frac{∂C}{∂z^l_j}Δz^l_j$$

现在呢,这个恶魔成为了一个好的恶魔,能够帮助我们减少损失。那么怎么去减少损失呢?下面就用数学来说明。

图片描述

这代表着第l层的第j个神经元所获得的残差。

我们的模型比较简单,我们需要的bp公式只有四个,我先不讨论怎么证明这四个公式的正确性,我们先来看看这四个公式,证明留在后面。

图片描述

bp1在成为上图中的bp1之前不长这样,

图片描述

这是个十分自然的表达式,左边的偏导代表第j个神经元的cost的改变速率,这么来理解,如果很小的话,说明这个点的误差不大。右边的表示sigma函数在其输入z上的改变速率。

为了方便点,假设我们的C是

$$C=\frac{1}{2}∑_j(y_j−a^L_j)^2$$

那求导后就很简单了。

$$\frac{∂C}{∂a^L_j}=(a^L_j−y_j)$$

由于这个初始的bp1是分量计算的方式,我们可以重写成

图片描述

$$∇_aC=(a^L−y)$$

这个公式计算的是最后一层反向传播的公式,因为只有最后一层才有loss_function,而中间的隐含层的反向传播利用的公式则是下面的bp2,这个很重要,我一开始都没理解,最后一层和中间的隐含层是不一样的。

图片描述

理解这个公式如同理解第一个公式一样,只是这里利用了前面的残差作为输入。

后两个就不讲了,只要有了这四个公式作为工具,我们就可以轻松反向传播。

证明

从bp1开始,首先我们先定义什么是残差,

图片描述

根据链式法则我们能得到

图片描述

这里我们需要结合一下图来解释更清楚

图片描述

比如这张图,L=3,也就是我们在输出层,z是这层的神经元的输入,所以当j!=k时,偏导是为0的,因此,我们可以简化公式

图片描述

下面证明bp2,因为这个是隐藏层的传播,我在之前已经有提醒过了,所以这里的神经元的输出a其实是下一个连接到的神经元的输入z,不过是层数加1。所以,

图片描述

图片描述

图片描述

图片描述

以上证明了bp2。

bp3证明其实就是

图片描述再乘偏z偏b

偏z偏b等于1,所以证得。bp4证略。

在下一篇[深度学习原理之理解反向传播[2] --- 简单全连接网络(BP)]()我会来介绍代码,不仅仅是这个简单版的代码,主要讲讲min-char-rnn.py的代码,很多的疑问。

转载地址:http://mgpdl.baihongyu.com/

你可能感兴趣的文章
Linuxドライバ_LDD3メモ_ハードウェアとの通信
查看>>
数学之美系列四 -- 怎样度量信息?
查看>>
用Access+SharePoint 来收集数据
查看>>
Nginx 的 Location 配置指令块
查看>>
Spark小课堂Week5 Scala初探
查看>>
go练习1-翻转字符串
查看>>
java第一天学习笔记
查看>>
GPS定位为什么要转换处理?高德地图和百度地图坐标处理有什么不一样?
查看>>
冲刺博客 五
查看>>
poj 2389 大整数乘法
查看>>
JSON.stringify JSON.parse
查看>>
java中二进制的程序表示_Java程序检查两个数字的二进制表示形式是否为字谜
查看>>
java web maven 框架_maven web框架搭建
查看>>
java实现数据排序_分析Java程序员如何实现数据排序
查看>>
java libraries在哪_java.library.path在哪? | 学步园
查看>>
java数据结构循环链表_JAVA 数据结构链表操作循环链表
查看>>
php如何连接access,PHP如何连接Access数据库_PHP教程
查看>>
通过php使用cmd命令,window系统下使用cmd执行php命令
查看>>
项目重构经验php转java,这几年从 PHP 转到 Java 的有成功案例吗?
查看>>
java中多个条件模糊查询,带条件的查询—模糊查询
查看>>