循环神经网络RNN

1. 概述

循环神经网络(Recurrent Neural Networks, RNN)主要用于时序数据,最常见的时序数据如文章,视频等,tt时刻的数据与t1t-1时刻的数据存在内在的联系。RNN模型能够对这样的时序数据建模。

2. 算法原理

RNN模型的基本结构如下所示(图片来自参考文献):

在这里插入图片描述

如上图所示,循环神经网络通过使用自带反馈的神经元,能够处理任意长度的时序数据,对此结构按照时间展开的形式如下所示(图片来自参考文献):

在这里插入图片描述

2.1. RNN的结构

上图中给出了RNN的内部结构,RNN根据输入输出主要可以分为以下三种:

  • 多输入单输出,如文本的分类问题;
  • 单输入多输出,如描述图像;
  • 多输入多输出,又分为等长或者不等长两种情况,等长如机器作诗,不等长如seq2seq模型;

这里以多输入单输出的情况为例,多输入单输出的具体结构如下所示:

在这里插入图片描述

2.2. RNN的计算过程

假设对于一个长度为TT的序列{x1,x2,,xT}\left \{ x_1,x_2,\cdots ,x_T \right \},其中xi=(xi,1,xi,2,,xi,n)x_i=\left ( x_{i,1},x_{i,2},\cdots ,x_{i,n} \right )是一个nn维的向量,假设RNN的输入xx的维度为n×1n\times 1,隐含层状态hth_t的维度为H×1H\times 1,RNN的状态更新公式为:

ht=f(Uht1+Wxt+b)h_t=f\left ( Uh_{t-1}+Wx_t+b \right )

通常h0h_0会设置为全00的向量。模型中的参数UU的维度为H×HH\times HWW的维度为H×nH\times nbb的维度为H×1H\times 1,对于具体的分类问题,其输出为:

y^=softmax(Woht+bo)\hat{y}=softmax(W_oh_t + b_o)

假设对于分类问题有cc个类别,则参数WoW_o的维度为c×Hc\times Hbob_o的维度为c×1c\times 1。最终的损失函数为:

J(U,W,b,Wo,bo)=1mi=1mL(y(i),y^(i))J\left ( U,W,b,W_o,b_o \right )=\frac{1}{m}\sum _{i=1}^mL\left ( y^{(i)},\hat{y}^{(i)} \right )

其中

L(y(i),y^(i))=j=1cyj(i)logy^j(i)L\left ( y^{(i)},\hat{y}^{(i)} \right )=-\sum_{j=1}^{c}y_j^{(i)}log\: \hat{y}_j^{(i)}

2.3. RNN中参数的求解

对于RNN模型,通常使用BPTT(BackPropagation Through Time)的训练方式,BPTT也是重复的使用链式法则,对于RNN而言,损失函数不仅依赖于当前时刻的输出层,也依赖于下一时刻。为了简单起见,以一个样本为例,此时的损失函数可以记为L(y,y^)L\left ( y,\hat{y} \right ),模型的参数为U,W,b,Wo,boU,W,b,W_o,b_o,具体的求解过程如下所示:

首先对y^\hat{y}重新定义,样本属于第(i)(i)个类别的预测值为:

y^(i)=eWoiht+boil=1ceWolht+bol\hat{y}_{(i)}=\frac{e^{W_{oi}h_t+b_{oi}}}{\sum _{l=1}^{c}e^{W_{ol}h_t+b_{ol}}}

LWoi\frac{\partial L}{\partial W_{oi}}Lboi\frac{\partial L}{\partial b_{oi}}分别为:

LWoi=(y(i)y^(i))ht\frac{\partial L}{\partial W_{oi}}=-\left ( y_{(i)}-\hat{y}_{(i)} \right )h_t

Lboi=(y(i)y^(i))\frac{\partial L}{\partial b_{oi}}=-\left ( y_{(i)}-\hat{y}_{(i)} \right )

假设ff为tanh,而tanh(a)tanh(a)的导数为1tanh(a)21-tanh(a)^2,以LU\frac{\partial L}{\partial U}为例:

LU=Ly^y^hthtU+Ly^y^hththt1ht1U++Ly^y^hththt1h1h0\frac{\partial L}{\partial U}=\frac{\partial L}{\partial \hat{y}}\cdot \frac{\partial \hat{y}}{\partial h_t}\cdot \frac{\partial h_t}{\partial U}+\frac{\partial L}{\partial \hat{y}}\cdot \frac{\partial \hat{y}}{\partial h_t}\cdot \frac{\partial h_t}{\partial h_{t-1}}\cdot \frac{\partial h_{t-1}}{\partial U}+\cdots +\frac{\partial L}{\partial \hat{y}}\cdot \frac{\partial \hat{y}}{\partial h_t}\cdot \frac{\partial h_t}{\partial h_{t-1}}\cdots \frac{\partial h_1}{\partial h_0}

htht1=[1tanh(ht)2]U\frac{\partial h_t}{\partial h_{t-1}}=\left [ 1-tanh\left ( h_t \right )^2 \right ]\cdot U,这是个小于1的数,从上面的公式我们发现,时序数据越长,后面的梯度就趋于0。

2.4. RNN存在的问题

从上述的BPTT过程来看,RNN存在长期依赖的问题,由于反向传播的过程中存在梯度消失或者爆炸的问题,简单的RNN很难建模长距离的依赖关系。

参考文献

[1] Understanding LSTM Networks