Conjugate-Computation Variational Inference (CVI)

December 12, 2023·
Kalvik
· 17 min read
Table of Contents

Problem

In Bayesian learning, we are often interested in computing the posterior distribution p(zy)p(\mathbf{z}\mid\mathbf{y}) given a set of data y\mathbf{y}:

p(zy)=p(y,z)p(y,z)dz=p(yz)p(z)p(y) p(\mathbf{z}|\mathbf{y}) = \frac{p(\mathbf{y}, \mathbf{z})}{\int p(\mathbf{y}, \mathbf{z}) d\mathbf{z}} = \frac{p(\mathbf{y} | \mathbf{z})p(\mathbf{z})}{p(\mathbf{y})}

However, the marginal p(y)p(\mathbf{y}) is often computationally intractable due to the integral involved in computing p(y,z)dz\int p(\mathbf{y}, \mathbf{z}) d\mathbf{z}. As a result, we need to resort to approximate inference methods such as variational inference to compute the posterior. Variational inference using stochastic gradient descent is a well-established approach (refer to my tutorial on variational Gaussian approximation for more details).

But if we make a few minor assumptions about how the inference problem is set up, we can leverage the conjugacy of the distributions involved in computing the posterior to obtain a computationally efficient variant of variational inference called Conjugate-Computation Variational Inference (CVI) (Khan and lin, 2017). This article will explain the CVI approach and its derivation.

Definition of Conjugacy

Suppose F\mathcal{F} is the class of data distributions p(yz)p\mathbf{(y\mid z)} parameterized by z\mathbf{z}, and P\mathcal{P} is the class of prior distributions for z\mathbf{z}. Then, the class P\mathcal{P} is conjugate for F\mathcal{F} if:

p(zy)P, p(θ)F and p()Pp\mathbf{(z|y)} \in \mathcal{P}, \ \forall p(·|\theta) \in \mathcal{F} \text{ and } p(·) \in \mathcal{P}

In other words, for a given likelihood p(yz)p\mathbf{(y\mid z)}, if posterior p(zy)p\mathbf{(z\mid y)} belongs to the same class of distributions P\mathcal{P} as the prior p(z)p(\mathbf{z}), then the class of distributions P\mathcal{P} is said to be conjugate for the likelihood’s class of distributions F\mathcal{F}.

Conjugate Model

When the probabilistic graphical model, i.e., the joint distribution p(y,z)p(\mathbf{y}, \mathbf{z}), can be decomposed into a prior p(z)p(\mathbf{z}) that is conjugate to the likelihood p(yz)p(\mathbf{y\mid z}), then the posterior p(zy)p(\mathbf{z\mid y}) is available in closed form. Indeed, the posterior p(zy)p(\mathbf{z\mid y}) be computed using simple computations referred to as conjugate computations.

Consider the prior and likelihood in the exponential-family distribution with the following natural parametrization:

p(z)=h(z)exp[ϕ(z),λpriorA(λprior)]p(yz)=exp[ϕ(z),λˉAˉ(λˉ)] \begin{aligned} p(\mathbf{z}) &= h(\mathbf{z}) \exp[\langle \phi(\mathbf{z}), \lambda_\text{prior} \rangle - A(\lambda_\text{prior})] \\ p(\mathbf{y|z}) &= \exp[\langle \phi(\mathbf{z}), \bar{\lambda} \rangle - \bar{A}(\bar{\lambda})] \\ \end{aligned}

here, h(z)h(\mathbf{z}) is a base measure, ϕ\phi are the sufficient statistics, and functions AA and Aˉ\bar{A} are the log-partition functions. In such cases, the posterior p(zy)p(\mathbf{z\mid y}) takes the same exponential form as the prior p(z)p(\mathbf{z}), with natural parameters being the sum of the prior and likelihood’s natural parameters:

p(zy)h(z)exp[ϕ(z),λprior+λˉ] p(\mathbf{z | y}) \propto h(\mathbf{z}) \exp[\langle \phi(\mathbf{z}), \lambda_\text{prior}+\bar{\lambda} \rangle] \\

Non-Conjugate Model

When the joint distribution p(y,z)p(\mathbf{y}, \mathbf{z}) cannot be decomposed into conjugate terms, computing the posterior becomes challenging. For example, when the prior is Gaussian and the likelihood is non-Gaussian, such as a logistic distribution used to model categorical variables. One approach to obtaining the posterior in such cases is to use variational inference.

Variational Inference using Stochastic Gradient Methods

In variational inference (VI), we approximate the posterior p(zy)p\mathbf{(z\mid y)} by minimizing the KL divergence between the posterior and the variational distribution q(z;λ)q(\mathbf{z}; \lambda) parametrized by the natural parameters λ\lambda (detailed derivation shown in my tutorial on Variational Gaussian Approximation):

KL(q(z)p(zy))Minimize=Eq[lnq(z)]Eq[lnp(z,y)]+lnp(y) \underbrace{\text{KL}(q(\mathbf{z}) || p(\mathbf{z} | \mathbf{y}))}_\text{Minimize} = \mathbb{E}_{q} [\ln q(\mathbf{z})] - \mathbb{E}_{q} [\ln p(\mathbf{z}, \mathbf{y})] + \ln p(\mathbf{y})

But the above requires us to compute the intractable marginal p(y)p(\mathbf{y}). Instead, we maximize the following, called the Expected Lower Bound (ELBO), which does not include the marginal p(y)p(\mathbf{y}) and is equivalent to the above KL up to an added constant:

ELBOMaximize:=L(q(z))=Eq[lnp(z,y)]Eq[lnq(z)] \underbrace{\text{ELBO}}_\text{Maximize} := \mathcal{L}(q(\mathbf{z})) = \mathbb{E}_{q} [\ln p(\mathbf{z}, \mathbf{y})] - \mathbb{E}_{q} [\ln q(\mathbf{z})]

