Prototypical Networks for Few-shot Learning

Prototypical Networks for Few-shot Learning

2 Prototypical Networks

2.1 Notation

support set $$S = {(xi, y_i)}{i=1}^N$$

$$x_i \in \mathbb{R}^{D}$$

2.2 Model

Prototypical Networks 对每一个class 计算M-dimensional representation $$ck \in \mathbb{R}^{M}$$,也叫做prototypical,通过一个embedding function $$f\phi: \mathbb{R}^{D} \to \mathbb{R}^{M}$$。每一个prototypical is the mean vector of the embedded support points belongs to its class:

$$ ck = \frac{1}{|S_k|} \sum{(xi,y_i) \in S_k} f\phi(x_i)

$$

通过一个distance function $$d$$,Prototypical Networks能够产生distribution over classes for a query point $$x$$:

$$ p(y=k|x) = \frac{\exp(-d(f\phi(x), c_k))} {\sum{k'}\exp(-d(f\phi(x), c{k'}))}

$$

learning process就是最小化log likelihood。training episodes包括从training set中选择subset of classes,然后每一个class选择subset of examples作为support set以及另外的subset of examples作为query set。

算法流程见Algorithm 1

2.3 Prototypical Networks as Mixture Density Estimation

对于一类特定的distance function,叫做 regular Bregman divergences, the Prototypical Networks algorithm is equivalent to performing mixture density estimation on the support set with an exponential family density.

$$ d\phi (z, z') = \phi (z) - \phi(z') - (z-z')^T \nabla{\phi} (z')

$$

Prototype computation can be viewed in terms of hard clustering on the support set, with one cluster per class and each support point assigned to its corresponding class cluster.

In this case, the Bregman divergences can achieve minimal distance to the assigned cluster mean.

2.4 Reinterpretation as a Linear Model

When using the Euclidean distance, then it's equal to the linear model with particular parameterization.

2.5 Comparison to Matching Networks

one-shot learning是一样的

few-shot learning会不一样

Matching Networks produce a weighted nearest neighbor classifier given the support set, while Prototypical Networks produce a linear classifier when squared Euclidean distance is used.

Matching Network 用了很多很fancy的技术,但是我们下面展示了哪怕只用很简单的模型,也能得到不错的结果。

2.6 Design Choices

Distance metric L2 distance被发现是非常有效的。然后L2 distance比cosine distance的优异之处在于,cosine distance并不属于Bregman divergence。

Episode composition 从Matching Network开始提出来的support set。然后有了一些小的发现,比如最好保持相同的shot数目。

2.7 Zero-Shot Learning

zero-shot learning 就是对于每一个class,不用support set来训练center points,而是给定class meta-data vector $$v_k$$。可以是提前给定,也可以是通过raw text学习到。这个时候class embedding和query embedding来源不同,发现可以将embedding g强制为unit length,而query embedding没有限制。

Appendix

results matching ""

    No results matching ""