Amortized Population Gibbs Samplers with Neural Sufficient Statistics

  • 2019-11-04 18:10:11
  • Hao Wu, Heiko Zimmermann, Eli Sennesh, Tuan Anh Le, Jan-Willem van de Meent
  • 30

Abstract

We develop amortized population Gibbs (APG) samplers, a new class ofautoencoding variational methods for deep probabilistic models. APG samplersconstruct high-dimensional proposals by iterating over updates tolower-dimensional blocks of variables. Each conditional update is a neuralproposal, which we train by minimizing the inclusive KL divergence relative tothe conditional posterior. To appropriately account for the size of the inputdata, we develop a new parameterization in terms of neural sufficientstatistics, resulting in quasi-conjugate variational approximations.Experiments demonstrate that learned proposals converge to the known analyticalconditional posterior in conjugate models, and that APG samplers can learninference networks for highly-structured deep generative models when theconditional posteriors are intractable. Here APG samplers offer a path towardscaling up stochastic variational methods to models in which standardautoencoding architectures fail to produce accurate samples.

 

Quick Read (beta)

Abstract

We develop amortized population Gibbs (APG) samplers, a new class of autoencoding variational methods for deep probabilistic models. APG samplers construct high-dimensional proposals by iterating over updates to lower-dimensional blocks of variables. Each conditional update is a neural proposal, which we train by minimizing the inclusive KL divergence relative to the conditional posterior. To appropriately account for the size of the input data, we develop a new parameterization in terms of neural sufficient statistics, resulting in quasi-conjugate variational approximations. Experiments demonstrate that learned proposals converge to the known analytical conditional posterior in conjugate models, and that APG samplers can learn inference networks for highly-structured deep generative models when the conditional posteriors are intractable. Here APG samplers offer a path toward scaling up stochastic variational methods to models in which standard autoencoding architectures fail to produce accurate samples.

References

\algnewcommand\algorithmicassert

