Variational Inference from the Ground Up

Here at Civis Analytics we love to learn new and interesting things. These kind of things have a funny habit of turning out to be useful, even if we don't end up putting them directly into production for our clients or in Civis Platform.

In this spirit, I recently went down a rabbit hole on variational inference or VI. This subject encompasses a set of interrelated techniques which try to find approximations to probability density functions (PDFs).

To see why this might be useful, let's take a step back to reexamine the basics of Bayesian inference. In Bayesian inference, one usually wants to compute integrals over the posterior PDF of the parameters given the data, $p(z|D)$. A common example is computing the posterior mean of some model parameter while also marginalizing over other parameters, possibly hyperparameters. The posterior is

$$p(z|D) \propto p(D|z)p(z)$$

where $p(z)$ is the prior and $p(D|z)$ is the likelihood. Typically, you use some sort of sampling method (e.g., Markov Chain Monte Carlo or MCMC) to draw samples from $p(D|z)p(z)$ and then use those draws to compute things.

The idea of using variational inference for Bayesian inference can be summarized as follows: Instead of producing Monte Carlo samples directly from the posterior, find a PDF which approximates the posterior. We can then use this approximation in place of the samples for further analysis. This PDF is called the variational approximation. Numerically, this procedure can be very advantageous for a few reasons. First, it introduces a natural way to compute approximations of integrals from the posterior with varying degrees of fidelity. We can use a more coarse approximation when fully sampling the posterior would be infeasible and switch to a higher accuracy one when computational resources allow. Second, it allows us to cast a Bayesian inference problem into an optimization problem (as opposed to a sampling problem). This fact allows us to take advantage of the wide variety of tools designed for fitting neural networks (e.g., automatic differentiation, numerical libraries which work on GPUs, etc.). Finally, one of the best features of VI is that we can do it with small batches of data. This allows the method to scale to extremely large datasets.

The rest of this blog post is answering a few key questions about VI.

  1. What do we mean by saying "$q(z)$ approximates $p(z)$"?
  2. How do we efficiently compute $q(z)$?
  3. Where can I find code to do this for me?

Here and below, I will denote the variational approximation as $q(z)$ and the original PDF as $p(z)$. I will go ahead and suppress any conditional dependencies of $p(z)$ since any PDF will work in these methods.

In [1]:
# packages!
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import copy
import time
import scipy.special
import scipy.optimize
import scipy.stats

%matplotlib notebook
sns.set()

Dirty VI: Eww!

So we need to find a PDF to approximate $p(z)$. First however, we should probably describe in what sense we mean that $q(z)$ should approximate $p(z)$. A well-formulated answer to this question will be given below. For now, we are going to explore some simple and crude things.

Note: I don't recommend using any of these methods. I am using them here for pedagogical purposes.

Dirty VI I: Regression!

Here is a hack that works at least in one dimension. What if we build a set of test points $\{z_{0}, z_{1}, ..., z_{n}\}$ and then minimize the MSE between $\log p(z)$ and $\log q(z)$ at these points? Sure why not.

Assuming $\phi$ is a set of parameters that characterizes $q(z)$, this procedure looks like

$${\hat \phi} = \underset{\phi}{{\rm arg\,min}} \frac{\sum_{i} q(z_{i};\phi)\left[\log q(z_{i};\phi) - \log p(z_{i})\right]^{2}}{\sum_{i} q(z_{i};\phi)}$$

Here I decided to weight points that have more probability in the approximation $q(z)$, the idea being that points near the bulk of $q(z)$ are more important to get right. OK. How do we pick $z_{i}$? Well, let's use the known distribution of $q(z;\phi)$ at each step of the optimization to select a grid of points which sample the regions of highest probability density. Then we will minimize the loss above. Cool!

To make things a bit more specific, let's approximate a Students-t distribution with $\nu=3$ with a Gaussian distribution of some mean and variance. (I will explain this choice below, but it suffices to say that Gaussian distributions are useful approximations because they have an analytic differential entropy.) Now some code!

