Skip to content

Commit

Permalink
Remove unused batch_dim logic in SequenceMvNormal
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski committed Jun 29, 2024
1 parent 69b5b17 commit 48bf59f
Showing 1 changed file with 0 additions and 16 deletions.
16 changes: 0 additions & 16 deletions pymc_experimental/statespace/filters/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@
from pymc.distributions.dist_math import check_parameters
from pymc.distributions.distribution import Continuous, SymbolicRandomVariable
from pymc.distributions.multivariate import MvNormal
from pymc.distributions.shape_utils import get_support_shape_1d, rv_size_is_none
from pymc.logprob.abstract import _logprob
from pytensor.graph.basic import Node
from pytensor.tensor.random.basic import MvNormalRV
from pytensor.tensor.random.utils import normalize_size_param

floatX = pytensor.config.floatX
COV_ZERO_TOL = 0
Expand Down Expand Up @@ -372,20 +370,6 @@ def dist(cls, mus, covs, logp, **kwargs):

@classmethod
def rv_op(cls, mus, covs, logp, size=None):

# TODO: None of this does anything -- what am I doing wrong?
size = normalize_size_param(size)
if rv_size_is_none(size):
# In this case the size of the init_dist depends on the parameters shape
# The last dimension of rho and init_dist does not matter
batch_size = pt.broadcast_shape(
tuple(mus.shape)[:-2],
tuple(covs.shape)[:-3],
arrays_are_shapes=True,
)
else:
batch_size = size

# Batch dimensions (if any) will be on the far left, but scan requires time to be there instead
if mus.ndim > 2:
mus = pt.moveaxis(mus, -2, 0)
Expand Down

0 comments on commit 48bf59f

Please sign in to comment.