Skip to content

Commit

Permalink
Removed net_json param and update test
Browse files Browse the repository at this point in the history
Signed-off-by: samadpls <[email protected]>
  • Loading branch information
samadpls committed Aug 17, 2024
1 parent 55c5180 commit d6a023a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 35 deletions.
27 changes: 10 additions & 17 deletions hnn_core/batch_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# Ryan Thorpe <[email protected]>
# Mainak Jas <[email protected]>

import json
import numpy as np
import os
from joblib import Parallel, delayed, parallel_config
Expand All @@ -14,11 +13,10 @@
from .externals.mne import _validate_type, _check_option
from .dipole import simulate_dipole
from .network_models import jones_2009_model
from .hnn_io import dict_to_network


class BatchSimulate(object):
def __init__(self, set_params, net=jones_2009_model(), net_json=None,
def __init__(self, set_params, net=jones_2009_model(),
tstop=170, dt=0.025, n_trials=1,
save_folder='./sim_results', batch_size=100,
overwrite=True, save_outputs=False, save_dpl=True,
Expand All @@ -41,9 +39,6 @@ def __init__(self, set_params, net=jones_2009_model(), net_json=None,
The network model to use for simulations. Must be an instance of
jones_2009_model, law_2021_model, or calcium_model.
Default is jones_2009_model().
net_json : str, optional
The path to a JSON file to create the network model. If provided,
this will override the `net` parameter. Default is None.
tstop : float, optional
The stop time for the simulation. Default is 170 ms.
dt : float, optional
Expand Down Expand Up @@ -125,8 +120,6 @@ def __init__(self, set_params, net=jones_2009_model(), net_json=None,
_validate_type(save_currents, types=(bool,), item_name='save_currents')
_validate_type(save_calcium, types=(bool,), item_name='save_calcium')
_validate_type(clear_cache, types=(bool,), item_name='clear_cache')
_validate_type(net_json, types=('path-like', None),
item_name='net_json')

if set_params is not None and not callable(set_params):
raise TypeError("set_params must be a callable function")
Expand Down Expand Up @@ -154,7 +147,6 @@ def __init__(self, set_params, net=jones_2009_model(), net_json=None,
self.save_currents = save_currents
self.save_calcium = save_calcium
self.clear_cache = clear_cache
self.net_json = net_json

def run(self, param_grid, return_output=True,
combinations=True, n_jobs=1, backend='loky',
Expand Down Expand Up @@ -296,14 +288,7 @@ def _run_single_sim(self, param_values):
- `param_values`: The parameter values used for the simulation.
"""

if isinstance(self.net_json, str):
with open(self.net_json, 'r') as file:
net_data = json.load(file)
net = dict_to_network(net_data)
else:
net = self.net
net = net.copy()

net = self.net.copy()
self.set_params(param_values, net)

results = {'net': net, 'param_values': param_values}
Expand Down Expand Up @@ -396,6 +381,14 @@ def _save(self, results, start_idx, end_idx):
if getattr(self, f'save_{attr}') and attr in results[0]:
save_data[attr] = [result[attr] for result in results]

metadata = {
'batch_size': self.batch_size,
'n_trials': self.n_trials,
'tstop': self.tstop,
'dt': self.dt
}
save_data['metadata'] = metadata

file_name = os.path.join(self.save_folder,
f'sim_run_{start_idx}-{end_idx}.npz')
if os.path.exists(file_name) and not self.overwrite:
Expand Down
18 changes: 0 additions & 18 deletions hnn_core/tests/test_batch_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,6 @@ def test_parameter_validation():
with pytest.raises(TypeError, match="net must be"):
BatchSimulate(net="invalid_network", set_params=lambda x: x)

with pytest.raises(TypeError, match="net_json must be"):
BatchSimulate(net_json=123, set_params=lambda x: x)

with pytest.raises(ValueError, match="'record_vsec' parameter"):
BatchSimulate(set_params=lambda x: x, record_vsec="invalid")

Expand Down Expand Up @@ -117,21 +114,6 @@ def test_run_single_sim(batch_simulate_instance):
assert isinstance(result['net'], type(batch_simulate_instance.net))


def test_net_json_loading(param_grid):
"""Test loading the network from a JSON file."""
json_path = assets_path / 'jones2009_3x3_drives.json'

batch_simulate = BatchSimulate(net_json=str(json_path),
set_params=lambda x, y: x,
tstop=70)

result = batch_simulate._run_single_sim(param_grid)
assert isinstance(result, dict)
assert 'net' in result
assert 'param_values' in result
assert 'dpl' in result


def test_simulate_batch(batch_simulate_instance, param_grid):
"""Test simulating a batch of parameter sets."""
param_combinations = batch_simulate_instance._generate_param_combinations(
Expand Down

0 comments on commit d6a023a

Please sign in to comment.