diff --git a/src/spikeinterface/preprocessing/tests/test_whiten.py b/src/spikeinterface/preprocessing/tests/test_whiten.py index 182950a173..feb6135cfa 100644 --- a/src/spikeinterface/preprocessing/tests/test_whiten.py +++ b/src/spikeinterface/preprocessing/tests/test_whiten.py @@ -475,20 +475,17 @@ def test_passed_W_and_M(self): def test_whiten_general(self, create_cache_folder): """ -<<<<<<< HEAD -======= Perform some general tests on the whitening functionality. First, perform smoke test that `compute_whitening_matrix` is running, check recording output datatypes are as expected. Check that saving preseves datatype, `int_scale` is propagated, and regularisation reduces the norm. ->>>>>>> 19105e25a (Finalise tests, tidy up and add docs.) """ cache_folder = create_cache_folder rec = generate_recording(num_channels=4, seed=2205) - random_chunk_kwargs = {} + random_chunk_kwargs = {"seed": 2205} W1, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, radius_um=None) with pytest.raises(AssertionError): @@ -501,28 +498,23 @@ def test_whiten_general(self, create_cache_folder): rec2 = whiten(rec) rec2.save(verbose=False) - # test dtype - rec_int = scale(rec2, dtype="int16") - rec3 = whiten(rec_int, dtype="float16") - rec3 = rec3.save(folder=cache_folder / "rec1") - assert rec3.get_dtype() == "float16" + # test dtype + rec_int = scale(rec2, dtype="int16") + rec3 = whiten(rec_int, dtype="float16") + rec3 = rec3.save(folder=cache_folder / "rec1") + assert rec3.get_dtype() == "float16" - # test parallel - rec_par = rec3.save(folder=cache_folder / "rec_par", n_jobs=2) - np.testing.assert_array_equal(rec3.get_traces(segment_index=0), rec_par.get_traces(segment_index=0)) - - with pytest.raises(AssertionError): - rec4 = whiten(rec_int, dtype=None) # int_scale should be applied - rec4 = whiten(rec_int, dtype=None, int_scale=256) - assert rec4.get_dtype() == "int16" - assert rec4._kwargs["M"] is None - - # test regularization : norm should be smaller - W2, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, regularize=True) - assert np.linalg.norm(W1) > np.linalg.norm(W2) - - # test regularization : norm should be smaller - if HAS_SKLEARN: - W2, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, regularize=True) - assert np.linalg.norm(W1) > np.linalg.norm(W2) + # test parallel + rec_par = rec3.save(folder=cache_folder / "rec_par", n_jobs=2) + np.testing.assert_array_equal(rec3.get_traces(segment_index=0), rec_par.get_traces(segment_index=0)) + with pytest.raises(AssertionError): + rec4 = whiten(rec_int, dtype=None) # int_scale should be applied + rec4 = whiten(rec_int, dtype=None, int_scale=256) + assert rec4.get_dtype() == "int16" + assert rec4._kwargs["M"] is None + + # test regularization : norm should be smaller + if HAS_SKLEARN: + W2, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, regularize=True) + assert np.linalg.norm(W1) > np.linalg.norm(W2) diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index c28b9e67c7..a40f2ac5a5 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -76,9 +76,9 @@ def __init__( dtype_ = fix_dtype(recording, dtype) if dtype_.kind == "i": - assert int_scale is not None, ( - "For recording with dtype=int you must set the output dtype to float OR set a int_scale" - ) + assert ( + int_scale is not None + ), "For recording with dtype=int you must set the output dtype to float OR set a int_scale" if not apply_mean and regularize: raise ValueError("`apply_mean` must be `True` if regularising. `assume_centered` is fixed to `True`.") @@ -273,7 +273,7 @@ def compute_sklearn_covariance_matrix(data, regularize_kwargs): import sklearn.covariance if "assume_centered" in regularize_kwargs and not regularize_kwargs["assume_centered"]: - raise ValueError("Cannot use `assume_centered=False` for `regularize_kwargs`. " "Fixing to `True`.") + raise ValueError("Cannot use `assume_centered=False` for `regularize_kwargs`. Fixing to `True`.") method = regularize_kwargs.pop("method") regularize_kwargs["assume_centered"] = True