Skip to content

Commit

Permalink
Avoid spurious deprecation warning signature/extended_signature in Cu…
Browse files Browse the repository at this point in the history
…stomDist

Also allow multivariate CustomDist to be created when signature suffices to infer core shape.
  • Loading branch information
ricardoV94 committed Jun 26, 2024
1 parent 29eef08 commit d4e5db1
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 53 deletions.
31 changes: 12 additions & 19 deletions pymc/distributions/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.type import RandomGeneratorType, RandomType
from pytensor.tensor.random.utils import normalize_size_param
from pytensor.tensor.utils import _parse_gufunc_signature, safe_signature
from pytensor.tensor.utils import safe_signature

from pymc.distributions.distribution import (
Distribution,
Expand Down Expand Up @@ -108,19 +108,9 @@ def dist(
class_name: str = "CustomDist",
**kwargs,
):
if ndim_supp is None or ndims_params is None:
if signature is None:
ndim_supp = 0
ndims_params = [0] * len(dist_params)
else:
inputs, outputs = _parse_gufunc_signature(signature)
ndim_supp = max(len(out) for out in outputs)
ndims_params = [len(inp) for inp in inputs]

if ndim_supp > 0:
raise NotImplementedError(
"CustomDist with ndim_supp > 0 and without a `dist` function are not supported."
)
if ndim_supp is None and signature is None:
# Assume a scalar distribution
signature = safe_signature([0] * len(dist_params), [0])

dist_params = [as_tensor_variable(param) for param in dist_params]

Expand Down Expand Up @@ -148,6 +138,7 @@ def dist(
support_point=support_point,
ndim_supp=ndim_supp,
ndims_params=ndims_params,
signature=signature,
dtype=dtype,
class_name=class_name,
**kwargs,
Expand All @@ -161,8 +152,9 @@ def rv_op(
logcdf: Callable | None,
random: Callable | None,
support_point: Callable | None,
ndim_supp: int,
ndims_params: Sequence[int],
signature: str | None,
ndim_supp: int | None,
ndims_params: Sequence[int] | None,
dtype: str,
class_name: str,
**kwargs,
Expand All @@ -175,6 +167,7 @@ def rv_op(
inplace=False,
ndim_supp=ndim_supp,
ndims_params=ndims_params,
signature=signature,
dtype=dtype,
_print_name=(class_name, f"\\operatorname{{{class_name}}}"),
# Specific to CustomDist
Expand Down Expand Up @@ -344,7 +337,7 @@ def change_custom_dist_size(op, rv, new_size, expand):
new_rv_op = rv_type(
inputs=[*dummy_params, *rngs],
outputs=[dummy_rv, *rngs_updates],
signature=signature,
extended_signature=extended_signature,
)
new_rv = new_rv_op(new_size, *dist_params, *rngs)

Expand All @@ -357,13 +350,13 @@ def change_custom_dist_size(op, rv, new_size, expand):

inputs = [*dummy_params, *rngs]
outputs = [dummy_rv, *rngs_updates]
signature = cls._infer_final_signature(
extended_signature = cls._infer_final_signature(
signature, n_inputs=len(inputs), n_outputs=len(outputs), n_rngs=len(rngs)
)
rv_op = rv_type(
inputs=inputs,
outputs=outputs,
signature=signature,
extended_signature=extended_signature,
)
return rv_op(size, *dist_params, *rngs)

Expand Down
57 changes: 23 additions & 34 deletions tests/distributions/test_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@
from pymc.step_methods import Metropolis
from pymc.testing import assert_support_point_is_expected

# Raise for any warnings in this file
pytestmark = pytest.mark.filterwarnings("error")


class TestCustomDist:
@pytest.mark.parametrize("size", [(), (3,), (3, 2)], ids=str)
Expand Down Expand Up @@ -105,24 +108,24 @@ def test_custom_dist_without_random(self):
with pytest.raises(NotImplementedError):
sample_posterior_predictive(idata, model=model)

@pytest.mark.xfail(
NotImplementedError,
reason="Support shape of multivariate CustomDist cannot be inferred. See https://github.com/pymc-devs/pytensor/pull/388",
)
@pytest.mark.parametrize("size", [(), (3,), (3, 2)], ids=str)
def test_custom_dist_with_random_multivariate(self, size):
def random(mu, rng, size):
return rng.multivariate_normal(
mean=mu.ravel(),
cov=np.eye(mu.shape[-1]),
size=size,
)

supp_shape = 5
with Model() as model:
mu = Normal("mu", 0, 1, size=supp_shape)
obs = CustomDist(
"custom_dist",
mu,
random=lambda mu, rng=None, size=None: rng.multivariate_normal(
mean=mu, cov=np.eye(len(mu)), size=size
),
random=random,
observed=np.random.randn(100, *size, supp_shape),
ndims_params=[1],
ndim_supp=1,
signature="(n)->(n)",
)

assert isinstance(obs.owner.op, CustomDistRV)
Expand Down Expand Up @@ -156,20 +159,16 @@ def test_custom_dist_old_api_error(self):
):
CustomDist("a", lambda x: x)

@pytest.mark.xfail(
NotImplementedError,
reason="Support shape of multivariate CustomDist cannot be inferred. See https://github.com/pymc-devs/pytensor/pull/388",
)
@pytest.mark.parametrize("size", [None, (), (2,)], ids=str)
def test_custom_dist_multivariate_logp(self, size):
supp_shape = 5
with Model() as model:

def logp(value, mu):
return MvNormal.logp(value, mu, pt.eye(mu.shape[0]))
return MvNormal.logp(value, mu, pt.eye(mu.shape[-1]))

mu = Normal("mu", size=supp_shape)
a = CustomDist("a", mu, logp=logp, ndims_params=[1], ndim_supp=1, size=size)
a = CustomDist("a", mu, logp=logp, signature="(n)->(n)", size=size)

assert isinstance(a.owner.op, CustomDistRV)
mu_test_value = npr.normal(loc=0, scale=1, size=supp_shape).astype(pytensor.config.floatX)
Expand Down Expand Up @@ -219,10 +218,6 @@ def density_support_point(rv, size, mu):
assert evaled_support_point.shape == to_tuple(size)
assert np.all(evaled_support_point == mu_val)

@pytest.mark.xfail(
NotImplementedError,
reason="Support shape of multivariate CustomDist cannot be inferred. See https://github.com/pymc-devs/pytensor/pull/388",
)
@pytest.mark.parametrize("size", [(), (2,), (3, 2)], ids=str)
def test_custom_dist_custom_support_point_multivariate(self, size):
def density_support_point(rv, size, mu):
Expand All @@ -235,19 +230,14 @@ def density_support_point(rv, size, mu):
"a",
mu,
support_point=density_support_point,
ndims_params=[1],
ndim_supp=1,
signature="(n)->(n)",
size=size,
)
assert isinstance(a.owner.op, CustomDistRV)
evaled_support_point = support_point(a).eval({mu: mu_val})
assert evaled_support_point.shape == (*to_tuple(size), 5)
assert np.all(evaled_support_point == mu_val)

@pytest.mark.xfail(
NotImplementedError,
reason="Support shape of multivariate CustomDist cannot be inferred. See https://github.com/pymc-devs/pytensor/pull/388",
)
@pytest.mark.parametrize(
"with_random, size",
[
Expand All @@ -267,21 +257,14 @@ def _random(mu, rng=None, size=None):
else:
random = None

mu_val = np.random.normal(loc=2, scale=1, size=5).astype(pytensor.config.floatX)
with Model():
mu = Normal("mu", size=5)
a = CustomDist("a", mu, random=random, ndims_params=[1], ndim_supp=1, size=size)
a = CustomDist("a", mu, random=random, signature="(n)->(n)", size=size)
assert isinstance(a.owner.op, CustomDistRV)
if with_random:
evaled_support_point = support_point(a).eval({mu: mu_val})
evaled_support_point = support_point(a).eval()
assert evaled_support_point.shape == (*to_tuple(size), 5)
assert np.all(evaled_support_point == 0)
else:
with pytest.raises(
TypeError,
match="Cannot safely infer the size of a multivariate random variable's support_point.",
):
evaled_support_point = support_point(a).eval({mu: mu_val})

def test_dist(self):
mu = 1
Expand All @@ -300,6 +283,12 @@ def test_dist(self):
x_logp = logp(x, test_value)
assert np.allclose(x_logp.eval(), st.norm(1).logpdf(test_value))

def test_multivariate_insufficient_signature(self):
with pytest.raises(
NotImplementedError, match="signature is not sufficient to infer the support shape"
):
CustomDist.dist(signature="(n)->(m)")


class TestCustomSymbolicDist:
def test_basic(self):
Expand Down

0 comments on commit d4e5db1

Please sign in to comment.