Abstract
We propose SWAGaussian (SWAG), a simple, scalable, and general purposeapproach for uncertainty representation and calibration in deep learning.Stochastic Weight Averaging (SWA), which computes the first moment ofstochastic gradient descent (SGD) iterates with a modified learning rateschedule, has recently been shown to improve generalization in deep learning.With SWAG, we fit a Gaussian using the SWA solution as the first moment and alow rank plus diagonal covariance also derived from the SGD iterates, formingan approximate posterior distribution over neural network weights; we thensample from this Gaussian distribution to perform Bayesian model averaging. Weempirically find that SWAG approximates the shape of the true posterior, inaccordance with results describing the stationary distribution of SGD iterates.Moreover, we demonstrate that SWAG performs well on a wide variety of computervision tasks, including out of sample detection, calibration, and transferlearning, in comparison to many popular alternatives including MC dropout, KFACLaplace, and temperature scaling.
Quick Read (beta)
Abstract
We propose SWAGaussian (SWAG), a simple, scalable, and general purpose approach for uncertainty representation and calibration in deep learning. Stochastic Weight Averaging (SWA), which computes the first moment of stochastic gradient descent (SGD) iterates with a modified learning rate schedule, has recently been shown to improve generalization in deep learning. With SWAG, we fit a Gaussian using the SWA solution as the first moment and a low rank plus diagonal covariance also derived from the SGD iterates, forming an approximate posterior distribution over neural network weights; we then sample from this Gaussian distribution to perform Bayesian model averaging. We empirically find that SWAG approximates the shape of the true posterior, in accordance with results describing the stationary distribution of SGD iterates. Moreover, we demonstrate that SWAG performs well on a wide variety of computer vision tasks, including out of sample detection, calibration, and transfer learning, in comparison to many popular alternatives including MC dropout, KFAC Laplace, and temperature scaling.
oddsidemargin has been altered.
marginparsep has been altered.
topmargin has been altered.
marginparwidth has been altered.
marginparpush has been altered.
paperheight has been altered.
The page layout violates the ICML style.
Please do not change the page layout, or include packages like geometry,
savetrees, or fullpage, which change it for you.
We’re not able to reliably undo arbitrary changes to the style. Please remove
the offending package(s), or layoutchanging commands and try again.
A Simple Baseline for Bayesian Uncertainty in Deep Learning
Wesley Maddox ^{* }^{0 } Timur Garipov ^{* }^{0 } Pavel Izmailov ^{* }^{0 } Dmitry Vetrov ^{0 }^{0 } Andrew Gordon Wilson ^{0 }
\@xsect
Ultimately, machine learning models are used to make decisions. Representing uncertainty is crucial for decision making. For example, in medical diagnoses and autonomous vehicles we want to protect against rare but costly mistakes. Deep learning models typically lack a representation of uncertainty, and provide overconfident and miscalibrated predictions (e.g., Kendall & Gal, 2017; Guo et al., 2017).
Bayesian methods provide a natural probabilistic representation of uncertainty in deep learning (e.g., Blundell et al., 2015; Kingma et al., 2015b; Chen et al., 2014), and previously had been a gold standard for inference with neural networks (Neal, 1996). However, existing approaches are often highly sensitive to hyperparameter choices, and hard to scale to modern datasets and architectures, which limits their general applicability in modern deep learning.
In this paper we propose a different approach to Bayesian deep learning: we use the information contained in the SGD trajectory to efficiently approximate the posterior distribution over the weights of the neural network. We find that the Gaussian distribution fitted to the first two moments of SGD iterates, with a modified learning rate schedule, captures the local geometry of the posterior surprisingly well. Using this Gaussian distribution we are able to obtain convenient, efficient, accurate and wellcalibrated predictions in a broad range of tasks in computer vision. In particular, our contributions are the following:

•
In this work we propose SWAG (SWAGaussian), a scalable approximate Bayesian inference technique for deep learning. SWAG builds on Stochastic Weight Averaging (Izmailov et al., 2018), which computes an average of SGD iterates with a high constant learning rate schedule, to provide improved generalization in deep learning. SWAG additionally computes a lowrank plus diagonal approximation to the covariance of the iterates, which is used together with the SWA mean, to define a Gaussian posterior approximation over neural network weights.

•
SWAG is motivated by the theoretical analysis of the stationary distribution of SGD iterates (e.g., Mandt et al., 2017; Chen et al., 2016), which suggests that the SGD trajectory contains useful information about the geometry of the posterior. In Section id1 we show that the assumptions of Mandt et al. (2017) do not hold for deep neural networks, due to nonconvexity and overparameterization. However, we find that in the subspace spanned by SGD iterates the shape of the posterior distribution is approximately Gaussian. Further, SWAG is able to capture the geometry of this posterior remarkably well.

•
On a thorough empirical evaluation we show that SWAG can provide wellcalibrated uncertainty estimates for neural networks across many settings in computer vision. In particular SWAG achieves higher test likelihood compared to many stateoftheart approaches, including MCDropout (Gal & Ghahramani, 2016), temperature scaling (Guo et al., 2017), KFACLaplace (Ritter et al., 2018b) and SWA (Izmailov et al., 2018) on CIFAR10, CIFAR100 and ImageNet, on a range of architectures. We also demonstrate the effectiveness of SWAG for outofdomain detection, and transfer learning.

