Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add negative-binomial conjugate rewrite #105

Merged
merged 1 commit into from
Feb 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 92 additions & 1 deletion aemcmc/conjugates.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from aesara.graph.rewriting.basic import in2out, node_rewriter
from aesara.graph.rewriting.db import LocalGroupDB
from aesara.graph.rewriting.unify import eval_if_etuple
from aesara.tensor.random.basic import BinomialRV, PoissonRV
from aesara.tensor.random.basic import BinomialRV, NegBinomialRV, PoissonRV
from etuples import etuple, etuplize
from kanren import eq, lall, run
from unification import var
Expand Down Expand Up @@ -181,10 +181,101 @@ def local_beta_binomial_posterior(fgraph, node):
return rv_var.owner.outputs


def beta_negative_binomial_conjugateo(observed_val, observed_rv_expr, posterior_expr):
r"""Produce a goal that represents the application of Bayes theorem
for a beta prior with a negative binomial observation model.

.. math::

\frac{
Y \sim \operatorname{NB}\left(k, p\right), \quad
p \sim \operatorname{Beta}\left(\alpha, \beta\right)
}{
\left(p \mid Y=y\right) \sim \operatorname{Beta}\left(\alpha + \sum^{n}_{i=1} y_i, \beta + k n\right)
}
where :math:`k` is the number of successes before experiment ended.


Parameters
----------
observed_val
The observed value.
observed_rv_expr
An expression that represents the observed variable.
posterior_exp
An expression that represents the posterior distribution of the latent
variable.

"""
# beta-negative_binomial observation model
alpha_lv, beta_lv = var(), var()
p_rng_lv = var()
p_size_lv = var()
p_type_idx_lv = var()
p_et = etuple(
etuplize(at.random.beta), p_rng_lv, p_size_lv, p_type_idx_lv, alpha_lv, beta_lv
)
n_lv = var() # success
Y_et = etuple(
etuplize(at.random.negative_binomial), var(), var(), var(), n_lv, p_et
)

new_alpha_et = etuple(etuplize(at.add), alpha_lv, observed_val)
new_beta_et = etuple(etuplize(at.add), beta_lv, n_lv)
p_posterior_et = etuple(
etuplize(at.random.beta),
new_alpha_et,
new_beta_et,
rng=p_rng_lv,
size=p_size_lv,
dtype=p_type_idx_lv,
)

return lall(
eq(observed_rv_expr, Y_et),
eq(posterior_expr, p_posterior_et),
)


@node_rewriter([NegBinomialRV])
def local_beta_negative_binomial_posterior(fgraph, node):
sampler_mappings = getattr(fgraph, "sampler_mappings", None)

rv_var = node.outputs[1]
key = ("local_beta_negative_binomial_posterior", rv_var)

if sampler_mappings is None or key in sampler_mappings.rvs_seen:
return None # pragma: no cover

q = var()

rv_et = etuplize(rv_var)

res = run(None, q, beta_negative_binomial_conjugateo(rv_var, rv_et, q))
res = next(res, None)

if res is None:
return None # pragma: no cover

beta_rv = rv_et[-1].evaled_obj
beta_posterior = eval_if_etuple(res)

sampler_mappings.rvs_to_samplers.setdefault(beta_rv, []).append(
("local_beta_negative_binomial_posterior", beta_posterior, None)
)
sampler_mappings.rvs_seen.add(key)

return rv_var.owner.outputs


conjugates_db = LocalGroupDB(apply_all_rewrites=True)
conjugates_db.name = "conjugates_db"
conjugates_db.register("beta_binomial", local_beta_binomial_posterior, "basic")
conjugates_db.register("gamma_poisson", local_gamma_poisson_posterior, "basic")
conjugates_db.register(
"negative_binomial", local_beta_negative_binomial_posterior, "basic"
)


sampler_finder_db.register(
"conjugates", in2out(conjugates_db.query("+basic"), name="gibbs"), "basic"
Expand Down
22 changes: 22 additions & 0 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,28 @@ def test_closed_form_posterior_gamma_poisson():
assert isinstance(p_posterior_step.owner.op, GammaRV)


def test_closed_form_posterior_beta_nbinom():
srng = RandomStream(0)

alpha_tt = at.scalar("alpha")
beta_tt = at.scalar("beta")

p_rv = srng.beta(alpha_tt, beta_tt, name="p")

n_tt = at.scalar("n")
Y_rv = srng.negative_binomial(n_tt, p_rv, name="Y")

y_vv = Y_rv.clone()
y_vv.name = "y"

sampler, initial_values = construct_sampler({Y_rv: y_vv}, srng)

p_posterior_step = sampler.sample_steps[p_rv]
assert len(sampler.parameters) == 0
assert len(sampler.stages) == 1
assert isinstance(p_posterior_step.owner.op, BetaRV)


@pytest.mark.parametrize("size", [1, (1,), (2, 3)])
def test_nuts_sampler_single_variable(size):
"""We make sure that the NUTS sampler compiles and updates the chains for
Expand Down
59 changes: 58 additions & 1 deletion tests/test_conjugates.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
from kanren import run
from unification import var

from aemcmc.conjugates import beta_binomial_conjugateo, gamma_poisson_conjugateo
from aemcmc.conjugates import (
beta_binomial_conjugateo,
beta_negative_binomial_conjugateo,
gamma_poisson_conjugateo,
)


def test_gamma_poisson_conjugate_contract():
Expand Down Expand Up @@ -101,3 +105,56 @@ def test_beta_binomial_conjugate_expand():
expanded = eval_if_etuple(expanded_expr)

assert isinstance(expanded.owner.op, type(at.random.beta))


def test_beta_negative_binomial_conjugate_contract():
"""Produce the closed-form posterior for the binomial observation model with
a beta prior.

"""
srng = RandomStream(0)

alpha_tt = at.scalar("alpha")
beta_tt = at.scalar("beta")
p_rv = srng.beta(alpha_tt, beta_tt, name="p")

n_tt = at.iscalar("n")
Y_rv = srng.negative_binomial(n_tt, p_rv)
y_vv = Y_rv.clone()
y_vv.tag.name = "y"

q_lv = var()
(posterior_expr,) = run(
1, q_lv, beta_negative_binomial_conjugateo(y_vv, Y_rv, q_lv)
)
posterior = eval_if_etuple(posterior_expr)

assert isinstance(posterior.owner.op, type(at.random.beta))

# Build the sampling function and check the results on limiting cases.
sample_fn = aesara.function((alpha_tt, beta_tt, y_vv, n_tt), posterior)
assert sample_fn(1.0, 1.0, 1000, 0) == pytest.approx(
1.0, abs=0.01
) # only successes
assert sample_fn(1.0, 1.0, 0, 1000) == pytest.approx(0.0, abs=0.01) # no success


@pytest.mark.xfail(
reason="Op.__call__ does not dispatch to Op.make_node for some RandomVariable and etuple evaluation returns an error"
)
def test_beta_negative_binomial_conjugate_expand():
"""Expand a contracted beta-binomial observation model."""

srng = RandomStream(0)

alpha_tt = at.scalar("alpha")
beta_tt = at.scalar("beta")
y_vv = at.iscalar("y")
n_tt = at.iscalar("n")
Y_rv = srng.beta(alpha_tt + y_vv, beta_tt + n_tt)

e_lv = var()
(expanded_expr,) = run(1, e_lv, beta_negative_binomial_conjugateo(e_lv, y_vv, Y_rv))
expanded = eval_if_etuple(expanded_expr)

assert isinstance(expanded.owner.op, type(at.random.beta))