Variational Bayesian Sparse Linear Mixed Model
Table of Contents
Introduction
BSLMM (Zhou, Carbonetto, and Stephens 2013) is the model: \( \newcommand\Normal{\mathcal{N}} \DeclareMathOperator\Cov{Cov} \DeclareMathOperator\det{det} \DeclareMathOperator\tr{tr} \DeclareMathOperator\elbo{ELBO} \newcommand\mk{\mathbf{K}} \newcommand\mi{\mathbf{I}} \newcommand\mh{\mathbf{H}} \newcommand\ml{\mathbf{L}} \newcommand\mx{\mathbf{X}} \newcommand\vb{\mathbf{b}} \newcommand\ve{\mathbf{e}} \newcommand\vu{\mathbf{u}} \newcommand\vy{\mathbf{y}} \)
\begin{align*} \vy &= \mx \vb + \vu + \ve\\ b_j &\sim \pi_0 \delta_0(\cdot) + (1 - \pi_0) \Normal(0, \sigma_b^2 \sigma^2)\\ \vu &\sim \Normal(\mathbf{0}, \sigma^2_u\sigma^2 \mk)\\ \ve &\sim \Normal(\mathbf{0}, \sigma^2 \mi) \end{align*}where \(\mk = \mx \mx'\). Zhou et al. use a Metropolis-Hastings algorithm (implemented in GEMMA) to estimate the posterior distribution of PVE \((V[\mx\vb + \vu]) / (V[\vy])\) and PGE \(V[\mx\vb] / V[\vy]\). This method is extremely slow, requiring many samples to mix, and access to the full data \(\mx\) and \(\vy\). We have implemented a VBEM/importance sampling algorithm for this model, which is available in the Python package vbslmm. This algorithm relies only on sufficient statistics \(\mx'\mx, \mx'\vy, \vy'\vy, n\), where \(n\) is the number of individuals. Here, we outline the method and evaluate it on simulations.
Setup
%matplotlib inline %config InlineBackend.figure_formats = set(['retina'])
import matplotlib.pyplot as plt plt.rcParams['figure.facecolor'] = 'w' plt.rcParams['font.family'] = 'Nimbus Sans'
Methods
Full data VBEM
Under the BSLMM model,
\[ \vy \mid \mx, \vb, \mh, \sigma^2 \sim \Normal(\mx\vb, \sigma^2 \mh) \]
where \(\mh = \sigma^2_u \mk + \mi\). Let \(\ml\) be a matrix such that \(\mh = \ml\ml'\). Then,
\[ \ml^{-1}\vy \sim \Normal(\ml^{-1}\mx\vb, \sigma^2 \mi) \]
and
\[ f(\vy) = f(\ml^{-1}\vy) \det(\ml^{-1}) = f(\ml^{-1}\vy) \det(\mh^{-1/2}) \]
where \(f(\cdot)\) denotes the density function. This transformation yields
a BVSR problem (Guan and Stephens 2011), which can be solved by varbvs
(Carbonetto and Stephens 2012,
Carbonetto, Zhou, and Stephens 2017).
To approximate the posterior distribution \(p(\vb \mid \mx, \vy, \cdot)\),
we can use variational inference to find a distribution \(q(\vb \mid
\cdot)\) with minimum KL divergence to the true posterior (Blei et al 2017),
parameterized as
\[ q(b_j) = \alpha_j \delta_0(\cdot) + (1 - \alpha_j) \Normal(\mu_j, s_j^2) \]
The KL divergence is intractable, but we can instead maximize a lower bound to the model evidence
\begin{multline*} \elbo(\tilde{\mx}, \tilde{\vy}, \cdot) = -\frac{n}{2} \ln(2 \pi \sigma^2) - \frac{1}{2 \sigma^2} \Vert \tilde{\vy} - \tilde{\mx} E[\vb] \Vert^2 - \frac{1}{2 \sigma^2} \sum_j (\tilde{\mx}'\tilde{\mx})_{jj} V[b_j]\\ + \sum_j \frac{\alpha_j}{2} \left[1 + \ln\left(\frac{s_j^2}{\sigma_b^2 \sigma^2}\right) - \frac{s_j^2 + \mu_j^2}{\sigma_b^2\sigma^2}\right] + \sum_j \left[\alpha_j \ln\left(\frac{\alpha_j}{\pi_0}\right) + (1 - \alpha_j) \ln\left(\frac{1 - \alpha_j}{1 - \pi_0}\right)\right] \end{multline*}where \(\tilde{\mx} = \ml^{-1}\mx\) and \(\tilde{\vy} = \ml^{-1}\vy\). We can use coordinate descent to optimize the ELBO (Carbonetto and Stephens 2012)
\begin{align*} s_j^2 &:= \frac{\sigma^2}{(\tilde{\mx}'\tilde{\mx})_{jj} + 1 / (\sigma_b^2 \sigma^2)}\\ \mu_j &:= \frac{s_j^2}{\sigma^2}\left((\tilde{\mx}'\tilde{\vy})_j - \sum_{t \neq j} (\tilde{\mx}'\tilde{\mx})_{jt} \alpha_t \mu_t\right)\\ \frac{\alpha_j}{1 - \alpha_j} &:= \left(\frac{\pi_0}{1 - \pi_0}\right)\left(\frac{s_j}{\sigma_b \sigma}\right)\exp\left(\frac{\mu_j^2}{2 s_j^2}\right) \end{align*}The solution to the transformed problem will also be a solution to the original problem, because
\[ \elbo(\mx, \vy, ...) = -\frac{1}{2}\ln\det(\mh) + \elbo(\tilde{\mx}, \tilde{\vy}, ..) \]
We can further use VBEM (Beal 2003) to optimize the hyperparameters
\begin{align*} \sigma^2 &:= \frac{\Vert \tilde{\vy} - \tilde{\mx}E[\vb] \Vert^2 + \sum_j (\tilde{\mx}'\tilde{\mx})_{jj} V[b_j] + \frac{1}{\sigma^2_b} \sum_j \alpha_j (s_j^2 + \mu_j^2)}{n + \sum_j \alpha_j}\\ \sigma_b^2 &:= \frac{\sum_j \alpha_j (s_j^2 + \mu_j^2)}{\sigma^2 \sum_j \alpha_j}\\ \ln\left(\frac{\pi_0}{1 - \pi_0}\right) &:= \ln\left(\frac{\sum_j \alpha_j}{\sum_j (1 - \alpha_j)}\right) \end{align*}The derivative with respect to \(\sigma^2_u\) is also available, and we can use line search to numerically update it.
\[ \frac{\partial\elbo}{\partial \sigma^2_u} = \frac{1}{2 \sigma^2} (\vy - \mx E[\vb])' \mh^{-1}\mk\mh^{-1} (\vy - \mx E[\vb]) + \frac{1}{2\sigma^2} \mx'\mh^{-1}\mk\mh^{-1}\mx'V[\vb] - \frac{1}{2}\tr(\mh^{-1}\mk) \]
Alternatively, we can run VBEM over a grid of proposals \(\sigma^2_u\), and use the ELBO values as importance weights (Carbonetto and Stephens 2012).
Sufficient statistics VBEM
The algorithm described above requires access to the full data \(\mx, \vy\), to compute necessary quantities like \((\ml^{-1}\mx)'(\ml^{-1}\mx)\). We now derive an algorithm which only depends on sufficient statistics \(\mx'\mx, \mx'\vy, \vy'\vy, n\).
We have
\begin{equation*} \left[\begin{array}{c} \vy \\ \vu \end{array}\right] \mid \mx, \vb, \sigma^2, \sigma^2_u \sim \mathcal{N} \left( \left[\begin{array}{c} \mx\vb \\ \mathbf{0} \end{array}\right], \sigma^2 \left[\begin{array}{cc} \mh & \sigma^2_u \mk \\ \sigma^2_u \mk & \sigma^2_u \mk \end{array}\right] \right) \end{equation*}because \(\Cov(\vb, \vu) = \Cov(\vu, \ve) = 0\). Therefore,
\[ \vu \mid \mx, \vy, \vb, \sigma^2, \sigma^2_u \sim \mathcal{N}\left(\frac{\sigma^2_u}{\sigma^2} \mk\mh^{-1}(\vy - \mx\vb), \sigma^2_u \mk (\mi - \mh^{-1})\right), \]
the expected log joint is
\begin{multline*} E[\ln p(\vy \mid \mx, \vb, \vu, \cdot) + \ln p(\vu \mid \cdot)] = -\frac{1}{2}\ln\det(2\pi\sigma^2 \mh) - \frac{1}{2\sigma^2} \Vert \vy - \mx E[\vb] - E[\vu] \Vert^2\\ - \frac{1}{2 \sigma^2_u \sigma^2} E[\vu' \mk^{-1} \vu] + \sum_j (\mx'\mx)_{jj} V[b_j] + \sum_i V[u_i], \end{multline*}where expectations are taken with respect to \(q(\vb \mid \cdot) p(\vu \mid \mx, \vy, \vb, \cdot)\), and the ELBO is
\begin{multline*} \elbo(\cdot) = E[\ln p(\vy \mid \mx, \vb, \vu, \cdot) + \ln p(\vu \mid \cdot)]\\ + \sum_j \frac{\alpha_j}{2} \left[1 + \ln\left(\frac{s_j^2}{\sigma_b^2 \sigma^2}\right) - \frac{s_j^2 + \mu_j^2}{\sigma_b^2\sigma^2}\right] + \sum_j \left[\alpha_j \ln\left(\frac{\alpha_j}{\pi_0}\right) + (1 - \alpha_j) \ln\left(\frac{1 - \alpha_j}{1 - \pi_0}\right)\right] \end{multline*}