diff --git a/pymc/distributions/custom.py b/pymc/distributions/custom.py index 3f8882bc163..9b8098ab99e 100644 --- a/pymc/distributions/custom.py +++ b/pymc/distributions/custom.py @@ -27,7 +27,7 @@ from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.type import RandomGeneratorType, RandomType from pytensor.tensor.random.utils import normalize_size_param -from pytensor.tensor.utils import _parse_gufunc_signature, safe_signature +from pytensor.tensor.utils import safe_signature from pymc.distributions.distribution import ( Distribution, @@ -108,19 +108,9 @@ def dist( class_name: str = "CustomDist", **kwargs, ): - if ndim_supp is None or ndims_params is None: - if signature is None: - ndim_supp = 0 - ndims_params = [0] * len(dist_params) - else: - inputs, outputs = _parse_gufunc_signature(signature) - ndim_supp = max(len(out) for out in outputs) - ndims_params = [len(inp) for inp in inputs] - - if ndim_supp > 0: - raise NotImplementedError( - "CustomDist with ndim_supp > 0 and without a `dist` function are not supported." - ) + if ndim_supp is None and signature is None: + # Assume a scalar distribution + signature = safe_signature([0] * len(dist_params), [0]) dist_params = [as_tensor_variable(param) for param in dist_params] @@ -148,6 +138,7 @@ def dist( support_point=support_point, ndim_supp=ndim_supp, ndims_params=ndims_params, + signature=signature, dtype=dtype, class_name=class_name, **kwargs, @@ -161,8 +152,9 @@ def rv_op( logcdf: Callable | None, random: Callable | None, support_point: Callable | None, - ndim_supp: int, - ndims_params: Sequence[int], + signature: str | None, + ndim_supp: int | None, + ndims_params: Sequence[int] | None, dtype: str, class_name: str, **kwargs, @@ -175,6 +167,7 @@ def rv_op( inplace=False, ndim_supp=ndim_supp, ndims_params=ndims_params, + signature=signature, dtype=dtype, _print_name=(class_name, f"\\operatorname{{{class_name}}}"), # Specific to CustomDist @@ -344,7 +337,7 @@ def change_custom_dist_size(op, rv, new_size, expand): new_rv_op = rv_type( inputs=[*dummy_params, *rngs], outputs=[dummy_rv, *rngs_updates], - signature=signature, + extended_signature=extended_signature, ) new_rv = new_rv_op(new_size, *dist_params, *rngs) @@ -357,13 +350,13 @@ def change_custom_dist_size(op, rv, new_size, expand): inputs = [*dummy_params, *rngs] outputs = [dummy_rv, *rngs_updates] - signature = cls._infer_final_signature( + extended_signature = cls._infer_final_signature( signature, n_inputs=len(inputs), n_outputs=len(outputs), n_rngs=len(rngs) ) rv_op = rv_type( inputs=inputs, outputs=outputs, - signature=signature, + extended_signature=extended_signature, ) return rv_op(size, *dist_params, *rngs) diff --git a/tests/distributions/test_custom.py b/tests/distributions/test_custom.py index 96ba9685e12..45a902b8e54 100644 --- a/tests/distributions/test_custom.py +++ b/tests/distributions/test_custom.py @@ -53,6 +53,9 @@ from pymc.step_methods import Metropolis from pymc.testing import assert_support_point_is_expected +# Raise for any warnings in this file +pytestmark = pytest.mark.filterwarnings("error") + class TestCustomDist: @pytest.mark.parametrize("size", [(), (3,), (3, 2)], ids=str) @@ -105,24 +108,24 @@ def test_custom_dist_without_random(self): with pytest.raises(NotImplementedError): sample_posterior_predictive(idata, model=model) - @pytest.mark.xfail( - NotImplementedError, - reason="Support shape of multivariate CustomDist cannot be inferred. See https://github.com/pymc-devs/pytensor/pull/388", - ) @pytest.mark.parametrize("size", [(), (3,), (3, 2)], ids=str) def test_custom_dist_with_random_multivariate(self, size): + def random(mu, rng, size): + return rng.multivariate_normal( + mean=mu.ravel(), + cov=np.eye(mu.shape[-1]), + size=size, + ) + supp_shape = 5 with Model() as model: mu = Normal("mu", 0, 1, size=supp_shape) obs = CustomDist( "custom_dist", mu, - random=lambda mu, rng=None, size=None: rng.multivariate_normal( - mean=mu, cov=np.eye(len(mu)), size=size - ), + random=random, observed=np.random.randn(100, *size, supp_shape), - ndims_params=[1], - ndim_supp=1, + signature="(n)->(n)", ) assert isinstance(obs.owner.op, CustomDistRV) @@ -156,20 +159,16 @@ def test_custom_dist_old_api_error(self): ): CustomDist("a", lambda x: x) - @pytest.mark.xfail( - NotImplementedError, - reason="Support shape of multivariate CustomDist cannot be inferred. See https://github.com/pymc-devs/pytensor/pull/388", - ) @pytest.mark.parametrize("size", [None, (), (2,)], ids=str) def test_custom_dist_multivariate_logp(self, size): supp_shape = 5 with Model() as model: def logp(value, mu): - return MvNormal.logp(value, mu, pt.eye(mu.shape[0])) + return MvNormal.logp(value, mu, pt.eye(mu.shape[-1])) mu = Normal("mu", size=supp_shape) - a = CustomDist("a", mu, logp=logp, ndims_params=[1], ndim_supp=1, size=size) + a = CustomDist("a", mu, logp=logp, signature="(n)->(n)", size=size) assert isinstance(a.owner.op, CustomDistRV) mu_test_value = npr.normal(loc=0, scale=1, size=supp_shape).astype(pytensor.config.floatX) @@ -219,10 +218,6 @@ def density_support_point(rv, size, mu): assert evaled_support_point.shape == to_tuple(size) assert np.all(evaled_support_point == mu_val) - @pytest.mark.xfail( - NotImplementedError, - reason="Support shape of multivariate CustomDist cannot be inferred. See https://github.com/pymc-devs/pytensor/pull/388", - ) @pytest.mark.parametrize("size", [(), (2,), (3, 2)], ids=str) def test_custom_dist_custom_support_point_multivariate(self, size): def density_support_point(rv, size, mu): @@ -235,8 +230,7 @@ def density_support_point(rv, size, mu): "a", mu, support_point=density_support_point, - ndims_params=[1], - ndim_supp=1, + signature="(n)->(n)", size=size, ) assert isinstance(a.owner.op, CustomDistRV) @@ -244,10 +238,6 @@ def density_support_point(rv, size, mu): assert evaled_support_point.shape == (*to_tuple(size), 5) assert np.all(evaled_support_point == mu_val) - @pytest.mark.xfail( - NotImplementedError, - reason="Support shape of multivariate CustomDist cannot be inferred. See https://github.com/pymc-devs/pytensor/pull/388", - ) @pytest.mark.parametrize( "with_random, size", [ @@ -267,21 +257,14 @@ def _random(mu, rng=None, size=None): else: random = None - mu_val = np.random.normal(loc=2, scale=1, size=5).astype(pytensor.config.floatX) with Model(): mu = Normal("mu", size=5) - a = CustomDist("a", mu, random=random, ndims_params=[1], ndim_supp=1, size=size) + a = CustomDist("a", mu, random=random, signature="(n)->(n)", size=size) assert isinstance(a.owner.op, CustomDistRV) if with_random: - evaled_support_point = support_point(a).eval({mu: mu_val}) + evaled_support_point = support_point(a).eval() assert evaled_support_point.shape == (*to_tuple(size), 5) assert np.all(evaled_support_point == 0) - else: - with pytest.raises( - TypeError, - match="Cannot safely infer the size of a multivariate random variable's support_point.", - ): - evaled_support_point = support_point(a).eval({mu: mu_val}) def test_dist(self): mu = 1 @@ -300,6 +283,12 @@ def test_dist(self): x_logp = logp(x, test_value) assert np.allclose(x_logp.eval(), st.norm(1).logpdf(test_value)) + def test_multivariate_insufficient_signature(self): + with pytest.raises( + NotImplementedError, match="signature is not sufficient to infer the support shape" + ): + CustomDist.dist(signature="(n)->(m)") + class TestCustomSymbolicDist: def test_basic(self):