From e7fcf8de52aa4fbd908994ea1aa33eb639d728b0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Aur=C3=A9lien?= <aurelien@MacBook-Pro-de-Aurelien.local>
Date: Fri, 10 Jan 2025 16:50:51 +0100
Subject: [PATCH] store and re-use threshold current data

---
 bluepyemodel/access_point/local.py     |  1 +
 bluepyemodel/emodel_pipeline/emodel.py |  7 ++++++
 bluepyemodel/evaluation/evaluation.py  | 11 ++++++++-
 bluepyemodel/evaluation/protocols.py   | 34 +++++++++++++++++---------
 bluepyemodel/validation/validation.py  |  5 ++++
 5 files changed, 45 insertions(+), 13 deletions(-)

diff --git a/bluepyemodel/access_point/local.py b/bluepyemodel/access_point/local.py
index 0eaabb96..80541b58 100644
--- a/bluepyemodel/access_point/local.py
+++ b/bluepyemodel/access_point/local.py
@@ -666,6 +666,7 @@ def format_emodel_data(self, model_data):
             passedValidation=model_data.get("validated", None),
             seed=model_data.get("seed", None),
             emodel_metadata=emodel_metadata,
+            threshold_data=model_data.get("threshold_data", None),
         )
 
         return emodel
diff --git a/bluepyemodel/emodel_pipeline/emodel.py b/bluepyemodel/emodel_pipeline/emodel.py
index b7a13710..40bb6945 100644
--- a/bluepyemodel/emodel_pipeline/emodel.py
+++ b/bluepyemodel/emodel_pipeline/emodel.py
@@ -122,6 +122,7 @@ def __init__(
         emodel_metadata=None,
         workflow_id=None,
         nexus_images=None,  # pylint: disable=unused-argument
+        threshold_data=None,
     ):
         """Init
 
@@ -138,6 +139,10 @@ def __init__(
             workflow_id (str): EModelWorkflow id on nexus.
             nexus_images (list): list of pdfs associated to the emodel.
                 Not used, retained for legacy purposes only.
+            threshold_data (dict): contains rmp, Rin, holding current and threshold current values
+                to avoid re-computation of related protocols. Keys must match the 'output_key'
+                used in respective protocols. If None, threshold-related protocols will be
+                re-computed each time protocols are run.
         """
 
         self.emodel_metadata = emodel_metadata
@@ -175,6 +180,7 @@ def __init__(
 
         self.responses = {}
         self.evaluator = None
+        self.threshold_data = threshold_data if threshold_data is not None else {}
 
     def copy_pdf_dependencies_to_new_path(self, seed, overwrite=False):
         """Copy pdf dependencies to new path using allen notation"""
@@ -212,4 +218,5 @@ def as_dict(self):
             "passedValidation": self.passed_validation,
             "nexus_images": pdf_dependencies,
             "seed": self.seed,
+            "threshold_data": self.threshold_data,
         }
diff --git a/bluepyemodel/evaluation/evaluation.py b/bluepyemodel/evaluation/evaluation.py
index 0f07c365..d04f5842 100644
--- a/bluepyemodel/evaluation/evaluation.py
+++ b/bluepyemodel/evaluation/evaluation.py
@@ -27,6 +27,7 @@
 from bluepyemodel.access_point import get_access_point
 from bluepyemodel.access_point.local import LocalAccessPoint
 from bluepyemodel.evaluation.evaluator import create_evaluator
+from bluepyemodel.evaluation.protocols import ProtocolRunner
 from bluepyemodel.model import model
 from bluepyemodel.tools.mechanisms import compile_mechs_in_emodel_dir
 from bluepyemodel.tools.mechanisms import delete_compiled_mechanisms
@@ -120,7 +121,7 @@ def get_responses(to_run):
 
     Args:
         to_run (dict): of the form
