Transformers to SSMs: Distilling Quadratic Knowledge to Subquadratic Models

  • 2024-08-19 18:48:11
  • Aviv Bick, Kevin Y. Li, Eric P. Xing, J. Zico Kolter, Albert Gu
  • 0

Abstract

Transformer architectures have become a dominant paradigm for domains likelanguage modeling but suffer in many inference settings due to theirquadratic-time self-attention. Recently proposed subquadratic architectures,such as Mamba, have shown promise, but have been pretrained with substantiallyless computational resources than the strongest Transformer models. In thiswork, we present a method that is able to distill a pretrained Transformerarchitecture into alternative architectures such as state space models (SSMs).The key idea to our approach is that we can view both Transformers and SSMs asapplying different forms of mixing matrices over the token sequences. We canthus progressively distill the Transformer architecture by matching differentdegrees of granularity in the SSM: first matching the mixing matricesthemselves, then the hidden units at each block, and finally the end-to-endpredictions. Our method, called MOHAWK, is able to distill a Mamba-2 variantbased on the Phi-1.5 architecture (Phi-Mamba) using only 3B tokens and a hybridversion (Hybrid Phi-Mamba) using 5B tokens. Despite using less than 1% of thetraining data typically used to train models from scratch, Phi-Mamba boastssubstantially stronger performance compared to all past open-sourcenon-Transformer models. MOHAWK allows models like SSMs to leveragecomputational resources invested in training Transformer-based architectures,highlighting a new avenue for building such models.

 

Quick Read (beta)

loading the full paper ...