Skip to content

Commit

Permalink
Add draft of SkewMultivariateNormal
Browse files Browse the repository at this point in the history
  • Loading branch information
colehaus committed Sep 1, 2022
1 parent 00aa9ef commit 1809d35
Show file tree
Hide file tree
Showing 5 changed files with 322 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ build
*.pyo
/build
/dist
/.hypothesis

# IDE
.idea
Expand Down
2 changes: 2 additions & 0 deletions numpyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
Pareto,
RelaxedBernoulli,
RelaxedBernoulliLogits,
SkewMultivariateNormal,
SoftLaplace,
StudentT,
Uniform,
Expand Down Expand Up @@ -158,6 +159,7 @@
"MultivariateStudentT",
"LowRankMultivariateNormal",
"Normal",
"SkewMultivariateNormal",
"NegativeBinomialProbs",
"NegativeBinomialLogits",
"NegativeBinomial2",
Expand Down
136 changes: 136 additions & 0 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

from typing import Union, cast

import numpy as np
from numpy.typing import NDArray

from jax import lax
from jax.experimental.sparse import BCOO
Expand Down Expand Up @@ -1731,6 +1734,139 @@ def variance(self):
return jnp.broadcast_to(self.scale**2, self.batch_shape)


def skew_delta(skewers_: NDArray[float], cov_: NDArray[float]):
return (jnp.einsum("...ij,...j->...i", cov_, skewers_)) / jnp.sqrt(
1
+ jnp.einsum("...j,...jk,...k->...", skewers_, cov_, skewers_)[..., jnp.newaxis]
)


# Regularized Multivariate Regression Models with Skew-t Error Distributions
# https://epublications.marquette.edu/cgi/viewcontent.cgi?article=1225&context=mscs_fac
class SkewMultivariateNormal(Distribution):
arg_constraints = {
"loc": constraints.real_vector,
"scale_tril": constraints.lower_cholesky,
"skewers": constraints.real_vector,
}
support = constraints.real_vector
reparametrized_params = ["loc", "scale_tril", "skewers"]
uv_norm = Normal(0.0, 1.0)

@staticmethod
def mk_big_mv_norm(
loc: NDArray[float], skewers: NDArray[float], scale_tril: NDArray[float]
):
cov = jnp.einsum("...ij,...hj->...ih", scale_tril, scale_tril)
delta_ = skew_delta(skewers, cov)
cov_star = jnp.block(
[
[
jnp.ones(skewers.shape[:-1] + (1, 1)),
jnp.expand_dims(delta_, axis=-2),
],
[jnp.expand_dims(delta_, axis=-1), cov],
]
)

return MultivariateNormal(
loc=jnp.zeros(loc.shape[-1] + 1), scale_tril=jnp.linalg.cholesky(cov_star)
)

def __init__(
self,
loc: Union[NDArray[float], float],
scale_tril: NDArray[float],
skewers: NDArray[float],
validate_args: None = None,
):
if jnp.ndim(loc) == 0:
(loc_,) = promote_shapes(loc, shape=(1,))
else:
loc_ = cast(NDArray[float], loc)
batch_shape = lax.broadcast_shapes(
jnp.shape(loc_)[:-1], jnp.shape(scale_tril)[:-2], jnp.shape(skewers)[:-1]
)
(self.loc,) = promote_shapes(loc_, shape=batch_shape + loc_.shape[-1:])
(self.skewers,) = promote_shapes(
skewers, shape=batch_shape + skewers.shape[-1:]
)
(self.scale_tril,) = promote_shapes(
scale_tril, shape=batch_shape + scale_tril.shape[-2:]
)
cov_batch = jnp.einsum("...ij,...hj->...ih", self.scale_tril, self.scale_tril)
self._std_devs = jnp.sqrt(jnp.diagonal(cov_batch, axis1=-2, axis2=-1))

# Used for sampling
self._big_mv_norm = self.mk_big_mv_norm(
# The blog post just uses unstandardized skewers here but that leads to
# a discrepancy between sampling and log_prob
loc=self.loc,
skewers=skewers / self._std_devs,
scale_tril=scale_tril,
)
# Used for log_prob
self._mv_norm = MultivariateNormal(loc_, scale_tril=scale_tril)

skew_mean = jnp.sqrt(2 / jnp.pi) * skew_delta(
self.skewers / self._std_devs, cov_batch
)
self._mean = self.loc + skew_mean
# The paper just uses `mean` here but that's definitely not right because
# it potentially leads to covariance matrices which are not positive semi definite
self._covariance = cov_batch - jnp.einsum(
"...i,...j->...ij", skew_mean, skew_mean
)

event_shape = jnp.shape(self.scale_tril)[-1:]
super().__init__(
batch_shape=batch_shape,
event_shape=event_shape,
validate_args=validate_args,
)

