Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Make all plots with time end at tstop #683

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions hnn_core/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,16 +467,11 @@ def savgol_filter(self, h_freq):
self.sfreq)
return self

def plot(self, tmin=None, tmax=None, layer='agg', decim=None, ax=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is technically a backwards incompatible change. If some user has tmin and tmax in their script, it will stop working now. The correct fix is to add a deprecation cycle. But if we want to assume that it's unlikely many users have this in their scripts, we should mark it as a "Bugfix" in whats_new.rst ...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought one reason to have tmin was for burn in period @rythorpe may have an opinion ...

but it's true we can probably just set plt.xlim((tmin, tmax)) to achieve the same

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comment in the previous PR here. I don't think we need tmin at the plotting level, however, it would be useful at the simulate_dipole level for a burn-in period. That's a very different feature though and should probably be implemented in a separate PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that we should do a deprecation cycle for these input arguments though.

color='k', show=True):
def plot(self, layer='agg', decim=None, ax=None, color='k', show=True):
"""Simple layer-specific plot function.

Parameters
----------
tmin : float or None
Start time of plot (in ms). If None, plot entire simulation.
tmax : float or None
End time of plot (in ms). If None, plot entire simulation.
layer : str
The layer to plot. Can be one of 'agg', 'L2', and 'L5'
decimate : int
Expand All @@ -493,8 +488,8 @@ def plot(self, tmin=None, tmax=None, layer='agg', decim=None, ax=None,
fig : instance of plt.fig
The matplotlib figure handle.
"""
return plot_dipole(self, tmin=tmin, tmax=tmax, ax=ax, layer=layer,
decim=decim, color=color, show=show)
return plot_dipole(self, ax=ax, layer=layer, decim=decim, color=color,
show=show)

def plot_psd(self, fmin=0, fmax=None, tmin=None, tmax=None, layer='agg',
color=None, label=None, ax=None, show=True):
Expand Down
14 changes: 5 additions & 9 deletions hnn_core/extracellular.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,9 +421,9 @@ def smooth(self, window_len):

return self

def plot_lfp(self, *, trial_no=None, contact_no=None, tmin=None, tmax=None,
ax=None, decim=None, color='cividis', voltage_offset=50,
voltage_scalebar=200, show=True):
def plot_lfp(self, *, trial_no=None, contact_no=None, ax=None, decim=None,
color='cividis', voltage_offset=50, voltage_scalebar=200,
show=True):
"""Plot laminar local field potential time series.

One plot is created for each trial. Multiple trials can be overlaid
Expand All @@ -435,11 +435,6 @@ def plot_lfp(self, *, trial_no=None, contact_no=None, tmin=None, tmax=None,
Trial number(s) to plot
contact_no : int | list of int | slice
Electrode contact number(s) to plot
tmin : float | None
Start time of plot in milliseconds. If None, plot entire
simulation.
tmax : float | None
End time of plot in milliseconds. If None, plot entire simulation.
ax : instance of matplotlib figure | None
The matplotlib axis
decim : int | list of int | None (default)
Expand Down Expand Up @@ -486,7 +481,7 @@ def plot_lfp(self, *, trial_no=None, contact_no=None, tmin=None, tmax=None,

for trial_data in plot_data:
fig = plot_laminar_lfp(
self.times, trial_data, tmin=tmin, tmax=tmax, ax=ax,
self.times, trial_data, ax=ax,
decim=decim, color=color,
voltage_offset=voltage_offset,
voltage_scalebar=voltage_scalebar,
Expand Down Expand Up @@ -534,6 +529,7 @@ class _ExtracellularArrayBuilder(object):
The instance of :class:`hnn_core.extracellular.ExtracellularArray` to
build in NEURON-Python
"""

def __init__(self, array):
self.array = array
self.n_contacts = array.n_contacts
Expand Down
48 changes: 24 additions & 24 deletions hnn_core/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ def plt_show(show=True, fig=None, **kwargs):
(fig or plt).show(**kwargs)


def plot_laminar_lfp(times, data, contact_labels, tmin=None, tmax=None,
ax=None, decim=None, color='cividis',
voltage_offset=50, voltage_scalebar=200, show=True):
def plot_laminar_lfp(times, data, contact_labels, ax=None, decim=None,
color='cividis', voltage_offset=50, voltage_scalebar=200,
show=True):
"""Plot laminar extracellular electrode array voltage time series.

Parameters
Expand All @@ -89,10 +89,6 @@ def plot_laminar_lfp(times, data, contact_labels, tmin=None, tmax=None,
Sampling times (in ms).
data : Two-dimensional Numpy array
The extracellular voltages as an (n_contacts, n_times) array.
tmin : float | None
Start time of plot in milliseconds. If None, plot entire simulation.
tmax : float | None
End time of plot in milliseconds. If None, plot entire simulation.
ax : instance of matplotlib figure | None
The matplotlib axis
decim : int | list of int | None (default)
Expand Down Expand Up @@ -168,11 +164,11 @@ def plot_laminar_lfp(times, data, contact_labels, tmin=None, tmax=None,
trace_offsets = np.arange(n_offsets)[:, np.newaxis] * voltage_offset

for contact_no, trace in enumerate(np.atleast_2d(data)):
plot_data, plot_times = _get_plot_data_trange(times, trace, tmin, tmax)
plot_data = trace
plot_times = times

if decim is not None:
plot_data, plot_times = _decimate_plot_data(decim, plot_data,
plot_times)
plot_data, plot_times = _decimate_plot_data(decim, trace, times)

if isinstance(color, np.ndarray):
col = color[contact_no]
Expand All @@ -182,6 +178,7 @@ def plot_laminar_lfp(times, data, contact_labels, tmin=None, tmax=None,
col = color
ax.plot(plot_times, plot_data + trace_offsets[contact_no],
label=f'C{contact_no}', color=col)
ax.set_xlim(right=times[-1])

if voltage_offset is not None:
ax.set_ylim(-voltage_offset, n_offsets * voltage_offset)
Expand Down Expand Up @@ -220,18 +217,14 @@ def plot_laminar_lfp(times, data, contact_labels, tmin=None, tmax=None,
return ax.get_figure()


def plot_dipole(dpl, tmin=None, tmax=None, ax=None, layer='agg', decim=None,
def plot_dipole(dpl, ax=None, layer='agg', decim=None,
color='k', label="average", average=False, show=True):
"""Simple layer-specific plot function.

