Skip to content

Commit

Permalink
fallback on target in more cases
Browse files Browse the repository at this point in the history
  • Loading branch information
auguste-probabl committed Feb 11, 2025
1 parent 65feca4 commit ba2168c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 40 deletions.
67 changes: 28 additions & 39 deletions skore/src/skore/sklearn/find_ml_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,45 +102,34 @@ def _find_ml_task(y, estimator=None) -> MLTask:
return "binary-classification"
if estimator.classes_.size > 2:
return "multiclass-classification"
else: # fallback on the target
else:
# fallback on the target
if y is None:
return "unknown"

target_type = type_of_target(y)
if target_type == "binary":
return "binary-classification"
if target_type == "multiclass":
# If y is a vector of integers, type_of_target considers
# the task to be multiclass-classification.
# We refine this analysis a bit here.
if _is_sequential(y) and 0 in y:
return "multiclass-classification"
return "regression"
return "unknown"
return "unknown"
else:
if y is None:
# NOTE: The task might not be clustering
return "clustering"

target_type = type_of_target(y)

if target_type == "continuous":
return "regression"
if target_type == "continuous-multioutput":
return "multioutput-regression"
if target_type == "binary":
return "binary-classification"
if target_type == "multiclass":
if _is_classification(y):
return "multiclass-classification"
return "regression"
if target_type == "multiclass-multioutput":
if _is_classification(y):
return "multioutput-multiclass-classification"
return "multioutput-regression"
if target_type == "multilabel-indicator":
if _is_classification(y):
return "multioutput-binary-classification"
return "multioutput-regression"
return "unknown"
# fallback on the target
if y is None:
# NOTE: The task might not be clustering
return "clustering"

target_type = type_of_target(y)

if target_type == "continuous":
return "regression"
if target_type == "continuous-multioutput":
return "multioutput-regression"
if target_type == "binary":
return "binary-classification"
if target_type == "multiclass":
if _is_classification(y):
return "multiclass-classification"
return "regression"
if target_type == "multiclass-multioutput":
if _is_classification(y):
return "multioutput-multiclass-classification"
return "multioutput-regression"
if target_type == "multilabel-indicator":
if _is_classification(y):
return "multioutput-binary-classification"
return "multioutput-regression"
return "unknown"
2 changes: 1 addition & 1 deletion skore/tests/unit/sklearn/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
(
*make_multilabel_classification(random_state=42),
MultiOutputClassifier(LogisticRegression()),
"unknown",
"multioutput-binary-classification",
),
],
)
Expand Down

0 comments on commit ba2168c

Please sign in to comment.