•
We release PyTorch code at
github.com/wjmaddox/swa_gaussian
Bayesian approaches represent uncertainty by placing a distribution over model parameters, and then marginalizing these parameters to form a whole predictive distribution, in a procedure known as Bayesian model averaging. In the late 1990s, Bayesian methods were the stateoftheart approach to learning with neural networks, through the seminal works of Neal (1996) and MacKay (1992a). However, modern neural networks often contain millions of parameters, the posterior over these parameters (and thus the loss surface) is highly nonconvex, and minibatch approaches are often needed to move to a space of good solutions (Keskar et al., 2017). For these reasons, Bayesian approaches have largely been intractable for modern neural networks. Here, we review several modern approaches to Bayesian deep learning.
was at one time a gold standard for inference with neural networks, through the Hamiltonian Monte Carlo (HMC) work of Neal (1996). However, HMC requires full gradients, which is computationally intractable for modern neural networks. To extend the HMC framework, stochastic gradient HMC (SGHMC) was introduced by Chen et al. (2014) and allows for stochastic gradients to be used in Bayesian inference, crucial for both scalability and exploring a space of solutions that provide good generalization. While SGHMC has been successfully applied in Bayesian deep learning (e.g., Saatci & Wilson, 2017), tuning SGHMC can be quite difficult, and many settings of SGHMC parameters are quite similar to standard SGD.
Graves (2011) suggested fitting a Gaussian variational posterior approximation over the weights of neural networks. This technique was generalized by Kingma & Welling (2013) which proposed the reparameterization trick for training deep latent variable models; multiple variational inference methods based on reparameterization trick were proposed for DNNs (e.g., Kingma et al., 2015a; Blundell et al., 2015; Molchanov et al., 2017; Louizos & Welling, 2017). While variational methods achieve strong performance for moderately sized networks, they can be difficult to train on larger architectures such as deep residual networks (He et al., 2016), due to the looseness of the optimization objective by Blier & Ollivier (2018).
An alternative line of work reinterprets noisy versions of optimization algorithms: for example, noisy Adam (Khan et al., 2018) and noisy KFAC (Zhang et al., 2017), as approximate variational inference.
Gal & Ghahramani (2016) used a spike and slab variational distribution to view dropout at test time as approximate variational Bayesian inference. Concrete dropout (Gal et al., 2017) extends this idea to optimize the dropout probabilities as well. From a practical perspective, these approaches are quite appealing as they only require ensembling dropout predictions at test time, and they were succesfully applied to several downstream tasks (Kendall & Gal, 2017; Mukhoti & Gal, 2018).
assume a Gaussian posterior, $\mathcal{N}({\theta}^{*},\mathcal{I}{({\theta}^{*})}^{1}),$ where ${\theta}^{*}$ is a MAP estimate and $\mathcal{I}{({\theta}^{*})}^{1}$ is the inverse of the Fisher information matrix (expected value of the Hessian evaluated at ${\theta}^{*}$). It was notably used for Bayesian neural networks in MacKay (1992b), where a diagonal approximation to the inverse of the Hessian was utilized for computational reasons. More recently, Kirkpatrick et al. (2017) proposed to use diagonal Laplace approximations to overcome catastrophic forgetting in deep learning. Ritter et al. (2018b) proposed the use of either a diagonal or block Kronecker factored (KFAC) approximation to the Hessian matrix for Laplace approximations, and Ritter et al. (2018a) successfully applied the KFAC approach to online learning situations.
Alternatively, Lakshminarayanan et al. (2017) proposed using ensembles of several networks for enhanced calibration, and incorporated an adversarial loss function to be used when possible as well. Outside of probabilistic neural networks, Guo et al. (2017) proposed temperature scaling, a procedure which uses a validation set and a single hyperparameter to rescale the logits of DNN outputs for enhanced calibration. Kuleshov et al. (2018) propose calibrated regression using a similar rescaling technique.
In this section we propose SWAGaussian (SWAG) for Bayesian model averaging and uncertainty estimation. In Section id1, we review stochastic weight averaging (SWA) (Izmailov et al., 2018), which we view as estimating the mean of the stationary distribution of SGD iterates. We then propose SWAGaussian in Sections id1 and id1 to estimate the covariance of the stationary distribution, forming a Gaussian approximation to the posterior over weight parameters. With SWAG, uncertainty in weight space is captured with minimal modifications to the SWA training procedure. We then present further theoretical and empirical analysis for SWAG in Section id1.
The main idea of SWA (Izmailov et al., 2018) is to run SGD with a constant learning rate schedule starting from a pretrained solution, and to average the weights of the models it traverses. Denoting the weights of the network obtained after epoch $i$ of SWA training ${\theta}_{i},$ the SWA solution after $T$ epochs is given by ${\theta}_{\text{SWA}}=\frac{1}{T}{\sum}_{i=1}^{T}{\theta}_{i}.$ A high constant learning rate schedule ensures that SGD explores the set of possible solutions instead of simply converging to a single point in the weight space. Izmailov et al. (2018) argue that conventional SGD training converges to the boundary of the set of highperforming solutions; SWA on the other hand is able to find a more centered solution that is robust to the shift between train and test distributions, leading to improved generalization performance. SWA and related ideas have been successfully applied to a wide range of applications (see e.g. Athiwaratkun et al., 2019; Yazici et al., 2019; Nikishin et al., 2018).
A related but different procedure is PolyakRuppert averaging (Polyak & Juditsky, 1992; Ruppert, 1988) in stochastic convex optimization, which uses a learning rate decaying to zero. Burnin periods and thinning are known to accelerate the convergence speed of PolyakRuppert averaging (Rakhlin et al., 2011), which provides some evidence for the effectiveness of SWA in practice.
We first consider a simple diagonal format for the covariance matrix. Diagonal covariance matrices are standard for variational inference and Laplace approximations in Bayesian deep learning (e.g., Blundell et al., 2015; Kirkpatrick et al., 2017). In order to fit a diagonal covariance approximation, we maintain a running average of the second uncentered moment for each weight, and then compute the covariance using the following standard identity at the end of training:
$$\overline{{\theta}^{2}}=\frac{1}{T}\sum _{i=1}^{T}{\theta}_{i}^{2},{\mathrm{\Sigma}}_{\text{diag}}=\text{diag}(\overline{{\theta}^{2}}{\theta}_{\text{SWA}}^{2}),$$ 
where the squares in ${\theta}_{\text{SWA}}^{2}$ and ${\theta}_{i}^{2}$ are applied elementwise. The resulting approximate posterior distribution is then $\mathcal{N}({\theta}_{\text{SWA}},{\mathrm{\Sigma}}_{\text{Diag}}).$ In our experiments, we term this method SWAGDiagonal.
Constructing the SWAGDiagonal posterior approximation requires storing two additional copies of DNN weights: ${\theta}_{\text{SWA}}$ and $\overline{{\theta}^{2}}$. Note that these models do not have to be stored on the GPU. The additional computational complexity of constructing SWAGDiagonal compared to standard training is negligible, as it only requires updating the running averages of weights once per epoch.
While the diagonal covariance approximation is standard in Bayesian deep learning, it can be too restrictive. We extend the idea of SWAGDiagonal to use a more flexible lowrank plus diagonal form of the approximate covariance.
Note that the sample covariance matrix can be written as the sum of outer products, $\mathrm{\Sigma}=\frac{1}{T1}{\sum}_{i=1}^{T}({\theta}_{i}{\theta}_{\text{SWA}}){({\theta}_{i}{\theta}_{\text{SWA}})}^{\top}$, and is of rank $T$. As we do not have access to the value of ${\theta}_{\text{SWA}}$ during training, we approximate the sample covariance with $\mathrm{\Sigma}\approx \frac{1}{T1}{\sum}_{i=1}^{T}({\theta}_{i}{\overline{\theta}}_{i}){({\theta}_{i}{\overline{\theta}}_{i})}^{\top}=\frac{1}{T1}D{D}^{\top},$ where $D$ is the deviation matrix comprised of columns ${D}_{i}=({\theta}_{i}{\overline{\theta}}_{i}),$ and ${\overline{\theta}}_{i}$ is the running estimate of the parameters’ mean obtained from the first $i$ samples. In practice to limit the rank of the estimated covariance matrix we only use the last $K$ of ${D}_{i}$ vectors corresponding to the last $K$ epochs of training. Here $K$ is the rank of the resulting approximation and is a hyperparameter of the method. We define $\widehat{D}$ to be the matrix with columns equal to ${D}_{i}$ for $i=TK+1,\mathrm{\dots},T$.
We then combine the resulting lowrank approximation ${\mathrm{\Sigma}}_{\text{lowrank}}=\frac{1}{K1}\cdot \widehat{D}{\widehat{D}}^{\top}$ with the diagonal approximation ${\mathrm{\Sigma}}_{\text{diag}}$ of section id1. The resulting approximate posterior distribution is a Gaussian with the SWA mean ${\theta}_{\text{SWA}}$ and summed covariance: $\mathcal{N}({\theta}_{\text{SWA}},\frac{1}{2}\cdot ({\mathrm{\Sigma}}_{\text{diag}}+{\mathrm{\Sigma}}_{\text{lowrank}}))$.^{1}^{1} 1 We use one half as the scale here because both the diagonal and low rank terms include the variance of the weights. We tested several other scales in Appendix id1. In our experiments, we term this method SWAG. Computing this approximate posterior distribution requires storing $K$ vectors ${D}_{i}$ of the same size as the model as well as the vectors ${\theta}_{\text{SWA}}$ and $\overline{{\theta}^{2}}$. These models do not have to be stored on a GPU.
Related methods for estimating the covariance of SGD iterates were considered in Mandt et al. (2017) and Chen et al. (2016), but store fullrank covariance $\mathrm{\Sigma}$ and thus scale quadratically in the number of parameters, which is prohibitively expensive for deep learning applications. We additionally note that using the deviation matrix for online covariance matrix estimation comes from viewing the online updates used in Dasgupta & Hsu (2007) in matrix fashion.
We sample from SWAGdiagonal by taking
$$\stackrel{~}{\theta}={\theta}_{SWA}+{\mathrm{\Sigma}}_{\text{diag}}^{\frac{1}{2}}{z}_{1},{z}_{1}\sim \mathcal{N}(0,{I}_{d}),$$ 
where $d$ is the number of parameters in the network. Note that ${\mathrm{\Sigma}}_{\text{diag}}$ is diagonal, and the product ${\mathrm{\Sigma}}_{\text{diag}}^{\frac{1}{2}}{z}_{1}$ can be computed in $\mathcal{O}(d)$ time. Similarly, to sample from SWAG we use the following identity
$\stackrel{~}{\theta}={\theta}_{SWA}+{\displaystyle \frac{1}{\sqrt{2}}}\cdot {\mathrm{\Sigma}}_{\text{diag}}^{\frac{1}{2}}{z}_{1}+{\displaystyle \frac{1}{\sqrt{2(K1)}}}\widehat{D}{z}_{2},$  (1)  
${z}_{1}\sim \mathcal{N}(0,{I}_{d}),{z}_{2}\sim \mathcal{N}(0,{I}_{K}).$ 
In both cases, Bayesian predictions on the testing data, ${y}^{*},$ using the approximate posterior, $q$, can now be made using Bayesian model averaging:
$p({y}^{*}\text{Data})$  $={\mathbb{E}}_{p(\theta \text{Data})}(p({y}^{*}\theta ))$  
$\approx {\displaystyle \frac{1}{S}}{\displaystyle \sum _{i=1}^{S}}p({y}^{*}{\stackrel{~}{\theta}}_{i}),{\stackrel{~}{\theta}}_{i}\sim q(\theta ),$ 
where $q$ can be either the SWAG or SWAGdiagonal approximate posterior distribution, and $S$ is the number of samples. The full Bayesian model averaging procedure for SWAG is described in Algorithm 1. We note that like Izmailov et al. (2018) (SWA) we update the batch normalization statistics after sampling weights for models that use batch normalization (Ioffe & Szegedy, 2015); an investigation into the necessity of this update is in Appendix id1.
Maximum aposteriori (MAP) optimization is a procedure whereby one maximizes the (log) posterior with respect to parameters $\theta $:
$\mathrm{log}p(\theta \mathcal{D})=\mathrm{log}p(\mathcal{D}\theta )+\mathrm{log}p(\theta ).$  (2) 
Here, the prior $p(\theta )$ is viewed as a regularizer in optimization. However, MAP is not Bayesian inference, since one only considers a single setting of the parameters ${\widehat{\theta}}_{\text{MAP}}={\text{argmax}}_{\theta}p(\theta \mathcal{D})$ in making predictions, forming $p({y}_{*}{\widehat{\theta}}_{\text{MAP}},{x}_{*})$, where ${x}_{*}$ and ${y}_{*}$ are test inputs and outputs.
A Bayesian procedure instead marginalizes the posterior distribution over $\theta $, in a Bayesian model average, for the unconditional predictive distribution:
$p({y}_{*}\mathcal{D},{x}_{*})={\displaystyle \int p({y}_{*}\theta ,{x}_{*})p(\theta \mathcal{D})\mathit{d}\theta}.$  (3) 
In practice, this integral is computed through a Monte Carlo sampling procedure:
$p({y}_{*}\mathcal{D},{x}_{*})\approx {\displaystyle \frac{1}{T}}{\displaystyle \sum _{t=1}^{T}}p({y}_{*}{\theta}_{t},{x}_{*}),{\theta}_{t}\sim p(\theta \mathcal{D}).$  (4) 
We emphasize that we are performing full Bayesian inference, rather than MAP optimization. We develop a Gaussian approximation to the posterior from SGD iterates, $p(\theta \mathcal{D})\approx \mathcal{N}(\theta ;\mu ,\mathrm{\Sigma})$, and then sample from this posterior distribution to perform a Bayesian model average. In our procedure, optimization with different regularizers, to characterize the Gaussian posterior approximation, corresponds to fully Bayesian inference with different priors $p(\theta )$.
Typically, weight decay is used to regularize DNNs, corresponding to explicit L2 regularization when SGD without momentum is used to train the model. When SGD is used with momentum, as is typically the case, implicit regularization still occurs, producing a vague prior on the weights of the DNN in our procedure. This regularizer can be given an explicit Gaussianlike form (see Proposition 3 of Loshchilov & Hutter (2019)), corresponding to a prior distribution on the weights. Thus, SWAG is an approximate Bayesian inference algorithm in our experiments (see Section id1) and can be applied to most DNNs without any modifications of the training procedure (as long as SGD is used with weight decay or explicit L2 regularization). Alternative regularization techniques could also be used, producing different priors on the weights.
Standard training of deep neural networks (DNNs) proceeds by applying stochastic gradient descent on the model weights $\theta $ with the following update rule:
$$\mathrm{\Delta}{\theta}_{t}=\eta \left(\frac{1}{B}\sum _{i=1}^{B}{\nabla}_{\theta}(\mathrm{log}p({y}_{i}{f}_{\theta}({x}_{i})))\frac{{\nabla}_{\theta}\mathrm{log}p(\theta )}{N}\right),$$ 
where the learning rate is $\eta ,$ the $i$th input (e.g. image) and label are $\{{x}_{i},{y}_{i}\}$, the size of the whole training set is $N$, the size of the batch is $B$, and the DNN, $f,$ has weight parameters $\theta $.^{2}^{2} 2 We ignore momentum for simplicity in this update; however the following theoretical results can be extended to include momentum. The loss function is a negative log likelihood ${\sum}_{i}\mathrm{log}p({y}_{i}{f}_{\theta}({x}_{i})),$ combined with a regularizer $\mathrm{log}p(\theta )$. This type of maximum likelihood training does not represent uncertainty in the predictions or parameters $\theta $.
Under conditions of decaying learning rates, smoothness of gradients, and the existence of a full rank stationary distribution, martingale based analyses of stochastic gradient descent (e.g., Asmussen & Glynn, 2007, Chapter 8) show that SGD has a Gaussian limiting distribution. That is, in the infinite time step limit, ${t}^{1/2}({\theta}_{t}{\theta}^{*})\to \mathcal{N}(0,R),$ where $R$ is some covariance matrix and ${\theta}^{*}$ is a stationary point or minima. Mandt et al. (2017), and Chen et al. (2016) are additional examples of this style of analysis. We focus on the assumptions of Mandt et al. (2017) in the next section. We note that these are essentially the same conditions as for the Bernstein von Mises Theorem (e.g., Vaart, 1998, Chapter 10) which shows that the asymptotic posterior will also be Gaussian asymptotically.
In this section, we investigate the results of Mandt et al. (2017) in the context of deep learning. Mandt et al. (2017) uses the following assumptions:

