Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SkewMultivariateNormal and SkewMultivariateStudentT #1452

Open
colehaus opened this issue Jul 17, 2022 · 4 comments
Open

SkewMultivariateNormal and SkewMultivariateStudentT #1452

colehaus opened this issue Jul 17, 2022 · 4 comments
Labels
enhancement New feature or request

Comments

@colehaus
Copy link
Contributor

I have implemented versions of both of these:

from __future__ import annotations

from typing import Union, cast

import jax
from jax import lax
from jax import random
import jax.numpy as jnp
from jax.random import PRNGKey
from jax.scipy.linalg import cho_solve
from jax.scipy.special import gammaln
from numpy.typing import NDArray
from numpyro.distributions import (
    Chi2,
    Distribution,
    MultivariateNormal,
    MultivariateStudentT,
    Normal,
    constraints,
)
from numpyro.distributions.util import is_prng_key, promote_shapes, validate_sample

def delta(skewers_: NDArray[float], cov_: NDArray[float]):
    return (jnp.einsum("...ij,...j->...i", cov_, skewers_)) / jnp.sqrt(
        1 + jnp.einsum("...j,...jk,...k->...", skewers_, cov_, skewers_)[..., jnp.newaxis]
    )

# Efficient computation of the distribution functions of student's t chi-squared and f to moderate accuracy
# https://sci-hub.se/10.1080/00949658208810542
# Have to use approximation because `betainc` doesn't have grads defined.
# Which means we can't use the official `StudentT.cdf`
@jax.jit
def t_cdf_approx(df: Union[NDArray[float], float], t: Union[NDArray[float], float]):
    a = df - 1 / 2
    b = 48 * a**2
    # Add epsilon to avoid undefined gradient at 0
    z = jnp.sqrt(a * jnp.log(1 + t**2 / df) + 1e-24)
    u = (
        z
        + (z**3 + 3 * z) / b
        - (4 * z**7 + 33 * z**5 + 240 * z**3 + 855 * z) / (10 * b * (b + 0.8 * z**4 + 100))
    )
    return Normal(loc=0, scale=1).cdf(u * jnp.sign(t))

