diff --git a/hnn_core/tests/test_viz.py b/hnn_core/tests/test_viz.py index 0a5057e0f..4b324217f 100644 --- a/hnn_core/tests/test_viz.py +++ b/hnn_core/tests/test_viz.py @@ -248,7 +248,7 @@ def test_dipole_visualization(setup_net): class TestCellResponsePlotters: - + """Tests plotting methods of the CellResponse class""" @pytest.fixture(scope='class') def class_setup_net(self): hnn_core_root = op.dirname(hnn_core.__file__) @@ -302,31 +302,33 @@ def test_spikes_raster_colors(self, base_simulation_spikes): def _get_line_hex_colors(fig): colors = [matplotlib.colors.to_hex(line.get_color()) for line in fig.axes[0].legend_.get_lines()] - return colors + labels = [text.get_text() + for text in fig.axes[0].legend_.get_texts()] + return colors, labels # Default colors should be the default color cycle fig = net.cell_response.plot_spikes_raster(trial_idx=0, show=False) - colors = _get_line_hex_colors(fig) + colors, _ = _get_line_hex_colors(fig) default_colors = (plt.rcParams['axes.prop_cycle'] .by_key()['color'][0:len(colors)]) assert colors == default_colors - # Custom hex colors + # Custom hex colors as list custom_colors = ['#daf7a6', '#ffc300', '#ff5733', '#c70039'] fig = net.cell_response.plot_spikes_raster(trial_idx=0, show=False, colors=custom_colors) - colors = _get_line_hex_colors(fig) + colors, _ = _get_line_hex_colors(fig) assert colors == custom_colors - # Custom named colors + # Custom named colors as list custom_colors = ['skyblue', 'maroon', 'gold', 'hotpink'] color_map = matplotlib.colors.get_named_colors_mapping() fig = net.cell_response.plot_spikes_raster(trial_idx=0, show=False, colors=custom_colors) - colors = _get_line_hex_colors(fig) + colors, _ = _get_line_hex_colors(fig) assert colors == [color_map[color].lower() for color in custom_colors] - # Incorrect number of colors + # Incorrect number of colors as list too_few = ['r', 'g', 'b'] too_many = ['r', 'g', 'b', 'y', 'k'] for colors in [too_few, too_many]: @@ -336,6 +338,29 @@ def _get_line_hex_colors(fig): show=False, colors=colors) + # Colors as dict mapping + dict_mapping = {'L2_basket': '#daf7a6', 'L2_pyramidal': '#ffc300', + 'L5_basket': '#ff5733', 'L5_pyramidal': '#c70039'} + fig = net.cell_response.plot_spikes_raster(trial_idx=0, show=False, + colors=dict_mapping) + colors, _ = _get_line_hex_colors(fig) + assert colors == list(dict_mapping.values()) + + # Change color of only one cell type + dict_mapping = {'L2_pyramidal': '#daf7a6'} + fig = net.cell_response.plot_spikes_raster(trial_idx=0, show=False, + colors=dict_mapping) + colors, cell_types = _get_line_hex_colors(fig) + assert colors[cell_types.index('L2_pyramidal')] == '#daf7a6' + + # Invalid key in dict mapping + dict_mapping = {'bad_cell_type': '#daf7a6'} + with pytest.raises(ValueError, + match='Invalid cell types provided.'): + net.cell_response.plot_spikes_raster(trial_idx=0, show=False, + colors=dict_mapping) + + def test_network_plotter_init(setup_net): """Test init keywords of NetworkPlotter class."""