Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid spurious deprecation warning in CustomDist #7391

Merged
merged 1 commit into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 12 additions & 19 deletions pymc/distributions/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.type import RandomGeneratorType, RandomType
from pytensor.tensor.random.utils import normalize_size_param
from pytensor.tensor.utils import _parse_gufunc_signature, safe_signature
from pytensor.tensor.utils import safe_signature

from pymc.distributions.distribution import (
Distribution,
Expand Down Expand Up @@ -108,19 +108,9 @@ def dist(
class_name: str = "CustomDist",
**kwargs,
):
if ndim_supp is None or ndims_params is None:
if signature is None:
ndim_supp = 0
ndims_params = [0] * len(dist_params)
else:
inputs, outputs = _parse_gufunc_signature(signature)
ndim_supp = max(len(out) for out in outputs)
ndims_params = [len(inp) for inp in inputs]

if ndim_supp > 0:
raise NotImplementedError(
"CustomDist with ndim_supp > 0 and without a `dist` function are not supported."
)
if ndim_supp is None and signature is None:
# Assume a scalar distribution
signature = safe_signature([0] * len(dist_params), [0])

dist_params = [as_tensor_variable(param) for param in dist_params]

Expand Down Expand Up @@ -148,6 +138,7 @@ def dist(
support_point=support_point,
ndim_supp=ndim_supp,
ndims_params=ndims_params,
signature=signature,
dtype=dtype,
class_name=class_name,
**kwargs,
Expand All @@ -161,8 +152,9 @@ def rv_op(
logcdf: Callable | None,
random: Callable | None,
support_point: Callable | None,
ndim_supp: int,
ndims_params: Sequence[int],
signature: str | None,
ndim_supp: int | None,
ndims_params: Sequence[int] | None,
dtype: str,
class_name: str,
**kwargs,
Expand All @@ -175,6 +167,7 @@ def rv_op(
inplace=False,
ndim_supp=ndim_supp,
ndims_params=ndims_params,
signature=signature,
dtype=dtype,
_print_name=(class_name, f"\\operatorname{{{class_name}}}"),
# Specific to CustomDist
Expand Down Expand Up @@ -344,7 +337,7 @@ def change_custom_dist_size(op, rv, new_size, expand):
new_rv_op = rv_type(
inputs=[*dummy_params, *rngs],
outputs=[dummy_rv, *rngs_updates],
signature=signature,
extended_signature=extended_signature,
)
new_rv = new_rv_op(new_size, *dist_params, *rngs)

Expand All @@ -357,13 +350,13 @@ def change_custom_dist_size(op, rv, new_size, expand):

inputs = [*dummy_params, *rngs]
outputs = [dummy_rv, *rngs_updates]
signature = cls._infer_final_signature(
extended_signature = cls._infer_final_signature(
signature, n_inputs=len(inputs), n_outputs=len(outputs), n_rngs=len(rngs)
)
rv_op = rv_type(
inputs=inputs,
outputs=outputs,
signature=signature,
extended_signature=extended_signature,
)
return rv_op(size, *dist_params, *rngs)

Expand Down
57 changes: 23 additions & 34 deletions tests/distributions/test_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@
from pymc.step_methods import Metropolis
from pymc.testing import assert_support_point_is_expected

# Raise for any warnings in this file
pytestmark = pytest.mark.filterwarnings("error")


class TestCustomDist:
@pytest.mark.parametrize("size", [(), (3,), (3, 2)], ids=str)
Expand Down Expand Up @@ -105,24 +108,24 @@ def test_custom_dist_without_random(self):
with pytest.raises(NotImplementedError):
sample_posterior_predictive(idata, model=model)

@pytest.mark.xfail(
NotImplementedError,
reason="Support shape of multivariate CustomDist cannot be inferred. See https://github.com/pymc-devs/pytensor/pull/388",
)
@pytest.mark.parametrize("size", [(), (3,), (3, 2)], ids=str)
def test_custom_dist_with_random_multivariate(self, size):
def random(mu, rng, size):
return rng.multivariate_normal(
mean=mu.ravel(),
cov=np.eye(mu.shape[-1]),
size=size,
)

supp_shape = 5
with Model() as model:
mu = Normal("mu", 0, 1, size=supp_shape)
obs = CustomDist(
"custom_dist",
mu,
random=lambda mu, rng=None, size=None: rng.multivariate_normal(
mean=mu, cov=np.eye(len(mu)), size=size
),
random=random,
observed=np.random.randn(100, *size, supp_shape),
ndims_params=[1],
ndim_supp=1,
signature="(n)->(n)",
)

assert isinstance(obs.owner.op, CustomDistRV)
Expand Down Expand Up @@ -156,20 +159,16 @@ def test_custom_dist_old_api_error(self):
):
CustomDist("a", lambda x: x)

@pytest.mark.xfail(
NotImplementedError,
reason="Support shape of multivariate CustomDist cannot be inferred. See https://github.com/pymc-devs/pytensor/pull/388",
)
@pytest.mark.parametrize("size", [None, (), (2,)], ids=str)
def test_custom_dist_multivariate_logp(self, size):
supp_shape = 5
with Model() as model:

def logp(value, mu):
return MvNormal.logp(value, mu, pt.eye(mu.shape[0]))
return MvNormal.logp(value, mu, pt.eye(mu.shape[-1]))

mu = Normal("mu", size=supp_shape)
a = CustomDist("a", mu, logp=logp, ndims_params=[1], ndim_supp=1, size=size)
a = CustomDist("a", mu, logp=logp, signature="(n)->(n)", size=size)

assert isinstance(a.owner.op, CustomDistRV)
mu_test_value = npr.normal(loc=0, scale=1, size=supp_shape).astype(pytensor.config.floatX)
Expand Down Expand Up @@ -219,10 +218,6 @@ def density_support_point(rv, size, mu):
assert evaled_support_point.shape == to_tuple(size)
assert np.all(evaled_support_point == mu_val)

@pytest.mark.xfail(
NotImplementedError,
reason="Support shape of multivariate CustomDist cannot be inferred. See https://github.com/pymc-devs/pytensor/pull/388",
)
@pytest.mark.parametrize("size", [(), (2,), (3, 2)], ids=str)
def test_custom_dist_custom_support_point_multivariate(self, size):
def density_support_point(rv, size, mu):
Expand All @@ -235,19 +230,14 @@ def density_support_point(rv, size, mu):
"a",
mu,
support_point=density_support_point,
ndims_params=[1],
ndim_supp=1,
signature="(n)->(n)",
size=size,
)
assert isinstance(a.owner.op, CustomDistRV)
evaled_support_point = support_point(a).eval({mu: mu_val})
assert evaled_support_point.shape == (*to_tuple(size), 5)
assert np.all(evaled_support_point == mu_val)

@pytest.mark.xfail(
NotImplementedError,
reason="Support shape of multivariate CustomDist cannot be inferred. See https://github.com/pymc-devs/pytensor/pull/388",
)
@pytest.mark.parametrize(
"with_random, size",
[
Expand All @@ -267,21 +257,14 @@ def _random(mu, rng=None, size=None):
else:
random = None

mu_val = np.random.normal(loc=2, scale=1, size=5).astype(pytensor.config.floatX)
with Model():
mu = Normal("mu", size=5)
a = CustomDist("a", mu, random=random, ndims_params=[1], ndim_supp=1, size=size)
a = CustomDist("a", mu, random=random, signature="(n)->(n)", size=size)
assert isinstance(a.owner.op, CustomDistRV)
if with_random:
evaled_support_point = support_point(a).eval({mu: mu_val})
evaled_support_point = support_point(a).eval()
assert evaled_support_point.shape == (*to_tuple(size), 5)
assert np.all(evaled_support_point == 0)
else:
with pytest.raises(
TypeError,
match="Cannot safely infer the size of a multivariate random variable's support_point.",
):
evaled_support_point = support_point(a).eval({mu: mu_val})

def test_dist(self):
mu = 1
Expand All @@ -300,6 +283,12 @@ def test_dist(self):
x_logp = logp(x, test_value)
assert np.allclose(x_logp.eval(), st.norm(1).logpdf(test_value))

def test_multivariate_insufficient_signature(self):
with pytest.raises(
NotImplementedError, match="signature is not sufficient to infer the support shape"
):
CustomDist.dist(signature="(n)->(m)")


class TestCustomSymbolicDist:
def test_basic(self):
Expand Down
Loading