diff --git a/src/skmatter/_selection.py b/src/skmatter/_selection.py index 95e43ed15..cd5f86569 100644 --- a/src/skmatter/_selection.py +++ b/src/skmatter/_selection.py @@ -236,8 +236,6 @@ def fit(self, X, y=None, warm_start=False): n_to_select_from = X.shape[self._axis] self.n_samples_in_, self.n_features_in_ = X.shape - self.n_samples_in_, self.n_features_in_ = X.shape - error_msg = ( "n_to_select must be either None, an " f"integer in [1, n_{self.selection_type}s] " @@ -439,7 +437,11 @@ def _continue_greedy_search(self, X, y, n_to_select): def _get_best_new_selection(self, scorer, X, y): scores = scorer(X, y) - max_score_idx = np.argmax(scores) + # Get the score argmax, but only for idxs not already selected + _tmp_scores = { + i: score for i, score in enumerate(scores) if i not in self.selected_idx_ + } + max_score_idx = max(_tmp_scores, key=_tmp_scores.get) if self.score_threshold is not None: if self.first_score_ is None: self.first_score_ = scores[max_score_idx] diff --git a/tests/test_feature_simple_cur.py b/tests/test_feature_simple_cur.py index 72554471d..4e49d643b 100644 --- a/tests/test_feature_simple_cur.py +++ b/tests/test_feature_simple_cur.py @@ -46,6 +46,22 @@ def test_non_it(self): self.assertTrue(np.allclose(selector.selected_idx_, ref_idx)) + def test_unique_selected_idx_zero_score(self): + """ + Tests that the selected idxs are unique, which may not be the + case when the score is numerically zero + """ + np.random.seed(0) + n_samples = 10 + n_features = 15 + X = np.random.rand(n_samples, n_features) + X[:, 3] = np.random.rand(10) * 1e-13 + X[:, 4] = np.random.rand(10) * 1e-13 + selector_problem = CUR(n_to_select=len(X.T)).fit(X) + assert len(selector_problem.selected_idx_) == len( + set(selector_problem.selected_idx_) + ) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/test_feature_simple_fps.py b/tests/test_feature_simple_fps.py index b29a2bc7b..8d7d40304 100644 --- a/tests/test_feature_simple_fps.py +++ b/tests/test_feature_simple_fps.py @@ -1,5 +1,6 @@ import unittest +import numpy as np from sklearn.datasets import load_diabetes as get_dataset from sklearn.utils.validation import NotFittedError @@ -62,6 +63,22 @@ def test_get_distances(self): selector = FPS(n_to_select=7) _ = selector.get_select_distance() + def test_unique_selected_idx_zero_score(self): + """ + Tests that the selected idxs are unique, which may not be the + case when the score is numerically zero + """ + np.random.seed(0) + n_samples = 10 + n_features = 15 + X = np.random.rand(n_samples, n_features) + X[:, 3] = np.random.rand(10) * 1e-13 + X[:, 4] = np.random.rand(10) * 1e-13 + selector_problem = FPS(n_to_select=len(X.T)).fit(X) + assert len(selector_problem.selected_idx_) == len( + set(selector_problem.selected_idx_) + ) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/test_sample_simple_cur.py b/tests/test_sample_simple_cur.py index b3a9437e1..ab45dc825 100644 --- a/tests/test_sample_simple_cur.py +++ b/tests/test_sample_simple_cur.py @@ -58,6 +58,23 @@ def test_non_it(self): self.assertTrue(np.allclose(selector.selected_idx_, ref_idx)) + def test_unique_selected_idx_zero_score(self): + """ + Tests that the selected idxs are unique, which may not be the + case when the score is numerically zero. + """ + np.random.seed(0) + n_samples = 10 + n_features = 15 + X = np.random.rand(n_samples, n_features) + X[4, :] = np.random.rand(15) * 1e-13 + X[5, :] = np.random.rand(15) * 1e-13 + X[6, :] = np.random.rand(15) * 1e-13 + selector_problem = CUR(n_to_select=len(X)).fit(X) + assert len(selector_problem.selected_idx_) == len( + set(selector_problem.selected_idx_) + ) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/test_sample_simple_fps.py b/tests/test_sample_simple_fps.py index ca7ee4bee..afd4d2111 100644 --- a/tests/test_sample_simple_fps.py +++ b/tests/test_sample_simple_fps.py @@ -1,5 +1,6 @@ import unittest +import numpy as np from sklearn.datasets import load_diabetes as get_dataset from sklearn.utils.validation import NotFittedError @@ -81,6 +82,23 @@ def test_threshold(self): self.assertEqual(len(selector.selected_idx_), 5) self.assertEqual(selector.selected_idx_.tolist(), self.idx[:5]) + def test_unique_selected_idx_zero_score(self): + """ + Tests that the selected idxs are unique, which may not be the + case when the score is numerically zero. + """ + np.random.seed(0) + n_samples = 10 + n_features = 15 + X = np.random.rand(n_samples, n_features) + X[4, :] = np.random.rand(15) * 1e-13 + X[5, :] = np.random.rand(15) * 1e-13 + X[6, :] = np.random.rand(15) * 1e-13 + selector_problem = FPS(n_to_select=len(X)).fit(X) + assert len(selector_problem.selected_idx_) == len( + set(selector_problem.selected_idx_) + ) + if __name__ == "__main__": unittest.main(verbosity=2)