Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Spikes raster plot colors #895

Open
wants to merge 11 commits into
base: master
Choose a base branch
from

Conversation

gtdang
Copy link
Collaborator

@gtdang gtdang commented Sep 20, 2024

Changed the spike raster colors to have a white background and use the current color cycle's first 4 colors.

Screenshot 2024-09-20 at 4 12 30 PM

Question:

  • Should the API allow users to be able to specify colors?
    • With this new implementation the user could technically change the plot colors outside of the hnn-core API if they changed the Matplotlib color cycle and default background with the matplotlib API.

closes #888

@gtdang gtdang changed the title Spikes raster plot colors [WIP] Spikes raster plot colors Sep 20, 2024
@gtdang
Copy link
Collaborator Author

gtdang commented Sep 25, 2024

Hi @ntolley. I was looking into writing a test for making sure the colors are working expected for this plot. I was looking into the existing tests for the plotter in test_viz.py. It is shown below with the relevant tests for the plotter on lines 208-212. When I plot the spike event plot for the network specified I noticed that the plot was empty. Is this what we would expect for the network specified? The network only has 2 rhythmic drives added (lines 136-146).

def test_dipole_visualization(setup_net):
"""Test dipole visualisations."""
net = setup_net
# Test plotting of simulations with no spiking
dpls = simulate_dipole(net, tstop=100., n_trials=1)
net.cell_response.plot_spikes_raster()
net.cell_response.plot_spikes_hist()
weights_ampa = {'L2_pyramidal': 5.4e-5, 'L5_pyramidal': 5.4e-5}
syn_delays = {'L2_pyramidal': 0.1, 'L5_pyramidal': 1.}
net.add_bursty_drive(
'beta_prox', tstart=0., burst_rate=25, burst_std=5,
numspikes=1, spike_isi=0, n_drive_cells=11, location='proximal',
weights_ampa=weights_ampa, synaptic_delays=syn_delays,
event_seed=14)
net.add_bursty_drive(
'beta_dist', tstart=0., burst_rate=25, burst_std=5,
numspikes=1, spike_isi=0, n_drive_cells=11, location='distal',
weights_ampa=weights_ampa, synaptic_delays=syn_delays,
event_seed=14)
dpls = simulate_dipole(net, tstop=100., n_trials=2, record_vsec='all')
fig = dpls[0].plot() # plot the first dipole alone
axes = fig.get_axes()[0]
dpls[0].copy().smooth(window_len=10).plot(ax=axes) # add smoothed versions
dpls[0].copy().savgol_filter(h_freq=30).plot(ax=axes) # on top
# test decimation options
plot_dipole(dpls[0], decim=2, show=False)
for dec in [-1, [2, 2.]]:
with pytest.raises(ValueError,
match='each decimation factor must be a positive'):
plot_dipole(dpls[0], decim=dec, show=False)
# test plotting multiple dipoles as overlay
fig = plot_dipole(dpls, show=False)
# test plotting multiple dipoles with average
fig = plot_dipole(dpls, average=True, show=False)
plt.close('all')
# test plotting dipoles with multiple layers
fig, ax = plt.subplots()
fig = plot_dipole(dpls, show=False, ax=[ax], layer=['L2'])
fig = plot_dipole(dpls, show=False, layer=['L2', 'L5', 'agg'])
fig, axes = plt.subplots(nrows=3, ncols=1)
fig = plot_dipole(dpls, show=False, ax=axes, layer=['L2', 'L5', 'agg'])
fig, axes = plt.subplots(nrows=3, ncols=1)
fig = plot_dipole(dpls,
show=False,
ax=[axes[0], axes[1], axes[2]],
layer=['L2', 'L5', 'agg'])
plt.close('all')
with pytest.raises(AssertionError,
match="ax and layer should have the same size"):
fig, axes = plt.subplots(nrows=3, ncols=1)
fig = plot_dipole(dpls, show=False, ax=axes, layer=['L2', 'L5'])
# multiple TFRs get averaged
fig = plot_tfr_morlet(dpls, freqs=np.arange(23, 26, 1.), n_cycles=3,
show=False)
with pytest.raises(RuntimeError,
match="All dipoles must be scaled equally!"):
plot_dipole([dpls[0].copy().scale(10), dpls[1].copy().scale(20)])
with pytest.raises(RuntimeError,
match="All dipoles must be scaled equally!"):
plot_psd([dpls[0].copy().scale(10), dpls[1].copy().scale(20)])
with pytest.raises(RuntimeError,
match="All dipoles must be sampled equally!"):
dpl_sfreq = dpls[0].copy()
dpl_sfreq.sfreq /= 10
plot_psd([dpls[0], dpl_sfreq])
# pytest deprecation warning for tmin and tmax
with pytest.deprecated_call():
plot_dipole(dpls[0], show=False, tmin=10, tmax=100)
# test cell response plotting
with pytest.raises(TypeError, match="trial_idx must be an instance of"):
net.cell_response.plot_spikes_raster(trial_idx='blah', show=False)
net.cell_response.plot_spikes_raster(trial_idx=0, show=False)
fig = net.cell_response.plot_spikes_raster(trial_idx=[0, 1], show=False)
assert len(fig.axes[0].collections) > 0, "No data plotted in raster plot"
with pytest.raises(TypeError, match="trial_idx must be an instance of"):
net.cell_response.plot_spikes_hist(trial_idx='blah')
net.cell_response.plot_spikes_hist(trial_idx=0, show=False)
net.cell_response.plot_spikes_hist(trial_idx=[0, 1], show=False)
net.cell_response.plot_spikes_hist(color='r')
net.cell_response.plot_spikes_hist(color=['C0', 'C1'])
net.cell_response.plot_spikes_hist(color={'beta_prox': 'r',
'beta_dist': 'g'})
net.cell_response.plot_spikes_hist(
spike_types={'group1': ['beta_prox', 'beta_dist']},
color={'group1': 'r'})
net.cell_response.plot_spikes_hist(
spike_types={'group1': ['beta']}, color={'group1': 'r'})
with pytest.raises(TypeError, match="color must be an instance of"):
net.cell_response.plot_spikes_hist(color=123)
with pytest.raises(ValueError):
net.cell_response.plot_spikes_hist(color='z')
with pytest.raises(ValueError):
net.cell_response.plot_spikes_hist(color={'beta_prox': 'z',
'beta_dist': 'g'})
with pytest.raises(TypeError, match="Dictionary values of color must"):
net.cell_response.plot_spikes_hist(color={'beta_prox': 123,
'beta_dist': 'g'})
with pytest.raises(ValueError, match="'beta_dist' must be"):
net.cell_response.plot_spikes_hist(color={'beta_prox': 'r'})
plt.close('all')

