DISTFLASHATTN: Distributed Memory-efficient Attention for Long-context LLMs Training

  • 2024-03-31 22:11:08
  • Dacheng Li, Rulin Shao, Anze Xie, Eric P. Xing, Xuezhe Ma, Ion Stoica, Joseph E. Gonzalez, Hao Zhang
FlashAttention (Dao, 2023) effectively reduces the quadratic peak memoryusage to linear in training transformer-based large language models (LLMs) on asingle GPU. In this paper, we introduce DISTFLASHATTN, a distributedmemory-efficient attention mechanism optimized for long-context LLMs training.We propose three key techniques: token-level workload balancing, overlappingkey-value communication, and a rematerialization-aware gradient checkpointingalgorithm. We evaluate DISTFLASHATTN on Llama-7B and variants with sequencelengths from 32K to 512K. DISTFLASHATTN achieves 8x longer sequences, 4.45 -5.64x speedup compared to Ring Self-Attention, 2 - 8x longer sequences, 1.24 -2.01x speedup compared to Megatron-LM with FlashAttention. It achieves 1.67xand 1.26 - 1.88x speedup compared to recent Ring Attention andDeepSpeed-Ulysses. Code is available at https://github.com/RulinShao/LightSeq.