assert \algnewcommand\Assert[1]\State \algorithmicassert(#1) \algnewcommand\algorithmicswitchswitch \algnewcommand\algorithmiccasecase \algnewcommand\algorithmicdefaultdefault \algdefSE[SWITCH]SwitchEndSwitch[1]\algorithmicswitch #1 \algorithmicdo\algorithmicend \algorithmicswitch \algdefSE[CASE]CaseEndCase[1]\algorithmiccase #1\algorithmicend \algorithmiccase \algdefSE[DEFAULT]DefaultEndDefault[1]\algorithmicdefault #1\algorithmicend \algorithmicdefault \algtext*EndSwitch \algtext*EndCase \algtext*EndDefault \algnewcommand\algorithmicmatchmatch \algdefSE[MATCH]MatchEndMatch[1]\algorithmicmatch #1\algorithmicend \algorithmicmatch \algtext*EndMatch \algnewcommand\algorithmictrytry \algnewcommand\algorithmiccatchcatch \algblockdefx[Try]TryEndTry\algorithmictry\algorithmicend \algorithmictry \algtext*EndTry \algcblockdefx[Catch]TryCatchEndTry[1]\algorithmiccatch #1\algorithmicend \algorithmictry\algtext*EndTry \algdefSE[QUERY]QueryEndQuery[2]query \textproc#1 \glsdisablehyper \newacronymSCFMscfmstochastic control-flow model \newacronymWSwswake-sleep \newacronymBWSbwsbasic wake-sleep \newacronymRWSrwsreweighted wake-sleep \newacronymELBOelboevidence lower bound \newacronymVAEvaevariational autoencoder \newacronymIWAEiwaeimportance weighted autoencoder \newacronymKLklKullback-Leibler \newacronymSGDsgdstochastic gradient descent \newacronymVIMCOvimcovariational inference for Monte Carlo objectives \newacronymWWwwwake-wake \newacronymWWSwwswake-wake-sleep \newacronymAIRairAttend, Infer, Repeat \newacronymESSesseffective sample size \newacronymREINFORCEreinforceReinforce gradient estimator \newacronymISisimportance sampling \newacronymGMMgmmGaussian mixture model \newacronymMNISTmnisthand-written digit dataset \newacronymRELAXrelaxRELAX gradient estimator \newacronymREBARrebarREBAR gradient estimator \newacronymPMFpmfprobability mass function \newacronymMLPmlpmultilayer perceptron \newacronymRNNrnnrecurrent neural network \newacronymPCFGpcfgprobabilistic context free grammar \newacronymADAMadamADAM \glsunsetADAM

 

Amortized Population Gibbs Samplers with Neural Sufficient Statistics


 


Hao Wu Northeastern University [email protected]                        Heiko Zimmermann Northeastern University [email protected]                        Eli Sennesh Northeastern University [email protected]

Tuan Anh Le Massachusetts Institute of Technology [email protected]                        Jan-Willem van de Meent Northeastern University [email protected]

1 Introduction

Deep probabilistic programming libraries such as Edward [tran2016edward], Pyro [bingham2018pyro], and Probabilistic Torch [siddharth2017learning] extend deep learning frameworks with functionality for deep probabilistic models which combine a generative model with an inference model that approximates the Bayesian posterior. Both models are parameterized using neural networks, which are trained using stochastic gradient descent by optimize a lower or upper bound on the log marginal likelihood. Training an inference network to perform amortized inference can be equivalently understood as a form of variational inference or adaptive importance sampling.

At present, deep probabilistic models most commonly have the form of standard variational autoencoders (VAEs) [kingma2013auto-encoding, rezende2014stochastic]. In these architectures, the generative model combines an unstructured prior (e.g. a spherical Gaussian) with a likelihood that is parameterized by an expressive neural network, often referred to as a decoder. The inference network, known as an encoder, maps input data (e.g. an image or sentence) onto an embedding vector, also known as the latent code.

Deep probabilistic programming aims to enable more general designs that incorporate structured priors for tasks such as multiple object detection [eslami2016attend], language modeling [esmaeili2019structured], or object tracking [kosiorek2018sequential]. In these domains, a prior can incorporate useful inductive biases, such as the requirement that object trajectories are smooth. These biases in turn can help guide a model to uncover patterns in the data in an unsupervised manner, and aid generalization in complex domains where the training data may not contain exemplars for all possible combinations of latent features.

However, training structured models also poses challenges that are not encountered in unstructured problems. To optimize a lower or an upper bound, we need to approximate the gradient of an expectation with a Monte Carlo estimate (see [mohamed2019monte] for a recent review). Standard VAEs rely on reparameterized estimators that can often approximate the gradient with a single sample. Unfortunately, these estimators can have a high variance in models where latent variables are high-dimensional and/or strongly correlated. Owing to these limitations, models that are trained using standard VAE objectives often consider relatively small-scale problems, such as tracking 2 objects over the course of 10 frames [kosiorek2018sequential], or assigning 10 sentences in a review to distinct aspects [esmaeili2019structured].

In this paper, we develop methods for amortized inference that are designed to scale to structured models with 100s of latent variables. We are particularly interested in the frequently arising cases of models that are characterized by a combinations of local variables, such as the time-dependent position of an object, and global variables, such as the shape of the object. In this type of model, it is often the case that knowledge of the local variables can help us make predictions about global variables and vice versa; If we know the shape of an object, then it should be easier to identify its location in an image. Conversely, if we know the position of an object in each frame, then we can more readily infer its shape.

The methods that we develop in this paper are similar in spirit to work by Johnson et al. [johnson2016composing], who developed methods for conjugate-exponential models with a neural likelihood. In this setting, we can perform inference using variational expectation maximization (EM) algorithms [beal2003variational, bishop2006pattern, wainwright2008graphical] that exploit conjugacy and conditional independence to derive closed-form updates to blocks of variables. The advantage of these approaches is that they are highly computationally efficient; variational EM can often converge in a small number of iterations and easily scales to much larger number of variables. Unfortunately, variational EM is also model-specific, difficult to implement, and only applicable to a restricted class of conjugate-exponential models.

To overcome the limitations imposed by conjugate-exponential family models, we here develop a more general approach. Rather than requiring exact EM updates we develop an importance sampling method that employs conditional proposals to iterate between updates to blocks of variables. To train these proposals, we define a a variational method that minimizes the inclusive KL divergence between the proposal update and the exact conditional posterior. We refer to the resulting class of methods as amortized Gibbs samplers, since the proposals approximate Gibbs updates.

The variational objective that we derive is not computable, since the exact Gibbs updates are in general intractable. However, we can nonetheless derive a Monte Carlo estimator for its gradient. Building on a recent body of work that employs importance samplers to train variational distributions [burda2016importance, le2018auto-encoding, maddison2017filtering, naesseth2018variational], we develop a sequential Monte Carlo sampler [delmoral2006sequential] that combines approximate Gibbs updates with resampling steps in order to construct high quality proposals, which serve both to compute gradient estimates at train time and to perform inference at test time. We demonstrate correctness of the proposed sampler by proving that samples are properly weighted relative to the generative model [naesseth2015nested].

One of the challenges in designing networks that parameterize conditional proposals is network outputs need to appropriately account for the amount of data on which we are conditioning; The conditional posterior on the mean for a cluster with a large number of points is more tightly peaked than that of a cluster with a small number of points. To address this difficulty, we propose a class of networks that we refer to as neural sufficient statistics, which define parameterizations of proposals in a manner that is additive in the local variables, much like the sufficient statistics in conjugate-exponential families.

Our experiments show that learned proposals converge to the true conditional posteriors in Gaussian mixture models, where the Gibbs updates can be computed in closed form. Moreover we establish that amortized Gibbs methods serve can a basis for scalable inference in structured deep generative models, including mixtures with neural likelihoods and unsupervised tracking models. Both of these tasks are representative of the current state-of-the art in unsupervised approaches for learning structured deep generative models.

2 Amortized Population Gibbs Samplers

We are interested in the task of jointly training a generative model pθ(x,z) by maximizing its marginal likelihood pθ(x) and learning an inference model qϕ(zx) that approximates the posterior pθ(zx). Like most amortized inference approaches, we assume that we can sample from a (possibly implicit) distribution p^(x) that either takes the form of an empirical distribution over training data or a data simulator.

As a means of generating high-quality samples in an incremental manner, we develop methods that are inspired by expectation maximization and classic Gibbs sampling strategies, which perform iterative updates to blocks of variables. Concretely, we will assume that the latent variables in the generative model decompose into blocks z={z1,,zB} and train proposals logqϕ(zbx,z-b) that update the variables in a each block zb conditioned on the variables in the remaining blocks z-b=z{zb}.

Starting with an initial sample qϕ(z1x) from a standard encoder we will generate a sequence of samples {z1,,zK} by performing conditional updates to each block zb, which we refer to as a sweep

qϕ(zkx,zk-1) =b=1Bqϕ(zbkx,zbk,zbk-1), (1)

where zb={zii<b} and zb={zii>b}. Repeatedly applying sweep updates then yields a proposal

qϕ(z1,,zKx)=qϕ(z1x)k=2Kqϕ(zkx,zk-1).

We want to train proposals that improve the quality of each sample zk relative to that of the preceding sample zk-1. There are two possible strategies for accomplishing this. One strategy is to define an objective that minimizes the discrepancy between the marginal qϕ(zKx) for the final sample and the posterior pθ(zKx). This corresponds to learning a sweep update qϕ(zkx,zk-1) that transforms the initial proposal to the posterior in exactly K sweeps. An example of this type of approach, albeit one that does not employ block updates, is the recent work on annealing variational objectives [huang2018improving].

In this paper, we will pursue a different approach. Instead of transforming the initial proposal in exactly K steps, we learn a sweep update that leaves the target density invariant

pθ(zk|x)=pθ(zk-1|x)qϕ(zk|x,zk-1)𝑑zk-1. (2)

When this condition is met, the proposal qϕ(z1,zK|x) is a Markov Chain whose stationary distribution is the posterior. This means a sweep update learned at training time can be applied at test time to iteratively improve sample quality, without requiring a pre-specified number of updates K.

In addition, when we require that each single block update qϕ(zbx,z-b) also leaves the target density invariant,

pθ(zb,z-bx) =pθ(zb,z-bx)qϕ(zb,x,z-b)dzb, (3)
=pθ(z-bx)qϕ(zbx,z-b),

Then we see that a block update must equal the exact conditional posterior, qϕ(zbx,z-b)=pθ(zbx,z-b). In other words, when the condition in Equation 3 is met, the proposal qϕ(z1,zK|x) is a Gibbs sampler.

2.1 Variational Objective

To learn each of the block proposals qϕ(zbx,z-b) we will minimize the inclusive KL divergence 𝒦b(ϕ)

𝔼p^(x)pθ(z-b|x)[KL(pθ(zbx,z-b)||qϕ(zbx,z-b))]. (4)

Unfortunately, this objective is intractable, since we are not able to evaluate the density of the true marginal pθ(z-bx), nor that of the conditional pθ(zbz-b,x). As we will discuss in Section 5, this has implications for the evaluation of learned proposals, since we cannot compute a lower or upper bound on the log marginal likelihood as in other variational methods. However, it nonetheless possible to approximate the gradient of the objective

-ϕ𝒦b(ϕ) =𝔼p^(x)pθ(zb,z-b|x)[ϕlogqϕ(zb|x,z-b)].

We can estimate this gradient using any Monte Carlo method that generates samples zpθ(zx) from the posterior. In the next section, we will use the learned proposals to define an importance sampler, which we then use to compute an self-normalized estimator of the gradient from weighted samples {(wl,zl)}l=1L,

-ϕ𝒦b(ϕ)l=1Lwllwlϕlogqϕ(zbl|x,z-bl). (5)

In problems where we would like to learn a deep generative model pθ(x,z), we can apply a similar self-normalized gradient estimator of the form

θlogpθ(x) =𝔼pθ(z|x)[θlogpθ(x,z)] (6)
l=1Lwllwlθlogpθ(x,zl).

This identity holds due to the standard property 𝔼pθ(z|x)[θlogpθ(z|x)]=0 (see Appendix A for details).

The estimator in Equation 5 is similar to the self-normalized estimator in reweighted wake-sleep methods [bornschein2014reweighted], which also minimizes an inclusive KL divergence. This estimator has a number of advantages over the estimator that is commonly used to train standard VAEs, which minimize an exclusive KL divergence [le2019revisiting]. Standard VAE objectives rely on reparameterization to compute gradient estimates. For discrete variables, reparameterization is not possible. This means that we need to compute likelihood-ratio estimators (also known as REINFORCE-style estimators [williams1992simple]), which can have very high variance. A range of approaches for variance reduction have been put forward, including continuous relaxations that are amenable to reparameterization [maddison2017concrete, jang2017categorical], credit assignment techniques (see [weber2019credit] for a review), and other control variates [mnih2016variational, tucker2017rebar, grathwohl2018backpropagation].

The estimator in Equation 5 sidesteps the need for these variance reduction techniques. To compute this gradient, we only require that the proposal density is differentiable, whereas reparameterized estimators require that the sample itself is differentiable. This is a milder condition, that holds for most distributions of interest, including those over discrete variables. Moreover, since this estimator minimizes the inclusive KL divergence, and not the exclusive KL divergence, there is smaller risk of learning a proposal that collapses to a single mode of a multi-modal posterior [le2019revisiting].

2.2 Generating High Quality Samples

Approximating the gradient presents a chicken-and-egg problem; we need samples from the posterior to compute a Monte Carlo estimate of the gradient, but generating these samples is precisely what we are hoping to use learned proposals for in the first place. Moreover, self-normalized importance samplers are consistent, but they are not unbiased. In the early stages of training, we will have poor quality proposals, which means that the bias of the gradient estimators in Equations 5 and 6 can be high.

Standard reweighted wake-sleep methods generate proposals from an encoder zqϕ(zx) and compute weights w=pθ(x,z)/qϕ(zx). A well-known limitation of this type of naive importance sampling strategy is that the computed weights will have a very high variance in models with high-dimensional and/or correlated latent variables, which in turn implies a high bias of the estimator. There is a very broad class of importance sampling strategies that can be employed to reduce the variance of importance weights. If we replace the naive importance sampler in reweighted wake-sleep with a more sophisticated sampling strategy, then this both improves the quality of gradient estimates at training, and the quality of inference results at test time.

To improve upon standard reweighted wake-sleep methods, we will we use the learned proposals to define a sequential Monte Carlo (SMC) sampler [delmoral2006sequential]. SMC methods [doucet2001sequential] combine two basic ideas. The first is sequential importance sampling, which decomposes a proposal for a sequence of variables into a sequence of conditional proposals. The second is resampling, which selects partial proposals with probability proportional to their weights in order to improve the overall sample quality. Most commonly, SMC methods are used in the context of state space models to generate proposals for a sequence of variables by proposing one variable at a time. SMC samplers (see Algorithm 2.2) are a subclass of SMC methods that interleave resampling with the application of a transition kernel, which is sometimes also referred to as resample-move SMC.

The distinction between SMC methods for state space models and SMC samplers is subtle but important. Whereas the former generate proposals for a sequence of variables z1:t by proposing ztq(ztz1:t-1) to extend the sample space at each iteration, SMC samplers can be understood as an importance sampling analogue to Markov chain Monte Carlo (MCMC) methods [brooks2011handbook], which construct a Markov chain z1:k by generating a proposal q(zkzk-1) from a transition kernel at each iteration.

Sequential Importance Sampling. To understand how approximate Gibbs proposals can be used in a SMC sampler, we will first explain how they can be used to define a sequential importance sampler, which decomposes the importance weight into a sequence of incremental weights. In general, SIS considers a sequence of unnormalized target densities γ1(z1),γ2(z1:2),,γK(z1:K). If we now consider an initial proposal q1(z1), along with a sequence of conditional proposals qk(zkz1:k-1), then we can recursively construct a sequence of weights wk=vkwk-1 by assuming w1=γ1(z1)/q1(z1) and defining the incremental weight

vk =γk(z1:k)γk-1(z1:k-1)qk(zkz1:k-1).

This construction ensures that, at step k in the sequence, we have a weight wk relative to the intermediate density γk(z1:k) of the form (see Appendix B)

wk=γk(z1:k)q1(z1)k=2kqk(zkz1:k-1).
{algorithm}

[!t] \setstretch1.2 SMC sampler \[email protected]@algorithmic[1] \Forl=1 to L \Statez1,lq1()\CommentPropose \Statew1,l=γ1(z1,l)q1(z1,l)\CommentWeigh \EndFor\Fork=2 to K \Statezk-1,1:L,wk-1,1:L=resample(zk-1,1:L,wk-1,1:L)

\For

l=1 to L \Statezk,lqk(zk-1,l)\CommentPropose \Statewk,l=γk(zk,l)rk-1(zk-1,lzk,l)γk-1(zk-1,l)qk(zk,lzk-1,l)wk-1,l\CommentWeigh \EndFor\EndFor We will now consider a specific sequence of intermediate densities that are defined using a reverse kernel r(zz)

γk(z1:k)=pθ(x,zk)k=2kr(zk-1zk).

This defines a density on an extended space such that

γk(zk)=γk(z1:k)𝑑z1:k-1=pθ(x,zk).

This means that at each step k, we can treat the preceding samples zk-1 as auxiliary variables; if we generate a proposal z1:k and simply disregard z1:k-1, then the pair (wk,zk) is a valid importance sample relative to pθ(zkx). If we additionally condition proposals on x, the incremental weight for this particular choice of target densities is

vk=pθ(x,zk)r(zk-1zk)pθ(x,zk-1)q(zkzk-1). (7)

This construction defines a valid importance sampler for any choice of proposal kernel q(zkzk-1) and reverse kernel r(zk-1zk). For a given choice of proposal, the optimal reverse kernel is

r(zk-1zk)=pθ(x,zk-1)pθ(x,zk)q(zkzk-1).

For this choice of kernel, the incremental weights are 1, which minimizes the variance of the weights wk.

We will now use the approximate Gibbs kernel from Equation 1 as both the forward and the reverse kernel

q(zkzk-1)=r(zk-1zk)=qϕ(zkx,zk-1). (8)

When the approximate Gibbs kernel converges to the actual Gibbs kernel, this choice becomes optimal, since the kernel will satisfy detailed balance in this limit

pθ(x,zk)qϕ(zk-1|x,zk)=pθ(x,zk-1)qϕ(zk|x,zk-1).

Resampling. In general, the weights wk in the sequential importance sampling scheme defined above will have a high variance. The weights w1 are just normal importance sampling weights, which themselves will have a high variance when z is high-dimensional, or there are correlations between variables. Moreover, we are now sampling these same variables k times. When the approximate Gibbs kernel converges to the true kernel, this will not increase the variance of weights (since vk=1 in this limit), but during training variance of weights wk will increase with k, since we are now jointly sampling an entire Markov chain. {algorithm}[!t] \setstretch1.2 resample \[email protected]@algorithmic[1] \StateInput: z 1:L,w 1:L \StateOutput: z 1:L,w 1:L \Fori=1 to L \StateaiDisc({wl/l=1Lwl}l=1L)\CommentIndex Selection \StateSet zi=zai \StateSet wi=1Ll=1Lwl\CommentRe-Weigh \EndFor\StateReturn z 1:L,w 1:L To overcome this problem, SMC samplers interleave application of the transition kernel with a resampling step. This step generates a new set of samples by selecting current samples with replacement, with probability proportional to their weight. Concretely, suppose that we have a set incoming samples {(wk,l,zk,l)}l=1L, then the resampling procedure (see Algorithm 2.2) selects index a with probability P(a=l)=wk,l/lwk,l and returns an outgoing sample zk,l=zk,a whose wk,l=1Llwk,l is equal to the average weight. The reduces the variance of the importance weights at the expense of also reducing the diversity of the sample set; high-weight samples are selected frequently, whereas low-weight samples are selected infrequently or not at all.

When we perform resampling after each sweep, we reduce the variance of importance weights to an extent. However we will likely still have high-variance weights, since each sample from the approximate Gibbs kernel still constitutes a high-dimensional proposal over all variables in the model. To further reduce the variance, we will employ resampling after each block update, rather than after each sweep. Because the incoming weights are now equal at each block update, we can compute gradient estimates using incremental weights v of the form

v=pθ(x,zb,z-b)qϕ(zbx,z-b)pθ(x,zb,z-b)qϕ(zbx,z-b). (9)

These incremental weights will have a much lower variance than the incremental weights for a full sweep, since we are now able to decompose a sampling problem for all the variables in a model into sampling problems for individual blocks. In models with many latent variables, such as the ones that we will consider in our experiments, this has the potential to greatly increase the tractability of the gradient estimation problem.

We refer to this implementation of a SMC sampler as an amortized population Gibbs (APG) sampler, and summarize all the steps of the computation in Algorithm 2.2. In Appendix D, we prove that this algorithm is correct using an argument based on proper weighting. More informally this property holds due to the fact that this sampler is a specific instance of a SMC sampler.

{algorithm}

[!tb] \setstretch1.2 Amortized Population Gibbs Sampling \[email protected]@algorithmic[1] \Stategϕ=0,gθ=0\CommentInitialize gradient to 0 \Forl=1 to L\CommentInitial proposal \Statez1,lqϕ(x)\CommentPropose \Statew1,l=pθ(x,z1,l)qϕ(z1,lx)\CommentWeigh \EndFor\Stategϕ=gϕ+l=1Lw1,ll=1Lw1,lϕlogqϕ(z1,lx) \Stategθ=gθ+l=1Lw1,ll=1Lw1,lθlogpθ(x,z1,l) \Fork=2 to K\CommentGibbs sweeps \Statez~1:L,w~1:L=zk-1,1:L,wk-1,1:L \Forb=1 to B\CommentBlock updates \Statez~1:L,w~1:L=resample(z~1:L,w~1:L) \Forl=1 to L \Statez~blqϕ(x,z~-bl)\CommentPropose \State w~l=pθ(x,z~bl,z~-bl)qϕ(z~blx,z~-bl)pθ(x,z~bl,z~-bl)qϕ(z~blx,z~-bl)w~l \CommentWeigh \Statez~bl=z~bl \CommentReassign \EndFor\Stategϕ=gϕ+l=1Lw~ll=1Lw~lϕlogqϕ(z~blx,z~-bl) \Stategθ=gθ+l=1Lw~ll=1Lw~lθlogpθ(x,z~l) \EndFor\Statezk, 1:L,wk, 1:L=z~1:L,w~1:L \EndFor\Returngϕ, gθ\CommentOutput: gradient

3 Neural Sufficient Statistics

Gibbs sampling strategies that sample from exact conditionals rely on conjugacy relationships. Typically, we assume a prior and likelihood that can both be expressed as exponential families

p(xz) =h(x)exp{η(z)T(x)-logA(η(z))},
p(z) =h(z)exp{λT(z)-logA(λ)}.

In these densities h() is a base measure, T() is a vector of sufficient statistics, and A() is a log normalizer. The two densities are jointly conjugate when

T(z)=(η(z),-logA(η(z)))

In this case, the posterior distribution lies in the same exponential family as the prior

p(zx)h(z)exp{ (λ1+T(x))T(z)
-(λ2+1)logA(η(z))}.

Typically, the prior p(zλ) and likelihood p(xz) are not jointly conjugate, but it is possible to identify conjugacy relationships at the level of individual blocks of variables,

p(zbz-b,x)h(zb)exp{ (λb,1+T(x,z-b))T(zb)
-(λb,2+1)logA(η(zb))}.

In the more general setting we consider here, these conjugacy relationships will typically not hold. However, we can still take inspiration to design variational distributions that make use of conditional independencies in a model. We will assume that each of the approximate Gibbs updates qϕ(zbx,z-b) is an exponential family, whose parameters are computed from a vector of prior parameters λ and a vector of neural sufficient statistics Tϕ(x,z-b)

qϕ(zbx,z-b)=p(zbλ+Tϕ(x,z-b)). (10)

This parameterization has a number of desirable properties. Exponential families are the largest-entropy distributions that match the moments defined by the sufficient statistics (see e.g. [wainwright2008graphical]), which is helpful when minimizing the inclusive KL divergence. In exponential families it is also more straightforward to control the entropy of the variational distribution. In particular, we can initialize Tϕ(x,z-b) to output values close to zero in order to ensure that we initially propose from a prior and/or regularize Tϕ(x,z-b) to help avoid local optima.

A particularly useful case arises in models where the data x={x1,,xN} are independent conditioned on z. In these models it is often possible to partition the latent variables z={zg,zl} into global and local variables zg and local variables zl. The dimensionality of global variables is typically constant, whereas local variables zl={z1l,,zNl} have a dimensionality that increases with the data N. For models with this structure, the local variables are typically conditionally independent znLz-nLx,zg, which means that we can parameterize the sufficient statistics as

λ~g =λg+n=1NTϕg(xn,znl), λ~nl =λnl+Tϕl(xn,zg).

The advantage of this parameterization is it allows us to train approximate Gibbs updates for global variables in a manner that scales dynamically with the size of the dataset, and appropriately adjusts the posterior variance according to the amount of available data.

(a) GMM
(b) DGMM
Figure 1: Samples from the GMM and the DGMM. (a) GMM, the left column shows 5 test datasets with different number of data points. The subsequent columns show inference results by RWS, followed by results after 4, 8 and 12 APG updates. (b) DGMM, the left column shows 5 test datasets with different number of data points. The subsequent columns show the inference results by RWS, followed by results after 3 and 6 APG updates. The right column shows reconstructions from the learned generative model.

4 Related Work

Our work fits into a line of recent methods for deep generative modeling that seek to improve inference quality, either by introducing auxiliary variables [maaloe2016auxiliary, ranganath2016hierarchical], or by performing iterative updates [marino2018iterative]. Our specific approach to learning block proposals is related to a number of methods that, in some way or other, combine transition kernels with variational inference. Work by Hoffman [hoffman2017learning] applies Hamiltonian Monte Carlo to samples that are generated from the encoder, which serves to improve the gradient estimate w.r.t. θ (Equation 6), while learning the inference network using a standard reparameterized ELBO objective. Li et al. [li2017approximate] similarly use MCMC to improve the quality of samples that are generated by an encoder, but additionally use these samples to train the encoder by minimizing the inclusive KL divergence relative to the filtering distribution of the Markov chain. As in our work, the filtering distribution after multiple MCMC steps is intractable. Li et al. therefore use an adversarial objective to minimize the inclusive KL. Neither of these two lines of work consider block decomposition of the latent variable space, nor do they learn transition kernels.

Work by Salimans et al. [salimans2015markov] uses transition kernels in variational inference. The authors use an importance weight to define (stochastic) lower bound, which is defined using a forward and reverse kernel in the same manner as in Equation 7. Huang et al. [huang2018improving] extend the work by Salimans et al. by learning a sequence of transition kernels that performs annealing from the initial encoder to the posterior. Since both these methods minimize an exclusive KL, rather than an inclusive KL, gradient estimates must be computed using reparameterization, which means that these methods are not applicable to models that contain discrete variables. Moreover, these methods perform a joint update on all variables at each iteration, and do not consider the task of learning conditional proposals as we do here.

Work by Wang et al. [wang2018meta] develops a meta-learning approach to learning Gibbs block conditionals. This work assumes a setup in which it is possible to sample x,zp(x,z) from the true generative model p(x,z), which means gradients can be estimated using sleep-phase Monte Carlo estimators. This circumvents the need for self-normalized estimators of the form in Equation 5, which are necessary when we additionally wish to learn the generative model. Like in our work, the approach by Wang et al. minimizes the inclusive KL, but uses the learned conditionals to directly define an (approximate) MCMC sampler, rather than using them as proposals in an SMC sampler. This work also has a somewhat different focus from ours, in that it primarily seeks to learn block conditionals that have the potential to generalize to previously unseen graphical models.

5 Experiments

We evaluate APG methods in 3 tasks. We begin by considering a Gaussian mixture model (GMM) as an exemplar of a model in the conjugate-exponential family. Here we verify that the learned block updates converge the analytical conditional posteriors as predicted by our analysis in Section 2. We next consider a deep generative mixture model (DGMM) that incorporates a neural likelihood to parameterize ring-shaped clusters. We show that we can train both the generative model and inference model in an end-to-end manner using APG methods, and that inference scales to datasets containing up to 600 points. For both models we quantify performance in terms of the effective sample size (ESS) and the relative magnitude of the log joint. In our third experiment, we consider an unsupervised model for multiple bouncing MNIST data. We extend the task proposed by Srivastava et al. [srivastava2015unsupervised] to consider up to 5 individual digits, and learn both a deep generative model for videos and an inference model that performs tracking.

(a) logpθ(x,z)                                                                  (b) ESS / L

Figure 2: APG sampler performance as a function of number of sweeps K for a constant sample budget KL=1000.
Table 1: APG performance in the GMM and DGMM. The left column in each table shows the change in log joint distribution, i.e. the difference between the log joint in the baseline and the log joint in other models. We compute the ESS/L metric is computed w.r.t. different variable blocks. For the GMM we additionally report the inclusive KL (Equation 4) for each block.
(a) GMM
Δlogpθ(x,z) ESS/L KL(pθ||qϕ)
{τ,μ,c} {τ,μ} {c} {τ,μ} {c}
MLP-RWS 0.001
LSTM-RWS 202.2 0.104
APG (K=5) 198.5 0.261 0.980 0.631 0.005 0.005
APG (K=10) 211.9 0.398 0.981 0.760 0.004 0.004
APG (K=15) 215.2 0.416 0.983 0.780 0.003 0.004
(b) DGMM
Δlogpθ(x,z) ESS/L
{μ,c,α} {μ} {c,α}
2538 0.001
0.001
6201 0.002 0.013 0.422
6293 0.002 0.019 0.454
6310 0.003 0.025 0.488

Results on each of these tasks constitute a significant advance relative to the state of the art. Standard VAEs perform poorly at Gaussian mixture modeling tasks, and to our knowledge there are no existing methods that scale to a problem of the complexity of the DGMM for rings. In the context of the unsupervised tracking model, APG easily scales beyond previously reported results for a specialized recurrent architecture [kosiorek2018sequential]. APG is not only is able to scale to models with higher complexity in these settings, but also provides a general framework for performing inference in models with global and local variables, which can be adapted to a variety of model classes with comparative ease.

5.1 Gaussian Mixture Model

To evaluate whether APG samplers can learn the exact Gibbs updates in conditionally conjugate models, we consider a Gaussian mixture model

μi,τiNormGamma(μ0,ν0,α0,β0),i =1,2..,I
cnCat(π),xn|cn=iNorm(μi,1/τi),n =1,2,..,N

In this model, the global variables zg={μ1:I,τ1:I} are the mean an precision for each mixture component, whereas the local variables are the cluster assignments zl={c1:N}. Conditioned on cluster assignments, the Gaussian likelihood p(x1:Nz1:N,μ1:I,τ1:I) is conjugate to a normal-gamma prior p(μ1:I,τ1:I) with sufficient statistics T(xn,cn)

{I[cn=i],I[cn=i]xn,I[cn=i]xn2|i=1,2,,I},

where I[zn=i] is an indicator function that evaluates to 1 if the equality holds, and 0 otherwise.

We employ a variational distribution that updates the global variables qϕ(μ,τx,c) and the local variables qϕ(cx,μ,τ), using point-wise neural sufficient statistics modeled after the ones in the analytical updates (see Appendix E for architecture details).

We train our models on 20,000 datasets with I=3 clusters and N=60 data points with fixed hyperparameters (μ0=0, ν0=0.3, α0=2, β0=2). We use 20 GMM datasets per batch, K=10 sweeps, L=10 particles, and Adam (lr=10-4,β1=0.9,β2=0.99) for 200,000 iterations.

We compare the APG sampler to samples from a standard encoder with MLP and LSTM architectures, which is trained using reweighted wake-sleep (RWS). Both architectures are parameterized using the same neural sufficient statistics as the APG sampler.

Figure 0(a) shows sequences of single samples from the variational distribution, where the first sample is drawn using RWS. Even when using a parameterization that employs neural sufficient statistics, the RWS encoder fails to propose reasonable clusters, whereas the APG sampler typically converges within 12 iterations across a range of dataset sizes.

Furthermore, we would like to quantify how similar learned proposals qϕ(zbx,z-b) are to the conditional posteriors pθ(zbx,z-b). With the case of GMM where the exact conditional posterior is tractable, we verify the convergence of the learned proposals by computing the inclusive KL divergence 𝒦b(ϕ) defined in equation  4 (see Table 1(a)). We can see that the APG samplers of the both {τ,μ} and {c} converge to the true conditional posterior.

5.2 Deep Generative Mixture Model

We next consider the task of training a deep generative model pθ(x,z) is jointly with the APG sampler. Our dataset consists of ring-shaped clusters. The true generative model (which we assume is unknown) takes the form

μiNorm(0,σ02I),i=1,2,,I
cnDisc(π),αnUnif[0,2π],
xn|cn=iNorm(gθ(αn)+μi,Σϵ).

Here μi is center of the ith ring. Given a cluster assignment cn and an angle αn we define a position on a ring, from which We sample a data point xn with 2D Gaussian noise.

Figure 3: Inferred digit trajectories and reconstructions for (top) D=3 (middle) D=4 and (bottom) D=5 digits for T=15 for a model trained on D=3 and T=10.

We train our model on 20,000 datasets with N=200 data points and I=4 clusters with fixed hyperparameters (σ0=3.5, Σϵ=0.2). We use 20 datasets per batch, K=10 sweeps, L=10 particles, and Adam (lr=10-4,β1=0.9,β2=0.99) for 200,000 iterations (see Appendix E for architecture details).

Once again, we compare the APG sampler with the encoders using RWS. Figure 0(b) shows individual samples analogous to the ones in Figure 0(a). The APG sampler scales to a large range of number of variables, whereas a standard encoder trained using RWS fails to produce reasonable proposals.

5.3 Sample Quality Evaluation

In both mixture models, we compute the log-joint distribution logpθ(x,z) (see Table 1) as a function of sweep iteration to measure the convergence and the effective sample size (see Table 1) to assess proposal quality

ESSL=(l=1Lwk,l)2Ll=1L(wk,l)2. (11)

Log joint logpθ(x,z). Because the marginal qϕ(zk|x) is intractable, it is difficult to compute an lower bound or upper bound at each sweep. Here we compute the log joint in each test dataset for both the APG sampler with different number of sweeps and the RWS baselines and report the differences on average to see how much more is achieved by the APG sampler. In both models, the APG sampler gains a higher log joint compared with the encoder trained by RWS.

ESS. One advantage of the APG sampler that it decomposes a high dimensional sampling problem into a sequence of lower dimensional sampling problems. To show that, we compute the ESS when 1) we resample only after one sweep and 2) we resample after each block update. WE can see that the granular sampling strategy significantly improves the ESS in both cases.

