The Implicit Bias of Gradient Descent on Separable Data

The Implicit Bias of Gradient Descent on Separable Data

Daniel Soudry, etc

Abstract

我们证明了在unregularized logistic regression上使用gradient descent,在linearly separable dataset上,会converge到和 max-margin 一个方向的solution。这个结果也能够泛化到其他单调递减的loss function。另外我们也展示了这个convergence很慢,大约是loss convergence的log。这也能够解释了在当training error变成0之后,继续进行训练的好处;哪怕之后training loss会非常的小,而且validation loss增加。我们的方法也能够用于理解在更加复杂的模型内核其他优化方法的implicit regularization。

Intro

predictor的norm会增长到无限,但是我们更关心的是direction of prediction,也就是$$w(t) / | w(t) |$$。

这篇paper,我们展示了哪怕没有任何明显的regularization,在线性可分的数据集上,当使用gradient descent最小化logistic regression,我们有$$w(t) / |w(t)|$$ converge到$$L_2$$ maximum margin separator,也就是hard margin SVM。

Main Results

loss定义为$$L(w) = \sum_{n=1}^N \ell(y_n w^T x_n)$$

Theorem 3: 给定假设1,2,3,从任意的出发点开始训练,使用gradient descent,在t时刻的parameter等于max margin solution乘以$$\log(t)$$加上关于$$t$$的一个函数。 $$w(t) = \hat w \log(t) + \rho(t)$$, where $$| \rho(t)| = O(\log \log (t))$$. 因此 $$\frac{w(t)}{|w(t)|} \to \frac{\hat w}{|\hat w|}$$. 并且对于大多数的数据集,residual $$\rho(t)$$ is bounded.

Appendix

A. Proof of Main Results

在下面证明中,对于任意的结果$$w(t)$$,我们定义

$$r(t) = w(t) - \hat w \log(t) - \tilde w$$

其中$$\hat w$$是L2 max margin vector

$$\hat w = \arg \min |w|^2$$ s.t. $$\forall v: w^T x_n \ge 1$$

而$$\dot w$$ 则满足

$$\forall n \in S: \eta \exp(-x_n^T \tilde w) = \alpha_n$$

其中$$X_S \in \mathbb{R}^{d \times |S|}$$表示的是support vectors,每一列对应一个support vector。

在Lemma 8 (Appendix F)中,我们证明了每一个数据集的$$\alpha$$都是唯一确定的,并且又不超过d个support vectors,且$$\alpha_n \ne 0$$。

另外,我们使用

$$ \theta = \min_{n \notin S} x_n^T \hat w > 1

$$

和通过$$C_i, \epsilon_i, t_i$$ 不同与t无关的常数。最后,我们使用matrix $$P_1 \in \mathbb{R}^{d \times d}$$作为orthogonal projection matrix,project到subspace spanned by support vectors。这个subspace即为$$X_S$$。

A.1 Simple Proof of Theorem 3 For A Special Case

这里我们先考虑一个特殊情况 $$\ell(u) = e^{-u}$$,然后使用continuous time limit(也就是t作为连续数值上的导数),有

$$\dot w = - \nabla \mathcal{L}(w(t))$$

定义 $$r(t) = w(t) - \log(t) \hat w - \tilde w$$

我们目标是为了展示$$|r(t)|$$是bounded,因此$$\rho(t) = r(t) + \tilde w$$也是bounded。

$$\dot r(t) = \dot w(t) - \frac{1}{t} \hat w = -\nabla(w(t)) - \frac{1}{t} \hat w$$

因此有

$$ \begin{align} \frac{1}{2} \frac{d}{dt} | r(t)^2| & = \dot r(t)^T r(t)\ & = \sum{n=1}^N -\nabla \mathcal{L}(w(t))^T r(t) - \frac{1}{t} \hat w^T r(t) \ & = \sum{n=1}^{N} \exp(-xn^T w(t)) x_n^T r(t) - \frac{1}{t} \hat w^T r(t)\ & = \sum{n=1}^{N} \exp(-xn^T (-r(t) - \log(t) \hat w - \tilde w)) x_n^T r(t) - \frac{1}{t} \hat w^T r(t)\ & = \bigg[ \sum{n \in S} \exp(-\log(t) \hat w^T xn - \tilde w^T x_n - x_n^T r(t)) x_n^T r(t) - \frac{1}{t} \hat w^T r(t) \bigg]\ & + \bigg[ \sum{n \notin S} \exp(-\log(t) \hat w^T x_n - \tilde w^T x_n - x_n^T r(t)) x_n^T r(t) \bigg]\ \end{align}

