Meta Networks

Meta Networks

Tsendsuren Munkhdalai, Hong Yu

slow weight 就是传统的SGD对model进行更新

fast weight 用另外一个model来预测当前model的weight

meta information 是loss和对应对gradient

3 Meta Networks

有两个主要的learning module:base learner和meta learner。

对于a sequence of tasks,每一个task有support set (size of N),training set (size of L)。整个的训练包含三个部分:acquisition of meta information,generation of fast gradients,optimization of slow gradients。

然后在test的时候,sample another sequence of tasks with unseen classes。然后model会根据support set对test examples进行分类。在training和test的时候,我们都拥有同样的support set。

3.1 Meta Learner

  • dynamic representation learning function $$u$$: 通过使用task-level fast weights来构成input embeddings
  • fast weight generation functions $$m$$ and $$d$$: 处理meta information,产生example和task-level fast weights

function $$m$$ learns the mapping from the loss gradient $${\nablai}{i=1}^N$$, derived from the base learner $$b$$, to fast weights $${Wi^*}{i=1}^N$$

$$ W_i^* = m(Z, \nabla_i)

$$

where $$m$$ is a neural network with parameter $$Z$$. fast memory而后会存储在memory中 $$M={Wi^*}{i=1}^N$$。

而representation learning function $$u$$ 是一个neural net parameterized by slow weights $$Q$$ 和 task-level fast weights $$Q^$$。它用representation loss来作为representation learning objective,然后把gradient作为meta information。我们这么来生成fast weights $$Q^$$:

$$ \elli = loss{embed} (u(Q, xi'), y_i')\ \nabla_i = \nabla_Q \ell_i\ Q^* = d(G, {\nabla}{i=1}^N)

$$

$$d$$ is a neural net (LSTM) parameterized by $$G$$. (也可以考虑用MLP作为替代,但是实验表明其convergence不好)

当有了fast weights,task dependent input representations $${ ri' }{i=1}^N$$:

$$ r_i' = u(Q, Q^*, x_i')

$$

其中$$Q, Q^*$$会通过3.3描述的layer augmentation进行整合。

$$loss{embed}$$ 并不一定需要和main task loss $$loss{task}$$ 一样,但是它至少要能够保证可以获取representation learning objective。我们用cross-entropy loss当one-shot learning case。当有多个loss的时候,可以用contrastive loss。然后对于T个gradients,loss就是

$$ \elli = loss{embed} (u(Q, x1'), u(Q, x_2'), l_i)\ l_i = \begin{cases} 1, y{1,i}' = y_{2,i}'\ 0, o.w. \end{cases}

$$

当parameters存到了memory matrix $$M$$ 之后,memory index $$R$$ 也有了,meta learner parameterizes the base learner with the fast weights $$W_i^*$$。

$$ a_i = attention(R, r_i)\ W_i^* = norm(a_i)^T M

$$

3.2 Base Learner

base learner $$b$$ 主要是学习main task loss $$loss_{task}$$。不过这里,$$b$$ is parameterized by slow weights $$W$$ and example-level fast weights $$W^*$$。

base learner使用representation by meta information(通过support set),来给meta learner提供feedbacks。这个meta information is derived from the base learner:

$$ \elli = loss{task} (b(W, x_i'), y_i')\ \nabla_i = \nabla_W \ell_i

$$

这里$$\ell_i$$是在support examples上的。

假设已经有了fast weights $$W_i^*$$,那么就有

$$ P(\hat y|x_i, W, W_i^) = b(W, W_i^, x_i)

$$

3.3 Layer Augmentation

如图2,使用element wise addition,将fast weights和slow weights加起来。

Appendix

contrastive loss

results matching ""

    No results matching ""