Skip to content

Commit

Permalink
test: add tests for autoreparam
Browse files Browse the repository at this point in the history
  • Loading branch information
ferrine committed Jul 19, 2024
1 parent 02d42d6 commit f4a9bc2
Showing 1 changed file with 22 additions and 17 deletions.
39 changes: 22 additions & 17 deletions pymc_experimental/tests/model/transforms/test_autoreparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def model_c():
m = pm.Normal("m")
s = pm.LogNormal("s")
pm.Normal("g", m, s, shape=5)
pm.Exponential("e", scale=s, shape=7)
return mod


Expand All @@ -20,31 +21,34 @@ def model_nc():
m = pm.Normal("m")
s = pm.LogNormal("s")
pm.Deterministic("g", pm.Normal("z", shape=5) * s + m)
pm.Deterministic("e", pm.Exponential("z_e", 1, shape=7) * s)
return mod


def test_reparametrize_created(model_c: pm.Model):
model_reparam, vip = vip_reparametrize(model_c, ["g"])
assert "g" in vip.get_lambda()
assert "g::lam_logit__" in model_reparam.named_vars
assert "g::tau_" in model_reparam.named_vars
@pytest.mark.parameterize("var", ["g", "e"])
def test_reparametrize_created(model_c: pm.Model, var):
model_reparam, vip = vip_reparametrize(model_c, [var])
assert f"{var}" in vip.get_lambda()
assert f"{var}::lam_logit__" in model_reparam.named_vars
assert f"{var}::tau_" in model_reparam.named_vars
vip.set_all_lambda(1)
assert ~np.isfinite(model_reparam["g::lam_logit__"].get_value()).any()
assert ~np.isfinite(model_reparam[f"{var}::lam_logit__"].get_value()).any()


def test_random_draw(model_c: pm.Model, model_nc):
@pytest.mark.parameterize("var", ["g", "e"])
def test_random_draw(model_c: pm.Model, model_nc, var):
model_c = pm.do(model_c, {"m": 3, "s": 2})
model_nc = pm.do(model_nc, {"m": 3, "s": 2})
model_v, vip = vip_reparametrize(model_c, ["g"])
assert "g" in [v.name for v in model_v.deterministics]
c = pm.draw(model_c["g"], random_seed=42, draws=1000)
nc = pm.draw(model_nc["g"], random_seed=42, draws=1000)
model_v, vip = vip_reparametrize(model_c, [var])
assert var in [v.name for v in model_v.deterministics]
c = pm.draw(model_c[var], random_seed=42, draws=1000)
nc = pm.draw(model_nc[var], random_seed=42, draws=1000)
vip.set_all_lambda(1)
v_1 = pm.draw(model_v["g"], random_seed=42, draws=1000)
v_1 = pm.draw(model_v[var], random_seed=42, draws=1000)
vip.set_all_lambda(0)
v_0 = pm.draw(model_v["g"], random_seed=42, draws=1000)
v_0 = pm.draw(model_v[var], random_seed=42, draws=1000)
vip.set_all_lambda(0.5)
v_05 = pm.draw(model_v["g"], random_seed=42, draws=1000)
v_05 = pm.draw(model_v[var], random_seed=42, draws=1000)
np.testing.assert_allclose(c.mean(), nc.mean())
np.testing.assert_allclose(c.mean(), v_0.mean())
np.testing.assert_allclose(v_05.mean(), v_1.mean())
Expand All @@ -56,11 +60,12 @@ def test_random_draw(model_c: pm.Model, model_nc):
np.testing.assert_allclose(v_1.std(), nc.std())


def test_reparam_fit(model_c):
model_v, vip = vip_reparametrize(model_c, ["g"])
@pytest.mark.parameterize("var", ["g", "e"])
def test_reparam_fit(model_c, var):
model_v, vip = vip_reparametrize(model_c, [var])
with model_v:
vip.fit(random_seed=42)
np.testing.assert_allclose(vip.get_lambda()["g"], 0, atol=0.01)
np.testing.assert_allclose(vip.get_lambda()[var], 0, atol=0.01)


def test_multilevel():
Expand Down

0 comments on commit f4a9bc2

Please sign in to comment.