Skip to content

Commit

Permalink
test: added tests for dictionary color mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
gtdang committed Oct 21, 2024
1 parent 834a8b1 commit c97f520
Showing 1 changed file with 33 additions and 8 deletions.
41 changes: 33 additions & 8 deletions hnn_core/tests/test_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def test_dipole_visualization(setup_net):


class TestCellResponsePlotters:

"""Tests plotting methods of the CellResponse class"""
@pytest.fixture(scope='class')
def class_setup_net(self):
hnn_core_root = op.dirname(hnn_core.__file__)
Expand Down Expand Up @@ -302,31 +302,33 @@ def test_spikes_raster_colors(self, base_simulation_spikes):
def _get_line_hex_colors(fig):
colors = [matplotlib.colors.to_hex(line.get_color())
for line in fig.axes[0].legend_.get_lines()]
return colors
labels = [text.get_text()
for text in fig.axes[0].legend_.get_texts()]
return colors, labels

# Default colors should be the default color cycle
fig = net.cell_response.plot_spikes_raster(trial_idx=0, show=False)
colors = _get_line_hex_colors(fig)
colors, _ = _get_line_hex_colors(fig)
default_colors = (plt.rcParams['axes.prop_cycle']
.by_key()['color'][0:len(colors)])
assert colors == default_colors

# Custom hex colors
# Custom hex colors as list
custom_colors = ['#daf7a6', '#ffc300', '#ff5733', '#c70039']
fig = net.cell_response.plot_spikes_raster(trial_idx=0, show=False,
colors=custom_colors)
colors = _get_line_hex_colors(fig)
colors, _ = _get_line_hex_colors(fig)
assert colors == custom_colors

# Custom named colors
# Custom named colors as list
custom_colors = ['skyblue', 'maroon', 'gold', 'hotpink']
color_map = matplotlib.colors.get_named_colors_mapping()
fig = net.cell_response.plot_spikes_raster(trial_idx=0, show=False,
colors=custom_colors)
colors = _get_line_hex_colors(fig)
colors, _ = _get_line_hex_colors(fig)
assert colors == [color_map[color].lower() for color in custom_colors]

# Incorrect number of colors
# Incorrect number of colors as list
too_few = ['r', 'g', 'b']
too_many = ['r', 'g', 'b', 'y', 'k']
for colors in [too_few, too_many]:
Expand All @@ -336,6 +338,29 @@ def _get_line_hex_colors(fig):
show=False,
colors=colors)

# Colors as dict mapping
dict_mapping = {'L2_basket': '#daf7a6', 'L2_pyramidal': '#ffc300',
'L5_basket': '#ff5733', 'L5_pyramidal': '#c70039'}
fig = net.cell_response.plot_spikes_raster(trial_idx=0, show=False,
colors=dict_mapping)
colors, _ = _get_line_hex_colors(fig)
assert colors == list(dict_mapping.values())

# Change color of only one cell type
dict_mapping = {'L2_pyramidal': '#daf7a6'}
fig = net.cell_response.plot_spikes_raster(trial_idx=0, show=False,
colors=dict_mapping)
colors, cell_types = _get_line_hex_colors(fig)
assert colors[cell_types.index('L2_pyramidal')] == '#daf7a6'

# Invalid key in dict mapping
dict_mapping = {'bad_cell_type': '#daf7a6'}
with pytest.raises(ValueError,
match='Invalid cell types provided.'):
net.cell_response.plot_spikes_raster(trial_idx=0, show=False,
colors=dict_mapping)



def test_network_plotter_init(setup_net):
"""Test init keywords of NetworkPlotter class."""
Expand Down

0 comments on commit c97f520

Please sign in to comment.