1.
Gradient noise at each point $\theta $ is $\mathcal{N}(0,C)$.

2.
$C$ is independent of $\theta $ and full rank.

3.
The learning rates, $\eta ,$ are small enough that we can approximate the dynamics of SGD by a continuoustime dynamic described by the corresponding stochastic differential equation.

4.
In the stationary distribution, the loss is approximately quadratic near the optima, i.e. approximately ${(\theta {\theta}^{*})}^{\top}\mathbb{H}(\theta )(\theta {\theta}^{*}),$ where $\mathbb{H}({\theta}^{*})$ is the Hessian at the optimum; further, the Hessian is assumed to be positive definite.
Assumption 1 is motivated by the central limit theorem, and Assumption 3 is necessary for the analysis in Mandt et al. (2017). Assumptions 2 and 4 may or may not hold for deep neural networks (as well as other models). Under these assumptions, Theorem 1 of Mandt et al. (2017) derives the optimal constant learning rate that minimizes the KLdivergence between the SGD stationary distribution and the posterior^{3}^{3} 3 An optimal diagonal preconditioner is also derived; our empirical work applies to that setting as well. A similar analysis with momentum holds as well, adding in only the momentum coefficient.:
$${\eta}^{*}=2\frac{B}{N}\frac{d}{tr(C)},$$  (5) 
where $N$ is the size of the dataset, $d$ is the dimension of the model, $B$ is the minibatch size and $C$ is the gradient noise covariance.
We computed Equation 5 over the course of training for two neural networks in Figure A.3(a), finding that the predicted optimal learning rate was orders of magnitude larger than what would be used in practice to explore the loss surface in a reasonable time (about $4\times {10}^{5}$ compared to $0.1$). See Appendix id1 for further details.
We now focus on seeing how Assumptions 2 and 4 fail for DNNs; this will give further insight into what portions of the theory do hold, and may give insights into a corrected version of the optimal learning rate.
In Figure A.3(b), the trace of the gradient noise covariance and thus the optimal learning rates are nearly constant; however, the total variance is much too small to induce effective learning rates, probably due to overparameterization effects inducing non full rank gradient covariances as was found in Chaudhari & Soatto (2018). We note that this experiment is not sufficient to be fully confident that $C$ is independent of the parameterization near the local optima, but rather that $tr(C)$ is close to constant; further experiments in this vein are necessary to test if the diagonals of $C$ are constant. The result that $tr(C)$ is close to constant suggests that a constant learning rate could be used for sampling in a stationary phase of training. The dimensionality parameter in Equation 5 could be modified to use the number of effective parameters or the rank of the gradient noise to reduce the optimal learning rate to a feasible number.
To test assumption 4, we used a GPUenabled Lanczos method from GPyTorch (Gardner et al., 2018) and used restarting to compute the minimum eigenvalue of the train loss of a pretrained PreResNet164 on CIFAR100. We found that even at the end of training, the minimum eigenvalue was $272$ (the maximum eigenvalue was $3580$ for comparison), indicating that the Hessian is not positive definite. This result harmonizes with other work analyzing the spectra of the Hessian for DNN training (Li et al., 2018; Sagun et al., 2018). Further, Garipov et al. (2018) and Draxler et al. (2018) argue that the loss surfaces of DNNs have directions along which the loss is completely flat, suggesting that the loss is nowhere near a positivedefinite quadratic form.
As we have seen in Section id1, the theoretical results of Mandt et al. (2017) may not directly apply to DNNs. This is slightly different from our intuition about the problem; we expect the SGD trajectory to contain information about the shape of the loss surface (and thus also the posterior distribution), because the loss surface informs how we step from one point to the next. In this section we empirically verify that SWAG can capture the local geometry of the posterior distribution.
In order to analyze the quality of the SWAG approximation, we study the posterior density along the directions corresponding to the eigenvectors of the SWAG covariance matrix for PreResNet164 on CIFAR100. In order to find these eigenvectors we use randomized SVD (Halko et al., 2011).^{4}^{4} 4 From sklearn.decomposition.TruncatedSVD. In the left panel of Figure 1 we visualize the ${\mathrm{\ell}}_{2}$regularized crossentropy loss $L(\cdot )$ (equivalent to the joint density of the weights and the loss with a Gaussian prior) as a function of distance $t$ from the SWA solution ${\theta}_{\text{SWA}}$ along the $i$th eigenvector ${v}_{i}$ of the SWAG covariance:
$$\varphi (t)=L\left({\theta}_{\text{SWA}}+t\cdot \frac{{v}_{i}}{\parallel {v}_{i}\parallel}\right).$$ 
As we can see in the left panel of Figure 1, there is a clear correlation between the variance of the SWAG approximation and the width of the posterior along the directions ${v}_{i}$. The SGD iterates indeed contain useful information about the shape of the posterior distribution, and SWAG is able to capture this information. We repeated the same experiment for SWAGDiagonal, finding that there was almost no variance in these eigendirections. Next, we plot the posterior density surface in the 2dimensional plane in the weight space spanning the two top eigenvectors ${v}_{1}$ and ${v}_{2}$ of the SWAG covariance
$$\psi ({t}_{1},{t}_{2})=L\left({\theta}_{\text{SWA}}+{t}_{1}\cdot \frac{{v}_{1}}{\parallel {v}_{1}\parallel}+{t}_{2}\cdot \frac{{v}_{2}}{\parallel {v}_{2}\parallel}\right).$$ 
We plot the results in the middle panel of Figure 1. Again, we can see that SWAG is able to capture the geometry of the posterior. The contours of constant posterior density appear to be remarkably well aligned with the eigenvalues of the SWAG covariance. We also present the analogous plot for the third and fourth top eigenvectors in the right panel of Figure 1. In Appendix id1, we additionally present similar results for PreResNet164 on CIFAR10 and VGG16 on CIFAR100.
As we can see, SWAG is able to capture the geometry of the posterior in the subspace spanned by SGD iterates. However, the dimensionality of this subspace is very low compared to the dimensionality of the weight space, and we can not guarantee that SWAG variance estimates are adequate along all directions in weight space. In particular, we would expect SWAG to underestimate the variances along random directions, as the SGD trajectory is in a lowdimensional subspace of the weight space, and a random vector has a closetozero projection on this subspace with high probability.
We conduct a thorough empirical evaluation of SWAG, comparing to a range of high performing baselines, including MC dropout (Gal & Ghahramani, 2016), temperature scaling (Guo et al., 2017), and Laplace approximations (Ritter et al., 2018b). In Section id1 we evaluate SWAG predictions and uncertainty estimates on image classification tasks. We also evaluate SWAG for transfer learning and outofdomain data detection. We present experiments investigating the effect of hyperparameter choices and practical limitations in SWAG in Appendix id1.
In this section we evaluate the quality of uncertainty estimates as well as predictive accuracy for SWAG and SWAGDiagonal and baselines on CIFAR10, CIFAR100 and ImageNet ILSVRC2012 (Russakovsky et al., 2015).
For all methods we analyze test negative loglikelihood, which reflects both the accuracy and the quality of predictive uncertainty. Following Guo et al. (2017) we also consider a variant of reliability diagrams to evaluate the calibration of uncertainty estimates (see Figure 3) and to show the difference between a method’s confidence in its predictions and its accuracy. To produce this plot for a given method we split the test data into $20$ bins uniformly based on the confidence of a method (maximum predicted probability). We then evaluate the accuracy and mean confidence of the method on the images from each bin, and plot the difference between confidence and accuracy. For a wellcalibrated model, this difference should be close to zero for each bin. We found that this procedure gives a more effective visualization of the actual confidence distribution of DNN predictions than the standard reliability diagrams used in Guo et al. (2017) and NiculescuMizil & Caruana (2005).
We provide tables containing the test accuracy, negative log likelihood and expected calibration error for all methods and datasets in Appendix id1.
On CIFAR datasets we run experiments with VGG16, PreResNet164 and WideResNet28x10 networks. In order to compare SWAG with existing alternatives we report the results for standard SGD and SWA (Izmailov et al., 2018) solutions (single models), MCDropout (Gal & Ghahramani, 2016), temperature scaling (Guo et al., 2017) applied to SWA and SGD solutions, and KFAC Laplace (Ritter et al., 2018b) methods. For all the methods we use our implementations in PyTorch (see Appendix id1). We train all networks for $300$ epochs, starting to collect models for SWA and SWAG approximations once per epoch after epoch $160$. For SWAG, KFAC Laplace, and Dropout we use 30 samples at test time.
On ImageNet we report our results for SWAG, SWAGDiagonal, SWA and SGD. We run experiments with DenseNet161 (Huang et al., 2017) and Resnet152 (He et al., 2016). For each model we start from a pretrained model available in the torchvision package, and run SGD with a constant learning rate for $10$ epochs. We collect models for the SWAG versions and SWA $4$ times per epoch. For SWAG we use $30$ samples from the posterior over network weights at testtime, and use randomly sampled $10\%$ of the training data to update batchnormalization statistics for each of the samples. For SGD with temperature scaling, we use the results reported in Guo et al. (2017).
We provide further details on the architectures and hyperparameter choices in Appendix id1.
We visualize the negative loglikelihood for all methods and datasets in Figure 2. On all considered tasks SWAG and SWAG diagonal perform comparably or better than all the considered alternatives, SWAG being best overall. We note that the combination of SWA and temperature scaling presents a competitive baseline. However, unlike SWAG it requires using a validation set to tune the temperature; further, temperature scaling is not effective when the test data distribution differs from train. We consider this setting in Sections id1 and id1.
Next, we analyze the calibration of uncertainty estimates provided by different methods. In Figure 3 we present reliability plots for WideResNet on CIFAR100, DenseNet161 and ResNet152 on ImageNet. The reliability diagrams for all other datasets and architectures are presented in the Appendix id1. As we can see, SWAG and SWAGDiagonal both achieve good calibration across the board. The lowrank plus diagonal version of SWAG is generally better calibrated than SWAGDiagonal. We also present the expected calibration error for each of the methods, architectures and datasets in Table A.2.
Finally, in Table A.4 we present the predictive accuracy for all of the methods. SWAG verions generally perform comparably to SWA in terms of predictive accuracy and outperform the other methods that do not use SWA.
Method  JSDistance 