$$

最后一个步骤是将其分解为support vectors和non support vectors。(如何得到第2个括号?)

对于support vectors,$$\hat w^T xn = 1$$,并且根据定义有$$\sum{n \in S} \exp(- \tilde w^T x_n) x_n = \hat w$$。这个理解为decision boundary $$\hat w$$是由support vectors来确定的(只不过从实际层面比较难以确定哪些是support vectors)。因此第一个括号可以写成

$$ \begin{align} & \sum{n \in S} \exp(-\log(t) \hat w^T x_n - \tilde w^T x_n - x_n^T r(t)) x_n^T r(t) - \frac{1}{t} \hat w^T r(t)\ = & \sum{n \in S} \exp(-\log(t) \cdot 1 - \tilde w^T xn - x_n^T r(t)) x_n^T r(t) - \frac{1}{t} \hat w^T r(t)\ = & \frac{1}{t} \sum{n \in S} \exp(- \tilde w^T xn - x_n^T r(t)) x_n^T r(t) - \frac{1}{t} \hat w^T r(t)\ = & \frac{1}{t} \bigg[ \sum{n \in S} \exp(- \tilde w^T xn - x_n^T r(t)) x_n^T r(t) - \exp (-\tilde w^Tx_n) x_n^T r(t) \bigg]\ = & \frac{1}{t} \bigg[ \sum{n \in S} \big( \exp(- \tilde w^T xn - x_n^T r(t)) - \exp (-\tilde w^Tx_n) \big) x_n^T r(t) \bigg]\ = & \frac{1}{t} \bigg[ \sum{n \in S} \big( \exp(- \tilde w^T x_n) \big) \cdot \big( \exp (x_n^T r(t)) - 1\big) x_n^T r(t) \bigg]\ \le & 0 \end{align}

$$

最后一步是因为$$z (e^-z -1) \le 0$$。

另外,因为根据$$\theta$$定义,所有非support vector的点到support vector的最短距离为$$\theta$$,因此$$\exp(\hat w^T x_n) \ge \exp (\theta)$$, so$$\exp(-\log(t) \hat w^T x_n) \le \exp(-\theta \log(t)) = \frac{1}{t^\theta}$$。并且有$$\exp(-z) z \le 1$$,所以第二个括号有:

$$ \begin{align} & \sum{n \notin S} \exp(-\log(t) \hat w^T x_n - \tilde w^T x_n - x_n^T r(t)) x_n^T r(t)\ = & \sum{n \notin S} \big( \exp(-\log(t) \hat w^T xn) \cdot \exp(- \tilde w^T x_n) \cdot \exp(- x_n^T r(t)) \cdot x_n^T r(t) \big)\ & \le \frac{1}{t^\theta} \sum{n \notin S} \exp(- \tilde w^T x_n) \end{align}

$$

因此,将这两个不等式代回到式子,有

$$ | r(t) |^2 - | r(t1) |^2 \le C \int{t_1}^t \frac{d t}{t^\theta} \le C' < \infty

$$

因为$$\theta > 1$$。

A.2 General Case

下面考虑更一般的情况,discrete time,exponentially-tailed function。

首先需要两个辅助的lemma

Lemma: Let $$\mathcal{L}(w)$$ be $$\beta$$-smooth non-negative objective. If $$\eta < 2 \beta^-1$$, then for any $$w(0)$$, with the GD sequence

$$ w(t+1) = w(t) - \eta \nabla \mathcal{L}(w(t))

$$

we have that $$\sum{u=0}^\infty | \nabla \mathcal{L}(w(u)) |^2 < \infty$$, therefore $$\lim{t \to \infty} | \nabla \mathcal{L}(w(t)) |^2 = 0$$.

Lemma: We have

Summary

results matching ""

    No results matching ""