InfAlign: Inference-aware language model alignment

  • 2024-12-27 18:45:36
  • Ananth Balashankar, Ziteng Sun, Jonathan Berant, Jacob Eisenstein, Michael Collins, Adrian Hutter, Jong Lee, Chirag Nagpal, Flavien Prost, Aradhana Sinha, and Ananda Theertha Suresh, Ahmad Beirami
  • 0

Abstract

Language model alignment has become a critical step in training moderngenerative language models. The goal of alignment is to finetune a referencemodel such that the win rate of a sample from the aligned model over a samplefrom the reference model is high, subject to a KL divergence constraint. Today,we are increasingly using inference-time algorithms (e.g., Best-of-N,controlled decoding, tree search) to decode from language models rather thanstandard sampling. However, the alignment objective does not capture suchinference-time decoding procedures. We show that the existing alignmentframework is sub-optimal in view of such inference-time methods. We then modifythe alignment objective and propose a framework for inference-aware alignment(IAPO). We prove that for any inference-time decoding algorithm, the optimalsolution that optimizes the inference-time win rate of the aligned policyagainst the reference policy is the solution to the typical RLHF problem with atransformation of the reward. This motivates us to provide the KL-regularizedcalibrate-and-transform RL (CTRL) algorithm to solve this problem, whichinvolves a reward calibration step and a KL-regularized reward maximizationstep with a transformation of the calibrated reward. We particularize our studyto two important inference-time strategies: best-of-N sampling and best-of-Njailbreaking, where N responses are sampled from the model and the one with thehighest or lowest reward is selected. We propose specific transformations forthese strategies and demonstrate that our framework offers significantimprovements over existing state-of-the-art methods for language modelalignment. Empirically, we outperform baselines that are designed withouttaking inference-time decoding into consideration by 8-12% and 4-9% oninference-time win rates over the Anthropic helpfulness and harmlessness dialogbenchmark datasets.

 

Quick Read (beta)

loading the full paper ...