diff --git a/skore/src/skore/sklearn/_estimator/feature_importance_accessor.py b/skore/src/skore/sklearn/_estimator/feature_importance_accessor.py index d44a68eb1..6c7961738 100644 --- a/skore/src/skore/sklearn/_estimator/feature_importance_accessor.py +++ b/skore/src/skore/sklearn/_estimator/feature_importance_accessor.py @@ -554,14 +554,17 @@ def _feature_permutation( n_repeats = data.shape[1] # Get score name - scoring_name = None if scoring is None: if is_classifier(self._parent.estimator_): scoring_name = "accuracy" elif is_regressor(self._parent.estimator_): scoring_name = "r2" - if isinstance(scoring, str): - scoring_name = scoring + else: + # e.g. if scoring is a callable + scoring_name = None + + # no other cases to deal with explicitly, because + # scoring cannot possibly be a string at this stage if scoring_name is None: index = pd.Index(feature_names, name="Feature") diff --git a/skore/tests/unit/sklearn/estimator/feature_importance/test_permutation_importance.py b/skore/tests/unit/sklearn/estimator/feature_importance/test_permutation_importance.py index 3fc4085ae..bfe9e9004 100644 --- a/skore/tests/unit/sklearn/estimator/feature_importance/test_permutation_importance.py +++ b/skore/tests/unit/sklearn/estimator/feature_importance/test_permutation_importance.py @@ -5,7 +5,7 @@ import pytest from sklearn.datasets import make_regression from sklearn.exceptions import NotFittedError -from sklearn.linear_model import LinearRegression +from sklearn.linear_model import LinearRegression, LogisticRegression from sklearn.metrics import make_scorer, r2_score, root_mean_squared_error from sklearn.model_selection import train_test_split from sklearn.pipeline import make_pipeline @@ -70,7 +70,7 @@ def case_r2_numpy(): def case_train_numpy(): data = regression_data() - kwargs = {"data_source": "train", "random_state": 42} + kwargs = {"data_source": "train", "scoring": "r2", "random_state": 42} return data, kwargs, multiindex_numpy, repeat_columns @@ -289,6 +289,26 @@ def test_cache_scoring_is_callable(regression_data, scoring): pd.testing.assert_frame_equal(cached_result, result) +def test_classification(classification_data): + """If `scoring` is a callable then the result is cached properly.""" + + X, y = classification_data + report = EstimatorReport(LogisticRegression(), X_train=X, y_train=y) + + result = report.feature_importance.feature_permutation( + data_source="train", random_state=42 + ) + + pd.testing.assert_index_equal( + result.index, + pd.MultiIndex.from_product( + [["accuracy"], (f"Feature #{i}" for i in range(5))], + names=("Metric", "Feature"), + ), + ) + pd.testing.assert_index_equal(result.columns, repeat_columns) + + def test_not_fitted(regression_data): """If estimator is not fitted, raise""" X, y = regression_data