Matching Networks for One Shot Learning

Matching Networks for One Shot Learning

2 Model

Given a small support set $$S$$, we want to find a model $$C_S$$ (classifier) s.t. $$S \to C_S(\cdot)$$.

2.1 Model Architecture

Borrow ideas from NN augmented by 'memory': seq2seq with attention, memory networks, pointer networks. 在这些模型模型中,a neural attention mechanism, often differentiable, is able to access a memory matrix to solve the tasks at hand. 他们主要是在预测$$P(B|A)$$,其中A和/或B是sequence,或者在我们这里是set。

我们这里将 one-shot learning 问题转化为set-to-set框架。关键的一点事Matching Networks在训练的时候,能够产生sensible test labels for unobserved classes without any changes to the network。更具体地说,我们希望能够从一个小的support set (k examples,input-label pairs $$S={ (xi, y_i)}{i=1}^k$$ )映射到一个classifier $$C_S(\hat x)$$。给定test sample $$\hat x$$,来预测test label $$\hat y$$。

$$ P(\hat y|\hat x, S) = \sum_{i=1}^k a(\hat x, x_i) y_i

$$

a是attention mechanism,下面会提到。

  • 当a is a kernel on $$X \times X$$,然后上述的公式就像是ernel density estimator。
  • 当attention mechanism中,如果距离$$x$$最远的b个$$x_i$$的metric为0,而其余的metric都是constant,那么这就等价于k-b nearest neighbor。
  • 另外一个角度就是当a像一个attention mechanism,而$$x_i, y_i$$像是键值对,那么这就像一个hashtable。

2.1.1 The Attention Kernel

最简单的a是在cosine distance上使用softmax

$$ a(\hat x, xi) = \frac{\exp{[c(f(\hat x), g(x_i))]}}{\sum{j=1}^k \exp{[c(f(\hat x), g(x_j))]}}

$$

而其中的f和g都是neural networks(大部分情况下二者相同)。

这个问题和metric learning相关。这一类loss可以用Neighbourhood Component Analysis, triplet loss,或者large margin nearest neighbor来解决。

但是这里研究的问题还是multi-way one-shot classification。而且loss会比较简单,能够以一种end-to-end方式解决。

2.1.2 Full Context Embedding

这篇paper的最大贡献就是提出了一个新的framework来做one-shot learning。与metric learning类似,embedding function f and g act as a lift to feature space X to achieve maximum accuracy。

我们提出了将set中的element embed时,考虑到full set S。也就是$$g(x_i, S)$$。因此,作为整个support set S上的函数,g能够帮助embed $$x_i$$。我们用bi-directional LSTM 来encode $$x_i$$ in the context of $$S$$。

第二个问题是将f依赖于$$\hat x$$,而S可以通过LSTM fixed,而这个LSTM有一个read-attention over the whole set S, whose inputs are equal to $$f'(\hat x)$$。

$$ \hat hk, c_k = \text{LSTM}(f'(\hat x), [h{k-1}, r{k-1}], c{k-1})\ hk = \hat h_k + f'(\hat x)\ r{k} = \sum{i=1}^{|S|} a(h{k-1}, g(xi)) g(x_i)\ a(h{k-1}, g(xi)) = \exp { h{k-1}^T g(xi) } / \sum{j=1}^{|S|} \exp { h_{k-1}^T g(x_j) }

$$

Appendix

attention-LSTM最后倒数第二个公式可能有点typo

results matching ""

    No results matching ""