Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Spikes raster plot colors #895

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
8 changes: 6 additions & 2 deletions hnn_core/cell_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,8 @@ def mean_rates(self, tstart, tstop, gid_ranges, mean_type='all'):

return spike_rates

def plot_spikes_raster(self, trial_idx=None, ax=None, show=True):
def plot_spikes_raster(self, trial_idx=None, ax=None, show=True,
colors=None):
"""Plot the aggregate spiking activity according to cell type.

Parameters
Expand All @@ -284,14 +285,17 @@ def plot_spikes_raster(self, trial_idx=None, ax=None, show=True):
An axis object from matplotlib. If None, a new figure is created.
show : bool
If True, show the figure.
colors: list of str | None
Optional custom colors to plot. Default will use the color cycler.

Returns
-------
fig : instance of matplotlib Figure
The matplotlib figure object.
"""
return plot_spikes_raster(
cell_response=self, trial_idx=trial_idx, ax=ax, show=show)
cell_response=self, trial_idx=trial_idx, ax=ax, show=show,
colors=colors)

def plot_spikes_hist(self, trial_idx=None, ax=None, spike_types=None,
color=None, show=True, **kwargs_hist):
Expand Down
97 changes: 97 additions & 0 deletions hnn_core/tests/test_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@
matplotlib.use('agg')


@pytest.fixture(autouse=True)
def cleanup_matplotlib():
# Code runs after the test finishes
yield
plt.close('all')
jasmainak marked this conversation as resolved.
Show resolved Hide resolved


@pytest.fixture
def setup_net():
hnn_core_root = op.dirname(hnn_core.__file__)
Expand Down Expand Up @@ -240,6 +247,96 @@ def test_dipole_visualization(setup_net):
plt.close('all')


class TestCellResponsePlotters:

@pytest.fixture(scope='class')
def class_setup_net(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing docstring

hnn_core_root = op.dirname(hnn_core.__file__)
params_fname = op.join(hnn_core_root, 'param', 'default.json')
params = read_params(params_fname)
net = jones_2009_model(params, mesh_shape=(3, 3))

return net

@pytest.fixture(scope='class')
def base_simulation_spikes(self, class_setup_net):
net = class_setup_net
weights_ampa = {'L2_pyramidal': 0.1, 'L5_pyramidal': 1.}
syn_delays = {'L2_pyramidal': 0.1, 'L5_pyramidal': 1.}
net.add_bursty_drive(
'beta_prox', tstart=0., burst_rate=25, burst_std=5,
numspikes=1, spike_isi=0, n_drive_cells=11, location='proximal',
weights_ampa=weights_ampa, synaptic_delays=syn_delays,
event_seed=14)

net.add_bursty_drive(
'beta_dist', tstart=0., burst_rate=25, burst_std=5,
numspikes=1, spike_isi=0, n_drive_cells=11, location='distal',
weights_ampa=weights_ampa, synaptic_delays=syn_delays,
event_seed=14)
dpls = simulate_dipole(net, tstop=100., n_trials=2, record_vsec='all')

return net, dpls

def test_spikes_raster_trial_idx(self, base_simulation_spikes):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't fully follow the categorization ... maybe docstrings will clarify

net, _ = base_simulation_spikes

# Bad index argument raises error
with pytest.raises(TypeError,
match="trial_idx must be an instance of"):
net.cell_response.plot_spikes_raster(trial_idx='blah', show=False)

# Test valid index arguments
for index_arg in (0, [0, 1]):
fig = net.cell_response.plot_spikes_raster(trial_idx=index_arg,
show=False)
# Check that collections contain data
assert all(
[collection.get_positions() != [-1]
for collection in fig.axes[0].collections]
), "No data plotted in raster plot"

def test_spikes_raster_colors(self, base_simulation_spikes):
net, _ = base_simulation_spikes

def _get_line_hex_colors(fig):
colors = [matplotlib.colors.to_hex(line.get_color())
for line in fig.axes[0].legend_.get_lines()]
return colors

# Default colors should be the default color cycle
fig = net.cell_response.plot_spikes_raster(trial_idx=0, show=False)
colors = _get_line_hex_colors(fig)
default_colors = (plt.rcParams['axes.prop_cycle']
.by_key()['color'][0:len(colors)])
assert colors == default_colors

# Custom hex colors
custom_colors = ['#daf7a6', '#ffc300', '#ff5733', '#c70039']
fig = net.cell_response.plot_spikes_raster(trial_idx=0, show=False,
colors=custom_colors)
colors = _get_line_hex_colors(fig)
assert colors == custom_colors

# Custom named colors
custom_colors = ['skyblue', 'maroon', 'gold', 'hotpink']
color_map = matplotlib.colors.get_named_colors_mapping()
fig = net.cell_response.plot_spikes_raster(trial_idx=0, show=False,
colors=custom_colors)
colors = _get_line_hex_colors(fig)
assert colors == [color_map[color].lower() for color in custom_colors]

# Incorrect number of colors
too_few = ['r', 'g', 'b']
too_many = ['r', 'g', 'b', 'y', 'k']
for colors in [too_few, too_many]:
with pytest.raises(ValueError,
match='Number of colors must be equal to'):
net.cell_response.plot_spikes_raster(trial_idx=0,
show=False,
colors=colors)


def test_network_plotter_init(setup_net):
"""Test init keywords of NetworkPlotter class."""
net = setup_net
Expand Down
35 changes: 27 additions & 8 deletions hnn_core/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,11 @@ def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None,
return ax.get_figure()


def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True):
def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True,
cell_types=['L2_basket', 'L2_pyramidal',
'L5_basket', 'L5_pyramidal'],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bad idea to have a default list in a function. You will get funky effects in Python ... default should not be a mutable

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, we're going to refactor this call to be aligned with the plot_spikes_hist implementation so that it can also take a dict of color assignments. I don't think we need to expose the cell_types as an argument... though I wish there was a way to get it dynamically from the network instead of hard-coding the types.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see #916 ... it allows you to dynamically extract the cell types

