Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 15, 2024
1 parent d3bee9a commit 0865dcd
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
sparsity = compute_sparsity(templates, noise_levels, **params["sparsity"])
templates = templates.to_sparse(sparsity)
templates = remove_empty_templates(templates)

if params["debug"]:
templates.to_zarr(folder_path=clustering_folder / "templates")
sorting = sorting.save(folder=clustering_folder / "sorting")
Expand Down
23 changes: 15 additions & 8 deletions src/spikeinterface/sortingcomponents/clustering/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class CircusClustering:
"n_svd": [5, 2],
"ms_before": 0.5,
"ms_after": 0.5,
"noise_threshold" : 1,
"noise_threshold": 1,
"rank": 5,
"noise_levels": None,
"tmp_folder": None,
Expand Down Expand Up @@ -231,13 +231,20 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs)

templates_array, templates_array_std = estimate_templates_with_accumulator(
recording, spikes, unit_ids, nbefore, nafter, return_scaled=False, return_std=True, job_name=None, **job_kwargs
recording,
spikes,
unit_ids,
nbefore,
nafter,
return_scaled=False,
return_std=True,
job_name=None,
**job_kwargs,
)

peak_snrs = np.abs(templates_array[:, nbefore, :])/templates_array_std[:, nbefore, :]
valid_templates = np.linalg.norm(peak_snrs, axis=1)/np.linalg.norm(params["noise_levels"])
valid_templates = valid_templates > params["noise_threshold"]

peak_snrs = np.abs(templates_array[:, nbefore, :]) / templates_array_std[:, nbefore, :]
valid_templates = np.linalg.norm(peak_snrs, axis=1) / np.linalg.norm(params["noise_levels"])
valid_templates = valid_templates > params["noise_threshold"]

if d["rank"] is not None:
from spikeinterface.sortingcomponents.matching.circus import compress_templates
Expand All @@ -254,12 +261,12 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
probe=recording.get_probe(),
is_scaled=False,
)

sparsity = compute_sparsity(templates, noise_levels=params["noise_levels"], **params["sparsity"])
templates = templates.to_sparse(sparsity)
empty_templates = templates.sparsity_mask.sum(axis=1) == 0
templates = remove_empty_templates(templates)

mask = np.isin(peak_labels, np.where(empty_templates)[0])
peak_labels[mask] = -1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class RandomProjectionClustering:
"random_seed": 42,
"noise_levels": None,
"smoothing_kwargs": {"window_length_ms": 0.25},
"noise_threshold" : 1,
"noise_threshold": 1,
"tmp_folder": None,
"verbose": True,
}
Expand Down Expand Up @@ -134,11 +134,19 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs)

templates_array, templates_array_std = estimate_templates_with_accumulator(
recording, spikes, unit_ids, nbefore, nafter, return_scaled=False, return_std=True, job_name=None, **job_kwargs
recording,
spikes,
unit_ids,
nbefore,
nafter,
return_scaled=False,
return_std=True,
job_name=None,
**job_kwargs,
)
peak_snrs = np.abs(templates_array[:, nbefore, :])/templates_array_std[:, nbefore, :]
valid_templates = np.linalg.norm(peak_snrs, axis=1)/np.linalg.norm(params["noise_levels"])

peak_snrs = np.abs(templates_array[:, nbefore, :]) / templates_array_std[:, nbefore, :]
valid_templates = np.linalg.norm(peak_snrs, axis=1) / np.linalg.norm(params["noise_levels"])
valid_templates = valid_templates > params["noise_threshold"]

templates = Templates(
Expand All @@ -151,12 +159,12 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
probe=recording.get_probe(),
is_scaled=False,
)

sparsity = compute_sparsity(templates, noise_levels=params["noise_levels"], **params["sparsity"])
templates = templates.to_sparse(sparsity)
empty_templates = templates.sparsity_mask.sum(axis=1) == 0
templates = remove_empty_templates(templates)

mask = np.isin(peak_labels, np.where(empty_templates)[0])
peak_labels[mask] = -1

Expand Down

0 comments on commit 0865dcd

Please sign in to comment.