SWAG  3.31 
SWAGDiag  2.27 
MC Dropout  3.04 
SWA  1.68 
SGD (Baseline)  3.14 
SGD + Temp. Scaling  2.98 
We evaluate the uncertainties provided by SWAG, SWAGDiagonal and baselines in the transferlearning setting. We use the models trained on CIFAR10 and evaluate them on STL10 (Coates et al., 2011). STL10 has a similar set of classes as CIFAR10, but the image distribution is different, so adapting the model from CIFAR10 to STL10 is a commonly used benchmark in transfer learning.
The test negative loglikelihood for all methods is given in Figure 2, and Figure 3 contains the reliability diagram for WideResNet28x10. As we can see, in transfer learning setting SWAG is significantly outperforming the existing alternatives both in terms of likelihood and calibration. In particular SWA with temperature scaling does not match the performance of SWAG, since the data distribution used to scale the temperature is different from the test distribution. See Appendix id1 for detailed results for all architectures.
Next, we evaluate the SWAG variants along with the baselines on outofsample data detection. To do so we train a WideResNet as described in section id1 on the data from five classes of the CIFAR10 dataset, and then analyze their predictions on the full test set. We expect the outputted class probabilities on objects that belong to classes that were not present in the training data to have highentropy reflecting the model’s high uncertainty in its predictions, and considerably lower entropy on the images that are similar to those on which the network was trained.
To make this comparison quantitative, we computed the symmetrized KL divergence between the binned in and out of sample distributions in Table 1, finding that SWAG and Dropout perform best on this measure. We plot the histograms of predictive entropies on the indomain (classes that were trained on) and outofdomain (classes that were not trained on) in Figure A.9 for a qualitative comparison.
In this paper we developed SWAGaussian (SWAG) for approximate Bayesian inference in deep learning. There has been a great desire to apply Bayesian methods in deep learning due to their theoretical properties and past success with small neural networks. We view SWAG as a step towards practical, scalable, and accurate Bayesian deep learning for large modern neural networks.
A key geometric observation in this paper is that the posterior distribution over neural network parameters is close to Gaussian in the subspace spanned by the trajectory of SGD. Further, this Gaussian distribution is remarkably wellaligned with the principal components (eigenvectors of sample covariance matrix) of the set of SGD iterates. The subspace spanned by SGD iterates is of interest for several reasons. First, our work shows that Bayesian model averaging within this subspace can improve predictions over SGD or SWA solutions. Second, GurAri et al. (2019) argue that the SGD trajectory lies in the subspace spanned by the eigenvectors of the Hessian corresponding to the few top eigenvalues. This result suggests that we would expect the subspace containing SGD trajectory to correspond to directions of rapid change in the predictions of the model.
Acknowledgements WM, PI, and AGW were supported by an Amazon Research Award, Facebook Research, and NSF IIS1563887. WM was additionally supported by an NSF Graduate Research Fellowship under Grant No. DGE1650441. We would like to thank Jacob Gardner for helpful discussions.
References
 Asmussen & Glynn (2007) Asmussen, S. and Glynn, P. W. Stochastic simulation: algorithms and analysis. Number 57 in Stochastic modelling and applied probability. Springer, New York, 2007. ISBN 9780387306797 9780387690339. OCLC: ocn123113652.
 Athiwaratkun et al. (2019) Athiwaratkun, B., Finzi, M., Izmailov, P., and Wilson, A. G. There are many consistent explanations for unlabeled data: why you should average. In International Conference on Learning Representations, 2019. URL http://arxiv.org/abs/1806.05594. arXiv: 1806.05594.
 Blier & Ollivier (2018) Blier, L. and Ollivier, Y. The Description Length of Deep Learning models. In Advances in Neural Information Processing Systems, pp. 11, 2018.
 Blundell et al. (2015) Blundell, C., Cornebise, J., Kavukcuoglu, K., and Wierstra, D. Weight Uncertainty in Neural Networks. In International Conference on Machine Learning, 2015. URL http://arxiv.org/abs/1505.05424. arXiv: 1505.05424.
 Chaudhari & Soatto (2018) Chaudhari, P. and Soatto, S. Stochastic gradient descent performs variational inference, converges to limit cycles for deep networks. In International Conference on Learning Representations, 2018. URL http://arxiv.org/abs/1710.11029. arXiv: 1710.11029.
 Chen et al. (2014) Chen, T., Fox, E. B., and Guestrin, C. Stochastic Gradient Hamiltonian Monte Carlo. In International Conference on Machine Learning, 2014. URL http://arxiv.org/abs/1402.4102. arXiv: 1402.4102.
 Chen et al. (2016) Chen, X., Lee, J. D., Tong, X. T., and Zhang, Y. Statistical Inference for Model Parameters in Stochastic Gradient Descent. arXiv: 1610.08637, October 2016. URL http://arxiv.org/abs/1610.08637.
 Coates et al. (2011) Coates, A., Ng, A., and Lee, H. An Analysis of SingleLayer Networks in Unsupervised Feature Learning. In Proceedings of the Fourteenth International Conference on Artificial Intelligence and Statistics, pp. 215–223, June 2011. URL http://proceedings.mlr.press/v15/coates11a.html.
 Dasgupta & Hsu (2007) Dasgupta, S. and Hsu, D. OnLine Estimation with the Multivariate Gaussian Distribution. In Bshouty, N. H. and Gentile, C. (eds.), Twentieth Annual Conference on Learning Theory., volume 4539, pp. 278–292, Berlin, Heidelberg, 2007. Springer Berlin Heidelberg. ISBN 9783540729259. doi: 10.1007/9783540729273˙21. URL http://link.springer.com/10.1007/9783540729273_21.
 Draxler et al. (2018) Draxler, F., Veschgini, K., Salmhofer, M., and Hamprecht, F. A. Essentially No Barriers in Neural Network Energy Landscape. In International Conference on Machine Learning, pp. 10, 2018.
 Gal & Ghahramani (2016) Gal, Y. and Ghahramani, Z. Dropout as a Bayesian Approximation. In International Conference on Machine Learning, 2016. URL http://proceedings.mlr.press/v48/gal16.pdf.
 Gal et al. (2017) Gal, Y., Hron, J., and Kendall, A. Concrete Dropout. In Advances in Neural Information Processing Systems, 2017. URL http://arxiv.org/abs/1705.07832. arXiv: 1705.07832.
 Gardner et al. (2018) Gardner, J., Pleiss, G., Weinberger, K. Q., Bindel, D., and Wilson, A. G. Gpytorch: Blackbox matrixmatrix gaussian process inference with gpu acceleration. In Advances in Neural Information Processing Systems, pp. 7587–7597, 2018.
 Garipov et al. (2018) Garipov, T., Izmailov, P., Podoprikhin, D., Vetrov, D., and Wilson, A. G. Loss Surfaces, Mode Connectivity, and Fast Ensembling of DNNs. In Advances in Neural Information Processing Systems, 2018. URL http://arxiv.org/abs/1802.10026. arXiv: 1802.10026.
 Graves (2011) Graves, A. Practical variational inference for neural networks. In Advances in neural information processing systems, pp. 2348–2356, 2011.
 Guo et al. (2017) Guo, C., Pleiss, G., Sun, Y., and Weinberger, K. Q. On Calibration of Modern Neural Networks. In International Conference on Machine Learning, June 2017. URL http://arxiv.org/abs/1706.04599. arXiv: 1706.04599.
 GurAri et al. (2019) GurAri, G., Roberts, D. A., and Dyer, E. Gradient descent happens in a tiny subspace, 2019. URL https://openreview.net/forum?id=ByeTHsAqtX.
 Halko et al. (2011) Halko, N., Martinsson, P.G., and Tropp, J. A. Finding structure with randomness: Probabilistic algorithms for constructing approximate matrix decompositions. SIAM review, 53(2):217–288, 2011.
 He et al. (2016) He, K., Zhang, X., Ren, S., and Sun, J. Deep Residual Learning for Image Recognition. In CVPR, 2016. URL http://arxiv.org/abs/1512.03385. arXiv: 1512.03385.
 Huang et al. (2017) Huang, G., Liu, Z., van der Maaten, L., and Weinberger, K. Q. Densely Connected Convolutional Networks. In CVPR, 2017. URL http://arxiv.org/abs/1608.06993. arXiv: 1608.06993.
 Ioffe & Szegedy (2015) Ioffe, S. and Szegedy, C. Batch normalization: Accelerating deep network training by reducing internal covariate shift. arXiv preprint arXiv:1502.03167, 2015.
 Izmailov et al. (2018) Izmailov, P., Podoprikhin, D., Garipov, T., Vetrov, D., and Wilson, A. G. Averaging Weights Leads to Wider Optima and Better Generalization. In Uncertainty in Artificial Intelligence, 2018, 2018. URL http://arxiv.org/abs/1803.05407. arXiv: 1803.05407.
 Kendall & Gal (2017) Kendall, A. and Gal, Y. What Uncertainties Do We Need in Bayesian Deep Learning for Computer Vision? In Advances in Neural Information Processing Systems, Long Beach, 2017. URL https://arxiv.org/pdf/1703.04977.pdf.
 Keskar et al. (2017) Keskar, N. S., Mudigere, D., Nocedal, J., Smelyanskiy, M., and Tang, P. T. P. On LargeBatch Training for Deep Learning: Generalization Gap and Sharp Minima. In International Conference on Learning Representations, 2017. URL http://arxiv.org/abs/1609.04836. arXiv: 1609.04836.
 Khan et al. (2018) Khan, M. E., Nielsen, D., Tangkaratt, V., Lin, W., Gal, Y., and Srivastava, A. Fast and Scalable Bayesian Deep Learning by WeightPerturbation in Adam. In International Conference on Machine Learning, 2018. URL http://arxiv.org/abs/1806.04854. arXiv: 1806.04854.
 Kingma & Welling (2013) Kingma, D. P. and Welling, M. Autoencoding variational bayes. In International Conference on Learning Representations, 2013.
 Kingma et al. (2015a) Kingma, D. P., Salimans, T., and Welling, M. Variational dropout and the local reparameterization trick. In Advances in Neural Information Processing Systems, pp. 2575–2583, 2015a.
 Kingma et al. (2015b) Kingma, D. P., Salimans, T., and Welling, M. Variational Dropout and the Local Reparameterization Trick. arXiv:1506.02557 [cs, stat], June 2015b. URL http://arxiv.org/abs/1506.02557. arXiv: 1506.02557.
 Kirkpatrick et al. (2017) Kirkpatrick, J., Pascanu, R., Rabinowitz, N., Veness, J., Desjardins, G., Rusu, A. A., Milan, K., Quan, J., Ramalho, T., GrabskaBarwinska, A., et al. Overcoming catastrophic forgetting in neural networks. Proceedings of the national academy of sciences, pp. 201611835, 2017.
 Kuleshov et al. (2018) Kuleshov, V., Fenner, N., and Ermon, S. Accurate Uncertainties for Deep Learning Using Calibrated Regression. In International Conference on Machine Learning, pp. 9, 2018.
 Lakshminarayanan et al. (2017) Lakshminarayanan, B., Pritzel, A., and Blundell, C. Simple and Scalable Predictive Uncertainty Estimation using Deep Ensembles. In Advances in Neural Information Processing Systems, 2017.
 Li et al. (2018) Li, H., Xu, Z., Taylor, G., Studer, C., and Goldstein, T. Visualizing the Loss Landscape of Neural Nets. In Advances in Neural Information Processing Systems, 2018. URL http://arxiv.org/abs/1712.09913. arXiv: 1712.09913.
 Loshchilov & Hutter (2019) Loshchilov, I. and Hutter, F. Decoupled Weight Decay Regularization. In International Conference on Learning Representations, 2019. URL http://arxiv.org/abs/1711.05101. arXiv: 1711.05101.
 Louizos & Welling (2017) Louizos, C. and Welling, M. Multiplicative normalizing flows for variational bayesian neural networks. In International Conference on Machine Learning, 2017.
 MacKay (1992a) MacKay, D. J. C. Bayesian Interpolation. Neural Computation, 1992a. URL https://www.mitpressjournals.org/doi/pdf/10.1162/neco.1992.4.3.415.
 MacKay (1992b) MacKay, D. J. C. A Practical Bayesian Framework for Backpropagation Networks. Neural Computation, 4(3):448–472, May 1992b. ISSN 08997667, 1530888X. doi: 10.1162/neco.1992.4.3.448. URL http://www.mitpressjournals.org/doi/10.1162/neco.1992.4.3.448.
 MacKay (2003) MacKay, D. J. C. Information theory, inference, and learning algorithms. Cambridge University Press, Cambridge, UK ; New York, 2003. ISBN 9780521642989.
 Maddox et al. (2018) Maddox, W., Garipov, T., Izmailov, P., Vetrov, D., and Wilson, A. G. Fast uncertainty estimates and Bayesian model averaging of DNNs. Uncertainty in Deep Learning Workshop at UAI, 2018.
 Mandt et al. (2017) Mandt, S., Hoffman, M. D., and Blei, D. M. Stochastic Gradient Descent as Approximate Bayesian Inference. JMLR, 18:1–35, 2017.
 Molchanov et al. (2017) Molchanov, D., Ashukha, A., and Vetrov, D. Variational dropout sparsifies deep neural networks. arXiv preprint arXiv:1701.05369, 2017.
 Mukhoti & Gal (2018) Mukhoti, J. and Gal, Y. Evaluating Bayesian Deep Learning Methods for Semantic Segmentation, November 2018. URL https://arxiv.org/abs/1811.12709v1.
 Naeini et al. (2015) Naeini, M. P., Cooper, G. F., and Hauskrecht, M. Obtaining well calibrated probabilities using bayesian binning. In AAAI, pp. 2901–2907, 2015.
 Neal (1996) Neal, R. M. Bayesian Learning for Neural Networks, volume 118 of Lecture Notes in Statistics. Springer New York, New York, NY, 1996. ISBN 9780387947242 9781461207450. URL http://link.springer.com/10.1007/9781461207450.
 NiculescuMizil & Caruana (2005) NiculescuMizil, A. and Caruana, R. Predicting good probabilities with supervised learning. In International Conference on Machine Learning, pp. 625–632, Bonn, Germany, 2005. ACM Press. ISBN 9781595931801. doi: 10.1145/1102351.1102430. URL http://portal.acm.org/citation.cfm?doid=1102351.1102430.
 Nikishin et al. (2018) Nikishin, E., Izmailov, P., Athiwaratkun, B., Podoprikhin, D., Garipov, T., Shvechikov, P., Vetrov, D., and Wilson, A. G. Improving Stability in Deep Reinforcement Learning with Weight Averaging. In Uncertainty in Artificial Intelligence Workshop on Uncertainty in Deep Learning, pp. 5, 2018.
 Polyak & Juditsky (1992) Polyak, B. T. and Juditsky, A. B. Acceleration of Stochastic Approximation by Averaging. SIAM Journal on Control and Optimization, 30(4):838–855, July 1992. ISSN 03630129, 10957138. doi: 10.1137/0330046. URL http://epubs.siam.org/doi/10.1137/0330046.
 Rakhlin et al. (2011) Rakhlin, A., Shamir, O., and Sridharan, K. Making Gradient Descent Optimal for Strongly Convex Stochastic Optimization. In International Conference on Machine Learning, September 2011. URL http://arxiv.org/abs/1109.5647. arXiv: 1109.5647.
 Ritter et al. (2018a) Ritter, H., Botev, A., and Barber, D. Online Structured Laplace Approximations For Overcoming Catastrophic Forgetting. In Advances in Neural Information Processing Systems, 2018a. URL http://arxiv.org/abs/1805.07810. arXiv: 1805.07810.
 Ritter et al. (2018b) Ritter, H., Botev, A., and Barber, D. A Scalable Laplace Approximation for Neural Networks. In International Conference on Learning Representations, 2018b.
 Ruppert (1988) Ruppert, D. Efficient Estimators from a Slowly Convergent RobbinsMunro Process. Technical Report 781, Cornell University, School of Operations Report and Industrial Engineering, 1988. URL https://ecommons.cornell.edu/bitstream/handle/1813/8664/TR000781.pdf?sequence=1&isAllowed=y.
 Russakovsky et al. (2015) Russakovsky, O., Deng, J., Su, H., Krause, J., Satheesh, S., Ma, S., Huang, Z., Karpathy, A., Khosla, A., Bernstein, M., Berg, A. C., and FeiFei, L. ImageNet Large Scale Visual Recognition Challenge. IJCV, 115(3):211–252, 2015. URL http://arxiv.org/abs/1409.0575. arXiv: 1409.0575.
 Saatci & Wilson (2017) Saatci, Y. and Wilson, A. G. Bayesian gan. In Guyon, I., Luxburg, U. V., Bengio, S., Wallach, H., Fergus, R., Vishwanathan, S., and Garnett, R. (eds.), Advances in Neural Information Processing Systems 30, pp. 3622–3631. Curran Associates, Inc., 2017. URL http://papers.nips.cc/paper/6953bayesiangan.pdf.
 Sagun et al. (2018) Sagun, L., Evci, U., Guney, V. U., Dauphin, Y., and Bottou, L. Empirical Analysis of the Hessian of OverParametrized Neural Networks. In International Conference on Learning Representations Workshop Track, 2018. URL http://arxiv.org/abs/1706.04454. arXiv: 1706.04454.
 Vaart (1998) Vaart, A. W. v. d. Asymptotic Statistics. Cambridge Series in Statistical and Probabilistic Mathematics. Cambridge University Press, Cambridge, 1998. ISBN 9780521784504. doi: 10.1017/CBO9780511802256. URL https://www.cambridge.org/core/books/asymptoticstatistics/A3C7DAD3F7E66A1FA60E9C8FE132EE1D.
 Yazici et al. (2019) Yazici, Y., Foo, C.S., Winkler, S., Yap, K.H., Piliouras, G., and Chandrasekhar, V. The Unusual Effectiveness of Averaging in GAN Training. In International Conference on Learning Representations, 2019. URL http://arxiv.org/abs/1806.04498. arXiv: 1806.04498.
 Zagoruyko & Komodakis (2016) Zagoruyko, S. and Komodakis, N. Wide Residual Networks. In BMVC, 2016. URL http://arxiv.org/abs/1605.07146. arXiv: 1605.07146.
 Zhang et al. (2017) Zhang, G., Sun, S., Duvenaud, D., and Grosse, R. Noisy Natural Gradient as Variational Inference. arXiv:1712.02390 [cs, stat], December 2017. URL http://arxiv.org/abs/1712.02390. arXiv: 1712.02390.
