Squeezed Attention: Accelerating Long Context Length LLM Inference

  • 2024-11-14 18:54:19
  • Coleman Hooper, Sehoon Kim, Hiva Mohammadzadeh, Monishwaran Maheswaran, June Paik, Michael W. Mahoney, Kurt Keutzer, Amir Gholami
  • 0

Abstract

Emerging Large Language Model (LLM) applications require long input promptsto perform complex downstream tasks like document analysis and code generation.For these long context length applications, the length of the input promptposes a significant challenge in terms of inference efficiency since theinference costs increase linearly with sequence length. However, for many ofthese applications, much of the context in the prompt is fixed across differentuser inputs, thereby providing the opportunity to perform offline optimizationsto process user inputs quickly, as they are received. In this work, we proposeSqueezed Attention as a mechanism to accelerate LLM applications where a largeportion of the input prompt is fixed. We first leverage K-means clusteringoffline to group the keys for the fixed context based on semantic similarityand represent each cluster with a single centroid value. During inference, wecompare query tokens from the user input with the centroids to predict which ofthe keys from the fixed context are semantically relevant and need to be loadedduring inference. We then compute exact attention using only these importantkeys from the fixed context, thereby reducing bandwidth and computationalcosts. We also extend our method to use a hierarchical centroid lookup toidentify important keys, which can reduce the complexity of attention fromlinear to logarithmic with respect to the context length. We implementoptimized Triton kernels for centroid comparison and sparse FlashAttention withimportant keys, achieving more than 4x speedups during both the prefill andgeneration phases for long-context inference. Furthermore, we have extensivelyevaluated our method on various long-context benchmarks including LongBench,where it achieves a 3x reduction in KV cache budget without accuracy loss andup to an 8x reduction with <0.5 point accuracy gap for various models.

 

Quick Read (beta)

loading the full paper ...