diff --git a/skore/src/skore/sklearn/find_ml_task.py b/skore/src/skore/sklearn/find_ml_task.py index d311219f7..bc07b5b10 100644 --- a/skore/src/skore/sklearn/find_ml_task.py +++ b/skore/src/skore/sklearn/find_ml_task.py @@ -15,6 +15,17 @@ def _is_sequential(y) -> bool: return np.array_equal(y_values, sequential) +def _is_classification(y) -> bool: + """Determine if `y` is a target for a classification task. + + If `y` contains integers, sklearn's `type_of_target` considers + the task to be multiclass classification. + This function makes the analysis finer. + """ + y = y.flatten() + return _is_sequential(y) and 0 in y + + def _find_ml_task(y, estimator=None) -> MLTask: """Guess the ML task being addressed based on a target array and an estimator. @@ -53,6 +64,26 @@ def _find_ml_task(y, estimator=None) -> MLTask: # Discrete sequential values, not containing 0 >>> _find_ml_task(numpy.array([1, 3, 2])) 'regression' + + # 2 values, containing 0, in a 2d array + >>> _find_ml_task(numpy.array([[0, 1], [1, 1]])) + 'multioutput-binary-classification' + + # Discrete sequential values, containing 0, in a 2d array + >>> _find_ml_task(numpy.array([[0, 1, 2], [2, 1, 1]])) + 'multioutput-multiclass-classification' + + # Discrete values, not sequential, in a 2d array + >>> _find_ml_task(numpy.array([[1, 5], [5, 9]])) + 'multioutput-regression' + + # Discrete values, not sequential, containing 0, in a 2d array + >>> _find_ml_task(numpy.array([[0, 1, 5, 9], [1, 0, 1, 1]])) + 'multioutput-regression' + + # Discrete sequential values, not containing 0, in a 2d array + >>> _find_ml_task(numpy.array([[1, 3, 2], [2, 1, 1]])) + 'multioutput-regression' """ if estimator is not None: # checking the estimator is more robust and faster than checking the type of @@ -71,38 +102,34 @@ def _find_ml_task(y, estimator=None) -> MLTask: return "binary-classification" if estimator.classes_.size > 2: return "multiclass-classification" - else: # fallback on the target + else: + # fallback on the target if y is None: return "unknown" - target_type = type_of_target(y) - if target_type == "binary": - return "binary-classification" - if target_type == "multiclass": - # If y is a vector of integers, type_of_target considers - # the task to be multiclass-classification. - # We refine this analysis a bit here. - if _is_sequential(y) and 0 in y: - return "multiclass-classification" - return "regression" - return "unknown" - return "unknown" - else: - if y is None: - # NOTE: The task might not be clustering - return "clustering" - - target_type = type_of_target(y) - - if target_type == "continuous": - return "regression" - if target_type == "binary": - return "binary-classification" - if target_type == "multiclass": - # If y is a vector of integers, type_of_target considers - # the task to be multiclass-classification. - # We refine this analysis a bit here. - if _is_sequential(y) and 0 in y: - return "multiclass-classification" - return "regression" - return "unknown" + # fallback on the target + if y is None: + # NOTE: The task might not be clustering + return "clustering" + + target_type = type_of_target(y) + + if target_type == "continuous": + return "regression" + if target_type == "continuous-multioutput": + return "multioutput-regression" + if target_type == "binary": + return "binary-classification" + if target_type == "multiclass": + if _is_classification(y): + return "multiclass-classification" + return "regression" + if target_type == "multiclass-multioutput": + if _is_classification(y): + return "multioutput-multiclass-classification" + return "multioutput-regression" + if target_type == "multilabel-indicator": + if _is_classification(y): + return "multioutput-binary-classification" + return "multioutput-regression" + return "unknown" diff --git a/skore/src/skore/sklearn/types.py b/skore/src/skore/sklearn/types.py index 3ac85af88..7c58a50a4 100644 --- a/skore/src/skore/sklearn/types.py +++ b/skore/src/skore/sklearn/types.py @@ -4,8 +4,11 @@ MLTask = Literal[ "binary-classification", + "clustering", "multiclass-classification", + "multioutput-binary-classification", + "multioutput-multiclass-classification", + "multioutput-regression", "regression", - "clustering", "unknown", ] diff --git a/skore/tests/unit/sklearn/test_utils.py b/skore/tests/unit/sklearn/test_utils.py index 08675c0c5..205a1d8ae 100644 --- a/skore/tests/unit/sklearn/test_utils.py +++ b/skore/tests/unit/sklearn/test_utils.py @@ -28,7 +28,7 @@ ( *make_multilabel_classification(random_state=42), MultiOutputClassifier(LogisticRegression()), - "unknown", + "multioutput-binary-classification", ), ], ) @@ -46,18 +46,29 @@ def test_find_ml_task_with_estimator(X, y, estimator, expected_task, should_fit) [ (make_classification(random_state=42)[1], "binary-classification"), ( - make_classification(n_classes=3, n_clusters_per_class=1, random_state=42)[ - 1 - ], + make_classification( + n_classes=3, + n_clusters_per_class=1, + random_state=42, + )[1], "multiclass-classification", ), (make_regression(n_samples=100, random_state=42)[1], "regression"), (None, "clustering"), - (make_multilabel_classification(random_state=42)[1], "unknown"), + ( + make_multilabel_classification(random_state=42)[1], + "multioutput-binary-classification", + ), (numpy.array([1, 5, 9]), "regression"), (numpy.array([0, 1, 2]), "multiclass-classification"), (numpy.array([1, 2, 3]), "regression"), (numpy.array([0, 1, 5, 9]), "regression"), + # Non-integer target + (numpy.array([[0.5, 2]]), "multioutput-regression"), + # No 0 class + (numpy.array([[1, 2], [2, 1]]), "multioutput-regression"), + # No 2 class + (numpy.array([[0, 3], [1, 3]]), "multioutput-regression"), ], ) def test_find_ml_task_without_estimator(target, expected_task):