Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(_find_ml_task): Refine ML task detection logic #1305

Merged
merged 11 commits into from
Feb 11, 2025
91 changes: 59 additions & 32 deletions skore/src/skore/sklearn/find_ml_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,17 @@ def _is_sequential(y) -> bool:
return np.array_equal(y_values, sequential)


def _is_classification(y) -> bool:
"""Determine if `y` is a target for a classification task.

If `y` contains integers, sklearn's `type_of_target` considers
the task to be multiclass classification.
This function makes the analysis finer.
"""
y = y.flatten()
return _is_sequential(y) and 0 in y


def _find_ml_task(y, estimator=None) -> MLTask:
"""Guess the ML task being addressed based on a target array and an estimator.

Expand Down Expand Up @@ -53,6 +64,26 @@ def _find_ml_task(y, estimator=None) -> MLTask:
# Discrete sequential values, not containing 0
>>> _find_ml_task(numpy.array([1, 3, 2]))
'regression'

# 2 values, containing 0, in a 2d array
>>> _find_ml_task(numpy.array([[0, 1], [1, 1]]))
'multioutput-binary-classification'

# Discrete sequential values, containing 0, in a 2d array
>>> _find_ml_task(numpy.array([[0, 1, 2], [2, 1, 1]]))
'multioutput-multiclass-classification'

# Discrete values, not sequential, in a 2d array
>>> _find_ml_task(numpy.array([[1, 5], [5, 9]]))
'multioutput-regression'

# Discrete values, not sequential, containing 0, in a 2d array
>>> _find_ml_task(numpy.array([[0, 1, 5, 9], [1, 0, 1, 1]]))
'multioutput-regression'

# Discrete sequential values, not containing 0, in a 2d array
>>> _find_ml_task(numpy.array([[1, 3, 2], [2, 1, 1]]))
'multioutput-regression'
"""
if estimator is not None:
# checking the estimator is more robust and faster than checking the type of
Expand All @@ -71,38 +102,34 @@ def _find_ml_task(y, estimator=None) -> MLTask:
return "binary-classification"
if estimator.classes_.size > 2:
return "multiclass-classification"
else: # fallback on the target
else:
# fallback on the target
if y is None:
return "unknown"

target_type = type_of_target(y)
if target_type == "binary":
return "binary-classification"
if target_type == "multiclass":
# If y is a vector of integers, type_of_target considers
# the task to be multiclass-classification.
# We refine this analysis a bit here.
if _is_sequential(y) and 0 in y:
return "multiclass-classification"
return "regression"
return "unknown"
return "unknown"
else:
if y is None:
# NOTE: The task might not be clustering
return "clustering"

target_type = type_of_target(y)

if target_type == "continuous":
return "regression"
if target_type == "binary":
return "binary-classification"
if target_type == "multiclass":
# If y is a vector of integers, type_of_target considers
# the task to be multiclass-classification.
# We refine this analysis a bit here.
if _is_sequential(y) and 0 in y:
return "multiclass-classification"
return "regression"
return "unknown"
# fallback on the target
if y is None:
# NOTE: The task might not be clustering
return "clustering"

target_type = type_of_target(y)

if target_type == "continuous":
return "regression"
if target_type == "continuous-multioutput":
return "multioutput-regression"
if target_type == "binary":
return "binary-classification"
if target_type == "multiclass":
if _is_classification(y):
return "multiclass-classification"
return "regression"
if target_type == "multiclass-multioutput":
if _is_classification(y):
return "multioutput-multiclass-classification"
return "multioutput-regression"
if target_type == "multilabel-indicator":
if _is_classification(y):
return "multioutput-binary-classification"
return "multioutput-regression"
return "unknown"
5 changes: 4 additions & 1 deletion skore/src/skore/sklearn/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@

MLTask = Literal[
"binary-classification",
"clustering",
"multiclass-classification",
"multioutput-binary-classification",
"multioutput-multiclass-classification",
"multioutput-regression",
"regression",
"clustering",
"unknown",
]
21 changes: 16 additions & 5 deletions skore/tests/unit/sklearn/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
(
*make_multilabel_classification(random_state=42),
MultiOutputClassifier(LogisticRegression()),
"unknown",
"multioutput-binary-classification",
),
],
)
Expand All @@ -46,18 +46,29 @@ def test_find_ml_task_with_estimator(X, y, estimator, expected_task, should_fit)
[
(make_classification(random_state=42)[1], "binary-classification"),
(
make_classification(n_classes=3, n_clusters_per_class=1, random_state=42)[
1
],
make_classification(
n_classes=3,
n_clusters_per_class=1,
random_state=42,
)[1],
"multiclass-classification",
),
(make_regression(n_samples=100, random_state=42)[1], "regression"),
(None, "clustering"),
(make_multilabel_classification(random_state=42)[1], "unknown"),
(
make_multilabel_classification(random_state=42)[1],
"multioutput-binary-classification",
),
(numpy.array([1, 5, 9]), "regression"),
(numpy.array([0, 1, 2]), "multiclass-classification"),
(numpy.array([1, 2, 3]), "regression"),
(numpy.array([0, 1, 5, 9]), "regression"),
# Non-integer target
(numpy.array([[0.5, 2]]), "multioutput-regression"),
# No 0 class
(numpy.array([[1, 2], [2, 1]]), "multioutput-regression"),
# No 2 class
(numpy.array([[0, 3], [1, 3]]), "multioutput-regression"),
],
)
def test_find_ml_task_without_estimator(target, expected_task):
Expand Down