We can optimize the ELBO with respect to the variational distribution’s natural parameters λ\lambda of q(z)q(\mathbf{z}) using stochastic gradient methods, such as the following stochastic gradient descent (ascent, since we are maximizing the ELBO) algorithm:

λt+1=λt+ρt[^λL(λt)] \lambda_{t+1} = \lambda_{t} + \rho_t \left[ \hat{\nabla}_\lambda \mathcal{L}(\lambda_t) \right]

here, ^λ\hat{\nabla}_\lambda represents the stochastic gradients of the ELBO L(q(z))\mathcal{L}(q(\mathbf{z})) with respect to the natural parameters λ\lambda at λ=λt\lambda = \lambda_t.

Limitations of Stochastic Gradient Methods

Stochastic gradient methods can be applied to a wide variety of inference problems and exhibit good scalability. However, a naive application of these methods could have the following limitations:

  • The efficiency and convergence rate may depend on the parameterization used for the variational distribution q(z)q(\mathbf{z}). For more details, refer to my tutorial on Variational Gaussian Approximation.

  • The parameters λ\lambda of the variational distribution q(z)q(\mathbf{z}) exist in a Riemannian space where the steepest descent direction is not always aligned with the gradient direction. This poses a challenge because conventional gradient methods operate in Euclidean spaces. This issue becomes apparent in the following alternate formulation of stochastic gradient descent (ascent):

    λt+1=argmaxλΩλ,^λL(λt)12ρλλt2 \lambda_{t+1} = \text{arg} \max_{\lambda \in \Omega} \langle \lambda, \hat{\nabla}_\lambda \mathcal{L}(\lambda_t) \rangle - \frac{1}{2\rho}||\lambda-\lambda_t||^2

    where Ω\Omega is the set of valid natural parameters, ρ\rho is the step size and .\mid\mid.\mid\mid is the Euclidean norm. The norm term ensures that the new λt+1\lambda_{t+1} is close to the previous λt\lambda_{t}. However, since the λ\lambda parameters do not exist in a Euclidean space, the Euclidean norm could slow the convergence rate of stochastic gradient descent. This issue is well illustrated by the following figure from Khan and Nielsen, 2018:

    alt text

    We observe that both cases have the same Euclidean distance, even though the left figure shows no overlap between the two distributions, while the right figure exhibits significant overlap.

  • Consider the case when the joint distribution p(y,z)p(\mathbf{y}, \mathbf{z}) can be decomposed into a set of conjugate and non-conjugate terms. For instance, when using VI for Gaussian processes, it’s common to fix the variational distribution q(z)q(\mathbf{z}) to be a Gaussian, and the prior p(z)p(\mathbf{z}) to also be a Gaussian, even if the likelihood is non-Gaussian. In such cases, even though the exact posterior might not belong to the Gaussian distribution, we enforce conjugacy by fixing the prior and posterior distributions. Consequently, we can say that the joint distribution p(y,z)p(\mathbf{y}, \mathbf{z}) is decomposed into a non-conjugate component (likelihood p(yz)p(\mathbf{y \mid z})) and a conjugate component (prior p(z)p(\mathbf{z})). When addressing such problems, it is possible to derive closed-form expressions for the updates from the conjugate terms in the ELBO and avoid any stochastic approximations.

Conjugate-Computation Variational Inference (CVI)

Conjugate-Computation Variational Inference addresses the aforementioned limitations and makes the following two assumptions:

Assumption 1

The variational distribution q(z;λ)q(\mathbf{z}; \lambda) is a minimal exponential-family distribution:

q(z;λ)=h(z)exp{ϕ(z),λA(λ)} q(\mathbf{z}; \lambda) = h(\mathbf{z}) \exp \left\{ \langle \phi(\mathbf{z}), \lambda \rangle - A(\lambda) \right\}

with λ\lambda as the natural parameters. Refer to my tutorial on natural parameters for more details. The minimal representation implies a one-to-one mapping between the natural parameters λ\lambda and the mean parameters η:=Eq[ϕ(z)]\eta := \mathbb{E}_q[\phi(\mathbf{z})]. Indeed, the ability to switch between the natural parametrization and mean parametrization plays a critical role in deriving the CVI method.

Assumption 2

The joint distribution p(y,z)p(\mathbf{y}, \mathbf{z}) can be decomposed into a non-conjugate term p~nc\tilde{p}_{nc} and a conjugate term p~c\tilde{p}_c. The conjugate term p~c\tilde{p}_c parameterized with λc\lambda_c has the same form as the variational distribution q(z)q(\mathbf{z}):

p(y,z)p~nc(y,z)p~c(y,z)p~c(y,z)h(z)expϕ(z),λc \begin{aligned} p(\mathbf{y}, \mathbf{z}) & \propto \tilde{p}_{nc}(\mathbf{y, z})\tilde{p}_c(\mathbf{y, z}) \\ \tilde{p}_c(\mathbf{y, z}) & \propto h(\mathbf{z}) \exp {\langle \phi(\mathbf{z}), \lambda_c \rangle} \end{aligned}

Note that these assumptions are not particularly restrictive. Indeed, the exponential-family distribution is a rich class of distributions, which includes the Gaussian distribution. Also, it is common to assume a Gaussian prior p~c\tilde{p}_c when using VI (even for a non-Gaussian likelihood p~nc\tilde{p}_{nc}), which, together with Assumption 1, satisfies Assumption 2. However, if there is no conjugate term in the joint distribution p(y,z)p(\mathbf{y}, \mathbf{z}), CVI might not offer any advantage over stochastic gradient-based VI.

