Skip to content

Commit

Permalink
WIP: Add notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed Jan 10, 2025
1 parent a6f26f5 commit 9d11d80
Show file tree
Hide file tree
Showing 5 changed files with 430 additions and 37 deletions.
352 changes: 352 additions & 0 deletions notebooks/_TEMP_Burst.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions src/spyglass/spikesorting/v0/spikesorting_burst.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,8 @@ def _validate_pair(

def _validate_pairs(self, key, pairs):
query = self.BurstPairUnit & key
if isinstance(pairs, tuple) and len(pairs) == 2:
pairs = [pairs]
valid_pairs = []
for p in pairs:
if valid_pair := self._validate_pair(query, *p):
Expand Down
97 changes: 66 additions & 31 deletions src/spyglass/spikesorting/v1/burst_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
)

from spyglass.decoding.utils import _get_peak_amplitude
from spyglass.spikesorting.v1.metric_curation import CurationV1, MetricCuration
from spyglass.spikesorting.v1.metric_curation import (
CurationV1,
MetricCuration,
MetricCurationSelection,
)
from spyglass.utils import logger

schema = dj.schema("burst_v1") # TODO: rename to spikesorting_burst_v1
Expand Down Expand Up @@ -124,13 +128,17 @@ def _null_insert(self, key, msg="No units found for") -> None:

def _curation_key(self, key):
"""Get the CurationV1 key for a given BurstPair key"""
return (
(self & key).proj() * MetricCurationSelection * CurationV1
).fetch1("curation_id", "sorting_id", as_dict=True)
ret = (
(BurstPairSelection & key)
* MetricCuration
* MetricCurationSelection
).fetch("curation_id", "sorting_id", as_dict=True)
if len(ret) != 1:
raise ValueError(f"Found {len(ret)} curation entries for {key}")
return ret[0]

