diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 360a8199..4deda063 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -10,6 +10,6 @@ dependencies: - xhistogram - statsmodels - pip: - - pymc>=5.16.1 # CI was failing to resolve + - pymc>=5.17.0 # CI was failing to resolve - blackjax - scikit-learn diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index 360a8199..4deda063 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -10,6 +10,6 @@ dependencies: - xhistogram - statsmodels - pip: - - pymc>=5.16.1 # CI was failing to resolve + - pymc>=5.17.0 # CI was failing to resolve - blackjax - scikit-learn diff --git a/pymc_experimental/model/marginal/distributions.py b/pymc_experimental/model/marginal/distributions.py index a3a2adbe..661665e9 100644 --- a/pymc_experimental/model/marginal/distributions.py +++ b/pymc_experimental/model/marginal/distributions.py @@ -1,30 +1,55 @@ from collections.abc import Sequence import numpy as np +import pytensor.tensor as pt -from pymc import Bernoulli, Categorical, DiscreteUniform, SymbolicRandomVariable, logp -from pymc.logprob import conditional_logp -from pymc.logprob.abstract import _logprob +from pymc.distributions import Bernoulli, Categorical, DiscreteUniform +from pymc.logprob.abstract import MeasurableOp, _logprob +from pymc.logprob.basic import conditional_logp, logp from pymc.pytensorf import constant_fold -from pytensor import Mode, clone_replace, graph_replace, scan -from pytensor import map as scan_map -from pytensor import tensor as pt -from pytensor.graph import vectorize_graph -from pytensor.tensor import TensorType, TensorVariable +from pytensor import Variable +from pytensor.compile.builders import OpFromGraph +from pytensor.compile.mode import Mode +from pytensor.graph import Op, vectorize_graph +from pytensor.graph.replace import clone_replace, graph_replace +from pytensor.scan import map as scan_map +from pytensor.scan import scan +from pytensor.tensor import TensorVariable from pymc_experimental.distributions import DiscreteMarkovChain -class MarginalRV(SymbolicRandomVariable): +class MarginalRV(OpFromGraph, MeasurableOp): """Base class for Marginalized RVs""" + def __init__(self, *args, dims_connections: tuple[tuple[int | None]], **kwargs) -> None: + self.dims_connections = dims_connections + super().__init__(*args, **kwargs) -class FiniteDiscreteMarginalRV(MarginalRV): - """Base class for Finite Discrete Marginalized RVs""" + @property + def support_axes(self) -> tuple[tuple[int]]: + """Dimensions of dependent RVs that belong to the core (non-batched) marginalized variable.""" + marginalized_ndim_supp = self.inner_outputs[0].owner.op.ndim_supp + support_axes_vars = [] + for dims_connection in self.dims_connections: + ndim = len(dims_connection) + marginalized_supp_axes = ndim - marginalized_ndim_supp + support_axes_vars.append( + tuple( + -i + for i, dim in enumerate(reversed(dims_connection), start=1) + if (dim is None or dim > marginalized_supp_axes) + ) + ) + return tuple(support_axes_vars) -class DiscreteMarginalMarkovChainRV(MarginalRV): - """Base class for Discrete Marginal Markov Chain RVs""" +class MarginalFiniteDiscreteRV(MarginalRV): + """Base class for Marginalized Finite Discrete RVs""" + + +class MarginalDiscreteMarkovChainRV(MarginalRV): + """Base class for Marginalized Discrete Markov Chain RVs""" def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]: @@ -34,7 +59,8 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]: return (0, 1) elif isinstance(op, Categorical): [p_param] = dist_params - return tuple(range(pt.get_vector_length(p_param))) + [p_param_length] = constant_fold([p_param.shape[-1]]) + return tuple(range(p_param_length)) elif isinstance(op, DiscreteUniform): lower, upper = constant_fold(dist_params) return tuple(np.arange(lower, upper + 1)) @@ -45,31 +71,81 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]: raise NotImplementedError(f"Cannot compute domain for op {op}") -def _add_reduce_batch_dependent_logps( - marginalized_type: TensorType, dependent_logps: Sequence[TensorVariable] -): - """Add the logps of dependent RVs while reducing extra batch dims relative to `marginalized_type`.""" +def reduce_batch_dependent_logps( + dependent_dims_connections: Sequence[tuple[int | None, ...]], + dependent_ops: Sequence[Op], + dependent_logps: Sequence[TensorVariable], +) -> TensorVariable: + """Combine the logps of dependent RVs and align them with the marginalized logp. + + This requires reducing extra batch dims and transposing when they are not aligned. + + idx = pm.Bernoulli(idx, shape=(3, 2)) # 0, 1 + pm.Normal("dep1", mu=idx.T[..., None] * 2, shape=(3, 2, 5)) + pm.Normal("dep2", mu=idx * 2, shape=(7, 2, 3)) + + marginalize(idx) + + The marginalized op will have dims_connections = [(1, 0, None), (None, 0, 1)] + which tells us we need to reduce the last axis of dep1 logp and the first of dep2 logp, + as well as transpose the remaining axis of dep1 logp before adding the two element-wise. + + """ + from pymc_experimental.model.marginal.graph_analysis import get_support_axes - mbcast = marginalized_type.broadcastable reduced_logps = [] - for dependent_logp in dependent_logps: - dbcast = dependent_logp.type.broadcastable - dim_diff = len(dbcast) - len(mbcast) - mbcast_aligned = (True,) * dim_diff + mbcast - vbcast_axis = [i for i, (m, v) in enumerate(zip(mbcast_aligned, dbcast)) if m and not v] - reduced_logps.append(dependent_logp.sum(vbcast_axis)) - return pt.add(*reduced_logps) + for dependent_op, dependent_logp, dependent_dims_connection in zip( + dependent_ops, dependent_logps, dependent_dims_connections + ): + if dependent_logp.type.ndim > 0: + # Find which support axis implied by the MarginalRV need to be reduced + # Some may have already been reduced by the logp expression of the dependent RV (e.g., multivariate RVs) + dep_supp_axes = get_support_axes(dependent_op)[0] + # Dependent RV support axes are already collapsed in the logp, so we ignore them + supp_axes = [ + -i + for i, dim in enumerate(reversed(dependent_dims_connection), start=1) + if (dim is None and -i not in dep_supp_axes) + ] + dependent_logp = dependent_logp.sum(supp_axes) -@_logprob.register(FiniteDiscreteMarginalRV) -def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs): - # Clone the inner RV graph of the Marginalized RV - marginalized_rvs_node = op.make_node(*inputs) - marginalized_rv, *inner_rvs = clone_replace( + # Finally, we need to align the dependent logp batch dimensions with the marginalized logp + dims_alignment = [dim for dim in dependent_dims_connection if dim is not None] + dependent_logp = dependent_logp.transpose(*dims_alignment) + + reduced_logps.append(dependent_logp) + + reduced_logp = pt.add(*reduced_logps) + return reduced_logp + + +def align_logp_dims(dims: tuple[tuple[int, None]], logp: TensorVariable) -> TensorVariable: + """Align the logp with the order specified in dims.""" + dims_alignment = [dim for dim in dims if dim is not None] + return logp.transpose(*dims_alignment) + + +def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Variable]: + """Inline the inner graph (outputs) of an OpFromGraph Op. + + Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps" + the inner graph. + """ + return clone_replace( op.inner_outputs, - replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)}, + replace=tuple(zip(op.inner_inputs, inputs)), ) + +DUMMY_ZERO = pt.constant(0, name="dummy_zero") + + +@_logprob.register(MarginalFiniteDiscreteRV) +def finite_discrete_marginal_rv_logp(op: MarginalFiniteDiscreteRV, values, *inputs, **kwargs): + # Clone the inner RV graph of the Marginalized RV + marginalized_rv, *inner_rvs = inline_ofg_outputs(op, inputs) + # Obtain the joint_logp graph of the inner RV graph inner_rv_values = dict(zip(inner_rvs, values)) marginalized_vv = marginalized_rv.clone() @@ -78,8 +154,10 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs): # Reduce logp dimensions corresponding to broadcasted variables marginalized_logp = logps_dict.pop(marginalized_vv) - joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps( - marginalized_rv.type, logps_dict.values() + joint_logp = marginalized_logp + reduce_batch_dependent_logps( + dependent_dims_connections=op.dims_connections, + dependent_ops=[inner_rv.owner.op for inner_rv in inner_rvs], + dependent_logps=[logps_dict[value] for value in values], ) # Compute the joint_logp for all possible n values of the marginalized RV. We assume @@ -116,21 +194,20 @@ def logp_fn(marginalized_rv_const, *non_sequences): mode=Mode().including("local_remove_check_parameter"), ) - joint_logps = pt.logsumexp(joint_logps, axis=0) + joint_logp = pt.logsumexp(joint_logps, axis=0) + + # Align logp with non-collapsed batch dimensions of first RV + joint_logp = align_logp_dims(dims=op.dims_connections[0], logp=joint_logp) # We have to add dummy logps for the remaining value variables, otherwise PyMC will raise - return joint_logps, *(pt.constant(0),) * (len(values) - 1) + dummy_logps = (DUMMY_ZERO,) * (len(values) - 1) + return joint_logp, *dummy_logps -@_logprob.register(DiscreteMarginalMarkovChainRV) +@_logprob.register(MarginalDiscreteMarkovChainRV) def marginal_hmm_logp(op, values, *inputs, **kwargs): - marginalized_rvs_node = op.make_node(*inputs) - inner_rvs = clone_replace( - op.inner_outputs, - replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)}, - ) + chain_rv, *dependent_rvs = inline_ofg_outputs(op, inputs) - chain_rv, *dependent_rvs = inner_rvs P, n_steps_, init_dist_, rng = chain_rv.owner.inputs domain = pt.arange(P.shape[-1], dtype="int32") @@ -145,8 +222,10 @@ def marginal_hmm_logp(op, values, *inputs, **kwargs): logp_emissions_dict = conditional_logp(dict(zip(dependent_rvs, values))) # Reduce and add the batch dims beyond the chain dimension - reduced_logp_emissions = _add_reduce_batch_dependent_logps( - chain_rv.type, logp_emissions_dict.values() + reduced_logp_emissions = reduce_batch_dependent_logps( + dependent_dims_connections=op.dims_connections, + dependent_ops=[dependent_rv.owner.op for dependent_rv in dependent_rvs], + dependent_logps=[logp_emissions_dict[value] for value in values], ) # Add a batch dimension for the domain of the chain @@ -185,7 +264,13 @@ def step_alpha(logp_emission, log_alpha, log_P): # Final logp is just the sum of the last scan state joint_logp = pt.logsumexp(log_alpha_seq[-1], axis=0) + # Align logp with non-collapsed batch dimensions of first RV + remaining_dims_first_emission = list(op.dims_connections[0]) + # The last dim of chain_rv was removed when computing the logp + remaining_dims_first_emission.remove(chain_rv.type.ndim - 1) + joint_logp = align_logp_dims(remaining_dims_first_emission, joint_logp) + # If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first - # return is the joint probability of everything together, but PyMC still expects one logp for each one. - dummy_logps = (pt.constant(0),) * (len(values) - 1) + # return is the joint probability of everything together, but PyMC still expects one logp for each emission stream. + dummy_logps = (DUMMY_ZERO,) * (len(values) - 1) return joint_logp, *dummy_logps diff --git a/pymc_experimental/model/marginal/graph_analysis.py b/pymc_experimental/model/marginal/graph_analysis.py index 58c11a1e..62ac2abb 100644 --- a/pymc_experimental/model/marginal/graph_analysis.py +++ b/pymc_experimental/model/marginal/graph_analysis.py @@ -1,8 +1,22 @@ +import itertools + +from collections.abc import Sequence +from itertools import zip_longest + +from pymc import SymbolicRandomVariable from pytensor.compile import SharedVariable -from pytensor.graph import Constant, FunctionGraph, ancestors -from pytensor.tensor import TensorVariable -from pytensor.tensor.elemwise import DimShuffle, Elemwise +from pytensor.graph import Constant, Variable, ancestors +from pytensor.graph.basic import io_toposort +from pytensor.tensor import TensorType, TensorVariable +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise +from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.rewriting.subtensor import is_full_slice from pytensor.tensor.shape import Shape +from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor, get_idx_list +from pytensor.tensor.type_other import NoneTypeT + +from pymc_experimental.model.marginal.distributions import MarginalRV def static_shape_ancestors(vars): @@ -48,45 +62,311 @@ def find_conditional_dependent_rvs(dependable_rv, all_rvs): ] -def is_elemwise_subgraph(rv_to_marginalize, other_input_rvs, output_rvs): - # TODO: No need to consider apply nodes outside the subgraph... - fg = FunctionGraph(outputs=output_rvs, clone=False) - - non_elemwise_blockers = [ - o - for node in fg.apply_nodes - if not ( - isinstance(node.op, Elemwise) - # Allow expand_dims on the left - or ( - isinstance(node.op, DimShuffle) - and not node.op.drop - and node.op.shuffle == sorted(node.op.shuffle) - ) - ) - for o in node.outputs - ] - blocker_candidates = [rv_to_marginalize, *other_input_rvs, *non_elemwise_blockers] - blockers = [var for var in blocker_candidates if var not in output_rvs] +def get_support_axes(op) -> tuple[tuple[int, ...], ...]: + if isinstance(op, MarginalRV): + return op.support_axes + else: + # For vanilla RVs, the support axes are the last ndim_supp + return (tuple(range(-op.ndim_supp, 0)),) - truncated_inputs = [ - var - for var in ancestors(output_rvs, blockers=blockers) - if ( - var in blockers - or (var.owner is None and not isinstance(var, Constant | SharedVariable)) - ) + +def _advanced_indexing_axis_and_ndim(idxs) -> tuple[int, int]: + """Find the output axis and dimensionality of the advanced indexing group (i.e., array indexing). + + There is a special case: when there are non-consecutive advanced indexing groups, the advanced indexing + group is always moved to the front. + + See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing + """ + adv_group_axis = None + simple_group_after_adv = False + for axis, idx in enumerate(idxs): + if isinstance(idx.type, TensorType): + if simple_group_after_adv: + # Special non-consecutive case + adv_group_axis = 0 + break + elif adv_group_axis is None: + adv_group_axis = axis + elif adv_group_axis is not None: + # Special non-consecutive case + simple_group_after_adv = True + + adv_group_ndim = max(idx.type.ndim for idx in idxs if isinstance(idx.type, TensorType)) + return adv_group_axis, adv_group_ndim + + +DIMS = tuple[int | None, ...] +VAR_DIMS = dict[Variable, DIMS] + + +def _broadcast_dims( + inputs_dims: Sequence[DIMS], +) -> DIMS: + output_ndim = max((len(input_dim) for input_dim in inputs_dims), default=0) + + # Add missing dims + inputs_dims = [ + (None,) * (output_ndim - len(input_dim)) + input_dim for input_dim in inputs_dims ] - # Check that we reach the marginalized rv following a pure elemwise graph - if rv_to_marginalize not in truncated_inputs: - return False - - # Check that none of the truncated inputs depends on the marginalized_rv - other_truncated_inputs = [inp for inp in truncated_inputs if inp is not rv_to_marginalize] - # TODO: We don't need to go all the way to the root variables - if rv_to_marginalize in ancestors( - other_truncated_inputs, blockers=[rv_to_marginalize, *other_input_rvs] - ): - return False - return True + # Find which known dims show in the output, while checking no mixing + output_dims = [] + for inputs_dim in zip(*inputs_dims): + output_dim = None + for input_dim in inputs_dim: + if input_dim is None: + continue + if output_dim is not None and output_dim != input_dim: + raise ValueError("Different known dimensions mixed via broadcasting") + output_dim = input_dim + output_dims.append(output_dim) + + # Check for duplicates + known_dims = [dim for dim in output_dims if dim is not None] + if len(known_dims) > len(set(known_dims)): + raise ValueError("Same known dimension used in different axis after broadcasting") + + return tuple(output_dims) + + +def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars) -> VAR_DIMS: + for node in io_toposort(input_vars, output_vars): + inputs_dims = [ + var_dims.get(inp, ((None,) * inp.type.ndim) if hasattr(inp.type, "ndim") else ()) + for inp in node.inputs + ] + + if all(dim is None for input_dims in inputs_dims for dim in input_dims): + # None of the inputs are related to the batch_axes of the input_vars + continue + + elif isinstance(node.op, DimShuffle): + [input_dims] = inputs_dims + output_dims = tuple(None if i == "x" else input_dims[i] for i in node.op.new_order) + var_dims[node.outputs[0]] = output_dims + + elif isinstance(node.op, MarginalRV) or ( + isinstance(node.op, SymbolicRandomVariable) and node.op.extended_signature is None + ): + # MarginalRV and SymbolicRandomVariables without signature are a wild-card, + # so we need to introspect the inner graph. + op = node.op + inner_inputs = op.inner_inputs + inner_outputs = op.inner_outputs + + inner_var_dims = _subgraph_batch_dim_connection( + dict(zip(inner_inputs, inputs_dims)), inner_inputs, inner_outputs + ) + + support_axes = iter(get_support_axes(op)) + if isinstance(op, MarginalRV): + # The first output is the marginalized variable for which we don't compute support axes + support_axes = itertools.chain(((),), support_axes) + for i, (out, inner_out) in enumerate(zip(node.outputs, inner_outputs)): + if not isinstance(out.type, TensorType): + continue + support_axes_out = next(support_axes) + + if inner_out in inner_var_dims: + out_dims = inner_var_dims[inner_out] + if any( + dim is not None for dim in (out_dims[axis] for axis in support_axes_out) + ): + raise ValueError(f"Known dim corresponds to core dimension of {node.op}") + var_dims[out] = out_dims + + elif isinstance(node.op, Elemwise | Blockwise | RandomVariable | SymbolicRandomVariable): + # NOTE: User-provided CustomDist may not respect core dimensions on the left. + + if isinstance(node.op, Elemwise): + op_batch_ndim = node.outputs[0].type.ndim + else: + op_batch_ndim = node.op.batch_ndim(node) + + if isinstance(node.op, SymbolicRandomVariable): + # SymbolicRandomVariable don't have explicit expand_dims unlike the other Ops considered in this + [_, _, param_idxs], _ = node.op.get_input_output_type_idxs( + node.op.extended_signature + ) + for param_idx, param_core_ndim in zip(param_idxs, node.op.ndims_params): + param_dims = inputs_dims[param_idx] + missing_ndim = op_batch_ndim - (len(param_dims) - param_core_ndim) + inputs_dims[param_idx] = (None,) * missing_ndim + param_dims + + if any( + dim is not None for input_dim in inputs_dims for dim in input_dim[op_batch_ndim:] + ): + raise ValueError( + f"Use of known dimensions as core dimensions of op {node.op} not supported." + ) + + batch_dims = _broadcast_dims( + tuple(input_dims[:op_batch_ndim] for input_dims in inputs_dims) + ) + for out in node.outputs: + if isinstance(out.type, TensorType): + core_ndim = out.type.ndim - op_batch_ndim + output_dims = batch_dims + (None,) * core_ndim + var_dims[out] = output_dims + + elif isinstance(node.op, CAReduce): + [input_dims] = inputs_dims + + axes = node.op.axis + if isinstance(axes, int): + axes = (axes,) + elif axes is None: + axes = tuple(range(node.inputs[0].type.ndim)) + + if any(input_dims[axis] for axis in axes): + raise ValueError( + f"Use of known dimensions as reduced dimensions of op {node.op} not supported." + ) + + output_dims = [dims for i, dims in enumerate(input_dims) if i not in axes] + var_dims[node.outputs[0]] = tuple(output_dims) + + elif isinstance(node.op, Subtensor): + value_dims, *keys_dims = inputs_dims + # Dims in basic indexing must belong to the value variable, since indexing keys are always scalar + assert not any(dim is None for dim in keys_dims) + keys = get_idx_list(node.inputs, node.op.idx_list) + + output_dims = [] + for value_dim, idx in zip_longest(value_dims, keys, fillvalue=slice(None)): + if idx == slice(None): + # Dim is kept + output_dims.append(value_dim) + elif value_dim is not None: + raise ValueError( + "Partial slicing or indexing of known dimensions not supported." + ) + elif isinstance(idx, slice): + # Unknown dimensions kept by partial slice. + output_dims.append(None) + + var_dims[node.outputs[0]] = tuple(output_dims) + + elif isinstance(node.op, AdvancedSubtensor): + # AdvancedSubtensor dimensions can show up as both the indexed variable and indexing variables + value, *keys = node.inputs + value_dims, *keys_dims = inputs_dims + + # Just to stay sane, we forbid any boolean indexing... + if any(isinstance(idx.type, TensorType) and idx.type.dtype == "bool" for idx in keys): + raise NotImplementedError( + f"Array indexing with boolean variables in node {node} not supported." + ) + + if any(dim is not None for dim in value_dims) and any( + dim is not None for key_dims in keys_dims for dim in key_dims + ): + # Both indexed variable and indexing variables have known dimensions + # I am to lazy to think through these, so we raise for now. + raise NotImplementedError( + f"Simultaneous use of known dimensions in indexed and indexing variables in node {node} not supported." + ) + + adv_group_axis, adv_group_ndim = _advanced_indexing_axis_and_ndim(keys) + + if any(dim is not None for dim in value_dims): + # Indexed variable has known dimensions + + if any(isinstance(idx.type, NoneTypeT) for idx in keys): + # Corresponds to an expand_dims, for now not supported + raise NotImplementedError( + f"Advanced indexing in node {node} which introduces new axis is not supported." + ) + + non_adv_dims = [] + for value_dim, idx in zip_longest(value_dims, keys, fillvalue=slice(None)): + if is_full_slice(idx): + non_adv_dims.append(value_dim) + elif value_dim is not None: + # We are trying to partially slice or index a known dimension + raise ValueError( + "Partial slicing or advanced integer indexing of known dimensions not supported." + ) + elif isinstance(idx, slice): + # Unknown dimensions kept by partial slice. + non_adv_dims.append(None) + + # Insert unknown dimensions corresponding to advanced indexing + output_dims = tuple( + non_adv_dims[:adv_group_axis] + + [None] * adv_group_ndim + + non_adv_dims[adv_group_axis:] + ) + + else: + # Indexing keys have known dimensions. + # Only array indices can have dimensions, the rest are just slices or newaxis + + # Advanced indexing variables broadcast together, so we apply same rules as in Elemwise + adv_dims = _broadcast_dims(keys_dims) + + start_non_adv_dims = (None,) * adv_group_axis + end_non_adv_dims = (None,) * ( + node.outputs[0].type.ndim - adv_group_axis - adv_group_ndim + ) + output_dims = start_non_adv_dims + adv_dims + end_non_adv_dims + + var_dims[node.outputs[0]] = output_dims + + else: + raise NotImplementedError(f"Marginalization through operation {node} not supported.") + + return var_dims + + +def subgraph_batch_dim_connection(input_var, output_vars) -> list[DIMS]: + """Identify how the batch dims of input map to the batch dimensions of the output_rvs. + + Example: + ------- + In the example below `idx` has two batch dimensions (indexed 0, 1 from left to right). + The two uncommented dependent variables each have 2 batch dimensions where each entry + results from a mapping of a single entry from one of these batch dimensions. + + This mapping is transposed in the case of the first dependent variable, and shows up in + the same order for the second dependent variable. Each of the variables as a further + batch dimension encoded as `None`. + + The commented out third dependent variable combines information from the batch dimensions + of `idx` via the `sum` operation. A `ValueError` would be raised if we requested the + connection of batch dims. + + .. code-block:: python + import pymc as pm + + idx = pm.Bernoulli.dist(shape=(3, 2)) + dep1 = pm.Normal.dist(mu=idx.T[..., None] * 2, shape=(3, 2, 5)) + dep2 = pm.Normal.dist(mu=idx * 2, shape=(7, 2, 3)) + # dep3 = pm.Normal.dist(mu=idx.sum()) # Would raise if requested + + print(subgraph_batch_dim_connection(idx, [], [dep1, dep2])) + # [(1, 0, None), (None, 0, 1)] + + Returns: + ------- + list of tuples + Each tuple corresponds to the batch dimensions of the output_rv in the order they are found in the output. + None is used to indicate a batch dimension that is not mapped from the input. + + Raises: + ------ + ValueError + If input batch dimensions are mixed in the graph leading to output_vars. + + NotImplementedError + If variable related to marginalized batch_dims is used in an operation that is not yet supported + """ + var_dims = {input_var: tuple(range(input_var.type.ndim))} + var_dims = _subgraph_batch_dim_connection(var_dims, [input_var], output_vars) + ret = [] + for output_var in output_vars: + output_dims = var_dims.get(output_var, (None,) * output_var.type.ndim) + assert len(output_dims) == output_var.type.ndim + ret.append(output_dims) + return ret diff --git a/pymc_experimental/model/marginal/marginal_model.py b/pymc_experimental/model/marginal/marginal_model.py index 94c577c4..b4700c3d 100644 --- a/pymc_experimental/model/marginal/marginal_model.py +++ b/pymc_experimental/model/marginal/marginal_model.py @@ -16,8 +16,7 @@ from pymc.pytensorf import compile_pymc, constant_fold from pymc.util import RandomState, _get_seeds_per_chain, treedict from pytensor.compile import SharedVariable -from pytensor.graph import FunctionGraph, clone_replace -from pytensor.graph.basic import graph_inputs +from pytensor.graph import FunctionGraph, clone_replace, graph_inputs from pytensor.graph.replace import vectorize_graph from pytensor.tensor import TensorVariable from pytensor.tensor.special import log_softmax @@ -26,16 +25,16 @@ from pymc_experimental.distributions import DiscreteMarkovChain from pymc_experimental.model.marginal.distributions import ( - DiscreteMarginalMarkovChainRV, - FiniteDiscreteMarginalRV, - _add_reduce_batch_dependent_logps, + MarginalDiscreteMarkovChainRV, + MarginalFiniteDiscreteRV, get_domain_of_finite_discrete_rv, + reduce_batch_dependent_logps, ) from pymc_experimental.model.marginal.graph_analysis import ( find_conditional_dependent_rvs, find_conditional_input_rvs, is_conditional_dependent, - is_elemwise_subgraph, + subgraph_batch_dim_connection, ) ModelRVs = TensorVariable | Sequence[TensorVariable] | str | Sequence[str] @@ -424,17 +423,22 @@ def transform_input(inputs): m = self.clone() marginalized_rv = m.vars_to_clone[marginalized_rv] m.unmarginalize([marginalized_rv]) - dependent_vars = find_conditional_dependent_rvs(marginalized_rv, m.basic_RVs) - joint_logps = m.logp(vars=[marginalized_rv, *dependent_vars], sum=False) - - marginalized_value = m.rvs_to_values[marginalized_rv] - other_values = [v for v in m.value_vars if v is not marginalized_value] + dependent_rvs = find_conditional_dependent_rvs(marginalized_rv, m.basic_RVs) + logps = m.logp(vars=[marginalized_rv, *dependent_rvs], sum=False) # Handle batch dims for marginalized value and its dependent RVs - marginalized_logp, *dependent_logps = joint_logps - joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps( - marginalized_rv.type, dependent_logps + dependent_rvs_dim_connections = subgraph_batch_dim_connection( + marginalized_rv, dependent_rvs ) + marginalized_logp, *dependent_logps = logps + joint_logp = marginalized_logp + reduce_batch_dependent_logps( + dependent_rvs_dim_connections, + [dependent_var.owner.op for dependent_var in dependent_rvs], + dependent_logps, + ) + + marginalized_value = m.rvs_to_values[marginalized_rv] + other_values = [v for v in m.value_vars if v is not marginalized_value] rv_shape = constant_fold(tuple(marginalized_rv.shape), raise_not_constant=False) rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv) @@ -448,37 +452,30 @@ def transform_input(inputs): 0, ) - joint_logps = vectorize_graph( + batched_joint_logp = vectorize_graph( joint_logp, replace={marginalized_value: rv_domain_tensor}, ) - joint_logps = pt.moveaxis(joint_logps, 0, -1) + batched_joint_logp = pt.moveaxis(batched_joint_logp, 0, -1) - rv_loglike_fn = None - joint_logps_norm = log_softmax(joint_logps, axis=-1) + joint_logp_norm = log_softmax(batched_joint_logp, axis=-1) if return_samples: - sample_rv_outs = pymc.Categorical.dist(logit_p=joint_logps) + rv_draws = pymc.Categorical.dist(logit_p=batched_joint_logp) if isinstance(marginalized_rv.owner.op, DiscreteUniform): - sample_rv_outs += rv_domain[0] - - rv_loglike_fn = compile_pymc( - inputs=other_values, - outputs=[joint_logps_norm, sample_rv_outs], - on_unused_input="ignore", - random_seed=seed, - ) + rv_draws += rv_domain[0] + outputs = [joint_logp_norm, rv_draws] else: - rv_loglike_fn = compile_pymc( - inputs=other_values, - outputs=joint_logps_norm, - on_unused_input="ignore", - random_seed=seed, - ) + outputs = joint_logp_norm + + rv_loglike_fn = compile_pymc( + inputs=other_values, + outputs=outputs, + on_unused_input="ignore", + random_seed=seed, + ) logvs = [rv_loglike_fn(**vs) for vs in posterior_pts] - logps = None - samples = None if return_samples: logps, samples = zip(*logvs) logps = np.array(logps) @@ -552,61 +549,47 @@ def collect_shared_vars(outputs, blockers): def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs): - # TODO: This should eventually be integrated in a more general routine that can - # identify other types of supported marginalization, of which finite discrete - # RVs is just one - dependent_rvs = find_conditional_dependent_rvs(rv_to_marginalize, all_rvs) if not dependent_rvs: raise ValueError(f"No RVs depend on marginalized RV {rv_to_marginalize}") - ndim_supp = {rv.owner.op.ndim_supp for rv in dependent_rvs} - if len(ndim_supp) != 1: - raise NotImplementedError( - "Marginalization with dependent variables of different support dimensionality not implemented" - ) - [ndim_supp] = ndim_supp - if ndim_supp > 0: - raise NotImplementedError("Marginalization with dependent Multivariate RVs not implemented") - marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs) - dependent_rvs_input_rvs = [ + other_direct_rv_ancestors = [ rv for rv in find_conditional_input_rvs(dependent_rvs, all_rvs) if rv is not rv_to_marginalize ] - # If the marginalized RV has batched dimensions, check that graph between - # marginalized RV and dependent RVs is composed strictly of Elemwise Operations. - # This implies (?) that the dimensions are completely independent and a logp graph - # can ultimately be generated that is proportional to the support domain and not - # to the variables dimensions - # We don't need to worry about this if the RV is scalar. - if np.prod(constant_fold(tuple(rv_to_marginalize.shape), raise_not_constant=False)) != 1: - if not is_elemwise_subgraph(rv_to_marginalize, dependent_rvs_input_rvs, dependent_rvs): - raise NotImplementedError( - "The subgraph between a marginalized RV and its dependents includes non Elemwise operations. " - "This is currently not supported", - ) + # If the marginalized RV has multiple dimensions, check that graph between + # marginalized RV and dependent RVs does not mix information from batch dimensions + # (otherwise logp would require enumerating over all combinations of batch dimension values) + try: + dependent_rvs_dim_connections = subgraph_batch_dim_connection( + rv_to_marginalize, dependent_rvs + ) + except (ValueError, NotImplementedError) as e: + # For the perspective of the user this is a NotImplementedError + raise NotImplementedError( + "The graph between the marginalized and dependent RVs cannot be marginalized efficiently. " + "You can try splitting the marginalized RV into separate components and marginalizing them separately." + ) from e - input_rvs = list(set((*marginalized_rv_input_rvs, *dependent_rvs_input_rvs))) - rvs_to_marginalize = [rv_to_marginalize, *dependent_rvs] + input_rvs = list(set((*marginalized_rv_input_rvs, *other_direct_rv_ancestors))) + output_rvs = [rv_to_marginalize, *dependent_rvs] - outputs = rvs_to_marginalize # We are strict about shared variables in SymbolicRandomVariables - inputs = input_rvs + collect_shared_vars(rvs_to_marginalize, blockers=input_rvs) + inputs = input_rvs + collect_shared_vars(output_rvs, blockers=input_rvs) if isinstance(rv_to_marginalize.owner.op, DiscreteMarkovChain): - marginalize_constructor = DiscreteMarginalMarkovChainRV + marginalize_constructor = MarginalDiscreteMarkovChainRV else: - marginalize_constructor = FiniteDiscreteMarginalRV + marginalize_constructor = MarginalFiniteDiscreteRV marginalization_op = marginalize_constructor( inputs=inputs, - outputs=outputs, - ndim_supp=ndim_supp, + outputs=output_rvs, # TODO: Add RNG updates to outputs so this can be used in the generative graph + dims_connections=dependent_rvs_dim_connections, ) - - marginalized_rvs = marginalization_op(*inputs) - fgraph.replace_all(tuple(zip(rvs_to_marginalize, marginalized_rvs))) - return rvs_to_marginalize, marginalized_rvs + new_output_rvs = marginalization_op(*inputs) + fgraph.replace_all(tuple(zip(output_rvs, new_output_rvs))) + return output_rvs, new_output_rvs diff --git a/requirements.txt b/requirements.txt index a7141a82..b992ad37 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -pymc>=5.16.1 +pymc>=5.17.0 scikit-learn diff --git a/tests/model/marginal/test_distributions.py b/tests/model/marginal/test_distributions.py index 7c0e0fd5..ecbc8817 100644 --- a/tests/model/marginal/test_distributions.py +++ b/tests/model/marginal/test_distributions.py @@ -8,7 +8,7 @@ from pymc_experimental import MarginalModel from pymc_experimental.distributions import DiscreteMarkovChain -from pymc_experimental.model.marginal.distributions import FiniteDiscreteMarginalRV +from pymc_experimental.model.marginal.distributions import MarginalFiniteDiscreteRV def test_marginalized_bernoulli_logp(): @@ -17,13 +17,10 @@ def test_marginalized_bernoulli_logp(): idx = pm.Bernoulli.dist(0.7, name="idx") y = pm.Normal.dist(mu=mu[idx], sigma=1.0, name="y") - marginal_rv_node = FiniteDiscreteMarginalRV( + marginal_rv_node = MarginalFiniteDiscreteRV( [mu], [idx, y], - ndim_supp=0, - n_updates=0, - # Ignore the fact we didn't specify shared RNG input/outputs for idx,y - strict=False, + dims_connections=(((),),), )(mu)[0].owner y_vv = y.clone() @@ -78,9 +75,7 @@ def test_marginalized_hmm_categorical_emission(categorical_emission): init_dist = pm.Categorical.dist(p=[0.375, 0.625]) chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=2) if categorical_emission: - emission = pm.Categorical( - "emission", p=pt.where(pt.eq(chain, 0)[..., None], [0.8, 0.2], [0.4, 0.6]) - ) + emission = pm.Categorical("emission", p=pt.constant([[0.8, 0.2], [0.4, 0.6]])[chain]) else: emission = pm.Bernoulli("emission", p=pt.where(pt.eq(chain, 0), 0.2, 0.6)) m.marginalize([chain]) @@ -91,29 +86,46 @@ def test_marginalized_hmm_categorical_emission(categorical_emission): np.testing.assert_allclose(logp_fn({"emission": test_value}), expected_logp) +@pytest.mark.parametrize("batch_chain", (False, True)) @pytest.mark.parametrize("batch_emission1", (False, True)) @pytest.mark.parametrize("batch_emission2", (False, True)) -def test_marginalized_hmm_multiple_emissions(batch_emission1, batch_emission2): - emission1_shape = (2, 4) if batch_emission1 else (4,) - emission2_shape = (2, 4) if batch_emission2 else (4,) +def test_marginalized_hmm_multiple_emissions(batch_chain, batch_emission1, batch_emission2): + chain_shape = (3, 1, 4) if batch_chain else (4,) + emission1_shape = ( + (2, *reversed(chain_shape)) if batch_emission1 else tuple(reversed(chain_shape)) + ) + emission2_shape = (*chain_shape, 2) if batch_emission2 else chain_shape with MarginalModel() as m: P = [[0, 1], [1, 0]] init_dist = pm.Categorical.dist(p=[1, 0]) - chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=3) - emission_1 = pm.Normal("emission_1", mu=chain * 2 - 1, sigma=1e-1, shape=emission1_shape) - emission_2 = pm.Normal( - "emission_2", mu=(1 - chain) * 2 - 1, sigma=1e-1, shape=emission2_shape + chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, shape=chain_shape) + emission_1 = pm.Normal( + "emission_1", mu=(chain * 2 - 1).T, sigma=1e-1, shape=emission1_shape ) + emission2_mu = (1 - chain) * 2 - 1 + if batch_emission2: + emission2_mu = emission2_mu[..., None] + emission_2 = pm.Normal("emission_2", mu=emission2_mu, sigma=1e-1, shape=emission2_shape) + with pytest.warns(UserWarning, match="multiple dependent variables"): m.marginalize([chain]) - logp_fn = m.compile_logp() + logp_fn = m.compile_logp(sum=False) test_value = np.array([-1, 1, -1, 1]) multiplier = 2 + batch_emission1 + batch_emission2 + if batch_chain: + multiplier *= 3 expected_logp = norm.logpdf(np.zeros_like(test_value), 0, 1e-1).sum() * multiplier - test_value_emission1 = np.broadcast_to(test_value, emission1_shape) - test_value_emission2 = np.broadcast_to(-test_value, emission2_shape) + + test_value = np.broadcast_to(test_value, chain_shape) + test_value_emission1 = np.broadcast_to(test_value.T, emission1_shape) + if batch_emission2: + test_value_emission2 = np.broadcast_to(-test_value[..., None], emission2_shape) + else: + test_value_emission2 = np.broadcast_to(-test_value, emission2_shape) test_point = {"emission_1": test_value_emission1, "emission_2": test_value_emission2} - np.testing.assert_allclose(logp_fn(test_point), expected_logp) + res_logp, dummy_logp = logp_fn(test_point) + assert res_logp.shape == ((1, 3) if batch_chain else ()) + np.testing.assert_allclose(res_logp.sum(), expected_logp) diff --git a/tests/model/marginal/test_graph_analysis.py b/tests/model/marginal/test_graph_analysis.py index 58d65dbe..2382247b 100644 --- a/tests/model/marginal/test_graph_analysis.py +++ b/tests/model/marginal/test_graph_analysis.py @@ -1,6 +1,13 @@ -from pytensor import tensor as pt +import pytensor.tensor as pt +import pytest -from pymc_experimental.model.marginal.graph_analysis import is_conditional_dependent +from pymc.distributions import CustomDist +from pytensor.tensor.type_other import NoneTypeT + +from pymc_experimental.model.marginal.graph_analysis import ( + is_conditional_dependent, + subgraph_batch_dim_connection, +) def test_is_conditional_dependent_static_shape(): @@ -12,3 +19,164 @@ def test_is_conditional_dependent_static_shape(): x2 = pt.matrix("x2", shape=(9, 5)) y2 = pt.random.normal(size=pt.shape(x2)) assert not is_conditional_dependent(y2, x2, [x2, y2]) + + +class TestSubgraphBatchDimConnection: + def test_dimshuffle(self): + inp = pt.tensor(shape=(5, 1, 4, 3)) + out1 = pt.matrix_transpose(inp) + out2 = pt.expand_dims(inp, 1) + out3 = pt.squeeze(inp) + [dims1, dims2, dims3] = subgraph_batch_dim_connection(inp, [out1, out2, out3]) + assert dims1 == (0, 1, 3, 2) + assert dims2 == (0, None, 1, 2, 3) + assert dims3 == (0, 2, 3) + + def test_careduce(self): + inp = pt.tensor(shape=(4, 3, 2)) + + out = pt.sum(inp[:, None], axis=(1,)) + [dims] = subgraph_batch_dim_connection(inp, [out]) + assert dims == (0, 1, 2) + + invalid_out = pt.sum(inp, axis=(1,)) + with pytest.raises(ValueError, match="Use of known dimensions"): + subgraph_batch_dim_connection(inp, [invalid_out]) + + def test_subtensor(self): + inp = pt.tensor(shape=(4, 3, 2)) + + invalid_out = inp[0, :1] + with pytest.raises( + ValueError, + match="Partial slicing or indexing of known dimensions not supported", + ): + subgraph_batch_dim_connection(inp, [invalid_out]) + + # If we are selecting dummy / unknown dimensions that's fine + valid_out = pt.expand_dims(inp, (0, 1))[0, :1] + [dims] = subgraph_batch_dim_connection(inp, [valid_out]) + assert dims == (None, 0, 1, 2) + + def test_advanced_subtensor_value(self): + inp = pt.tensor(shape=(2, 4)) + intermediate_out = inp[:, None, :, None] + pt.zeros((2, 3, 4, 5)) + + # Index on an unlabled dim introduced by broadcasting with zeros + out = intermediate_out[:, [0, 0, 1, 2]] + [dims] = subgraph_batch_dim_connection(inp, [out]) + assert dims == (0, None, 1, None) + + # Indexing that introduces more dimensions + out = intermediate_out[:, [[0, 0], [1, 2]], :] + [dims] = subgraph_batch_dim_connection(inp, [out]) + assert dims == (0, None, None, 1, None) + + # Special case where advanced dims are moved to the front of the output + out = intermediate_out[:, [0, 0, 1, 2], :, 0] + [dims] = subgraph_batch_dim_connection(inp, [out]) + assert dims == (None, 0, 1) + + # Indexing on a labeled dim fails + out = intermediate_out[:, :, [0, 0, 1, 2]] + with pytest.raises(ValueError, match="Partial slicing or advanced integer indexing"): + subgraph_batch_dim_connection(inp, [out]) + + def test_advanced_subtensor_key(self): + inp = pt.tensor(shape=(5, 5), dtype=int) + base = pt.zeros((2, 3, 4)) + + out = base[inp] + [dims] = subgraph_batch_dim_connection(inp, [out]) + assert dims == (0, 1, None, None) + + out = base[:, :, inp] + [dims] = subgraph_batch_dim_connection(inp, [out]) + assert dims == ( + None, + None, + 0, + 1, + ) + + out = base[1:, 0, inp] + [dims] = subgraph_batch_dim_connection(inp, [out]) + assert dims == (None, 0, 1) + + # Special case where advanced dims are moved to the front of the output + out = base[0, :, inp] + [dims] = subgraph_batch_dim_connection(inp, [out]) + assert dims == (0, 1, None) + + # Mix keys dimensions + out = base[:, inp, inp.T] + with pytest.raises(ValueError, match="Different known dimensions mixed via broadcasting"): + subgraph_batch_dim_connection(inp, [out]) + + def test_elemwise(self): + inp = pt.tensor(shape=(5, 5)) + + out = inp + inp + [dims] = subgraph_batch_dim_connection(inp, [out]) + assert dims == (0, 1) + + out = inp + inp.T + with pytest.raises(ValueError, match="Different known dimensions mixed via broadcasting"): + subgraph_batch_dim_connection(inp, [out]) + + out = inp[None, :, None, :] + inp[:, None, :, None] + with pytest.raises( + ValueError, match="Same known dimension used in different axis after broadcasting" + ): + subgraph_batch_dim_connection(inp, [out]) + + def test_blockwise(self): + inp = pt.tensor(shape=(5, 4)) + + invalid_out = inp @ pt.ones((4, 3)) + with pytest.raises(ValueError, match="Use of known dimensions"): + subgraph_batch_dim_connection(inp, [invalid_out]) + + out = (inp[:, :, None, None] + pt.zeros((2, 3))) @ pt.ones((2, 3)) + [dims] = subgraph_batch_dim_connection(inp, [out]) + assert dims == (0, 1, None, None) + + def test_random_variable(self): + inp = pt.tensor(shape=(5, 4, 3)) + + out1 = pt.random.normal(loc=inp) + out2 = pt.random.categorical(p=inp[..., None]) + out3 = pt.random.multivariate_normal(mean=inp[..., None], cov=pt.eye(1)) + [dims1, dims2, dims3] = subgraph_batch_dim_connection(inp, [out1, out2, out3]) + assert dims1 == (0, 1, 2) + assert dims2 == (0, 1, 2) + assert dims3 == (0, 1, 2, None) + + invalid_out = pt.random.categorical(p=inp) + with pytest.raises(ValueError, match="Use of known dimensions"): + subgraph_batch_dim_connection(inp, [invalid_out]) + + invalid_out = pt.random.multivariate_normal(mean=inp, cov=pt.eye(3)) + with pytest.raises(ValueError, match="Use of known dimensions"): + subgraph_batch_dim_connection(inp, [invalid_out]) + + def test_symbolic_random_variable(self): + inp = pt.tensor(shape=(4, 3, 2)) + + # Test univariate + out = CustomDist.dist( + inp, + dist=lambda mu, size: pt.random.normal(loc=mu, size=size), + ) + [dims] = subgraph_batch_dim_connection(inp, [out]) + assert dims == (0, 1, 2) + + # Test multivariate + def dist(mu, size): + if isinstance(size.type, NoneTypeT): + size = mu.shape + return pt.random.normal(loc=mu[..., None], size=(*size, 2)) + + out = CustomDist.dist(inp, dist=dist, size=(4, 3, 2), signature="()->(2)") + [dims] = subgraph_batch_dim_connection(inp, [out]) + assert dims == (0, 1, 2, None) diff --git a/tests/model/marginal/test_marginal_model.py b/tests/model/marginal/test_marginal_model.py index a94499cf..c93cdb74 100644 --- a/tests/model/marginal/test_marginal_model.py +++ b/tests/model/marginal/test_marginal_model.py @@ -10,6 +10,7 @@ from arviz import InferenceData, dict_to_dataset from pymc.distributions import transforms +from pymc.distributions.transforms import ordered from pymc.model.fgraph import fgraph_from_model from pymc.pytensorf import inputvars from pymc.util import UNSET @@ -117,6 +118,36 @@ def test_one_to_many_marginalized_rvs(): np.testing.assert_array_almost_equal(logp_x_y, ref_logp_x_y) +def test_one_to_many_unaligned_marginalized_rvs(): + """Test that marginalization works when there is more than one dependent RV with batch dimensions that are not aligned""" + + def build_model(build_batched: bool): + with MarginalModel() as m: + if build_batched: + idx = pm.Bernoulli("idx", p=[0.75, 0.4], shape=(3, 2)) + else: + idxs = [pm.Bernoulli(f"idx_{i}", p=(0.75 if i % 2 == 0 else 0.4)) for i in range(6)] + idx = pt.stack(idxs, axis=0).reshape((3, 2)) + + x = pm.Normal("x", mu=idx.T[:, :, None], shape=(2, 3, 1)) + y = pm.Normal("y", mu=(idx * 2 - 1), shape=(1, 3, 2)) + + return m + + m = build_model(build_batched=True) + ref_m = build_model(build_batched=False) + + with pytest.warns(UserWarning, match="There are multiple dependent variables"): + m.marginalize(["idx"]) + ref_m.marginalize([f"idx_{i}" for i in range(6)]) + + test_point = m.initial_point() + np.testing.assert_allclose( + m.compile_logp()(test_point), + ref_m.compile_logp()(test_point), + ) + + def test_many_to_one_marginalized_rvs(): """Test when random variables depend on multiple marginalized variables""" with MarginalModel() as m: @@ -132,40 +163,127 @@ def test_many_to_one_marginalized_rvs(): np.testing.assert_allclose(np.exp(logp({"z": 2})), 0.1 * 0.3) -def test_nested_marginalized_rvs(): +@pytest.mark.parametrize("batched", (False, "left", "right")) +def test_nested_marginalized_rvs(batched): """Test that marginalization works when there are nested marginalized RVs""" - with MarginalModel() as m: - sigma = pm.HalfNormal("sigma") + def build_model(build_batched: bool) -> MarginalModel: + idx_shape = (3,) if build_batched else () + sub_idx_shape = (5,) if not build_batched else (5, 3) if batched == "left" else (3, 5) - idx = pm.Bernoulli("idx", p=0.75) - dep = pm.Normal("dep", mu=pt.switch(pt.eq(idx, 0), -1000.0, 1000.0), sigma=sigma) + with MarginalModel() as m: + sigma = pm.HalfNormal("sigma") - sub_idx = pm.Bernoulli("sub_idx", p=pt.switch(pt.eq(idx, 0), 0.15, 0.95), shape=(5,)) - sub_dep = pm.Normal("sub_dep", mu=dep + sub_idx * 100, sigma=sigma, shape=(5,)) + idx = pm.Bernoulli("idx", p=0.75, shape=idx_shape) + dep = pm.Normal("dep", mu=pt.switch(pt.eq(idx, 0), -1000.0, 1000.0), sigma=sigma) - ref_logp_fn = m.compile_logp(vars=[idx, dep, sub_idx, sub_dep]) + sub_idx_p = pt.switch(pt.eq(idx, 0), 0.15, 0.95) + if build_batched and batched == "right": + sub_idx_p = sub_idx_p[..., None] + dep = dep[..., None] + sub_idx = pm.Bernoulli("sub_idx", p=sub_idx_p, shape=sub_idx_shape) + sub_dep = pm.Normal("sub_dep", mu=dep + sub_idx * 100, sigma=sigma) - with pytest.warns(UserWarning, match="There are multiple dependent variables"): - m.marginalize([idx, sub_idx]) + return m - assert set(m.marginalized_rvs) == {idx, sub_idx} + m = build_model(build_batched=batched) + with pytest.warns(UserWarning, match="There are multiple dependent variables"): + m.marginalize(["idx", "sub_idx"]) + assert sorted(m.name for m in m.marginalized_rvs) == ["idx", "sub_idx"] # Test logp + ref_m = build_model(build_batched=False) + ref_logp_fn = ref_m.compile_logp( + vars=[ref_m["idx"], ref_m["dep"], ref_m["sub_idx"], ref_m["sub_dep"]] + ) + + test_point = ref_m.initial_point() + test_point["dep"] = np.full_like(test_point["dep"], 1000) + test_point["sub_dep"] = np.full_like(test_point["sub_dep"], 1000 + 100) + ref_logp = logsumexp( + [ + ref_logp_fn({**test_point, **{"idx": idx, "sub_idx": np.array(sub_idxs)}}) + for idx in (0, 1) + for sub_idxs in itertools.product((0, 1), repeat=5) + ] + ) + if batched: + ref_logp *= 3 + test_point = m.initial_point() - test_point["dep"] = 1000 - test_point["sub_dep"] = np.full((5,), 1000 + 100) - - ref_logp = [ - ref_logp_fn({**test_point, **{"idx": idx, "sub_idx": np.array(sub_idxs)}}) - for idx in (0, 1) - for sub_idxs in itertools.product((0, 1), repeat=5) - ] - logp = m.compile_logp(vars=[dep, sub_dep])(test_point) - - np.testing.assert_almost_equal( - logp, - logsumexp(ref_logp), + test_point["dep"] = np.full_like(test_point["dep"], 1000) + test_point["sub_dep"] = np.full_like(test_point["sub_dep"], 1000 + 100) + logp = m.compile_logp(vars=[m["dep"], m["sub_dep"]])(test_point) + + np.testing.assert_almost_equal(logp, ref_logp) + + +@pytest.mark.parametrize("advanced_indexing", (False, True)) +def test_marginalized_index_as_key(advanced_indexing): + """Test we can marginalize graphs where indexing is used as a mapping.""" + + w = [0.1, 0.3, 0.6] + mu = pt.as_tensor([-1, 0, 1]) + + if advanced_indexing: + y_val = pt.as_tensor([[-1, -1], [0, 1]]) + shape = (2, 2) + else: + y_val = -1 + shape = () + + with MarginalModel() as m: + x = pm.Categorical("x", p=w, shape=shape) + y = pm.Normal("y", mu[x].T, sigma=1, observed=y_val) + + m.marginalize(x) + + marginal_logp = m.compile_logp(sum=False)({})[0] + ref_logp = pm.logp(pm.NormalMixture.dist(w=w, mu=mu.T, sigma=1, shape=shape), y_val).eval() + + np.testing.assert_allclose(marginal_logp, ref_logp) + + +def test_marginalized_index_as_value_and_key(): + """Test we can marginalize graphs were marginalized_rv is indexed.""" + + def build_model(build_batched: bool) -> MarginalModel: + with MarginalModel() as m: + if build_batched: + latent_state = pm.Bernoulli("latent_state", p=0.3, size=(4,)) + else: + latent_state = pm.math.stack( + [pm.Bernoulli(f"latent_state_{i}", p=0.3) for i in range(4)] + ) + # latent state is used as the indexed variable + latent_intensities = pt.where(latent_state[:, None], [0.0, 1.0, 2.0], [0.0, 10.0, 20.0]) + picked_intensity = pm.Categorical("picked_intensity", p=[0.2, 0.2, 0.6]) + # picked intensity is used as the indexing variable + pm.Normal( + "intensity", + mu=latent_intensities[:, picked_intensity], + observed=[0.5, 1.5, 5.0, 15.0], + ) + return m + + # We compare with the equivalent but less efficient batched model + m = build_model(build_batched=True) + ref_m = build_model(build_batched=False) + + m.marginalize(["latent_state"]) + ref_m.marginalize([f"latent_state_{i}" for i in range(4)]) + test_point = {"picked_intensity": 1} + np.testing.assert_allclose( + m.compile_logp()(test_point), + ref_m.compile_logp()(test_point), + ) + + m.marginalize(["picked_intensity"]) + ref_m.marginalize(["picked_intensity"]) + test_point = {} + np.testing.assert_allclose( + m.compile_logp()(test_point), + ref_m.compile_logp()(test_point), ) @@ -229,6 +347,15 @@ def test_mixed_dims_via_support_dimension(self): with pytest.raises(NotImplementedError): m.marginalize(x) + def test_mixed_dims_via_nested_marginalization(self): + with MarginalModel() as m: + x = pm.Bernoulli("x", p=0.7, shape=(3,)) + y = pm.Bernoulli("y", p=0.7, shape=(2,)) + z = pm.Normal("z", mu=pt.add.outer(x, y), shape=(3, 2)) + + with pytest.raises(NotImplementedError): + m.marginalize([x, y]) + def test_marginalized_deterministic_and_potential(): rng = np.random.default_rng(299) @@ -531,6 +658,62 @@ def dist(idx, size): pt = {"norm": test_value} np.testing.assert_allclose(logp_fn(pt), ref_logp_fn(pt)) + def test_k_censored_clusters_model(self): + def build_model(build_batched: bool) -> MarginalModel: + data = np.array([[-1.0, -1.0], [0.0, 0.0], [1.0, 1.0]]) + nobs = data.shape[0] + n_clusters = 5 + coords = { + "cluster": range(n_clusters), + "ndim": ("x", "y"), + "obs": range(nobs), + } + with MarginalModel(coords=coords) as m: + if build_batched: + idx = pm.Categorical("idx", p=np.ones(n_clusters) / n_clusters, dims=["obs"]) + else: + idx = pm.math.stack( + [ + pm.Categorical(f"idx_{i}", p=np.ones(n_clusters) / n_clusters) + for i in range(nobs) + ] + ) + + mu_x = pm.Normal( + "mu_x", + dims=["cluster"], + transform=ordered, + initval=np.linspace(-1, 1, n_clusters), + ) + mu_y = pm.Normal("mu_y", dims=["cluster"]) + mu = pm.math.stack([mu_x, mu_y], axis=-1) # (cluster, ndim) + mu_indexed = mu[idx, :] + + sigma = pm.HalfNormal("sigma") + + y = pm.Censored( + "y", + dist=pm.Normal.dist(mu_indexed, sigma), + lower=-3, + upper=3, + observed=data, + dims=["obs", "ndim"], + ) + + return m + + m = build_model(build_batched=True) + ref_m = build_model(build_batched=False) + + m.marginalize([m["idx"]]) + ref_m.marginalize([n for n in ref_m.named_vars if n.startswith("idx_")]) + + test_point = m.initial_point() + np.testing.assert_almost_equal( + m.compile_logp()(test_point), + ref_m.compile_logp()(test_point), + ) + class TestRecoverMarginals: def test_basic(self): @@ -608,7 +791,7 @@ def test_batched(self): with MarginalModel() as m: sigma = pm.HalfNormal("sigma") idx = pm.Bernoulli("idx", p=0.7, shape=(3, 2)) - y = pm.Normal("y", mu=idx, sigma=sigma, shape=(3, 2)) + y = pm.Normal("y", mu=idx.T, sigma=sigma, shape=(2, 3)) m.marginalize([idx]) @@ -626,10 +809,9 @@ def test_batched(self): idata = m.recover_marginals(idata, return_samples=True) post = idata.posterior - assert "idx" in post - assert "lp_idx" in post - assert post.idx.shape == post.y.shape - assert post.lp_idx.shape == (*post.idx.shape, 2) + assert post["y"].shape == (1, 20, 2, 3) + assert post["idx"].shape == (1, 20, 3, 2) + assert post["lp_idx"].shape == (1, 20, 3, 2, 2) def test_nested(self): """Test that marginalization works when there are nested marginalized RVs"""