CAGE: Curvature-Aware Gradient Estimation For Accurate Quantization-Aware Training

  • 2025-11-10 17:53:51
  • Soroush Tabesh, Mher Safaryan, Andrei Panferov, Alexandra Volkova, Dan Alistarh
  • 0

Abstract

Despite significant work on low-bit quantization-aware training (QAT), thereis still an accuracy gap between such techniques and native training. Toaddress this, we introduce CAGE (Curvature-Aware Gradient Estimation), a newQAT method that augments the straight-through estimator (STE) gradient with acurvature-aware correction designed to counteract the loss increase induced byquantization. CAGE is derived from a multi-objective view of QAT that balancesloss minimization with the quantization constraints, yielding a principledcorrection term that depends on local curvature information. On the theoreticalside, we introduce the notion of Pareto-optimal solutions for quantizedoptimization, and establish that CAGE yields strong convergence guarantees inthe smooth non-convex setting. In terms of implementation, our approach isoptimizer-agnostic, but we provide a highly-efficient implementation thatleverages Adam statistics. CAGE significantly improves upon the priorstate-of-the-art methods in terms of accuracy, for similar computational cost:for QAT fine-tuning, it halves the compression accuracy loss relative to theprior best method, while for QAT pre-training of Llama models, its accuracy for3-bit weights-and-activations (W3A3) matches the accuracy achieved at 4-bits(W4A4) with the prior best method. The official implementation can be foundover https://github.com/IST-DASLab/CAGE .

 

Quick Read (beta)

loading the full paper ...