From c3aa9f362fce71bd9408a0a4b4432bb87b1a72bc Mon Sep 17 00:00:00 2001 From: Peter Hurford Date: Tue, 10 Jan 2023 19:54:13 -0500 Subject: [PATCH] Don't slow down low n mixture sample - use best of both worlds! --- squigglepy/samplers.py | 101 +++++++++++++++++++++++++++++------------ tests/integration.py | 10 ++-- 2 files changed, 76 insertions(+), 35 deletions(-) diff --git a/squigglepy/samplers.py b/squigglepy/samplers.py index e7c6d46..434ad95 100644 --- a/squigglepy/samplers.py +++ b/squigglepy/samplers.py @@ -480,6 +480,63 @@ def discrete_sample(items, samples=1, verbose=False, _multicore_tqdm_n=1, _multicore_tqdm_cores=_multicore_tqdm_cores) +def _mixture_sample_for_large_n(values, weights=None, relative_weights=None, + samples=1, verbose=False, _multicore_tqdm_n=1, + _multicore_tqdm_cores=1): + def _run_presample(dist, pbar): + if _is_dist(dist) and dist.type == 'mixture': + raise ValueError(('You cannot nest mixture distributions within ' + + 'mixture distributions.')) + elif _is_dist(dist) and dist.type == 'discrete': + raise ValueError(('You cannot nest discrete distributions within ' + + 'mixture distributions.')) + _tick_tqdm(pbar) + return _enlist(sample(dist, n=samples)) + + pbar = _init_tqdm(verbose=verbose, total=len(values)) + values = [_run_presample(v, pbar) for v in values] + _flush_tqdm(pbar) + + def _run_mixture(picker, i, pbar): + _tick_tqdm(pbar, _multicore_tqdm_cores) + for j, w in enumerate(weights): + if picker < w: + return values[j][i] + return values[-1][i] + + weights = np.cumsum(weights) + picker = uniform_sample(0, 1, samples=samples) + + tqdm_samples = samples if _multicore_tqdm_cores == 1 else _multicore_tqdm_n + pbar = _init_tqdm(verbose=verbose, total=tqdm_samples) + out = _simplify([_run_mixture(p, i, pbar) for i, p in enumerate(_enlist(picker))]) + _flush_tqdm(pbar) + + return out + + +def _mixture_sample_for_small_n(values, weights=None, relative_weights=None, samples=1, + verbose=False, _multicore_tqdm_n=1, _multicore_tqdm_cores=1): + def _run_mixture(values, weights, pbar=None, tick=1): + r_ = uniform_sample(0, 1) + _tick_tqdm(pbar, tick) + for i, dist in enumerate(values): + weight = weights[i] + if r_ <= weight: + return sample(dist) + return sample(dist) + + weights = np.cumsum(weights) + tqdm_samples = samples if _multicore_tqdm_cores == 1 else _multicore_tqdm_n + pbar = _init_tqdm(verbose=verbose, total=tqdm_samples) + out = _simplify([_run_mixture(values=values, + weights=weights, + pbar=pbar, + tick=_multicore_tqdm_cores) for _ in range(samples)]) + _flush_tqdm(pbar) + return out + + def mixture_sample(values, weights=None, relative_weights=None, samples=1, verbose=False, _multicore_tqdm_n=1, _multicore_tqdm_cores=1): """ @@ -528,36 +585,20 @@ def mixture_sample(values, weights=None, relative_weights=None, samples=1, verbo if len(values) == 1: return sample(values[0], n=samples) - def _run_presample(dist, pbar): - if _is_dist(dist) and dist.type == 'mixture': - raise ValueError(('You cannot nest mixture distributions within ' + - 'mixture distributions.')) - elif _is_dist(dist) and dist.type == 'discrete': - raise ValueError(('You cannot nest discrete distributions within ' + - 'mixture distributions.')) - _tick_tqdm(pbar) - return _enlist(sample(dist, n=samples)) - - pbar = _init_tqdm(verbose=verbose, total=len(values)) - values = [_run_presample(v, pbar) for v in values] - _flush_tqdm(pbar) - - def _run_mixture(picker, i, pbar): - _tick_tqdm(pbar, _multicore_tqdm_cores) - for j, w in enumerate(weights): - if picker < w: - return values[j][i] - return values[-1][i] - - weights = np.cumsum(weights) - picker = uniform_sample(0, 1, samples=samples) - - tqdm_samples = samples if _multicore_tqdm_cores == 1 else _multicore_tqdm_n - pbar = _init_tqdm(verbose=verbose, total=tqdm_samples) - out = _simplify([_run_mixture(p, i, pbar) for i, p in enumerate(_enlist(picker))]) - _flush_tqdm(pbar) - - return out + if samples > 100: + return _mixture_sample_for_large_n(values=values, + weights=weights, + samples=samples, + verbose=verbose, + _multicore_tqdm_n=_multicore_tqdm_n, + _multicore_tqdm_cores=_multicore_tqdm_cores) + else: + return _mixture_sample_for_small_n(values=values, + weights=weights, + samples=samples, + verbose=verbose, + _multicore_tqdm_n=_multicore_tqdm_n, + _multicore_tqdm_cores=_multicore_tqdm_cores) def sample(dist=None, n=1, lclip=None, rclip=None, memcache=False, reload_cache=False, diff --git a/tests/integration.py b/tests/integration.py index e531a9e..58ed1e1 100644 --- a/tests/integration.py +++ b/tests/integration.py @@ -381,7 +381,7 @@ def move_days(days): print('ERROR 9') import pdb pdb.set_trace() - _mark_time(start9, 0.019, 'Test 9 complete') + _mark_time(start9, 0.002, 'Test 9 complete') print('Test 10 (ALARM NET)...') @@ -513,7 +513,7 @@ def move_days(days): print('ERROR ON 18') import pdb pdb.set_trace() - _mark_time(start18, 18.9, 'Test 18 complete') + _mark_time(start18, 2.66, 'Test 18 complete') print('Test 19 (RCLIP FIDELITY, 1M SAMPLES)...') @@ -526,7 +526,7 @@ def move_days(days): print('ERROR ON 19') import pdb pdb.set_trace() - test_19_mark = _mark_time(start19, 18.9, 'Test 19 complete') + test_19_mark = _mark_time(start19, 2.78, 'Test 19 complete') print('Test 20 (MULTICORE SAMPLE, 10M SAMPLES)...') @@ -539,7 +539,7 @@ def move_days(days): print('ERROR ON 20') import pdb pdb.set_trace() - test_20_mark = _mark_time(start20, 108.2, 'Test 20 complete') + test_20_mark = _mark_time(start20, 7.43, 'Test 20 complete') print('1 core 10M RUNS expected {}sec'.format(round(test_19_mark['timing(sec)'] * 10, 1))) print('7 core 10M RUNS ideal {}sec'.format(round(test_19_mark['timing(sec)'] * 10 / 7, 1))) print('7 core 10M RUNS actual {}sec'.format(round(test_20_mark['timing(sec)'], 1))) @@ -570,5 +570,5 @@ def move_days(days): print('Squigglepy version is {}'.format(sq.__version__)) # END - _mark_time(start1, 218.4, 'Integration tests complete') + _mark_time(start1, 154.1, 'Integration tests complete') print('DONE! INTEGRATION TEST SUCCESS!')