Here's a code snip to recreate:

from hnn_core import simulate_dipole, jones_2009_model, read_params
from pathlib import Path

hnn_core_root = Path.cwd().parents[0]
params_fname = Path(hnn_core_root, 'param', 'default.json')
params = read_params(params_fname)
net = jones_2009_model(params, mesh_shape=(3, 3))

weights_ampa = {'L2_pyramidal': 5.4e-5, 'L5_pyramidal': 5.4e-5}
syn_delays = {'L2_pyramidal': 0.1, 'L5_pyramidal': 1.}

net.add_bursty_drive(
    'beta_prox', tstart=0., burst_rate=25, burst_std=5,
    numspikes=1, spike_isi=0, n_drive_cells=11, location='proximal',
    weights_ampa=weights_ampa, synaptic_delays=syn_delays,
    event_seed=14)

net.add_bursty_drive(
    'beta_dist', tstart=0., burst_rate=25, burst_std=5,
    numspikes=1, spike_isi=0, n_drive_cells=11, location='distal',
    weights_ampa=weights_ampa, synaptic_delays=syn_delays,
    event_seed=14)

dpls = simulate_dipole(net, tstop=100., n_trials=2, record_vsec='all')

fig1 = net.cell_response.plot_spikes_raster()
fig1.show()

@ntolley
Copy link
Contributor

ntolley commented Sep 26, 2024

@gtdang this test was written for the edge case with the drives are too weak to produce spiking activity in the network

There was an earlier bug where the plotting function would throw an error if not spikes occurred, so this test is to make sure that an empty plot is generated (the desired behavior in these simulations)

@ntolley
Copy link
Contributor

ntolley commented Sep 26, 2024

If you want spiking just change the weights to a bigger number like 0.1 or 1.0!

@gtdang gtdang marked this pull request as ready for review October 11, 2024 19:50
@gtdang gtdang changed the title [WIP] Spikes raster plot colors [MRG] Spikes raster plot colors Oct 11, 2024
@gtdang gtdang requested a review from asoplata October 15, 2024 19:16
hnn_core/viz.py Outdated
def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True):
def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True,
cell_types=['L2_basket', 'L2_pyramidal',
'L5_basket', 'L5_pyramidal'],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bad idea to have a default list in a function. You will get funky effects in Python ... default should not be a mutable

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, we're going to refactor this call to be aligned with the plot_spikes_hist implementation so that it can also take a dict of color assignments. I don't think we need to expose the cell_types as an argument... though I wish there was a way to get it dynamically from the network instead of hard-coding the types.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see #916 ... it allows you to dynamically extract the cell types

class TestCellResponsePlotters:

@pytest.fixture(scope='class')
def class_setup_net(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing docstring


return net, dpls

def test_spikes_raster_trial_idx(self, base_simulation_spikes):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't fully follow the categorization ... maybe docstrings will clarify

events = []
for cell_type in cell_types:
cell_type_gids = np.unique(spike_gids[spike_types == cell_type])
cell_type_times, cell_type_ypos = [], []
color = next(color_iter)

for gid in cell_type_gids:
gid_time = spike_times[spike_gids == gid]
cell_type_times.append(gid_time)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while you are at it, I am wondering if this could be addressed as well. The following line:

cell_type_ypos.append(-gid)

causes cells that spike with neighboring gids to overlap. I have been staring at these raster plots recently and it's very hard to tell how many times the same cell spiked (important to understand the underlying dynamics). Adding a small offset between nearby cells should address that problem

Copy link
Collaborator Author

@gtdang gtdang Oct 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a look into this. The overlap is due to both the y-position as you've identified and the line lengths defined during the plot function call on line 568. The y-position is the center of each line and the length is how much it extends from that center point (+-2.5 each way for a value of 5).

hnn-core/hnn_core/viz.py

Lines 559 to 568 in 27c6fc1

for gid in cell_type_gids:
gid_time = spike_times[spike_gids == gid]
cell_type_times.append(gid_time)
cell_type_ypos.append(-gid)
if cell_type_times:
events.append(
ax.eventplot(cell_type_times, lineoffsets=cell_type_ypos,
color=cell_type_colors[cell_type],
label=cell_type, linelengths=5))

A simple solution is to change the line length to 1. However the lines will look more like dots with this change.
Screenshot 2024-10-23 at 10 06 41 AM

Another solution would be to analyze the cell times and gids, and apply a larger y-offset if they are within an X and Y bounding box of one another.

Let me know what you think.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay with the spikes looking like dots ... it's just a function of the number of cells in our network. Did a quick google image search of "spike raster plot" and the plots do look dotted when there are more neurons. I guess the y-offset = -gid is helpful since it allows you to identify the cell, so maybe best not to touch that. @ntolley any opinion here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

API: Spiking plot
3 participants