diff --git a/hnn_core/batch_simulate.py b/hnn_core/batch_simulate.py index c75386a11..73a6ca0e8 100644 --- a/hnn_core/batch_simulate.py +++ b/hnn_core/batch_simulate.py @@ -5,7 +5,6 @@ # Ryan Thorpe # Mainak Jas -import json import numpy as np import os from joblib import Parallel, delayed, parallel_config @@ -14,11 +13,10 @@ from .externals.mne import _validate_type, _check_option from .dipole import simulate_dipole from .network_models import jones_2009_model -from .hnn_io import dict_to_network class BatchSimulate(object): - def __init__(self, set_params, net=jones_2009_model(), net_json=None, + def __init__(self, set_params, net=jones_2009_model(), tstop=170, dt=0.025, n_trials=1, save_folder='./sim_results', batch_size=100, overwrite=True, save_outputs=False, save_dpl=True, @@ -41,9 +39,6 @@ def __init__(self, set_params, net=jones_2009_model(), net_json=None, The network model to use for simulations. Must be an instance of jones_2009_model, law_2021_model, or calcium_model. Default is jones_2009_model(). - net_json : str, optional - The path to a JSON file to create the network model. If provided, - this will override the `net` parameter. Default is None. tstop : float, optional The stop time for the simulation. Default is 170 ms. dt : float, optional @@ -125,8 +120,6 @@ def __init__(self, set_params, net=jones_2009_model(), net_json=None, _validate_type(save_currents, types=(bool,), item_name='save_currents') _validate_type(save_calcium, types=(bool,), item_name='save_calcium') _validate_type(clear_cache, types=(bool,), item_name='clear_cache') - _validate_type(net_json, types=('path-like', None), - item_name='net_json') if set_params is not None and not callable(set_params): raise TypeError("set_params must be a callable function") @@ -154,7 +147,6 @@ def __init__(self, set_params, net=jones_2009_model(), net_json=None, self.save_currents = save_currents self.save_calcium = save_calcium self.clear_cache = clear_cache - self.net_json = net_json def run(self, param_grid, return_output=True, combinations=True, n_jobs=1, backend='loky', @@ -296,14 +288,7 @@ def _run_single_sim(self, param_values): - `param_values`: The parameter values used for the simulation. """ - if isinstance(self.net_json, str): - with open(self.net_json, 'r') as file: - net_data = json.load(file) - net = dict_to_network(net_data) - else: - net = self.net - net = net.copy() - + net = self.net.copy() self.set_params(param_values, net) results = {'net': net, 'param_values': param_values} @@ -396,6 +381,14 @@ def _save(self, results, start_idx, end_idx): if getattr(self, f'save_{attr}') and attr in results[0]: save_data[attr] = [result[attr] for result in results] + metadata = { + 'batch_size': self.batch_size, + 'n_trials': self.n_trials, + 'tstop': self.tstop, + 'dt': self.dt + } + save_data['metadata'] = metadata + file_name = os.path.join(self.save_folder, f'sim_run_{start_idx}-{end_idx}.npz') if os.path.exists(file_name) and not self.overwrite: diff --git a/hnn_core/tests/test_batch_simulate.py b/hnn_core/tests/test_batch_simulate.py index 4604ade32..aa0266dcb 100644 --- a/hnn_core/tests/test_batch_simulate.py +++ b/hnn_core/tests/test_batch_simulate.py @@ -79,9 +79,6 @@ def test_parameter_validation(): with pytest.raises(TypeError, match="net must be"): BatchSimulate(net="invalid_network", set_params=lambda x: x) - with pytest.raises(TypeError, match="net_json must be"): - BatchSimulate(net_json=123, set_params=lambda x: x) - with pytest.raises(ValueError, match="'record_vsec' parameter"): BatchSimulate(set_params=lambda x: x, record_vsec="invalid") @@ -117,21 +114,6 @@ def test_run_single_sim(batch_simulate_instance): assert isinstance(result['net'], type(batch_simulate_instance.net)) -def test_net_json_loading(param_grid): - """Test loading the network from a JSON file.""" - json_path = assets_path / 'jones2009_3x3_drives.json' - - batch_simulate = BatchSimulate(net_json=str(json_path), - set_params=lambda x, y: x, - tstop=70) - - result = batch_simulate._run_single_sim(param_grid) - assert isinstance(result, dict) - assert 'net' in result - assert 'param_values' in result - assert 'dpl' in result - - def test_simulate_batch(batch_simulate_instance, param_grid): """Test simulating a batch of parameter sets.""" param_combinations = batch_simulate_instance._generate_param_combinations(