Abstract
The attention mechanism forms the foundational blocks for transformerlanguage models. Recent approaches show that scaling the model achieveshuman-level performance. However, with increasing demands for scaling andconstraints on hardware memory, the inference costs of these models remainhigh. To reduce the inference time, Multi-Query Attention (MQA) andGrouped-Query Attention (GQA) were proposed in (Shazeer, 2019) and (Ainslieetal., 2023) respectively. In this paper, we propose a variation of Grouped-QueryAttention, termed Weighted Grouped-Query Attention (WGQA). We introduced newlearnable parameters for each key and value head in the T5 decoder attentionblocks, enabling the model to take a weighted average during finetuning. Ourmodel achieves an average of 0.53% improvement over GQA, and the performanceconverges to traditional Multi-head attention (MHA) with no additional overheadduring inference. We evaluated the introduction of these parameters andsubsequent finetuning informs the model about the grouping mechanism duringtraining, thereby enhancing performance. Additionally, we demonstrate thescaling laws in our analysis by comparing the results between T5-small andT5-base architecture.