diff --git a/skore/src/skore/sklearn/_estimator/feature_importance_accessor.py b/skore/src/skore/sklearn/_estimator/feature_importance_accessor.py index d779de198..0fa37ace0 100644 --- a/skore/src/skore/sklearn/_estimator/feature_importance_accessor.py +++ b/skore/src/skore/sklearn/_estimator/feature_importance_accessor.py @@ -27,6 +27,8 @@ # - a dictionary with metric names as keys and callables a values. Scoring = Union[str, Callable, Iterable[str], dict[str, Callable]] +Aggregation = Literal["mean", "std"] + class _FeatureImportanceAccessor(_BaseAccessor["EstimatorReport"], DirNamesMixin): """Accessor for feature importance related operations. @@ -111,6 +113,7 @@ def feature_permutation( data_source: DataSource = "test", X: Optional[ArrayLike] = None, y: Optional[ArrayLike] = None, + aggregate: Optional[Union[Aggregation, list[Aggregation]]] = None, scoring: Optional[Scoring] = None, n_repeats: int = 5, n_jobs: Optional[int] = None, @@ -147,6 +150,9 @@ def feature_permutation( New target on which to compute the metric. By default, we use the test target provided when creating the report. + aggregate : {"mean", "std"} or list of such str, default=None + Function to aggregate the scores across the repeats. + scoring : str, callable, list, tuple, or dict, default=None The scorer to pass to :func:`~sklearn.inspection.permutation_importance`. @@ -211,12 +217,30 @@ def feature_permutation( Feature #7 0.023... 0.017... Feature #8 0.077... 0.077... Feature #9 0.011... 0.023... + >>> report.feature_importance.feature_permutation( + ... n_repeats=2, + ... aggregate=["mean", "std"], + ... random_state=0, + ... ) + mean std + Feature + Feature #0 0.001... 0.002... + Feature #1 0.009... 0.007... + Feature #2 0.128... 0.019... + Feature #3 0.074... 0.004... + Feature #4 0.000... 0.000... + Feature #5 -0.000... 0.002... + Feature #6 0.031... 0.002... + Feature #7 0.020... 0.004... + Feature #8 0.077... 0.000... + Feature #9 0.017... 0.008... """ return self._feature_permutation( data_source=data_source, data_source_hash=None, X=X, y=y, + aggregate=aggregate, scoring=scoring, n_repeats=n_repeats, n_jobs=n_jobs, @@ -231,6 +255,7 @@ def _feature_permutation( data_source_hash: Optional[int] = None, X: Optional[ArrayLike] = None, y: Optional[ArrayLike] = None, + aggregate: Optional[Union[Aggregation, list[Aggregation]]] = None, scoring: Optional[Scoring] = None, n_repeats: int = 5, n_jobs: Optional[int] = None, @@ -269,6 +294,9 @@ def _feature_permutation( else: cache_key_parts.append(scoring) + # aggregate is not included in the cache + # in order to trade off computation for storage + # order arguments by key to ensure cache works # n_jobs variable should not be in the cache kwargs = { @@ -334,6 +362,11 @@ def _feature_permutation( if cache_key is not None: self._parent._cache[cache_key] = score + if aggregate: + if isinstance(aggregate, str): + aggregate = [aggregate] + score = score.aggregate(func=aggregate, axis=1) + return score #################################################################################### diff --git a/skore/tests/unit/sklearn/estimator/feature_importance/test_permutation_importance.py b/skore/tests/unit/sklearn/estimator/feature_importance/test_permutation_importance.py index 8b84d71d3..39cb25f70 100644 --- a/skore/tests/unit/sklearn/estimator/feature_importance/test_permutation_importance.py +++ b/skore/tests/unit/sklearn/estimator/feature_importance/test_permutation_importance.py @@ -105,6 +105,19 @@ def case_X_y(): return data, kwargs, expected +def case_aggregate(): + data = regression_data() + + kwargs = {"data_source": "train", "aggregate": "mean", "random_state": 42} + + expected = pd.DataFrame( + data=np.zeros((3, 1)), + index=pd.Index((f"Feature #{i}" for i in range(3)), name="Feature"), + columns=pd.Index(["mean"]), + ) + return data, kwargs, expected + + def case_default_args_dataframe(): data = regression_data_dataframe() @@ -179,6 +192,7 @@ def case_several_scoring_dataframe(): case_r2_numpy, case_train_numpy, case_several_scoring_numpy, + case_aggregate, case_default_args_dataframe, case_r2_dataframe, case_train_dataframe,