Abstract
Transformers with linear attention allow for efficient parallel training butcan simultaneously be formulated as an RNN with 2D (matrix-valued) hiddenstates, thus enjoying linear (with respect to output length) inferencecomplexity. Recent works such as RetNet (Sun et al., 2023) and TransNormerLLM(Qin et al., 2023a) observe that adding a global decay term to the additive RNNupdate rule greatly improves performance, sometimes outperforming standardTransformers with softmax attention when trained at scale. In this work we showthat adding a data-dependent gating mechanism further improves performance. Wederive a parallel form of this gated linear attention layer that enablesefficient training. However, a straightforward, numerically stableimplementation of this parallel form requires generalized matrixmultiplications in log-space for numerical stability, and thus cannot takeadvantage of tensor cores on modern GPUs which are optimized for standardmatrix multiplications. We develop a hardware-efficient version of the parallelform that can still make use of tensor cores through block-parallelcomputations over sequence chunks. Experiments on moderate-scale languagemodeling (340M-parameter models trained on 15B tokens, 1.3B-parameter modelstrained on 100B tokens) show that gated linear attention (GLA) Transformersperform competitively against a strong LLaMA-architecture Transformer baseline(Touvron et al., 2023) as well as Mamba (Gu & Dao, 2023), a recently introducedstate-space model with a data-dependent state transition mechanism. Fortraining speed, our Triton-based implementation performs comparably toCUDA-optimized FlashAttention-2 (Dao, 2023) under the regular 2048 traininglength setting, while outperforming FlashAttention-2 when training on longersequences beyond 4096.