Skip to content

Commit

Permalink
Merge branch 'uxlfoundation:main' into dev/sklearnex_assert_all_finite
Browse files Browse the repository at this point in the history
  • Loading branch information
icfaust authored Dec 3, 2024
2 parents 1db7575 + 675a2da commit 8fca003
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 19 deletions.
3 changes: 3 additions & 0 deletions deselected_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,9 @@ deselected_tests:
# Fails in stock scikit-learn: checks that data is modified in-place when not strictly required
- linear_model/tests/test_base.py::test_inplace_data_preprocessing

# Failure occurs in python3.9 on windows CPU only - not easy to reproduce
- ensemble/tests/test_weight_boosting.py::test_estimator >= 1.4 win32

# --------------------------------------------------------
# No need to test daal4py patching
reduced_tests:
Expand Down
39 changes: 22 additions & 17 deletions sklearnex/ensemble/tests/test_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def test_sklearnex_import_rf_classifier(dataframe, queue, block, trees, rows, sc

@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
def test_sklearnex_import_rf_regression(dataframe, queue):
if queue and queue.sycl_device.is_gpu:
pytest.skip("RF regressor predict for the GPU sycl_queue is buggy.")
if (not daal_check_version((2025, "P", 200))) and queue and queue.sycl_device.is_gpu:
pytest.skip("Skipping due to bug in histogram merges fixed in 2025.2.")
from sklearnex.ensemble import RandomForestRegressor

X, y = make_regression(n_features=4, n_informative=2, random_state=0, shuffle=False)
Expand All @@ -74,19 +74,20 @@ def test_sklearnex_import_rf_regression(dataframe, queue):
assert "sklearnex" in rf.__module__
pred = _as_numpy(rf.predict([[0, 0, 0, 0]]))

if queue is not None and queue.sycl_device.is_gpu:
assert_allclose([-0.011208], pred, atol=1e-2)
else:
if daal_check_version((2024, "P", 0)):
assert_allclose([-6.971], pred, atol=1e-2)
else:
assert_allclose([-6.839], pred, atol=1e-2)
# Check that the prediction is within a reasonable range.
# 'y' should be in the neighborhood of zero for x=0.
assert pred[0] >= -10
assert pred[0] <= 10

# Check that the trees aren't just empty nodes predicting the mean
for estimator in rf.estimators_:
assert estimator.tree_.children_left.shape[0] > 1


@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
def test_sklearnex_import_et_classifier(dataframe, queue):
if queue and queue.sycl_device.is_gpu:
pytest.skip("ET classifier predict for the GPU sycl_queue is buggy.")
if (not daal_check_version((2025, "P", 200))) and queue and queue.sycl_device.is_gpu:
pytest.skip("Skipping due to bug in histogram merges fixed in 2025.2.")
from sklearnex.ensemble import ExtraTreesClassifier

X, y = make_classification(
Expand All @@ -108,8 +109,8 @@ def test_sklearnex_import_et_classifier(dataframe, queue):

@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
def test_sklearnex_import_et_regression(dataframe, queue):
if queue and queue.sycl_device.is_gpu:
pytest.skip("ET regressor predict for the GPU sycl_queue is buggy.")
if (not daal_check_version((2025, "P", 200))) and queue and queue.sycl_device.is_gpu:
pytest.skip("Skipping due to bug in histogram merges fixed in 2025.2.")
from sklearnex.ensemble import ExtraTreesRegressor

X, y = make_regression(n_features=1, random_state=0, shuffle=False)
Expand All @@ -129,7 +130,11 @@ def test_sklearnex_import_et_regression(dataframe, queue):
)
)

if queue is not None and queue.sycl_device.is_gpu:
assert_allclose([1.909769], pred, atol=1e-2)
else:
assert_allclose([0.445], pred, atol=1e-2)
# Check that the prediction is within a reasonable range.
# 'y' should be in the neighborhood of zero for x=0.
assert pred[0] >= -10
assert pred[0] <= 10

# Check that the trees aren't just empty nodes predicting the mean
for estimator in rf.estimators_:
assert estimator.tree_.children_left.shape[0] > 1
4 changes: 2 additions & 2 deletions sklearnex/linear_model/tests/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@
def test_sklearnex_import_linear(
dataframe, queue, dtype, macro_block, overdetermined, multi_output
):
if (overdetermined or multi_output) and not daal_check_version((2025, "P", 1)):
if (not overdetermined or multi_output) and not daal_check_version((2025, "P", 1)):
pytest.skip("Functionality introduced in later versions")
if (
overdetermined
not overdetermined
and queue
and queue.sycl_device.is_gpu
and not daal_check_version((2025, "P", 200))
Expand Down

0 comments on commit 8fca003

Please sign in to comment.