Skip to content

Commit

Permalink
replace default_supp_shape_from_params with supp_shape_from_ref_param…
Browse files Browse the repository at this point in the history
…_shape
  • Loading branch information
bwengals committed Jul 25, 2023
1 parent cd1d354 commit 72cb829
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@
from pytensor.tensor import TensorConstant, gammaln, sigmoid
from pytensor.tensor.nlinalg import det, eigh, matrix_inverse, trace
from pytensor.tensor.random.basic import dirichlet, multinomial, multivariate_normal
from pytensor.tensor.random.op import RandomVariable, default_supp_shape_from_params
from pytensor.tensor.random.utils import broadcast_params
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.utils import (
broadcast_params,
supp_shape_from_ref_param_shape,
)
from pytensor.tensor.slinalg import Cholesky, SolveTriangular
from pytensor.tensor.type import TensorType
from scipy import linalg, stats
Expand Down Expand Up @@ -321,8 +324,11 @@ def __call__(self, nu, mu=None, cov=None, size=None, **kwargs):
return super().__call__(nu, mu, cov, size=size, **kwargs)

def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
return default_supp_shape_from_params(
self.ndim_supp, dist_params, rep_param_idx, param_shapes
return supp_shape_from_ref_param_shape(
ndim_supp=self.ndim_supp,
dist_params=dist_params,
param_shapes=param_shapes,
ref_param_idx=rep_param_idx,
)

@classmethod
Expand Down Expand Up @@ -612,8 +618,11 @@ class DirichletMultinomialRV(RandomVariable):
_print_name = ("DirichletMN", "\\operatorname{DirichletMN}")

def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
return default_supp_shape_from_params(
self.ndim_supp, dist_params, rep_param_idx, param_shapes
return supp_shape_from_ref_param_shape(
ndim_supp=self.ndim_supp,
dist_params=dist_params,
param_shapes=param_shapes,
ref_param_idx=rep_param_idx,
)

@classmethod
Expand Down

0 comments on commit 72cb829

Please sign in to comment.