Parallelizing Linear Transformers with the Delta Rule over Sequence Length

  • 2024-06-10 18:24:42
  • Songlin Yang, Bailin Wang, Yu Zhang, Yikang Shen, Yoon Kim
  • 0

Abstract

Transformers with linear attention (i.e., linear transformers) andstate-space models have recently been suggested as a viable linear-timealternative to transformers with softmax attention. However, these models stillunderperform transformers especially on tasks that require in-contextretrieval. While more expressive variants of linear transformers which replacethe additive outer-product update in linear transformers with the delta rulehave been found to be more effective at associative recall, existing algorithmsfor training such models do not parallelize over sequence length and are thusinefficient to train on modern hardware. This work describes ahardware-efficient algorithm for training linear transformers with the deltarule, which exploits a memory-efficient representation for computing productsof Householder matrices. This algorithm allows us to scale up DeltaNet tostandard language modeling settings. We train a 1.3B model for 100B tokens andfind that it outperforms recent linear-time baselines such as Mamba and GLA interms of perplexity and zero-shot performance on downstream tasks (including ontasks that focus on recall). We also experiment with two hybrid models whichcombine DeltaNet layers with (1) sliding-window attention layers every otherlayer or (2) two global attention layers, and find that these hybrid modelsoutperform strong transformer baselines.

 

Quick Read (beta)

loading the full paper ...