# Regularized Multivariate Regression Models with Skew-t Error Distributions
# https://epublications.marquette.edu/cgi/viewcontent.cgi?article=1225&context=mscs_fac
class SkewMultivariateNormal(Distribution):  # type: ignore # pylint: disable=too-many-instance-attributes
    arg_constraints = {
        "loc": constraints.real_vector,
        "scale_tril": constraints.lower_cholesky,
        "skewers": constraints.real_vector,
    }
    support = constraints.real_vector
    reparametrized_params = ["loc", "scale_tril", "skewers"]
    uv_norm = Normal(0.0, 1.0)

    @staticmethod
    def mk_big_mv_norm(loc: NDArray[float], skewers: NDArray[float], scale_tril: NDArray[float]):
        cov = jnp.einsum("...ij,...hj->...ih", scale_tril, scale_tril)
        delta_ = delta(skewers, cov)
        cov_star = jnp.block(
            [
                [jnp.ones(skewers.shape[:-1] + (1, 1)), jnp.expand_dims(delta_, axis=-2)],
                [jnp.expand_dims(delta_, axis=-1), cov],
            ]
        )

        return MultivariateNormal(loc=jnp.zeros(loc.shape[-1] + 1), scale_tril=jnp.linalg.cholesky(cov_star))
    def __init__(
        self,
        scale_tril: NDArray[float],
        skewers: NDArray[float],
        loc: Union[NDArray[float], float] = 0,
        validate_args: None = None,
    ):
        if jnp.ndim(loc) == 0:
            (loc_,) = promote_shapes(loc, shape=(1,))
        else:
            loc_ = cast(NDArray[float], loc)
        batch_shape = lax.broadcast_shapes(
            jnp.shape(loc_)[:-1], jnp.shape(scale_tril)[:-2], jnp.shape(skewers)[:-1]
        )
        (self.loc,) = promote_shapes(loc_, shape=batch_shape + loc_.shape[-1:])
        (self.skewers,) = promote_shapes(skewers, shape=batch_shape + skewers.shape[-1:])
        (self.scale_tril,) = promote_shapes(scale_tril, shape=batch_shape + scale_tril.shape[-2:])
        cov_batch = jnp.einsum("...ij,...hj->...ih", self.scale_tril, self.scale_tril)
        self._std_devs = jnp.sqrt(jnp.diagonal(cov_batch, axis1=-2, axis2=-1))

        # Used for sampling
        self._big_mv_norm = self.mk_big_mv_norm(
            # The blog post just uses unstandardized skewers here but that leads to
            # a discrepancy between sampling and log_prob
            loc=self.loc,
            skewers=skewers / self._std_devs,
            scale_tril=scale_tril,
        )
        # Used for log_prob
        self._mv_norm = MultivariateNormal(loc_, scale_tril=scale_tril)

        skew_mean = jnp.sqrt(2 / jnp.pi) * delta(self.skewers / self._std_devs, cov_batch)
        self._mean = self.loc + skew_mean
        # The paper just uses `mean` here but that's definitely not right because
        # it potentially leads to covariance matrices which are not positive semi definite
        self._covariance = cov_batch - jnp.einsum("...i,...j->...ij", skew_mean, skew_mean)

        event_shape = jnp.shape(self.scale_tril)[-1:]
        super().__init__(
            batch_shape=batch_shape,
            event_shape=event_shape,
            validate_args=validate_args,
        )
    @validate_sample
    def log_prob(self, value: NDArray[float]) -> NDArray[float]:
        return (
            jnp.log(2)
            + self._mv_norm.log_prob(value)
            + jnp.log(
                self.uv_norm.cdf(jnp.einsum("...k,...k->...", (value - self.loc) / self._std_devs, self.skewers))
            )
        )
    @staticmethod
    def infer_shapes(loc: NDArray[float], scale_tril: NDArray[float], skewers: NDArray[float]):
        event_shape = (scale_tril[-1],)
        batch_shape = lax.broadcast_shapes(loc[:-1], scale_tril[:-2], skewers[:-1])
        return batch_shape, event_shape
    # https://gregorygundersen.com/blog/2020/12/29/multivariate-skew-normal/
    def sample(self, key: PRNGKey, sample_shape: tuple[int, ...] = ()) -> NDArray[float]:
        assert is_prng_key(key)
        x = self._big_mv_norm.sample(key, sample_shape=sample_shape)
        sign_bit, samples = x[..., 0, jnp.newaxis], x[..., 1:]
        return jnp.where(sign_bit <= 0, -1 * samples, samples) + self.loc
    @property
    def mean(self):
        return jnp.broadcast_to(self._mean, self.shape())
    @property
    def covariance_matrix(self):
        return self._covariance