@validate_sample
def log_prob(self, value: NDArray[float]) -> NDArray[float]:
return (
jnp.log(2)
+ self._mv_norm.log_prob(value)
+ jnp.log(
self.uv_norm.cdf(
jnp.einsum(
"...k,...k->...",
(value - self.loc) / self._std_devs,
self.skewers,
)
)
)
)

@staticmethod
def infer_shapes(
loc: NDArray[float], scale_tril: NDArray[float], skewers: NDArray[float]
):
event_shape = (scale_tril[-1],)
batch_shape = lax.broadcast_shapes(loc[:-1], scale_tril[:-2], skewers[:-1])
return batch_shape, event_shape

# https://gregorygundersen.com/blog/2020/12/29/multivariate-skew-normal/
def sample(
self, key: random.PRNGKey, sample_shape: tuple[int, ...] = ()
) -> NDArray[float]:
assert is_prng_key(key)
x = self._big_mv_norm.sample(key, sample_shape=sample_shape)
sign_bit, samples = x[..., 0, jnp.newaxis], x[..., 1:]
return jnp.where(sign_bit <= 0, -1 * samples, samples) + self.loc

@property
def mean(self):
return jnp.broadcast_to(self._mean, self.shape())

@property
def covariance_matrix(self):
return self._covariance


class Pareto(TransformedDistribution):
arg_constraints = {"scale": constraints.positive, "alpha": constraints.positive}
reparametrized_params = ["scale", "alpha"]
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"test": [
"black[jupyter]>=21.8b0",
"flake8",
"hypothesis[numpy]",
"isort>=5.0",
"pytest>=4.1",
"pyro-api>=0.1.1",
Expand Down
182 changes: 182 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,15 @@
import inspect
import math
import os
from typing import cast

from hypothesis import given, note, settings
import hypothesis.extra.numpy as hnp
import hypothesis.strategies as st
from hypothesis.strategies import DrawFn, SearchStrategy
import numpy as np
from numpy.testing import assert_allclose, assert_array_equal
from numpy.typing import NDArray
import pytest
import scipy
import scipy.stats as osp
Expand Down Expand Up @@ -534,6 +540,12 @@ def get_sp_dist(jax_dist):
T(dist.Normal, 0.0, 1.0),
T(dist.Normal, 1.0, np.array([1.0, 2.0])),
T(dist.Normal, np.array([0.0, 1.0]), np.array([[1.0], [2.0]])),
T(
dist.SkewMultivariateNormal,
np.array([2.0, 0.0]),
np.array([[1.0, 0.0], [0.5, 1.0]]),
np.array([0.0, 0.0]),
),
T(dist.Pareto, 1.0, 2.0),
T(dist.Pareto, np.array([1.0, 0.5]), np.array([0.3, 2.0])),
T(dist.Pareto, np.array([[1.0], [3.0]]), np.array([1.0, 0.5])),
Expand Down Expand Up @@ -1502,6 +1514,10 @@ def test_mean_var(jax_dist, sp_dist, params):
dist.TwoSidedTruncatedDistribution,
):
pytest.skip("Truncated distributions do not has mean/var implemented")
if jax_dist is dist.SkewMultivariateNormal:
pytest.skip(
"We check SkewMultivariateNormal against MultivariateNormal elsewhere"
)
if jax_dist is dist.ProjectedNormal:
pytest.skip("Mean is defined in submanifold")

Expand Down Expand Up @@ -2570,3 +2586,169 @@ def sample_binomial_withp0(key):
return dist.Binomial(total_count=n, probs=0).sample(key)

jax.vmap(sample_binomial_withp0)(random.split(random.PRNGKey(0), 1))


def locs(size: int) -> SearchStrategy[NDArray[float]]:
return cast(
SearchStrategy[NDArray[float]],
hnp.arrays(
elements=st.floats(
min_value=-1, max_value=1, allow_nan=False, allow_infinity=False
),
dtype=np.dtype("float"),
shape=size,
),
)


def skews(size: int) -> SearchStrategy[NDArray[float]]:
return cast(
SearchStrategy[NDArray[float]],
hnp.arrays(
elements=st.floats(
min_value=-4, max_value=4, allow_nan=False, allow_infinity=False
),
dtype=np.dtype("float"),
shape=size,
),
)


def variances(size: int) -> SearchStrategy[NDArray[float]]:
return cast(
SearchStrategy[NDArray[float]],
hnp.arrays(
# Variances that are too small make it impossible to test t against normal
elements=st.floats(
min_value=0.1,
max_value=3,
allow_nan=False,
allow_infinity=False,
exclude_min=True,
),
dtype=np.dtype("float"),
shape=size,
),
)


def corr_vech_to_matrix(vech: NDArray[float]):
width = (math.isqrt(8 * vech.size + 1) + 1) // 2
zeros = np.zeros((width, width))
zeros[np.tril_indices(width, k=-1)] = vech
np.fill_diagonal(zeros, 1)
return zeros