5.4 Fixed Computation Budget Analysis

As a mean of comparing the performance of APG samplers for varying numbers of sweeps K, we perform an experiment in which the computation budget is fixed at KL=1000 samples. Figure 2 shows logpθ(x,z) and ESS/L. The shaded area denotes the standard deviation over 10 runs that each comprise 5 datasets that were chosen at random. We can see that it in general, it is more effective to perform more APG sweeps K with a smaller number of particles L, that it is to increase the particle budget.

5.5 Time Series Model – Bouncing MNIST

Finally, we apply the APG sampler to a time series model that is trained with short timescales, and evaluate its performance with longer timescales and larger numbers of latent variables. The data x1:T is a sequence of images of D moving MNIST digits. Our generative model consists of global variables z1:Dwhat corresponding to digit latent variables and local variables z1:D,1:Twhere corresponding to the digit trajectories. The deep generative model is a state space model that factorizes across digits of the form

zdwhat Norm(0,I),zd,1whereNorm(0,I),
zd,twhere Norm(zd,t-1where,σ02I)
xt Bern(σ(dST(μθ(zdwhat),zd,twhere)))

Here, ST is a spatial transformer [jaderberg2015spatial] that maps the output of a feedforward decoder μθ that maps logits for a 28×28 MNIST image onto a 96×96 canvas based on the location variable zd,twhere.

