Skip to content

Commit

Permalink
hold
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Nov 12, 2024
1 parent e85492f commit 602e6ce
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 31 deletions.
46 changes: 19 additions & 27 deletions src/spikeinterface/preprocessing/tests/test_whiten.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,20 +475,17 @@ def test_passed_W_and_M(self):

def test_whiten_general(self, create_cache_folder):
"""
<<<<<<< HEAD
=======
Perform some general tests on the whitening functionality.
First, perform smoke test that `compute_whitening_matrix` is running,
check recording output datatypes are as expected. Check that
saving preseves datatype, `int_scale` is propagated, and
regularisation reduces the norm.
>>>>>>> 19105e25a (Finalise tests, tidy up and add docs.)
"""
cache_folder = create_cache_folder
rec = generate_recording(num_channels=4, seed=2205)

random_chunk_kwargs = {}
random_chunk_kwargs = {"seed": 2205}
W1, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, radius_um=None)

with pytest.raises(AssertionError):
Expand All @@ -501,28 +498,23 @@ def test_whiten_general(self, create_cache_folder):
rec2 = whiten(rec)
rec2.save(verbose=False)

# test dtype
rec_int = scale(rec2, dtype="int16")
rec3 = whiten(rec_int, dtype="float16")
rec3 = rec3.save(folder=cache_folder / "rec1")
assert rec3.get_dtype() == "float16"
# test dtype
rec_int = scale(rec2, dtype="int16")
rec3 = whiten(rec_int, dtype="float16")
rec3 = rec3.save(folder=cache_folder / "rec1")
assert rec3.get_dtype() == "float16"

# test parallel
rec_par = rec3.save(folder=cache_folder / "rec_par", n_jobs=2)
np.testing.assert_array_equal(rec3.get_traces(segment_index=0), rec_par.get_traces(segment_index=0))

with pytest.raises(AssertionError):
rec4 = whiten(rec_int, dtype=None) # int_scale should be applied
rec4 = whiten(rec_int, dtype=None, int_scale=256)
assert rec4.get_dtype() == "int16"
assert rec4._kwargs["M"] is None

# test regularization : norm should be smaller
W2, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, regularize=True)
assert np.linalg.norm(W1) > np.linalg.norm(W2)

# test regularization : norm should be smaller
if HAS_SKLEARN:
W2, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, regularize=True)
assert np.linalg.norm(W1) > np.linalg.norm(W2)
# test parallel
rec_par = rec3.save(folder=cache_folder / "rec_par", n_jobs=2)
np.testing.assert_array_equal(rec3.get_traces(segment_index=0), rec_par.get_traces(segment_index=0))

with pytest.raises(AssertionError):
rec4 = whiten(rec_int, dtype=None) # int_scale should be applied
rec4 = whiten(rec_int, dtype=None, int_scale=256)
assert rec4.get_dtype() == "int16"
assert rec4._kwargs["M"] is None

# test regularization : norm should be smaller
if HAS_SKLEARN:
W2, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, regularize=True)
assert np.linalg.norm(W1) > np.linalg.norm(W2)
8 changes: 4 additions & 4 deletions src/spikeinterface/preprocessing/whiten.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def __init__(
dtype_ = fix_dtype(recording, dtype)

if dtype_.kind == "i":
assert int_scale is not None, (
"For recording with dtype=int you must set the output dtype to float OR set a int_scale"
)
assert (
int_scale is not None
), "For recording with dtype=int you must set the output dtype to float OR set a int_scale"

if not apply_mean and regularize:
raise ValueError("`apply_mean` must be `True` if regularising. `assume_centered` is fixed to `True`.")
Expand Down Expand Up @@ -273,7 +273,7 @@ def compute_sklearn_covariance_matrix(data, regularize_kwargs):
import sklearn.covariance

if "assume_centered" in regularize_kwargs and not regularize_kwargs["assume_centered"]:
raise ValueError("Cannot use `assume_centered=False` for `regularize_kwargs`. " "Fixing to `True`.")
raise ValueError("Cannot use `assume_centered=False` for `regularize_kwargs`. Fixing to `True`.")

method = regularize_kwargs.pop("method")
regularize_kwargs["assume_centered"] = True
Expand Down

0 comments on commit 602e6ce

Please sign in to comment.