To estimate $tr(C)$ from the gradient noise we need to divide the estimated variance by the batch size (as $V(\widehat{g}(\theta ))=BC(\theta )$), for a correct version of Equation 5. From Assumption 1 and Equation 6 of Mandt et al. (2017), we see that
$$\widehat{g}(\theta )\approx g(\theta )+\frac{1}{\sqrt{B}}\nabla g(\theta ),\nabla g(\theta )\sim N(0,C(\theta )),$$ 
where $B$ is the batch size. Thus, collecting the variance of $\widehat{g}(\theta )$ (the variance of the stochastic gradients) will give estimates that are upscaled by a factor of $B$:
$$\eta \approx 2\frac{{B}^{2}}{N}\frac{d}{tr(V(\widehat{g}(\theta )))}.$$ 
To include momentum, we can repeat the analysis in Sections 4.1 and 4.3 Mandt et al. (2017) finding that this also involves scaling the optimal learning rate but by a factor of $\mu ,$ the momentum term. This gives the final optimal learning rate equation as
$$\eta \approx 2\frac{\mu {B}^{2}}{N}\frac{d}{tr(V(\widehat{g}(\theta )))}.$$  (6) 
In Figure 3(b), we computed $tr(C)$ for VGG16 and PreResNet164 on CIFAR100 beginning from the start of training (referred to as from scratch), as well as the start of the SWAG procedure (referred to in the legend as SWA). We see that $tr(C)$ is never quite constant when trained from scratch, while for a period of constant learning rate near the end of training, referred to as the stationary phase, $tr(C)$ is essentially constant throughout. This discrepancy is likely due to large gradients at the very beginning of training, indicating that the stationary distribution has not been reached yet.
Next, in Figure 3(a), we used the computed $tr(C)$ estimate for all four models and Equation 6 to compute the optimal learning rate under the assumptions of Mandt et al. (2017), finding that these learning rates are not constant for the estimates beginning at the start of training and that they are always way too large (3,000 at the minimum compared to a standard learning rate of 0.1).
In Figure 5 we present plots analogous to those in Section id1 for PreResNet110 and VGG16 on CIFAR10 and CIFAR100. For all datasetarchitecture pairs we see that SWAG is able to capture the geometry of the posterior in the subspace spanned by SGD trajectory.
In this section, we discuss the hyperparameters in SWAG, as well as some current theoretical limitations.
We now evaluate the effect of the covariance matrix rank on the SWAG approximation. To do so, we trained a PreResNet56 on CIFAR100 with SWAG beginning from epoch 161, and evaluated 30 sample Bayesian model averages obtained at different epochs; the accuracy plot from this experiment is shown in Figure 6 (a). The rank of each model after epoch 161 is simply $\mathrm{min}(epoch161,140),$ and we can see that the 30 samples from even a low rank approximation reach the same predictive accuracy as the SWA model. Interestingly, both SWAG and SWA outperform ensembles of a SGD run and ensembles of the SGD models in the SWA run.
In most situations where SWAG will be used, no closed form expression for the integral $\int f(y)q(\theta y)\mathit{d}\theta ,$ will exist. Thus, Monte Carlo approximations will be used; Monte Carlo integration converges at a rate of $1/\sqrt{K},$ where $K$ is the number of samples used, but practically good results may be found with very few samples (e.g. Chapter 29 of MacKay (2003)).
To test how many samples are needed for good predictive accuracy in a Bayesian model averaging task, we used a rank 20 approximation for SWAG and then tested the NLL on the test set as a function of the number of samples for WideResNet28x10 (Zagoruyko & Komodakis, 2016) on CIFAR100.
The results from this experiment are shown in Figure 6 (b, c), where it is possible to see that about 3 samples will match the SWA result for NLL, with about 30 samples necessary for stable accuracy (about the same as SWA for this network). In most of our experiments, we used 30 samples for consistency. In practice, we suggest tuning this number by looking at a validation set as well as the computational resources available and comparing to the free SWA predictions that come with SWAG.
First, we note that the covariance, $\mathrm{\Sigma},$ estimated using SWAG, is a function of the learning rate (and momentum) for SGD. While the theoretical work of Mandt et al. (2017) suggests that it is possible to optimally set the learning rate, our experiments in Section id1 show that currently the assumptions of the theory do not match the empirical reality in deep learning. In the linear setting as in Mandt et al. (2017), the learning rate controls the scale of the asymptotic covariance matrix. If the optimal learning rate (Equation 5) is used in this setting, the covariance matches the true posterior. To attempt to disassociate the learning rate from the covariance in practice, we rescale the covariance matrix when sampling by a constant factor for a WideResNet on CIFAR100 shown in Figure 6 (d).
Over several replications, we found that a scale of 0.5 worked best, which is expected because the low rank plus diagonal covariance incorporates the variance twice (once for the diagonal component and once from the low rank component).
One possible slowdown of SWAG at inference time is in the usage of updated batch norm parameters. Following Izmailov et al. (2018), we found that in order for the averaging and sampling to work well, it was necessary to update the batch norm parameters of networks after sampling a new model. This is shown in Figure 7 for a WideResNet on CIFAR100 for two independently trained models.
From our experimental findings, we see that given an equal amount of training time, SWAG typically outperforms other methods for uncertainty calibration. SWAG additionally does not require a validation set like temperature scaling and Platt scaling (e.g. Guo et al. (2017); Kuleshov et al. (2018)). SWAG also appears to have a distinct advantage over temperature scaling, and other popular alternatives, when the target data are from a different distribution than the training data, as shown by our transfer learning experiments.
Deep ensembles (Lakshminarayanan et al., 2017) require several times longer training for equal calibration, but often perform somewhat better due to incorporating several independent training runs. Thus SWAG will be particularly valuable when training time is limited, but inference time may not be. One possible application is thus in medical applications when image sizes (for semantic segmentation) are large, but predictions can be parallelized and may not have to be instantaneous.
In this section we describe all of the architectures and hyperparameters we use in Sections id1, id1.
On ImageNet we use architecture implementations and pretrained weights from https://github.com/pytorch/vision/tree/master/torchvision. For the experiments on CIFAR datasets we adapted the following implementations:
 •

