Skip to content

Commit

Permalink
refactor mixture for speedup
Browse files Browse the repository at this point in the history
  • Loading branch information
peterhurford committed Jan 11, 2023
1 parent a7efbdb commit 283f305
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 26 deletions.
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
## v0.21

* Mixture sampling is now 4-23x faster.
* You can now get the version of squigglepy via `sq.__version__`.
* Fixes a bug where the tqdm was displayed with the incorrect count when collecting cores during a multicore `sample`.

## v0.20

Expand Down
44 changes: 28 additions & 16 deletions squigglepy/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from scipy import stats

from .utils import (_process_weights_values, _is_dist, _simplify, _safe_len, _core_cuts,
_init_tqdm, _tick_tqdm, _flush_tqdm)
from .utils import (_process_weights_values, _is_dist, _simplify, _enlist, _safe_len,
_core_cuts, _init_tqdm, _tick_tqdm, _flush_tqdm)


_squigglepy_internal_sample_caches = {}
Expand Down Expand Up @@ -528,23 +528,35 @@ def mixture_sample(values, weights=None, relative_weights=None, samples=1, verbo
if len(values) == 1:
return sample(values[0], n=samples)

def _run_mixture(values, weights, pbar=None, tick=1):
r_ = uniform_sample(0, 1)
_tick_tqdm(pbar, tick)
for i, dist in enumerate(values):
weight = weights[i]
if r_ <= weight:
return sample(dist)
return sample(dist)
def _run_presample(dist, pbar):
if _is_dist(dist) and dist.type == 'mixture':
raise ValueError(('You cannot nest mixture distributions within ' +
'mixture distributions.'))
elif _is_dist(dist) and dist.type == 'discrete':
raise ValueError(('You cannot nest discrete distributions within ' +
'mixture distributions.'))
_tick_tqdm(pbar)
return _enlist(sample(dist, n=samples))

pbar = _init_tqdm(verbose=verbose, total=len(values))
values = [_run_presample(v, pbar) for v in values]
_flush_tqdm(pbar)

def _run_mixture(picker, i, pbar):
_tick_tqdm(pbar, _multicore_tqdm_cores)
for j, w in enumerate(weights):
if picker < w:
return values[j][i]
return values[-1][i]

weights = np.cumsum(weights)
picker = uniform_sample(0, 1, samples=samples)

tqdm_samples = samples if _multicore_tqdm_cores == 1 else _multicore_tqdm_n
pbar = _init_tqdm(verbose=verbose, total=tqdm_samples)
out = _simplify(np.array([_run_mixture(values=values,
weights=weights,
pbar=pbar,
tick=_multicore_tqdm_cores) for _ in range(samples)]))
out = _simplify([_run_mixture(p, i, pbar) for i, p in enumerate(_enlist(picker))])
_flush_tqdm(pbar)

return out


Expand Down Expand Up @@ -673,7 +685,7 @@ def multicore_sample(core, total_n=n, total_cores=cores, verbose=False):
if verbose:
print('Collecting data...')
samples = np.array([])
pbar = _init_tqdm(verbose=verbose, total=n)
pbar = _init_tqdm(verbose=verbose, total=cores)
for core in range(cores):
with open('test-core-{}.npy'.format(core), 'rb') as f:
samples = np.concatenate((samples, np.load(f, allow_pickle=True)), axis=None)
Expand Down Expand Up @@ -839,4 +851,4 @@ def run_dist(dist, pbar=None, tick=1):
print('...Cached')

# Return
return samples
return np.array(samples) if isinstance(samples, list) else samples
9 changes: 9 additions & 0 deletions squigglepy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,15 @@ def _simplify(a):
return a


def _enlist(a):
if isinstance(a, list):
return a
elif _is_numpy(a):
return a.tolist()
else:
return [a]


def _safe_len(a):
if _is_numpy(a):
return a.size
Expand Down
2 changes: 1 addition & 1 deletion tests/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def move_days(days):
average = bayes.average(prior, evidence)
average_samples = sq.sample(average, n=K)
out = (np.mean(average_samples), np.std(average_samples))
if round(out[0], 2) != 2.75 and round(out[1], 2) != 0.94:
if round(out[0], 2) != 2.76 and round(out[1], 2) != 0.9:
print('ERROR 9')
import pdb
pdb.set_trace()
Expand Down
18 changes: 9 additions & 9 deletions tests/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,13 +372,13 @@ def test_sample_discrete_indirect_mixture():
@patch.object(samplers, '_get_rng', Mock(return_value=FakeRNG()))
@patch.object(samplers, 'uniform_sample', Mock(return_value=0))
def test_mixture_sample(mocker):
assert all(mixture_sample([norm(1, 2), norm(3, 4)], [0.2, 0.8])[0] == (1.5, 0.3))
assert mixture_sample([norm(1, 2), norm(3, 4)], [0.2, 0.8]) == (1.5, 0.3)


@patch.object(samplers, '_get_rng', Mock(return_value=FakeRNG()))
@patch.object(samplers, 'uniform_sample', Mock(return_value=0))
def test_mixture_sample_alt_format(mocker):
assert all(mixture_sample([[0.2, norm(1, 2)], [0.8, norm(3, 4)]])[0] == (1.5, 0.3))
assert mixture_sample([[0.2, norm(1, 2)], [0.8, norm(3, 4)]]) == (1.5, 0.3)


@patch.object(samplers, 'normal_sample', Mock(return_value=100))
Expand All @@ -391,13 +391,13 @@ def test_mixture_sample_rclip_lclip(mocker):
@patch.object(samplers, '_get_rng', Mock(return_value=FakeRNG()))
@patch.object(samplers, 'uniform_sample', Mock(return_value=0))
def test_mixture_sample_no_weights(mocker):
assert all(mixture_sample([norm(1, 2), norm(3, 4)])[0] == (1.5, 0.3))
assert mixture_sample([norm(1, 2), norm(3, 4)]) == (1.5, 0.3)


@patch.object(samplers, '_get_rng', Mock(return_value=FakeRNG()))
@patch.object(samplers, 'uniform_sample', Mock(return_value=0))
def test_mixture_sample_different_distributions(mocker):
assert all(mixture_sample([lognorm(1, 2), norm(3, 4)])[0] == (0.35, 0.21))
assert mixture_sample([lognorm(1, 2), norm(3, 4)]) == (0.35, 0.21)


@patch.object(samplers, '_get_rng', Mock(return_value=FakeRNG()))
Expand All @@ -409,13 +409,13 @@ def test_mixture_sample_with_numbers(mocker):
@patch.object(samplers, '_get_rng', Mock(return_value=FakeRNG()))
@patch.object(samplers, 'uniform_sample', Mock(return_value=0))
def test_sample_mixture(mocker):
assert all(sample(mixture([norm(1, 2), norm(3, 4)], [0.2, 0.8]))[0] == (1.5, 0.3))
assert sample(mixture([norm(1, 2), norm(3, 4)], [0.2, 0.8])) == (1.5, 0.3)


@patch.object(samplers, '_get_rng', Mock(return_value=FakeRNG()))
@patch.object(samplers, 'uniform_sample', Mock(return_value=0))
def test_sample_mixture_alt_format(mocker):
assert all(sample(mixture([[0.2, norm(1, 2)], [0.8, norm(3, 4)]]))[0] == (1.5, 0.3))
assert sample(mixture([[0.2, norm(1, 2)], [0.8, norm(3, 4)]])) == (1.5, 0.3)


@patch.object(samplers, 'normal_sample', Mock(return_value=1))
Expand Down Expand Up @@ -495,13 +495,13 @@ def test_sample_mixture_competing_clip(mocker):
@patch.object(samplers, '_get_rng', Mock(return_value=FakeRNG()))
@patch.object(samplers, 'uniform_sample', Mock(return_value=0))
def test_sample_mixture_no_weights(mocker):
assert all(sample(mixture([norm(1, 2), norm(3, 4)]))[0] == (1.5, 0.3))
assert sample(mixture([norm(1, 2), norm(3, 4)])) == (1.5, 0.3)


@patch.object(samplers, '_get_rng', Mock(return_value=FakeRNG()))
@patch.object(samplers, 'uniform_sample', Mock(return_value=0))
def test_sample_mixture_different_distributions(mocker):
assert all(sample(mixture([lognorm(1, 2), norm(3, 4)]))[0] == (0.35, 0.21))
assert sample(mixture([lognorm(1, 2), norm(3, 4)])) == (0.35, 0.21)


@patch.object(samplers, '_get_rng', Mock(return_value=FakeRNG()))
Expand All @@ -516,7 +516,7 @@ def test_sample_mixture_can_be_discrete():
assert ~mixture([0, 1, 2]) == 0
assert ~mixture([[0.9, 'a'], [0.1, 'b']]) == 'a'
assert ~mixture({'a': 0.9, 'b': 0.1}) == 'a'
assert all((~mixture([norm(1, 2), norm(3, 4)]))[0] == (1.5, 0.3))
assert ~mixture([norm(1, 2), norm(3, 4)]) == (1.5, 0.3)


@patch.object(samplers, '_get_rng', Mock(return_value=FakeRNG()))
Expand Down

0 comments on commit 283f305

Please sign in to comment.