In [2]:
NU = 3
NU_NORM = scipy.special.gammaln((NU + 1)/2) - scipy.special.gammaln(NU/2) - 0.5 * np.log(NU*np.pi)
def logp(z):
    """log of students-t dist."""
    return -0.5 * (NU+1) * np.log(1 + z*z/NU) + NU_NORM
    
def logq(z, mu, lnsigma):
    """log of Gaussian parameterized by mean and log(sigma)"""
    sigma = np.exp(lnsigma)
    return -0.5 * ((z - mu) / sigma) ** 2 - lnsigma - 0.5*np.log(2.0 * np.pi)

def regression_vi(logp, n, mu_start, lnsigma_start, atol=1e-6):
    """use an optimizer for simple 1D VI"""
    phi_start = np.array([mu_start, lnsigma_start])
    
    # Objective function. Computes sum above on a grid.
    def obj(phi):
        _sigma = np.exp(phi[1])  # get sigma
        
        # This is the grid, factor of 10 is a random choice.
        z = np.linspace(phi[0] - 10.0*_sigma , phi[0] + 10.0*_sigma, n)

        # Build weights and differences.
        logqz = logq(z, phi[0], phi[1])
        w = np.exp(logqz)
        diff = logqz - logp(z)
        return np.sum(diff * diff * w) / np.sum(w)

    # Run the optimizer.
    opts = {'disp': True, 'maxiter': 5000, 'maxfev': 5000,
            'fatol': atol, 'xatol': 1e-8}
    phi_hat = scipy.optimize.minimize(obj, phi_start,
                                      method='Nelder-Mead',
                                      options=opts)
    print(phi_hat)
    return phi_hat['x'], phi_hat

phi_hat, res = regression_vi(logp, 100, 100.0, -100.0)
Optimization terminated successfully.
         Current function value: 0.034185
         Iterations: 107
         Function evaluations: 206
 final_simplex: (array([[ -3.31217268e-09,   1.39418201e-01],
       [  3.00982760e-09,   1.39418195e-01],
       [  5.89924420e-09,   1.39418199e-01]]), array([ 0.03418518,  0.03418518,  0.03418518]))
           fun: 0.03418517538400144
       message: 'Optimization terminated successfully.'
          nfev: 206
           nit: 107
        status: 0
       success: True
             x: array([ -3.31217268e-09,   1.39418201e-01])

We have used a very dirty method here, chucking our loss into an optimizer which requires no derivatives. Now let's make some plots!

In [3]:
z = np.linspace(-5.0, 5.0, 1000)
pz = np.exp(logp(z))
qz = np.exp(logq(z, phi_hat[0], phi_hat[1]))

plt.figure()
plt.plot(z, pz, label='p(z)')
plt.plot(z, qz, label='q(z)')
plt.xlabel('z')
plt.ylabel('PDF')
plt.legend();

Well that's nice, but eww. Unfortunately, this hack may not scale to many dimensions due to the grid, and also, eww.

Dirty VI II: The Laplace Approximation

A big thank you to Bill Lattner for pointing out the Laplace approximation!

Here is another quick and dirty way to build a variational approximation. Let's expand $\log p(z)$ in a Taylor series about its maximum. I am going to assume there is only a single local and global maximum. In one dimension we have

$$\log p(z) \approx \log p(z_{max}) + \left.\frac{d\log p(z)}{dz}\right|_{z=z_{max}}(z - z_{max}) + \frac{1}{2}\left.\frac{d^{2}\log p(z)}{dz^{2}}\right|_{z=z_{max}}(z - z_{max})^{2} +\ ...$$

At a maximum, $z_{max}$, the first order derivative is zero by definition. Thus we can approximate $p(z)$ up to terms second order in $z$ via

$$\log q(z) = \frac{1}{2}\left.\frac{d^{2}\log p(z)}{dz^{2}}\right|_{z=z_{max}}(z - z_{max})^{2} + c$$