Our amortized Gibbs updates employ T+1 blocks (z1:Dwhat,z1:D,1where,z1:D,2where,,z1:D,Twhere). Empirically this works better than splitting the latent variables into global and local variables, since resampling at each time step t helps disentangle the digit locations if they overlap.

We train our model on 60000 bouncing MNIST sequences, each of which contains D=3 digits and T=10 frame images. We use 10 sequences per batch, K=5 sweeps, L=10 particles, and Adam (lr=10-4,β1=0.9,β2=0.99) for 200,000 iterations (see Appendix E for architecture details).

We show that APG sampler can scale to larger number of variables by testing the model on datasets with T{20,100} time steps and D{3,4,5} digits. Figure 3 shows the inference and reconstruction using single samples from the variational distribution. (plots are truncated by the first 15 time steps due to limited space, see Appendix F for more examples with full time steps). Qualitatively, we see that the digit trajectories z1:D,1:Twhere and latent variables z1:Dwhat are inferred well. In Figure 4, we show the mean squared error between the video and its reconstruction for different T and D. The results confirm that performance improves with increasing number of Gibbs sweeps K. In certain cases, a larger number of time points T in fact improves convergence as a function of the number of sweeps K.

Figure 4: Mean squared error between video frames and reconstructions as a function of the number of APG sweeps.

