Skip to content

Commit

Permalink
add deprecation cycle tmin/tmax
Browse files Browse the repository at this point in the history
  • Loading branch information
katduecker committed May 10, 2024
1 parent fdc6335 commit 89042b7
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 31 deletions.
40 changes: 26 additions & 14 deletions hnn_core/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from .viz import plot_dipole, plot_psd, plot_tfr_morlet


def simulate_dipole(net, tstop, dt=0.025, n_trials=None, record_vsec=False,
record_isec=False, postproc=False):
"""Simulate a dipole given the experiment parameters.
Expand Down Expand Up @@ -106,25 +105,20 @@ def simulate_dipole(net, tstop, dt=0.025, n_trials=None, record_vsec=False,
return dpls


def _read_dipole_txt(fname, extension='.txt'):
def _read_dipole_txt(fname):
"""Read dipole values from a txt file and create a Dipole instance.
Parameters
----------
fname : str or io.StringIO
Full path to the input file (.txt or .csv) or
Content of file in memory as a StringIO
fname : str
Full path to the input file (.txt)
Returns
-------
dpl : Dipole
The instance of Dipole class
"""
if extension == '.csv':
# read from a csv file ignoring the headers
dpl_data = np.genfromtxt(fname, delimiter=',',
skip_header=1, dtype=float)
else:
dpl_data = np.loadtxt(fname, dtype=float)
dpl_data = np.loadtxt(fname, dtype=float)
ncols = dpl_data.shape[1]
if ncols not in (2, 4):
raise ValueError(
Expand Down Expand Up @@ -179,6 +173,10 @@ def read_dipole(fname):
The instance of Dipole class
"""

# For supporting tests in test_gui.py
if isinstance(fname, StringIO):
return _read_dipole_txt(fname)

fname = str(fname)
if not os.path.exists(fname):
raise FileNotFoundError('File not found at path %s.' % (fname,))
Expand Down Expand Up @@ -468,11 +466,18 @@ def savgol_filter(self, h_freq):
self.sfreq)
return self

def plot(self, layer='agg', decim=None, ax=None, color='k', show=True):
def plot(self, tmin=None, tmax=None, layer='agg', decim=None, ax=None,
color='k', show=True):
"""Simple layer-specific plot function.
Parameters
----------
tmin : float | None [deprecated]
Start time of plot in milliseconds.
If None, plot entire simulation.
tmax : float | None [deprecated]
End time of plot in milliseconds.
If None, plot entire simulation.
layer : str
The layer to plot. Can be one of 'agg', 'L2', and 'L5'
decimate : int
Expand All @@ -484,13 +489,20 @@ def plot(self, layer='agg', decim=None, ax=None, color='k', show=True):
show : bool
If True, show the figure
(tmin and tmax are deprecated)
tmin : float or None
Start time of plot (in ms). If None, plot entire simulation.
tmax : float or None
End time of plot (in ms). If None, plot entire simulation.
Returns
-------
fig : instance of plt.fig
The matplotlib figure handle.
"""
return plot_dipole(self, ax=ax, layer=layer, decim=decim, color=color,
show=show)

return plot_dipole(self, tmin=tmin, tmax=tmax, ax=ax, layer=layer,
decim=decim, color=color, show=show)

def plot_psd(self, fmin=0, fmax=None, tmin=None, tmax=None, layer='agg',
color=None, label=None, ax=None, show=True):
Expand Down
16 changes: 11 additions & 5 deletions hnn_core/extracellular.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

from .externals.mne import _validate_type, _check_option


