Matching Networks for One Shot Learning
Matching Networks for One Shot Learning
2 Model
Given a small support set $$S$$, we want to find a model $$C_S$$ (classifier) s.t. $$S \to C_S(\cdot)$$.
2.1 Model Architecture
Borrow ideas from NN augmented by 'memory': seq2seq with attention, memory networks, pointer networks. 在这些模型模型中,a neural attention mechanism, often differentiable, is able to access a memory matrix to solve the tasks at hand. 他们主要是在预测$$P(B|A)$$,其中A和/或B是sequence,或者在我们这里是set。
我们这里将 one-shot learning 问题转化为set-to-set框架。关键的一点事Matching Networks在训练的时候,能够产生sensible test labels for unobserved classes without any changes to the network。更具体地说,我们希望能够从一个小的support set (k examples,input-label pairs $$S={ (xi, y_i)}{i=1}^k$$ )映射到一个classifier $$C_S(\hat x)$$。给定test sample $$\hat x$$,来预测test label $$\hat y$$。
$$ P(\hat y|\hat x, S) = \sum_{i=1}^k a(\hat x, x_i) y_i
$$
a是attention mechanism,下面会提到。
- 当a is a kernel on $$X \times X$$,然后上述的公式就像是ernel density estimator。
- 当attention mechanism中,如果距离$$x$$最远的b个$$x_i$$的metric为0,而其余的metric都是constant,那么这就等价于k-b nearest neighbor。
- 另外一个角度就是当a像一个attention mechanism,而$$x_i, y_i$$像是键值对,那么这就像一个hashtable。
2.1.1 The Attention Kernel
最简单的a是在cosine distance上使用softmax
$$ a(\hat x, xi) = \frac{\exp{[c(f(\hat x), g(x_i))]}}{\sum{j=1}^k \exp{[c(f(\hat x), g(x_j))]}}
$$
而其中的f和g都是neural networks(大部分情况下二者相同)。
这个问题和metric learning相关。这一类loss可以用Neighbourhood Component Analysis, triplet loss,或者large margin nearest neighbor来解决。
但是这里研究的问题还是multi-way one-shot classification。而且loss会比较简单,能够以一种end-to-end方式解决。
2.1.2 Full Context Embedding
这篇paper的最大贡献就是提出了一个新的framework来做one-shot learning。与metric learning类似,embedding function f and g act as a lift to feature space X to achieve maximum accuracy。
我们提出了将set中的element embed时,考虑到full set S。也就是$$g(x_i, S)$$。因此,作为整个support set S上的函数,g能够帮助embed $$x_i$$。我们用bi-directional LSTM 来encode $$x_i$$ in the context of $$S$$。
第二个问题是将f依赖于$$\hat x$$,而S可以通过LSTM fixed,而这个LSTM有一个read-attention over the whole set S, whose inputs are equal to $$f'(\hat x)$$。
$$ \hat hk, c_k = \text{LSTM}(f'(\hat x), [h{k-1}, r{k-1}], c{k-1})\ hk = \hat h_k + f'(\hat x)\ r{k} = \sum{i=1}^{|S|} a(h{k-1}, g(xi)) g(x_i)\ a(h{k-1}, g(xi)) = \exp { h{k-1}^T g(xi) } / \sum{j=1}^{|S|} \exp { h_{k-1}^T g(x_j) }
$$
Appendix
attention-LSTM最后倒数第二个公式可能有点typo