Skip to content

Commit

Permalink
Reintroduce dummy intermediate variables in implementation of Truncat…
Browse files Browse the repository at this point in the history
…edRV

Partially reverts 9d4a3d7 and 3888d53

The logprob derivation(s) in the icdf implementation of `Truncated` can duplicate nodes and cause spurious input variables to be marked as missing. We replace these by dummies so the graph above is hidden, and variables cannot be accidentally cloned/modified during logprob inference.
  • Loading branch information
ricardoV94 committed May 25, 2024
1 parent 43c5a8e commit 19be124
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 24 deletions.
54 changes: 30 additions & 24 deletions pymc/distributions/truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,28 +106,34 @@ def rv_op(cls, dist, lower, upper, max_n_steps, *, size=None):
]
graph_inputs = [*rv_inputs, lower, upper]

rv = dist.owner.op.make_node(*rv_inputs).default_output()
# Variables with `_` suffix identify dummy inputs for the OpFromGraph
graph_inputs_ = [
inp.type() if not isinstance(inp.type, RandomType) else inp for inp in graph_inputs
]
*rv_inputs_, lower_, upper_ = graph_inputs_

rv_ = dist.owner.op.make_node(*rv_inputs_).default_output()

# Try to use inverted cdf sampling
# truncated_rv = icdf(rv, draw(uniform(cdf(lower), cdf(upper))))
try:
logcdf_lower, logcdf_upper = cls._create_logcdf_exprs(rv, rv, lower, upper)
logcdf_lower_, logcdf_upper_ = TruncatedRV._create_logcdf_exprs(
rv_, rv_, lower_, upper_
)
# We use the first RNG from the base RV, so we don't have to introduce a new one
# This is not problematic because the RNG won't be used in the RV logcdf graph
uniform_rng = next(inp for inp in rv_inputs if isinstance(inp.type, RandomType))
uniform_next_rng, uniform = pt.random.uniform(
pt.exp(logcdf_lower),
pt.exp(logcdf_upper),
rng=uniform_rng,
size=rv.shape,
uniform_rng_ = next(inp_ for inp_ in rv_inputs_ if isinstance(inp_.type, RandomType))
uniform_next_rng_, uniform_ = pt.random.uniform(
pt.exp(logcdf_lower_),
pt.exp(logcdf_upper_),
rng=uniform_rng_,
size=rv_.shape,
).owner.outputs
# So icdf does not see the random graph of uniform
uniform_type = uniform.type()
truncated_rv = graph_replace(icdf(rv, uniform_type), {uniform_type: uniform})
truncated_rv_ = icdf(rv_, uniform_, warn_rvs=False)
return TruncatedRV(
base_rv_op=dist.owner.op,
inputs=graph_inputs,
outputs=[truncated_rv, uniform_next_rng],
inputs=graph_inputs_,
outputs=[truncated_rv_, uniform_next_rng_],
ndim_supp=0,
max_n_steps=max_n_steps,
)(*graph_inputs)
Expand All @@ -154,25 +160,25 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, *rv_inputs):

return (
(truncated_rv, reject_draws),
collect_default_updates(new_truncated_rv, inputs=rv_inputs),
collect_default_updates(new_truncated_rv),
until(~pt.any(reject_draws)),
)

(truncated_rv, reject_draws_), updates = scan(
(truncated_rv_, reject_draws_), updates = scan(
loop_fn,
outputs_info=[
pt.zeros_like(rv),
pt.ones_like(rv, dtype=bool),
pt.zeros_like(rv_),
pt.ones_like(rv_, dtype=bool),
],
non_sequences=[lower, upper, *rv_inputs],
non_sequences=[lower_, upper_, *rv_inputs_],
n_steps=max_n_steps,
strict=True,
)

truncated_rv = truncated_rv[-1]
convergence = ~pt.any(reject_draws_[-1])
truncated_rv = TruncationCheck(f"Truncation did not converge in {max_n_steps} steps")(
truncated_rv, convergence
truncated_rv_ = truncated_rv_[-1]
convergence_ = ~pt.any(reject_draws_[-1])
truncated_rv_ = TruncationCheck(f"Truncation did not converge in {max_n_steps} steps")(
truncated_rv_, convergence_
)

# Sort updates of each RNG so that they show in the same order as the input RNGs
Expand All @@ -184,8 +190,8 @@ def sort_updates(update):

return TruncatedRV(
base_rv_op=dist.owner.op,
inputs=graph_inputs,
outputs=[truncated_rv, *next_rngs],
inputs=graph_inputs_,
outputs=[truncated_rv_, *next_rngs],
ndim_supp=0,
max_n_steps=max_n_steps,
)(*graph_inputs)
Expand Down
14 changes: 14 additions & 0 deletions tests/distributions/test_truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,3 +585,17 @@ def test_truncated_identity_input(dist_op):

rv_out = Truncated.dist(dist=dist_op(mu_identity, 5), lower=0, upper=1)
assert np.ptp(draw(rv_out, draws=500)) < 1


@pytest.mark.parametrize("rv_op", [icdf_normal, rejection_normal])
def test_truncated_custom_dist_indexed_argument(rv_op):
# Regression test for https://github.com/pymc-devs/pymc/issues/7312

def dist(scale, size):
return pt.exp(rv_op(scale=scale, size=size))

scale = Exponential.dist(scale=[1, 2, 3])
latent = CustomDist.dist(scale[[0, 0, 1, 1, 2, 2]], dist=dist)
rv_out = Truncated.dist(latent, upper=7)

assert np.ptp(draw(rv_out, draws=100)) < 7

0 comments on commit 19be124

Please sign in to comment.