diff --git a/hnn_core/cell_response.py b/hnn_core/cell_response.py index 248eca559..d84cf8376 100644 --- a/hnn_core/cell_response.py +++ b/hnn_core/cell_response.py @@ -273,7 +273,8 @@ def mean_rates(self, tstart, tstop, gid_ranges, mean_type='all'): return spike_rates - def plot_spikes_raster(self, trial_idx=None, ax=None, show=True): + def plot_spikes_raster(self, trial_idx=None, ax=None, show=True, + colors=None): """Plot the aggregate spiking activity according to cell type. Parameters @@ -284,6 +285,8 @@ def plot_spikes_raster(self, trial_idx=None, ax=None, show=True): An axis object from matplotlib. If None, a new figure is created. show : bool If True, show the figure. + colors: list of str | None + Optional custom colors to plot. Default will use the color cycler. Returns ------- @@ -291,7 +294,8 @@ def plot_spikes_raster(self, trial_idx=None, ax=None, show=True): The matplotlib figure object. """ return plot_spikes_raster( - cell_response=self, trial_idx=trial_idx, ax=ax, show=show) + cell_response=self, trial_idx=trial_idx, ax=ax, show=show, + colors=colors) def plot_spikes_hist(self, trial_idx=None, ax=None, spike_types=None, color=None, show=True, **kwargs_hist): diff --git a/hnn_core/tests/test_viz.py b/hnn_core/tests/test_viz.py index 081f2416e..5ae229fd8 100644 --- a/hnn_core/tests/test_viz.py +++ b/hnn_core/tests/test_viz.py @@ -19,6 +19,13 @@ matplotlib.use('agg') +@pytest.fixture(autouse=True) +def cleanup_matplotlib(): + # Code runs after the test finishes + yield + plt.close('all') + + @pytest.fixture def setup_net(): hnn_core_root = op.dirname(hnn_core.__file__) @@ -240,6 +247,122 @@ def test_dipole_visualization(setup_net): plt.close('all') +class TestCellResponsePlotters: + """Tests plotting methods of the CellResponse class""" + @pytest.fixture(scope='class') + def class_setup_net(self): + hnn_core_root = op.dirname(hnn_core.__file__) + params_fname = op.join(hnn_core_root, 'param', 'default.json') + params = read_params(params_fname) + net = jones_2009_model(params, mesh_shape=(3, 3)) + + return net + + @pytest.fixture(scope='class') + def base_simulation_spikes(self, class_setup_net): + net = class_setup_net + weights_ampa = {'L2_pyramidal': 0.1, 'L5_pyramidal': 1.} + syn_delays = {'L2_pyramidal': 0.1, 'L5_pyramidal': 1.} + net.add_bursty_drive( + 'beta_prox', tstart=0., burst_rate=25, burst_std=5, + numspikes=1, spike_isi=0, n_drive_cells=11, location='proximal', + weights_ampa=weights_ampa, synaptic_delays=syn_delays, + event_seed=14) + + net.add_bursty_drive( + 'beta_dist', tstart=0., burst_rate=25, burst_std=5, + numspikes=1, spike_isi=0, n_drive_cells=11, location='distal', + weights_ampa=weights_ampa, synaptic_delays=syn_delays, + event_seed=14) + dpls = simulate_dipole(net, tstop=100., n_trials=2, record_vsec='all') + + return net, dpls + + def test_spikes_raster_trial_idx(self, base_simulation_spikes): + """Plotting with different index arguments""" + net, _ = base_simulation_spikes + + # Bad index argument raises error + with pytest.raises(TypeError, + match="trial_idx must be an instance of"): + net.cell_response.plot_spikes_raster(trial_idx='blah', show=False) + + # Test valid index arguments + for index_arg in (0, [0, 1]): + fig = net.cell_response.plot_spikes_raster(trial_idx=index_arg, + show=False) + # Check that collections contain data + assert all( + [collection.get_positions() != [-1] + for collection in fig.axes[0].collections] + ), "No data plotted in raster plot" + + def test_spikes_raster_colors(self, base_simulation_spikes): + """Plotting with different color arguments""" + net, _ = base_simulation_spikes + + def _get_line_hex_colors(fig): + colors = [matplotlib.colors.to_hex(line.get_color()) + for line in fig.axes[0].legend_.get_lines()] + labels = [text.get_text() + for text in fig.axes[0].legend_.get_texts()] + return colors, labels + + # Default colors should be the default color cycle + fig = net.cell_response.plot_spikes_raster(trial_idx=0, show=False) + colors, _ = _get_line_hex_colors(fig) + default_colors = (plt.rcParams['axes.prop_cycle'] + .by_key()['color'][0:len(colors)]) + assert colors == default_colors + + # Custom hex colors as list + custom_colors = ['#daf7a6', '#ffc300', '#ff5733', '#c70039'] + fig = net.cell_response.plot_spikes_raster(trial_idx=0, show=False, + colors=custom_colors) + colors, _ = _get_line_hex_colors(fig) + assert colors == custom_colors + + # Custom named colors as list + custom_colors = ['skyblue', 'maroon', 'gold', 'hotpink'] + color_map = matplotlib.colors.get_named_colors_mapping() + fig = net.cell_response.plot_spikes_raster(trial_idx=0, show=False, + colors=custom_colors) + colors, _ = _get_line_hex_colors(fig) + assert colors == [color_map[color].lower() for color in custom_colors] + + # Incorrect number of colors as list + too_few = ['r', 'g', 'b'] + too_many = ['r', 'g', 'b', 'y', 'k'] + for colors in [too_few, too_many]: + with pytest.raises(ValueError, + match='Number of colors must be equal to'): + net.cell_response.plot_spikes_raster(trial_idx=0, + show=False, + colors=colors) + + # Colors as dict mapping + dict_mapping = {'L2_basket': '#daf7a6', 'L2_pyramidal': '#ffc300', + 'L5_basket': '#ff5733', 'L5_pyramidal': '#c70039'} + fig = net.cell_response.plot_spikes_raster(trial_idx=0, show=False, + colors=dict_mapping) + colors, _ = _get_line_hex_colors(fig) + assert colors == list(dict_mapping.values()) + + # Change color of only one cell type + dict_mapping = {'L2_pyramidal': '#daf7a6'} + fig = net.cell_response.plot_spikes_raster(trial_idx=0, show=False, + colors=dict_mapping) + colors, cell_types = _get_line_hex_colors(fig) + assert colors[cell_types.index('L2_pyramidal')] == '#daf7a6' + + # Invalid key in dict mapping + dict_mapping = {'bad_cell_type': '#daf7a6'} + with pytest.raises(ValueError, + match='Invalid cell types provided.'): + net.cell_response.plot_spikes_raster(trial_idx=0, show=False, + colors=dict_mapping) + + def test_network_plotter_init(setup_net): """Test init keywords of NetworkPlotter class.""" net = setup_net diff --git a/hnn_core/viz.py b/hnn_core/viz.py index 45169e9f3..15401e673 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -506,7 +506,9 @@ def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None, return ax.get_figure() -def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True): +def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True, + cell_types=None, colors=None, + ): """Plot the aggregate spiking activity according to cell type. Parameters @@ -519,6 +521,10 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True): An axis object from matplotlib. If None, a new figure is created. show : bool If True, show the figure. + cell_types: list of str + List of cell types to plot + colors: list of str | None + Optional custom colors to plot. Default will use the color cycler. Returns ------- @@ -531,10 +537,53 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True): if trial_idx is None: trial_idx = list(range(n_trials)) + # Get spike types + spike_types_data = np.concatenate(np.array(cell_response.spike_types, + dtype=object)) + spike_types = np.unique(spike_types_data).tolist() + + # validate trial argument if isinstance(trial_idx, int): trial_idx = [trial_idx] _validate_type(trial_idx, list, 'trial_idx', 'int, list of int') + # validate cell types + default_cell_types = ['L2_basket', 'L2_pyramidal', + 'L5_basket', 'L5_pyramidal'] + if cell_types: + _validate_type(cell_types, list, 'cell_types', 'list of str') + if not set(cell_types).issubset(set(spike_types)): + raise ValueError("Invalid cell types provided. " + f"Must be of set {spike_types}. " + f"Got {cell_types}") + default_cell_types = cell_types + + # Set default colors + default_colors = (plt.rcParams['axes.prop_cycle'] + .by_key()['color'][:len(default_cell_types)]) + cell_colors = {cell: color + for cell, color in zip(default_cell_types, default_colors)} + + # validate colors argument + _validate_type(colors, (list, dict, None), 'color', 'list of str, or dict') + if colors: + if isinstance(colors, list): + if len(colors) != len(default_cell_types): + raise ValueError( + f"Number of colors must be equal to number of " + f"cell types. {len(colors)} colors provided " + f"for {len(default_cell_types)} cell types.") + cell_colors = {cell: color + for cell, color in zip(default_cell_types, colors)} + + if isinstance(colors, dict): + # Check valid cell types + if not set(colors.keys()).issubset(set(spike_types)): + raise ValueError("Invalid cell types provided. " + f"Must be of set {spike_types}. " + f"Got {colors.keys()}") + cell_colors.update(colors) + # Extract desired trials spike_times = np.concatenate( np.array(cell_response._spike_times, dtype=object)[trial_idx]) @@ -543,17 +592,14 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True): spike_gids = np.concatenate( np.array(cell_response._spike_gids, dtype=object)[trial_idx]) - cell_types = ['L2_basket', 'L2_pyramidal', 'L5_basket', 'L5_pyramidal'] - cell_type_colors = {'L5_pyramidal': 'r', 'L5_basket': 'b', - 'L2_pyramidal': 'g', 'L2_basket': 'w'} - if ax is None: _, ax = plt.subplots(1, 1, constrained_layout=True) events = [] - for cell_type in cell_types: + for cell_type, color in cell_colors.items(): cell_type_gids = np.unique(spike_gids[spike_types == cell_type]) cell_type_times, cell_type_ypos = [], [] + for gid in cell_type_gids: gid_time = spike_times[spike_gids == gid] cell_type_times.append(gid_time) @@ -562,16 +608,15 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True): if cell_type_times: events.append( ax.eventplot(cell_type_times, lineoffsets=cell_type_ypos, - color=cell_type_colors[cell_type], + color=color, label=cell_type, linelengths=5)) else: events.append( ax.eventplot([-1], lineoffsets=[-1], - color=cell_type_colors[cell_type], + color=color, label=cell_type, linelengths=5)) ax.legend(handles=[e[0] for e in events], loc=1) - ax.set_facecolor('k') ax.set_xlabel('Time (ms)') ax.get_yaxis().set_visible(False)