Few-Shot Learning with Graph Neural Networks

Few-Shot Learning with Graph Neural Networks

Victor Garcia, Joan Bruna

3 Problem Setup

input-output pairs $$(T_i, Y_i)_i$$ drawn i.i.d. from a distribution $$P$$ of partially-labeled image collection:

$$ T = {\ { (x_1, l_1), ..., (x_s, l_s) },\ { \tilde x_1, ..., \tilde x_r },\ { \bar x_1, ..., \bar x_t };\ l_i \in { 1, K}, x_i, \tilde x_j, \bar x_j \sim P_l\ }

$$

$$ Y = (y_1, ..., y_t) \in { 1, K}^t

$$

for arbitrary values of $$s,r,t,K$$.

  • $$s$$ is the number of labeled samples
  • $$r$$ is the number of unlabeled samples
  • $$t$$ is the number of samples to classify (?)
  • $$K$$ is the number of classes

We will focus on $$t=1$$, i.e., we will classify one sample per task.

Few-Shot Learning $$r=0, t=1, s=qK$$

Semi-Supervised Learning $$r>0, t=1$$

Active Learning the learner has the ability to request labels from the sub-collection $$\tilde x_1, ..., \tilde x_r$$.

4 Model

4.1 Set and Graph Input Representations

训练数据包含了labeled和unlabeled images。这个可以formalized 为 posterior inference over a graphical model determined by the input and output。

根据最近几个work,将posterior inference强制使用GNN的message passing。在这个设定下,similarity can be learned in a discriminative fashion with parametric model similarity like Siamese network.

4.2 Graph Neural Networks

$$ xl^{(k+1)} = G(x^{(k)}) = \rho ( \sum{B \in A} B x^{(k)} \theta_{B,l}^{(k)} )

$$, where $$l=d1, ..., d{k+1}

$$

Also, learn the edge features from the current node representation:

$$ A{i,j}^{(k)} = \psi{\tilde \theta} (xi^{(k)}, x_j^{(k)}) = MLP{\tilde \theta}(abs(x_i^{(k)} - x_j{(k)}))

$$

Construction of Initial Node Features 用的是CNN学习到的image representation,然后再concatenate上label的one-hot encoding。

4.3 Relationship with Existing Models

Siamese Network 可以理解为单层的message-passing iteration of our model。

Prototypical Network 这个是考虑了每一个class对应的cluster的center point。在我们的模型中也可以实现,通过对ajdacency matrix进行特殊的改造。

Matching Network 有两个主要不同的地方: 1. Matching Network的attention mechanism类似于这里的adjacency matrix learning。但是attention mechanism用的是同样的node embedding,而这里是将adjacency stack起来。换句话说,attention-LSTM的输入是node embedding,没有考虑到label的信息(one-hot encoding)。 2. Matching Network将label和image信息分开考虑,而最后是用一个kernel将两个结合起来;这个弊端是模型在中间无法了利用label和image相互之间的信息。

Appendix

results matching ""

    No results matching ""