Skip to content

Commit

Permalink
Update test_forest.py
Browse files Browse the repository at this point in the history
  • Loading branch information
icfaust authored Dec 13, 2024
1 parent 9d3f3a5 commit 1aad858
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion sklearnex/ensemble/tests/test_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
@pytest.mark.parametrize("block, trees, rows, scale", hparam_values)
def test_sklearnex_import_rf_classifier(dataframe, queue, block, trees, rows, scale):
from sklearnex.ensemble import RandomForestClassifier

from sklearnex.utils.validation import validate_data

X, y = make_classification(
n_samples=1000,
n_features=4,
Expand All @@ -51,6 +52,8 @@ def test_sklearnex_import_rf_classifier(dataframe, queue, block, trees, rows, sc
X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
y = _convert_to_dataframe(y, sycl_queue=queue, target_df=dataframe)
rf = RandomForestClassifier(max_depth=2, random_state=0).fit(X, y)
# Test to see if this changes validation coverage
validate_data(rf, X, reset=False)
hparams = RandomForestClassifier.get_hyperparameters("infer")
if hparams and block is not None:
hparams.block_size = block
Expand Down

0 comments on commit 1aad858

Please sign in to comment.