Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Make all plots with time end at stop #752

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 22 additions & 15 deletions hnn_core/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from .viz import plot_dipole, plot_psd, plot_tfr_morlet


def simulate_dipole(net, tstop, dt=0.025, n_trials=None, record_vsec=False,
record_isec=False, postproc=False):
"""Simulate a dipole given the experiment parameters.
Expand Down Expand Up @@ -106,25 +105,20 @@ def simulate_dipole(net, tstop, dt=0.025, n_trials=None, record_vsec=False,
return dpls


def _read_dipole_txt(fname, extension='.txt'):
def _read_dipole_txt(fname):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@katduecker here is the problem! Since the time you had copied dipole.py to your desktop the master branch has changed. Therefore these modifications are undoing the latest commits to the master branch.

I would reccomend copying this function exactly as it's written on the current master branch and replacing it in your code. That way the only changes that show up in the diff are related to the tstop PR.

If done correctly you will no longer see _read_dipole_txt() show up in the files changed.

"""Read dipole values from a txt file and create a Dipole instance.

Parameters
----------
fname : str or io.StringIO
Full path to the input file (.txt or .csv) or
Content of file in memory as a StringIO
fname : str
Full path to the input file (.txt)

Returns
-------
dpl : Dipole
The instance of Dipole class
"""
if extension == '.csv':
# read from a csv file ignoring the headers
dpl_data = np.genfromtxt(fname, delimiter=',',
skip_header=1, dtype=float)
else:
dpl_data = np.loadtxt(fname, dtype=float)
dpl_data = np.loadtxt(fname, dtype=float)
ncols = dpl_data.shape[1]
if ncols not in (2, 4):
raise ValueError(
Expand Down Expand Up @@ -179,6 +173,10 @@ def read_dipole(fname):
The instance of Dipole class
"""

# For supporting tests in test_gui.py
if isinstance(fname, StringIO):
return _read_dipole_txt(fname)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you will want to copy the read_dipole() function from the master branch and replace it in your code as well


