Skip to content

Commit

Permalink
feat: Add a method .set_style to displays (#1336)
Browse files Browse the repository at this point in the history
closes #1273 

This PR is adding a method `.style` used to set the plotting style that
can be reused each time the `display.plot()`.
  • Loading branch information
glemaitre authored Feb 25, 2025
1 parent 0e4bc3b commit 0ffbe83
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 9 deletions.
10 changes: 9 additions & 1 deletion skore/src/skore/sklearn/_plot/precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sklearn.metrics import average_precision_score, precision_recall_curve
from sklearn.preprocessing import LabelBinarizer

from skore.sklearn._plot.style import StyleDisplayMixin
from skore.sklearn._plot.utils import (
HelpDisplayMixin,
_ClassifierCurveDisplayMixin,
Expand All @@ -19,7 +20,9 @@
from skore.sklearn.types import MLTask


class PrecisionRecallCurveDisplay(HelpDisplayMixin, _ClassifierCurveDisplayMixin):
class PrecisionRecallCurveDisplay(
HelpDisplayMixin, _ClassifierCurveDisplayMixin, StyleDisplayMixin
):
"""Precision Recall visualization.
An instance of this class is should created by
Expand Down Expand Up @@ -106,6 +109,8 @@ class PrecisionRecallCurveDisplay(HelpDisplayMixin, _ClassifierCurveDisplayMixin
>>> display.plot(pr_curve_kwargs={"color": "tab:red"})
"""

_default_pr_curve_kwargs: Union[dict[str, Any], None] = None

def __init__(
self,
*,
Expand Down Expand Up @@ -187,6 +192,9 @@ def plot(
ax=ax, estimator_name=estimator_name
)

if pr_curve_kwargs is None:
pr_curve_kwargs = self._default_pr_curve_kwargs

self.lines_ = []
default_line_kwargs: dict[str, Any] = {}
if len(self.precision) == 1: # binary-classification
Expand Down
16 changes: 11 additions & 5 deletions skore/src/skore/sklearn/_plot/prediction_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sklearn.utils.validation import _num_samples, check_array, check_random_state

from skore.externals._sklearn_compat import _safe_indexing
from skore.sklearn._plot.style import StyleDisplayMixin
from skore.sklearn._plot.utils import (
HelpDisplayMixin,
_despine_matplotlib_axis,
Expand All @@ -19,7 +20,7 @@
from skore.sklearn.types import MLTask


class PredictionErrorDisplay(HelpDisplayMixin):
class PredictionErrorDisplay(HelpDisplayMixin, StyleDisplayMixin):
"""Visualization of the prediction error of a regression model.
This tool can display "residuals vs predicted" or "actual vs predicted"
Expand Down Expand Up @@ -84,6 +85,9 @@ class PredictionErrorDisplay(HelpDisplayMixin):
>>> display.plot(kind="actual_vs_predicted")
"""

_default_scatter_kwargs: Union[dict[str, Any], None] = None
_default_line_kwargs: Union[dict[str, Any], None] = None

def __init__(
self,
*,
Expand Down Expand Up @@ -175,10 +179,12 @@ def plot(
else: # kind == "residual_vs_predicted"
xlabel, ylabel = "Predicted values", "Residuals (actual - predicted)"

if scatter_kwargs is None:
scatter_kwargs = {}
if line_kwargs is None:
line_kwargs = {}
scatter_kwargs = (
self._default_scatter_kwargs if scatter_kwargs is None else scatter_kwargs
) or {}
line_kwargs = (
self._default_line_kwargs if line_kwargs is None else line_kwargs
) or {}

if estimator_name is None:
estimator_name = self.estimator_name
Expand Down
15 changes: 13 additions & 2 deletions skore/src/skore/sklearn/_plot/roc_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from sklearn.metrics import auc, roc_curve
from sklearn.preprocessing import LabelBinarizer

from skore.sklearn._plot.style import StyleDisplayMixin
from skore.sklearn._plot.utils import (
HelpDisplayMixin,
_ClassifierCurveDisplayMixin,
Expand All @@ -21,7 +22,9 @@
from skore.sklearn.types import MLTask


class RocCurveDisplay(HelpDisplayMixin, _ClassifierCurveDisplayMixin):
class RocCurveDisplay(
HelpDisplayMixin, _ClassifierCurveDisplayMixin, StyleDisplayMixin
):
"""ROC Curve visualization.
An instance of this class is should created by `EstimatorReport.metrics.roc()`.
Expand Down Expand Up @@ -109,6 +112,9 @@ class RocCurveDisplay(HelpDisplayMixin, _ClassifierCurveDisplayMixin):
>>> display.plot(roc_curve_kwargs={"color": "tab:red"})
"""

_default_roc_curve_kwargs: Union[dict[str, Any], None] = None
_default_chance_level_kwargs: Union[dict[str, Any], None] = None

def __init__(
self,
*,
Expand Down Expand Up @@ -188,7 +194,12 @@ def plot(
ax=ax, estimator_name=estimator_name
)

self.lines_ = []
if roc_curve_kwargs is None:
roc_curve_kwargs = self._default_roc_curve_kwargs
if chance_level_kwargs is None:
chance_level_kwargs = self._default_chance_level_kwargs

self.lines_: list[Line2D] = []
default_line_kwargs: dict[str, Any] = {}
if len(self.fpr) == 1: # binary-classification
assert (
Expand Down
51 changes: 51 additions & 0 deletions skore/src/skore/sklearn/_plot/style.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Any


class StyleDisplayMixin:
"""Mixin to control the style plot of a display."""

@property
def _style_params(self) -> list[str]:
"""Get the list of available style parameters.
Returns
-------
list
List of style parameter names (without '_default_' prefix).
"""
prefix = "_default_"
suffix = "_kwargs"
return [
attr[len(prefix) :]
for attr in dir(self)
if attr.startswith(prefix) and attr.endswith(suffix)
]

def set_style(self, **kwargs: Any):
"""Set the style parameters for the display.
Parameters
----------
**kwargs : dict
Style parameters to set. Each parameter name should correspond to a
a style attribute passed to the plot method of the display.
Returns
-------
self : object
Returns the instance itself.
Raises
------
ValueError
If a style parameter is unknown.
"""
for param_name, param_value in kwargs.items():
default_attr = f"_default_{param_name}"
if not hasattr(self, default_attr):
raise ValueError(
f"Unknown style parameter: {param_name}. "
f"The parameter name should be one of {self._style_params}."
)
setattr(self, default_attr, param_value)
return self
14 changes: 13 additions & 1 deletion skore/tests/unit/sklearn/plot/test_precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,23 @@ def test_precision_recall_curve_display_pr_curve_kwargs(
report = EstimatorReport(
estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test
)
display = report.metrics.precision_recall()
for pr_curve_kwargs in ({"color": "red"}, [{"color": "red"}]):
display = report.metrics.precision_recall()
display.plot(pr_curve_kwargs=pr_curve_kwargs)
assert display.lines_[0].get_color() == "red"

# check the `.style` display setter
display.plot() # default style
assert display.lines_[0].get_color() == "#1f77b4"
display.set_style(pr_curve_kwargs=pr_curve_kwargs)
display.plot()
assert display.lines_[0].get_color() == "red"
display.plot(pr_curve_kwargs=pr_curve_kwargs)
assert display.lines_[0].get_color() == "red"

# reset to default style since next call to `precision_recall` will use the
# cache
display.set_style(pr_curve_kwargs={"color": "#1f77b4"})

estimator, X_train, X_test, y_train, y_test = multiclass_classification_data
report = EstimatorReport(
Expand Down
21 changes: 21 additions & 0 deletions skore/tests/unit/sklearn/plot/test_prediction_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,27 @@ def test_prediction_error_display_kwargs(pyplot, regression_data):
np.testing.assert_allclose(display.scatter_.get_facecolor(), [[1, 0, 0, 0.3]])
assert display.line_.get_color() == "blue"

# check the `.style` display setter
display.plot() # default style
np.testing.assert_allclose(
display.scatter_.get_facecolor(),
[[0.121569, 0.466667, 0.705882, 0.3]],
rtol=1e-3,
)
assert display.line_.get_color() == "black"
display.set_style(scatter_kwargs={"color": "red"}, line_kwargs={"color": "blue"})
display.plot()
np.testing.assert_allclose(display.scatter_.get_facecolor(), [[1, 0, 0, 0.3]])
assert display.line_.get_color() == "blue"
# overwrite the style that was set above
display.plot(
scatter_kwargs={"color": "tab:orange"}, line_kwargs={"color": "tab:green"}
)
np.testing.assert_allclose(
display.scatter_.get_facecolor(), [[1.0, 0.498039, 0.054902, 0.3]], rtol=1e-3
)
assert display.line_.get_color() == "tab:green"

display.plot(despine=False)
assert display.ax_.spines["top"].get_visible()
assert display.ax_.spines["right"].get_visible()
Expand Down
17 changes: 17 additions & 0 deletions skore/tests/unit/sklearn/plot/test_roc_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,23 @@ def test_roc_curve_display_roc_curve_kwargs_binary_classification(
assert display.lines_[0].get_color() == "red"
assert display.chance_level_.get_color() == "blue"

# check the `.style` display setter
display.plot() # default style
assert display.lines_[0].get_color() == "#1f77b4"
assert display.chance_level_.get_color() == "k"
display.set_style(
roc_curve_kwargs=roc_curve_kwargs, chance_level_kwargs={"color": "blue"}
)
display.plot()
assert display.lines_[0].get_color() == "red"
assert display.chance_level_.get_color() == "blue"
# overwrite the style that was set above
display.plot(
roc_curve_kwargs={"color": "#1f77b4"}, chance_level_kwargs={"color": "red"}
)
assert display.lines_[0].get_color() == "#1f77b4"
assert display.chance_level_.get_color() == "red"


def test_roc_curve_display_roc_curve_kwargs_multiclass_classification(
pyplot, multiclass_classification_data
Expand Down
15 changes: 15 additions & 0 deletions skore/tests/unit/sklearn/plot/test_style.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pytest
from skore.sklearn._plot.style import StyleDisplayMixin


class TestDisplay(StyleDisplayMixin):
_default_some_kwargs = None


def test_style_mixin():
display = TestDisplay()
display.set_style(some_kwargs=1)
assert display._default_some_kwargs == 1

with pytest.raises(ValueError, match="Unknown style parameter: unknown_param."):
display.set_style(unknown_param=1)

0 comments on commit 0ffbe83

Please sign in to comment.