@staticmethod
def _get_peak_amps1(
waves: WaveformExtractor, unit: int, timestamp_ind: int
self, waves: WaveformExtractor, unit: int, timestamp_ind: int
):
"""Get peak value for a unit at a given timestamp index"""
wave = _get_peak_amplitude(
Expand All @@ -139,8 +147,20 @@ def _get_peak_amps1(
peak_sign="neg",
estimate_peak_time=True,
)

# PROBLEM: example key showed timestamp_id larger than wave length
timestamp_ind, wave = self._truncate_to_shortest(
"", timestamp_ind, wave
)
return wave[timestamp_ind]

def _truncate_to_shortest(self, msg="", *args):
"""Truncate all arrays to the shortest length"""
if msg and not all([len(a) == len(args[0]) for a in args]):
logger.warning(f"Truncating arrays to shortest length: {msg}")
min_len = min([len(a) for a in args])
return [a[:min_len] for a in args]

def get_peak_amps(
self, key: dict
) -> Tuple[Dict[int, np.ndarray], Dict[int, np.ndarray]]:
Expand Down Expand Up @@ -175,12 +195,15 @@ def get_peak_amps(

peak_amps, peak_timestamps = {}, {}
for unit_id in unit_ids:
timestamp = np.asarray(sortong["spike_times"][unit_id])
timestamp = np.asarray(sorting["spike_times"][unit_id])
timestamp_ind = np.argsort(timestamp)
peak_amps[unit_id] = self._get_peak_amps1(
waves, unit_id, timestamp_ind
upeak = self._get_peak_amps1(waves, unit_id, timestamp_ind)
utime = timestamp[timestamp_ind]
upeak, utime = self._truncate_to_shortest(
f"unit {unit_id}", upeak, utime
)
peak_timestamps[unit_id] = timestamp[timestamp_ind]
peak_amps[unit_id] = upeak
peak_timestamps[unit_id] = utime

self._peak_amp_cache[key_hash] = peak_amps, peak_timestamps

Expand Down Expand Up @@ -304,17 +327,14 @@ def make(self, key) -> None:
self.BurstPairUnit.insert(unit_pairs)

@staticmethod
def _plot_metrics(sg_query):
def _plot_metrics(sort_query):
"""parameters are 4 metrics to be plotted against each other.
Parameters
----------
wf_similarity : dict
waveform similarities
isi_violation : dict
isi violation
xcorrel_asymm : dict
spike cross correlogram asymmetry
sort_query : dj.QueryExpression
query to get the metrics for plotting, including wf_similarity,
and xcorrel_asymm. One row per soring_id
Returns
-------
Expand All @@ -323,7 +343,7 @@ def _plot_metrics(sg_query):

fig, ax = plt.subplots(1, 1, figsize=(12, 5))

for color_ind, row in enumerate(sg_query):
for color_ind, row in enumerate(sort_query):
color = dict(color=f"C{color_ind}")
wf = row["wf_similarity"]
ca = row["xcorrel_asymm"]
Expand All @@ -337,25 +357,29 @@ def _plot_metrics(sg_query):

return fig

def _get_fig_by_sg_id(self, key, sort_group_ids=None):
query = self.BurstPairUnit & key
def _get_fig_by_sort_id(self, key, sorting_ids=None):
query = (
(self.BurstPairUnit & key)
* MetricCuration
* MetricCurationSelection
)

if isinstance(sort_group_ids, int):
sort_group_ids = [sort_group_ids]
if isinstance(sorting_ids, str):
sorting_ids = [sorting_ids]

if sort_group_ids:
query &= f'sort_group_id IN ({",".join(map(str, sort_group_ids))})'
if sorting_ids:
query = query.restrict_by_list("sorting_id", sorting_ids)
else:
sort_group_ids = np.unique(query.fetch("sort_group_id"))
sorting_ids = np.unique(query.fetch("sorting_id"))

fig = {}
for sort_group_id in sort_group_ids:
sg_query = query & {"sort_group_id": sort_group_id}
for sort_group_id in sorting_ids:
sg_query = query & {"sorting_id": sort_group_id}
fig[sort_group_id] = self._plot_metrics(sg_query)
return fig

def plot_by_sort_group_ids(self, key, sort_group_ids=None):
fig = self._get_fig_by_sg_id(key, sort_group_ids)
def plot_by_sorting_ids(self, key, sort_group_ids=None):
fig = self._get_fig_by_sort_id(key, sort_group_ids)
for sg_id, f in fig.items():
title = f"sort group {sg_id}"
managed_fig, _ = plt.subplots(
Expand Down Expand Up @@ -391,6 +415,8 @@ def _validate_pair(

def _validate_pairs(self, key, pairs):
query = self.BurstPairUnit & key
if isinstance(pairs, tuple) and len(pairs) == 2:
pairs = [pairs]
valid_pairs = []
for p in pairs:
if valid_pair := self._validate_pair(query, *p):
Expand All @@ -404,7 +430,7 @@ def investigate_pair_xcorrel(self, key, to_investigate_pairs):

col_num = int(np.ceil(len(used_pairs) / 2))

fig = self._get_fig_by_sg_id(key)
fig = self._get_fig_by_sort_id(key)

fig, axes = plt.subplots(
2,
Expand Down Expand Up @@ -508,6 +534,7 @@ def plot_1peak_over_time(
row_duration: int = 600,
show_plot: bool = False,
):

max_channel = np.argmax(-np.mean(voltages, 0))
time_since = timestamps - timestamps[0]
row_num = int(np.ceil(time_since[-1] / row_duration))
Expand All @@ -522,12 +549,20 @@ def plot_1peak_over_time(
squeeze=False,
)

# PROBLEM: example key showed sub_ind larger than voltages
def select_voltages(voltages, sub_ind):
if len(sub_ind) > len(voltages):
sub_ind = sub_ind[: len(voltages)]
logger.warning("Timestamp index out of bounds, truncating")
return voltages[sub_ind, max_channel]

for ind in range(row_num):
t0 = ind * row_duration
t1 = t0 + row_duration
sub_ind = np.logical_and(time_since >= t0, time_since <= t1)
# PROBLEM: axes[2, 0] out of bounds for some pairs
axes[ind, 0].scatter(
time_since[sub_ind] - t0, voltages[sub_ind, max_channel]
time_since[sub_ind] - t0, select_voltages(voltages, sub_ind)
)

if not show_plot:
Expand Down
10 changes: 5 additions & 5 deletions src/spyglass/spikesorting/v1/curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,15 +261,15 @@ def get_merged_sorting(cls, key: dict) -> si.BaseSorting:
) as io:
nwbfile = io.read()
nwb_sorting = nwbfile.objects[curation_key["object_id"]]
merge_groups = nwb_sorting["merge_groups"][:]
merge_groups = nwb_sorting.get("merge_groups")

if merge_groups:
units_to_merge = _merge_dict_to_list(merge_groups)
if merge_groups: # bumped slice down to here for case w/o merge_groups
units_to_merge = _merge_dict_to_list(merge_groups[:])
return sc.MergeUnitsSorting(
parent_sorting=si_sorting, units_to_merge=units_to_merge
)
else:
return si_sorting

return si_sorting

@classmethod
def get_sort_group_info(cls, key: dict) -> dj.Table:
Expand Down
6 changes: 5 additions & 1 deletion src/spyglass/spikesorting/v1/metric_curation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import uuid
from pathlib import Path
from time import time
from typing import Any, Dict, List, Union

Expand Down Expand Up @@ -305,7 +306,10 @@ def get_waveforms(self, key: dict, overwrite: bool = True):
recording = sp.whiten(recording, dtype=np.float64)

waveforms_dir = temp_dir + "/" + str(key["metric_curation_id"])
os.makedirs(waveforms_dir, exist_ok=True)
wf_dir_obj = Path(waveforms_dir)
wf_dir_obj.mkdir(parents=True, exist_ok=True)
if not any(wf_dir_obj.iterdir()): # if the directory is empty
overwrite = True

# Extract non-sparse waveforms by default
waveform_params.setdefault("sparse", False)
Expand Down

0 comments on commit 9d11d80

Please sign in to comment.