6 Conclusion

One of the challenges in amortized inference for deep generative models is learning high-quality proposals for models with a structured prior over a high-dimensional set of latent variables. These priors arise naturally when, rather than encoding a single data point (e.g. an image), we wish to encode a dataset (e.g. a sequence of images). Even for apparently simple problems, such as inferring the cluster parameters and assignments in a mixture model, standard encoders often fail to produce good samples. One of the reasons for this is that it is fundamentally difficult to jointly generate proposals for a high-dimensional set of latent variables.

APG samplers are very general, and offer a path towards the development of deep generative models that incorporate structured priors to provide meaningful inductive biases in settings where we have little or no supervision. These methods have particular strengths in problems with global variables, but more generally make it possible to design amortized approaches that exploit conditional independence. Moreover, our parameterization in terms of neural sufficient statistics makes it comparatively easy to design models that scale to much larger number of latent variables and thus generalize to datasets that vary in size.

Immediate lines of future work are to compare the approach in this paper, which learns kernels that leave the target density invariant, with approaches that perform annealing, in which the learned kernels are assymmetric in the sense that they gradually transform the initial encoder distribution to the target density.

7 Acknowledgements

This work was supported by the Intel Corporation, NSF award 1835309, the DARPA LwLL program, and startup funds from Northeastern University. Tuan Anh Le was supported by AFOSR award FA9550-18-S-0003.

