本文介绍了一种文本生成的方法,关键技术为生成对抗网络(GAN)和强化学习(RL)。
内容主要包括方法的数学原理、实现思路和具体实现细节。
2019.3.3 初稿 初定内容大纲
2019.3.11 增加 循环神经网络(Recurrent Neural Network RNN) 内容
2019.3.17 增加 长短时记忆模型(Long Short-Term Memory LSTM) 内容
关键技术
我们要使用的是LSTM,但是LSTM是一个RNN模型的改进,其目的是为了适应原始RNN梯度消失、文本相关性等缺点,所以首先我们要知道什么是RNN。
循环神经网络(Recurrent Neural Network RNN)
【Karpathy】The Unreasonable Effectiveness of Recurrent Neural Networks
概述
- 序列模型
- 能够体现序列元素之间的相关性
- 相关性的体现,使其被运用于文本生成、机器翻译、语音识别等
RNN与普通神经网络(Vanilla Neural Networks)不同的是,它不是单一的input-hidden layer-output的形式,而是输出、隐层、输出都可实现序列化,如下图
蓝色方框是output,红色为input,绿色则是中间层,我们可以把绿色中间层理解成input的中间状态(state)。
(1)普通神经网络只是one-to-one的关系,而RNN则有其他四种形式:(2)one-to-many可以理解成由图片作为输入,输出对图片的文字说明(e.g. 图片描述、诗歌),其中每一个蓝色方框都是一个词;(3)many-to-one可以理解为对一句话做情感分析;(4)many-to-many的第一种形式则可以理解为机器翻译;(5)many-to-many的第二种形式则可以理解为视频分类(对视频的每一帧进行分类或者预测)。
我们通过上图,可以看到RNN的另一个特点:每一个中间层,或者说每一个 input 状态都与前一个 input 状态相关,这就体现了相邻 input 之间的相关性。
比如说我输入一句话作为我的 input ,”猫吃鱼“,其中 input1 ”猫“, input2 ”吃“, input3 ”鱼“,”猫“的中间状态由前两个 input 的中间状态决定,即,”吃“”鱼“这两个 input 和”猫“有关,这符合我们的逻辑,猫不就是吃鱼的么。
上下文构成了联系,而不是被单一孤立开来,RNN实际上实现了部分我们对智能体的期待,所以Karpathy也说,RNN是最令他激动的模型。
模型
我们用x<i>来表示第 i 个输入值,y<i>来表示第 i 个输出值,举个例子,我有一句话:
My name is ClowNaxcvd Shi
那么每一个单词分别对应一个输入值,即,My 对应 x<1>,name 对应 x<2>……
我想要判断这句话中有哪些单词表示名字,那么x<t>对应的输出值就为 y<t> 。如果是名字,y<t>就等于1,否则就等于0。
因为计算机只认数字,其他的东西一律看不懂,所以还涉及到将单词向量化的问题。简单来说,就是将文字数字化,所以我们输入的x<t>实际上应该是一组数字,不是单词。
我们首先要将单词变成数字,这一点知道就好,为了方便理解,我们就当计算机识字好了。
对于一般神经网络来说,如果我们要做名字的识别,我们一般会这样做:
将 x<1> 传入神经网络,输出 y<1>,
图中 a<1> 表示中间状态,一般会用激活函数,将其变成y<1>。
同理,将 x<2> 传入神经网络,输出 y<2>,
这样做的问题很明显,我们不能将前一个输入和后一个输入联系起来,每一个单词都是孤立的。但是我们可以看到,ClowNaxcvd 和 Shi ,我们不会将 Shi 理解成一个东西,因为它跟在 ClowNaxcvd 后面,符合姓名的逻辑。
于是我们将这两者结合:
x<1> 的中间状态 a<1> 传给 x<2>,共同计算出它的中间状态 a<2>,然后再通过a<2>计算出 y<2>。
我们将其推广到一般情况,如下图。其中我们看到一开始有一个 a<0>,是因为a<t> 是由a<t−1>、x<t>决定的,而 a<1>之前已经没有输入值和中间值了,所以我们额外加一个 a<0> ,这样 a<1> 就合乎逻辑了。
逻辑讲清楚了,接下来我们看该怎么计算出 a<t> y<t>。
前向传播
我们从左向右看,一般开始先输入 𝑎<0>,它是一个零向量。接着就是前向传播过程,先计算激活值 𝑎<1>,然后再计算 y<1>。
𝑎<1>=g1(𝑊𝑎aa<0>+𝑊a𝑥x<1>+𝑏a)
^𝑦<1>=g2(𝑊yaa<1>+𝑏y)
我将用这样的符号约定来表示这些矩阵下标,举个例子, Wax第二个下标x意味着 Wax要乘以某个x类型的量,然后第一个下标a表示它是用来计算某个a类型的变量。同样的,可以看出这里的 Wya乘上了某个a类型的量,用来计算出某个y类型的量。
循环神经网络用的激活函数经常是 tanh,不过有时候也会用 ReLU,但是 tanh 是更通常的选择,我们有其他方法来避免梯度消失问题,我们将在之后进行讲述。选用哪个激活函数是取决于你的输出𝑦,如果它是一个二分问题,那么我猜你会用 sigmoid 函数作为激活函数, 如果是𝑘类别分类问题的话,那么可以选用 softmax 作为激活函数。不过这里激活函数的类型取决于你有什么样类型的输出y,对于命名实体识别来说𝑦只可能是 0 或者 1,那我猜这里第二个激活函数 g 可以是 sigmoid 激活函数。
更一般的情况下,在 t 时刻,
a<t>=g1(𝑊aaa<t−1>+𝑊axx<t>+ba)
ˆy<t>=g2(𝑊yaa<t>+by)
下图提供了更直观的过程。其中 h 就是 y ,只是不同人的表示方法有所差异而已,这点不用细究了。
RNN虽然理论上可以很漂亮的解决序列数据的训练,但是它也像DNN一样有梯度消失时的问题,当序列很长的时候问题尤其严重。
同时,随着序列加长,模型很难捕捉前面的输入与后面的输入之间的关系,举个例子,我有两句话
The person, who standing there and……, is lively.
The people, who standing there and……, are lively.
除了主语的单复数之外,其他的成分一模一样,而且在主谓之间,有一句可以无限长的从句,这种情况下,模型在判断后面到底是is还是are的时候,就非常困难。一方面存在梯度爆炸或者梯度消失问题,另一方面,主谓之间的依赖性(就是 a<i> )会随着从句的边长而被稀释。
因此,上面的RNN模型一般不能直接用于应用领域。在语音识别,手写书别以及机器翻译等NLP领域实际应用比较广泛的是基于RNN模型的一个特例LSTM,下面我们就来讨论LSTM模型。
长短时记忆模型(Long Short-Term Memory LSTM)
【Colah】Understanding LSTM Networks
【李宏毅】机器学习-Recurrent Neural Network(2019)
概述
原始RNN模型只有一个计算流程,x<t>→a<t>→y<t>,如下图所示。
而LSTM则增加了三个计算流程,如下图。
这几个流程是什么意思呢?
LSTM的本质,其实就是对a<t>的一系列改进,也就是对输入x<t>状态的改进。
更形象地说,还是那两句话:
The person, who standing there and……, is lively.
The people, who standing there and……, are lively.
我的目的就是,即使主谓之间隔着这么远的距离,我仍然能够让计算机知道,后面的谓语应该是单数还是复数,LSTM实际上就是通过增强了a<t>对a<t+i>的影响,来增加x<t>对x<t+i>的影响。
这种影响的增强,就是通过的三个门来实现的。
- 输入门
保留多少 当前时刻输入 到 当前时刻单元状态。
- 遗忘门
保留多少 前一时刻单元状态 到 当前时刻单元状态。
- 输出门
保留多少 当前时刻单元状态 留给 下一时刻的单元状态。
下面我们一步步详述。
门机制
遗忘门
首先,我需要利用上一个输入量得到的输出为参考量,去判断上一个单元状态对现在的输入量,有多少联系。
这个ft是一个尺度,在0-1之间,用来度量上一个单元状态对此刻单元状态的约束。
输入门
然后,我要判断,当前输入量有没有价值,是否需要把它存到我的单元状态里面。
再然后,我综合通过遗忘门和输入门的值,得到了当前单元状态量。这个单元状态量,结合了前一个输出量与现在输入量的综合关系。
输出门
最后,我需要一个尺度,来判断我算出来单元状态量有多少意义,这个尺度就是输出门。
总结以上叙述,整个流程可以归纳为:
- 通过输入门,得到当前时刻输入值的对当前单元状态的价值;
- 通过遗忘门,得到前一时刻单元状态对当前单元状态的价值;
- 累加1、2两步的值,得到的当前单元状态,表示的意义为:当前输入值和上一输入值对当前状态的意义;
- 通过输出门,得到当前状态值对输出结果的影响。
整个过程可以跟着上图过一遍。
生成对抗网络(Generative Adversarial Networks)
强化学习方法:策略梯度(Policy Gradient)
该方法可参考我以前的学习笔记,【强化学习】Policy Gradient
写得还算详细,故此不再赘述。
实现思路/逻辑
LTSM
GAN+PG/GAN+RL
实现细节/代码
LTSM
GAN+PG/GAN+RL
Reference
【Karpathy】The Unreasonable Effectiveness of Recurrent Neural Networks
【Colah】Understanding LSTM Networks
【李宏毅】机器学习-Recurrent Neural Network(2019)
- 本文作者: YA
- 本文链接: http://www.yuuuuang.com/2019/03/11/GAN-RL-文本生成的方法/
- 版权声明: 本博客所有文章除特别声明外,均采用 MIT 许可协议。转载请注明出处!