Method

The derivation of the CVI method begins by considering the following mirror-descent algorithm, which replaces the Euclidean norm in the stochastic gradient descent (ascent) algorithm with the Bergman divergence. This algorithm operates on the mean parametrization η\eta of the variational distribution q(z)q(\mathbf{z}) to optimize the ELBO:

ηt+1=argmaxηMη,^ηL~(ηt)1βtBA(ηηt) \eta_{t+1} = \text{arg} \max_{\eta \in \mathcal{M}} \langle \eta, \hat{\nabla}_\eta \tilde{\mathcal{L}}(\eta_t) \rangle - \frac{1}{\beta_t} \mathbb{B}_{A^*}(\eta||\eta_t)

here, L~\tilde{\mathcal{L}} is the ELBO that operates on the mean parameters η\eta of the variational distribution q(z)q(\mathbf{z}) instead of the natural parameters λ\lambda. A\*(η)A^{\*}(\eta) is the convex-conjugate of the log-partition function A(λ)A(\lambda), BA\*\mathbb{B}_{A^\*} is the Bergman divergence induced by AA^* over M\mathcal{M}, and βt>0\beta_t > 0 is the step size.

The exact form of the Bergman divergence varies depending on the space to which the variables belong. In our case, for exponential-family distributions, the Bergman divergence is equal to the KL divergence (refer to Amari, 2016 chapter 1 for the detailed derivation), resulting in the following updates:

ηt+1=argmaxηMη,^ηL~(ηt)1βtKL(q(z;η)qt(z;ηt)) \eta_{t+1} = \text{arg} \max_{\eta \in \mathcal{M}} \langle \eta, \hat{\nabla}_\eta \tilde{\mathcal{L}}(\eta_t) \rangle - \frac{1}{\beta_t} \text{KL}(q(\mathbf{z};\eta)||q_t(\mathbf{z};\eta_t))

Using the KL term to regularize our updates instead of using the Euclidean norm makes more sense, as the KL divergence better capture the distance between two distributions (note the the KL divergence is asymmetric). The above updates are also equal to using natural gradient descent (ascent) in the natural parameter space Ω\Omega, where the gradient update are scaled by the Fisher information matrix F(λ)\mathbf{F}(\lambda) to align the gradients with the steepest descent direction when operating in Riemannian spaces, as we do here:

λt+1=argmaxλΩλ,^λL(λt)12ρt(λλt)F(λt)(λλt)    λt+1=λt+ρt[F(λt)]1^λL(λt)natural gradient ~λ \begin{aligned} \lambda_{t+1} &= \text{arg} \max_{\lambda \in \Omega} \langle \lambda, \hat{\nabla}_\lambda \mathcal{L}(\lambda_t) \rangle - \frac{1}{2\rho_t} (\lambda - \lambda_t)^\top \mathbf{F}(\lambda_t) (\lambda - \lambda_t) \\ \implies & \lambda_{t+1} = \lambda_t + \rho_t \underbrace{[\mathbf{F}(\lambda_t)]^{-1} \hat{\nabla}_\lambda \mathcal{L}(\lambda_t)}_\text{natural gradient $\tilde{\nabla}_\lambda$} \end{aligned}

here the Fisher information matrix for the exponential family is given by: F(λ)=Eqλ[λlogqλ(z)λlogqλ(z)]\mathbf{F}(\lambda) = \mathbb{E}_{q_\lambda}[\nabla_\lambda \log q_\lambda(\mathbf{z}) \nabla_\lambda \log q_\lambda(\mathbf{z})^\top].

However, Khan and lin, 2017 showed that the natural gradients of the terms in the ELBO L\mathcal{L} that are conjugate to each other (p~c(y,z)\tilde{p}_c(\mathbf{y, z}) and q(z)q(\mathbf{z})) have a simple closed-form solution that does not require explicit computation and inversion of the Fisher information matrix. The following is the derivation of the above mentioned closed-form solution.

From assumption 2, the ELBO L\mathcal{L} can be written as follows:

L~(η)=L(λ)=Eq[lnp~nc(z,y)]+Eq[lnp~c(z,y)]Eq[lnq(z)] \tilde{\mathcal{L}}(\eta) = \mathcal{L}(\lambda) = \mathbb{E}_{q} [\ln \tilde{p}_{nc}(\mathbf{z}, \mathbf{y})] + \mathbb{E}_{q} [\ln \tilde{p}_{c}(\mathbf{z}, \mathbf{y})] - \mathbb{E}_{q} [\ln q(\mathbf{z})]

Now consider the dot product η,^ηL~(ηt)\langle \eta, \hat{\nabla}_\eta \tilde{\mathcal{L}}(\eta_t) \rangle in the mirror-descent update equation. If we ignore the non-conjugate term (p~nc(y,z)\tilde{p}_{nc}(\mathbf{y, z})) of the ELBO L\mathcal{L} for the time being and compute the dot product, we get the following:

