From 72cb829686bfaf96b07e1a706d0dc383bfa1d0f2 Mon Sep 17 00:00:00 2001 From: Bill Engels Date: Tue, 25 Jul 2023 11:49:00 -0700 Subject: [PATCH] replace default_supp_shape_from_params with supp_shape_from_ref_param_shape --- pymc/distributions/multivariate.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 476fa14c2d..516e6b7100 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -32,8 +32,11 @@ from pytensor.tensor import TensorConstant, gammaln, sigmoid from pytensor.tensor.nlinalg import det, eigh, matrix_inverse, trace from pytensor.tensor.random.basic import dirichlet, multinomial, multivariate_normal -from pytensor.tensor.random.op import RandomVariable, default_supp_shape_from_params -from pytensor.tensor.random.utils import broadcast_params +from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.random.utils import ( + broadcast_params, + supp_shape_from_ref_param_shape, +) from pytensor.tensor.slinalg import Cholesky, SolveTriangular from pytensor.tensor.type import TensorType from scipy import linalg, stats @@ -321,8 +324,11 @@ def __call__(self, nu, mu=None, cov=None, size=None, **kwargs): return super().__call__(nu, mu, cov, size=size, **kwargs) def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): - return default_supp_shape_from_params( - self.ndim_supp, dist_params, rep_param_idx, param_shapes + return supp_shape_from_ref_param_shape( + ndim_supp=self.ndim_supp, + dist_params=dist_params, + param_shapes=param_shapes, + ref_param_idx=rep_param_idx, ) @classmethod @@ -612,8 +618,11 @@ class DirichletMultinomialRV(RandomVariable): _print_name = ("DirichletMN", "\\operatorname{DirichletMN}") def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): - return default_supp_shape_from_params( - self.ndim_supp, dist_params, rep_param_idx, param_shapes + return supp_shape_from_ref_param_shape( + ndim_supp=self.ndim_supp, + dist_params=dist_params, + param_shapes=param_shapes, + ref_param_idx=rep_param_idx, ) @classmethod