Skip to content

Commit

Permalink
sklearn fix and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Dmitry Razdoburdin committed Sep 3, 2024
1 parent f52f11e commit 30a8a93
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 4 deletions.
8 changes: 5 additions & 3 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,10 @@ def _check_rf_callback(
)


def _can_use_qdm(tree_method: Optional[str]) -> bool:
return tree_method in ("hist", "gpu_hist", None, "auto")
def _can_use_qdm(tree_method: Optional[str],
device: Optional[str]) -> bool:
is_sycl = device.startswith("sycl")
return tree_method in ("hist", "gpu_hist", None, "auto") and not is_sycl


class _SklObjWProto(Protocol): # pylint: disable=too-few-public-methods
Expand Down Expand Up @@ -1031,7 +1033,7 @@ def _duplicated(parameter: str) -> None:

def _create_dmatrix(self, ref: Optional[DMatrix], **kwargs: Any) -> DMatrix:
# Use `QuantileDMatrix` to save memory.
if _can_use_qdm(self.tree_method) and self.booster != "gblinear":
if _can_use_qdm(self.tree_method, self.device) and self.booster != "gblinear":
try:
return QuantileDMatrix(
**kwargs, ref=ref, nthread=self.n_jobs, max_bin=self.max_bin
Expand Down
37 changes: 37 additions & 0 deletions tests/python-sycl/test_sycl_with_sklearn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import xgboost as xgb
import pytest
import sys
import numpy as np

from xgboost import testing as tm

sys.path.append("tests/python")
import test_with_sklearn as twskl # noqa

pytestmark = pytest.mark.skipif(**tm.no_sklearn())

rng = np.random.RandomState(1994)


def test_sycl_binary_classification():
from sklearn.datasets import load_digits
from sklearn.model_selection import KFold

digits = load_digits(n_class=2)
y = digits["target"]
X = digits["data"]
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
for cls in (xgb.XGBClassifier, xgb.XGBRFClassifier):
for train_index, test_index in kf.split(X, y):
xgb_model = cls(random_state=42, device="sycl", n_estimators=4).fit(
X[train_index], y[train_index]
)
preds = xgb_model.predict(X[test_index])
labels = y[test_index]
err = sum(
1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i]
) / float(len(preds))
print(preds)
print(labels)
print(err)
assert err < 0.1

0 comments on commit 30a8a93

Please sign in to comment.