fname = str(fname)
if not os.path.exists(fname):
raise FileNotFoundError('File not found at path %s.' % (fname,))
Expand Down Expand Up @@ -474,10 +472,12 @@ def plot(self, tmin=None, tmax=None, layer='agg', decim=None, ax=None,

Parameters
----------
tmin : float or None
Start time of plot (in ms). If None, plot entire simulation.
tmax : float or None
End time of plot (in ms). If None, plot entire simulation.
tmin : float | None [deprecated]
Start time of plot in milliseconds.
If None, plot entire simulation.
tmax : float | None [deprecated]
End time of plot in milliseconds.
If None, plot entire simulation.
layer : str
The layer to plot. Can be one of 'agg', 'L2', and 'L5'
decimate : int
Expand All @@ -489,11 +489,18 @@ def plot(self, tmin=None, tmax=None, layer='agg', decim=None, ax=None,
show : bool
If True, show the figure

(tmin and tmax are deprecated)
tmin : float or None
Start time of plot (in ms). If None, plot entire simulation.
tmax : float or None
End time of plot (in ms). If None, plot entire simulation.

Returns
-------
fig : instance of plt.fig
The matplotlib figure handle.
"""

return plot_dipole(self, tmin=tmin, tmax=tmax, ax=ax, layer=layer,
decim=decim, color=color, show=show)

Expand Down
14 changes: 8 additions & 6 deletions hnn_core/extracellular.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

from .externals.mne import _validate_type, _check_option


def calculate_csd2d(lfp_data, delta=1):
"""Current source density (CSD) estimation

Expand Down Expand Up @@ -457,11 +456,12 @@ def plot_lfp(self, *, trial_no=None, contact_no=None, tmin=None, tmax=None,
Trial number(s) to plot
contact_no : int | list of int | slice
Electrode contact number(s) to plot
tmin : float | None
Start time of plot in milliseconds. If None, plot entire
simulation.
tmax : float | None
End time of plot in milliseconds. If None, plot entire simulation.
tmin : float | None [deprecated]
Start time of plot in milliseconds.
If None, plot entire simulation.
tmax : float | None [deprecated]
End time of plot in milliseconds.
If None, plot entire simulation.
ax : instance of matplotlib figure | None
The matplotlib axis
decim : int | list of int | None (default)
Expand Down Expand Up @@ -574,6 +574,7 @@ class _ExtracellularArrayBuilder(object):
The instance of :class:`hnn_core.extracellular.ExtracellularArray` to
build in NEURON-Python
"""

def __init__(self, array):
self.array = array
self.n_contacts = array.n_contacts
Expand Down Expand Up @@ -718,3 +719,4 @@ def _get_nrn_times(self):
return self._nrn_times.to_python()
else:
raise RuntimeError('Simulation not yet run!')

65 changes: 47 additions & 18 deletions hnn_core/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import numpy as np
from itertools import cycle
import colorsys
import warnings

from .externals.mne import _validate_type


def _lighten_color(color, amount=0.5):
import matplotlib.colors as mc
try:
Expand All @@ -21,7 +21,7 @@ def _lighten_color(color, amount=0.5):
return colorsys.hls_to_rgb(c[0], 1 - amount * (1 - c[1]), c[2])


def _get_plot_data_trange(times, data, tmin, tmax):
def _get_plot_data_trange(times, data, tmin=None, tmax=None):
"""Get slices of times and data based on tmin and tmax"""
if isinstance(times, list):
times = np.array(times)
Expand Down Expand Up @@ -79,8 +79,8 @@ def plt_show(show=True, fig=None, **kwargs):


def plot_laminar_lfp(times, data, contact_labels, tmin=None, tmax=None,
ax=None, decim=None, color='cividis',
voltage_offset=50, voltage_scalebar=200, show=True):
ax=None, decim=None, color='cividis', voltage_offset=50,
voltage_scalebar=200, show=True):
"""Plot laminar extracellular electrode array voltage time series.

Parameters
Expand All @@ -89,10 +89,11 @@ def plot_laminar_lfp(times, data, contact_labels, tmin=None, tmax=None,
Sampling times (in ms).
data : Two-dimensional Numpy array
The extracellular voltages as an (n_contacts, n_times) array.
tmin : float | None
tmin : float | None [deprecated]
Start time of plot in milliseconds. If None, plot entire simulation.
tmax : float | None
tmax : float | None [deprecated]
End time of plot in milliseconds. If None, plot entire simulation.

ax : instance of matplotlib figure | None
The matplotlib axis
decim : int | list of int | None (default)
Expand Down Expand Up @@ -168,11 +169,11 @@ def plot_laminar_lfp(times, data, contact_labels, tmin=None, tmax=None,
trace_offsets = np.arange(n_offsets)[:, np.newaxis] * voltage_offset

for contact_no, trace in enumerate(np.atleast_2d(data)):
plot_data, plot_times = _get_plot_data_trange(times, trace, tmin, tmax)
plot_data = trace
plot_times = times

if decim is not None:
plot_data, plot_times = _decimate_plot_data(decim, plot_data,
plot_times)
plot_data, plot_times = _decimate_plot_data(decim, trace, times)

if isinstance(color, np.ndarray):
col = color[contact_no]
Expand All @@ -183,6 +184,14 @@ def plot_laminar_lfp(times, data, contact_labels, tmin=None, tmax=None,
ax.plot(plot_times, plot_data + trace_offsets[contact_no],
label=f'C{contact_no}', color=col)

# To be removed after deprecation cycle
if tmin is not None or tmax is not None:
ax.set_xlim(left=tmin, right=tmax)
warnings.warn('tmin and tmax are deprecated and will be removed in future releases of hnn-core.'
'By default, dipoles and laminar LFPs are now plotted from 0 to tstop.', DeprecationWarning)
else:
ax.set_xlim(right=times[-1])

if voltage_offset is not None:
ax.set_ylim(-voltage_offset, n_offsets * voltage_offset)
ylabel = 'Individual contact traces'
Expand Down Expand Up @@ -228,9 +237,9 @@ def plot_dipole(dpl, tmin=None, tmax=None, ax=None, layer='agg', decim=None,
----------
dpl : instance of Dipole | list of Dipole instances
The Dipole object.
tmin : float or None
tmin : float | None [deprecated]
Start time of plot in milliseconds. If None, plot entire simulation.
tmax : float or None
tmax : float | None [deprecated]
End time of plot in milliseconds. If None, plot entire simulation.
ax : instance of matplotlib figure | None
The matplotlib axis
Expand Down Expand Up @@ -288,9 +297,14 @@ def plot_dipole(dpl, tmin=None, tmax=None, ax=None, layer='agg', decim=None,
if layer in dpl_trial.data.keys():

# extract scaled data and times
data, times = _get_plot_data_trange(dpl_trial.times,
dpl_trial.data[layer],
tmin, tmax)
data = dpl_trial.data[layer]
times = dpl_trial.times

# to be removed after deprecation cycle
if tmin is not None:
data, times = _get_plot_data_trange(dpl_trial.times,
dpl_trial.data[layer])

if decim is not None:
data, times = _decimate_plot_data(decim, data, times)
if idx == len(dpl) - 1 and average:
Expand All @@ -300,7 +314,14 @@ def plot_dipole(dpl, tmin=None, tmax=None, ax=None, layer='agg', decim=None,
alpha = 0.5 if average else 1.
ax.plot(times, data, color=_lighten_color(color, 0.5),
alpha=alpha, lw=1.)
# To be removed after deprecation cycle
if tmin is not None or tmax is not None:
ax.set_xlim(left=tmin, right=tmax)
warnings.warn('tmin and tmax are deprecated and will be removed in future releases of hnn-core.'
'By default, dipoles and laminar LFPs are now plotted from 0 to tstop.', DeprecationWarning)

else:
ax.set_xlim(right=times[-1])
if average:
ax.legend()

Expand Down Expand Up @@ -330,6 +351,7 @@ def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None,
----------
cell_response : instance of CellResponse
The CellResponse object from net.cell_response
End time of plot in milliseconds. If None, plot entire simulation.
trial_idx : int | list of int | None
Index of trials to be plotted. If None, all trials plotted.
ax : instance of matplotlib axis | None
Expand Down Expand Up @@ -480,11 +502,14 @@ def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None,
ax.set_ylabel("Counts")
ax.legend()

ax.set_xlim(left=0, right=cell_response.times[-1])

plt_show(show)
return ax.get_figure()


def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True):
def plot_spikes_raster(cell_response, trial_idx=None,
ax=None, show=True):
"""Plot the aggregate spiking activity according to cell type.

Parameters
Expand Down Expand Up @@ -540,6 +565,9 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True):
cell_type_times, cell_type_ypos = [], []
for gid in cell_type_gids:
gid_time = spike_times[spike_gids == gid]
_, gid_time = _get_plot_data_trange(
gid_time, gid_time, cell_response.times[0],
cell_response.times[-1])
cell_type_times.append(gid_time)
cell_type_ypos.append(ypos)
ypos = ypos - 1
Expand All @@ -554,7 +582,7 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True):
ax.set_facecolor('k')
ax.set_xlabel('Time (ms)')
ax.get_yaxis().set_visible(False)
ax.set_xlim(left=0)
ax.set_xlim(left=0, right=cell_response.times[-1])

plt_show(show)
return ax.get_figure()
Expand Down Expand Up @@ -1185,8 +1213,8 @@ def _onclick(event):
return ax.get_figure()


def plot_laminar_csd(times, data, contact_labels, ax=None, colorbar=True,
show=True):
def plot_laminar_csd(times, data, contact_labels,
ax=None, colorbar=True, show=True):
"""Plot laminar current source density (CSD) estimation from LFP array.

Parameters
Expand Down Expand Up @@ -1224,6 +1252,7 @@ def plot_laminar_csd(times, data, contact_labels, ax=None, colorbar=True,

ax.set_xlabel('Time (ms)')
ax.set_ylabel('Electrode depth')
ax.set_xlim(right=times[-1])
plt.tight_layout()
plt_show(show)

Expand Down
Loading