where $c$ is a constant we will compute to normalize the approximation. The above expression for $q(z)$ is a Gaussian with variance

$$\sigma^{2} = -\left(\left.\frac{d^{2}\log p(z)}{dz^{2}}\right|_{z=z_{max}}\right)^{-1}$$

indicating that

$$c = -\frac{1}{2}\log(2\pi\sigma^{2}).$$

In higher dimensions, one would simply invert the negative of the Hessian matrix to get the approximate covariance matrix.

Let's redo our Students-t example above with this approximation.

In [4]:
def second_deriv(logp, z, h):
    """compute second deriv with finite diff stencil at second order
    
    See https://en.wikipedia.org/wiki/Finite_difference.
    """
    fzph = logp(z + h)
    fzmh = logp(z - h)
    fz = logp(z)
    return (fzph - 2*fz + fzmh) / h**2

def laplace_apprx(logp, z_start):
    """compute a laplace approx to a 1d pdf
    
    returns mean, lnsigma of the approximate Gaussian
    """
    # First find max.
    z_max = scipy.optimize.minimize(lambda z: -logp(z), z_start,
                                    method='Nelder-Mead',
                                    options={'disp': True})
    print(z_max)
    z_max = z_max['x'][0]

    # Now do finite diff. 
    # Set the stepsize by large of 1e-6 of z_max or 1e-6
    eps = 1e-6
    h = max(np.abs(z_max * eps), eps)
    
    # Get hessian and compute log(sigma).
    hess = second_deriv(logp, z_max, h)
    lnsigma = 0.5 * np.log(-1.0 / hess)
    return z_max, lnsigma

mu, lnsigma = laplace_apprx(logp, 100.0)
print("z_max  :", mu)
print("lnsigma:", lnsigma)
Optimization terminated successfully.
         Current function value: 1.000889
         Iterations: 24
         Function evaluations: 48
 final_simplex: (array([[  0.00000000e+00],
       [ -7.62939453e-05]]), array([ 1.00088885,  1.00088885]))
           fun: 1.0008888496235098
       message: 'Optimization terminated successfully.'
          nfev: 48
           nit: 24
        status: 0
       success: True
             x: array([ 0.])
z_max  : 0.0
lnsigma: -0.143774459782
In [5]:
z = np.linspace(-5.0, 5.0, 1000)
pz = np.exp(logp(z))
qz = np.exp(logq(z, mu, lnsigma))

plt.figure()
plt.plot(z, pz, label='p(z)')
plt.plot(z, qz, label='q(z)')
plt.xlabel('z')
plt.ylabel('PDF')
plt.legend();

Hmmmm. Well by eye at least, this looks worse. The curvature of the peak does seem to match better, as one might expect. One way forward would be to add more terms in the Taylor series above, but this clearly has its limits.

So far, hacking at this at random is not leading us to very principled approximations. Luckily, a lot of work has been done on this subject and there are well-defined ways to do VI! Let's discuss one of the most common ones below.

VI via the KL Divergence

Before we talk more about VI, let's introduce the Kullback–Leibler divergence or KL divergence. This quantity is defined for PDFs $q(z)$ and $p(z)$

$$D_{\rm KL}\big(Q||P\big) = \int q(z) \log\frac{q(z)}{p(z)}dz\ .$$

The most interesting property of the KL divergence is known as the Gibbs' inequality,

$$D_{\rm KL}\big(Q||P\big) \ge 0$$

with equality only when $p(z) = q(z)$. Note also that it is not symmetric in its arguments. The KL divergence is a way to measure the difference between PDFs.

We can now do a standard set of manipulations on the KL divergence to derive a quantity known as the evidence lower bound or ELBO. We start by reintroducing the data $D$ and noting that since $p(z,D)=p(z|D)p(D)$

$$\log p(z|D) = \log p(z,D) - \log p(D)\ .$$

Then we get

$$D_{\rm KL}\big(Q||P\big) = \int q(z)\log q(z)dz + \int q(z) \left[\log p(D) - \log p(D|z)p(z)\right] dz$$

