Reparametrizing a Multivariate Normal
Posted on: March 19, 2025
A multivariate normal \(y\in \mathbb{R}^d\) with mean \( \mu \in \mathbb{R}^d \) and covariance matrix \( \Sigma \in \mathbb{R}^{d \times d} \) is written as
\begin{align} y &\sim \mathcal{N}(\mu, \Sigma) \label{eq-1} \end{align}
Here, assuming the covariance matrix is symmetric positive definite, there exists a unique Cholesky factorization \(\Sigma = L L^T\), where \(L \in \mathbb{R}^{d\times d}\) is a lower triangular matrix with positive diagonal entries.
Now, consider
\begin{align} z &\sim \mathcal{N}(0, I_d) \\[6pt] Lz &\sim \mathcal{N}\!\left(0, L I_d L^T\right) \\[6pt] &= \mathcal{N}\!\left(0, L L^T\right) \\[6pt] &= \mathcal{N}\!\left(0, \Sigma\right) \end{align}
where \(I_d\) is the \(d\)-dimensional identity matrix. Also,
\begin{align} \mu + Lz &\sim \mathcal{N}\!\left(\mu, \Sigma\right) \label{eq-2} \end{align}
Hence, Eq. \eqref{eq-2} provides a reparametrization of Eq. \eqref{eq-1}.
Implementation in NumPyro
Here's a minimal working example of this in jax and numpyro
import jax.numpy as jnp
from jax import random
from numpyro import distributions as dist
# Setup
key = random.PRNGKey(0)
num_samples = 10_000 # N = 10_000
dim = 3 # d = 3
# Generate mean
key, sub = random.split(key)
mu = random.normal(sub, (dim,))
# Generate covariance
key, sub = random.split(key)
A = random.normal(sub, (dim, dim))
cov = A @ A.T # ensure positive definite
# Generate samples directly
key, sub = random.split(key)
direct = dist.MultivariateNormal(mu, cov).sample(sub, (num_samples,))
Reparametrized Sampling
To sample using the reparametrized form in Eq. \eqref{eq-2}, we compute the Cholesky decomposition of \( \Sigma\), and draw samples from the standard normal
# Lower Cholesky factor
L = jnp.linalg.cholesky(cov) # (d, d)
# Generate samples with reparametrization
key, sub = random.split(key)
z = random.normal(sub, (num_samples, dim)) # (N, d)
reparam = mu + jnp.einsum("ij,nj->ni", L, z)
Visual Validation
We can confirm the two methods produce equivalent samples by comparing pairwise plots (code omitted).
I have found this reparametrization to be useful when fitting hierarchical models.