Semi-supervised Learning with Deep Generative Models

Semi-supervised Learning with Deep Generative Models

Diederik P. Kingma, Danilo J. Rezende, Shakir Mohamed, Max Welling

Intro

semi-supervised learning的两个需要解决的问题 1. 如何利用数据的属性来增强decision boundary 2. 最终利用unlabeled数据能够比仅仅使用labelled数据的准确性更高

这篇paper,我们使用probabilistic model for inductive and transductive semi-supervised learning来回答这个问题。通过使用一个data density的explicit model,构造在一个deep generative model和variational inference上。

对于generative model来说,semi-supervised learning就是一个对于classification问题中特殊的missing data imputation task。

Deep Generative Models for Semi-supervised Learning

Latent-feature discriminative model (M1): 一个常见的做法是构造一个模型只提供了feature embedding,然后再给予embedding/latent representation训练一个单独的classification模型。

与以往的linear embedding或者auto-encoder不同,我们构造了一个deep generative model,能够提供更加robust的latent features。使用的deep generative model是

$$ p(z) = \mathcal{N}(z|0, I), p_\theta(x|z) = f(x;a,\theta)

$$

其中$$f(x;z,\theta)$$是一个恰当的likelihood,比如Gaussian或者Bernoulli distribution,其概率是由 non-linear transformation (参数为$$\theta$$) 和 latent variable ($$z$$)构成的。

然后在posterior distribution $$p(z|x)$$上approximate samples,并且将这些approximate samples用来训练分类器,来预测label $$y$$,比如(transductive) SVM或者multinomial regression。这个low dimensional embedding应该更容易进行分割,因为我们利用了另外一个独立的latent Gaussian posteriors,其参数是通过基于数据的non-linear transformation组成。

Generative semi-supervised model (M2): 我们提出一个probabilistic model,来描述数据如何从latent class variable $$y$$ 和 continuous latent variable $$z$$。数据是通过如下generative process描述

$$ p(y) = Cat(y|\pi), p(z) = \mathcal{N}(0, I), p_\theta(x|y,z) = f(x; y, z, \theta)

$$

所以其joint distribution是

$$ p(x,y,z) = p(y) \cdot p(z) \cdot p(x|y,z)

$$

其中$$Cat(y|\pi)$$是multinomial distribution,如果class label未知, 那么$$y$$被当作latent variables,$$z$$被当作是额外的latent variable。这些variables是marginally independent,并且允许我们(比如digit writing情况)将class specification和writing style区分开。我们使用deep neural networks来当作non-linear function $$f$$。因为大部分 label $$y$$ 并没有观测到,在inference过程中,我们选择Integrate over the class of any unlabelled data(也就是求期望),因此classification问题变成了inference问题。对于missing label的prediction问题可以通过inferred posterior distribution $$p_\theta(y|x)$$得到。

Stacked generative semi-supervised model (M1+M2): 我们可以将两个方法结合起来,首先是使用M1中的generative model来学习new latent representation $$z_1$$;然后学习generative semi-supervised model M2,使用$$z_1$$的embedding 而不是raw data $$x$$。最后是一个deep generative model with two layers of stochastic variables:

$$ p\theta(x,y,z_1,z_2) = p(y) p(z_2) p\theta (z1|y, z_2) p\theta(x|z_1)

$$

其中prior $$p(y)$$和$$p(z2)$$等于 $$y, z$$ (?)。两个$$p\theta$$是两个deep neural networks。

如果纯粹从graphical model的角度出发,即为

$$ y \to z_1\ z_2 \to z_1\ z_1 \to x

$$

Scalable Variational Inference

Lower Bound Objective

我们用inference principle,用一个fixed-form distribution $$q\phi(z|x)$$ 来估计真实的posterior distribution $$p\theta(z|x)$$(因为其本身的计算是intractable)。

我们将approximate posterior distribution $$q_\phi$$当作是一个inference或者recognition model。

  • M1: $$q\phi (z|x) = \mathcal{N}(z| \mu\phi(x), diag(\sigma^2_\phi (x) ))$$
  • M2: $$q\phi (z|y,x) = \mathcal{N}(z | \mu\phi(x), diag(\sigma\phi^2 (x) ) ); q\phi (y|x) = cat(y| \pi_\phi(x))$$

Latent Feature Discriminative Model Objective

对于每一个点,marginal likelihood的variational bound是

$$ \log p\theta(x) \ge \mathbb{E}{q{\phi (z|x)}} [\log p\theta(x|z)] - KL[q\phi(z|x) || p\theta (z)]

$$

inference network $$q_\phi(z|x)$$训练的时候,是transductive方法,同时使用labelled和unlabelled data。这个估计的posterior用来当feature extractor(在labelled data set上),和features used for classifier。

Generative Semi-supervised Model Objective

有两个情况要考虑。

  1. 每一个数据点的label都观测到,并且variational bound是上面bound的拓展形式

$$ \log p\theta (x,y) \ge \mathbb{E}{q{\phi}(z|x,y)} [\log p\theta(x|y,z) + \log p\theta(y) + \log p(z) - \log q\phi(z|x,y)] = -\mathcal{L}(x)

$$

  1. 对于label missing的情况,会被当作一个latent variable,而我们对着干遍历会进行posterior inference,最终的bound是

$$ \log p\theta(x) \ge \ \mathbb{E}{q{\phi}(y,z|x)} [\log p\theta(x|y,z) + \log p\theta(y) + \log p(z) - \log q\phi(y,z|x)]\ = \sumy q\phi(y|x) (-\mathcal{L}(x,y)) + \mathcal{H}(q_\phi(y|x)) \ = \mathcal{U}(x)

$$

因此整个数据集上的marginal likelihood的bound是

$$ \mathcal{J} = \sum{(x,y) \sim \tilde p_l} \mathcal{L}(x) + \sum{x\sim \tilde p_u} \mathcal{U}(x)

$$

Appendix

有两类semi-supervised learning

  • Transductive learning: 在训练过程中,已知testing data(unlabelled data)是transductive learing
  • Inductive learning: 在训练过程中,并不知道testing data ,训练好模型后去解决未知的testing data 是inductive learing

results matching ""

    No results matching ""