Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
Dmitry Razdoburdin committed Jun 27, 2024
1 parent 1a2f4d1 commit ab47428
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 20 deletions.
1 change: 1 addition & 0 deletions tests/ci_build/conda_env/linux_sycl_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ dependencies:
- pytest-cov
- dpcpp_linux-64
- onedpl-devel
- dask
42 changes: 22 additions & 20 deletions tests/python-sycl/test_sycl_simple_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import dask.array as da
import dask.distributed


def train_result(client, param, dtrain, num_rounds):
result = dxgb.train(
client,
Expand All @@ -17,24 +18,25 @@ def train_result(client, param, dtrain, num_rounds):
)
return result


class TestSYCLDask:
# The simplest test verify only one node training.
def test_simple(self):
cluster = dask.distributed.LocalCluster(n_workers=1)
client = dask.distributed.Client(cluster)

param={}
param["tree_method"] = "hist"
param["device"] = "sycl"
param["verbosity"] = 0
param["objective"] = "reg:squarederror"

# X and y must be Dask dataframes or arrays
num_obs = 1e4
num_features = 20
X = da.random.random(size=(num_obs, num_features), chunks=(1000, num_features))
y = da.random.random(size=(num_obs, 1), chunks=(1000, 1))
dtrain = dxgb.DaskDMatrix(client, X, y)

result = train_result(client, param, dtrain, 10)
assert tm.non_increasing(result["history"]["train"]["rmse"])
# The simplest test verify only one node training.
def test_simple(self):
cluster = dask.distributed.LocalCluster(n_workers=1)
client = dask.distributed.Client(cluster)

param = {}
param["tree_method"] = "hist"
param["device"] = "sycl"
param["verbosity"] = 0
param["objective"] = "reg:squarederror"

# X and y must be Dask dataframes or arrays
num_obs = 1e4
num_features = 20
X = da.random.random(size=(num_obs, num_features), chunks=(1000, num_features))
y = da.random.random(size=(num_obs, 1), chunks=(1000, 1))
dtrain = dxgb.DaskDMatrix(client, X, y)

result = train_result(client, param, dtrain, 10)
assert tm.non_increasing(result["history"]["train"]["rmse"])

0 comments on commit ab47428

Please sign in to comment.