def correlation_chols(size: int) -> SearchStrategy[NDArray[float]]:
return hnp.arrays(
# Floating point issues mean we sometimes get arrays which aren't positive semi-definite
# if we allow correlations of exactly 1 and -1
elements=st.floats(
min_value=-0.99, max_value=0.99, allow_nan=False, allow_infinity=False
),
dtype=np.dtype("float"),
shape=size * (size - 1) // 2,
).map(
corr_vech_to_matrix # type: ignore
)


@st.composite
def loc_and_scale(draw: DrawFn):
# Would need to generalize meshgrid to relax this restriction
size = 2
corr = draw(correlation_chols(size))
var = draw(variances(size))
return (draw(locs(size)), jnp.sqrt(var)[..., None] * corr)


@st.composite
def loc_and_scale_and_skewers(draw: DrawFn):
# Would need to generalize meshgrid to relax this restriction
size = 2
corr = draw(correlation_chols(size))
var = draw(variances(size))
return (
draw(locs(size)),
jnp.sqrt(var)[..., None] * corr,
draw(skews(size)),
)


X, Y = np.meshgrid(np.linspace(-3, 3, 100), np.linspace(-3, 3, 100))
grid = np.dstack((X, Y))
X_wide, Y_wide = np.meshgrid(np.linspace(-6, 6, 50), np.linspace(-6, 6, 50))
grid_wide = np.dstack((X_wide, Y_wide))


@settings(deadline=None)
@given(loc_and_scale())
def test_skew_normal_log_prob_generalizes_normal(
loc_scale_tril: tuple[NDArray[float], NDArray[float]]
):
loc, scale_tril = loc_scale_tril
mvn = dist.MultivariateNormal(loc=loc, scale_tril=scale_tril)
smvn = dist.SkewMultivariateNormal(
loc=loc, scale_tril=scale_tril, skewers=np.zeros(scale_tril.shape[-1])
)
assert_allclose(mvn.log_prob(grid), smvn.log_prob(grid), atol=1e-6)


@settings(deadline=None)
@given(loc_and_scale())
def test_skew_normal_moments_generalize_normal(
loc_scale_tril: tuple[NDArray[float], NDArray[float]]
):
loc, scale_tril = loc_scale_tril
mvn = dist.MultivariateNormal(loc=loc, scale_tril=scale_tril)
smvn = dist.SkewMultivariateNormal(
loc=loc, scale_tril=scale_tril, skewers=np.zeros(scale_tril.shape[-1])
)
assert_allclose(mvn.mean, smvn.mean, atol=1e-30)
assert_allclose(mvn.covariance_matrix, smvn.covariance_matrix, atol=1e-30)


@settings(deadline=None, max_examples=10)
@given(loc_and_scale_and_skewers())
def test_skew_normal_log_prob_vs_samples(
loc_scale_tril_skewers: tuple[NDArray[float], NDArray[float], NDArray[float]]
):
loc, scale_tril, skewers = loc_scale_tril_skewers
note(f"Covariance: {scale_tril @ scale_tril.T}")
smvn = dist.SkewMultivariateNormal(loc=loc, scale_tril=scale_tril, skewers=skewers)
samples = smvn.sample(random.PRNGKey(0), sample_shape=(50_000,))
# gaussian_kde needs a different format
grid_ = np.vstack([X_wide.ravel(), Y_wide.ravel()])
lp = jnp.exp(smvn.log_prob(grid_.T))
k = osp.gaussian_kde(samples.T, bw_method="scott")(grid_)

lp_normed = (lp - lp.min()) / (lp.max() - lp.min())
k_normed = (k - k.min()) / (k.max() - k.min())
assert_allclose(lp_normed, k_normed, atol=0.07)


def split_cov(cov: NDArray[float]) -> tuple[NDArray[float], NDArray[float]]:
std_devs = np.sqrt(np.diag(cov))
dinv = np.diag(1 / std_devs)
corr = dinv @ cov @ dinv
tril_i = np.tril_indices(len(std_devs), k=-1)
return (std_devs, corr[tril_i])


@settings(deadline=None)
@given(loc_and_scale_and_skewers())
def test_skew_normal_moments_vs_samples(
loc_scale_tril_skewers: tuple[NDArray[float], NDArray[float], NDArray[float]]
):
loc, scale_tril, skewers = loc_scale_tril_skewers
note(f"Covariance: {scale_tril @ scale_tril.T}")
smvn = dist.SkewMultivariateNormal(loc=loc, scale_tril=scale_tril, skewers=skewers)
samples = smvn.sample(random.PRNGKey(0), sample_shape=(500_000,))
assert_allclose(np.mean(samples, axis=0), smvn.mean, rtol=0.005, atol=0.001)

std_devs_sample, corr_sample = split_cov(np.cov(samples.T))
std_devs_dist, corr_dist = split_cov(smvn.covariance_matrix)
assert_allclose(std_devs_sample, std_devs_dist, rtol=0.003)
note(f"Sample corr: {corr_sample}, Distribution corr: {corr_dist}")
assert_allclose(corr_sample, corr_dist, atol=0.006)

0 comments on commit 1809d35

Please sign in to comment.