DISTFLASHATTN: Distributed Memory-efficient Attention for Long-context LLMs Training

  • 2024-03-31 22:11:08
  • Dacheng Li, Rulin Shao, Anze Xie, Eric P. Xing, Xuezhe Ma, Ion Stoica, Joseph E. Gonzalez, Hao Zhang
  • 0

Abstract

FlashAttention (Dao, 2023) effectively reduces the quadratic peak memoryusage to linear in training transformer-based large language models (LLMs) on asingle GPU. In this paper, we introduce DISTFLASHATTN, a distributedmemory-efficient attention mechanism optimized for long-context LLMs training.We propose three key techniques: token-level workload balancing, overlappingkey-value communication, and a rematerialization-aware gradient checkpointingalgorithm. We evaluate DISTFLASHATTN on Llama-7B and variants with sequencelengths from 32K to 512K. DISTFLASHATTN achieves 8x longer sequences, 4.45 -5.64x speedup compared to Ring Self-Attention, 2 - 8x longer sequences, 1.24 -2.01x speedup compared to Megatron-LM with FlashAttention. It achieves 1.67xand 1.26 - 1.88x speedup compared to recent Ring Attention andDeepSpeed-Ulysses. Code is available at https://github.com/RulinShao/LightSeq.

 

Quick Read (beta)

loading the full paper ...