Few-Shot Learning with Graph Neural Networks
Few-Shot Learning with Graph Neural Networks
Victor Garcia, Joan Bruna
3 Problem Setup
input-output pairs $$(T_i, Y_i)_i$$ drawn i.i.d. from a distribution $$P$$ of partially-labeled image collection:
$$ T = {\ { (x_1, l_1), ..., (x_s, l_s) },\ { \tilde x_1, ..., \tilde x_r },\ { \bar x_1, ..., \bar x_t };\ l_i \in { 1, K}, x_i, \tilde x_j, \bar x_j \sim P_l\ }
$$
$$ Y = (y_1, ..., y_t) \in { 1, K}^t
$$
for arbitrary values of $$s,r,t,K$$.
- $$s$$ is the number of labeled samples
- $$r$$ is the number of unlabeled samples
- $$t$$ is the number of samples to classify (?)
- $$K$$ is the number of classes
We will focus on $$t=1$$, i.e., we will classify one sample per task.
Few-Shot Learning $$r=0, t=1, s=qK$$
Semi-Supervised Learning $$r>0, t=1$$
Active Learning the learner has the ability to request labels from the sub-collection $$\tilde x_1, ..., \tilde x_r$$.
4 Model
4.1 Set and Graph Input Representations
训练数据包含了labeled和unlabeled images。这个可以formalized 为 posterior inference over a graphical model determined by the input and output。
根据最近几个work,将posterior inference强制使用GNN的message passing。在这个设定下,similarity can be learned in a discriminative fashion with parametric model similarity like Siamese network.
4.2 Graph Neural Networks
$$ xl^{(k+1)} = G(x^{(k)}) = \rho ( \sum{B \in A} B x^{(k)} \theta_{B,l}^{(k)} )
$$, where $$l=d1, ..., d{k+1}
$$
Also, learn the edge features from the current node representation:
$$ A{i,j}^{(k)} = \psi{\tilde \theta} (xi^{(k)}, x_j^{(k)}) = MLP{\tilde \theta}(abs(x_i^{(k)} - x_j{(k)}))
$$
Construction of Initial Node Features 用的是CNN学习到的image representation,然后再concatenate上label的one-hot encoding。
4.3 Relationship with Existing Models
Siamese Network 可以理解为单层的message-passing iteration of our model。
Prototypical Network 这个是考虑了每一个class对应的cluster的center point。在我们的模型中也可以实现,通过对ajdacency matrix进行特殊的改造。
Matching Network 有两个主要不同的地方: 1. Matching Network的attention mechanism类似于这里的adjacency matrix learning。但是attention mechanism用的是同样的node embedding,而这里是将adjacency stack起来。换句话说,attention-LSTM的输入是node embedding,没有考虑到label的信息(one-hot encoding)。 2. Matching Network将label和image信息分开考虑,而最后是用一个kernel将两个结合起来;这个弊端是模型在中间无法了利用label和image相互之间的信息。