Skip to content

Commit

Permalink
test: added test class for CellResponse plotting methods
Browse files Browse the repository at this point in the history
So far only added test for the spikes raster plot.
  • Loading branch information
gtdang committed Oct 11, 2024
1 parent fb13b78 commit 9717ae7
Showing 1 changed file with 89 additions and 0 deletions.
89 changes: 89 additions & 0 deletions hnn_core/tests/test_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,95 @@ def test_dipole_visualization(setup_net):
plt.close('all')


class TestCellResponsePlotters:

@pytest.fixture(scope='class')
def class_setup_net(self):
hnn_core_root = op.dirname(hnn_core.__file__)
params_fname = op.join(hnn_core_root, 'param', 'default.json')
params = read_params(params_fname)
net = jones_2009_model(params, mesh_shape=(3, 3))

return net

@pytest.fixture(scope='class')
def base_simulation_spikes(self, class_setup_net):
net = class_setup_net
weights_ampa = {'L2_pyramidal': 0.1, 'L5_pyramidal': 1.}
syn_delays = {'L2_pyramidal': 0.1, 'L5_pyramidal': 1.}
net.add_bursty_drive(
'beta_prox', tstart=0., burst_rate=25, burst_std=5,
numspikes=1, spike_isi=0, n_drive_cells=11, location='proximal',
weights_ampa=weights_ampa, synaptic_delays=syn_delays,
event_seed=14)

net.add_bursty_drive(
'beta_dist', tstart=0., burst_rate=25, burst_std=5,
numspikes=1, spike_isi=0, n_drive_cells=11, location='distal',
weights_ampa=weights_ampa, synaptic_delays=syn_delays,
event_seed=14)
dpls = simulate_dipole(net, tstop=100., n_trials=2, record_vsec='all')

return net, dpls

def test_spikes_raster_trial_idx(self, base_simulation_spikes):
net, _ = base_simulation_spikes

# Bad index argument raises error
with pytest.raises(TypeError,
match="trial_idx must be an instance of"):
net.cell_response.plot_spikes_raster(trial_idx='blah', show=False)

# Test valid index arguments
for index_arg in (0, [0, 1]):
fig = net.cell_response.plot_spikes_raster(trial_idx=index_arg,
show=False)
# Check that collections contain data
assert (all([collection.get_positions() != [-1]
for collection in fig.axes[0].collections]),
"No data plotted in raster plot")

def test_spikes_raster_colors(self, base_simulation_spikes):
net, _ = base_simulation_spikes

def _get_line_hex_colors(fig):
colors = [matplotlib.colors.to_hex(line.get_color())
for line in fig.axes[0].legend_.get_lines()]
return colors

# Default colors should be the default color cycle
fig = net.cell_response.plot_spikes_raster(trial_idx=0, show=False)
colors = _get_line_hex_colors(fig)
default_colors = (plt.rcParams['axes.prop_cycle']
.by_key()['color'][0:len(colors)])
assert colors == default_colors

# Custom hex colors
custom_colors = ['#daf7a6', '#ffc300', '#ff5733', '#c70039']
fig = net.cell_response.plot_spikes_raster(trial_idx=0, show=False,
colors=custom_colors)
colors = _get_line_hex_colors(fig)
assert colors == custom_colors

# Custom named colors
custom_colors = ['skyblue', 'maroon', 'gold', 'hotpink']
color_map = matplotlib.colors.get_named_colors_mapping()
fig = net.cell_response.plot_spikes_raster(trial_idx=0, show=False,
colors=custom_colors)
colors = _get_line_hex_colors(fig)
assert colors == [color_map[color].lower() for color in custom_colors]

# Incorrect number of colors
too_few = ['r', 'g', 'b']
too_many = ['r', 'g', 'b', 'y', 'k']
for colors in [too_few, too_many]:
with pytest.raises(ValueError,
match='Number of colors must be equal to'):
net.cell_response.plot_spikes_raster(trial_idx=0,
show=False,
colors=colors)


def test_network_plotter_init(setup_net):
"""Test init keywords of NetworkPlotter class."""
net = setup_net
Expand Down

0 comments on commit 9717ae7

Please sign in to comment.