Parameters
----------
dpl : instance of Dipole | list of Dipole instances
The Dipole object.
tmin : float or None
Start time of plot in milliseconds. If None, plot entire simulation.
tmax : float or None
End time of plot in milliseconds. If None, plot entire simulation.
ax : instance of matplotlib figure | None
The matplotlib axis
layer : str
Expand Down Expand Up @@ -288,9 +281,8 @@ def plot_dipole(dpl, tmin=None, tmax=None, ax=None, layer='agg', decim=None,
if layer in dpl_trial.data.keys():

# extract scaled data and times
data, times = _get_plot_data_trange(dpl_trial.times,
dpl_trial.data[layer],
tmin, tmax)
data = dpl_trial.data[layer]
times = dpl_trial.times
if decim is not None:
data, times = _decimate_plot_data(decim, data, times)
if idx == len(dpl) - 1 and average:
Expand All @@ -300,7 +292,7 @@ def plot_dipole(dpl, tmin=None, tmax=None, ax=None, layer='agg', decim=None,
alpha = 0.5 if average else 1.
ax.plot(times, data, color=_lighten_color(color, 0.5),
alpha=alpha, lw=1.)

ax.set_xlim(right=dpl_trial.times[-1])
if average:
ax.legend()

Expand Down Expand Up @@ -477,12 +469,14 @@ def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None,

ax.set_ylabel("Counts")
ax.legend()
ax.set_xlim(right=cell_response.times[-1])

plt_show(show)
return ax.get_figure()


def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True):
def plot_spikes_raster(cell_response, trial_idx=None,
ax=None, show=True):
"""Plot the aggregate spiking activity according to cell type.

Parameters
Expand Down Expand Up @@ -538,6 +532,9 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True):
cell_type_times, cell_type_ypos = [], []
for gid in cell_type_gids:
gid_time = spike_times[spike_gids == gid]
_, gid_time = _get_plot_data_trange(
gid_time, gid_time, cell_response.times[0],
cell_response.times[-1])
tianqi-cheng marked this conversation as resolved.
Show resolved Hide resolved
cell_type_times.append(gid_time)
cell_type_ypos.append(ypos)
ypos = ypos - 1
Expand All @@ -552,7 +549,7 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True):
ax.set_facecolor('k')
ax.set_xlabel('Time (ms)')
ax.get_yaxis().set_visible(False)
ax.set_xlim(left=0)
ax.set_xlim(left=0, right=cell_response.times[-1])

plt_show(show)
return ax.get_figure()
Expand Down Expand Up @@ -1179,8 +1176,8 @@ def _onclick(event):
return ax.get_figure()


def plot_laminar_csd(times, data, contact_labels, ax=None, colorbar=True,
show=True):
def plot_laminar_csd(times, data, contact_labels,
ax=None, colorbar=True, show=True):
"""Plot laminar current source density (CSD) estimation from LFP array.

Parameters
Expand All @@ -1207,7 +1204,9 @@ def plot_laminar_csd(times, data, contact_labels, ax=None, colorbar=True,
import matplotlib.pyplot as plt
if ax is None:
_, ax = plt.subplots(1, 1, constrained_layout=True)

times, data = _get_plot_data_trange(times, data, times[0], times[-1])
_, contact_labels = _get_plot_data_trange(
times, contact_labels, times[0], times[-1])
im = ax.pcolormesh(times, contact_labels, np.array(data),
cmap="jet_r", shading='auto')
ax.set_title("CSD")
Expand All @@ -1218,6 +1217,7 @@ def plot_laminar_csd(times, data, contact_labels, ax=None, colorbar=True,

ax.set_xlabel('Time (ms)')
ax.set_ylabel('Electrode depth')
ax.set_xlim(right=times[-1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a nitpick, but I'd recommend defining tstop (or whatever you want to call it) before decimating so that you don't introduce a small error in the right x-limit.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Ryan, it seems that there is no decim in this function. Do you want me to do this for other functions which contain decim?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah you're right. I didn't realize that _get_plot_data_trange only crops the time series. Since this is the case, however, I don't think we need the added lines on L1207-1209 as they crop the time series to its original length (i.e., they don't actually crop anything off).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re: adding the decimate feature to plot_laminar_csd, my vote would be to save it for a separate PR.

plt.tight_layout()
plt_show(show)

Expand Down
Loading