•
PreactivationResNet$164$: https://github.com/bearpaw/pytorchclassification/blob/master/models/cifar/preresnet.py
 •
For all datasets and architectures we use the same piecewise constant learning rate schedule and weight decay as in Izmailov et al. (2018), except we train PreResNet for $300$ epochs and start averaging after epoch $160$ in SWAG and SWA. For all of the methods we are using our own implementations in PyTorch. We describe the hyperparameters for all experiments for each model:
We use the same hyperparameters as Izmailov et al. (2018) on CIFAR datasets. On ImageNet we used a constant learning rate of ${10}^{3}$ instead of the cyclical schedule, and averaged $4$ models per epoch. We adapt the code from https://github.com/timgaripov/swa for our implementation of SWA.
In all experiments we use rank $K=20$ and use $30$ weight samples for Bayesian model averaging. We reuse all the other hyperparameters from SWA.
For our implementation we adapt the code for KFAC Fisher approximation from https://github.com/Thrandis/EKFACpytorch and implement our own code for sampling. Following (Ritter et al., 2018b) we tune the scale of the approximation on validation set for every model and dataset.
In order to implement MCdropout we add dropout layers before each weight layer and sample $30$ different dropout masks for Bayesian model averaging at inference time. To choose the learning rate, we ran the models with drop out rates in the set $\{0.1,0.05,0.01\}$ and chose the one that performed best on validation data. For both VGG16 and WideResNet28x10 we found that dropout rate of $0.05$ worked best and used it in all experiments. On PreResNet164 we couldn’t achieve reasonable performance with any of the three dropout rates, which has been reported from the work of He et al. (2016). We report the results for MCDropout in combination with both SWA (SWADrop) and SGD (SGDDrop) training.
For SWA and SGD solutions we picked the optimal temperature by minimizing negative loglikelihood on validation data, adapting the code from https://github.com/gpleiss/temperature_scaling.
On CIFAR datasets for tuning hyperparameters we used the last $5000$ training data points as a validation set. On ImageNet we used $5000$ of test data points for validation. On the transfer task for CIFAR10 to STL10, we report accuracy on all 10 STL10 classes even though frogs are not a part of the STL10 test set (and monkeys are not a part of the CIFAR10 training set).
We provide test accuracies and negative loglikelihoods (NLL) for all methods and datasets in Tables 3 and 4 respectively. We observe that SWAG is competitive with SWA, SWA with temperature scaling and SWADropout in terms of test accuracy, and typically outperforms all the baselines in terms of NLL. SWAGDiagonal is generally inferior to SWAG for loglikelihood, but outperforms SWA.
In Table 2 we additionally report expected calibration error (ECE, Naeini et al., 2015), a metric of calibration of the predictive uncertainties. To compute ECE for a given model we split the test points into $20$ bins based on the confidence of the model, and we compute the absolute value of the difference of the average confidence and accuracy within each bin, and average the obtained values over all bins. Please refer to (Naeini et al., 2015; Guo et al., 2017) for more details. We observe that SWAG is competitive with temperature scaling for ECE. Again, SWAGDiagonal achieves better calibration than SWA, but using the lowrank plus diagonal covariance approximation in SWAG leads to substantially improved performance.
We provide the additional reliability diagrams for all methods and datasets in Figure 8. SWAG consistently improves calibration over SWA, and performs on par or better than temperature scaling. In transfer learning temperature scaling fails to achieve good calibration, while SWAG still provides a significant improvement over SWA.
Table 1 shows the computed symmetrized, discretized KL distance between in and out of sample distributions for the CIFAR5 out of sample image detection class. We used the same bins as in Figure 9 to discretize the entropy distributions, then smoothed these bins by a factor of 1e7 before calculating $KL(\text{IN}\text{OUT})+KL(\text{OUT}\text{IN})$ using the scipy.stats.entropy function. We can see even qualitatively that the distributions are more distinct for SWAG and SWAGDiagonal than for the other methods, particularly temperature scaling.
Dataset  Model  SGD  SWA  SWAGDiag  SWAG  KFACLaplace  SWADropout  SWATemp 