η,^η(Eq[lnp~c(z,y)]Eq[lnq(z)])=η,^ηEq[lnp~c(z,y)q(z)]=η,^ηEq[ϕ(z),λcλ+A(λ)]=η,^η[η,λcλ+A(λ)]^η[η,λcλ+A(λ)]=^ηη,λcλ+^ηA(λ)=^ηη(λcλ)+η^η(λcλ)+^ηA(λ)=λcλ+η^ηλcη^ηλ+^ηA(λ)=λcλ+0ηFλ1^λλ+Fλ1^λA(λ)=λcληFλ1+Fλ1η=λcλ \begin{aligned} \langle \eta, \hat{\nabla}_\eta (\mathbb{E}_{q} [\ln \tilde{p}_{c}(\mathbf{z}, \mathbf{y})] &- \mathbb{E}_{q} [\ln q(\mathbf{z})]) \rangle = \left\langle \eta, \hat{\nabla}_\eta \mathbb{E}_{q} \left[\ln \frac{\tilde{p}_{c}(\mathbf{z}, \mathbf{y})}{q(\mathbf{z})}\right] \right\rangle \\ &= \left\langle \eta, \hat{\nabla}_\eta \mathbb{E}_{q} \left[ \langle \phi(\mathbf{z}), \lambda_c-\lambda \rangle + A(\lambda) \right] \right\rangle \\ &= \left\langle \eta, \hat{\nabla}_\eta \left[ \langle \eta, \lambda_c-\lambda \rangle + A(\lambda) \right] \right\rangle \\ \hat{\nabla}_\eta \left[ \langle \eta, \lambda_c-\lambda \rangle + A(\lambda) \right] &= \hat{\nabla}_\eta \langle \eta, \lambda_c-\lambda \rangle + \hat{\nabla}_\eta A(\lambda) \\ &= \hat{\nabla}_\eta \eta (\lambda_c-\lambda) + \eta \hat{\nabla}_\eta (\lambda_c-\lambda) + \hat{\nabla}_\eta A(\lambda) \\ &= \lambda_c-\lambda + \eta \hat{\nabla}_\eta \lambda_c - \eta \hat{\nabla}_\eta \lambda + \hat{\nabla}_\eta A(\lambda) \\ &= \lambda_c-\lambda + 0 - \eta \mathbf{F}^{-1}_{\lambda} \hat{\nabla}_\lambda \lambda + \mathbf{F}^{-1}_{\lambda} \hat{\nabla}_\lambda A(\lambda) \\ &= \lambda_c-\lambda - \eta \mathbf{F}^{-1}_{\lambda} + \mathbf{F}^{-1}_{\lambda} \eta \\ &= \lambda_c-\lambda \end{aligned}
Click for relevant identities

  • Expectation of the sufficient statistics

    Eq[ϕ(z)]=η \mathbb{E}_{q} [\phi(\mathbf{z})] = \eta
  • Relation between η\eta and λ\lambda

    η=λA(λ) \eta = \nabla_\lambda A(\lambda)
  • Natural gradients (refer to my tutorial on natural parameters for more details.)

    f(λ)η=ληf(λ)λ=[ηλ]1f(λ)λ=[λA(λ)λ]1f(λ)λ=[2A(λ)λ2]1Fλ1f(λ)λ \begin{aligned} \frac{\partial \mathbf{f}(\lambda)}{\partial \eta} &= \frac{\partial \lambda}{\partial \eta} \frac{\partial \mathbf{f}(\lambda)}{\partial \lambda} \\ &= \left[ \frac{\partial \eta}{\partial \lambda} \right]^{-1} \frac{\partial \mathbf{f}(\lambda)}{\partial \lambda} \\ &= \left[ \frac{\partial}{\partial \lambda} \frac{\partial A(\lambda)}{\partial \lambda} \right]^{-1} \frac{\partial \mathbf{f}(\lambda)}{\partial \lambda} \\ &= \underbrace{\left[ \frac{\partial^2 A(\lambda)}{\partial \lambda^2} \right]^{-1}}_{\mathbf{F}_\lambda^{-1}} \frac{\partial \mathbf{f}(\lambda)}{\partial \lambda} \end{aligned}

    It is well known that the second derivative of A(λ)A(\lambda) is equal to the fisher information matrix F\mathbf{F} for exponential-family distributions.

  • Derivative of dot product

    xa,b=(xa)b+a(xb) \nabla_{x} \langle \mathbf{a, b} \rangle = (\nabla_{x} \mathbf{a}) \mathbf{b} + \mathbf{a} (\nabla_{x}\mathbf{b})

As we can see, the gradient updates from the conjugate terms (p~c(y,z)\tilde{p}_c(\mathbf{y, z}) and q(z)q(\mathbf{z})) in the ELBO L\mathcal{L} have a closed-form equation, which is simply λcλ\lambda_c-\lambda. Moreover, note that the mirror-descent algorithm was set up to operate on the mean parameters η\eta of the variational distribution q(z)q(\mathbf{z}) and used stochastic gradients. However, the above derivation shows that the gradient updates are equal to the natural gradients of the conjugate terms with respect to the natural parameters λ\lambda of variational distribution q(z)q(\mathbf{z}). Indeed, the following theorem from Khan and Nielsen, 2018 states:

Theorem 1. For an exponential-family in the minimal representation, the natural gradient with respect to λ\lambda is equal to the gradient with respect to η\eta, and vice versa, i.e.,

~λL(λ)=ηL~(η) and ~ηL~(η)=λL(λ) \tilde{\nabla}_\lambda \mathcal{L}(\lambda) = \nabla_\eta \tilde{\mathcal{L}}(\eta) \text{ and } \tilde{\nabla}_\eta \tilde{\mathcal{L}}(\eta) = \nabla_\lambda \mathcal{L}(\lambda)

Based on the above derivations, we can write the natural gradient of the full ELBO L\mathcal{L}—including the non-conjugate term—as follows:

~λL(λ)=~λEq[lnp~nc(z,y)]+~λEq[lnp~c(z,y)]~λEq[lnq(z)]=~λEq[lnp~nc(z,y)]+λcλ=^ηEq[lnp~nc(z,y)]+λcλ \begin{aligned} \tilde{\nabla}_\lambda \mathcal{L}(\lambda) &= \tilde{\nabla}_\lambda \mathbb{E}_{q} [\ln \tilde{p}_{nc}(\mathbf{z}, \mathbf{y})] + \tilde{\nabla}_\lambda \mathbb{E}_{q} [\ln \tilde{p}_{c}(\mathbf{z}, \mathbf{y})] - \tilde{\nabla}_\lambda \mathbb{E}_{q} [\ln q(\mathbf{z})] \\ &= \tilde{\nabla}_\lambda \mathbb{E}_{q} [\ln \tilde{p}_{nc}(\mathbf{z}, \mathbf{y})] + \lambda_{c} - \lambda \\ &= \hat{\nabla}_\eta \mathbb{E}_{q} [\ln \tilde{p}_{nc}(\mathbf{z}, \mathbf{y})] + \lambda_{c} - \lambda \\ \end{aligned}

Note that in the last step, for the non-conjugate term p~nc(z,y)\tilde{p}_{nc}(\mathbf{z}, \mathbf{y}), we switched from natural gradients with respect to the mean parameters η\eta to the stochastic gradients with respect to the natural parameters λ\lambda. Therefore, even for the non-conjugate term, the approach does not require explicit computation and inversion of the Fisher information matrix. However, there is no closed-form equation for the expectation of the stochastic gradients of the non-conjugate terms Eq[lnp~nc(z,y)]\mathbb{E}_{q} [\ln \tilde{p}_{nc}(\mathbf{z}, \mathbf{y})]; as such, one has to resort to Monte Carlo methods to compute it. Nonetheless, compared to the vanilla variational inference method, which computes stochastic gradients even for the conjugate terms, CVI is computationally faster.

But we are not done yet! The above only demonstrated how to compute the gradients, which is only a part of mirror-descent algorithm. The following details the remaining derivation of the mirror-descent algorithm’s update step using the above gradients. First, consider the dot product between the mean parameter η\eta and the derivative of the conjugate terms of the ELBO L\mathcal{L}:

η,^ηEq[lnp~c(z,y)q(z)]η=ηt=η,λcλt=Eq[ϕ(z),λcλt+A(λt)]+c=Eq[log(p~c(z,y)qt(z))]+c=Eq[log(p~c(z,y)q(z))]+Eq[log(q(z)qt(z))]KL divergence+c \begin{aligned} \left\langle \eta, \hat{\nabla}_\eta \mathbb{E}_{q} \left[\ln \frac{\tilde{p}_{c}(\mathbf{z}, \mathbf{y})}{q(\mathbf{z})}\right]\biggr\rvert_{\eta=\eta_t} \right\rangle &= \left\langle \eta, \lambda_{c} - \lambda_t \right\rangle \\ & \mkern-150mu = \mathbb{E}_q \left[\langle \phi(\mathbf{z}), \lambda_{c} - \lambda_t \rangle + A(\lambda_t)\right] + c \\ & \mkern-150mu = \mathbb{E}_q \left[\log \left(\frac{\tilde{p}_c(\mathbf{z}, \mathbf{y})}{q_t(\mathbf{z})}\right)\right] + c \\ & \mkern-150mu = \mathbb{E}_q \left[\log \left(\frac{\tilde{p}_c(\mathbf{z}, \mathbf{y})}{q(\mathbf{z})}\right)\right] + \underbrace{\mathbb{E}_q \left[\log \left(\frac{q(\mathbf{z})}{q_t(\mathbf{z})}\right)\right]}_\text{KL divergence} + c \\ \end{aligned}

here, qtq_t is the variational distribution at mirror-descent iteration tt with fixed parameters ηt\eta_t, and qq is the variational distribution that is currently being optimized. Introducing qtq_t into the above formulation allows us to derive the update steps of the mirror-descent algorithm. Indeed, now consider the above dot product along with the non-conjugate term p~nc(z,y)\tilde{p}_{nc}(\mathbf{z}, \mathbf{y}):

η,^ηL~(ηt)=η,^ηEq[lnp~nc(z,y)]+^ηEq[lnp~c(z,y)]^ηEq[lnq(z)]=η,^ηEq[lnp~nc(z,y)]+η,^ηEq[lnp~c(z,y)]^ηEq[lnq(z)]=η,^ηEq[lnp~nc(z,y)]+Eq[log(p~c(z,y)q(z))]+Eq[log(q(z)qt(z))]+c \begin{aligned} &\langle \eta, \hat{\nabla}_\eta \tilde{\mathcal{L}}(\eta_t) \rangle = \left\langle \eta, \hat{\nabla}_\eta \mathbb{E}_{q} [\ln \tilde{p}_{nc}(\mathbf{z}, \mathbf{y})] + \hat{\nabla}_\eta \mathbb{E}_{q} [\ln \tilde{p}_{c}(\mathbf{z}, \mathbf{y})] - \hat{\nabla}_\eta \mathbb{E}_{q} [\ln q(\mathbf{z})] \right\rangle \\ &= \left\langle \eta, \hat{\nabla}_\eta \mathbb{E}_{q} [\ln \tilde{p}_{nc}(\mathbf{z}, \mathbf{y})] \right\rangle + \left\langle \eta, \hat{\nabla}_\eta \mathbb{E}_{q} [\ln \tilde{p}_{c}(\mathbf{z}, \mathbf{y})] - \hat{\nabla}_\eta \mathbb{E}_{q} [\ln q(\mathbf{z})] \right\rangle \\ &= \left\langle \eta, \hat{\nabla}_\eta \mathbb{E}_{q} [\ln \tilde{p}_{nc}(\mathbf{z}, \mathbf{y})] \right\rangle + \mathbb{E}_q \left[\log \left(\frac{\tilde{p}_c(\mathbf{z}, \mathbf{y})}{q(\mathbf{z})}\right)\right] + \mathbb{E}_q \left[\log \left(\frac{q(\mathbf{z})}{q_t(\mathbf{z})}\right)\right] + c \\ \end{aligned}

