diff --git a/pymc/distributions/truncated.py b/pymc/distributions/truncated.py index 45f77574d6..1ef4a1da32 100644 --- a/pymc/distributions/truncated.py +++ b/pymc/distributions/truncated.py @@ -106,28 +106,34 @@ def rv_op(cls, dist, lower, upper, max_n_steps, *, size=None): ] graph_inputs = [*rv_inputs, lower, upper] - rv = dist.owner.op.make_node(*rv_inputs).default_output() + # Variables with `_` suffix identify dummy inputs for the OpFromGraph + graph_inputs_ = [ + inp.type() if not isinstance(inp.type, RandomType) else inp for inp in graph_inputs + ] + *rv_inputs_, lower_, upper_ = graph_inputs_ + + rv_ = dist.owner.op.make_node(*rv_inputs_).default_output() # Try to use inverted cdf sampling # truncated_rv = icdf(rv, draw(uniform(cdf(lower), cdf(upper)))) try: - logcdf_lower, logcdf_upper = cls._create_logcdf_exprs(rv, rv, lower, upper) + logcdf_lower_, logcdf_upper_ = TruncatedRV._create_logcdf_exprs( + rv_, rv_, lower_, upper_ + ) # We use the first RNG from the base RV, so we don't have to introduce a new one # This is not problematic because the RNG won't be used in the RV logcdf graph - uniform_rng = next(inp for inp in rv_inputs if isinstance(inp.type, RandomType)) - uniform_next_rng, uniform = pt.random.uniform( - pt.exp(logcdf_lower), - pt.exp(logcdf_upper), - rng=uniform_rng, - size=rv.shape, + uniform_rng_ = next(inp_ for inp_ in rv_inputs_ if isinstance(inp_.type, RandomType)) + uniform_next_rng_, uniform_ = pt.random.uniform( + pt.exp(logcdf_lower_), + pt.exp(logcdf_upper_), + rng=uniform_rng_, + size=rv_.shape, ).owner.outputs - # So icdf does not see the random graph of uniform - uniform_type = uniform.type() - truncated_rv = graph_replace(icdf(rv, uniform_type), {uniform_type: uniform}) + truncated_rv_ = icdf(rv_, uniform_, warn_rvs=False) return TruncatedRV( base_rv_op=dist.owner.op, - inputs=graph_inputs, - outputs=[truncated_rv, uniform_next_rng], + inputs=graph_inputs_, + outputs=[truncated_rv_, uniform_next_rng_], ndim_supp=0, max_n_steps=max_n_steps, )(*graph_inputs) @@ -154,25 +160,25 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, *rv_inputs): return ( (truncated_rv, reject_draws), - collect_default_updates(new_truncated_rv, inputs=rv_inputs), + collect_default_updates(new_truncated_rv), until(~pt.any(reject_draws)), ) - (truncated_rv, reject_draws_), updates = scan( + (truncated_rv_, reject_draws_), updates = scan( loop_fn, outputs_info=[ - pt.zeros_like(rv), - pt.ones_like(rv, dtype=bool), + pt.zeros_like(rv_), + pt.ones_like(rv_, dtype=bool), ], - non_sequences=[lower, upper, *rv_inputs], + non_sequences=[lower_, upper_, *rv_inputs_], n_steps=max_n_steps, strict=True, ) - truncated_rv = truncated_rv[-1] - convergence = ~pt.any(reject_draws_[-1]) - truncated_rv = TruncationCheck(f"Truncation did not converge in {max_n_steps} steps")( - truncated_rv, convergence + truncated_rv_ = truncated_rv_[-1] + convergence_ = ~pt.any(reject_draws_[-1]) + truncated_rv_ = TruncationCheck(f"Truncation did not converge in {max_n_steps} steps")( + truncated_rv_, convergence_ ) # Sort updates of each RNG so that they show in the same order as the input RNGs @@ -184,8 +190,8 @@ def sort_updates(update): return TruncatedRV( base_rv_op=dist.owner.op, - inputs=graph_inputs, - outputs=[truncated_rv, *next_rngs], + inputs=graph_inputs_, + outputs=[truncated_rv_, *next_rngs], ndim_supp=0, max_n_steps=max_n_steps, )(*graph_inputs) diff --git a/tests/distributions/test_truncated.py b/tests/distributions/test_truncated.py index 4f748f0fbc..50ae98147f 100644 --- a/tests/distributions/test_truncated.py +++ b/tests/distributions/test_truncated.py @@ -585,3 +585,17 @@ def test_truncated_identity_input(dist_op): rv_out = Truncated.dist(dist=dist_op(mu_identity, 5), lower=0, upper=1) assert np.ptp(draw(rv_out, draws=500)) < 1 + + +@pytest.mark.parametrize("rv_op", [icdf_normal, rejection_normal]) +def test_truncated_custom_dist_indexed_argument(rv_op): + # Regression test for https://github.com/pymc-devs/pymc/issues/7312 + + def dist(scale, size): + return pt.exp(rv_op(scale=scale, size=size)) + + scale = Exponential.dist(scale=[1, 2, 3]) + latent = CustomDist.dist(scale[[0, 0, 1, 1, 2, 2]], dist=dist) + rv_out = Truncated.dist(latent, upper=7) + + assert np.ptp(draw(rv_out, draws=100)) < 7