After moving terms around a bit and noting that $p(D)$ is constant with respect to $z$ and any parameters $\phi$ in $q(z)$, we get

$$\begin{eqnarray} \log p(D) &=& D_{\rm KL}\big(Q||P\big) + \int q(z) \log p(D|z)p(z) dz - \int q(z)\log q(z)dz \\ &\equiv& D_{\rm KL}\big(Q||P\big) + {\rm ELBO}(Q) \end{eqnarray}$$

where

$${\rm ELBO}(Q) = \int q(z) \log p(D|z)p(z) dz - \int q(z)\log q(z)dz$$

and we use the fact that

$$\int q(z) \log p(D)dz = \log p(D)\int q(z) dz = \log p(D)$$

(since $q(z)$ is a properly normalized PDF). The second term in the ELBO

$$h(Z) = - \int q(z)\log q(z)dz$$

is know as the differential entropy. It turns out that distributions which allow you to compute the differential entropy analytically make for particularly easy and numerically efficient VI algorithms. This fact motivates the use of Gaussian distributions in one or more dimensions in many VI algorithms.

Finally, since the evidence, $p(D)$, is constant with respect to $q(z)$ and its parameters, maximizing the ELBO is equivalent to minimizing the KL divergence.

Note that if you are simply trying to approximate some PDF, then you can still use the KL divergence directly by minimizing it. Let's use the KL divergence to redo our Students-t example.

In [6]:
def kl_vi(logp, n, mu_start, lnsigma_start):
    """vi with KL divergence"""
    phi_start = np.array([mu_start, lnsigma_start])
    
    # Objective function. Computes the KL div of q and p.
    def obj(phi):
        # This term is -\int q*log(q).
        # Also known as the differential entropy.
        # For a Gaussian, it can be computed exactly. 
        # See wikipedia or something.
        entropy = phi[1] + 0.5*np.log(2.0 * np.pi) + 0.5

        # This is the grid, factor of 20 is a random choice.
        _sigma = np.exp(phi[1])  # get sigma        
        z = np.linspace(phi[0] - 20.0*_sigma , phi[0] + 20.0*_sigma, n)
        dz = z[1] - z[0]  # factor needed for numerical integral
        
        # This term is \int q*log(p)
        logqz = logq(z, phi[0], phi[1])
        qz = np.exp(logqz)

        return -entropy - np.sum(qz * logp(z) * dz)

    # Run the optimizer.
    phi_hat = scipy.optimize.minimize(obj, phi_start,
                                      method='Nelder-Mead',
                                      options={'disp': True})
    print(phi_hat)
    return phi_hat['x'], phi_hat

phi_hat, res = kl_vi(logp, 10000, 1.0, 0.0)
Optimization terminated successfully.
         Current function value: 0.040695
         Iterations: 56
         Function evaluations: 108
 final_simplex: (array([[ -4.26771590e-06,   2.31270951e-01],
       [  5.19733032e-05,   2.31261077e-01],
       [  5.67321110e-05,   2.31314504e-01]]), array([ 0.04069546,  0.04069546,  0.04069546]))
           fun: 0.040695455334933017
       message: 'Optimization terminated successfully.'
          nfev: 108
           nit: 56
        status: 0
       success: True
             x: array([ -4.26771590e-06,   2.31270951e-01])
In [7]:
z = np.linspace(-5.0, 5.0, 1000)
pz = np.exp(logp(z))
qz = np.exp(logq(z, phi_hat[0], phi_hat[1]))

plt.figure()
plt.plot(z, pz, label='p(z)')
plt.plot(z, qz, label='q(z)')
plt.xlabel('z')
plt.ylabel('PDF')
plt.legend();

Cool! I cannot say this approximation looks as good as the first one based on the MSE, but VI done with the KL divergence is what you will find both in the literature and in packages like Stan. So this is what we are going to stick with for now.

Traditional, Bespoke VI: Yikes.