colors=None,
):
"""Plot the aggregate spiking activity according to cell type.

Parameters
Expand All @@ -519,6 +523,10 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True):
An axis object from matplotlib. If None, a new figure is created.
show : bool
If True, show the figure.
cell_types: list of str
List of cell types to plot
colors: list of str | None
Optional custom colors to plot. Default will use the color cycler.

Returns
-------
Expand All @@ -531,10 +539,19 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True):
if trial_idx is None:
trial_idx = list(range(n_trials))

# validate trial argument
if isinstance(trial_idx, int):
trial_idx = [trial_idx]
_validate_type(trial_idx, list, 'trial_idx', 'int, list of int')

# validate colors argument
if colors:
_validate_type(colors, list, 'colors', 'list of str')
if len(colors) != len(cell_types):
raise ValueError(f"Number of colors must be equal to number of "
f"cell types. {len(colors)} colors provided "
f"for {len(cell_types)} cell types.")

# Extract desired trials
spike_times = np.concatenate(
np.array(cell_response._spike_times, dtype=object)[trial_idx])
Expand All @@ -543,17 +560,20 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True):
spike_gids = np.concatenate(
np.array(cell_response._spike_gids, dtype=object)[trial_idx])

cell_types = ['L2_basket', 'L2_pyramidal', 'L5_basket', 'L5_pyramidal']
cell_type_colors = {'L5_pyramidal': 'r', 'L5_basket': 'b',
'L2_pyramidal': 'g', 'L2_basket': 'w'}

if ax is None:
_, ax = plt.subplots(1, 1, constrained_layout=True)

# Check if custom colors are provided, else use the default color cycle
if colors is None:
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
color_iter = iter(colors)

events = []
for cell_type in cell_types:
cell_type_gids = np.unique(spike_gids[spike_types == cell_type])
cell_type_times, cell_type_ypos = [], []
color = next(color_iter)

for gid in cell_type_gids:
gid_time = spike_times[spike_gids == gid]
cell_type_times.append(gid_time)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while you are at it, I am wondering if this could be addressed as well. The following line:

cell_type_ypos.append(-gid)

causes cells that spike with neighboring gids to overlap. I have been staring at these raster plots recently and it's very hard to tell how many times the same cell spiked (important to understand the underlying dynamics). Adding a small offset between nearby cells should address that problem

Copy link
Collaborator Author

@gtdang gtdang Oct 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a look into this. The overlap is due to both the y-position as you've identified and the line lengths defined during the plot function call on line 568. The y-position is the center of each line and the length is how much it extends from that center point (+-2.5 each way for a value of 5).

hnn-core/hnn_core/viz.py

Lines 559 to 568 in 27c6fc1

for gid in cell_type_gids:
gid_time = spike_times[spike_gids == gid]
cell_type_times.append(gid_time)
cell_type_ypos.append(-gid)
if cell_type_times:
events.append(
ax.eventplot(cell_type_times, lineoffsets=cell_type_ypos,
color=cell_type_colors[cell_type],
label=cell_type, linelengths=5))

A simple solution is to change the line length to 1. However the lines will look more like dots with this change.
Screenshot 2024-10-23 at 10 06 41 AM

Another solution would be to analyze the cell times and gids, and apply a larger y-offset if they are within an X and Y bounding box of one another.

Let me know what you think.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay with the spikes looking like dots ... it's just a function of the number of cells in our network. Did a quick google image search of "spike raster plot" and the plots do look dotted when there are more neurons. I guess the y-offset = -gid is helpful since it allows you to identify the cell, so maybe best not to touch that. @ntolley any opinion here?

Expand All @@ -562,16 +582,15 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True):
if cell_type_times:
events.append(
ax.eventplot(cell_type_times, lineoffsets=cell_type_ypos,
color=cell_type_colors[cell_type],
color=color,
label=cell_type, linelengths=5))
else:
events.append(
ax.eventplot([-1], lineoffsets=[-1],
color=cell_type_colors[cell_type],
color=color,
label=cell_type, linelengths=5))

ax.legend(handles=[e[0] for e in events], loc=1)
ax.set_facecolor('k')
ax.set_xlabel('Time (ms)')
ax.get_yaxis().set_visible(False)

Expand Down
Loading