Abstract
FlashAttention (Dao, 2023) effectively reduces the quadratic peak memoryusage to linear in training transformer-based large language models (LLMs) on asingle GPU. In this paper, we introduce DISTFLASHATTN, a distributedmemory-efficient attention mechanism optimized for long-context LLMs training.We propose three key techniques: token-level workload balancing, overlappingkey-value communication, and a rematerialization-aware gradient checkpointingalgorithm. We evaluate DISTFLASHATTN on Llama-7B and variants with sequencelengths from 32K to 512K. DISTFLASHATTN achieves 8x longer sequences, 4.45 -5.64x speedup compared to Ring Self-Attention, 2 - 8x longer sequences, 1.24 -2.01x speedup compared to Megatron-LM with FlashAttention. It achieves 1.67xand 1.26 - 1.88x speedup compared to recent Ring Attention andDeepSpeed-Ulysses. Code is available at https://github.com/RulinShao/LightSeq.