-            to_run = {"evaluator": CellEvaluator, "parameters": Dict}
+            to_run = {"evaluator": CellEvaluator, "parameters": Dict, "threshold_data": Dict}
     """
 
     eva = to_run["evaluator"]
@@ -128,6 +129,10 @@ def get_responses(to_run):
 
     eva.cell_model.unfreeze(params)
 
+    for prot in eva.fitness_protocols.values():
+        if to_run.get("threshold_data", {}) and isinstance(prot, ProtocolRunner):
+            prot.threshold_data = to_run["threshold_data"]
+
     responses = eva.run_protocols(protocols=eva.fitness_protocols.values(), param_values=params)
     responses["evaluator"] = eva
 
@@ -142,6 +147,7 @@ def compute_responses(
     preselect_for_validation=False,
     store_responses=False,
     load_from_local=False,
+    recompute_threshold_protocols=False,
 ):
     """Compute the responses of the emodel to the optimisation and validation protocols.
 
@@ -157,6 +163,8 @@ def compute_responses(
             only select models that have not been through validation yet.
         store_responses (bool): whether to locally store the responses.
         load_from_local (bool): True to load responses from locally saved recordings.
+        recompute_threshold_protocols (bool): True to re-compute rmp, rin, holding current and
+            threshold current even when threshold output is available.
     Returns:
         emodels (list): list of emodels.
     """
@@ -183,6 +191,7 @@ def compute_responses(
                 {
                     "evaluator": copy.deepcopy(cell_evaluator),
                     "parameters": mo.parameters,
+                    "threshold_data": {} if recompute_threshold_protocols else mo.threshold_data,
                 }
             )
 
diff --git a/bluepyemodel/evaluation/protocols.py b/bluepyemodel/evaluation/protocols.py
index 1712602a..23984cd7 100644
--- a/bluepyemodel/evaluation/protocols.py
+++ b/bluepyemodel/evaluation/protocols.py
@@ -958,6 +958,7 @@ def __init__(self, protocols, name="ProtocolRunner"):
 
         self.protocols = protocols
         self.execution_order = self.compute_execution_order()
+        self.threshold_data = {}
 
     def _add_to_execution_order(self, protocol, execution_order, before_index=None):
         """Recursively adds protocols to the execution order while making sure that their
@@ -996,19 +997,28 @@ def run(self, cell_model, param_values, sim=None, isolate=None, timeout=None):
         cell_model.freeze(param_values)
 
         for protocol_name in self.execution_order:
-            logger.debug("Computing protocol %s", protocol_name)
-            new_responses = self.protocols[protocol_name].run(
-                cell_model,
-                param_values={},
-                sim=sim,
-                isolate=isolate,
-                timeout=timeout,
-                responses=responses,
-            )
+            prot_output_key = getattr(self.protocols[protocol_name], "output_key", None)
+            if prot_output_key in self.threshold_data:
+                logger.debug(
+                    "Skipping protocol %s, using saved value %s",
+                    protocol_name,
+                    self.threshold_data[prot_output_key]
+                )
+                new_responses = {prot_output_key, self.threshold_data[prot_output_key]}
+            else:
+                logger.debug("Computing protocol %s", protocol_name)
+                new_responses = self.protocols[protocol_name].run(
+                    cell_model,
+                    param_values={},
+                    sim=sim,
+                    isolate=isolate,
+                    timeout=timeout,
+                    responses=responses,
+                )
 
-            if new_responses is None or any(v is None for v in new_responses.values()):
-                logger.debug("None in responses, exiting evaluation")
-                break
+                if new_responses is None or any(v is None for v in new_responses.values()):
+                    logger.debug("None in responses, exiting evaluation")
+                    break
 
             responses.update(new_responses)
 
diff --git a/bluepyemodel/validation/validation.py b/bluepyemodel/validation/validation.py
index 4babaa67..9bbd4ef9 100644
--- a/bluepyemodel/validation/validation.py
+++ b/bluepyemodel/validation/validation.py
@@ -130,6 +130,11 @@ def validate(access_point, mapper, preselect_for_validation=False):
             )
         )
 
+        # save threshold computation results
+        model.threshold_data = {
+            key: output for key, output in model.responses.items() if key[:4] == "bpo_"
+        }
+
         access_point.store_or_update_emodel(model)
 
     return emodels