From b198cdbdab9187af371e948763a76e4c66391bb9 Mon Sep 17 00:00:00 2001 From: Marie Date: Fri, 7 Mar 2025 09:11:46 +0100 Subject: [PATCH] fix test --- skore/tests/unit/sklearn/test_base.py | 35 ++++++++++++++++----------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/skore/tests/unit/sklearn/test_base.py b/skore/tests/unit/sklearn/test_base.py index 735278551..e19f91ff9 100644 --- a/skore/tests/unit/sklearn/test_base.py +++ b/skore/tests/unit/sklearn/test_base.py @@ -8,7 +8,7 @@ from sklearn.datasets import make_classification from sklearn.linear_model import LogisticRegression from sklearn.model_selection import train_test_split -from skore.sklearn._base import _BaseAccessor, _get_cached_response_values +from skore.sklearn._base import _BaseAccessor, _BaseReport, _get_cached_response_values class MockClassifier(ClassifierMixin, BaseEstimator): @@ -176,15 +176,16 @@ def test_get_cached_response_values_different_data_source_hash( ) -class MockReport: - def __init__(self, estimator, X_train=None, y_train=None, X_test=None, y_test=None): - """Mock a report with the minimal required attributes. +class MockReport(_BaseReport): + """Mock a report with the minimal required attributes. + + Attributes + ---------- + no_private : dummy object + The text to catch. + """ - Attributes - ---------- - no_private : dummy object - The text to catch. - """ + def __init__(self, estimator, X_train=None, y_train=None, X_test=None, y_test=None): self._estimator = estimator self._X_train = X_train self._y_train = y_train @@ -316,12 +317,18 @@ def test_base_accessor_get_X_y_and_data_source_hash(data_source): def test_base_accessor_get_attributes_description(): - _get_attributes_for_help = MockAccessor._get_attributes_for_help + X, y = make_classification(n_samples=10, n_classes=2, random_state=42) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) + estimator = LogisticRegression() + report = MockReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + + attributes = report._get_attributes_for_help() - # should only catch non private attributes - assert len(_get_attributes_for_help()) == 1 - assert MockAccessor._get_attribute_description("no_private") == "The text to catch" + assert len(attributes) == 7 + assert report._get_attribute_description("no_private") == "The text to catch" assert ( - MockAccessor._get_attribute_description("attr_without_description") + report._get_attribute_description("attr_without_description") == "No description available" )