SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient

SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient

Lantao Yu, Weinan Zhang, Jun Wang, Yong Yu

Intro

RNN经常用于sequence generation,而且通常是最大化log predictive likelihood。但是后来Bengio在2015年的paper中提出了MLE会带来exposure bias问题:模型迭代的生成序列,产生的下一个token可能从来没有在词库中出现过。为了解决这个问题,Bengio提出了使用scheduled sampling,生成模型在产生写下一个token的时候,输入的是合成的数据(观测到的token)作为prefix。但后来Huszar的paper证明了这个方法并没有从根本上解决问题。另外一个可能的解决方案是在整个生成的序列上面构造loss function,而不是每一次转换。

GAN模型也可以用来处理上面问题。但是将GAN用于seq generation有两个问题 1. GAN是为了产生连续的数据,对于discrete还有问题 2. GAN只有为整个序列产生score/loss,对于partially generated seq没有

通过这个办法解决:把sequence generation当作sequential decision making过程,生成模型就被当作RL中的agent,state是已经生成的token,action是下一个选什么token赖声川。为了给reward,我们使用discriminator来评估生成的序列,而用评估的反馈来指导generative model。为了解决输出是discrete时候,无法使得gradient传递给生成器,使用stochastic parameterized policy,因为policy空间是可导的(但其实这里有一个问题,在David 2014年的一篇paper中,已经证明其实deterministic off-policy也可以是可导的,或许将来可以尝试)。

Sequence Generative Adversarial Nets

训练一个generative model $$G\theta$$,策略是随机的。判别模型discriminative model $$D\phi$$,是用来判断多大概率这个序列是真实数据还是人工生成的。同时使用policy gradient和蒙特卡洛搜索来更新$$G\theta$$。其中MC收到的reward期望是来自于$$D\phi$$判断为真是序列的likelihood。

SeqGAN via Policy Gradient

生成模型的目的是为了返回从开始状态能生成最大reward期望的序列

$$J(\theta) = \mathbb{E}[RT|s_0, \theta] = \sum{y1} G\theta(y1|s_0) \cdot Q{D\phi}^{G\theta} (s_0, y_1)$$

一个问题是如何policy evaluation。考虑使用$$D\phi$$给出的概率值作为reward。但是判别器只给finished sequence提供reward value。既然我们关心的是长期的reward,那么每一个timestep,不应该直关心前面已经生成序列的fitness,同时也要关心最终的outcome。因此,为了评估中间状态的action-state value,在MC搜索上再应用roll-out policy $$G\beta$$,来抽样后面的tokens。如果rollout是non-uniform,就可以有$$\epsilon-greedy$$。,不过这里没有提具体哪种rollout。

Algorithm. 1描述了完整的流程。pre-train可以有效提训练效率。

The Generative Model For Sequences

使用LSTM

The Discriminative Model for Sequences

DNN,CNN,RCNN是常用的序列分类算法。这里使用CNN。

Synthetic Data Experiments

测试是使用了一个随机的LSTM作为true model,来产生real data distribution。

Evaluation Metric

使用人工合成的LSTM好处就是能够准确的评估性能。(evaluation在generation model中也是一个困难点)MLE是为了最小化真实数据和生成数据的cross-entropy,但是更有效的评估方法是让人工观察(这个会不会有些主观)。

Real-world Scenario

在generated texts和human-created texts进行评估时,用BLUE scores。

被证明SeqGAN比MLE有效。

Appendix

正好看这篇paper的weekday,Paper Weekly请来了作者进行介绍。附上slide

文章结尾说它这是使用GAN生成离散序列最早的应用,但我记得2017年年初的时候就有人用GAN model生成SMILES序列,回头再比较一下。

results matching ""

    No results matching ""