Meta Networks
Meta Networks
Tsendsuren Munkhdalai, Hong Yu
slow weight 就是传统的SGD对model进行更新
fast weight 用另外一个model来预测当前model的weight
meta information 是loss和对应对gradient
3 Meta Networks
有两个主要的learning module:base learner和meta learner。
对于a sequence of tasks,每一个task有support set (size of N),training set (size of L)。整个的训练包含三个部分:acquisition of meta information,generation of fast gradients,optimization of slow gradients。
然后在test的时候,sample another sequence of tasks with unseen classes。然后model会根据support set对test examples进行分类。在training和test的时候,我们都拥有同样的support set。
3.1 Meta Learner
- dynamic representation learning function $$u$$: 通过使用task-level fast weights来构成input embeddings
- fast weight generation functions $$m$$ and $$d$$: 处理meta information,产生example和task-level fast weights
function $$m$$ learns the mapping from the loss gradient $${\nablai}{i=1}^N$$, derived from the base learner $$b$$, to fast weights $${Wi^*}{i=1}^N$$
$$ W_i^* = m(Z, \nabla_i)
$$
where $$m$$ is a neural network with parameter $$Z$$. fast memory而后会存储在memory中 $$M={Wi^*}{i=1}^N$$。
而representation learning function $$u$$ 是一个neural net parameterized by slow weights $$Q$$ 和 task-level fast weights $$Q^$$。它用representation loss来作为representation learning objective,然后把gradient作为meta information。我们这么来生成fast weights $$Q^$$:
$$ \elli = loss{embed} (u(Q, xi'), y_i')\ \nabla_i = \nabla_Q \ell_i\ Q^* = d(G, {\nabla}{i=1}^N)
$$
$$d$$ is a neural net (LSTM) parameterized by $$G$$. (也可以考虑用MLP作为替代,但是实验表明其convergence不好)
当有了fast weights,task dependent input representations $${ ri' }{i=1}^N$$:
$$ r_i' = u(Q, Q^*, x_i')
$$
其中$$Q, Q^*$$会通过3.3描述的layer augmentation进行整合。
$$loss{embed}$$ 并不一定需要和main task loss $$loss{task}$$ 一样,但是它至少要能够保证可以获取representation learning objective。我们用cross-entropy loss当one-shot learning case。当有多个loss的时候,可以用contrastive loss。然后对于T个gradients,loss就是
$$ \elli = loss{embed} (u(Q, x1'), u(Q, x_2'), l_i)\ l_i = \begin{cases} 1, y{1,i}' = y_{2,i}'\ 0, o.w. \end{cases}
$$
当parameters存到了memory matrix $$M$$ 之后,memory index $$R$$ 也有了,meta learner parameterizes the base learner with the fast weights $$W_i^*$$。
$$ a_i = attention(R, r_i)\ W_i^* = norm(a_i)^T M
$$
3.2 Base Learner
base learner $$b$$ 主要是学习main task loss $$loss_{task}$$。不过这里,$$b$$ is parameterized by slow weights $$W$$ and example-level fast weights $$W^*$$。
base learner使用representation by meta information(通过support set),来给meta learner提供feedbacks。这个meta information is derived from the base learner:
$$ \elli = loss{task} (b(W, x_i'), y_i')\ \nabla_i = \nabla_W \ell_i
$$
这里$$\ell_i$$是在support examples上的。
假设已经有了fast weights $$W_i^*$$,那么就有
$$ P(\hat y|x_i, W, W_i^) = b(W, W_i^, x_i)
$$
3.3 Layer Augmentation
如图2,使用element wise addition,将fast weights和slow weights加起来。