diff --git a/funsor/gaussian.py b/funsor/gaussian.py index 50d5dfc2..8874a0c6 100644 --- a/funsor/gaussian.py +++ b/funsor/gaussian.py @@ -965,6 +965,35 @@ def eager_reduce(self, op, reduced_vars): return Gaussian(white_vec, prec_sqrt, inputs) + elif op is ops.max: + # Marginalize out real variables, but keep mixtures lazy. + assert all(v in self.inputs for v in reduced_vars) + real_vars = frozenset( + k for k, d in self.inputs.items() if d.dtype == "real" + ) + reduced_reals = reduced_vars & real_vars + reduced_ints = reduced_vars - real_vars + if reduced_ints: + raise NotImplementedError("TODO argmax over Gaussian mixtures") + + inputs = OrderedDict( + (k, d) for k, d in self.inputs.items() if k not in reduced_reals + ) + int_inputs = OrderedDict( + (k, v) for k, v in inputs.items() if v.dtype != "real" + ) + if reduced_reals == real_vars: + if self.rank <= self.prec_sqrt.shape[-2]: + return 0.0 + # Otherwise compress. + white_vec, prec_sqrt, shift = _compress_rank( + self.white_vec, self.prec_sqrt + ) + return Tensor(shift, int_inputs) + + # FIXME + raise NotImplementedError("TODO partial max") + return None # defer to default implementation def _sample(self, sampled_vars, sample_inputs, rng_key): diff --git a/funsor/recipes.py b/funsor/recipes.py index 1e52b2ee..d93b81e5 100644 --- a/funsor/recipes.py +++ b/funsor/recipes.py @@ -82,3 +82,65 @@ def forward_filter_backward_rsample( assert set(log_prob.inputs) == set(sample_inputs) return samples, log_prob + + +def forward_max_backward_argmax( + factors: Dict[str, funsor.Funsor], + eliminate: FrozenSet[str], + plates: FrozenSet[str], +): + """ + A forward-filter backward-batched-reparametrized-sample algorithm for use + in variational inference. The motivating use case is performing Gaussian + tensor variable elimination over structured variational posteriors. + + :param dict factors: A dictionary mapping sample site name to a Funsor + factor created at that sample site. + :param frozenset: A set of names of latent variables to marginalize and + plates to aggregate. + :param plates: A set of names of plates to aggregate. + :returns: A pair ``samples:Dict[str, Tensor], log_prob: Tensor`` of samples + and log density evaluated at each of those samples. If ``sample_inputs`` + is nonempty, both outputs will be batched. + :rtype: tuple + """ + assert isinstance(factors, dict) + assert all(isinstance(k, str) for k in factors) + assert all(isinstance(v, funsor.Funsor) for v in factors.values()) + assert isinstance(eliminate, frozenset) + assert all(isinstance(v, str) for v in eliminate) + assert isinstance(plates, frozenset) + assert all(isinstance(v, str) for v in plates) + + # Perform tensor variable elimination. + with funsor.interpretations.reflect: + log_Z = funsor.sum_product.sum_product( + funsor.ops.max, + funsor.ops.add, + list(factors.values()), + eliminate, + plates, + ) + log_Z = funsor.optimizer.apply_optimizer(log_Z) + with funsor.approximations.argmax_approximate: + log_Z, marginals = funsor.adjoint.forward_backward( + funsor.ops.max, funsor.ops.add, log_Z + ) + + # Extract sample tensors. + samples = {} + for name, factor in factors.items(): + if name in eliminate: + samples.update(funsor.montecarlo.extract_samples(marginals[factor])) + assert frozenset(samples) == eliminate - plates + + # Compute log density at each sample. + log_prob = -log_Z + for f in factors.values(): + term = f(**samples) + plates = eliminate.intersection(term.inputs) + term = term.reduce(funsor.ops.add, plates) + log_prob += term + assert not log_prob.inputs + + return samples, log_prob diff --git a/test/test_recipes.py b/test/test_recipes.py index 8eb249a7..e9fae58b 100644 --- a/test/test_recipes.py +++ b/test/test_recipes.py @@ -10,7 +10,7 @@ import funsor.ops as ops from funsor.domains import Bint, Real, Reals from funsor.montecarlo import extract_samples -from funsor.recipes import forward_filter_backward_rsample +from funsor.recipes import forward_filter_backward_rsample, forward_max_backward_argmax from funsor.terms import Lambda, Variable from funsor.testing import assert_close, random_gaussian from funsor.util import get_backend @@ -38,6 +38,15 @@ def get_moments(samples): return moments +def subs_factors(factors, plates, subs): + result = 0.0 + for factor in factors.values(): + f = factor(**subs) + f = f.reduce(ops.add, plates.intersection(f.inputs)) + result += f + return result + + def check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob): """ This can be seen as performing naive tensor variable elimination by @@ -60,14 +69,8 @@ def check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob): domain = Reals[broken_shape + factor.inputs[name].shape] flat_vars[name] = Variable("flat_" + name, domain)[broken_plates[name]] - flat_factors = [] - for factor in factors.values(): - f = factor(**flat_vars) - f = f.reduce(ops.add, plates.intersection(f.inputs)) - flat_factors.append(f) - # Check log prob. - flat_joint = sum(flat_factors) + flat_joint = subs_factors(factors, plates, flat_vars) log_Z = flat_joint.reduce(ops.logaddexp) flat_samples = {} for k, v in actual_samples.items(): @@ -92,34 +95,51 @@ def check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob): assert_close(actual_moments, expected_moments, atol=0.02, rtol=None) -def test_ffbr_1(): +def check_fmba(factors, eliminate, plates, actual_samples, actual_log_prob): + pass # TODO + + +@pytest.mark.parametrize("backward", ["sample", "argmax"]) +def test_ffbr_1(backward): """ def model(data): a = pyro.sample("a", dist.Normal(0, 1)) pyro.sample("b", dist.Normal(a, 1), obs=data) """ num_samples = int(1e5) - factors = { "a": random_gaussian(OrderedDict({"a": Real})), "b": random_gaussian(OrderedDict({"a": Real})), } eliminate = frozenset(["a"]) plates = frozenset() - sample_inputs = OrderedDict(particle=Bint[num_samples]) - - rng_key = None if get_backend() != "jax" else np.array([0, 0], dtype=np.uint32) - actual_samples, actual_log_prob = forward_filter_backward_rsample( - factors, eliminate, plates, sample_inputs, rng_key - ) - assert set(actual_samples) == {"a"} - assert actual_samples["a"].output == Real - assert set(actual_samples["a"].inputs) == {"particle"} - check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob) - - -def test_ffbr_2(): + if backward == "sample": + sample_inputs = OrderedDict(particle=Bint[num_samples]) + rng_key = None if get_backend() != "jax" else np.array([0, 0], dtype=np.uint32) + actual_samples, actual_log_prob = forward_filter_backward_rsample( + factors, eliminate, plates, sample_inputs, rng_key + ) + assert set(actual_samples) == {"a"} + assert actual_samples["a"].output == Real + assert set(actual_samples["a"].inputs) == {"particle"} + check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob) + + elif backward == "argmax": + actual_samples, actual_log_prob = forward_max_backward_argmax( + factors, eliminate, plates + ) + assert set(actual_samples) == {"a"} + assert actual_samples["a"].output == Real + assert not actual_samples["a"].inputs + check_fmba(factors, eliminate, plates, actual_samples, actual_log_prob) + + else: + raise ValueError(backward) + + +@pytest.mark.parametrize("backward", ["sample", "argmax"]) +def test_ffbr_2(backward): """ def model(data): a = pyro.sample("a", dist.Normal(0, 1)) @@ -127,7 +147,6 @@ def model(data): pyro.sample("c", dist.Normal(a, b.exp()), obs=data) """ num_samples = int(1e5) - factors = { "a": random_gaussian(OrderedDict({"a": Real})), "b": random_gaussian(OrderedDict({"b": Real})), @@ -135,22 +154,37 @@ def model(data): } eliminate = frozenset(["a", "b"]) plates = frozenset() - sample_inputs = {"particle": Bint[num_samples]} - rng_key = None if get_backend() != "jax" else np.array([0, 0], dtype=np.uint32) - actual_samples, actual_log_prob = forward_filter_backward_rsample( - factors, eliminate, plates, sample_inputs, rng_key - ) - assert set(actual_samples) == {"a", "b"} - assert actual_samples["a"].output == Real - assert actual_samples["b"].output == Real - assert set(actual_samples["a"].inputs) == {"particle"} - assert set(actual_samples["b"].inputs) == {"particle"} - - check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob) - - -def test_ffbr_3(): + if backward == "sample": + sample_inputs = {"particle": Bint[num_samples]} + rng_key = None if get_backend() != "jax" else np.array([0, 0], dtype=np.uint32) + actual_samples, actual_log_prob = forward_filter_backward_rsample( + factors, eliminate, plates, sample_inputs, rng_key + ) + assert set(actual_samples) == {"a", "b"} + assert actual_samples["a"].output == Real + assert actual_samples["b"].output == Real + assert set(actual_samples["a"].inputs) == {"particle"} + assert set(actual_samples["b"].inputs) == {"particle"} + check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob) + + elif backward == "argmax": + actual_samples, actual_log_prob = forward_max_backward_argmax( + factors, eliminate, plates + ) + assert set(actual_samples) == {"a", "b"} + assert actual_samples["a"].output == Real + assert actual_samples["b"].output == Real + assert not actual_samples["a"].inputs + assert not actual_samples["b"].inputs + check_fmba(factors, eliminate, plates, actual_samples, actual_log_prob) + + else: + raise ValueError(backward) + + +@pytest.mark.parametrize("backward", ["sample", "argmax"]) +def test_ffbr_3(backward): """ def model(data): a = pyro.sample("a", dist.Normal(0, 1)) @@ -159,7 +193,6 @@ def model(data): pyro.sample("c", dist.Normal(a, b.exp()), obs=data) """ num_samples = int(1e5) - factors = { "a": random_gaussian(OrderedDict({"a": Real})), "b": random_gaussian(OrderedDict({"i": Bint[2], "b": Real})), @@ -169,17 +202,31 @@ def model(data): plates = frozenset(["i"]) sample_inputs = {"particle": Bint[num_samples]} - rng_key = None if get_backend() != "jax" else np.array([0, 0], dtype=np.uint32) - actual_samples, actual_log_prob = forward_filter_backward_rsample( - factors, eliminate, plates, sample_inputs, rng_key - ) - assert set(actual_samples) == {"a", "b"} - assert actual_samples["a"].output == Real - assert actual_samples["b"].output == Real - assert set(actual_samples["a"].inputs) == {"particle"} - assert set(actual_samples["b"].inputs) == {"particle", "i"} - - check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob) + if backward == "sample": + rng_key = None if get_backend() != "jax" else np.array([0, 0], dtype=np.uint32) + actual_samples, actual_log_prob = forward_filter_backward_rsample( + factors, eliminate, plates, sample_inputs, rng_key + ) + assert set(actual_samples) == {"a", "b"} + assert actual_samples["a"].output == Real + assert actual_samples["b"].output == Real + assert set(actual_samples["a"].inputs) == {"particle"} + assert set(actual_samples["b"].inputs) == {"particle", "i"} + check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob) + + elif backward == "argmax": + actual_samples, actual_log_prob = forward_max_backward_argmax( + factors, eliminate, plates + ) + assert set(actual_samples) == {"a", "b"} + assert actual_samples["a"].output == Real + assert actual_samples["b"].output == Real + assert set(actual_samples["a"].inputs) == set() + assert set(actual_samples["b"].inputs) == {"i"} + check_fmba(factors, eliminate, plates, actual_samples, actual_log_prob) + + else: + raise ValueError(backward) def test_ffbr_4():