Skip to content

Commit

Permalink
Allow WESS and WEED to detect when a simulation has been truncated (w…
Browse files Browse the repository at this point in the history
…estpa#452)

* Modified 'wess' attr to a dataset, so additional iterations can be appended

* added a check to w_truncate for wess history, which also truncates reweighting events after the cutoff iter

* Walked back changes to w_truncate and added truncate logic to wess_driver.py and weed_driver.py

* reinstated last_reweighting attribute to ensure backwards compatibility

* removed reading in numpy array
  • Loading branch information
gma57 authored Oct 25, 2024
1 parent 1af3dcf commit d1a1010
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/westpa/westext/weed/weed_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,15 @@ def prepare_new_iteration(self):

with self.data_manager.lock:
weed_global_group = self.data_manager.we_h5file.require_group('weed')
reweighting_history_dataset = weed_global_group.require_dataset(
'reweighting_history', (1,), maxshape=(None,), dtype=int
)
last_reweighting = int(weed_global_group.attrs.get('last_reweighting', 0))
if last_reweighting > n_iter:
last_reweighting = n_iter - 1
reweighting_history = reweighting_history_dataset[:]
reweighting_history = reweighting_history[reweighting_history < n_iter]
reweighting_history_dataset.resize((reweighting_history.size), axis=0)

if n_iter - last_reweighting < self.reweight_period:
# Not time to reweight yet
Expand Down Expand Up @@ -172,6 +180,8 @@ def prepare_new_iteration(self):
for bin, newprob in zip(bins, binprobs):
bin.reweight(newprob)

reweighting_history_dataset.resize((reweighting_history_dataset.shape[0] + 1), axis=0)
reweighting_history_dataset[-1] = n_iter
weed_global_group.attrs['last_reweighting'] = n_iter

assert (
Expand Down
10 changes: 10 additions & 0 deletions src/westpa/westext/wess/wess_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,15 @@ def prepare_new_iteration(self):

with self.data_manager.lock:
wess_global_group = self.data_manager.we_h5file.require_group('wess')
reweighting_history_dataset = wess_global_group.require_dataset(
'reweighting_history', (1,), maxshape=(None,), dtype=int
)
last_reweighting = int(wess_global_group.attrs.get('last_reweighting', 0))
if last_reweighting > n_iter:
last_reweighting = n_iter - 1
reweighting_history = reweighting_history_dataset[:]
reweighting_history = reweighting_history[reweighting_history < n_iter]
reweighting_history_dataset.resize((reweighting_history.size), axis=0)

if n_iter - last_reweighting < self.reweight_period:
# Not time to reweight yet
Expand Down Expand Up @@ -197,6 +205,8 @@ def prepare_new_iteration(self):
if len(bin):
bin.reweight(newprob)

reweighting_history_dataset.resize((reweighting_history_dataset.shape[0] + 1), axis=0)
reweighting_history_dataset[-1] = n_iter
wess_global_group.attrs['last_reweighting'] = n_iter

assert (
Expand Down

0 comments on commit d1a1010

Please sign in to comment.