\printbibliography

Appendix A Gradient of the generative model

This is actually a known (although indeed not obvious) identity. Briefly, we can express the expected gradient of the log joint as

𝔼pθ(z|x)[θlogpθ(x,z)]𝔼pθ(z|x)[θlogpθ(x)+θlogpθ(z|x)]𝔼pθ(z|x)[θlogpθ(x)]θlogpθ(x)

Here we make use of a standard identity that is also used in, e.g., likelihood-ratio estimators

𝔼pθ(z|x)[θlogpθ(z|x)]=pθ(z|x)θlogpθ(z|x)𝑑z=θpθ(z|x)𝑑z=θpθ(z|x)𝑑z=θ1=0

Therefore, we have the the following equality

θlogpθ(x)=𝔼pθ(z|x)[θlogpθ(x,z)].

which is Equation  6. As a result, we can then use self-normalized importance sampling to approximate 𝔼pθ(z|x)[θlogpθ(x,z)].

Appendix B Importance weights in sequential importance sampling

At step k=1, we use exactly the standard importance sampler, thus it is obvious that the following is a valid importance weight

w1=γ1(z1)q1(z1).

When step k>2, we are going to prove that the importance weight relative to the intermediate densities has the form

wk=γk(z1:k)q1(z1)k=2kqk(zkz1:k-1). (12)

