Reparametrizing a Multivariate Normal

Posted on: March 19, 2025

A multivariate normal \(y\in \mathbb{R}^d\) with mean \( \mu \in \mathbb{R}^d \) and covariance matrix \( \Sigma \in \mathbb{R}^{d \times d} \) is written as

\begin{align} y &\sim \mathcal{N}(\mu, \Sigma) \label{eq-1} \end{align}

Here, assuming the covariance matrix is symmetric positive definite, there exists a unique Cholesky factorization \(\Sigma = L L^T\), where \(L \in \mathbb{R}^{d\times d}\) is a lower triangular matrix with positive diagonal entries.

Now, consider

\begin{align} z &\sim \mathcal{N}(0, I_d) \\[6pt] Lz &\sim \mathcal{N}\!\left(0, L I_d L^T\right) \\[6pt] &= \mathcal{N}\!\left(0, L L^T\right) \\[6pt] &= \mathcal{N}\!\left(0, \Sigma\right) \end{align}

where \(I_d\) is the \(d\)-dimensional identity matrix. Also,

\begin{align} \mu + Lz &\sim \mathcal{N}\!\left(\mu, \Sigma\right) \label{eq-2} \end{align}

Hence, Eq. \eqref{eq-2} provides a reparametrization of Eq. \eqref{eq-1}.

Implementation in NumPyro

Here's a minimal working example of this in jax and numpyro

import jax.numpy as jnp
from jax import random
from numpyro import distributions as dist

# Setup
key = random.PRNGKey(0)
num_samples = 10_000    # N = 10_000
dim = 3                 # d = 3

# Generate mean
key, sub = random.split(key)
mu = random.normal(sub, (dim,))

# Generate covariance
key, sub = random.split(key)
A = random.normal(sub, (dim, dim))
cov = A @ A.T   # ensure positive definite

# Generate samples directly
key, sub = random.split(key)
direct = dist.MultivariateNormal(mu, cov).sample(sub, (num_samples,))

Reparametrized Sampling

To sample using the reparametrized form in Eq. \eqref{eq-2}, we compute the Cholesky decomposition of \( \Sigma\), and draw samples from the standard normal

# Lower Cholesky factor 
L = jnp.linalg.cholesky(cov)    # (d, d)

# Generate samples with reparametrization
key, sub = random.split(key)
z = random.normal(sub, (num_samples, dim))   # (N, d)
reparam = mu + jnp.einsum("ij,nj->ni", L, z)

Visual Validation

We can confirm the two methods produce equivalent samples by comparing pairwise plots (code omitted).

Pairwise comparison of direct vs reparametrized samples across dimensions
Figure 1. Pairwise plots show large overlap between direct and reparametrized samples.

I have found this reparametrization to be useful when fitting hierarchical models.