variational_inference

Variational Inference: A Review for Statisticians

General Problem Setting:

Consider a joint density of latent variables \(\mathbf{z} = z_{1:m}\) and observations \(\mathbf{x} = x_{1:n}\):

\[p(\mathbf{z}, \mathbf{x}) = p(\mathbf{z}) p(\mathbf{x}|\mathbf{z})\]

The main idea behind variational inference is to use optimization to compute the conditional density of the latent variables given the observations:

\[p(\mathbf{z} | \mathbf{x}) = \frac{p(\mathbf{z}) p(\mathbf{x}|\mathbf{z})}{p(\mathbf{x})}\]

This conditional can be used to produce point or interval estimates of the latent variables, form predictive densities of new data \(p(x^* | \mathbf{x})\) and more.

The quantity:

\[p(\mathbf{x}) = \int_\mathbf{z} p(\mathbf{z}, \mathbf{x}) d\mathbf{z}\]

is called the evidence, for many models, the evidence integral is unavailable in closed form or requires exponential time to compute.

Evidence Lower Bound

In variational inference, we specify a family \(\Theta\) of densities over the latent variables. Each \(q(\mathbf{z})\) is a candidate approximation to the exact conditional. Our goal is to find the best candidate, the one closest in KL divergence to the exact conditional. Inference now amounts to solving the following optimization problem:

\[q^*(\mathbf{z}) = \arg\min_{q(\mathbf{z}) \in \Theta} KL(q(\mathbf{z}) \;||\; p(\mathbf{z} | \mathbf{x})) = \int_{z} q(\mathbf{z})\log\frac{q(\mathbf{z})}{p(\mathbf{z}| \mathbf{x})} = E_{\mathbf{z} \sim q(\mathbf{z})}[\log q(\mathbf{z})] - E_{\mathbf{z} \sim q(\mathbf{z})}[\log p(\mathbf{z} | \mathbf{x})] = E_{\mathbf{z} \sim q(\mathbf{z})}[\log q(\mathbf{z})] - E_{\mathbf{z} \sim q(\mathbf{z})}[\log p(\mathbf{z}, \mathbf{x})] + \log p(\mathbf{x})\]

However, this objective is not computable because it requires computing the evidence \(\log p(\mathbf{x})\). Since we cannot compute KL, we optimize an alternative objective that is equivalent ot the KL up to an added constant:

\[ELBO(q) = E_{\mathbf{z} \sim q(\mathbf{z})} [\log p(\mathbf{z}, \mathbf{x})] - E_{\mathbf{z} \sim q(\mathbf{z})}[\log q(\mathbf{z})]\]

This function is called the evidence lower bound (ELBO), ELBO is the negative KL divergence plus a constant \(\log p(\mathbf{x})\) w.r.t \(q(\mathbf{z})\). Maximizing ELBO is equivalent to minimizing the KL divergence.


Property 1: ELBO is the Balance Between Likelihood and Prior

Continue from above equation of \(ELBO(q)\), we can rewrite it as (\(E[\cdot] := E_{\mathbf{z} \sim q(\mathbf{z})}[\cdot]\)):

\[ELBO(q) = E[\log p(\mathbf{z})] - E[\log q(\mathbf{z})] + E[\log p(\mathbf{x} | \mathbf{z})] = -KL(q(\mathbf{z}) | p(\mathbf{z})) + E[\log p(\mathbf{x} | \mathbf{z})]\]

In order to find a \(q(\mathbf{z})\) that maximize ELBO, we have to maximize the expected log likelihood of the data given the hidden variables and minimize the distance to the prior (so the prior acts as a regularizer).

Property 2: ELBO is the Lower Bound of the Log Evidence

ELBO lower bounds the log evidence:

\[\log p(\mathbf{x}) = KL(q(\mathbf{z}) \;||\; p(\mathbf{z} | \mathbf{x})) \geq ELBO (q), \quad \quad \forall q \in \Theta\]

REF

https://arxiv.org/pdf/2205.14415.pdf