Meta-Learning with Memory-Augmented Neural Networks
Meta-Learning with Memory-Augmented Neural Networks
Adam Santoro, Sergey Bartunov, Matthew Botvinick, Daan Wierstra, Timothy Lillicrap
Google, SBOS
Intro
包含memory的neural network可以用来提供一个有效的meta-learning方法。这里就提出了使用memory-augmented neural network(MANN)。
2 Meta-Learning Task Methodology
对于meta-learning,我们学习一个参数能够减小expected cost,在某个数据的分布 $$p(D)$$ 上
$$\theta^* = argmin\theta \, \mathbb{E}{D \sim p(D)} [\mathcal{L}(D; \theta)]$$
在time step t的输入是$$xt, y{t-1}$$,也就是上一步的输入。这里的模型是为了预测 $$p(yt | x_t, D{1:t-1}; \theta)$$(保证了模型能够将有用信息存储在memory中)。
3 Memory-Augmented Model
3.1 Neural Turing Machine
Neural Turing Machine是一个differentiable implementation of MANN。使用一个controller来进行memory存储。controller也就是一个前向网络或者是LSTM。controller会和外部的一个memory module进行交互(每一个time-step),而memory module能够存储重要的长期和短期记忆。
在表现形式上controller就是能够将$$x_t$$作为输入,输出一个key $$k_t$$。然后这个key会要么存储到meomry matrix $$M_t$$中的某一行,要么会被用来retrieve a particular memory i from a row, i.e. $$M_t(i)$$。
当retrieve memory(read)的时候,使用基于cosine similarity将模型取出一个read-weight vector。
$$ K(k_t, M_t(i)) = \frac{k_t \cdot M_t(i)}{| k_t | \cdot | M_t(i) |}\ w_t^r(i) = softmax(K(k_t, M_t(i)))
$$
然后一个memory $$r_t$$ is retrieved using this read-weight vector
$$ r_t = \sum_i w_t^r(i)M(i)
$$
这个memory在controller中会被当作classifier的input,或者下一层controller的input。
3.2 Least Recently Used Access
这个NTM的方法对于sequence-based task十分有效,但对于独立于序列的聚合信息不是特别最优。就此提出了Least Recently Used Access(LRUA)。大致想法就是将内容和访问次数联系起来,新的信息有两种加入memory的方式
- 放到least used location,保存recently encoded information
- 放到most recently / last used location,同时将memory内容更新,比如newer,more relevant info
这两者的区别是根据这两者关系的两种理解:previous read weights and weights scaled according to usage weights $$w_t^u$$。这个usage weight在每一个time-step都会根据前一步的usage weight和当前一步的read和write weights进行更新:
$$ wt^u = \gamma w{t-1}^u + w_t^r + w_t^w
$$
$$\gamma$$是decay parameter,$$w_t^r$$的计算前已经提到。而least-used weight $$w_t^{lu}$$ 可以通过$$w_t^u$$计算。首先引入概念$$m(v,n)$$表示vector $$v$$中第n小的元素。那么$$w_t^{ul}$$就是
$$ w_t^{ul} = \begin{cases} 0, \text{ if } w_t^u(i) > m(w_t^u, n)\ 1, \text{ otherwise} \end{cases}
$$
其中 $$n$$ 是read head的数目。然后为了更新write weights
$$ wt^w = \sigma(\alpha) w{t-1}^r + (1-\sigma(\alpha)) w_{t-1}^{lu}
$$
最开始的初始化中,memory里面都是空的,least used memory location 是0. 然后更新到memory就用了这个write weights:
$$ Mt(i) = M{t-1}(i) + w_t^w(i) k_t, \forall i
$$
这样这个memory就被写到了zeroed memory slot或者是一个least used memory被擦掉替代。
Appendix
智能单元 最前沿:百家争鸣的Meta Learning/Learning to learn
我们可以看到,网络的输入把上一次的y label也作为输入,并且添加了external memory存储上一次的x输入,这使得下一次输入后进行反向传播时,可以让y label和x建立联系,使得之后的x能够通过外部记忆获取相关图像进行比对来实现更好的预测。