Meta-Learning with Memory-Augmented Neural Networks

Meta-Learning with Memory-Augmented Neural Networks

Adam Santoro, Sergey Bartunov, Matthew Botvinick, Daan Wierstra, Timothy Lillicrap

Google, SBOS

Intro

包含memory的neural network可以用来提供一个有效的meta-learning方法。这里就提出了使用memory-augmented neural network(MANN)。

2 Meta-Learning Task Methodology

对于meta-learning,我们学习一个参数能够减小expected cost,在某个数据的分布 $$p(D)$$ 上

$$\theta^* = argmin\theta \, \mathbb{E}{D \sim p(D)} [\mathcal{L}(D; \theta)]$$

在time step t的输入是$$xt, y{t-1}$$,也就是上一步的输入。这里的模型是为了预测 $$p(yt | x_t, D{1:t-1}; \theta)$$(保证了模型能够将有用信息存储在memory中)。

3 Memory-Augmented Model

3.1 Neural Turing Machine

Neural Turing Machine是一个differentiable implementation of MANN。使用一个controller来进行memory存储。controller也就是一个前向网络或者是LSTM。controller会和外部的一个memory module进行交互(每一个time-step),而memory module能够存储重要的长期和短期记忆。

在表现形式上controller就是能够将$$x_t$$作为输入,输出一个key $$k_t$$。然后这个key会要么存储到meomry matrix $$M_t$$中的某一行,要么会被用来retrieve a particular memory i from a row, i.e. $$M_t(i)$$。

当retrieve memory(read)的时候,使用基于cosine similarity将模型取出一个read-weight vector。

$$ K(k_t, M_t(i)) = \frac{k_t \cdot M_t(i)}{| k_t | \cdot | M_t(i) |}\ w_t^r(i) = softmax(K(k_t, M_t(i)))

$$

然后一个memory $$r_t$$ is retrieved using this read-weight vector

$$ r_t = \sum_i w_t^r(i)M(i)

$$

这个memory在controller中会被当作classifier的input,或者下一层controller的input。

3.2 Least Recently Used Access

这个NTM的方法对于sequence-based task十分有效,但对于独立于序列的聚合信息不是特别最优。就此提出了Least Recently Used Access(LRUA)。大致想法就是将内容和访问次数联系起来,新的信息有两种加入memory的方式

  1. 放到least used location,保存recently encoded information
  2. 放到most recently / last used location,同时将memory内容更新,比如newer,more relevant info

这两者的区别是根据这两者关系的两种理解:previous read weights and weights scaled according to usage weights $$w_t^u$$。这个usage weight在每一个time-step都会根据前一步的usage weight和当前一步的read和write weights进行更新:

$$ wt^u = \gamma w{t-1}^u + w_t^r + w_t^w

$$

$$\gamma$$是decay parameter,$$w_t^r$$的计算前已经提到。而least-used weight $$w_t^{lu}$$ 可以通过$$w_t^u$$计算。首先引入概念$$m(v,n)$$表示vector $$v$$中第n小的元素。那么$$w_t^{ul}$$就是

$$ w_t^{ul} = \begin{cases} 0, \text{ if } w_t^u(i) > m(w_t^u, n)\ 1, \text{ otherwise} \end{cases}

$$

其中 $$n$$ 是read head的数目。然后为了更新write weights

$$ wt^w = \sigma(\alpha) w{t-1}^r + (1-\sigma(\alpha)) w_{t-1}^{lu}

$$

最开始的初始化中,memory里面都是空的,least used memory location 是0. 然后更新到memory就用了这个write weights:

$$ Mt(i) = M{t-1}(i) + w_t^w(i) k_t, \forall i

$$

这样这个memory就被写到了zeroed memory slot或者是一个least used memory被擦掉替代。

Appendix

智能单元 最前沿:百家争鸣的Meta Learning/Learning to learn

我们可以看到,网络的输入把上一次的y label也作为输入,并且添加了external memory存储上一次的x输入,这使得下一次输入后进行反向传播时,可以让y label和x建立联系,使得之后的x能够通过外部记忆获取相关图像进行比对来实现更好的预测。

results matching ""

    No results matching ""