From fde07260241c5005cbc67c60fad2e3c558c0b78c Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 16 Apr 2024 11:02:33 +0200 Subject: [PATCH] Allow creating MarginalModel from existing Model --- docs/api_reference.rst | 1 + pymc_experimental/model/marginal_model.py | 92 ++++++++++++++----- .../tests/model/test_marginal_model.py | 33 +++++++ pymc_experimental/tests/utils.py | 31 +++++++ 4 files changed, 134 insertions(+), 23 deletions(-) create mode 100644 pymc_experimental/tests/utils.py diff --git a/docs/api_reference.rst b/docs/api_reference.rst index dfc439685..b3455458a 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -10,6 +10,7 @@ methods in the current release of PyMC experimental. as_model MarginalModel + marginalize model_builder.ModelBuilder Inference diff --git a/pymc_experimental/model/marginal_model.py b/pymc_experimental/model/marginal_model.py index 8847f5eb3..cd26d3109 100644 --- a/pymc_experimental/model/marginal_model.py +++ b/pymc_experimental/model/marginal_model.py @@ -1,5 +1,5 @@ import warnings -from typing import Sequence +from typing import Sequence, Union import numpy as np import pymc @@ -26,10 +26,12 @@ from pytensor.tensor.shape import Shape from pytensor.tensor.special import log_softmax -__all__ = ["MarginalModel"] +__all__ = ["MarginalModel", "marginalize"] from pymc_experimental.distributions import DiscreteMarkovChain +ModelRVs = TensorVariable | Sequence[TensorVariable] | str | Sequence[str] + class MarginalModel(Model): """Subclass of PyMC Model that implements functionality for automatic @@ -208,35 +210,50 @@ def logp(self, vars=None, **kwargs): vars = [m[var.name] for var in vars] return m._logp(vars=vars, **kwargs) - def clone(self): - m = MarginalModel(coords=self.coords) - model_vars = self.basic_RVs + self.potentials + self.deterministics + self.marginalized_rvs - data_vars = [var for name, var in self.named_vars.items() if var not in model_vars] + @staticmethod + def _clone(model: Union[Model, "MarginalModel"]): + new_model = MarginalModel(coords=model.coords) + if isinstance(model, MarginalModel): + marginalized_rvs = model.marginalized_rvs + marginalized_named_vars_to_dims = model._marginalized_named_vars_to_dims + else: + marginalized_rvs = [] + marginalized_named_vars_to_dims = {} + + model_vars = model.basic_RVs + model.potentials + model.deterministics + marginalized_rvs + data_vars = [var for name, var in model.named_vars.items() if var not in model_vars] vars = model_vars + data_vars cloned_vars = clone_replace(vars) vars_to_clone = {var: cloned_var for var, cloned_var in zip(vars, cloned_vars)} - m.vars_to_clone = vars_to_clone - - m.named_vars = treedict({name: vars_to_clone[var] for name, var in self.named_vars.items()}) - m.named_vars_to_dims = self.named_vars_to_dims - m.values_to_rvs = {i: vars_to_clone[rv] for i, rv in self.values_to_rvs.items()} - m.rvs_to_values = {vars_to_clone[rv]: i for rv, i in self.rvs_to_values.items()} - m.rvs_to_transforms = {vars_to_clone[rv]: i for rv, i in self.rvs_to_transforms.items()} - m.rvs_to_initial_values = { - vars_to_clone[rv]: i for rv, i in self.rvs_to_initial_values.items() + new_model.vars_to_clone = vars_to_clone + + new_model.named_vars = treedict( + {name: vars_to_clone[var] for name, var in model.named_vars.items()} + ) + new_model.named_vars_to_dims = model.named_vars_to_dims + new_model.values_to_rvs = {i: vars_to_clone[rv] for i, rv in model.values_to_rvs.items()} + new_model.rvs_to_values = {vars_to_clone[rv]: i for rv, i in model.rvs_to_values.items()} + new_model.rvs_to_transforms = { + vars_to_clone[rv]: i for rv, i in model.rvs_to_transforms.items() + } + new_model.rvs_to_initial_values = { + vars_to_clone[rv]: i for rv, i in model.rvs_to_initial_values.items() } - m.free_RVs = [vars_to_clone[rv] for rv in self.free_RVs] - m.observed_RVs = [vars_to_clone[rv] for rv in self.observed_RVs] - m.potentials = [vars_to_clone[pot] for pot in self.potentials] - m.deterministics = [vars_to_clone[det] for det in self.deterministics] + new_model.free_RVs = [vars_to_clone[rv] for rv in model.free_RVs] + new_model.observed_RVs = [vars_to_clone[rv] for rv in model.observed_RVs] + new_model.potentials = [vars_to_clone[pot] for pot in model.potentials] + new_model.deterministics = [vars_to_clone[det] for det in model.deterministics] - m.marginalized_rvs = [vars_to_clone[rv] for rv in self.marginalized_rvs] - m._marginalized_named_vars_to_dims = self._marginalized_named_vars_to_dims - return m + new_model.marginalized_rvs = [vars_to_clone[rv] for rv in marginalized_rvs] + new_model._marginalized_named_vars_to_dims = marginalized_named_vars_to_dims + return new_model + + def clone(self): + return self._clone(self) def marginalize( self, - rvs_to_marginalize: TensorVariable | Sequence[TensorVariable] | str | Sequence[str], + rvs_to_marginalize: ModelRVs, ): if not isinstance(rvs_to_marginalize, Sequence): rvs_to_marginalize = (rvs_to_marginalize,) @@ -491,6 +508,35 @@ def transform_input(inputs): return rv_dataset +def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel: + """Marginalize a subset of variables in a PyMC model. + + This creates a class of `MarginalModel` from an existing `Model`, with the specified + variables marginalized. + + See documentation for `MarginalModel` for more information. + + Parameters + ---------- + model : Model + PyMC model to marginalize. Original variables well be cloned. + rvs_to_marginalize : Sequence[TensorVariable] + Variables to marginalize in the returned model. + + Returns + ------- + marginal_model: MarginalModel + Marginal model with the specified variables marginalized. + """ + if not isinstance(rvs_to_marginalize, tuple | list): + rvs_to_marginalize = (rvs_to_marginalize,) + rvs_to_marginalize = [rv if isinstance(rv, str) else rv.name for rv in rvs_to_marginalize] + + marginal_model = MarginalModel._clone(model) + marginal_model.marginalize(rvs_to_marginalize) + return marginal_model + + class MarginalRV(SymbolicRandomVariable): """Base class for Marginalized RVs""" diff --git a/pymc_experimental/tests/model/test_marginal_model.py b/pymc_experimental/tests/model/test_marginal_model.py index 610f3b47b..fc05d4edc 100644 --- a/pymc_experimental/tests/model/test_marginal_model.py +++ b/pymc_experimental/tests/model/test_marginal_model.py @@ -10,15 +10,18 @@ from pymc import ImputationWarning, inputvars from pymc.distributions import transforms from pymc.logprob.abstract import _logprob +from pymc.model.fgraph import fgraph_from_model from pymc.util import UNSET from scipy.special import log_softmax, logsumexp from scipy.stats import halfnorm, norm +from utils import equal_computations_up_to_root from pymc_experimental.distributions import DiscreteMarkovChain from pymc_experimental.model.marginal_model import ( FiniteDiscreteMarginalRV, MarginalModel, is_conditional_dependent, + marginalize, ) @@ -776,3 +779,33 @@ def test_mutable_indexing_jax_backend(): pm.LogNormal("y", mu=cat_effect[cat_effect_idx], sigma=1 + is_outlier, observed=data) model.marginalize(["is_outlier"]) get_jaxified_logp(model) + + +def test_marginal_model_func(): + def create_model(model_class): + with model_class(coords={"trial": range(10)}) as m: + idx = pm.Bernoulli("idx", p=0.5, dims="trial") + mu = pt.where(idx, 1, -1) + sigma = pm.HalfNormal("sigma") + y = pm.Normal("y", mu=mu, sigma=sigma, dims="trial", observed=[1] * 10) + return m + + marginal_m = marginalize(create_model(pm.Model), ["idx"]) + assert isinstance(marginal_m, MarginalModel) + + reference_m = create_model(MarginalModel) + reference_m.marginalize(["idx"]) + + # Check forward graph representation is the same + marginal_fgraph, _ = fgraph_from_model(marginal_m) + reference_fgraph, _ = fgraph_from_model(reference_m) + assert equal_computations_up_to_root(marginal_fgraph.outputs, reference_fgraph.outputs) + + # Check logp graph is the same + # This fails because OpFromGraphs comparison is broken + # assert equal_computations_up_to_root([marginal_m.logp()], [reference_m.logp()]) + ip = marginal_m.initial_point() + np.testing.assert_allclose( + marginal_m.compile_logp()(ip), + reference_m.compile_logp()(ip), + ) diff --git a/pymc_experimental/tests/utils.py b/pymc_experimental/tests/utils.py new file mode 100644 index 000000000..2d934bf6c --- /dev/null +++ b/pymc_experimental/tests/utils.py @@ -0,0 +1,31 @@ +from typing import Sequence + +from pytensor.compile import SharedVariable +from pytensor.graph import Constant, graph_inputs +from pytensor.graph.basic import Variable, equal_computations +from pytensor.tensor.random.type import RandomType + + +def equal_computations_up_to_root( + xs: Sequence[Variable], ys: Sequence[Variable], ignore_rng_values=True +) -> bool: + # Check if graphs are equivalent even if root variables have distinct identities + + x_graph_inputs = [var for var in graph_inputs(xs) if not isinstance(var, Constant)] + y_graph_inputs = [var for var in graph_inputs(ys) if not isinstance(var, Constant)] + if len(x_graph_inputs) != len(y_graph_inputs): + return False + for x, y in zip(x_graph_inputs, y_graph_inputs): + if x.type != y.type: + return False + if x.name != y.name: + return False + if isinstance(x, SharedVariable): + if not isinstance(y, SharedVariable): + return False + if isinstance(x.type, RandomType) and ignore_rng_values: + continue + if not x.type.values_eq(x.get_value(), y.get_value()): + return False + + return equal_computations(xs, ys, in_xs=x_graph_inputs, in_ys=y_graph_inputs)