Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add draft of SkewMultivariateNormal #1471

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ numpyro.egg-info
__pycache__/
.ipynb_checkpoints/
build
.venv/

# built / compiled
*.pyc
*.pyo
/build
/dist
/.hypothesis

# IDE
.idea
Expand Down
2 changes: 2 additions & 0 deletions numpyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
Pareto,
RelaxedBernoulli,
RelaxedBernoulliLogits,
SkewMultivariateNormal,
SoftLaplace,
StudentT,
Uniform,
Expand Down Expand Up @@ -158,6 +159,7 @@
"MultivariateStudentT",
"LowRankMultivariateNormal",
"Normal",
"SkewMultivariateNormal",
"NegativeBinomialProbs",
"NegativeBinomialLogits",
"NegativeBinomial2",
Expand Down
136 changes: 136 additions & 0 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

from typing import Union, cast

import numpy as np
from numpy.typing import NDArray

from jax import lax
from jax.experimental.sparse import BCOO
Expand Down Expand Up @@ -1731,6 +1734,139 @@ def variance(self):
return jnp.broadcast_to(self.scale**2, self.batch_shape)


def skew_delta(skewers_: NDArray[float], cov_: NDArray[float]):
return (jnp.einsum("...ij,...j->...i", cov_, skewers_)) / jnp.sqrt(
1
+ jnp.einsum("...j,...jk,...k->...", skewers_, cov_, skewers_)[..., jnp.newaxis]
)


# Regularized Multivariate Regression Models with Skew-t Error Distributions
# https://epublications.marquette.edu/cgi/viewcontent.cgi?article=1225&context=mscs_fac
class SkewMultivariateNormal(Distribution):
arg_constraints = {
"loc": constraints.real_vector,
"scale_tril": constraints.lower_cholesky,
"skewers": constraints.real_vector,
}
support = constraints.real_vector
reparametrized_params = ["loc", "scale_tril", "skewers"]
uv_norm = Normal(0.0, 1.0)

@staticmethod
def mk_big_mv_norm(
loc: NDArray[float], skewers: NDArray[float], scale_tril: NDArray[float]
):
cov = jnp.einsum("...ij,...hj->...ih", scale_tril, scale_tril)
delta_ = skew_delta(skewers, cov)
cov_star = jnp.block(
[
[
jnp.ones(skewers.shape[:-1] + (1, 1)),
jnp.expand_dims(delta_, axis=-2),
],
[jnp.expand_dims(delta_, axis=-1), cov],
]
)

return MultivariateNormal(
loc=jnp.zeros(loc.shape[-1] + 1), scale_tril=jnp.linalg.cholesky(cov_star)
)

def __init__(
self,
loc: Union[NDArray[float], float],
scale_tril: NDArray[float],
skewers: NDArray[float],
validate_args: None = None,
):
if jnp.ndim(loc) == 0:
(loc_,) = promote_shapes(loc, shape=(1,))
else:
loc_ = cast(NDArray[float], loc)
batch_shape = lax.broadcast_shapes(
jnp.shape(loc_)[:-1], jnp.shape(scale_tril)[:-2], jnp.shape(skewers)[:-1]
)
(self.loc,) = promote_shapes(loc_, shape=batch_shape + loc_.shape[-1:])
(self.skewers,) = promote_shapes(
skewers, shape=batch_shape + skewers.shape[-1:]
)
(self.scale_tril,) = promote_shapes(
scale_tril, shape=batch_shape + scale_tril.shape[-2:]
)
cov_batch = jnp.einsum("...ij,...hj->...ih", self.scale_tril, self.scale_tril)
self._std_devs = jnp.sqrt(jnp.diagonal(cov_batch, axis1=-2, axis2=-1))

# Used for sampling
self._big_mv_norm = self.mk_big_mv_norm(
# The blog post just uses unstandardized skewers here but that leads to
# a discrepancy between sampling and log_prob
loc=self.loc,
skewers=skewers / self._std_devs,
scale_tril=scale_tril,
)
# Used for log_prob
self._mv_norm = MultivariateNormal(loc_, scale_tril=scale_tril)

skew_mean = jnp.sqrt(2 / jnp.pi) * skew_delta(
self.skewers / self._std_devs, cov_batch
)
self._mean = self.loc + skew_mean
# The paper just uses `mean` here but that's definitely not right because
# it potentially leads to covariance matrices which are not positive semi definite
self._covariance = cov_batch - jnp.einsum(
"...i,...j->...ij", skew_mean, skew_mean
)

event_shape = jnp.shape(self.scale_tril)[-1:]
super().__init__(
batch_shape=batch_shape,
event_shape=event_shape,
validate_args=validate_args,
)

@validate_sample
def log_prob(self, value: NDArray[float]) -> NDArray[float]:
return (
jnp.log(2)
+ self._mv_norm.log_prob(value)
+ jnp.log(
self.uv_norm.cdf(
jnp.einsum(
"...k,...k->...",
(value - self.loc) / self._std_devs,
self.skewers,
)
)
)
)

@staticmethod
def infer_shapes(
loc: NDArray[float], scale_tril: NDArray[float], skewers: NDArray[float]
):
event_shape = (scale_tril[-1],)
batch_shape = lax.broadcast_shapes(loc[:-1], scale_tril[:-2], skewers[:-1])
return batch_shape, event_shape

# https://gregorygundersen.com/blog/2020/12/29/multivariate-skew-normal/
def sample(
self, key: random.PRNGKey, sample_shape: tuple[int, ...] = ()
) -> NDArray[float]:
assert is_prng_key(key)
x = self._big_mv_norm.sample(key, sample_shape=sample_shape)
sign_bit, samples = x[..., 0, jnp.newaxis], x[..., 1:]
return jnp.where(sign_bit <= 0, -1 * samples, samples) + self.loc

@property
def mean(self):
return jnp.broadcast_to(self._mean, self.shape())

@property
def covariance_matrix(self):
return self._covariance


class Pareto(TransformedDistribution):
arg_constraints = {"scale": constraints.positive, "alpha": constraints.positive}
reparametrized_params = ["scale", "alpha"]
Expand Down
2 changes: 1 addition & 1 deletion scripts/update_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import sys

root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
blacklist = ["/build/", "/dist/", "/pyro_api.egg"]
blacklist = ["/build/", "/dist/", "/pyro_api.egg", "/.venv/"]
file_types = [("*.py", "# {}"), ("*.cpp", "// {}")]

parser = argparse.ArgumentParser()
Expand Down
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[flake8]
max-line-length = 120
exclude = docs/src, build, dist, .ipynb_checkpoints
exclude = docs/src, build, dist, .ipynb_checkpoints, .venv
ignore = W503,E203
per-file-ignores =
numpyro/contrib/tfp/distributions.py:F811
Expand All @@ -17,6 +17,7 @@ force_sort_within_sections = true
combine_as_imports = true
multi_line_output = 3
skip=docs
extend_skip = .venv

[tool:pytest]
filterwarnings = error
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"test": [
"black[jupyter]>=21.8b0",
"flake8",
"hypothesis[numpy]",
"isort>=5.0",
"pytest>=4.1",
"pyro-api>=0.1.1",
Expand Down
Loading