Variational Inference for Bayesian Probit Regression

- 6 mins

Variational inference has become one of the most important approximate inference techniques for Bayesian statistics, but it has taken me a long time to wrap my head around the central ideas (and I’m still learning). Since I’ve found that going through examples is the most efficient way to learn, I thought I would go through a single example in this post, performing variational inference on Bayesian probit regression.

I’m going to assume the reader is somewhat familiar with the basic ideas behind variational inference. If you’ve never seen variational infererence before, I strongly recommend this tutorial by David Blei, Alp Kucukelbir, and Jon McAuliffe. These course notes from David Blei are also very handy.

Variational Inference: A (Very) Brief Overview

Bayesian statistics often requires computing the conditional density of latent variables given observed variables . Since this distribution is typically intractable, variational inference learns an approximate distribution that is meant to be “close” to , using Kullback-Leibler divergence as a measure.

Thus, there are two steps. The first comes from providing a form for the variational distribution, . The most frequently used form comes from the mean-field variational family, where factors into conditionally independent distributions each governed by some set of parameters, . Once we have specified the factorization of the distribution, we are still required to figure out the optimal form of each factor, both in terms of its family and parameters (although these can be conisdered the same thing). Thus, the second step is optimizing .

It turns out the optimal form of each factor is straightforward: , where refers to the expectation when omitting variable . To minimize , we cycle between latent factors and update the mean (with respect to the current parameters) according to the equation above. If these results are unfamiliar, definitely check out the tutorial I mentioned earlier.

Variational Inference for Bayesian Probit Regression

Consider a probit regression problem, where we have data and a binary outcome . In probit regression, we assume , where and are unknown and random, with a uniform prior, and is the standard normal CDF. To simplify things, we can introduce variables so if and if .

The first step is writing down the log posterior density up to a constant. It is straightforward to see

The next step is defining our variational distribution . We will provide one factor for each , along with indendent factors for and each. Therefore, consists of independent factors:

To learn the optimal form of each factor, we use the rule described above. That is, consider a single . The optimal distribution is therefore . Writing this out, we see

Thus, after exponentiating, we have that the ideal form is a truncated normal distribution. That is, if and if , where and are normal distributions truncated to be positive and negative, respecitively.

Similarly, for , we have . Removing terms that do not depend on and completing the square, we have the optimal form as .

Finally, for , we have . Again removing the terms that do not depend on and completing the square, we have the following optimal form:

Now that we know the form of all the factors, it’s time to optimize. To do this, we set each parameter to the mean of its optimal factored distribution. The updates can take the following form in R:

update_M_zj = function(M_a,M_b,j) {
  mu = M_a + M_b*x[j]
  if (y[j] == 1) {
    return(mu + dnorm(-1*mu)/(1-pnorm(-1*mu)))
  } else {
    return(mu - dnorm(-1*mu)/(pnorm(-1*mu)))
  }
}
update_M_a = function(M_z,M_b) {
  return(sum(M_z-M_b*x)/n)
}
update_M_b = function(M_z,M_a) {
  return(sum(x*(M_z-M_a))/sum(x^2))
}

Thefore, a single updating step would look like

for (i in 1:n) {
  M_z[iteration] = update_M_zj(M_a,M_b,i)
}
M_a = update_M_a(M_z,M_b)
M_b = update_M_b(M_z,M_a)
as[iteration] = M_a
bs[iteration] = M_b

Again, variational inference is an incredibly powerful tool, and I cannot overstate how helpful the links I posted above are in understanding all of this. Hopefully this tutorial clears up some of the confusion about variational inferece.

Keyon Vafa

Keyon Vafa

Statistics PhD student focusing on machine learning

rss facebook twitter github youtube mail spotify instagram linkedin google google-plus pinterest medium vimeo stackoverflow reddit quora