From 060fcf0fa1305e87b71e79b83c669b2a0995a44b Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 9 Feb 2024 11:38:01 +0100 Subject: [PATCH] Fix MarginalModel with Data containers --- pymc_experimental/model/marginal_model.py | 10 ++++--- .../tests/model/test_marginal_model.py | 26 +++++++++++++++++++ 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/pymc_experimental/model/marginal_model.py b/pymc_experimental/model/marginal_model.py index 1f8e4531..068d0c61 100644 --- a/pymc_experimental/model/marginal_model.py +++ b/pymc_experimental/model/marginal_model.py @@ -212,8 +212,10 @@ def logp(self, vars=None, **kwargs): return m._logp(vars=vars, **kwargs) def clone(self): - m = MarginalModel() - vars = self.basic_RVs + self.potentials + self.deterministics + self.marginalized_rvs + m = MarginalModel(coords=self.coords) + model_vars = self.basic_RVs + self.potentials + self.deterministics + self.marginalized_rvs + data_vars = [var for name, var in self.named_vars.items() if var not in model_vars] + vars = model_vars + data_vars cloned_vars = clone_replace(vars) vars_to_clone = {var: cloned_var for var, cloned_var in zip(vars, cloned_vars)} m.vars_to_clone = vars_to_clone @@ -598,7 +600,7 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs # can ultimately be generated that is proportional to the support domain and not # to the variables dimensions # We don't need to worry about this if the RV is scalar. - if np.prod(constant_fold(tuple(rv_to_marginalize.shape))) > 1: + if np.prod(constant_fold(tuple(rv_to_marginalize.shape), raise_not_constant=False)) != 1: if not is_elemwise_subgraph(rv_to_marginalize, dependent_rvs_input_rvs, dependent_rvs): raise NotImplementedError( "The subgraph between a marginalized RV and its dependents includes non Elemwise operations. " @@ -682,7 +684,7 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs): # batched dimensions of the marginalized RV # PyMC does not allow RVs in the logp graph, even if we are just using the shape - marginalized_rv_shape = constant_fold(tuple(marginalized_rv.shape)) + marginalized_rv_shape = constant_fold(tuple(marginalized_rv.shape), raise_not_constant=False) marginalized_rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv) marginalized_rv_domain_tensor = pt.moveaxis( pt.full( diff --git a/pymc_experimental/tests/model/test_marginal_model.py b/pymc_experimental/tests/model/test_marginal_model.py index b667635c..7533702e 100644 --- a/pymc_experimental/tests/model/test_marginal_model.py +++ b/pymc_experimental/tests/model/test_marginal_model.py @@ -598,3 +598,29 @@ def test_is_conditional_dependent_static_shape(): x2 = pt.matrix("x2", shape=(9, 5)) y2 = pt.random.normal(size=pt.shape(x2)) assert not is_conditional_dependent(y2, x2, [x2, y2]) + + +def test_data_container(): + """Test that MarginalModel can handle Data containers.""" + with MarginalModel(coords_mutable={"obs": [0]}) as marginal_m: + x = pm.MutableData("x", 2.5) + idx = pm.Bernoulli("idx", p=0.7, dims="obs") + y = pm.Normal("y", idx * x, dims="obs") + + marginal_m.marginalize([idx]) + + logp_fn = marginal_m.compile_logp() + + with pm.Model(coords_mutable={"obs": [0]}) as m_ref: + x = pm.MutableData("x", 2.5) + y = pm.NormalMixture("y", w=[0.3, 0.7], mu=[0, x], dims="obs") + + ref_logp_fn = m_ref.compile_logp() + + for i, x_val in enumerate((-1.5, 2.5, 3.5), start=1): + for m in (marginal_m, m_ref): + m.set_dim("obs", new_length=i, coord_values=tuple(range(i))) + pm.set_data({"x": x_val}, model=m) + + ip = marginal_m.initial_point() + np.testing.assert_allclose(logp_fn(ip), ref_logp_fn(ip))