# https://epublications.marquette.edu/cgi/viewcontent.cgi?article=1225&context=mscs_fac
class SkewMultivariateStudentT(Distribution):  # type: ignore # pylint: disable=too-many-instance-attributes
    arg_constraints = {
        "df": constraints.positive,
        "loc": constraints.real_vector,
        "scale_tril": constraints.lower_cholesky,
        "skewers": constraints.real_vector,
    }
    support = constraints.real_vector
    reparametrized_params = ["df", "loc", "scale_tril", "skewers"]

    def __init__(  # pylint: disable=too-many-arguments
        self,
        df: float,
        scale_tril: NDArray[float],
        skewers: NDArray[float],
        loc: Union[NDArray[float], float] = 0,
        validate_args: None = None,
    ):
        if jnp.ndim(loc) == 0:
            (loc_,) = promote_shapes(loc, shape=(1,))
        else:
            loc_ = cast(NDArray[float], loc)
        batch_shape = lax.broadcast_shapes(
            jnp.shape(df), jnp.shape(loc_)[:-1], jnp.shape(scale_tril)[:-2], jnp.shape(skewers)[:-1]
        )
        (self.df,) = promote_shapes(df, shape=batch_shape)
        (self.loc,) = promote_shapes(loc_, shape=batch_shape + loc_.shape[-1:])
        (self.skewers,) = promote_shapes(skewers, shape=batch_shape + skewers.shape[-1:])
        (self.scale_tril,) = promote_shapes(scale_tril, shape=batch_shape + scale_tril.shape[-2:])

        self._width = scale_tril.shape[-1]

        # For log_prob
        self._mv_t = MultivariateStudentT(df=df, scale_tril=scale_tril, loc=loc)
        eye = jnp.broadcast_to(jnp.eye(self._width), shape=batch_shape + scale_tril.shape[-2:])
        prec_scale_tril = jnp.linalg.cholesky(cho_solve((self.scale_tril, True), eye))
        self.prec = jnp.einsum("...ij,...hj->...ih", prec_scale_tril, prec_scale_tril)
        cov_batch = jnp.einsum("...ij,...hj->...ih", self.scale_tril, self.scale_tril)
        self._std_devs = jnp.sqrt(jnp.diagonal(cov_batch, axis1=-2, axis2=-1))

        # For sample
        self._mv_skew_norm = SkewMultivariateNormal(
            scale_tril=scale_tril, loc=jnp.zeros(self._width), skewers=skewers
        )
        self._chi2 = Chi2(self.df)

        # Mean
        b = jnp.sqrt(self.df / jnp.pi) * jnp.exp(gammaln((self.df - 1) / 2) - gammaln(self.df / 2))
        skew_mean = b[..., jnp.newaxis] * delta(self.skewers / self._std_devs, cov_batch)
        self._mean = self.loc + skew_mean
        # The paper says we should multiply by the std devs but that produces results that
        # disagree with `sample` and with `SkewMultivariateNormal`
        # It also says we should use `_mean` instead of `skew_mean` but that allows for
        # covariance matrices which are not positive semi-definite
        self._covariance = jnp.array((self.df / (self.df - 2)))[
            ..., jnp.newaxis, jnp.newaxis
        ] * cov_batch - jnp.einsum("...i,...j->...ij", skew_mean, skew_mean)

        event_shape = jnp.shape(self.scale_tril)[-1:]
        super().__init__(
            batch_shape=batch_shape,
            event_shape=event_shape,
            validate_args=validate_args,
        )
    @validate_sample
    def log_prob(self, value: NDArray[float]) -> NDArray[float]:
        distance = value - self.loc
        Qy = jnp.einsum("...j,...jk,...k->...", distance, self.prec, distance)
        # Have to use approximation because `betainc` doesn't have grads defined.
        # Which means we can't use the official `StudentT.cdf`
        skew = t_cdf_approx(
            self.df + self._width,
            jnp.einsum(
                "...k,...k->...",
                self.skewers,
                jnp.einsum(
                    "...i,...->...i", distance / self._std_devs, jnp.sqrt((self.df + self._width) / (Qy + self.df))
                ),
            ),
        )
        return jnp.log(2) + self._mv_t.log_prob(value) + jnp.log(skew)
    @staticmethod
    def infer_shapes(df: float, loc: NDArray[float], scale_tril: NDArray[float], skewers: NDArray[float]):
        event_shape = (scale_tril[-1],)
        batch_shape = lax.broadcast_shapes(df, loc[:-1], scale_tril[:-2], skewers[:-1])
        return batch_shape, event_shape
    def sample(self, key: PRNGKey, sample_shape: tuple[int, ...] = ()) -> NDArray[float]:
        assert is_prng_key(key)
        key_normal, key_chi2 = random.split(key)
        normal = self._mv_skew_norm.sample(key_normal, sample_shape=sample_shape)
        chi = self._chi2.sample(key_chi2, sample_shape)
        return self.loc + jnp.einsum("...i,...->...i", normal, jnp.sqrt(self.df / chi))
    @property
    def mean(self):
        return jnp.broadcast_to(self._mean, self.shape())
    @property
    def covariance_matrix(self):
        return self._covariance

(I also have some coding testing them.)

  1. Is there interest in upstreaming these?
  2. Are there obvious simplifications?
  3. SkewMultivariateStudentT is notably slower than MultivariateStudentT in some circumstances. Are there any obvious performance improvements available?
@martinjankowiak
Copy link
Collaborator

@colehaus yes, i'm sure a PR would be welcome.

Would it help to use tfp.math.betainc?

import tensorflow_probability.substrates.jax as tfp

@colehaus
Copy link
Contributor Author

colehaus commented Jul 17, 2022

Unless I'm misunderstanding you: There's a comment in the source describing the problem there which is that betainc doesn't have all grads defined: tensorflow/probability#655 (comment).

@fehiepsi
Copy link
Member

fehiepsi commented Jul 17, 2022

FYI since the last release, tfp.math.betainc has grad w.r.t. all parameters. I would suggest to have 3 PRs for:

  • StudentT.cdf (which locally import import tensorflow_probability.substrates.jax as tfp)
  • SkewMVN
  • SkewMVT

@fehiepsi fehiepsi added the enhancement New feature or request label Jul 17, 2022
@colehaus
Copy link
Contributor Author

Ah, that's good news! And that sounds like a reasonable plan. I probably won't be able to think about doing it for a few weeks though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants