diff --git a/bluepyemodel/access_point/local.py b/bluepyemodel/access_point/local.py index 0eaabb96..80541b58 100644 --- a/bluepyemodel/access_point/local.py +++ b/bluepyemodel/access_point/local.py @@ -666,6 +666,7 @@ def format_emodel_data(self, model_data): passedValidation=model_data.get("validated", None), seed=model_data.get("seed", None), emodel_metadata=emodel_metadata, + threshold_data=model_data.get("threshold_data", None), ) return emodel diff --git a/bluepyemodel/emodel_pipeline/emodel.py b/bluepyemodel/emodel_pipeline/emodel.py index b7a13710..40bb6945 100644 --- a/bluepyemodel/emodel_pipeline/emodel.py +++ b/bluepyemodel/emodel_pipeline/emodel.py @@ -122,6 +122,7 @@ def __init__( emodel_metadata=None, workflow_id=None, nexus_images=None, # pylint: disable=unused-argument + threshold_data=None, ): """Init @@ -138,6 +139,10 @@ def __init__( workflow_id (str): EModelWorkflow id on nexus. nexus_images (list): list of pdfs associated to the emodel. Not used, retained for legacy purposes only. + threshold_data (dict): contains rmp, Rin, holding current and threshold current values + to avoid re-computation of related protocols. Keys must match the 'output_key' + used in respective protocols. If None, threshold-related protocols will be + re-computed each time protocols are run. """ self.emodel_metadata = emodel_metadata @@ -175,6 +180,7 @@ def __init__( self.responses = {} self.evaluator = None + self.threshold_data = threshold_data if threshold_data is not None else {} def copy_pdf_dependencies_to_new_path(self, seed, overwrite=False): """Copy pdf dependencies to new path using allen notation""" @@ -212,4 +218,5 @@ def as_dict(self): "passedValidation": self.passed_validation, "nexus_images": pdf_dependencies, "seed": self.seed, + "threshold_data": self.threshold_data, } diff --git a/bluepyemodel/evaluation/evaluation.py b/bluepyemodel/evaluation/evaluation.py index 0f07c365..d04f5842 100644 --- a/bluepyemodel/evaluation/evaluation.py +++ b/bluepyemodel/evaluation/evaluation.py @@ -27,6 +27,7 @@ from bluepyemodel.access_point import get_access_point from bluepyemodel.access_point.local import LocalAccessPoint from bluepyemodel.evaluation.evaluator import create_evaluator +from bluepyemodel.evaluation.protocols import ProtocolRunner from bluepyemodel.model import model from bluepyemodel.tools.mechanisms import compile_mechs_in_emodel_dir from bluepyemodel.tools.mechanisms import delete_compiled_mechanisms @@ -120,7 +121,7 @@ def get_responses(to_run): Args: to_run (dict): of the form - to_run = {"evaluator": CellEvaluator, "parameters": Dict} + to_run = {"evaluator": CellEvaluator, "parameters": Dict, "threshold_data": Dict} """ eva = to_run["evaluator"] @@ -128,6 +129,10 @@ def get_responses(to_run): eva.cell_model.unfreeze(params) + for prot in eva.fitness_protocols.values(): + if to_run.get("threshold_data", {}) and isinstance(prot, ProtocolRunner): + prot.threshold_data = to_run["threshold_data"] + responses = eva.run_protocols(protocols=eva.fitness_protocols.values(), param_values=params) responses["evaluator"] = eva @@ -142,6 +147,7 @@ def compute_responses( preselect_for_validation=False, store_responses=False, load_from_local=False, + recompute_threshold_protocols=False, ): """Compute the responses of the emodel to the optimisation and validation protocols. @@ -157,6 +163,8 @@ def compute_responses( only select models that have not been through validation yet. store_responses (bool): whether to locally store the responses. load_from_local (bool): True to load responses from locally saved recordings. + recompute_threshold_protocols (bool): True to re-compute rmp, rin, holding current and + threshold current even when threshold output is available. Returns: emodels (list): list of emodels. """ @@ -183,6 +191,7 @@ def compute_responses( { "evaluator": copy.deepcopy(cell_evaluator), "parameters": mo.parameters, + "threshold_data": {} if recompute_threshold_protocols else mo.threshold_data, } ) diff --git a/bluepyemodel/evaluation/protocols.py b/bluepyemodel/evaluation/protocols.py index 1712602a..95eacfad 100644 --- a/bluepyemodel/evaluation/protocols.py +++ b/bluepyemodel/evaluation/protocols.py @@ -958,6 +958,7 @@ def __init__(self, protocols, name="ProtocolRunner"): self.protocols = protocols self.execution_order = self.compute_execution_order() + self.threshold_data = {} def _add_to_execution_order(self, protocol, execution_order, before_index=None): """Recursively adds protocols to the execution order while making sure that their @@ -996,19 +997,28 @@ def run(self, cell_model, param_values, sim=None, isolate=None, timeout=None): cell_model.freeze(param_values) for protocol_name in self.execution_order: - logger.debug("Computing protocol %s", protocol_name) - new_responses = self.protocols[protocol_name].run( - cell_model, - param_values={}, - sim=sim, - isolate=isolate, - timeout=timeout, - responses=responses, - ) + prot_output_key = getattr(self.protocols[protocol_name], "output_key", None) + if prot_output_key in self.threshold_data: + logger.debug( + "Skipping protocol %s, using saved value %s", + protocol_name, + self.threshold_data[prot_output_key], + ) + new_responses = {prot_output_key, self.threshold_data[prot_output_key]} + else: + logger.debug("Computing protocol %s", protocol_name) + new_responses = self.protocols[protocol_name].run( + cell_model, + param_values={}, + sim=sim, + isolate=isolate, + timeout=timeout, + responses=responses, + ) - if new_responses is None or any(v is None for v in new_responses.values()): - logger.debug("None in responses, exiting evaluation") - break + if new_responses is None or any(v is None for v in new_responses.values()): + logger.debug("None in responses, exiting evaluation") + break responses.update(new_responses) diff --git a/bluepyemodel/validation/validation.py b/bluepyemodel/validation/validation.py index 4babaa67..9bbd4ef9 100644 --- a/bluepyemodel/validation/validation.py +++ b/bluepyemodel/validation/validation.py @@ -130,6 +130,11 @@ def validate(access_point, mapper, preselect_for_validation=False): ) ) + # save threshold computation results + model.threshold_data = { + key: output for key, output in model.responses.items() if key[:4] == "bpo_" + } + access_point.store_or_update_emodel(model) return emodels