CIFAR10  VGG16  $0.0483\pm 0.0022$  $0.0408\pm 0.0019$  $0.0267\pm 0.0025$  $0.0158\pm 0.0030$  $\mathbf{0.0094}\pm 0.0005$  $0.0284\pm 0.0036$  $0.0366\pm 0.0063$ 
CIFAR10  PreResNet164  $0.0255\pm 0.0009$  $0.0203\pm 0.0010$  $0.0082\pm 0.0008$  $\mathbf{0.0053}\pm 0.0004$  $0.0092\pm 0.0018$  $0.0162\pm 0.0000$  $0.0172\pm 0.0010$ 
CIFAR10  WideResNet28x10  $0.0166\pm 0.0007$  $0.0087\pm 0.0002$  $\mathbf{0.0047}\pm 0.0013$  $0.0088\pm 0.0006$  $0.0060\pm 0.0003$  $0.0094\pm 0.0014$  $0.0080\pm 0.0007$ 
CIFAR100  VGG16  $0.1870\pm 0.0014$  $0.1514\pm 0.0032$  $0.0819\pm 0.0021$  $0.0395\pm 0.0061$  $0.0778\pm 0.0054$  $0.1108\pm 0.0181$  $\mathbf{0.0291}\pm 0.0097$ 
CIFAR100  PreResNet164  $0.1012\pm 0.0009$  $0.0700\pm 0.0056$  $0.0239\pm 0.0047$  $0.0587\pm 0.0048$  $\mathbf{0.0158}\pm 0.0014$  $0.0175\pm 0.0037$  
CIFAR100  WideResNet28x10  $0.0479\pm 0.0010$  $0.0684\pm 0.0022$  $0.0322\pm 0.0018$  $\mathbf{0.0113}\pm 0.0020$  $0.0379\pm 0.0047$  $0.0574\pm 0.0028$  $0.0220\pm 0.0007$ 
ImageNet  DenseNet161  $0.0545\pm 0.0000$  $0.0509\pm 0.0000$  $0.0459\pm 0.0000$  $0.0204\pm 0.0000$  $\mathbf{0.0190}\pm 0.0000$  
ImageNet  ResNet152  $0.0478\pm 0.0000$  $0.0605\pm 0.0000$  $0.0566\pm 0.0000$  $0.0279\pm 0.0000$  $\mathbf{0.0183}\pm 0.0000$  
CIFAR10 $\to $ STL10  VGG16  $0.2149\pm 0.0027$  $0.2082\pm 0.0056$  $0.1719\pm 0.0075$  $\mathbf{0.1463}\pm 0.0075$  $0.1803\pm 0.0024$  $0.2089\pm 0.0055$  
CIFAR10 $\to $ STL10  PreResNet164  $0.1758\pm 0.0000$  $0.1739\pm 0.0000$  $0.1312\pm 0.0000$  $\mathbf{0.1110}\pm 0.0000$  $0.1646\pm 0.0000$  
CIFAR10 $\to $ STL10  WideResNet28x10  $0.1561\pm 0.0000$  $0.1413\pm 0.0000$  $0.1241\pm 0.0000$  $\mathbf{0.1017}\pm 0.0000$  $0.1421\pm 0.0000$  $0.1371\pm 0.0000$ 
Dataset  Model  SGD  SWA  SWAGDiag  SWAG  KFACLaplace  SWADropout  SWATemp 

