Backward Feature Correction: How Deep Learning Performs Deep Learning

  • 2020-01-13 17:28:29
  • Zeyuan Allen-Zhu, Yuanzhi Li
  • 96

Abstract

How does a 110-layer ResNet learn a high-complexity classifier usingrelatively few training examples and short training time? We present a theorytowards explaining this in terms of $\textit{hierarchical learning}$. We referhierarchical learning as the learner learns to represent a complicated targetfunction by decomposing it into a sequence of simpler functions to reducesample and time complexity. This paper formally analyzes how multi-layer neuralnetworks can perform such hierarchical learning efficiently and automaticallysimply by applying stochastic gradient descent (SGD). On the conceptual side,we present, to the best of our knowledge, the FIRST theory result indicatinghow very deep neural networks can still be sample and time efficient on certainhierarchical learning tasks, when NO KNOWN non-hierarchical algorithms (such askernel method, linear regression over feature mappings, tensor decomposition,sparse coding) are efficient. We establish a new principle called "backwardfeature correction", which we believe is the key to understand the hierarchicallearning in multi-layer neural networks. On the technical side, we show forregression and even for binary classification, for every input dimension $d >0$, there is a concept class consisting of degree $\omega(1)$ multi-variatepolynomials so that, using $\omega(1)$-layer neural networks as learners, SGDcan learn any target function from this class in $\mathsf{poly}(d)$ time using$\mathsf{poly}(d)$ samples to any $\frac{1}{\mathsf{poly}(d)}$ error, throughlearning to represent it as a composition of $\omega(1)$ layers of quadraticfunctions. In contrast, we present lower bounds stating that severalnon-hierarchical learners, including any kernel methods, neural tangentkernels, must suffer from $d^{\omega(1)}$ sample or time complexity to learnfunctions in this concept class even to any $d^{-0.01}$ error.

 

Quick Read (beta)

loading the full paper ...