diff --git a/pymc/distributions/truncated.py b/pymc/distributions/truncated.py index e21e0d4b3f..364af05584 100644 --- a/pymc/distributions/truncated.py +++ b/pymc/distributions/truncated.py @@ -211,10 +211,14 @@ def rv_op(cls, dist, lower, upper, max_n_steps, size=None): # Fallback to rejection sampling def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs): next_rng, new_truncated_rv = dist.owner.op.make_node(rng, *rv_inputs).outputs - truncated_rv = pt.set_subtensor( - truncated_rv[reject_draws], - new_truncated_rv[reject_draws], - ) + # Avoid scalar boolean indexing + if truncated_rv.type.ndim == 0: + truncated_rv = new_truncated_rv + else: + truncated_rv = pt.set_subtensor( + truncated_rv[reject_draws], + new_truncated_rv[reject_draws], + ) reject_draws = pt.or_((truncated_rv < lower), (truncated_rv > upper)) return ( diff --git a/tests/distributions/test_truncated.py b/tests/distributions/test_truncated.py index f4ddfef1e9..d9d007c51f 100644 --- a/tests/distributions/test_truncated.py +++ b/tests/distributions/test_truncated.py @@ -101,11 +101,12 @@ def test_truncation_specialized_op(shape_info): @pytest.mark.parametrize("lower, upper", [(-1, np.inf), (-1, 1.5), (-np.inf, 1.5)]) @pytest.mark.parametrize("op_type", ["icdf", "rejection"]) -def test_truncation_continuous_random(op_type, lower, upper): +@pytest.mark.parametrize("scalar", [True, False]) +def test_truncation_continuous_random(op_type, lower, upper, scalar): loc = 0.15 scale = 10 normal_op = icdf_normal if op_type == "icdf" else rejection_normal - x = normal_op(loc, scale, name="x", size=100) + x = normal_op(loc, scale, name="x", size=() if scalar else (100,)) xt = Truncated.dist(x, lower=lower, upper=upper) assert isinstance(xt.owner.op, TruncatedRV) @@ -134,7 +135,7 @@ def test_truncation_continuous_random(op_type, lower, upper): assert np.unique(xt_draws).size == xt_draws.size else: with pytest.raises(TruncationError, match="^Truncation did not converge"): - draw(xt) + draw(xt, draws=100 if scalar else 1) @pytest.mark.parametrize("lower, upper", [(-1, np.inf), (-1, 1.5), (-np.inf, 1.5)])