Plugging the above into the mirror-descent update equation gives us the following:

    ηt+1=argmaxηMη,^ηL~(ηt)1βtKL(q(z;η)qt(z;ηt))=argmaxηMη,^ηEq[lnp~nc(z,y)]+Eq[log(p~c(z,y)q(z))]+Eq[log(q(z)qt(z))]1βtKL(q(z;η)qt(z;ηt)) \begin{aligned} &\implies \eta_{t+1} = \text{arg} \max_{\eta \in \mathcal{M}} \langle \eta, \hat{\nabla}_\eta \tilde{\mathcal{L}}(\eta_t) \rangle - \frac{1}{\beta_t} \text{KL}(q(\mathbf{z};\eta)||q_t(\mathbf{z};\eta_t)) \\ &= \text{arg} \max_{\eta \in \mathcal{M}} \left\langle \eta, \hat{\nabla}_\eta \mathbb{E}_{q} [\ln \tilde{p}_{nc}(\mathbf{z}, \mathbf{y})] \right\rangle + \mathbb{E}_q \left[\log \left(\frac{\tilde{p}_c(\mathbf{z}, \mathbf{y})}{q(\mathbf{z})}\right)\right] \\ & \quad \quad \quad \quad \quad \quad + \mathbb{E}_q \left[\log \left(\frac{q(\mathbf{z})}{q_t(\mathbf{z})}\right)\right] - \frac{1}{\beta_t} \text{KL}(q(\mathbf{z};\eta)||q_t(\mathbf{z};\eta_t)) \\ \end{aligned}
Click to see the full derivation

Dr. Khan and Dr. Lin did an excellent job showing the derivation in their paper’s appendix (Khan and lin, 2017), but I found some steps confusing. Therefore, I added a few extra intermediate steps to make it clear. Also, I omitted the argmax\text{arg} \max because of limited space.