CIFAR10  VGG16  $0.3285\pm 0.0139$  $0.2621\pm 0.0104$  $0.2200\pm 0.0078$  $\mathbf{0.2016}\pm 0.0031$  $0.2252\pm 0.0032$  $0.2328\pm 0.0049$  $0.2481\pm 0.0245$ 
CIFAR10  PreResNet164  $0.1814\pm 0.0025$  $0.1450\pm 0.0042$  $0.1251\pm 0.0029$  $\mathbf{0.1232}\pm 0.0022$  $0.1471\pm 0.0012$  $0.1270\pm 0.0000$  $0.1347\pm 0.0038$ 
CIFAR10  WideResNet28x10  $0.1294\pm 0.0022$  $0.1075\pm 0.0004$  $0.1077\pm 0.0009$  $0.1122\pm 0.0009$  $0.1210\pm 0.0020$  $0.1094\pm 0.0021$  $\mathbf{0.1064}\pm 0.0004$ 
CIFAR100  VGG16  $1.7308\pm 0.0137$  $1.2780\pm 0.0051$  $1.0163\pm 0.0032$  $\mathbf{0.9480}\pm 0.0038$  $1.1915\pm 0.0199$  $1.1872\pm 0.0524$  $1.0386\pm 0.0126$ 
CIFAR100  PreResNet164  $0.9465\pm 0.0191$  $0.7370\pm 0.0265$  $0.6837\pm 0.0186$  $0.7081\pm 0.0162$  $0.7881\pm 0.0025$  $\mathbf{0.6770}\pm 0.0191$  
CIFAR100  WideResNet28x10  $0.7958\pm 0.0089$  $0.6684\pm 0.0034$  $0.6150\pm 0.0029$  $\mathbf{0.6078}\pm 0.0006$  $0.7692\pm 0.0092$  $0.6500\pm 0.0049$  $0.6134\pm 0.0023$ 
ImageNet  DenseNet161  $0.9094\pm 0.0000$  $0.8655\pm 0.0000$  $0.8559\pm 0.0000$  $\mathbf{0.8303}\pm 0.0000$  $0.8359\pm 0.0000$  
ImageNet  ResNet152  $0.8716\pm 0.0000$  $0.8682\pm 0.0000$  $0.8584\pm 0.0000$  $\mathbf{0.8205}\pm 0.0000$  $0.8226\pm 0.0000$  
CIFAR10 $\to $ STL10  VGG16  $1.6528\pm 0.0390$  $1.3993\pm 0.0502$  $1.2258\pm 0.0446$  $\mathbf{1.1402}\pm 0.0342$  $1.3133\pm 0.0000$  $1.4082\pm 0.0506$  
CIFAR10 $\to $ STL10  PreResNet164  $1.4790\pm 0.0000$  $1.3552\pm 0.0000$  $1.0700\pm 0.0000$  $\mathbf{0.9706}\pm 0.0000$  $1.2228\pm 0.0000$  
CIFAR10 $\to $ STL10  WideResNet28x10  $1.1308\pm 0.0000$  $1.0047\pm 0.0000$  $0.9340\pm 0.0000$  $\mathbf{0.8710}\pm 0.0000$  $0.9914\pm 0.0000$  $0.9706\pm 0.0000$ 
Dataset  Model  SGD  SWA  SWAGDiag  SWAG  KFACLaplace  SWADropout  SWATemp 

CIFAR10  VGG16  $93.17\pm 0.14$  $93.61\pm 0.11$  $\mathbf{93.66}\pm 0.15$  $93.60\pm 0.10$  $92.65\pm 0.20$  $93.23\pm 0.36$  $93.61\pm 0.11$ 
CIFAR10  PreResNet164  $95.49\pm 0.06$  $96.09\pm 0.08$  $96.03\pm 0.10$  $96.03\pm 0.02$  $95.49\pm 0.06$  $\mathbf{96.18}\pm 0.00$  $96.09\pm 0.08$ 
CIFAR10  WideResNet28x10  $96.41\pm 0.10$  $\mathbf{96.46}\pm 0.04$  $96.41\pm 0.05$  $96.32\pm 0.08$  $96.17\pm 0.00$  $96.39\pm 0.09$  $96.46\pm 0.04$ 
CIFAR100  VGG16  $73.15\pm 0.11$  $74.30\pm 0.22$  $74.68\pm 0.22$  $\mathbf{74.77}\pm 0.09$  $72.38\pm 0.23$  $72.50\pm 0.54$  $74.30\pm 0.22$ 
CIFAR100  PreResNet164  $78.50\pm 0.32$  $\mathbf{80.19}\pm 0.52$  $80.18\pm 0.50$  $79.90\pm 0.50$  $78.51\pm 0.05$  $80.19\pm 0.52$  
CIFAR100  WideResNet28x10  $80.76\pm 0.29$  $82.40\pm 0.16$  $\mathbf{82.40}\pm 0.09$  $82.23\pm 0.19$  $80.94\pm 0.41$  $82.30\pm 0.19$  $82.40\pm 0.16$ 
ImageNet  DenseNet161  $77.79\pm 0.00$  $\mathbf{78.60}\pm 0.00$  $78.59\pm 0.00$  $78.59\pm 0.00$  $78.60\pm 0.00$  
ImageNet  ResNet152  $78.39\pm 0.00$  $78.92\pm 0.00$  $78.96\pm 0.00$  $\mathbf{79.08}\pm 0.00$  $78.92\pm 0.00$  
CIFAR10 $\to $ STL10  VGG16  $\mathbf{72.42}\pm 0.07$  $71.92\pm 0.01$  $72.09\pm 0.04$  $72.19\pm 0.06$  $71.45\pm 0.11$  $71.92\pm 0.01$  
CIFAR10 $\to $ STL10  PreResNet164  $75.56\pm 0.00$  $\mathbf{76.02}\pm 0.00$  $75.95\pm 0.00$  $75.88\pm 0.00$  $76.02\pm 0.00$  
CIFAR10 $\to $ STL10  WideResNet28x10  $76.75\pm 0.00$  $\mathbf{77.50}\pm 0.00$  $77.26\pm 0.00$  $77.09\pm 0.00$  $76.91\pm 0.00$  $77.50\pm 0.00$ 