Skip to content

Commit

Permalink
further refactor merging.
Browse files Browse the repository at this point in the history
Now I create a unique batch list. I.e. if one run used
[-inf, inf], [-5,5], [-4,4]
and another
[-inf, inf], [-5,5], [-3,2]
the runs will be [-inf,inf] [-5,5] [-4, 4] [3,2]

I also add a test that test for issue uncovered in #481
  • Loading branch information
segasai committed Aug 29, 2024
1 parent 1c785bd commit aea5c9c
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 14 deletions.
24 changes: 15 additions & 9 deletions py/dynesty/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1946,12 +1946,17 @@ def _merge_two(res1, res2, compute_aux=False):
combined_info[curk] = []

# Check if batch info is the same and modify counters accordingly.
if np.all(base_info['bounds'] == new_info['bounds']):
bounds = base_info['bounds']
boffset = 0
else:
bounds = np.concatenate((base_info['bounds'], new_info['bounds']))
boffset = len(base_info['bounds'])
ubounds = np.unique(np.concatenate(
(base_info['bounds'], new_info['bounds'])),
axis=0)
new_bound_map = {}
base_bound_map = {}
for i in range(len(new_info['bounds'])):
new_bound_map[i] = np.where(
np.all(new_info['bounds'][i] == ubounds, axis=1))[0][0]
for i in range(len(base_info['bounds'])):
base_bound_map[i] = np.where(
np.all(base_info['bounds'][i] == ubounds, axis=1))[0][0]

# Start our counters at the beginning of each set of dead points.
idx_base, idx_new = 0, 0
Expand Down Expand Up @@ -1999,13 +2004,14 @@ def _merge_two(res1, res2, compute_aux=False):
if logl_b <= logl_n:
add_idx = idx_base
from_run = base_info
from_map = base_bound_map
idx_base += 1
combined_info['batch'].append(from_run['batch'][add_idx])
else:
add_idx = idx_new
from_run = new_info
from_map = new_bound_map
idx_new += 1
combined_info['batch'].append(from_run['batch'][add_idx] + boffset)
combined_info['batch'].append(from_map[from_run['batch'][add_idx]])

for curk in ['id', 'u', 'v', 'logl', 'nc', 'it', 'blob']:
combined_info[curk].append(from_run[curk][add_idx])
Expand Down Expand Up @@ -2051,7 +2057,7 @@ def _merge_two(res1, res2, compute_aux=False):
samples=np.asarray(combined_info['v']),
logl=np.asarray(combined_info['logl']),
logvol=np.asarray(combined_info['logvol']),
batch_bounds=np.asarray(bounds),
batch_bounds=np.asarray(ubounds),
blob=np.asarray(combined_info['blob']))

for curk in ['id', 'it', 'n', 'u', 'batch']:
Expand Down
28 changes: 23 additions & 5 deletions tests/test_gau.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,31 @@ def test_gaussian():
assert (np.abs(logz - g.logz_truth) < sig * results.logzerr[-1])
res_comb = dyfunc.merge_runs([result_list[0]])
res_comb = dyfunc.merge_runs(result_list)
assert (np.abs(res_comb['logz'][-1] - g.logz_truth) <
sig * results['logzerr'][-1])
assert (np.abs(res_comb['logz'][-1] - g.logz_truth)
< sig * results['logzerr'][-1])
# check summary
res = sampler.results
res.summary()


def test_merge():
rstate = get_rstate()
g = Gaussian()
sampler1 = dynesty.DynamicNestedSampler(g.loglikelihood,
g.prior_transform,
g.ndim,
nlive=nlive,
rstate=rstate)
sampler1.run_nested(print_progress=printing, maxbatch=1)
sampler2 = dynesty.DynamicNestedSampler(g.loglikelihood,
g.prior_transform,
g.ndim,
nlive=nlive,
rstate=rstate)
sampler2.run_nested(print_progress=printing, maxbatch=2)
dyfunc.merge_runs((sampler1.results, sampler2.results))


def test_generator():
# Test that we can use the sampler as a generator
rstate = get_rstate()
Expand Down Expand Up @@ -239,9 +257,9 @@ def test_bounding_sample(bound, sample):
print(sampler.citations)


@pytest.mark.parametrize("bound,sample",
itertools.product(
['single', 'multi', 'balls', 'cubes'], ['unif']))
@pytest.mark.parametrize(
"bound,sample",
itertools.product(['single', 'multi', 'balls', 'cubes'], ['unif']))
def test_bounding_bootstrap(bound, sample):
# check various bounding methods with bootstrap

Expand Down

0 comments on commit aea5c9c

Please sign in to comment.