diff --git a/tests/scores/test_self_scores.py b/tests/scores/test_self_scores.py index 0f460f0..a73e3fb 100644 --- a/tests/scores/test_self_scores.py +++ b/tests/scores/test_self_scores.py @@ -13,10 +13,14 @@ from kronfluence.utils.dataset import DataLoaderKwargs from tests.utils import ( ATOL, + DEFAULT_FACTORS_NAME, + DEFAULT_SCORES_NAME, RTOL, check_tensor_dict_equivalence, + custom_factors_name, + custom_scores_name, prepare_model_and_analyzer, - prepare_test, DEFAULT_FACTORS_NAME, DEFAULT_SCORES_NAME, custom_scores_name, custom_factors_name, + prepare_test, ) @@ -432,7 +436,12 @@ def test_compute_self_scores_with_indices( assert self_scores[ALL_MODULE_NAME].size(0) == 48 -@pytest.mark.parametrize("test_name",["mlp",],) +@pytest.mark.parametrize( + "test_name", + [ + "mlp", + ], +) @pytest.mark.parametrize("train_size", [60]) @pytest.mark.parametrize("seed", [1]) def test_compute_self_scores_with_diagonal_pairwise_equivalence(