Multivariate Normal Reparameterization in NumPyro

Posted on: March 19, 2025

Reparameterization is a core technique in variational inference and HMC-based samplers for Bayesian modeling. Here, we revisit the multivariate normal distribution (MVN) and show how to rewrite it in a form that decouples mean and correlation, which can improve sampling efficiency.

Background

A multivariate normal distribution with mean vector \( \boldsymbol{\mu} \in \mathbb{R}^d \) and covariance matrix \( \boldsymbol{\Sigma} \in \mathbb{R}^{d \times d} \) is written as

\[ \mathbf{y} \sim \mathcal{N}(\boldsymbol{\mu}, \boldsymbol{\Sigma}) \]

This can be reparameterized as \[ \mathbf{y} = \boldsymbol{\mu} + L \mathbf{z}, \quad \text{where } \mathbf{z} \sim \mathcal{N}(0, I), \text{ and } LL^\top = \boldsymbol{\Sigma} \] Here, \( L \) is the (lower-triangular) Cholesky factor of the covariance \( \boldsymbol{\Sigma} \), which exists and is unique when \( \boldsymbol{\Sigma} \) is symmetric positive-definite.

Implementation in NumPyro

Here's a minimal working example of this in jax and numpyro

import jax.numpy as jnp
from jax import random
from numpyro.distributions import MultivariateNormal

# Setup
key = random.PRNGKey(0)
num_samples = 1000
dim = 3

# Generate mean and covariance
key, subkey = random.split(key)
mu = random.normal(subkey, (dim,))
A = random.normal(subkey, (dim, dim))
cov = A @ A.T  # ensure positive-definite

# Sample directly
samples = MultivariateNormal(mu, cov).sample(subkey, (num_samples,))

Reparameterized Sampling

To sample using the reparameterized form \( \mathbf{y} = \boldsymbol{\mu} + L\mathbf{z} \), compute the Cholesky decomposition of \( \boldsymbol{\Sigma} \), and draw standard normal noise

L = jnp.linalg.cholesky(cov)

key, subkey = random.split(key)
z = random.normal(subkey, (num_samples, dim))
samples_reparam = mu + z @ L.T

Visual Validation

We can confirm the two methods produce equivalent samples by comparing pairwise plots (code omitted for brevity).

Pairwise comparison of direct vs reparameterized samples across dimensions
Figure 1. Pairwise plots show large overlap between direct and reparameterized samples.

This reparameterization is useful in hierarchical Bayesian models, where decoupling location and scale improves posterior geometry.