If you read text books or the interwebs, you will find lots of information on VI in terms of its application to specific models and the calculus of variations. While this sort of technique can be very useful in many cases, it is not applicable to any PDF. Here is how this typically goes.

In mean-field variational inference, one assumes that the variational approximation is a separable distribution in the parameters (i.e., $q(z)=\prod_{i}q_{i}(z_{i})$). Then one tries to directly minimize the KL divergence

$$\int q(z) \log\frac{q(z)}{p(z)}dz$$

with this separable distribution. The problem here is to find functions $\{q_{i}(z_{i})\}$ such that this integral is minimized. The calculus of variations is the branch of mathematics that deals with problems of this nature. In the terminology of this branch of mathematics, the KL divergence is viewed as a functional of $q(z)$. Typically, once you find the optimal functions, there will be a coupled set of equations to be solved. Sometimes these equations admit an analytic solution. Other times, one resorts to iterative schemes, which can be shown to converge.

Even the simplest examples from this field generate pages of mathematical analysis. Instead of doing this in detail, I am going to setup one of the simplest problems to give you a flavor of how this goes. This example is from Chapter 33 of Information Theory, Inference, and Learning Algorithms by David MacKay.

Suppose we are modeling $N$ observations $x_{i}$ from a Gaussian with an unknown mean, $\mu$, and variance, $\sigma^{2}$. The posterior in this case is

$$p(\mu, \sigma|\{x_{i}\}) \propto p(\mu, \sigma)\prod_{i} N(x_{i}|\mu, \sigma)\ .$$

We are going use flat priors on $\mu$ and $\log\sigma$. (These choices are motivated in MacKay. The prior on $\sigma$ is called a Jeffreys prior and is invariant under reparameterization.)

Then the optimization problem we have is to minimize the functional

$$\int q_{\mu}(\mu)q_{\sigma}(\sigma) \log\frac{q_{\mu}(\mu)q_{\sigma}(\sigma)}{p(\mu, \sigma|\{x_{i}\})}d\mu d\sigma$$

with respect to $q_{u}$ and $q_{\sigma}$.

So then you go to work and find that (see MacKay!)

