Out-of-Distribution Generalization via Risk Extrapolation (REx)

  • 2021-02-25 17:53:07
  • David Krueger, Ethan Caballero, Joern-Henrik Jacobsen, Amy Zhang, Jonathan Binas, Dinghuai Zhang, Remi Le Priol, Aaron Courville
  • 0

Abstract

Distributional shift is one of the major obstacles when transferring machinelearning prediction systems from the lab to the real world. To tackle thisproblem, we assume that variation across training domains is representative ofthe variation we might encounter at test time, but also that shifts at testtime may be more extreme in magnitude. In particular, we show that reducingdifferences in risk across training domains can reduce a model's sensitivity toa wide range of extreme distributional shifts, including the challengingsetting where the input contains both causal and anti-causal elements. Wemotivate this approach, Risk Extrapolation (REx), as a form of robustoptimization over a perturbation set of extrapolated domains (MM-REx), andpropose a penalty on the variance of training risks (V-REx) as a simplervariant. We prove that variants of REx can recover the causal mechanisms of thetargets, while also providing some robustness to changes in the inputdistribution ("covariate shift"). By appropriately trading-off robustness tocausally induced distributional shifts and covariate shift, REx is able tooutperform alternative methods such as Invariant Risk Minimization insituations where these types of shift co-occur.

 

Quick Read (beta)

loading the full paper ...