Skip to content

Commit

Permalink
feat: Change repr method of EstimatorReport and `CrossValidationRep…
Browse files Browse the repository at this point in the history
…ort` (#1304)

This makes it more practical to print objects containing such reports,
e.g. `[report]`.

Closes #1293
  • Loading branch information
auguste-probabl authored Feb 10, 2025
1 parent c5d87b3 commit 789ee65
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 12 deletions.
6 changes: 2 additions & 4 deletions skore/src/skore/sklearn/_cross_validation/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,5 @@ def _get_help_legend(self):
)

def __repr__(self):
"""Return a string representation using rich."""
return self._rich_repr(
class_name="skore.CrossValidationReport", help_method_name="help()"
)
"""Return a string representation."""
return f"{self.__class__.__name__}(estimator={self.estimator_}, ...)"
6 changes: 2 additions & 4 deletions skore/src/skore/sklearn/_estimator/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,5 @@ def _get_help_legend(self):
)

def __repr__(self):
"""Return a string representation using rich."""
return self._rich_repr(
class_name="skore.EstimatorReport", help_method_name="help()"
)
"""Return a string representation."""
return f"{self.__class__.__name__}(estimator={self.estimator_}, ...)"
3 changes: 1 addition & 2 deletions skore/tests/unit/sklearn/test_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,7 @@ def test_cross_validation_report_repr(binary_classification_data):
report = CrossValidationReport(estimator, X, y)

repr_str = repr(report)
assert "skore.CrossValidationReport" in repr_str
assert "help()" in repr_str
assert "CrossValidationReport" in repr_str


@pytest.mark.parametrize(
Expand Down
3 changes: 1 addition & 2 deletions skore/tests/unit/sklearn/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,7 @@ def test_estimator_report_repr(binary_classification_data):
report = EstimatorReport(estimator, X_test=X_test, y_test=y_test)

repr_str = repr(report)
assert "skore.EstimatorReport" in repr_str
assert "help()" in repr_str
assert "EstimatorReport" in repr_str


@pytest.mark.parametrize(
Expand Down

0 comments on commit 789ee65

Please sign in to comment.