Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement a forward_max_backward_argmax() recipe #576

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions funsor/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,35 @@ def eager_reduce(self, op, reduced_vars):

return Gaussian(white_vec, prec_sqrt, inputs)

elif op is ops.max:
# Marginalize out real variables, but keep mixtures lazy.
assert all(v in self.inputs for v in reduced_vars)
real_vars = frozenset(
k for k, d in self.inputs.items() if d.dtype == "real"
)
reduced_reals = reduced_vars & real_vars
reduced_ints = reduced_vars - real_vars
if reduced_ints:
raise NotImplementedError("TODO argmax over Gaussian mixtures")

inputs = OrderedDict(
(k, d) for k, d in self.inputs.items() if k not in reduced_reals
)
int_inputs = OrderedDict(
(k, v) for k, v in inputs.items() if v.dtype != "real"
)
if reduced_reals == real_vars:
if self.rank <= self.prec_sqrt.shape[-2]:
return 0.0
# Otherwise compress.
white_vec, prec_sqrt, shift = _compress_rank(
self.white_vec, self.prec_sqrt
)
return Tensor(shift, int_inputs)

# FIXME
raise NotImplementedError("TODO partial max")

return None # defer to default implementation

def _sample(self, sampled_vars, sample_inputs, rng_key):
Expand Down
62 changes: 62 additions & 0 deletions funsor/recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,65 @@ def forward_filter_backward_rsample(
assert set(log_prob.inputs) == set(sample_inputs)

return samples, log_prob


def forward_max_backward_argmax(
factors: Dict[str, funsor.Funsor],
eliminate: FrozenSet[str],
plates: FrozenSet[str],
):
"""
A forward-filter backward-batched-reparametrized-sample algorithm for use
in variational inference. The motivating use case is performing Gaussian
tensor variable elimination over structured variational posteriors.

:param dict factors: A dictionary mapping sample site name to a Funsor
factor created at that sample site.
:param frozenset: A set of names of latent variables to marginalize and
plates to aggregate.
:param plates: A set of names of plates to aggregate.
:returns: A pair ``samples:Dict[str, Tensor], log_prob: Tensor`` of samples
and log density evaluated at each of those samples. If ``sample_inputs``
is nonempty, both outputs will be batched.
:rtype: tuple
"""
assert isinstance(factors, dict)
assert all(isinstance(k, str) for k in factors)
assert all(isinstance(v, funsor.Funsor) for v in factors.values())
assert isinstance(eliminate, frozenset)
assert all(isinstance(v, str) for v in eliminate)
assert isinstance(plates, frozenset)
assert all(isinstance(v, str) for v in plates)

# Perform tensor variable elimination.
with funsor.interpretations.reflect:
log_Z = funsor.sum_product.sum_product(
funsor.ops.max,
funsor.ops.add,
list(factors.values()),
eliminate,
plates,
)
log_Z = funsor.optimizer.apply_optimizer(log_Z)
with funsor.approximations.argmax_approximate:
log_Z, marginals = funsor.adjoint.forward_backward(
funsor.ops.max, funsor.ops.add, log_Z
)

# Extract sample tensors.
samples = {}
for name, factor in factors.items():
if name in eliminate:
samples.update(funsor.montecarlo.extract_samples(marginals[factor]))
assert frozenset(samples) == eliminate - plates

# Compute log density at each sample.
log_prob = -log_Z
for f in factors.values():
term = f(**samples)
plates = eliminate.intersection(term.inputs)
term = term.reduce(funsor.ops.add, plates)
log_prob += term
assert not log_prob.inputs

return samples, log_prob
149 changes: 98 additions & 51 deletions test/test_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import funsor.ops as ops
from funsor.domains import Bint, Real, Reals
from funsor.montecarlo import extract_samples
from funsor.recipes import forward_filter_backward_rsample
from funsor.recipes import forward_filter_backward_rsample, forward_max_backward_argmax
from funsor.terms import Lambda, Variable
from funsor.testing import assert_close, random_gaussian
from funsor.util import get_backend
Expand Down Expand Up @@ -38,6 +38,15 @@ def get_moments(samples):
return moments


def subs_factors(factors, plates, subs):
result = 0.0
for factor in factors.values():
f = factor(**subs)
f = f.reduce(ops.add, plates.intersection(f.inputs))
result += f
return result


def check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob):
"""
This can be seen as performing naive tensor variable elimination by
Expand All @@ -60,14 +69,8 @@ def check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob):
domain = Reals[broken_shape + factor.inputs[name].shape]
flat_vars[name] = Variable("flat_" + name, domain)[broken_plates[name]]

flat_factors = []
for factor in factors.values():
f = factor(**flat_vars)
f = f.reduce(ops.add, plates.intersection(f.inputs))
flat_factors.append(f)

# Check log prob.
flat_joint = sum(flat_factors)
flat_joint = subs_factors(factors, plates, flat_vars)
log_Z = flat_joint.reduce(ops.logaddexp)
flat_samples = {}
for k, v in actual_samples.items():
Expand All @@ -92,65 +95,96 @@ def check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob):
assert_close(actual_moments, expected_moments, atol=0.02, rtol=None)


def test_ffbr_1():
def check_fmba(factors, eliminate, plates, actual_samples, actual_log_prob):
pass # TODO


@pytest.mark.parametrize("backward", ["sample", "argmax"])
def test_ffbr_1(backward):
"""
def model(data):
a = pyro.sample("a", dist.Normal(0, 1))
pyro.sample("b", dist.Normal(a, 1), obs=data)
"""
num_samples = int(1e5)

factors = {
"a": random_gaussian(OrderedDict({"a": Real})),
"b": random_gaussian(OrderedDict({"a": Real})),
}
eliminate = frozenset(["a"])
plates = frozenset()
sample_inputs = OrderedDict(particle=Bint[num_samples])

rng_key = None if get_backend() != "jax" else np.array([0, 0], dtype=np.uint32)
actual_samples, actual_log_prob = forward_filter_backward_rsample(
factors, eliminate, plates, sample_inputs, rng_key
)
assert set(actual_samples) == {"a"}
assert actual_samples["a"].output == Real
assert set(actual_samples["a"].inputs) == {"particle"}

check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob)


def test_ffbr_2():
if backward == "sample":
sample_inputs = OrderedDict(particle=Bint[num_samples])
rng_key = None if get_backend() != "jax" else np.array([0, 0], dtype=np.uint32)
actual_samples, actual_log_prob = forward_filter_backward_rsample(
factors, eliminate, plates, sample_inputs, rng_key
)
assert set(actual_samples) == {"a"}
assert actual_samples["a"].output == Real
assert set(actual_samples["a"].inputs) == {"particle"}
check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob)

elif backward == "argmax":
actual_samples, actual_log_prob = forward_max_backward_argmax(
factors, eliminate, plates
)
assert set(actual_samples) == {"a"}
assert actual_samples["a"].output == Real
assert not actual_samples["a"].inputs
check_fmba(factors, eliminate, plates, actual_samples, actual_log_prob)

else:
raise ValueError(backward)


@pytest.mark.parametrize("backward", ["sample", "argmax"])
def test_ffbr_2(backward):
"""
def model(data):
a = pyro.sample("a", dist.Normal(0, 1))
b = pyro.sample("b", dist.Normal(0, 1))
pyro.sample("c", dist.Normal(a, b.exp()), obs=data)
"""
num_samples = int(1e5)

factors = {
"a": random_gaussian(OrderedDict({"a": Real})),
"b": random_gaussian(OrderedDict({"b": Real})),
"c": random_gaussian(OrderedDict({"a": Real, "b": Real})),
}
eliminate = frozenset(["a", "b"])
plates = frozenset()
sample_inputs = {"particle": Bint[num_samples]}

rng_key = None if get_backend() != "jax" else np.array([0, 0], dtype=np.uint32)
actual_samples, actual_log_prob = forward_filter_backward_rsample(
factors, eliminate, plates, sample_inputs, rng_key
)
assert set(actual_samples) == {"a", "b"}
assert actual_samples["a"].output == Real
assert actual_samples["b"].output == Real
assert set(actual_samples["a"].inputs) == {"particle"}
assert set(actual_samples["b"].inputs) == {"particle"}

check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob)


def test_ffbr_3():
if backward == "sample":
sample_inputs = {"particle": Bint[num_samples]}
rng_key = None if get_backend() != "jax" else np.array([0, 0], dtype=np.uint32)
actual_samples, actual_log_prob = forward_filter_backward_rsample(
factors, eliminate, plates, sample_inputs, rng_key
)
assert set(actual_samples) == {"a", "b"}
assert actual_samples["a"].output == Real
assert actual_samples["b"].output == Real
assert set(actual_samples["a"].inputs) == {"particle"}
assert set(actual_samples["b"].inputs) == {"particle"}
check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob)

elif backward == "argmax":
actual_samples, actual_log_prob = forward_max_backward_argmax(
factors, eliminate, plates
)
assert set(actual_samples) == {"a", "b"}
assert actual_samples["a"].output == Real
assert actual_samples["b"].output == Real
assert not actual_samples["a"].inputs
assert not actual_samples["b"].inputs
check_fmba(factors, eliminate, plates, actual_samples, actual_log_prob)

else:
raise ValueError(backward)


@pytest.mark.parametrize("backward", ["sample", "argmax"])
def test_ffbr_3(backward):
"""
def model(data):
a = pyro.sample("a", dist.Normal(0, 1))
Expand All @@ -159,7 +193,6 @@ def model(data):
pyro.sample("c", dist.Normal(a, b.exp()), obs=data)
"""
num_samples = int(1e5)

factors = {
"a": random_gaussian(OrderedDict({"a": Real})),
"b": random_gaussian(OrderedDict({"i": Bint[2], "b": Real})),
Expand All @@ -169,17 +202,31 @@ def model(data):
plates = frozenset(["i"])
sample_inputs = {"particle": Bint[num_samples]}

rng_key = None if get_backend() != "jax" else np.array([0, 0], dtype=np.uint32)
actual_samples, actual_log_prob = forward_filter_backward_rsample(
factors, eliminate, plates, sample_inputs, rng_key
)
assert set(actual_samples) == {"a", "b"}
assert actual_samples["a"].output == Real
assert actual_samples["b"].output == Real
assert set(actual_samples["a"].inputs) == {"particle"}
assert set(actual_samples["b"].inputs) == {"particle", "i"}

check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob)
if backward == "sample":
rng_key = None if get_backend() != "jax" else np.array([0, 0], dtype=np.uint32)
actual_samples, actual_log_prob = forward_filter_backward_rsample(
factors, eliminate, plates, sample_inputs, rng_key
)
assert set(actual_samples) == {"a", "b"}
assert actual_samples["a"].output == Real
assert actual_samples["b"].output == Real
assert set(actual_samples["a"].inputs) == {"particle"}
assert set(actual_samples["b"].inputs) == {"particle", "i"}
check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob)

elif backward == "argmax":
actual_samples, actual_log_prob = forward_max_backward_argmax(
factors, eliminate, plates
)
assert set(actual_samples) == {"a", "b"}
assert actual_samples["a"].output == Real
assert actual_samples["b"].output == Real
assert set(actual_samples["a"].inputs) == set()
assert set(actual_samples["b"].inputs) == {"i"}
check_fmba(factors, eliminate, plates, actual_samples, actual_log_prob)

else:
raise ValueError(backward)


def test_ffbr_4():
Expand Down