From 089e8b3c210a05568f32605eb35aa39904b2b02f Mon Sep 17 00:00:00 2001 From: Jing Xie Date: Fri, 24 Feb 2023 10:55:32 -0500 Subject: [PATCH] Add Uniform pareto conjugates --- aemcmc/conjugates.py | 88 +++++++++++++++++++++++++++++++++++++++- tests/test_conjugates.py | 23 +++++++++++ 2 files changed, 110 insertions(+), 1 deletion(-) diff --git a/aemcmc/conjugates.py b/aemcmc/conjugates.py index d49d7dd..cfdb922 100644 --- a/aemcmc/conjugates.py +++ b/aemcmc/conjugates.py @@ -2,7 +2,7 @@ from aesara.graph.rewriting.basic import in2out, node_rewriter from aesara.graph.rewriting.db import LocalGroupDB from aesara.graph.rewriting.unify import eval_if_etuple -from aesara.tensor.random.basic import BinomialRV, NegBinomialRV, PoissonRV +from aesara.tensor.random.basic import BinomialRV, NegBinomialRV, PoissonRV, UniformRV from etuples import etuple, etuplize from kanren import eq, lall, run from unification import var @@ -268,6 +268,91 @@ def local_beta_negative_binomial_posterior(fgraph, node): return rv_var.owner.outputs +def uniform_pareto_conjugateo(observed_val, observed_rv_expr, posterior_expr): + r"""Produce a goal that represents the application of Bayes theorem + for a pareto prior with a uniform with 0 as the lower bound observation model. + + .. math:: + Y \sim \operatorname{Uniform}\left(0, \theta\right) + + + + Parameters + ---------- + observed_val + The observed value. + observed_rv_expr + An expression that represents the observed variable. + posterior_exp + An expression that represents the posterior distribution of the latent + variable. + + """ + # beta-negative_binomial observation model + x_lv, k_lv = var(), var() + theta_rng_lv = var() + theta_size_lv = var() + theta_type_idx_lv = var() + theta_et = etuple( + etuplize(at.random.pareto), + theta_rng_lv, + theta_size_lv, + theta_type_idx_lv, + k_lv, + x_lv, + ) + Y_et = etuple(etuplize(at.random.uniform), var(), var(), var(), 1, theta_et) + + # new_x_et = at.max(observed_val) + new_x_et = at.max(observed_val, x_lv) + new_k_et = etuple(etuplize(at.add), k_lv, 1) + + theta_posterior_et = etuple( + etuplize(at.random.pareto), + new_k_et, + new_x_et, + rng=theta_rng_lv, + size=theta_size_lv, + dtype=theta_type_idx_lv, + ) + + return lall( + eq(observed_rv_expr, Y_et), + eq(posterior_expr, theta_posterior_et), + ) + + +@node_rewriter([UniformRV]) +def local_uniform_pareto_posterior(fgraph, node): + sampler_mappings = getattr(fgraph, "sampler_mappings", None) + + rv_var = node.outputs[1] + key = ("local_beta_negative_binomial_posterior", rv_var) + + if sampler_mappings is None or key in sampler_mappings.rvs_seen: + return None # pragma: no cover + + q = var() + + rv_et = etuplize(rv_var) + + res = run(None, q, uniform_pareto_conjugateo(rv_var, rv_et, q)) + res = next(res, None) + + if res is None: + return None # pragma: no cover + + pareto_rv = rv_et[-1].evaled_obj + pareto_posterior = eval_if_etuple(res) + + sampler_mappings.rvs_to_samplers.setdefault(pareto_rv, []).append( + ("local_uniform_pareto_posterior", pareto_posterior, None) + ) + sampler_mappings.rvs_seen.add(key) + + return rv_var.owner.outputs + + conjugates_db = LocalGroupDB(apply_all_rewrites=True) conjugates_db.name = "conjugates_db" conjugates_db.register("beta_binomial", local_beta_binomial_posterior, "basic") @@ -275,6 +360,7 @@ def local_beta_negative_binomial_posterior(fgraph, node): conjugates_db.register( "negative_binomial", local_beta_negative_binomial_posterior, "basic" ) +conjugates_db.register("uniform", local_uniform_pareto_posterior, "basic") sampler_finder_db.register( diff --git a/tests/test_conjugates.py b/tests/test_conjugates.py index a86fa06..bfc7f4a 100644 --- a/tests/test_conjugates.py +++ b/tests/test_conjugates.py @@ -10,6 +10,7 @@ beta_binomial_conjugateo, beta_negative_binomial_conjugateo, gamma_poisson_conjugateo, + uniform_pareto_conjugateo, ) @@ -157,3 +158,25 @@ def test_beta_negative_binomial_conjugate_expand(): expanded = eval_if_etuple(expanded_expr) assert isinstance(expanded.owner.op, type(at.random.beta)) + + +def test_uniform_pareto_conjugate_contract(): + """Produce the closed-form posterior for the uniform observation model with + a pareto prior. + + """ + srng = RandomStream(0) + + xm_tt = at.scalar("xm") + k_tt = at.scalar("k") + theta_rv = srng.pareto(k_tt, xm_tt, name="theta") + + Y_rv = srng.uniform(0, theta_rv) + y_vv = Y_rv.clone() + y_vv.tag.name = "y" + + q_lv = var() + (posterior_expr,) = run(1, q_lv, uniform_pareto_conjugateo(y_vv, Y_rv, q_lv)) + posterior = eval_if_etuple(posterior_expr) + + assert isinstance(posterior.owner.op, type(at.random.pareto))