At step k=2, the importance weight is defined as

wk =v2w1=γ2(z1:2)γ1(z1)q2(z2z1)γ1(z1)q1(z1)=γ2(z1:2)q1(z1)q2(z2z1).

which is exactly that form. Now we prove weights in future steps by induction. At step k2, assume the weight has the form in Equation 12, i.e.

wk=γk(z1:k)q1(z1)k=2kqk(zkz1:k-1).

, then at step k+1, the importance weight is the product of incremental weight and incoming weight

wk+1=vk+1wk=γk+1(z1:k+1)γk(z1:k)qk+1(zk+1z1:k)γk(z1:k)q1(z1)k=2kqk(zkz1:k-1)=γk+1(z1:k+1)q1(z1)k=2k+1qk(zkz1:k-1).

Thus the importance weight wk has the form of Equation 12 at each step k>2 in sequential importance sampling.

Appendix C Derivation of Posterior Invariance

We can see that individual block updates leave the posterior invariant by proposing variables zbk from a partial kernel κ(zbkx,zk-1) and then marginalize over the corresponding variables from the previous step zbk-1,

𝑑zbk-1pθ(zk-1x)κ(zbkx,zk-1) =𝑑zbk-1pθ(zk-1x)𝑑zbkκ(zkx,zk-1)
=𝑑zbk-1pθ(zk-1x)m=1bpθ(zmkx,zmk,zmk-1)
=𝑑zbk-1pθ(zk-1x)pθ(zbkx,z1k-1)
=pθ(zbk,zbk-1x).

Appendix D Proof of the amortized population Gibbs algorithm

Here, we provide an alternative proof of correctness of the APG algorithm given in Algorithm 2.2, based on the construction of proper weights [naesseth2015nested] which was introduced after SMC samplers [delmoral2006sequential]. We first introduce proper weights, and then present several operations that preserve the proper weighting property and finally we apply these properties in proving correctness of APG.

D.1 Proper weights

Definition 1 (Proper weights).

Given an unnormalized density p~(z), with corresponding normalizing constant Zp:=p~(z)dz and normalized density pp~/Zp, the random variables z,wP(z,w) are properly weighted with respect to p~(z) if and only if for any measurable function f

𝔼P(z,w)[wf(z)]=Zp𝔼p(z)[f(z)]. (13)

We will also denote this as

z,wp.w.p~.
Using proper weights.

Given independent samples zl,wlP, we can estimate Zp by setting f1:

Zp1Ll=1Lwl.

This estimator is unbiased because it is a Monte Carlo estimator of the left hand side of (13). We can also estimate 𝔼p(z)[f(z)] as

𝔼p(z)[f(z)]1Ll=1Lwlf(zl)1Ll=1Lwl.

While the numerator and the denominator are unbiased estimators of Zp𝔼p(z)[f(z)] and Zp respectively, their fraction is biased. We often write this estimator as

𝔼p(z)[f(z)]l=1Lw¯lf(zl), (14)

where w¯l:=wl/l=1Lwl is the normalized weight.

D.2 Operations that preserve proper weights

Proposition 1 (Nested importance sampling).

Adapted from [naesseth2015nested, Algorithm 1]. Given unnormalized densities q~(z),p~(z) with the normalizing constants Zq,Zp and normalized densities q(z),p(z), if

z,wp.w.q~, (15)

then

z,wp~(z)q~(z)p.w.p~.
Proof.

First define the distribution of z,w as Q. For measurable f(z)

𝔼Q(z,w)[wp~(z)q~(z)f(z)]=Zq𝔼q(z)[p~(z)f(z)q~(z)]=Zqq(z)p~(z)f(z)q~(z)dz=p~(z)f(z)dz=Zp𝔼p(z)[f(z)].

Proposition 2 (Resampling).

Adapted from [naesseth2015nested, Section 3.1]. Given an unnormalized density p~(z) (normalizing constant Zp, normalized density p(z)), if we have a set of properly weighted samples

zl,wlp.w.p~,l=1,,L (16)

then the resampling operation preserves the proper weighting, i.e.

zl,wlp.w.p~,l=1,,L

where zl=za with probability P(a=i)=wi/l=1Lwl and wl:=1Ll=1Lwl.

Proof.

Define the distribution of zl,wl as P^. We show that for any f, 𝔼[f(za)wl]=Zp𝔼p(z)[f(z)].

𝔼(l=1LP^(zl,wl))p(aw1:L)[f(za)wl]
=𝔼l=1LP^(zl,wl)[i=1Lf(zi)wP(a=i)]
=𝔼l=1LP^(zl,wl)[i=1Lf(zi)wwil=1Lwl]
=𝔼l=1LP^(zl,wl)[1Li=1Lf(zi)wi]
=1Li=1L𝔼P^(zi,wi)[f(zi)wi]=1Li=1LZp𝔼p(z)[f(z)]=Zp𝔼p(z)[f(z)].

Therefore, the resampling will return a new set of samples that are still properly weighted relative to the target distribution in the APG sampler (Algorithm 2.2).

Proposition 3 (Move).

Given an unnormalized density p~(z) (normalizing constant Zp, normalized density p(z)) and normalized conditional densities q(z|z) and r(z|z), the proper weighting is preserved if we apply the transition kernel to a properly weighted sample, i.e. if we have

zl,wlp.w.p~, (17)
zlq(zl|zl), (18)
wl=p~(zl)r(zl|zl)p~(zl)q(zl|zl)wl,l=1,,L (19)

then we have

zl,wlp.w.p~,l=1,,L (20)
Proof.

Firstly we simplify the notation by dropping the superscript l without loss of generality. Define the distribution of z,w as P^. Then, due to (17), for any measurable f(z), we have

𝔼P[wf(z)]=ZpEp[f(z)].

To prove (20), we show 𝔼P^(z,w)q(z|z)[wf(z)]=Zp𝔼p(z)[f(z)] for any f as follows:

𝔼P^(z,w)q(z|z)[wf(z)] =𝔼P^(z,w)q(z|z)[p~(z)r(z|z)p~(z)q(z|z)wf(z)]
=P^(z,w)q(z|z)p~(z)r(z|z)p~(z)q(z|z)wf(z)dzdwdz
=P^(z,w)p~(z)r(z|z)p~(z)wf(z)dzdwdz
=p~(z)f(z)(P^(z,w)wr(z|z)p~(z)dzdw)dz
=p~(z)f(z)Zp𝔼p(z)[r(z|z)p~(z)]dz. (21)

Using the fact that 𝔼p(z)[r(z|z)p~(z)]=p(z)r(z|z)p~(z)dz=r(z|z)dz/Zp=1/Zp. Equation 21 simplifies to

p~(z)f(z)dz=Zp𝔼p(z)[f(z)].

D.3 Correctness of APG Sampler

