Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(api): Allow to flatten index in reports #1300

Merged
merged 12 commits into from
Feb 20, 2025
21 changes: 16 additions & 5 deletions skore/src/skore/sklearn/_cross_validation/metrics_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
RocCurveDisplay,
)
from skore.utils._accessor import _check_supported_ml_task
from skore.utils._index import flatten_multi_index
from skore.utils._parallel import Parallel, delayed
from skore.utils._progress_bar import progress_decorator

Expand Down Expand Up @@ -46,9 +47,10 @@ def report_metrics(
data_source="test",
scoring=None,
scoring_names=None,
pos_label=None,
scoring_kwargs=None,
pos_label=None,
aggregate=None,
flat_index=False,
):
"""Report a set of metrics for our estimator.

Expand All @@ -73,15 +75,18 @@ def report_metrics(
Used to overwrite the default scoring names in the report. It should be of
the same length as the `scoring` parameter.

pos_label : int, float, bool or str, default=None
The positive class.

scoring_kwargs : dict, default=None
The keyword arguments to pass to the scoring functions.

pos_label : int, float, bool or str, default=None
The positive class.

aggregate : {"mean", "std"} or list of such str, default=None
Function to aggregate the scores across the cross-validation splits.

flat_index : bool, default=False
Whether to flatten the `MultiIndex` columns.

Returns
-------
pd.DataFrame
Expand All @@ -104,7 +109,7 @@ def report_metrics(
Precision (↗︎) 0.94... 0.024...
Recall (↗︎) 0.96... 0.027...
"""
return self._compute_metric_scores(
results = self._compute_metric_scores(
report_metric_name="report_metrics",
data_source=data_source,
aggregate=aggregate,
Expand All @@ -113,6 +118,12 @@ def report_metrics(
scoring_kwargs=scoring_kwargs,
scoring_names=scoring_names,
)
if flat_index:
if isinstance(results.columns, pd.MultiIndex):
results.columns = flatten_multi_index(results.columns)
if isinstance(results.index, pd.MultiIndex):
results.index = flatten_multi_index(results.index)
return results

@progress_decorator(description="Compute metric for each split")
def _compute_metric_scores(
Expand Down
19 changes: 15 additions & 4 deletions skore/src/skore/sklearn/_estimator/metrics_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
RocCurveDisplay,
)
from skore.utils._accessor import _check_supported_ml_task
from skore.utils._index import flatten_multi_index


class _MetricsAccessor(_BaseAccessor, DirNamesMixin):
Expand Down Expand Up @@ -48,8 +49,9 @@ def report_metrics(
y=None,
scoring=None,
scoring_names=None,
pos_label=None,
scoring_kwargs=None,
pos_label=None,
flat_index=False,
):
"""Report a set of metrics for our estimator.

Expand Down Expand Up @@ -83,11 +85,14 @@ def report_metrics(
Used to overwrite the default scoring names in the report. It should be of
the same length as the `scoring` parameter.

scoring_kwargs : dict, default=None
The keyword arguments to pass to the scoring functions.

pos_label : int, float, bool or str, default=None
The positive class.

scoring_kwargs : dict, default=None
The keyword arguments to pass to the scoring functions.
flat_index : bool, default=False
Whether to flatten the multiindex columns.

Returns
-------
Expand Down Expand Up @@ -326,7 +331,13 @@ def report_metrics(
names=name_index,
)

return pd.concat(scores, axis=0)
results = pd.concat(scores, axis=0)
if flat_index:
if isinstance(results.columns, pd.MultiIndex):
results.columns = flatten_multi_index(results.columns)
if isinstance(results.index, pd.MultiIndex):
results.index = flatten_multi_index(results.index)
return results

def _compute_metric_scores(
self,
Expand Down
32 changes: 32 additions & 0 deletions skore/src/skore/utils/_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pandas as pd


def flatten_multi_index(index: pd.MultiIndex) -> pd.Index:
"""Flatten a pandas MultiIndex into a single-level Index.

Flatten a pandas `MultiIndex` into a single-level Index by joining the levels
with underscores. Empty strings are skipped when joining.

Parameters
----------
index : pandas.MultiIndex
The `MultiIndex` to flatten.

Returns
-------
pandas.Index
A flattened `Index` with non-empty levels joined by underscores.

Examples
--------
>>> import pandas as pd
>>> mi = pd.MultiIndex.from_tuples(
... [('a', ''), ('b', '2')], names=['letter', 'number']
... )
>>> flatten_multi_index(mi)
Index(['a', 'b_2'], dtype='object')
"""
if not isinstance(index, pd.MultiIndex):
raise ValueError("`index` must be a MultiIndex.")

return pd.Index(["_".join(filter(bool, map(str, values))) for values in index])
25 changes: 25 additions & 0 deletions skore/tests/unit/sklearn/test_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,31 @@ def test_cross_validation_report_pickle(tmp_path, binary_classification_data):
joblib.dump(report, tmp_path / "report.joblib")


def test_cross_validation_report_flat_index(binary_classification_data):
"""Check that the index is flattened when `flat_index` is True.

Since `pos_label` is None, then by default a MultiIndex would be returned.
Here, we force to have a single-index by passing `flat_index=True`.
"""
estimator, X, y = binary_classification_data
report = CrossValidationReport(estimator, X=X, y=y, cv_splitter=2)
result = report.metrics.report_metrics(flat_index=True)
assert result.shape == (6, 2)
assert isinstance(result.index, pd.Index)
assert result.index.tolist() == [
"Precision (↗︎)_0",
"Precision (↗︎)_1",
"Recall (↗︎)_0",
"Recall (↗︎)_1",
"ROC AUC (↗︎)",
"Brier score (↘︎)",
]
assert result.columns.tolist() == [
"RandomForestClassifier_Split #0",
"RandomForestClassifier_Split #1",
]


########################################################################################
# Check the plot methods
########################################################################################
Expand Down
22 changes: 22 additions & 0 deletions skore/tests/unit/sklearn/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,28 @@ def test_estimator_report_pickle(tmp_path, binary_classification_data):
joblib.dump(report, tmp_path / "report.joblib")


def test_estimator_report_flat_index(binary_classification_data):
"""Check that the index is flattened when `flat_index` is True.

Since `pos_label` is None, then by default a MultiIndex would be returned.
Here, we force to have a single-index by passing `flat_index=True`.
"""
estimator, X_test, y_test = binary_classification_data
report = EstimatorReport(estimator, X_test=X_test, y_test=y_test)
result = report.metrics.report_metrics(flat_index=True)
assert result.shape == (6, 1)
assert isinstance(result.index, pd.Index)
assert result.index.tolist() == [
"Precision (↗︎)_0",
"Precision (↗︎)_1",
"Recall (↗︎)_0",
"Recall (↗︎)_1",
"ROC AUC (↗︎)",
"Brier score (↘︎)",
]
assert result.columns.tolist() == ["RandomForestClassifier"]


########################################################################################
# Check the plot methods
########################################################################################
Expand Down
45 changes: 45 additions & 0 deletions skore/tests/unit/utils/test_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pandas as pd
import pytest
from skore.utils._index import flatten_multi_index


@pytest.mark.parametrize(
"input_tuples, names, expected_values",
[
pytest.param(
[("a", 1), ("b", 2)], ["letter", "number"], ["a_1", "b_2"], id="basic"
),
pytest.param(
[("a", 1, "x"), ("b", 2, "y")],
["letter", "number", "symbol"],
["a_1_x", "b_2_y"],
id="multiple_levels",
),
pytest.param(
[("a", None), (None, 2)],
["letter", "number"],
["a_nan", "nan_2.0"],
id="none_values",
),
pytest.param(
[("a@b", "1#2"), ("c&d", "3$4")],
["letter", "number"],
["a@b_1#2", "c&d_3$4"],
id="special_chars",
),
pytest.param([], ["letter", "number"], [], id="empty"),
],
)
def test_flatten_multi_index(input_tuples, names, expected_values):
"""Test flatten_multi_index with various input cases."""
mi = pd.MultiIndex.from_tuples(input_tuples, names=names)
result = flatten_multi_index(mi)
expected = pd.Index(expected_values)
pd.testing.assert_index_equal(result, expected)


def test_flatten_multi_index_invalid_input():
"""Test that non-MultiIndex input raises ValueError."""
simple_index = pd.Index(["a", "b"])
with pytest.raises(ValueError, match="`index` must be a MultiIndex."):
flatten_multi_index(simple_index)