$$\begin{eqnarray} q_{\mu}(\mu) & = & N(\mu| {\bar x}, 1/\sqrt{N{\bar \beta}})\\ {\bar \beta} & = & \int q_{\sigma}(\sigma)\sigma^{-2}d\sigma\\ q_{\beta}(\beta) & = & \Gamma(\beta|b', c')\\ \beta & = & 1/\sigma^{2}\\ \frac{1}{b'} & = & \frac{1}{2}\left({\bar \beta}^{-1} + S\right)\\ c' & = & N/2\\ {\bar x} & = & \frac{1}{N}\sum_i x_{i}\\ S & = & \sum_{i} \left(x_{i} - {\bar x}\right)^{2} \end{eqnarray}$$

Yuck. Now you have to solve this joint set of equations for the parameters that define $q_{\mu}$ and $q_{\sigma}$. (It turns out in this simple case, one can show that $1/{\bar \beta} = S/(N-1)$.)

I think it is clear that for the data science work, bespoke VI methods, while useful for some specific models, may not be general or practical enough for everyday use with arbitrary PDFs.

Automated Variational Inference and Reparameterization

So we would like a more automatic form of variational inference. One method, which you will find in Stan, is called automatic differentiation variational inference or ADVI. The paper introducing this technique by Kucukelbir et al. is definitely worth reading. Before we discuss ADVI in its full glory, it is important to work through an example of how a simple form of VI can fail and what we might do about it.

In order to be specific, we are going to use this example PDF

$$\log p(z) = 10^{3}\log z + \log(1-z) - c$$

where $c$ is a constant to normalize the PDF. (In this case $c={\rm Beta}(10^{3}+1, 2)$.)

A Naive Approach

Let's start with directly trying to do VI with this PDF. Note that when doing VI, the variational approximation should have the same domain of support as the target PDF $p(z)$. To do this with a Gaussian, we are going to renormalize the Gaussian to have a unit integral on the domain $[0,1]$ and set it to zero outside of this domain.

In [8]:
def logq_unit(z, mu, lnsigma):
    """log of Gaussian parameterized by mean and log(sigma)
    has unit integral over 0,1 
    and value zero outside of 0,1
    """
    val = np.zeros_like(z)
    msk = (z >= 1.0) | (z <= 0.0)
    val[msk] = -np.inf
    if np.any(~msk):
        sigma = np.exp(lnsigma)
        a, b = (0.0 - mu) / sigma, (1.0 - mu) / sigma
        val[~msk] = scipy.stats.truncnorm.logpdf(z[~msk], a=a, b=b, loc=mu, scale=sigma)
    
    return val

def logp_hard(z, a=1e3, b=1):
    val = np.zeros_like(z)
    msk = (z >= 1.0) | (z <= 0.0)
    val[msk] = -np.inf
    if np.any(~msk):
        val[~msk] = a * np.log(z) + b * np.log(1.0 - z) - scipy.special.betaln(a + 1.0, b + 1.0)
    return val

def kl_vi_unit(logp, n, mu_start, lnsigma_start, eps=1e-8):
    """vi with KL divergence over unit integral"""
    phi_start = np.array([mu_start, lnsigma_start])
    
    # Objective function. Computes the KL div of q and p.
    def obj(phi):
        # This term is -\int q*log(q).
        sigma = np.exp(phi[1])
        a, b = (0.0 - phi[0]) / sigma, (1.0 - phi[0]) / sigma
        entropy = scipy.stats.truncnorm.entropy(a=a, b=b, loc=phi[0], scale=sigma)

        # This is the grid, factor of 20 is a random choice.
        _sigma = np.exp(phi[1])  # get sigma        
        z = np.linspace(eps, 1.0 - eps, n)
        dz = z[1] - z[0]  # factor needed for numerical integral
        
        # This term is \int q*log(p)
        logqz = logq_unit(z, phi[0], phi[1])
        qz = np.exp(logqz)

        return -entropy - np.sum(qz * logp(z) * dz)

    # Run the optimizer.
    phi_hat = scipy.optimize.minimize(obj, phi_start,
                                      method='Nelder-Mead',
                                      options={'disp': True, 'maxfev': 10000})
    print(phi_hat)
    return phi_hat['x'], phi_hat

phi_hat, res = kl_vi_unit(logp_hard, 10000, 0.0, 0.0)
/Users/mbecker/miniconda3/envs/civis/lib/python3.6/site-packages/scipy/stats/_continuous_distns.py:4846: RuntimeWarning: divide by zero encountered in log
  self._logdelta = np.log(self._delta)
/Users/mbecker/miniconda3/envs/civis/lib/python3.6/site-packages/scipy/stats/_continuous_distns.py:4850: RuntimeWarning: invalid value encountered in double_scalars
  return _norm_pdf(x) / self._delta
Optimization terminated successfully.
         Current function value: 0.145730
         Iterations: 231
         Function evaluations: 427
 final_simplex: (array([[ 0.9981397 , -6.63765288],
       [ 0.99813979, -6.63771762],
       [ 0.99813975, -6.63760831]]), array([ 0.14572997,  0.14572997,  0.14572997]))
           fun: 0.14572997442662317
       message: 'Optimization terminated successfully.'
          nfev: 427
           nit: 231
        status: 0
       success: True
             x: array([ 0.9981397 , -6.63765288])
In [9]:
z = np.linspace(0.5, 0.999999, 100000)
pz = np.exp(logp_hard(z))
qz = np.exp(logq_unit(z, phi_hat[0], phi_hat[1]))
dz_dlogitz = z * (1.0 - z)

plt.figure()
plt.plot(scipy.special.logit(z), pz * dz_dlogitz, label='p(logit(z))')
plt.plot(scipy.special.logit(z), qz * dz_dlogitz, label='q(logit(z))')
plt.xlabel('logit(z)')
plt.ylabel('PDF')
plt.legend();