Skip to content

Commit

Permalink
increase coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
auguste-probabl committed Mar 5, 2025
1 parent 1af7373 commit f9217ad
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -554,14 +554,17 @@ def _feature_permutation(
n_repeats = data.shape[1]

# Get score name
scoring_name = None
if scoring is None:
if is_classifier(self._parent.estimator_):
scoring_name = "accuracy"
elif is_regressor(self._parent.estimator_):
scoring_name = "r2"
if isinstance(scoring, str):
scoring_name = scoring
else:
# e.g. if scoring is a callable
scoring_name = None

# no other cases to deal with explicitly, because
# scoring cannot possibly be a string at this stage

if scoring_name is None:
index = pd.Index(feature_names, name="Feature")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
from sklearn.datasets import make_regression
from sklearn.exceptions import NotFittedError
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.metrics import make_scorer, r2_score, root_mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.pipeline import make_pipeline
Expand Down Expand Up @@ -70,7 +70,7 @@ def case_r2_numpy():
def case_train_numpy():
data = regression_data()

kwargs = {"data_source": "train", "random_state": 42}
kwargs = {"data_source": "train", "scoring": "r2", "random_state": 42}

return data, kwargs, multiindex_numpy, repeat_columns

Expand Down Expand Up @@ -289,6 +289,26 @@ def test_cache_scoring_is_callable(regression_data, scoring):
pd.testing.assert_frame_equal(cached_result, result)


def test_classification(classification_data):
"""If `scoring` is a callable then the result is cached properly."""

X, y = classification_data
report = EstimatorReport(LogisticRegression(), X_train=X, y_train=y)

result = report.feature_importance.feature_permutation(
data_source="train", random_state=42
)

pd.testing.assert_index_equal(
result.index,
pd.MultiIndex.from_product(
[["accuracy"], (f"Feature #{i}" for i in range(5))],
names=("Metric", "Feature"),
),
)
pd.testing.assert_index_equal(result.columns, repeat_columns)


def test_not_fitted(regression_data):
"""If estimator is not fitted, raise"""
X, y = regression_data
Expand Down

0 comments on commit f9217ad

Please sign in to comment.