From 47c4b489dbfda419c258701562a795d1d6409184 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 2 Nov 2023 16:52:29 +0100 Subject: [PATCH] Marginalize DiscreteMarkovChain Co-authored-by: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com> --- pymc_experimental/model/marginal_model.py | 115 +++++++++++++++--- .../tests/model/test_marginal_model.py | 85 +++++++++++++ 2 files changed, 185 insertions(+), 15 deletions(-) diff --git a/pymc_experimental/model/marginal_model.py b/pymc_experimental/model/marginal_model.py index a250c47a..7ee047b3 100644 --- a/pymc_experimental/model/marginal_model.py +++ b/pymc_experimental/model/marginal_model.py @@ -10,21 +10,16 @@ from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform from pymc.distributions.transforms import Chain from pymc.logprob.abstract import _logprob -from pymc.logprob.basic import conditional_logp +from pymc.logprob.basic import conditional_logp, logp from pymc.logprob.transforms import IntervalTransform from pymc.model import Model from pymc.pytensorf import compile_pymc, constant_fold, inputvars from pymc.util import _get_seeds_per_chain, dataset_to_point_list, treedict -from pytensor import Mode +from pytensor import Mode, scan from pytensor.compile import SharedVariable from pytensor.compile.builders import OpFromGraph -from pytensor.graph import ( - Constant, - FunctionGraph, - ancestors, - clone_replace, - vectorize_graph, -) +from pytensor.graph import Constant, FunctionGraph, ancestors, clone_replace +from pytensor.graph.replace import vectorize_graph from pytensor.scan import map as scan_map from pytensor.tensor import TensorType, TensorVariable from pytensor.tensor.elemwise import Elemwise @@ -33,6 +28,8 @@ __all__ = ["MarginalModel"] +from pymc_experimental.distributions import DiscreteMarkovChain + class MarginalModel(Model): """Subclass of PyMC Model that implements functionality for automatic @@ -247,16 +244,25 @@ def marginalize( self[var] if isinstance(var, str) else var for var in rvs_to_marginalize ] - supported_dists = (Bernoulli, Categorical, DiscreteUniform) for rv_to_marginalize in rvs_to_marginalize: if rv_to_marginalize not in self.free_RVs: raise ValueError( f"Marginalized RV {rv_to_marginalize} is not a free RV in the model" ) - if not isinstance(rv_to_marginalize.owner.op, supported_dists): + + rv_op = rv_to_marginalize.owner.op + if isinstance(rv_op, DiscreteMarkovChain): + if rv_op.n_lags > 1: + raise NotImplementedError( + "Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported" + ) + if rv_to_marginalize.owner.inputs[0].type.ndim > 2: + raise NotImplementedError( + "Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported" + ) + elif not isinstance(rv_op, (Bernoulli, Categorical, DiscreteUniform)): raise NotImplementedError( - f"RV with distribution {rv_to_marginalize.owner.op} cannot be marginalized. " - f"Supported distribution include {supported_dists}" + f"Marginalization of RV with distribution {rv_to_marginalize.owner.op} is not supported" ) if rv_to_marginalize.name in self.named_vars_to_dims: @@ -492,6 +498,10 @@ class FiniteDiscreteMarginalRV(MarginalRV): """Base class for Finite Discrete Marginalized RVs""" +class DiscreteMarginalMarkovChainRV(MarginalRV): + """Base class for Discrete Marginal Markov Chain RVs""" + + def static_shape_ancestors(vars): """Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph).""" return [ @@ -620,11 +630,17 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs replace_inputs.update({input_rv: input_rv.type() for input_rv in input_rvs}) cloned_outputs = clone_replace(outputs, replace=replace_inputs) - marginalization_op = FiniteDiscreteMarginalRV( + if isinstance(rv_to_marginalize.owner.op, DiscreteMarkovChain): + marginalize_constructor = DiscreteMarginalMarkovChainRV + else: + marginalize_constructor = FiniteDiscreteMarginalRV + + marginalization_op = marginalize_constructor( inputs=list(replace_inputs.values()), outputs=cloned_outputs, ndim_supp=ndim_supp, ) + marginalized_rvs = marginalization_op(*replace_inputs.keys()) fgraph.replace_all(tuple(zip(rvs_to_marginalize, marginalized_rvs))) return rvs_to_marginalize, marginalized_rvs @@ -640,6 +656,9 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]: elif isinstance(op, DiscreteUniform): lower, upper = constant_fold(rv.owner.inputs[3:]) return tuple(range(lower, upper + 1)) + elif isinstance(op, DiscreteMarkovChain): + P = rv.owner.inputs[0] + return tuple(range(pt.get_vector_length(P[-1]))) raise NotImplementedError(f"Cannot compute domain for op {op}") @@ -647,7 +666,7 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]: def _add_reduce_batch_dependent_logps( marginalized_type: TensorType, dependent_logps: Sequence[TensorVariable] ): - """Add the logps of dependent RVs while reducing extra batch dims as assessed from the `marginalized_type`.""" + """Add the logps of dependent RVs while reducing extra batch dims relative to `marginalized_type`.""" mbcast = marginalized_type.broadcastable reduced_logps = [] @@ -730,3 +749,69 @@ def logp_fn(marginalized_rv_const, *non_sequences): # We have to add dummy logps for the remaining value variables, otherwise PyMC will raise return joint_logps, *(pt.constant(0),) * (len(values) - 1) + + +@_logprob.register(DiscreteMarginalMarkovChainRV) +def marginal_hmm_logp(op, values, *inputs, **kwargs): + + marginalized_rvs_node = op.make_node(*inputs) + inner_rvs = clone_replace( + op.inner_outputs, + replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)}, + ) + + chain_rv, *dependent_rvs = inner_rvs + P, n_steps_, init_dist_, rng = chain_rv.owner.inputs + domain = pt.arange(P.shape[-1], dtype="int32") + + # Construct logp in two steps + # Step 1: Compute the probability of the data ("emissions") under every possible state (vec_logp_emission) + + # First we need to vectorize the conditional logp graph of the data, in case there are batch dimensions floating + # around. To do this, we need to break the dependency between chain and the init_dist_ random variable. Otherwise, + # PyMC will detect a random variable in the logp graph (init_dist_), that isn't relevant at this step. + chain_value = chain_rv.clone() + dependent_rvs = clone_replace(dependent_rvs, {chain_rv: chain_value}) + logp_emissions_dict = conditional_logp(dict(zip(dependent_rvs, values))) + + # Reduce and add the batch dims beyond the chain dimension + reduced_logp_emissions = _add_reduce_batch_dependent_logps( + chain_rv.type, logp_emissions_dict.values() + ) + + # Add a batch dimension for the domain of the chain + chain_shape = constant_fold(tuple(chain_rv.shape)) + batch_chain_value = pt.moveaxis(pt.full((*chain_shape, domain.size), domain), -1, 0) + batch_logp_emissions = vectorize_graph(reduced_logp_emissions, {chain_value: batch_chain_value}) + + # Step 2: Compute the transition probabilities + # This is the "forward algorithm", alpha_t = p(y | s_t) * sum_{s_{t-1}}(p(s_t | s_{t-1}) * alpha_{t-1}) + # We do it entirely in logs, though. + + # To compute the prior probabilities of each state, we evaluate the logp of the domain (all possible states) under + # the initial distribution. This is robust to everything the user can throw at it. + batch_logp_init_dist = pt.vectorize(lambda x: logp(init_dist_, x), "()->()")( + batch_chain_value[..., 0] + ) + log_alpha_init = batch_logp_init_dist + batch_logp_emissions[..., 0] + + def step_alpha(logp_emission, log_alpha, log_P): + step_log_prob = pt.logsumexp(log_alpha[:, None] + log_P, axis=0) + return logp_emission + step_log_prob + + P_bcast_dims = (len(chain_shape) - 1) - (P.type.ndim - 2) + log_P = pt.shape_padright(pt.log(P), P_bcast_dims) + log_alpha_seq, _ = scan( + step_alpha, + non_sequences=[log_P], + outputs_info=[log_alpha_init], + # Scan needs the time dimension first, and we already consumed the 1st logp computing the initial value + sequences=pt.moveaxis(batch_logp_emissions[..., 1:], -1, 0), + ) + # Final logp is just the sum of the last scan state + joint_logp = pt.logsumexp(log_alpha_seq[-1], axis=0) + + # If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first + # return is the joint probability of everything together, but PyMC still expects one logp for each one. + dummy_logps = (pt.constant(0),) * (len(values) - 1) + return joint_logp, *dummy_logps diff --git a/pymc_experimental/tests/model/test_marginal_model.py b/pymc_experimental/tests/model/test_marginal_model.py index b664feca..c0e1bd90 100644 --- a/pymc_experimental/tests/model/test_marginal_model.py +++ b/pymc_experimental/tests/model/test_marginal_model.py @@ -14,6 +14,7 @@ from scipy.special import log_softmax, logsumexp from scipy.stats import halfnorm, norm +from pymc_experimental.distributions import DiscreteMarkovChain from pymc_experimental.model.marginal_model import ( FiniteDiscreteMarginalRV, MarginalModel, @@ -673,3 +674,87 @@ def dist(idx, size): ): pt = {"norm": test_value} np.testing.assert_allclose(logp_fn(pt), ref_logp_fn(pt)) + + +@pytest.mark.parametrize("batch_chain", (False, True), ids=lambda x: f"batch_chain={x}") +@pytest.mark.parametrize("batch_emission", (False, True), ids=lambda x: f"batch_emission={x}") +def test_marginalized_hmm_normal_emission(batch_chain, batch_emission): + if batch_chain and not batch_emission: + pytest.skip("Redundant implicit combination") + + with MarginalModel() as m: + P = [[0, 1], [1, 0]] + init_dist = pm.Categorical.dist(p=[1, 0]) + chain = DiscreteMarkovChain( + "chain", P=P, init_dist=init_dist, steps=3, shape=(3, 4) if batch_chain else None + ) + emission = pm.Normal( + "emission", mu=chain * 2 - 1, sigma=1e-1, shape=(3, 4) if batch_emission else None + ) + + m.marginalize([chain]) + logp_fn = m.compile_logp() + + test_value = np.array([-1, 1, -1, 1]) + expected_logp = pm.logp(pm.Normal.dist(0, 1e-1), np.zeros_like(test_value)).sum().eval() + if batch_emission: + test_value = np.broadcast_to(test_value, (3, 4)) + expected_logp *= 3 + np.testing.assert_allclose(logp_fn({f"emission": test_value}), expected_logp) + + +@pytest.mark.parametrize( + "categorical_emission", + [ + False, + # Categorical has a core vector parameter, + # so it is not possible to build a graph that uses elemwise operations exclusively + pytest.param(True, marks=pytest.mark.xfail(raises=NotImplementedError)), + ], +) +def test_marginalized_hmm_categorical_emission(categorical_emission): + """Example adapted from https://www.youtube.com/watch?v=9-sPm4CfcD0""" + with MarginalModel() as m: + P = np.array([[0.5, 0.5], [0.3, 0.7]]) + init_dist = pm.Categorical.dist(p=[0.375, 0.625]) + chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=2) + if categorical_emission: + emission = pm.Categorical( + "emission", p=pt.where(pt.eq(chain, 0)[..., None], [0.8, 0.2], [0.4, 0.6]) + ) + else: + emission = pm.Bernoulli("emission", p=pt.where(pt.eq(chain, 0), 0.2, 0.6)) + m.marginalize([chain]) + + test_value = np.array([0, 0, 1]) + expected_logp = np.log(0.1344) # Shown at the 10m22s mark in the video + logp_fn = m.compile_logp() + np.testing.assert_allclose(logp_fn({f"emission": test_value}), expected_logp) + + +@pytest.mark.parametrize("batch_emission1", (False, True)) +@pytest.mark.parametrize("batch_emission2", (False, True)) +def test_marginalized_hmm_multiple_emissions(batch_emission1, batch_emission2): + emission1_shape = (2, 4) if batch_emission1 else (4,) + emission2_shape = (2, 4) if batch_emission2 else (4,) + with MarginalModel() as m: + P = [[0, 1], [1, 0]] + init_dist = pm.Categorical.dist(p=[1, 0]) + chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=3) + emission_1 = pm.Normal("emission_1", mu=chain * 2 - 1, sigma=1e-1, shape=emission1_shape) + emission_2 = pm.Normal( + "emission_2", mu=(1 - chain) * 2 - 1, sigma=1e-1, shape=emission2_shape + ) + + with pytest.warns(UserWarning, match="multiple dependent variables"): + m.marginalize([chain]) + + logp_fn = m.compile_logp() + + test_value = np.array([-1, 1, -1, 1]) + multiplier = 2 + batch_emission1 + batch_emission2 + expected_logp = norm.logpdf(np.zeros_like(test_value), 0, 1e-1).sum() * multiplier + test_value_emission1 = np.broadcast_to(test_value, emission1_shape) + test_value_emission2 = np.broadcast_to(-test_value, emission2_shape) + test_point = {"emission_1": test_value_emission1, "emission_2": test_value_emission2} + np.testing.assert_allclose(logp_fn(test_point), expected_logp)