ShiftAddViT: Mixture of Multiplication Primitives Towards Efficient Vision Transformer

  • 2024-07-25 18:19:31
  • Haoran You, Huihong Shi, Yipin Guo, Yingyan Celine Lin
  • 0

Abstract

Vision Transformers (ViTs) have shown impressive performance and have becomea unified backbone for multiple vision tasks. However, both the attentionmechanism and multi-layer perceptrons (MLPs) in ViTs are not sufficientlyefficient due to dense multiplications, leading to costly training andinference. To this end, we propose to reparameterize pre-trained ViTs with amixture of multiplication primitives, e.g., bitwise shifts and additions,towards a new type of multiplication-reduced model, dubbed$\textbf{ShiftAddViT}$, which aims to achieve end-to-end inference speedups onGPUs without requiring training from scratch. Specifically, all$\texttt{MatMuls}$ among queries, keys, and values are reparameterized usingadditive kernels, after mapping queries and keys to binary codes in Hammingspace. The remaining MLPs or linear layers are then reparameterized with shiftkernels. We utilize TVM to implement and optimize those customized kernels forpractical hardware deployment on GPUs. We find that such a reparameterizationon attention maintains model accuracy, while inevitably leading to accuracydrops when being applied to MLPs. To marry the best of both worlds, we furtherpropose a new mixture of experts (MoE) framework to reparameterize MLPs bytaking multiplication or its primitives as experts, e.g., multiplication andshift, and designing a new latency-aware load-balancing loss. Such a loss helpsto train a generic router for assigning a dynamic amount of input tokens todifferent experts according to their latency. Extensive experiments on various2D/3D Transformer-based vision tasks consistently validate the effectiveness ofour proposed ShiftAddViT, achieving up to $\textbf{5.18$\times$}$ latencyreductions on GPUs and $\textbf{42.9}$% energy savings, while maintaining acomparable accuracy as original or efficient ViTs.

 

Quick Read (beta)

loading the full paper ...