Skip to content

Commit

Permalink
Fix truncated rejection sampling for scalar RVs
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Oct 24, 2023
1 parent efa0d34 commit 6b486b9
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
12 changes: 8 additions & 4 deletions pymc/distributions/truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,14 @@ def rv_op(cls, dist, lower, upper, max_n_steps, size=None):
# Fallback to rejection sampling
def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs):
next_rng, new_truncated_rv = dist.owner.op.make_node(rng, *rv_inputs).outputs
truncated_rv = pt.set_subtensor(
truncated_rv[reject_draws],
new_truncated_rv[reject_draws],
)
# Avoid scalar boolean indexing
if truncated_rv.type.ndim == 0:
truncated_rv = new_truncated_rv
else:
truncated_rv = pt.set_subtensor(
truncated_rv[reject_draws],
new_truncated_rv[reject_draws],
)
reject_draws = pt.or_((truncated_rv < lower), (truncated_rv > upper))

return (
Expand Down
7 changes: 4 additions & 3 deletions tests/distributions/test_truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,12 @@ def test_truncation_specialized_op(shape_info):

@pytest.mark.parametrize("lower, upper", [(-1, np.inf), (-1, 1.5), (-np.inf, 1.5)])
@pytest.mark.parametrize("op_type", ["icdf", "rejection"])
def test_truncation_continuous_random(op_type, lower, upper):
@pytest.mark.parametrize("scalar", [True, False])
def test_truncation_continuous_random(op_type, lower, upper, scalar):
loc = 0.15
scale = 10
normal_op = icdf_normal if op_type == "icdf" else rejection_normal
x = normal_op(loc, scale, name="x", size=100)
x = normal_op(loc, scale, name="x", size=() if scalar else (100,))

xt = Truncated.dist(x, lower=lower, upper=upper)
assert isinstance(xt.owner.op, TruncatedRV)
Expand Down Expand Up @@ -134,7 +135,7 @@ def test_truncation_continuous_random(op_type, lower, upper):
assert np.unique(xt_draws).size == xt_draws.size
else:
with pytest.raises(TruncationError, match="^Truncation did not converge"):
draw(xt)
draw(xt, draws=100 if scalar else 1)


@pytest.mark.parametrize("lower, upper", [(-1, np.inf), (-1, 1.5), (-np.inf, 1.5)])
Expand Down

0 comments on commit 6b486b9

Please sign in to comment.