diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index bea2c60770..beaa737dd4 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -286,7 +286,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sparsity = compute_sparsity(templates, noise_levels, **params["sparsity"]) templates = templates.to_sparse(sparsity) templates = remove_empty_templates(templates) - + if params["debug"]: templates.to_zarr(folder_path=clustering_folder / "templates") sorting = sorting.save(folder=clustering_folder / "sorting") diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 27b2c9493b..f1038c73bc 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -60,7 +60,7 @@ class CircusClustering: "n_svd": [5, 2], "ms_before": 0.5, "ms_after": 0.5, - "noise_threshold" : 1, + "noise_threshold": 1, "rank": 5, "noise_levels": None, "tmp_folder": None, @@ -231,13 +231,20 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs) templates_array, templates_array_std = estimate_templates_with_accumulator( - recording, spikes, unit_ids, nbefore, nafter, return_scaled=False, return_std=True, job_name=None, **job_kwargs + recording, + spikes, + unit_ids, + nbefore, + nafter, + return_scaled=False, + return_std=True, + job_name=None, + **job_kwargs, ) - - peak_snrs = np.abs(templates_array[:, nbefore, :])/templates_array_std[:, nbefore, :] - valid_templates = np.linalg.norm(peak_snrs, axis=1)/np.linalg.norm(params["noise_levels"]) - valid_templates = valid_templates > params["noise_threshold"] + peak_snrs = np.abs(templates_array[:, nbefore, :]) / templates_array_std[:, nbefore, :] + valid_templates = np.linalg.norm(peak_snrs, axis=1) / np.linalg.norm(params["noise_levels"]) + valid_templates = valid_templates > params["noise_threshold"] if d["rank"] is not None: from spikeinterface.sortingcomponents.matching.circus import compress_templates @@ -254,12 +261,12 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): probe=recording.get_probe(), is_scaled=False, ) - + sparsity = compute_sparsity(templates, noise_levels=params["noise_levels"], **params["sparsity"]) templates = templates.to_sparse(sparsity) empty_templates = templates.sparsity_mask.sum(axis=1) == 0 templates = remove_empty_templates(templates) - + mask = np.isin(peak_labels, np.where(empty_templates)[0]) peak_labels[mask] = -1 diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 34c6fd8b68..922387695e 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -53,7 +53,7 @@ class RandomProjectionClustering: "random_seed": 42, "noise_levels": None, "smoothing_kwargs": {"window_length_ms": 0.25}, - "noise_threshold" : 1, + "noise_threshold": 1, "tmp_folder": None, "verbose": True, } @@ -134,11 +134,19 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs) templates_array, templates_array_std = estimate_templates_with_accumulator( - recording, spikes, unit_ids, nbefore, nafter, return_scaled=False, return_std=True, job_name=None, **job_kwargs + recording, + spikes, + unit_ids, + nbefore, + nafter, + return_scaled=False, + return_std=True, + job_name=None, + **job_kwargs, ) - - peak_snrs = np.abs(templates_array[:, nbefore, :])/templates_array_std[:, nbefore, :] - valid_templates = np.linalg.norm(peak_snrs, axis=1)/np.linalg.norm(params["noise_levels"]) + + peak_snrs = np.abs(templates_array[:, nbefore, :]) / templates_array_std[:, nbefore, :] + valid_templates = np.linalg.norm(peak_snrs, axis=1) / np.linalg.norm(params["noise_levels"]) valid_templates = valid_templates > params["noise_threshold"] templates = Templates( @@ -151,12 +159,12 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): probe=recording.get_probe(), is_scaled=False, ) - + sparsity = compute_sparsity(templates, noise_levels=params["noise_levels"], **params["sparsity"]) templates = templates.to_sparse(sparsity) empty_templates = templates.sparsity_mask.sum(axis=1) == 0 templates = remove_empty_templates(templates) - + mask = np.isin(peak_labels, np.where(empty_templates)[0]) peak_labels[mask] = -1