Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(api): Allow to flatten index in reports #1300

Merged
merged 12 commits into from
Feb 20, 2025
22 changes: 17 additions & 5 deletions skore/src/skore/sklearn/_comparison/metrics_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from skore.externals._pandas_accessors import DirNamesMixin
from skore.sklearn._base import _BaseAccessor
from skore.utils._accessor import _check_supported_ml_task
from skore.utils._index import flatten_multi_index
from skore.utils._progress_bar import progress_decorator


Expand Down Expand Up @@ -42,9 +43,10 @@ def report_metrics(
y=None,
scoring=None,
scoring_names=None,
pos_label=None,
scoring_kwargs=None,
pos_label=None,
indicator_favorability=False,
flat_index=False,
):
"""Report a set of metrics for the estimators.

Expand Down Expand Up @@ -78,16 +80,20 @@ def report_metrics(
Used to overwrite the default scoring names in the report. It should be of
the same length as the ``scoring`` parameter.

pos_label : int, float, bool or str, default=None
The positive class.

scoring_kwargs : dict, default=None
The keyword arguments to pass to the scoring functions.

pos_label : int, float, bool or str, default=None
The positive class.

indicator_favorability : bool, default=False
Whether or not to add an indicator of the favorability of the metric as
an extra column in the returned DataFrame.

flat_index : bool, default=False
Whether to flatten the `MultiIndex` columns. Flat index will always be lower
case, do not include spaces and remove the hash symbol to ease indexing.

Returns
-------
pd.DataFrame
Expand Down Expand Up @@ -129,7 +135,7 @@ def report_metrics(
Precision 0.96... 0.96...
Recall 0.97... 0.97...
"""
return self._compute_metric_scores(
results = self._compute_metric_scores(
report_metric_name="report_metrics",
data_source=data_source,
X=X,
Expand All @@ -140,6 +146,12 @@ def report_metrics(
scoring_names=scoring_names,
indicator_favorability=indicator_favorability,
)
if flat_index:
if isinstance(results.columns, pd.MultiIndex):
results.columns = flatten_multi_index(results.columns)
if isinstance(results.index, pd.MultiIndex):
results.index = flatten_multi_index(results.index)
return results

@progress_decorator(description="Compute metric for each split")
def _compute_metric_scores(
Expand Down
22 changes: 17 additions & 5 deletions skore/src/skore/sklearn/_cross_validation/metrics_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
RocCurveDisplay,
)
from skore.utils._accessor import _check_supported_ml_task
from skore.utils._index import flatten_multi_index
from skore.utils._parallel import Parallel, delayed
from skore.utils._progress_bar import progress_decorator

Expand Down Expand Up @@ -46,9 +47,10 @@ def report_metrics(
data_source="test",
scoring=None,
scoring_names=None,
pos_label=None,
scoring_kwargs=None,
pos_label=None,
indicator_favorability=False,
flat_index=False,
aggregate=None,
):
"""Report a set of metrics for our estimator.
Expand All @@ -74,16 +76,20 @@ def report_metrics(
Used to overwrite the default scoring names in the report. It should be of
the same length as the `scoring` parameter.

pos_label : int, float, bool or str, default=None
The positive class.

scoring_kwargs : dict, default=None
The keyword arguments to pass to the scoring functions.

pos_label : int, float, bool or str, default=None
The positive class.

indicator_favorability : bool, default=False
Whether or not to add an indicator of the favorability of the metric as
an extra column in the returned DataFrame.

flat_index : bool, default=False
Whether to flatten the `MultiIndex` columns. Flat index will always be lower
case, do not include spaces and remove the hash symbol to ease indexing.

aggregate : {"mean", "std"} or list of such str, default=None
Function to aggregate the scores across the cross-validation splits.

Expand Down Expand Up @@ -112,7 +118,7 @@ def report_metrics(
Precision 0.94... 0.02... (↗︎)
Recall 0.96... 0.02... (↗︎)
"""
return self._compute_metric_scores(
results = self._compute_metric_scores(
report_metric_name="report_metrics",
data_source=data_source,
aggregate=aggregate,
Expand All @@ -122,6 +128,12 @@ def report_metrics(
scoring_names=scoring_names,
indicator_favorability=indicator_favorability,
)
if flat_index:
if isinstance(results.columns, pd.MultiIndex):
results.columns = flatten_multi_index(results.columns)
if isinstance(results.index, pd.MultiIndex):
results.index = flatten_multi_index(results.index)
return results

@progress_decorator(description="Compute metric for each split")
def _compute_metric_scores(
Expand Down
22 changes: 17 additions & 5 deletions skore/src/skore/sklearn/_estimator/metrics_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
RocCurveDisplay,
)
from skore.utils._accessor import _check_supported_ml_task
from skore.utils._index import flatten_multi_index


class _MetricsAccessor(_BaseAccessor, DirNamesMixin):
Expand Down Expand Up @@ -48,9 +49,10 @@ def report_metrics(
y=None,
scoring=None,
scoring_names=None,
pos_label=None,
scoring_kwargs=None,
pos_label=None,
indicator_favorability=False,
flat_index=False,
):
"""Report a set of metrics for our estimator.

Expand Down Expand Up @@ -84,16 +86,20 @@ def report_metrics(
Used to overwrite the default scoring names in the report. It should be of
the same length as the `scoring` parameter.

pos_label : int, float, bool or str, default=None
The positive class.

scoring_kwargs : dict, default=None
The keyword arguments to pass to the scoring functions.

pos_label : int, float, bool or str, default=None
The positive class.

indicator_favorability : bool, default=False
Whether or not to add an indicator of the favorability of the metric as
an extra column in the returned DataFrame.

flat_index : bool, default=False
Whether to flatten the multiindex columns. Flat index will always be lower
case, do not include spaces and remove the hash symbol to ease indexing.

Returns
-------
pd.DataFrame
Expand Down Expand Up @@ -339,7 +345,13 @@ def report_metrics(
names=name_index,
)

return pd.concat(scores, axis=0)
results = pd.concat(scores, axis=0)
if flat_index:
if isinstance(results.columns, pd.MultiIndex):
results.columns = flatten_multi_index(results.columns)
if isinstance(results.index, pd.MultiIndex):
results.index = flatten_multi_index(results.index)
return results

def _compute_metric_scores(
self,
Expand Down
41 changes: 41 additions & 0 deletions skore/src/skore/utils/_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pandas as pd


def flatten_multi_index(index: pd.MultiIndex) -> pd.Index:
"""Flatten a pandas MultiIndex into a single-level Index.

Flatten a pandas `MultiIndex` into a single-level Index by joining the levels
with underscores. Empty strings are skipped when joining. Spaces are replaced by
an underscore and "#" are skipped.

Parameters
----------
index : pandas.MultiIndex
The `MultiIndex` to flatten.

Returns
-------
pandas.Index
A flattened `Index` with non-empty levels joined by underscores.

Examples
--------
>>> import pandas as pd
>>> mi = pd.MultiIndex.from_tuples(
... [('a', ''), ('b', '2')], names=['letter', 'number']
... )
>>> flatten_multi_index(mi)
Index(['a', 'b_2'], dtype='object')
"""
if not isinstance(index, pd.MultiIndex):
raise ValueError("`index` must be a MultiIndex.")

return pd.Index(
[
"_".join(filter(bool, map(str, values)))
.replace(" ", "_")
.replace("#", "")
.lower()
for values in index
]
)
28 changes: 28 additions & 0 deletions skore/tests/unit/sklearn/test_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,34 @@ def test_comparison_report_custom_metric_X_y(binary_classification_model):
pd.testing.assert_frame_equal(result, expected)


def test_cross_validation_report_flat_index(binary_classification_model):
"""Check that the index is flattened when `flat_index` is True.

Since `pos_label` is None, then by default a MultiIndex would be returned.
Here, we force to have a single-index by passing `flat_index=True`.
"""
estimator, X_train, X_test, y_train, y_test = binary_classification_model
report_1 = EstimatorReport(
estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test
)
report_2 = EstimatorReport(
estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test
)
report = ComparisonReport({"report_1": report_1, "report_2": report_2})
result = report.metrics.report_metrics(flat_index=True)
assert result.shape == (6, 2)
assert isinstance(result.index, pd.Index)
assert result.index.tolist() == [
"precision_0",
"precision_1",
"recall_0",
"recall_1",
"roc_auc",
"brier_score",
]
assert result.columns.tolist() == ["report_1", "report_2"]


def test_estimator_report_report_metrics_indicator_favorability(
binary_classification_model,
):
Expand Down
25 changes: 25 additions & 0 deletions skore/tests/unit/sklearn/test_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,31 @@ def test_cross_validation_report_pickle(tmp_path, binary_classification_data):
joblib.dump(report, tmp_path / "report.joblib")


def test_cross_validation_report_flat_index(binary_classification_data):
"""Check that the index is flattened when `flat_index` is True.

Since `pos_label` is None, then by default a MultiIndex would be returned.
Here, we force to have a single-index by passing `flat_index=True`.
"""
estimator, X, y = binary_classification_data
report = CrossValidationReport(estimator, X=X, y=y, cv_splitter=2)
result = report.metrics.report_metrics(flat_index=True)
assert result.shape == (6, 2)
assert isinstance(result.index, pd.Index)
assert result.index.tolist() == [
"precision_0",
"precision_1",
"recall_0",
"recall_1",
"roc_auc",
"brier_score",
]
assert result.columns.tolist() == [
"randomforestclassifier_split_0",
"randomforestclassifier_split_1",
]


########################################################################################
# Check the plot methods
########################################################################################
Expand Down
22 changes: 22 additions & 0 deletions skore/tests/unit/sklearn/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,28 @@ def test_estimator_report_pickle(binary_classification_data):
joblib.dump(report, stream)


def test_estimator_report_flat_index(binary_classification_data):
"""Check that the index is flattened when `flat_index` is True.

Since `pos_label` is None, then by default a MultiIndex would be returned.
Here, we force to have a single-index by passing `flat_index=True`.
"""
estimator, X_test, y_test = binary_classification_data
report = EstimatorReport(estimator, X_test=X_test, y_test=y_test)
result = report.metrics.report_metrics(flat_index=True)
assert result.shape == (6, 1)
assert isinstance(result.index, pd.Index)
assert result.index.tolist() == [
"precision_0",
"precision_1",
"recall_0",
"recall_1",
"roc_auc",
"brier_score",
]
assert result.columns.tolist() == ["RandomForestClassifier"]


########################################################################################
# Check the plot methods
########################################################################################
Expand Down
63 changes: 63 additions & 0 deletions skore/tests/unit/utils/test_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import pandas as pd
import pytest
from skore.utils._index import flatten_multi_index


@pytest.mark.parametrize(
"input_tuples, names, expected_values",
[
pytest.param(
[("A", 1), ("B", 2)], ["letter", "number"], ["a_1", "b_2"], id="basic"
),
pytest.param(
[("A", 1, "X"), ("B", 2, "Y")],
["letter", "number", "symbol"],
["a_1_x", "b_2_y"],
id="multiple_levels",
),
pytest.param(
[("A", None), (None, 2)],
["letter", "number"],
["a_nan", "nan_2.0"],
id="none_values",
),
pytest.param(
[("A@B", "1#2"), ("C&D", "3$4")],
["letter", "number"],
["a@b_12", "c&d_3$4"],
id="special_chars",
),
pytest.param([], ["letter", "number"], [], id="empty"),
pytest.param(
[("Hello World", "A B"), ("Space Test", "X Y")],
["text", "more"],
["hello_world_a_b", "space_test_x_y"],
id="spaces",
),
pytest.param(
[("A#B#C", "1#2#3"), ("X#Y", "5#6")],
["text", "numbers"],
["abc_123", "xy_56"],
id="hash_symbols",
),
pytest.param(
[("UPPER", "CASE"), ("MiXeD", "cAsE")],
["text", "type"],
["upper_case", "mixed_case"],
id="case_sensitivity",
),
],
)
def test_flatten_multi_index(input_tuples, names, expected_values):
"""Test flatten_multi_index with various input cases."""
mi = pd.MultiIndex.from_tuples(input_tuples, names=names)
result = flatten_multi_index(mi)
expected = pd.Index(expected_values)
pd.testing.assert_index_equal(result, expected)


def test_flatten_multi_index_invalid_input():
"""Test that non-MultiIndex input raises ValueError."""
simple_index = pd.Index(["a", "b"])
with pytest.raises(ValueError, match="`index` must be a MultiIndex."):
flatten_multi_index(simple_index)