We provide the proof by performing 2 steps in the APG sampler (Algorithm 2.2), i.e,  we prove the correctness when we initialize samples at step k=1 (line 2.2 - line 2.2) and then do one Gibbs sweep at step k=2 (line 2.2 - line 2.2). In fact, its correctness still holds if we perform more Gibbs sweeps by induction.

Step k=1. We initialize the proposal zqϕ(z|x) (line 2.2) and train that encoder using the wake-ϕ phase objective in the standard reweighted wake-sleep[le2019revisiting] 𝔼p(x)[kl(pθ(z|x)||qϕ(z|x))]. Then we estimate its gradient w.r.t. parameter ϕ (line 2.2) as

gϕ: =-ϕ𝔼p(x)[kl(pθ(z|x)||qϕ(z|x))] (22)
=𝔼p(x)[𝔼pθ(z|x)[ϕlogqϕ(zx)]]. (23)

Step k=2. After one full sweep, we have the following objective

𝔼p(x)[b=1B𝔼pθ(z-bx)[kl(pθ(zbz-b,x)||qϕ(zbx,z-b))]]

And we will prove that we correctly estimate the following gradient w.r.t. parameter ϕ at each block update (line 2.2)

gϕb: =-ϕ𝔼p(x)[𝔼pθ(z-b|x)[kl(pθ(zb|z-b,x)||qϕ(zb|z-b,x))]] (24)
=𝔼p(x)[𝔼pθ(z1:B|x)[ϕlogqϕ(zb|z-b,x)]],b=1,,B. (25)

At each step, as long as we show that samples are properly weighted

z1:Bl,wlp.w.pθ(z1:B,x),l=1,,L. (26)

Equation 14 will guarantee the validity of both gradient estimations (line 2.2 and line 2.2).

At step k=1, samples are properly weighted because zl and wl are proposed using importance sampling (line 2.2) where qϕ(z|x) is the proposal density and pθ(zl,x) is the unnormalized target density. The resampling step (line 2.2) will preserve the proper weighting because of Proposition 2.

To prove that Gibbs sweep (line 2.2 - line 2.2) in the APG sampler also preserves proper weighting, we show that each block update satisfies all the 3 conditions (Equation 17,  19 and  26) in Proposition 3, by which we can conclude the samples are still properly weighted after each block update. Without loss of generality, we drop all l superscripts in the rest of the proof. Before we start any block update (before line 2.2), we already know that samples are properly weighted, i.e.

z,wp.w.pθ(z,x). (27)

which corresponds Equation 17. Next we define a conditional distribution q(zz):=qϕ(zb|x,z-b)δz-b(z-b), from which we propose a new sample

zqϕ(zb|x,z-b)δz-b(z-b), (28)

where the density of z-b is a delta mass on z-b defined as δz-b(z-b)=1 if z-b=z-b and 0 otherwise. In fact, this sampling step is equivalent to firstly sampling zbqϕ(|x,z-b) (line 2.2) and let z-b=z-b, which is exactly what the APG sampler assumes procedurally. This condition corresponds to Equation 18.

Finally, we define the weight w

w=pθ(x,zb,z-b)r(zb|x,z-b)δz-b(z-b)pθ(x,zb,z-b)qϕ(zb|x,z-b)δz-b(z-b)w, (29)

where the terms in blue are treated as densities (normalized or unnormalized) of z1:B and the terms in red are treated as densities of z1:B. Since both delta mass densities evaluate to one, this weight is equal to the weight computed after each block update (line 2.2). This condition corresponds to Equation 19.

Now we can apply the conclusion (20) in Proposition 3 and claim

z1:B,wp.w.pθ(z1:B,x).

since z-b=z-b and zb=zb due to the re-assignment (line 2.2). Based on the fact that proper weighting is preserves at both initial step k=1 and the Gibbs sweep k=2, we have proved that both gradient estimations (line 2.2 and line 2.2) are correct.

Appendix E Architecture of the Amortized Population Gibbs samplers

GMM

Layer qϕ(μ,τ|x)
Input Concat[xn2]
1

FC 2

FC 3 Softmax

Layer qϕ(μ,τ|x,c)
Input Concat[xn2,cn3]
1

FC 2

FC 3 Softmax

Layer qϕ(c|x,μ,τ)
Input Concat[xn2,μi2]
1

FC 32 Tanh

2 FC 1, Intermediate Variable oi
3 Concat[oi], Softmax (cn)

DGMM

Layer qϕ(μ|x)
Input xn2
1

FC 32 Tanh

FC 32 Tanh

2

FC 16 Tanh, vn

FC 4 Softmax, γn3

3 Tn:=γnvn3×16
4 Concat[nNTn[i],μ02,Diag(σ02I)2],i=1,2,3,4
5 FC 2×32 Tanh
6 FC 2×8 (μ1:I)
Layer qϕ(μ|z,c)
Input Concat[xn2,cn3]
1

FC 32 Tanh

FC 32 Tanh

2

FC 16, vn

FC 4 Softmax, γn3

3 Tn:=γnvn3×16
4 Concat[nNTn[i],μ02,Diag(σ02I)2],i=1,2,3,4
5 FC 2×32 Tanh
6 FC 2×2 (μi)
Layer qϕ(c|z,μ)
Input Concat[xn2,μi2]
1 FC 32 Tanh
2

FC 1, Intermediate Variable oi

3

Concat[oi], Softmax (cn)

Layer qϕ(α|x,z,μ)
Input xn-μi2|zn=i
1 FC 32 Tanh
2

FC 1 Tanh

Layer pq(x|μ,c,α)
Input Concat[αn,cn]5
1 FC 32 Tanh
2 FC 2 Tanh (μn, fixed σϵ)

Bouncing MNIST

Layer pθ(x|zwhat,zwhere)
Input ziwhat10
1 FC 200 ReLU
2 FC 400 ReLU
3 digit di784
4 ST(di,zi,twhere)9276,i=1,i,t=1,T
Layer qϕ(zwhat|zwhere)
Input xt9276,zi,twhere2,i=1,I,t=1,T
1 ST(xt, zi,twhere) 784,i=1,..,I,t=1,,T
2 FC 400 ReLU
3 FC 200 ReLU
4 zi,twhat10,i=1,..,I,t=1,,T
5 Mean(z1,twhat, 1:T) 10,i=1,,I
Layer qϕ(zwhere|zwhat)
Input xt9276, zi,twhat,i=1,..,t=1,..,T
1 Conv2d(xt, zi,twhat)4638, i=1,..,t=1,..,T
2 FC 400 Tanh
3 2× FC 200 Tanh
4 2×2 Tanh

Appendix F More Qualitative Results of Bouncing MNIST

The following are full reconstructions on test sets where time steps T=100 and number of digits D=3,4,5, respectively. In each figure, the 1st, 3rd, 5th, 7th, 9th rows show the inference results, while the other rows show the reconstruction of the series above.

Figure 5: Full reconstruction for a video where T=100,D=3.
Figure 6: Full reconstruction for a video where T=100,D=4.
Figure 7: Full reconstruction for a video where T=100,D=5.
Figure 8: Full reconstruction for a video where T=100,D=5.