Skip to content

Commit

Permalink
Merge pull request #181 from BlueBrain/store-threshold-data
Browse files Browse the repository at this point in the history
store and re-use threshold current data
  • Loading branch information
AurelienJaquier authored Jan 13, 2025
2 parents 626f780 + 098dd56 commit e127613
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 13 deletions.
1 change: 1 addition & 0 deletions bluepyemodel/access_point/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,7 @@ def format_emodel_data(self, model_data):
passedValidation=model_data.get("validated", None),
seed=model_data.get("seed", None),
emodel_metadata=emodel_metadata,
threshold_data=model_data.get("threshold_data", None),
)

return emodel
Expand Down
7 changes: 7 additions & 0 deletions bluepyemodel/emodel_pipeline/emodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def __init__(
emodel_metadata=None,
workflow_id=None,
nexus_images=None, # pylint: disable=unused-argument
threshold_data=None,
):
"""Init
Expand All @@ -138,6 +139,10 @@ def __init__(
workflow_id (str): EModelWorkflow id on nexus.
nexus_images (list): list of pdfs associated to the emodel.
Not used, retained for legacy purposes only.
threshold_data (dict): contains rmp, Rin, holding current and threshold current values
to avoid re-computation of related protocols. Keys must match the 'output_key'
used in respective protocols. If None, threshold-related protocols will be
re-computed each time protocols are run.
"""

self.emodel_metadata = emodel_metadata
Expand Down Expand Up @@ -175,6 +180,7 @@ def __init__(

self.responses = {}
self.evaluator = None
self.threshold_data = threshold_data if threshold_data is not None else {}

def copy_pdf_dependencies_to_new_path(self, seed, overwrite=False):
"""Copy pdf dependencies to new path using allen notation"""
Expand Down Expand Up @@ -212,4 +218,5 @@ def as_dict(self):
"passedValidation": self.passed_validation,
"nexus_images": pdf_dependencies,
"seed": self.seed,
"threshold_data": self.threshold_data,
}
11 changes: 10 additions & 1 deletion bluepyemodel/evaluation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from bluepyemodel.access_point import get_access_point
from bluepyemodel.access_point.local import LocalAccessPoint
from bluepyemodel.evaluation.evaluator import create_evaluator
from bluepyemodel.evaluation.protocols import ProtocolRunner
from bluepyemodel.model import model
from bluepyemodel.tools.mechanisms import compile_mechs_in_emodel_dir
from bluepyemodel.tools.mechanisms import delete_compiled_mechanisms
Expand Down Expand Up @@ -120,14 +121,18 @@ def get_responses(to_run):
Args:
to_run (dict): of the form
to_run = {"evaluator": CellEvaluator, "parameters": Dict}
to_run = {"evaluator": CellEvaluator, "parameters": Dict, "threshold_data": Dict}
"""

eva = to_run["evaluator"]
params = to_run["parameters"]

eva.cell_model.unfreeze(params)

for prot in eva.fitness_protocols.values():
if to_run.get("threshold_data", {}) and isinstance(prot, ProtocolRunner):
prot.threshold_data = to_run["threshold_data"]

responses = eva.run_protocols(protocols=eva.fitness_protocols.values(), param_values=params)
responses["evaluator"] = eva

Expand All @@ -142,6 +147,7 @@ def compute_responses(
preselect_for_validation=False,
store_responses=False,
load_from_local=False,
recompute_threshold_protocols=False,
):
"""Compute the responses of the emodel to the optimisation and validation protocols.
Expand All @@ -157,6 +163,8 @@ def compute_responses(
only select models that have not been through validation yet.
store_responses (bool): whether to locally store the responses.
load_from_local (bool): True to load responses from locally saved recordings.
recompute_threshold_protocols (bool): True to re-compute rmp, rin, holding current and
threshold current even when threshold output is available.
Returns:
emodels (list): list of emodels.
"""
Expand All @@ -183,6 +191,7 @@ def compute_responses(
{
"evaluator": copy.deepcopy(cell_evaluator),
"parameters": mo.parameters,
"threshold_data": {} if recompute_threshold_protocols else mo.threshold_data,
}
)

Expand Down
34 changes: 22 additions & 12 deletions bluepyemodel/evaluation/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,7 @@ def __init__(self, protocols, name="ProtocolRunner"):

self.protocols = protocols
self.execution_order = self.compute_execution_order()
self.threshold_data = {}

def _add_to_execution_order(self, protocol, execution_order, before_index=None):
"""Recursively adds protocols to the execution order while making sure that their
Expand Down Expand Up @@ -996,19 +997,28 @@ def run(self, cell_model, param_values, sim=None, isolate=None, timeout=None):
cell_model.freeze(param_values)

for protocol_name in self.execution_order:
logger.debug("Computing protocol %s", protocol_name)
new_responses = self.protocols[protocol_name].run(
cell_model,
param_values={},
sim=sim,
isolate=isolate,
timeout=timeout,
responses=responses,
)
prot_output_key = getattr(self.protocols[protocol_name], "output_key", None)
if prot_output_key in self.threshold_data:
logger.debug(
"Skipping protocol %s, using saved value %s",
protocol_name,
self.threshold_data[prot_output_key],
)
new_responses = {prot_output_key, self.threshold_data[prot_output_key]}
else:
logger.debug("Computing protocol %s", protocol_name)
new_responses = self.protocols[protocol_name].run(
cell_model,
param_values={},
sim=sim,
isolate=isolate,
timeout=timeout,
responses=responses,
)

if new_responses is None or any(v is None for v in new_responses.values()):
logger.debug("None in responses, exiting evaluation")
break
if new_responses is None or any(v is None for v in new_responses.values()):
logger.debug("None in responses, exiting evaluation")
break

responses.update(new_responses)

Expand Down
5 changes: 5 additions & 0 deletions bluepyemodel/validation/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ def validate(access_point, mapper, preselect_for_validation=False):
)
)

# save threshold computation results
model.threshold_data = {
key: output for key, output in model.responses.items() if key[:4] == "bpo_"
}

access_point.store_or_update_emodel(model)

return emodels

0 comments on commit e127613

Please sign in to comment.