=η,^ηEq[lnp~nc(z,y)]+Eq[log(p~c(z,y)q(z))]+Eq[log(q(z)qt(z))]1βtEq[log(q(z)qt(z))]=η,^ηEq[lnp~nc(z,y)]+Eq[log(p~c(z,y)q(z))]1βtβtEq[log(q(z)qt(z))]=Eq{ϕ(z),^ηEq[lnp~nc(z,y)]+log(p~c(z,y)q(z))1βtβt[log(q(z)qt(z))]}=Eqlogexp{ϕ(z),^ηEq[lnp~nc(z,y)]+log(p~c(z,y)q(z))1βtβt[log(q(z)qt(z))]}=Eq[logexp{ϕ(z),^ηEq[lnp~nc(z,y)]}exp{log(p~c(z,y)q(z))}exp{1βtβt[log(q(z)qt(z))]}]=Eq[logexp{ϕ(z),^ηEq[lnp~nc(z,y)]}(p~c(z,y)q(z))(q(z)qt(z))1βtβt]=Eq[logexp{ϕ(z),^ηEq[lnp~nc(z,y)]}p~c(z,y)qt(z)1βtβtq(z)1βtβtq(z)]=Eq[logexp{ϕ(z),^ηEq[lnp~nc(z,y)]}p~c(z,y)qt(z)1βtβtq(z)1+1βtβt]=Eq[logexp{ϕ(z),^ηEq[lnp~nc(z,y)]}p~c(z,y)qt(z)1βtβtq(z)1βt]=1βtEq[log(exp{ϕ(z),^ηEq[lnp~nc(z,y)]}p~c(z,y))βtqt(z)(1βt)q(z)] \begin{aligned} &= \left\langle \eta, \hat{\nabla}_\eta \mathbb{E}_{q} [\ln \tilde{p}_{nc}(\mathbf{z}, \mathbf{y})] \right\rangle + \mathbb{E}_q \left[\log \left(\frac{\tilde{p}_c(\mathbf{z}, \mathbf{y})}{q(\mathbf{z})}\right)\right] + \mathbb{E}_q \left[\log \left(\frac{q(\mathbf{z})}{q_t(\mathbf{z})}\right)\right] - \frac{1}{\beta_t} \mathbb{E}_q \left[\log \left(\frac{q(\mathbf{z})}{q_t(\mathbf{z})}\right)\right] \\ &= \left\langle \eta, \hat{\nabla}_\eta \mathbb{E}_{q} [\ln \tilde{p}_{nc}(\mathbf{z}, \mathbf{y})] \right\rangle + \mathbb{E}_q \left[\log \left(\frac{\tilde{p}_c(\mathbf{z}, \mathbf{y})}{q(\mathbf{z})}\right)\right] - \frac{1-\beta_t}{\beta_t} \mathbb{E}_q \left[\log \left(\frac{q(\mathbf{z})}{q_t(\mathbf{z})}\right)\right] \\ &= \mathbb{E}_q \left\{ \left\langle \phi(\mathbf{z}), \hat{\nabla}_\eta \mathbb{E}_{q} [\ln \tilde{p}_{nc}(\mathbf{z}, \mathbf{y})] \right\rangle + \log \left(\frac{\tilde{p}_c(\mathbf{z}, \mathbf{y})}{q(\mathbf{z})}\right) - \frac{1-\beta_t}{\beta_t} \left[\log \left(\frac{q(\mathbf{z})}{q_t(\mathbf{z})}\right)\right] \right\} \\ &= \mathbb{E}_q \log \exp \left\{ \left\langle \phi(\mathbf{z}), \hat{\nabla}_\eta \mathbb{E}_{q} [\ln \tilde{p}_{nc}(\mathbf{z}, \mathbf{y})] \right\rangle + \log \left(\frac{\tilde{p}_c(\mathbf{z}, \mathbf{y})}{q(\mathbf{z})}\right) - \frac{1-\beta_t}{\beta_t} \left[\log \left(\frac{q(\mathbf{z})}{q_t(\mathbf{z})}\right)\right] \right\} \\ &= \mathbb{E}_q \left[ \log \frac{\exp \left\{ \left\langle \phi(\mathbf{z}), \hat{\nabla}_\eta \mathbb{E}_{q} [\ln \tilde{p}_{nc}(\mathbf{z}, \mathbf{y})] \right\rangle \right\} \exp \left\{\log \left(\frac{\tilde{p}_c(\mathbf{z}, \mathbf{y})}{q(\mathbf{z})}\right) \right\}}{\exp \left\{ \frac{1-\beta_t}{\beta_t} \left[\log \left(\frac{q(\mathbf{z})}{q_t(\mathbf{z})}\right)\right] \right\}} \right] \\ &= \mathbb{E}_q \left[ \log \frac{\exp \left\{ \left\langle \phi(\mathbf{z}), \hat{\nabla}_\eta \mathbb{E}_{q} [\ln \tilde{p}_{nc}(\mathbf{z}, \mathbf{y})] \right\rangle \right\} \left(\frac{\tilde{p}_c(\mathbf{z}, \mathbf{y})}{q(\mathbf{z})}\right) }{ \left(\frac{q(\mathbf{z})}{q_t(\mathbf{z})}\right)^{\frac{1-\beta_t}{\beta_t}} } \right] \\ &= \mathbb{E}_q \left[ \log \frac{\exp \left\{ \left\langle \phi(\mathbf{z}), \hat{\nabla}_\eta \mathbb{E}_{q} [\ln \tilde{p}_{nc}(\mathbf{z}, \mathbf{y})] \right\rangle \right\} \tilde{p}_c(\mathbf{z}, \mathbf{y}) q_t(\mathbf{z})^{\frac{1-\beta_t}{\beta_t}} }{ q(\mathbf{z})^{\frac{1-\beta_t}{\beta_t}} q(\mathbf{z}) } \right] \\ &= \mathbb{E}_q \left[ \log \frac{\exp \left\{ \left\langle \phi(\mathbf{z}), \hat{\nabla}_\eta \mathbb{E}_{q} [\ln \tilde{p}_{nc}(\mathbf{z}, \mathbf{y})] \right\rangle \right\} \tilde{p}_c(\mathbf{z}, \mathbf{y}) q_t(\mathbf{z})^{\frac{1-\beta_t}{\beta_t}} }{ q(\mathbf{z})^{1+\frac{1-\beta_t}{\beta_t}} } \right] \\ &= \mathbb{E}_q \left[ \log \frac{\exp \left\{ \left\langle \phi(\mathbf{z}), \hat{\nabla}_\eta \mathbb{E}_{q} [\ln \tilde{p}_{nc}(\mathbf{z}, \mathbf{y})] \right\rangle \right\} \tilde{p}_c(\mathbf{z}, \mathbf{y}) q_t(\mathbf{z})^{\frac{1-\beta_t}{\beta_t}} }{ q(\mathbf{z})^{\frac{1}{\beta_t}} } \right] \\ &= \frac{1}{\beta_t} \mathbb{E}_q \left[ \log \frac{ \left( \exp \left\{ \left\langle \phi(\mathbf{z}), \hat{\nabla}_\eta \mathbb{E}_{q} [\ln \tilde{p}_{nc}(\mathbf{z}, \mathbf{y})] \right\rangle \right\} \tilde{p}_c(\mathbf{z}, \mathbf{y}) \right)^{\beta_t} q_t(\mathbf{z})^{(1-\beta_t)} }{ q(\mathbf{z}) } \right] \\ \end{aligned}
=argmaxηM1βtKL[(exp{ϕ(z),^ηEq[lnp~nc(z,y)]}p~c(z,y))βtqt(z)(1βt)q(z)]=argminηM1βtKL[(exp{ϕ(z),^ηEq[lnp~nc(z,y)]}p~c(z,y))βtqt(z)(1βt)unnormalize exponential family distributionq(z)] \begin{aligned} &= \text{arg} \max_{\eta \in \mathcal{M}} -\frac{1}{\beta_t} \text{KL} \left[ \left( \exp \left\{ \left\langle \phi(\mathbf{z}), \hat{\nabla}_\eta \mathbb{E}_{q} [\ln \tilde{p}_{nc}(\mathbf{z}, \mathbf{y})] \right\rangle \right\} \tilde{p}_c(\mathbf{z}, \mathbf{y}) \right)^{\beta_t} q_t(\mathbf{z})^{(1-\beta_t)} || q(\mathbf{z}) \right] \\ &= \text{arg} \min_{\eta \in \mathcal{M}} \frac{1}{\beta_t} \text{KL} \left[ \underbrace{ \left( \exp \left\{ \left\langle \phi(\mathbf{z}), \hat{\nabla}_\eta \mathbb{E}_{q} [\ln \tilde{p}_{nc}(\mathbf{z}, \mathbf{y})] \right\rangle \right\} \tilde{p}_c(\mathbf{z}, \mathbf{y}) \right)^{\beta_t} q_t(\mathbf{z})^{(1-\beta_t)}}_\text{unnormalize exponential family distribution} || q(\mathbf{z}) \right] \\ \end{aligned}

