Exploring Generalization in Deep Learning

Exploring Generalization in Deep Learning

Behnam Neyshabur, Srinadh Bhojanapalli, David McAllester, Nathan Srebro

TTIC

Intro

有很多讨论network complexity measure的角度,比如measure,sharpness,robustness。如何确定一个合适的complexity measure来解释deep learning的generalization。

  • 对于随机label可能在training上完美拟合。而相比于随即label,如果实用真实label,expect network能够有lower complexity。
  • 增加hidden units,导致参数的增加,会引起generalization error的减少,哪怕training error并不会减少的情况下。
  • 当使用相同架构进行训练的时候,使用相同的数据集,和不同的optimization方法,尽管两个都达到0 training error,还是有一个会表现得更好。我们会期待complexity measure和generalization ability之间的correlation。

Generalization and Capacity Control in Deep Learning

考虑模型的statistical capacity,关于example的数量来保证generalization。比如test error十分接近training error。

complexity measure是hypothesis space笛卡儿积于training set,映射到实数域,$$\mathcal{M}:{ \mathcal{H}, \mathcal{S} } \to \mathcal{R}^+$$。所以对于任意$$\alpha$$,通过capacity得到一个restricted class $$\mathcal{H}_{\mathcal{M}, \alpha} = { h: h \in \mathcal{H}, \mathcal{M}(h) \le \alpha }$$。

一个关于好的假设,有low complexity,而倾向于有low complexity的假设对于学习是足够的,哪怕整个hypothesis space的capacity很高。而如果我们依靠$$\mathcal{M}$$来判断generalization,那么我们倾向于使用的假设h会有比较小的$$\mathcal{M}(h)$$。

考虑好几种不同的complexity measures,每一个measure,都会先考虑是否对于generalization充分,然后分析capacity of $$\mathcal{H}_{\mathcal{M}, \alpha}$$。

Network Size

对于任何模型,如果参数有有限的精度,那么它的capacity与parameter个数线性相关。feedforward networks的VC dimension可以被参数维度bound住。

VC-dim = $$\tilde O (d * dim(w))$$

所以在over-parameterized的设定下,依赖于整体参数个数的measure会非常weak。这也与这个观察一致:给定随机的label,始终能0 training-error fit。另外,依靠参数个数的measure bound没有办法解释随着hidden units个数增加,generalization性能变差。

Norms and Margins

通过类似l-2 norm的regularization,能够使得linear predictor的capacity与number of parameters独立。类似也有对feedforward network建立的基于norm的complexity measure。比如通过每一层layer的l-1 norm的bound,capacity正比于$$\prod{i=1}^{d} |W_i|{1,\infty}^2$$,其中$$|Wi|{1,\infty}$$是第i层layer的hidden unit的l-1norm的最大值。generalization bound with capacity正比于$$\prod{i=1}^d | W_i |_2^2 (\sum{j=1}^d ( |W_j_1| / |W_j|_2)^{2/3} )$$。

然后谈到了scaling问题。在考虑0-1 loss的时候,通过scaling,总归能给得到非常小的parameter norm。因此应该使用scale sensitive loss,比如cross entropy loss。

在比较不同模型,以cross entropy loss作为优化目标的时候,有需要注意的地方,尤其是当training error为0.当training error走向0,为了使得cross entropy loss也变成0,network的output就会走向无穷大,因此norm也会趋向无穷大。也就是minimize cross entropy loss会导致norm趋向于无穷。但是norm的数值很多时候是预示optimization能够progress多久,使用比较严格的stopping criteria会导致norm更大。换句话说,比较不同的model会发现,使用不同的优化函数没有意义,因为最终norm都会走向无穷大。

如果要有意义地比较两个network的norm,就需要明确地带入output scaling。在training error为0的情况下,一种考虑的方法就是预测的margin。margin指的是对于一个特定的data point,预测正确label和其他label的差值,也就是

$$fw(x) [y{true}] - max{y\ne y{true}} f_w(x)[y]$$

为了在整个training set上measure scale,一个简单的方法是考虑hard margin,是所有training point的最小margin。但这对于extreme point和dataset size比较敏感。我们考虑更加robust,允许一小部分data points可以violate margin。对于一个training set和很小的数值$$\epsilon > 0$$,定义margin $$\gamma_{margin}$$为有$$\lceil \epsilon m \rceil$$个data都小于这个margin $$\gamma$$的最小值。(也就是最小的$$\epsilon$$data,再取其中margin的最大值)根据经验发现$$\epsilon$$变化并不会影响台独哦,比如从0.001到0.1。

bound涉及到了path-norm,得follow一下前面的工作。

Sharpness

sharpness是最近新提出的概念,对应于parameter space上进行adversarial pertubation的robustness:

$$\zeta (w) = \frac{max{|\nu| \le \alpha(|w|+1)} \hat L(f{w+\nu}) - \hat L(fw)}{1+\hat L(f_w)} \simeq max{|\nu| \le \alpha(|w|+1)} \hat L(f_{w+\nu}) - \hat L(f_w)$$

能够取得这样的近似是因为training error $$\hat L(f_w)$$通常会比较小,所以可以从分母中去掉。

这种定义sharpness的方法并不能够抓住generalization behavior。为了验证这一点,首先我们检查在使用真实label和随机label的情况下,sharpness能否预测generalization behavior。尽管在bigger network上,sharpness能够很好地预测generalization behavior;但是在size更小的network上,相比于使用true label,使用随机label训练的模型有更小的sharpness。(我们期望是true label训练的模型sharpness一直更大)加上scaleness,因此只考虑sharpness不足以控制network capacity。

我们提出将sharpness放到PAC-Bayesian framework考虑。这里会发现sharpness只能控制1-2个相关项,并且必须和其他measure比如norm进行平衡。一起考虑,sharpness和norm提供了capacity control,并且能够解释很多现象。

公式4表明了bound会受两个因素控制,一个是expected sharpness,另一个与prior的KL divergence。

Empirical Investigation

实验结果很有意思。deep network的capacity一般都很大,所以可以很容易找到不同的local minima,而且满足training error为0。

Appendix

这篇paper……嗯,需要follow up一下context。

results matching ""

    No results matching ""