def calculate_csd2d(lfp_data, delta=1):
"""Current source density (CSD) estimation
Expand Down Expand Up @@ -443,9 +442,9 @@ def smooth(self, window_len):

return self

def plot_lfp(self, *, trial_no=None, contact_no=None, ax=None, decim=None,
color='cividis', voltage_offset=50, voltage_scalebar=200,
show=True):
def plot_lfp(self, *, trial_no=None, contact_no=None, tmin=None, tmax=None,
ax=None, decim=None, color='cividis', voltage_offset=50,
voltage_scalebar=200, show=True):
"""Plot laminar local field potential time series.
One plot is created for each trial. Multiple trials can be overlaid
Expand All @@ -457,6 +456,12 @@ def plot_lfp(self, *, trial_no=None, contact_no=None, ax=None, decim=None,
Trial number(s) to plot
contact_no : int | list of int | slice
Electrode contact number(s) to plot
tmin : float | None [deprecated]
Start time of plot in milliseconds.
If None, plot entire simulation.
tmax : float | None [deprecated]
End time of plot in milliseconds.
If None, plot entire simulation.
ax : instance of matplotlib figure | None
The matplotlib axis
decim : int | list of int | None (default)
Expand Down Expand Up @@ -503,7 +508,7 @@ def plot_lfp(self, *, trial_no=None, contact_no=None, ax=None, decim=None,

for trial_data in plot_data:
fig = plot_laminar_lfp(
self.times, trial_data, ax=ax,
self.times, trial_data, tmin=tmin, tmax=tmax, ax=ax,
decim=decim, color=color,
voltage_offset=voltage_offset,
voltage_scalebar=voltage_scalebar,
Expand Down Expand Up @@ -714,3 +719,4 @@ def _get_nrn_times(self):
return self._nrn_times.to_python()
else:
raise RuntimeError('Simulation not yet run!')

53 changes: 41 additions & 12 deletions hnn_core/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import numpy as np
from itertools import cycle
import colorsys
import warnings

from .externals.mne import _validate_type


def _lighten_color(color, amount=0.5):
import matplotlib.colors as mc
try:
Expand All @@ -21,7 +21,7 @@ def _lighten_color(color, amount=0.5):
return colorsys.hls_to_rgb(c[0], 1 - amount * (1 - c[1]), c[2])


def _get_plot_data_trange(times, data, tmin, tmax):
def _get_plot_data_trange(times, data, tmin=None, tmax=None):
"""Get slices of times and data based on tmin and tmax"""
if isinstance(times, list):
times = np.array(times)
Expand Down Expand Up @@ -78,9 +78,9 @@ def plt_show(show=True, fig=None, **kwargs):
(fig or plt).show(**kwargs)


def plot_laminar_lfp(times, data, contact_labels, ax=None, decim=None,
color='cividis', voltage_offset=50, voltage_scalebar=200,
show=True):
def plot_laminar_lfp(times, data, contact_labels, tmin=None, tmax=None,
ax=None, decim=None, color='cividis', voltage_offset=50,
voltage_scalebar=200, show=True):
"""Plot laminar extracellular electrode array voltage time series.
Parameters
Expand All @@ -89,6 +89,11 @@ def plot_laminar_lfp(times, data, contact_labels, ax=None, decim=None,
Sampling times (in ms).
data : Two-dimensional Numpy array
The extracellular voltages as an (n_contacts, n_times) array.
tmin : float | None [deprecated]
Start time of plot in milliseconds. If None, plot entire simulation.
tmax : float | None [deprecated]
End time of plot in milliseconds. If None, plot entire simulation.
ax : instance of matplotlib figure | None
The matplotlib axis
decim : int | list of int | None (default)
Expand Down Expand Up @@ -178,7 +183,14 @@ def plot_laminar_lfp(times, data, contact_labels, ax=None, decim=None,
col = color
ax.plot(plot_times, plot_data + trace_offsets[contact_no],
label=f'C{contact_no}', color=col)
ax.set_xlim(right=times[-1])

# To be removed after deprecation cycle
if tmin is not None or tmax is not None:
ax.set_xlim(left=tmin, right=tmax)
warnings.warn('tmin and tmax are deprecated and will be removed in future releases of hnn-core.'
'By default, dipoles and laminar LFPs are now plotted from 0 to tstop.', DeprecationWarning)
else:
ax.set_xlim(right=times[-1])

if voltage_offset is not None:
ax.set_ylim(-voltage_offset, n_offsets * voltage_offset)
Expand Down Expand Up @@ -217,14 +229,18 @@ def plot_laminar_lfp(times, data, contact_labels, ax=None, decim=None,
return ax.get_figure()


def plot_dipole(dpl, ax=None, layer='agg', decim=None,
def plot_dipole(dpl, tmin=None, tmax=None, ax=None, layer='agg', decim=None,
color='k', label="average", average=False, show=True):
"""Simple layer-specific plot function.
Parameters
----------
dpl : instance of Dipole | list of Dipole instances
The Dipole object.
tmin : float | None [deprecated]
Start time of plot in milliseconds. If None, plot entire simulation.
tmax : float | None [deprecated]
End time of plot in milliseconds. If None, plot entire simulation.
ax : instance of matplotlib figure | None
The matplotlib axis
layer : str
Expand Down Expand Up @@ -283,6 +299,12 @@ def plot_dipole(dpl, ax=None, layer='agg', decim=None,
# extract scaled data and times
data = dpl_trial.data[layer]
times = dpl_trial.times

# to be removed after deprecation cycle
if tmin is not None:
data, times = _get_plot_data_trange(dpl_trial.times,
dpl_trial.data[layer])

if decim is not None:
data, times = _decimate_plot_data(decim, data, times)
if idx == len(dpl) - 1 and average:
Expand All @@ -292,7 +314,14 @@ def plot_dipole(dpl, ax=None, layer='agg', decim=None,
alpha = 0.5 if average else 1.
ax.plot(times, data, color=_lighten_color(color, 0.5),
alpha=alpha, lw=1.)
ax.set_xlim(right=dpl_trial.times[-1])
# To be removed after deprecation cycle
if tmin is not None or tmax is not None:
ax.set_xlim(left=tmin, right=tmax)
warnings.warn('tmin and tmax are deprecated and will be removed in future releases of hnn-core.'
'By default, dipoles and laminar LFPs are now plotted from 0 to tstop.', DeprecationWarning)

else:
ax.set_xlim(right=times[-1])
if average:
ax.legend()

Expand Down Expand Up @@ -322,6 +351,7 @@ def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None,
----------
cell_response : instance of CellResponse
The CellResponse object from net.cell_response
End time of plot in milliseconds. If None, plot entire simulation.
trial_idx : int | list of int | None
Index of trials to be plotted. If None, all trials plotted.
ax : instance of matplotlib axis | None
Expand Down Expand Up @@ -471,7 +501,8 @@ def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None,

ax.set_ylabel("Counts")
ax.legend()
ax.set_xlim(right=cell_response.times[-1])

ax.set_xlim(left=0, right=cell_response.times[-1])

plt_show(show)
return ax.get_figure()
Expand Down Expand Up @@ -1210,9 +1241,7 @@ def plot_laminar_csd(times, data, contact_labels,
import matplotlib.pyplot as plt
if ax is None:
_, ax = plt.subplots(1, 1, constrained_layout=True)
times, data = _get_plot_data_trange(times, data, times[0], times[-1])
_, contact_labels = _get_plot_data_trange(
times, contact_labels, times[0], times[-1])

im = ax.pcolormesh(times, contact_labels, np.array(data),
cmap="jet_r", shading='auto')
ax.set_title("CSD")
Expand Down

0 comments on commit 89042b7

Please sign in to comment.