State-space models can learn in-context by gradient descent

  • 2025-02-18 18:55:39
  • Neeraj Mohan Sushma, Yudou Tian, Harshvardhan Mestha, Nicolo Colombo, David Kappel, Anand Subramoney
  • 0

Abstract

Deep state-space models (Deep SSMs) are becoming popular as effectiveapproaches to model sequence data. They have also been shown to be capable ofin-context learning, much like transformers. However, a complete picture of howSSMs might be able to do in-context learning has been missing. In this study,we provide a direct and explicit construction to show that state-space modelscan perform gradient-based learning and use it for in-context learning in muchthe same way as transformers. Specifically, we prove that a single structuredstate-space model layer, augmented with multiplicative input and output gating,can reproduce the outputs of an implicit linear model with least squares lossafter one step of gradient descent. We then show a straightforward extension tomulti-step linear and non-linear regression tasks. We validate our constructionby training randomly initialized augmented SSMs on linear and non-linearregression tasks. The empirically obtained parameters through optimizationmatch the ones predicted analytically by the theoretical construction. Overall,we elucidate the role of input- and output-gating in recurrent architectures asthe key inductive biases for enabling the expressive power typical offoundation models. We also provide novel insights into the relationship betweenstate-space models and linear self-attention, and their ability to learnin-context.

 

Quick Read (beta)

loading the full paper ...