Skip to content

Commit

Permalink
Support more sklearn tags for testing. (dmlc#10230)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Apr 28, 2024
1 parent f8c3d22 commit 837d44a
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 15 deletions.
16 changes: 15 additions & 1 deletion python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,10 @@ def __init__(

def _more_tags(self) -> Dict[str, bool]:
"""Tags used for scikit-learn data validation."""
return {"allow_nan": True, "no_validation": True}
tags = {"allow_nan": True, "no_validation": True}
if hasattr(self, "kwargs") and self.kwargs.get("updater") == "shotgun":
tags["non_deterministic"] = True
return tags

def __sklearn_is_fitted__(self) -> bool:
return hasattr(self, "_Booster")
Expand Down Expand Up @@ -1439,6 +1442,11 @@ def __init__(
) -> None:
super().__init__(objective=objective, **kwargs)

def _more_tags(self) -> Dict[str, bool]:
tags = super()._more_tags()
tags["multilabel"] = True
return tags

@_deprecate_positional_args
def fit(
self,
Expand Down Expand Up @@ -1717,6 +1725,12 @@ def __init__(
) -> None:
super().__init__(objective=objective, **kwargs)

def _more_tags(self) -> Dict[str, bool]:
tags = super()._more_tags()
tags["multioutput"] = True
tags["multioutput_only"] = False
return tags


@xgboost_model_doc(
"scikit-learn API for XGBoost random forest regression.",
Expand Down
36 changes: 22 additions & 14 deletions tests/python/test_with_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1300,20 +1300,12 @@ def test_estimator_reg(estimator, check):
):
estimator.fit(X, y)
return
if (
os.environ["PYTEST_CURRENT_TEST"].find("check_estimators_overwrite_params")
!= -1
):
# A hack to pass the scikit-learn parameter mutation tests. XGBoost regressor
# returns actual internal default values for parameters in `get_params`, but
# those are set as `None` in sklearn interface to avoid duplication. So we fit
# a dummy model and obtain the default parameters here for the mutation tests.
from sklearn.datasets import make_regression

X, y = make_regression(n_samples=2, n_features=1)
estimator.set_params(**xgb.XGBRegressor().fit(X, y).get_params())

check(estimator)
elif os.environ["PYTEST_CURRENT_TEST"].find("check_regressor_multioutput") != -1:
# sklearn requires float64
with pytest.raises(AssertionError, match="Got float32"):
check(estimator)
else:
check(estimator)


def test_categorical():
Expand Down Expand Up @@ -1475,3 +1467,19 @@ def test_fit_none() -> None:

with pytest.raises(ValueError, match="labels"):
xgb.XGBRegressor().fit(X, None)


def test_tags() -> None:
for reg in [xgb.XGBRegressor(), xgb.XGBRFRegressor()]:
tags = reg._more_tags()
assert "non_deterministic" not in tags
assert tags["multioutput"] is True
assert tags["multioutput_only"] is False

for clf in [xgb.XGBClassifier()]:
tags = clf._more_tags()
assert "multioutput" not in tags
assert tags["multilabel"] is True

tags = xgb.XGBRanker()._more_tags()
assert "multioutput" not in tags

0 comments on commit 837d44a

Please sign in to comment.