The above proves that the optimal variational distribution q(z)q(\mathbf{z}) is obtained when the above KL is minimized. Moreover, this can be written as a recursion. Let gt:=^ηEq[lnp~nc(z,y)]η=ηt\mathbf{g}_t := \hat{\nabla}_\eta \mathbb{E}_{q} [\ln \tilde{p}_{nc}(\mathbf{z}, \mathbf{y})]\rvert_{\eta=\eta_t} and q1(z)p~c(z,y)q_1(\mathbf{z}) \propto \tilde{p}_{c}(\mathbf{z}, \mathbf{y}), then we obtain the following iterates:

q1(z)p~c(z,y)β0p~c(z,y)(1β0)=p~c(z,y)q2(z)expϕ(z),β1g1p~c(z,y)β1q1(z)(1β1)=expϕ(z),β1g1p~c(z,y)β1p~c(z,y)(1β1)=expϕ(z),β1g1p~c(z,y)=expϕ(z),λ~1p~c(z,y)q3(z)expϕ(z),β2g2p~c(z,y)β2q2(z)(1β2)=expϕ(z),β2g2+(1β2)λ~1p~c(z,y)=expϕ(z),λ~2p~c(z,y) \begin{aligned} q_1(\mathbf{z}) & \propto \tilde{p}_{c}(\mathbf{z}, \mathbf{y})^{\beta_0} \tilde{p}_{c}(\mathbf{z}, \mathbf{y})^{(1-\beta_0)} = \tilde{p}_{c}(\mathbf{z}, \mathbf{y}) \\ \\ q_2(\mathbf{z}) & \propto \exp \langle \phi(\mathbf{z}), \beta_1 \mathbf{g}_1 \rangle \tilde{p}_{c}(\mathbf{z}, \mathbf{y})^{\beta_1} q_1(\mathbf{z})^{(1-\beta_1)} \\ &= \exp \langle \phi(\mathbf{z}), \beta_1 \mathbf{g}_1 \rangle \tilde{p}_{c}(\mathbf{z}, \mathbf{y})^{\beta_1} \tilde{p}_{c}(\mathbf{z}, \mathbf{y})^{(1-\beta_1)} \\ &= \exp \langle \phi(\mathbf{z}), \beta_1 \mathbf{g}_1 \rangle \tilde{p}_{c}(\mathbf{z}, \mathbf{y}) \\ &= \exp \langle \phi(\mathbf{z}), \tilde{\lambda}_1 \rangle \tilde{p}_{c}(\mathbf{z}, \mathbf{y}) \\ \\ q_3(\mathbf{z}) & \propto \exp \langle \phi(\mathbf{z}), \beta_2 \mathbf{g}_2 \rangle \tilde{p}_{c}(\mathbf{z}, \mathbf{y})^{\beta_2} q_2(\mathbf{z})^{(1-\beta_2)} \\ &= \exp \langle \phi(\mathbf{z}), \beta_2 \mathbf{g}_2 + (1-\beta_2)\tilde{\lambda}_1 \rangle \tilde{p}_{c}(\mathbf{z}, \mathbf{y}) \\ &= \exp \langle \phi(\mathbf{z}), \tilde{\lambda}_2 \rangle \tilde{p}_{c}(\mathbf{z}, \mathbf{y}) \\ \end{aligned}

The above iterates can be succinctly written as the following recursion by initializing λ~0=0\tilde{\lambda}_0 = 0 (Theorem 1/Algorithm 1 in Khan and lin, 2017):

for t=1,2,3,, doλ~t=(1βt)λ~t1+βt^ηEq[lnp~nc(z,y)]η=ηtλt+1=λ~t+λpriorend for \begin{aligned} \textbf{for }& t=1, 2, 3, \cdots, \textbf{ do} \\ &\tilde{\lambda}_t = (1-\beta_t)\tilde{\lambda}_{t-1} + \beta_t \hat{\nabla}_\eta \mathbb{E}_{q} [\ln \tilde{p}_{nc}(\mathbf{z}, \mathbf{y})]\rvert_{\eta=\eta_t} \\ &\lambda_{t+1} = \tilde{\lambda}_t + \lambda_\text{prior} \\ \textbf{end fo} & \textbf{r} \end{aligned}

The above follows from the property that the posterior parameters λ\lambda of the product of nn exponential-family distributions with parameters λi=1,2,,n\lambda_{i=1, 2, \cdots, n} is simply the sum of those parameters (i.e., λ=i=1,2,,nλi\lambda = \sum_{i=1, 2, \cdots, n} \lambda_{i}). However, in the above recursion, the non-conjugate term p~nc(z,y)\tilde{p}_{nc}(\mathbf{z}, \mathbf{y}) is being approximated by an exponential-family distribution parametrized by λ~t\tilde{\lambda}_t. Therefore, we need to use the above recursion until convergence instead of performing a one-step inference, as we would do in a fully-conjugate model.

Conclusion

Stochastic gradient-based VI is a versatile approach applicable to a wide range of probabilistic models. However, assuming that the variational distribution q(z)q(\mathbf{z}) follows an exponential-family distribution, and that the probabilistic model can be decomposed into conjugate and non-conjugate terms, leads to the far more computationally efficient CVI algorithm. The CVI algorithm approximates the non-conjugate terms of the probabilistic model with an exponential-family distribution and leverages the properties of mapping exponential-family distributions from the mean parametrization η\eta to the natural parametrization λ\lambda (and vice versa) to obtain computationally efficient natural-gradient descent updates.

Furthermore, I have only scratched the surface of the advantages of CVI. Indeed, the computational cost of CVI is invariant to the parametrization of the variational distribution, unlike stochastic gradient-based VI. It is modular when considering mean-field approximations, as we can use message-passing. We can also use doubly stochastic updates, and it enables structured inference in deep models. These advantages can be fully appreciated if you take the time to read Dr. Khan’s papers and work out the derivations and examples.


References

This article was based on what I learned from the following sources: