From fec6051ce6aaade1ecb9b2202fbe9bc938759078 Mon Sep 17 00:00:00 2001 From: Jing Xie Date: Tue, 14 Feb 2023 15:40:44 -0500 Subject: [PATCH] Add in NB beta conjugates and its tests --- aemcmc/conjugates.py | 94 +++++++++++++++++++++++++++++++++++++++- tests/test_conjugates.py | 59 ++++++++++++++++++++++++- 2 files changed, 151 insertions(+), 2 deletions(-) diff --git a/aemcmc/conjugates.py b/aemcmc/conjugates.py index 07eed5c..f09a901 100644 --- a/aemcmc/conjugates.py +++ b/aemcmc/conjugates.py @@ -2,7 +2,7 @@ from aesara.graph.rewriting.basic import in2out, node_rewriter from aesara.graph.rewriting.db import LocalGroupDB from aesara.graph.rewriting.unify import eval_if_etuple -from aesara.tensor.random.basic import BinomialRV, PoissonRV +from aesara.tensor.random.basic import BinomialRV, NegBinomialRV, PoissonRV from etuples import etuple, etuplize from kanren import eq, lall, run from unification import var @@ -181,10 +181,102 @@ def local_beta_binomial_posterior(fgraph, node): return rv_var.owner.outputs +def beta_negative_binomial_conjugateo(observed_val, observed_rv_expr, posterior_expr): + r"""Produce a goal that represents the application of Bayes theorem + for a beta prior with a negative binomial observation model. + + .. math:: + + \frac{ + Y \sim \operatorname{NB}\left(k, p\right), \quad + p \sim \operatorname{Beta}\left(\alpha, \beta\right) + }{ + \left(p|Y=y\right) \sim \operatorname{Beta}\left(\alpha+\sum^{n}_{i=1} y_i, \beta+kN\right) + } + , k is the number of successes before experiment ended + + + Parameters + ---------- + observed_val + The observed value. + observed_rv_expr + An expression that represents the observed variable. + posterior_exp + An expression that represents the posterior distribution of the latent + variable. + + """ + # beta-negative_binomial observation model + alpha_lv, beta_lv = var(), var() + p_rng_lv = var() + p_size_lv = var() + p_type_idx_lv = var() + p_et = etuple( + etuplize(at.random.beta), p_rng_lv, p_size_lv, p_type_idx_lv, alpha_lv, beta_lv + ) + n_lv = var() # success + Y_et = etuple( + etuplize(at.random.negative_binomial), var(), var(), var(), n_lv, p_et + ) + + new_alpha_et = etuple(etuplize(at.add), alpha_lv, observed_val) + new_beta_et = etuple(etuplize(at.add), beta_lv, n_lv) + p_posterior_et = etuple( + etuplize(at.random.beta), + new_alpha_et, + new_beta_et, + rng=p_rng_lv, + size=p_size_lv, + dtype=p_type_idx_lv, + ) + + return lall( + eq(observed_rv_expr, Y_et), + eq(posterior_expr, p_posterior_et), + ) + + +@node_rewriter([NegBinomialRV]) +def local_beta_negative_binomial_posterior(fgraph, node): + + sampler_mappings = getattr(fgraph, "sampler_mappings", None) + + rv_var = node.outputs[1] + key = ("local_beta_negative_binomial_posterior", rv_var) + + if sampler_mappings is None or key in sampler_mappings.rvs_seen: + return None # pragma: no cover + + q = var() + + rv_et = etuplize(rv_var) + + res = run(None, q, beta_negative_binomial_conjugateo(rv_var, rv_et, q)) + res = next(res, None) + + if res is None: + return None # pragma: no cover + + beta_rv = rv_et[-1].evaled_obj + beta_posterior = eval_if_etuple(res) + + sampler_mappings.rvs_to_samplers.setdefault(beta_rv, []).append( + ("local_beta_negative_binomial_posterior", beta_posterior, None) + ) + sampler_mappings.rvs_seen.add(key) + + return rv_var.owner.outputs + + conjugates_db = LocalGroupDB(apply_all_rewrites=True) conjugates_db.name = "conjugates_db" conjugates_db.register("beta_binomial", local_beta_binomial_posterior, "basic") conjugates_db.register("gamma_poisson", local_gamma_poisson_posterior, "basic") +conjugates_db.register( + "negative_binomial", local_beta_negative_binomial_posterior, "basic" +) + sampler_finder_db.register( "conjugates", in2out(conjugates_db.query("+basic"), name="gibbs"), "basic" diff --git a/tests/test_conjugates.py b/tests/test_conjugates.py index 46fa7cd..64a6d00 100644 --- a/tests/test_conjugates.py +++ b/tests/test_conjugates.py @@ -6,7 +6,11 @@ from kanren import run from unification import var -from aemcmc.conjugates import beta_binomial_conjugateo, gamma_poisson_conjugateo +from aemcmc.conjugates import ( + beta_binomial_conjugateo, + beta_negative_binomial_conjugateo, + gamma_poisson_conjugateo, +) def test_gamma_poisson_conjugate_contract(): @@ -101,3 +105,56 @@ def test_beta_binomial_conjugate_expand(): expanded = eval_if_etuple(expanded_expr) assert isinstance(expanded.owner.op, type(at.random.beta)) + + +def test_beta_negative_binomial_conjugate_contract(): + """Produce the closed-form posterior for the binomial observation model with + a beta prior. + + """ + srng = RandomStream(0) + + alpha_tt = at.scalar("alpha") + beta_tt = at.scalar("beta") + p_rv = srng.beta(alpha_tt, beta_tt, name="p") + + n_tt = at.iscalar("n") + Y_rv = srng.negative_binomial(n_tt, p_rv) + y_vv = Y_rv.clone() + y_vv.tag.name = "y" + + q_lv = var() + (posterior_expr,) = run( + 1, q_lv, beta_negative_binomial_conjugateo(y_vv, Y_rv, q_lv) + ) + posterior = eval_if_etuple(posterior_expr) + + assert isinstance(posterior.owner.op, type(at.random.beta)) + + # Build the sampling function and check the results on limiting cases. + sample_fn = aesara.function((alpha_tt, beta_tt, y_vv, n_tt), posterior) + assert sample_fn(1.0, 1.0, 1000, 0) == pytest.approx( + 1.0, abs=0.01 + ) # only successes + assert sample_fn(1.0, 1.0, 0, 1000) == pytest.approx(0.0, abs=0.01) # no success + + +@pytest.mark.xfail( + reason="Op.__call__ does not dispatch to Op.make_node for some RandomVariable and etuple evaluation returns an error" +) +def test_beta_negative_binomial_conjugate_expand(): + """Expand a contracted beta-binomial observation model.""" + + srng = RandomStream(0) + + alpha_tt = at.scalar("alpha") + beta_tt = at.scalar("beta") + y_vv = at.iscalar("y") + n_tt = at.iscalar("n") + Y_rv = srng.beta(alpha_tt + y_vv, beta_tt + n_tt) + + e_lv = var() + (expanded_expr,) = run(1, e_lv, beta_negative_binomial_conjugateo(e_lv, y_vv, Y_rv)) + expanded = eval_if_etuple(expanded_expr) + + assert isinstance(expanded.owner.op, type(at.random.beta))