Skip to content

Commit

Permalink
Add sycl python tests to public CI (#19)
Browse files Browse the repository at this point in the history
* initial

* install dpcpp with conda

* add env file for sycl

* fixing missprint

* fixing env file name

* rename device sycl:gpu->sycl in predictor tests

* set env variables to use g++ for xgboost compilation in sycl tests

* modify tests

* make predictor tests training on cpu

* fixes for predictor tests

* add zero buffer check to predictor

* add zero buffer size chack to updater

* add zero buffer check to SetIndexData()

* fix error

* add more checks for zero buffers

* add zero buffer check to PredTransform

---------

Co-authored-by: Dmitry Razdoburdin <>
  • Loading branch information
razdoburdin authored Nov 20, 2023
1 parent 6aa6a39 commit 5046112
Show file tree
Hide file tree
Showing 11 changed files with 90 additions and 17 deletions.
45 changes: 45 additions & 0 deletions .github/workflows/python_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,51 @@ jobs:
run: |
pytest -s -v -rxXs --durations=0 ./tests/test_distributed/test_with_spark
python-sycl-tests-on-ubuntu:
name: Test XGBoost Python package with SYCL on ${{ matrix.config.os }}
runs-on: ${{ matrix.config.os }}
timeout-minutes: 90
strategy:
matrix:
config:
- {os: ubuntu-latest, python-version: "3.8"}

steps:
- uses: actions/checkout@v2
with:
submodules: 'true'

- uses: mamba-org/provision-with-micromamba@f347426e5745fe3dfc13ec5baf20496990d0281f # v14
with:
cache-downloads: true
cache-env: true
environment-name: linux_sycl_test
environment-file: tests/ci_build/conda_env/linux_sycl_test.yml

- name: Display Conda env
run: |
conda info
conda list
- name: Build XGBoost on Ubuntu
run: |
mkdir build
cd build
export CXX=g++
export CC=gcc
cmake .. -DPLUGIN_SYCL=ON -DCMAKE_PREFIX_PATH=$CONDA_PREFIX
make -j$(nproc)
- name: Install Python package
run: |
cd python-package
python --version
pip install -v .
- name: Test Python package
run: |
pytest -s -v -rxXs --durations=0 ./tests/python-sycl/
python-system-installation-on-ubuntu:
name: Test XGBoost Python package System Installation on ${{ matrix.os }}
runs-on: ${{ matrix.os }}
Expand Down
1 change: 1 addition & 0 deletions plugin/sycl/common/hist_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ void GHistIndexMatrix::SetIndexData(::sycl::queue qu,
size_t nbins,
size_t row_stride,
uint32_t* offsets) {
if (hit_count.size() == 0) return;
const xgboost::Entry *data_ptr = dmat_device.data.DataConst();
const bst_row_t *offset_vec = dmat_device.row_ptr.DataConst();
const size_t num_rows = dmat_device.row_ptr.Size() - 1;
Expand Down
7 changes: 4 additions & 3 deletions plugin/sycl/objective/multiclass_obj.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ class SoftmaxMultiClassObj : public ObjFunction {
const MetaInfo& info,
int iter,
HostDeviceVector<GradientPair>* out_gpair) override {
if (info.labels.Size() == 0) {
return;
}
if (preds.Size() == 0) return;
if (info.labels.Size() == 0) return;

CHECK(preds.Size() == (static_cast<size_t>(param_.num_class) * info.labels.Size()))
<< "SoftmaxMultiClassObj: label size and pred size does not match.\n"
<< "label.Size() * num_class: "
Expand Down Expand Up @@ -187,6 +187,7 @@ class SoftmaxMultiClassObj : public ObjFunction {


inline void Transform(HostDeviceVector<bst_float> *io_preds, bool prob) const {
if (io_preds->Size() == 0) return;
const int nclass = param_.num_class;
const auto ndata = static_cast<int64_t>(io_preds->Size() / nclass);
max_preds_.Resize(ndata);
Expand Down
5 changes: 2 additions & 3 deletions plugin/sycl/objective/regression_obj.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ class RegLossObj : public ObjFunction {
const MetaInfo &info,
int iter,
HostDeviceVector<GradientPair>* out_gpair) override {
if (info.labels.Size() == 0U) {
LOG(WARNING) << "Label set is empty.";
}
if (info.labels.Size() == 0) return;
CHECK_EQ(preds.Size(), info.labels.Size())
<< " " << "labels are not correctly provided"
<< "preds.size=" << preds.Size() << ", label.size=" << info.labels.Size() << ", "
Expand Down Expand Up @@ -125,6 +123,7 @@ class RegLossObj : public ObjFunction {

void PredTransform(HostDeviceVector<float> *io_preds) const override {
size_t const ndata = io_preds->Size();
if (ndata == 0) return;
::sycl::buffer<bst_float, 1> io_preds_buf(io_preds->HostPointer(), io_preds->Size());

qu_.submit([&](::sycl::handler& cgh) {
Expand Down
6 changes: 3 additions & 3 deletions plugin/sycl/predictor/predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,9 @@ void DevicePredictInternal(::sycl::queue qu,
const gbm::GBTreeModel& model,
size_t tree_begin,
size_t tree_end) {
if (tree_end - tree_begin == 0) {
return;
}
if (tree_end - tree_begin == 0) return;
if (out_preds->HostVector().size() == 0) return;

DeviceModel device_model;
device_model.Init(qu, model, tree_begin, tree_end);

Expand Down
1 change: 1 addition & 0 deletions plugin/sycl/tree/updater_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,7 @@ bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(

const size_t stride = out_preds.Stride(0);
const int buffer_size = out_preds.Size()*stride - stride + 1;
if (buffer_size == 0) return true;
::sycl::buffer<float, 1> out_preds_buf(&out_preds(0), buffer_size);

size_t n_nodes = row_set_collection_.Size();
Expand Down
20 changes: 20 additions & 0 deletions tests/ci_build/conda_env/linux_sycl_test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
name: linux_sycl_test
channels:
- conda-forge
- intel
dependencies:
- python=3.8
- cmake
- c-compiler
- cxx-compiler
- pip
- wheel
- numpy
- scipy
- scikit-learn
- pandas
- hypothesis>=6.46
- pytest
- pytest-timeout
- pytest-cov
- dpcpp_linux-64
16 changes: 11 additions & 5 deletions tests/python-sycl/test_sycl_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_predict(self):
cpu_pred_test = bst.predict(dtest, output_margin=True)
cpu_pred_val = bst.predict(dval, output_margin=True)

bst.set_param({"device": "sycl:gpu"})
bst.set_param({"device": "sycl"})
sycl_pred_train = bst.predict(dtrain, output_margin=True)
sycl_pred_test = bst.predict(dtest, output_margin=True)
sycl_pred_val = bst.predict(dval, output_margin=True)
Expand Down Expand Up @@ -82,7 +82,7 @@ def test_multi_predict(self):
bst = xgb.train(params, dtrain)
cpu_predict = bst.predict(dtest)

bst.set_param({"device": "sycl:gpu"})
bst.set_param({"device": "sycl"})

predict0 = bst.predict(dtest)
predict1 = bst.predict(dtest)
Expand Down Expand Up @@ -110,7 +110,7 @@ def test_sklearn(self):
cpu_test_score = m.score(X_test, y_test)

# Now with sycl_predictor
params['device'] = 'sycl:gpu'
params['device'] = 'sycl'
m.set_params(**params)

# m = xgb.XGBRegressor(**params).fit(X_train, y_train)
Expand All @@ -125,11 +125,14 @@ def test_sklearn(self):
tm.make_dataset_strategy(), shap_parameter_strategy)
@settings(deadline=None)
def test_shap(self, num_rounds, dataset, param):
param.update({"device": "sycl:gpu"})
if dataset.name.endswith("-l1"): # not supported by the exact tree method
return
param.update({"tree_method": "hist", "device": "cpu"})
param = dataset.set_params(param)
dmat = dataset.get_dmat()
bst = xgb.train(param, dmat, num_rounds)
test_dmat = xgb.DMatrix(dataset.X, dataset.y, dataset.w, dataset.margin)
bst.set_param({"device": "sycl"})
shap = bst.predict(test_dmat, pred_contribs=True)
margin = bst.predict(test_dmat, output_margin=True)
assume(len(dataset.y) > 0)
Expand All @@ -139,11 +142,14 @@ def test_shap(self, num_rounds, dataset, param):
tm.make_dataset_strategy(), shap_parameter_strategy)
@settings(deadline=None, max_examples=20)
def test_shap_interactions(self, num_rounds, dataset, param):
param.update({"device": "sycl:gpu"})
if dataset.name.endswith("-l1"): # not supported by the exact tree method
return
param.update({"tree_method": "hist", "device": "cpu"})
param = dataset.set_params(param)
dmat = dataset.get_dmat()
bst = xgb.train(param, dmat, num_rounds)
test_dmat = xgb.DMatrix(dataset.X, dataset.y, dataset.w, dataset.margin)
bst.set_param({"device": "sycl"})
shap = bst.predict(test_dmat, pred_interactions=True)
margin = bst.predict(test_dmat, output_margin=True)
assume(len(dataset.y) > 0)
Expand Down
2 changes: 1 addition & 1 deletion tests/python-sycl/test_sycl_training_continuation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def run_training_continuation(self, use_json):
X = np.random.randn(kRows, kCols)
y = np.random.randn(kRows)
dtrain = xgb.DMatrix(X, y)
params = {'device': 'sycl:gpu', 'max_depth': '2',
params = {'device': 'sycl', 'max_depth': '2',
'gamma': '0.1', 'alpha': '0.01',
'enable_experimental_json_serialization': use_json}
bst_0 = xgb.train(params, dtrain, num_boost_round=64)
Expand Down
2 changes: 1 addition & 1 deletion tests/python-sycl/test_sycl_updaters.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class TestSYCLUpdaters:
@settings(deadline=None)
def test_sycl_hist(self, param, num_rounds, dataset):
param['tree_method'] = 'hist'
param['device'] = 'sycl:gpu'
param['device'] = 'sycl'
param['verbosity'] = 0
param = dataset.set_params(param)
result = train_result(param, dataset.get_dmat(), num_rounds)
Expand Down
2 changes: 1 addition & 1 deletion tests/python-sycl/test_sycl_with_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_sycl_binary_classification():
for cls in (xgb.XGBClassifier, xgb.XGBRFClassifier):
for train_index, test_index in kf.split(X, y):
xgb_model = cls(
random_state=42, device='sycl:gpu',
random_state=42, device='sycl',
n_estimators=4).fit(X[train_index], y[train_index])
preds = xgb_model.predict(X[test_index])
labels = y[test_index]
Expand Down

0 comments on commit 5046112

Please sign in to comment.