Skip to content

Commit

Permalink
Allow WESS plugin to use a different binning scheme than the one used…
Browse files Browse the repository at this point in the history
… by simulation (westpa#231)

* Added an optional parameter to wess driver plugin for assigning custom bins

* allow passing a direct binmapper object in WESSDriver

* allow load_plugins to take argments that contain plugin configs directly

* Allow specifying RecursiveBinMapper in west.cfg

* removed debug message

* a bug fix in wess_driver where binmapper and bins can be inconsistent

* fixed the target state binning in wess_driver if a different binmapper was given

* fixed the segment binning for custom binmapper

* resolving conflicts

Co-authored-by: She Zhang <[email protected]>
  • Loading branch information
SHZ66 and She Zhang authored Feb 4, 2022
1 parent 30d772d commit ad3ca16
Showing 1 changed file with 30 additions and 5 deletions.
35 changes: 30 additions & 5 deletions src/westpa/westext/wess/wess_driver.py
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@
from westpa.core.yamlcfg import check_bool
from westpa.core.kinetics import RateAverager
from westpa.westext.wess.ProbAdjust import prob_adjust
from westpa.core._rc import bins_from_yaml_dict

EPS = np.finfo(np.float64).eps

@@ -72,6 +73,11 @@ def __init__(self, sim_manager, plugin_config):
self.rate_calc_queue_size = plugin_config.get('rate_calc_queue_size', 1)
self.rate_calc_n_blocks = plugin_config.get('rate_calc_n_blocks', 1)

bin_obj = plugin_config.get('bins', None)
if isinstance(bin_obj, dict):
bin_obj = bins_from_yaml_dict(bin_obj)
self.bin_mapper = bin_obj

if self.do_reweight:
sim_manager.register_callback(sim_manager.prepare_new_iteration, self.prepare_new_iteration, self.priority)

@@ -116,8 +122,30 @@ def prepare_new_iteration(self):
else:
log.debug('reweighting')

mapper = we_driver.bin_mapper
bins = we_driver.next_iter_binning
if self.bin_mapper is None:
mapper = we_driver.bin_mapper
bins = we_driver.next_iter_binning
target_regions = list(we_driver.target_states.keys())
westpa.rc.pstatus('\nReweighting using the simulation bin mapper:\n{}'.format(mapper))
else:
mapper = self.bin_mapper
bins = mapper.construct_bins()

segments = [s for s in we_driver.next_iter_segments]
pcoords = self.system.new_pcoord_array(len(segments))
for iseg, segment in enumerate(segments):
pcoords[iseg] = segment.pcoord[0]
assignments = mapper.assign(pcoords)
for (segment, assignment) in zip(segments, assignments):
bins[assignment].add(segment)

target_states = list(we_driver.target_states.values())
target_regions = []
for tstate in target_states:
tstate_assignment = mapper.assign([tstate.pcoord])[0]
target_regions.append(tstate_assignment)
westpa.rc.pstatus('\nReweighting using a different bin mapper than simulation:\n{}'.format(mapper))

n_bins = len(bins)
westpa.rc.pstatus('Averaging rates')
averager = self.get_rates(n_iter, mapper)
@@ -148,9 +176,6 @@ def prepare_new_iteration(self):
rij, oldindex = reduce_array(averager.average_rate)
uij = averager.stderr_rate[np.ix_(oldindex, oldindex)]

# target_regions = np.where(we_driver.target_state_mask)[0]
target_regions = list(we_driver.target_states.keys())

flat_target_regions = []
for target_region in target_regions:
if target_region in oldindex: # it is possible that the target region was removed (ie if no recycling has occurred)

0 comments on commit ad3ca16

Please sign in to comment.