Skip to content

Commit

Permalink
Don't slow down low n mixture sample - use best of both worlds!
Browse files Browse the repository at this point in the history
  • Loading branch information
peterhurford committed Jan 11, 2023
1 parent 283f305 commit c3aa9f3
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 35 deletions.
101 changes: 71 additions & 30 deletions squigglepy/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,63 @@ def discrete_sample(items, samples=1, verbose=False, _multicore_tqdm_n=1,
_multicore_tqdm_cores=_multicore_tqdm_cores)


def _mixture_sample_for_large_n(values, weights=None, relative_weights=None,
samples=1, verbose=False, _multicore_tqdm_n=1,
_multicore_tqdm_cores=1):
def _run_presample(dist, pbar):
if _is_dist(dist) and dist.type == 'mixture':
raise ValueError(('You cannot nest mixture distributions within ' +
'mixture distributions.'))
elif _is_dist(dist) and dist.type == 'discrete':
raise ValueError(('You cannot nest discrete distributions within ' +
'mixture distributions.'))
_tick_tqdm(pbar)
return _enlist(sample(dist, n=samples))

pbar = _init_tqdm(verbose=verbose, total=len(values))
values = [_run_presample(v, pbar) for v in values]
_flush_tqdm(pbar)

def _run_mixture(picker, i, pbar):
_tick_tqdm(pbar, _multicore_tqdm_cores)
for j, w in enumerate(weights):
if picker < w:
return values[j][i]
return values[-1][i]

weights = np.cumsum(weights)
picker = uniform_sample(0, 1, samples=samples)

tqdm_samples = samples if _multicore_tqdm_cores == 1 else _multicore_tqdm_n
pbar = _init_tqdm(verbose=verbose, total=tqdm_samples)
out = _simplify([_run_mixture(p, i, pbar) for i, p in enumerate(_enlist(picker))])
_flush_tqdm(pbar)

return out


def _mixture_sample_for_small_n(values, weights=None, relative_weights=None, samples=1,
verbose=False, _multicore_tqdm_n=1, _multicore_tqdm_cores=1):
def _run_mixture(values, weights, pbar=None, tick=1):
r_ = uniform_sample(0, 1)
_tick_tqdm(pbar, tick)
for i, dist in enumerate(values):
weight = weights[i]
if r_ <= weight:
return sample(dist)
return sample(dist)

weights = np.cumsum(weights)
tqdm_samples = samples if _multicore_tqdm_cores == 1 else _multicore_tqdm_n
pbar = _init_tqdm(verbose=verbose, total=tqdm_samples)
out = _simplify([_run_mixture(values=values,
weights=weights,
pbar=pbar,
tick=_multicore_tqdm_cores) for _ in range(samples)])
_flush_tqdm(pbar)
return out


def mixture_sample(values, weights=None, relative_weights=None, samples=1, verbose=False,
_multicore_tqdm_n=1, _multicore_tqdm_cores=1):
"""
Expand Down Expand Up @@ -528,36 +585,20 @@ def mixture_sample(values, weights=None, relative_weights=None, samples=1, verbo
if len(values) == 1:
return sample(values[0], n=samples)

def _run_presample(dist, pbar):
if _is_dist(dist) and dist.type == 'mixture':
raise ValueError(('You cannot nest mixture distributions within ' +
'mixture distributions.'))
elif _is_dist(dist) and dist.type == 'discrete':
raise ValueError(('You cannot nest discrete distributions within ' +
'mixture distributions.'))
_tick_tqdm(pbar)
return _enlist(sample(dist, n=samples))

pbar = _init_tqdm(verbose=verbose, total=len(values))
values = [_run_presample(v, pbar) for v in values]
_flush_tqdm(pbar)

def _run_mixture(picker, i, pbar):
_tick_tqdm(pbar, _multicore_tqdm_cores)
for j, w in enumerate(weights):
if picker < w:
return values[j][i]
return values[-1][i]

weights = np.cumsum(weights)
picker = uniform_sample(0, 1, samples=samples)

tqdm_samples = samples if _multicore_tqdm_cores == 1 else _multicore_tqdm_n
pbar = _init_tqdm(verbose=verbose, total=tqdm_samples)
out = _simplify([_run_mixture(p, i, pbar) for i, p in enumerate(_enlist(picker))])
_flush_tqdm(pbar)

return out
if samples > 100:
return _mixture_sample_for_large_n(values=values,
weights=weights,
samples=samples,
verbose=verbose,
_multicore_tqdm_n=_multicore_tqdm_n,
_multicore_tqdm_cores=_multicore_tqdm_cores)
else:
return _mixture_sample_for_small_n(values=values,
weights=weights,
samples=samples,
verbose=verbose,
_multicore_tqdm_n=_multicore_tqdm_n,
_multicore_tqdm_cores=_multicore_tqdm_cores)


def sample(dist=None, n=1, lclip=None, rclip=None, memcache=False, reload_cache=False,
Expand Down
10 changes: 5 additions & 5 deletions tests/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def move_days(days):
print('ERROR 9')
import pdb
pdb.set_trace()
_mark_time(start9, 0.019, 'Test 9 complete')
_mark_time(start9, 0.002, 'Test 9 complete')


print('Test 10 (ALARM NET)...')
Expand Down Expand Up @@ -513,7 +513,7 @@ def move_days(days):
print('ERROR ON 18')
import pdb
pdb.set_trace()
_mark_time(start18, 18.9, 'Test 18 complete')
_mark_time(start18, 2.66, 'Test 18 complete')


print('Test 19 (RCLIP FIDELITY, 1M SAMPLES)...')
Expand All @@ -526,7 +526,7 @@ def move_days(days):
print('ERROR ON 19')
import pdb
pdb.set_trace()
test_19_mark = _mark_time(start19, 18.9, 'Test 19 complete')
test_19_mark = _mark_time(start19, 2.78, 'Test 19 complete')


print('Test 20 (MULTICORE SAMPLE, 10M SAMPLES)...')
Expand All @@ -539,7 +539,7 @@ def move_days(days):
print('ERROR ON 20')
import pdb
pdb.set_trace()
test_20_mark = _mark_time(start20, 108.2, 'Test 20 complete')
test_20_mark = _mark_time(start20, 7.43, 'Test 20 complete')
print('1 core 10M RUNS expected {}sec'.format(round(test_19_mark['timing(sec)'] * 10, 1)))
print('7 core 10M RUNS ideal {}sec'.format(round(test_19_mark['timing(sec)'] * 10 / 7, 1)))
print('7 core 10M RUNS actual {}sec'.format(round(test_20_mark['timing(sec)'], 1)))
Expand Down Expand Up @@ -570,5 +570,5 @@ def move_days(days):
print('Squigglepy version is {}'.format(sq.__version__))

# END
_mark_time(start1, 218.4, 'Integration tests complete')
_mark_time(start1, 154.1, 'Integration tests complete')
print('DONE! INTEGRATION TEST SUCCESS!')

0 comments on commit c3aa9f3

Please sign in to comment.