From e881928fda205eb990db95c7c147228097971e8d Mon Sep 17 00:00:00 2001 From: Alexander Andreev Date: Tue, 10 Dec 2024 07:28:32 -0800 Subject: [PATCH 01/18] Apply scipy array API support --- .ci/scripts/run_sklearn_tests.py | 5 ++ .../sklearn/linear_model/tests/test_linear.py | 12 +++ daal4py/sklearn/metrics/_pairwise.py | 86 +++++++++++++++++-- onedal/svm/tests/test_svc.py | 11 +++ sklearnex/_config.py | 7 ++ .../tests/test_incremental_covariance.py | 11 +++ 6 files changed, 123 insertions(+), 9 deletions(-) diff --git a/.ci/scripts/run_sklearn_tests.py b/.ci/scripts/run_sklearn_tests.py index 3521fea859..a7f5b04b7f 100644 --- a/.ci/scripts/run_sklearn_tests.py +++ b/.ci/scripts/run_sklearn_tests.py @@ -25,6 +25,8 @@ import pytest import sklearn +from daal4py.sklearn._utils import sklearn_check_version + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( @@ -43,6 +45,9 @@ if os.environ["SELECTED_TESTS"] == "all": os.environ["SELECTED_TESTS"] = "" + if sklearn_check_version("1.6"): + os.environ["SCIPY_ARRAY_API"] = "1" + pytest_args = ( "--verbose --durations=100 --durations-min=0.01 " f"--rootdir={sklearn_file_dir} " diff --git a/daal4py/sklearn/linear_model/tests/test_linear.py b/daal4py/sklearn/linear_model/tests/test_linear.py index 57a11c6cdb..29137b475a 100644 --- a/daal4py/sklearn/linear_model/tests/test_linear.py +++ b/daal4py/sklearn/linear_model/tests/test_linear.py @@ -14,6 +14,18 @@ # limitations under the License. # ============================================================================== + +from os import environ + +from daal4py.sklearn._utils import sklearn_check_version + +# sklearn requires manual enabling of Scipy array API support +# if `array-api-compat` package is present in environment +# TODO: create generic approach to handle this for all tests +if sklearn_check_version("1.6"): + environ["SCIPY_ARRAY_API"] = "1" + + import numpy as np import pytest from sklearn.datasets import make_regression diff --git a/daal4py/sklearn/metrics/_pairwise.py b/daal4py/sklearn/metrics/_pairwise.py index 432c0d60a1..be5692757a 100755 --- a/daal4py/sklearn/metrics/_pairwise.py +++ b/daal4py/sklearn/metrics/_pairwise.py @@ -48,7 +48,12 @@ def _precompute_metric_params(*args, **kwrds): from .._utils import PatchingConditionsChain, getFPType, sklearn_check_version if sklearn_check_version("1.3"): - from sklearn.utils._param_validation import Integral, StrOptions, validate_params + from sklearn.utils._param_validation import ( + Hidden, + Integral, + StrOptions, + validate_params, + ) def _daal4py_cosine_distance_dense(X): @@ -65,7 +70,7 @@ def _daal4py_correlation_distance_dense(X): return res.correlationDistance -def pairwise_distances( +def _pairwise_distances( X, Y=None, metric="euclidean", *, n_jobs=None, force_all_finite=True, **kwds ): if metric not in _VALID_METRICS and not callable(metric) and metric != "precomputed": @@ -140,15 +145,78 @@ def pairwise_distances( return _parallel_pairwise(X, Y, func, n_jobs, **kwds) +# logic to deprecate `force_all_finite` from sklearn: +# it was renamed to `ensure_all_finite` since 1.6 and will be removed in 1.8 if sklearn_check_version("1.3"): + pairwise_distances_parameters = { + "X": ["array-like", "sparse matrix"], + "Y": ["array-like", "sparse matrix", None], + "metric": [StrOptions(set(_VALID_METRICS) | {"precomputed"}), callable], + "n_jobs": [Integral, None], + } + if sklearn_check_version("1.6"): + pairwise_distances_parameters["ensure_all_finite"] = [ + "boolean", + StrOptions({"allow-nan"}), + Hidden(None), + ] + if not sklearn_check_version("1.8"): + from sklearn.utils.deprecation import _deprecate_force_all_finite + + pairwise_distances_parameters["force_all_finite"] = [ + "boolean", + StrOptions({"allow-nan"}), + Hidden(StrOptions({"deprecated"})), + ] + + def pairwise_distances( + X, + Y=None, + metric="euclidean", + *, + n_jobs=None, + force_all_finite="deprecated", + ensure_all_finite=None, + **kwds, + ): + force_all_finite = _deprecate_force_all_finite( + force_all_finite, ensure_all_finite + ) + return _pairwise_distances( + X, Y, metric, n_jobs=n_jobs, force_all_finite=force_all_finite, **kwds + ) + + else: + + def pairwise_distances( + X, + Y=None, + metric="euclidean", + *, + n_jobs=None, + ensure_all_finite=None, + **kwds, + ): + return _pairwise_distances( + X, + Y, + metric, + n_jobs=n_jobs, + force_all_finite=ensure_all_finite, + **kwds, + ) + + else: + + def pairwise_distances( + X, Y=None, metric="euclidean", *, n_jobs=None, force_all_finite=True, **kwds + ): + return _pairwise_distances( + X, Y, metric, n_jobs=n_jobs, force_all_finite=force_all_finite, **kwds + ) + pairwise_distances = validate_params( - { - "X": ["array-like", "sparse matrix"], - "Y": ["array-like", "sparse matrix", None], - "metric": [StrOptions(set(_VALID_METRICS) | {"precomputed"}), callable], - "n_jobs": [Integral, None], - "force_all_finite": ["boolean", StrOptions({"allow-nan"})], - }, + pairwise_distances_parameters, prefer_skip_nested_validation=True, )(pairwise_distances) diff --git a/onedal/svm/tests/test_svc.py b/onedal/svm/tests/test_svc.py index 9f7eaa4810..e6534be3e0 100644 --- a/onedal/svm/tests/test_svc.py +++ b/onedal/svm/tests/test_svc.py @@ -14,6 +14,17 @@ # limitations under the License. # ============================================================================== +from os import environ + +from daal4py.sklearn._utils import sklearn_check_version + +# sklearn requires manual enabling of Scipy array API support +# if `array-api-compat` package is present in environment +# TODO: create generic approach to handle this for all tests +if sklearn_check_version("1.6"): + environ["SCIPY_ARRAY_API"] = "1" + + import numpy as np import pytest import sklearn.utils.estimator_checks diff --git a/sklearnex/_config.py b/sklearnex/_config.py index fafdde6e68..6589f77d85 100644 --- a/sklearnex/_config.py +++ b/sklearnex/_config.py @@ -15,10 +15,12 @@ # ============================================================================== from contextlib import contextmanager +from os import environ from sklearn import get_config as skl_get_config from sklearn import set_config as skl_set_config +from daal4py.sklearn._utils import sklearn_check_version from onedal._config import _get_config as onedal_get_config @@ -65,6 +67,11 @@ def set_config( config_context : Context manager for global configuration. get_config : Retrieve current values of the global configuration. """ + + array_api_dispatch = sklearn_configs.get("array_api_dispatch", False) + if array_api_dispatch and sklearn_check_version("1.6"): + environ["SCIPY_ARRAY_API"] = "1" + skl_set_config(**sklearn_configs) local_config = onedal_get_config(copy=False) diff --git a/sklearnex/covariance/tests/test_incremental_covariance.py b/sklearnex/covariance/tests/test_incremental_covariance.py index 68272ced9e..e42373cf84 100644 --- a/sklearnex/covariance/tests/test_incremental_covariance.py +++ b/sklearnex/covariance/tests/test_incremental_covariance.py @@ -14,6 +14,17 @@ # limitations under the License. # =============================================================================== +from os import environ + +from daal4py.sklearn._utils import sklearn_check_version + +# sklearn requires manual enabling of Scipy array API support +# if `array-api-compat` package is present in environment +# TODO: create generic approach to handle this for all tests +if sklearn_check_version("1.6"): + environ["SCIPY_ARRAY_API"] = "1" + + import numpy as np import pytest from numpy.linalg import slogdet From 0fa0184568da7e395296bb1fda58a765bf69167a Mon Sep 17 00:00:00 2001 From: Alexander Andreev Date: Tue, 10 Dec 2024 07:28:51 -0800 Subject: [PATCH 02/18] Deselect tests for unsupported skl1.6 features --- deselected_tests.yaml | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/deselected_tests.yaml b/deselected_tests.yaml index 57e36a9208..86e60136a5 100755 --- a/deselected_tests.yaml +++ b/deselected_tests.yaml @@ -25,6 +25,20 @@ # will exclude deselection in versions 0.18.1, and 0.18.2 only. deselected_tests: + # sklearn 1.6 unsupported features + - linear_model/tests/test_base.py::test_linear_regression_sample_weight_consistency[42-True-None-X_shape1] + - linear_model/tests/test_base.py::test_linear_regression_sample_weight_consistency[42-True-None-X_shape2] + - linear_model/tests/test_ridge.py::test_ridge_shapes_type + - linear_model/tests/test_ridge.py::test_ridge_cv_results_predictions[2-False-False] + - linear_model/tests/test_ridge.py::test_ridge_cv_results_predictions[2-False-True] + - neighbors/tests/test_neighbors.py::test_nan_euclidean_support[KNeighborsClassifier-params0] + - neighbors/tests/test_neighbors.py::test_nan_euclidean_support[KNeighborsRegressor-params1] + - neighbors/tests/test_neighbors.py::test_nan_euclidean_support[LocalOutlierFactor-params6] + - neighbors/tests/test_neighbors.py::test_neighbor_classifiers_loocv[ball_tree-nn_model0] + - neighbors/tests/test_neighbors.py::test_neighbor_classifiers_loocv[brute-nn_model0] + - neighbors/tests/test_neighbors.py::test_neighbor_classifiers_loocv[kd_tree-nn_model0] + - neighbors/tests/test_neighbors.py::test_neighbor_classifiers_loocv[auto-nn_model0] + # Array API support # sklearnex functional Array API support doesn't guaranty namespace consistency for the estimator's array attributes. - decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,svd_solver='covariance_eigh')-check_array_api_input_and_values-array_api_strict-None-None] From 042e5f4046fa175821e036ec56f8f32508f4ae29 Mon Sep 17 00:00:00 2001 From: Alexander Andreev Date: Tue, 10 Dec 2024 07:32:41 -0800 Subject: [PATCH 03/18] Add sklearn 1.6 to CI matrix --- .ci/pipeline/ci.yml | 6 ++++++ requirements-test.txt | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/.ci/pipeline/ci.yml b/.ci/pipeline/ci.yml index 09499563d4..36877eec4a 100644 --- a/.ci/pipeline/ci.yml +++ b/.ci/pipeline/ci.yml @@ -134,6 +134,9 @@ jobs: Python3.13_Sklearn1.5: PYTHON_VERSION: '3.13' SKLEARN_VERSION: '1.5' + Python3.13_Sklearn1.6: + PYTHON_VERSION: '3.13' + SKLEARN_VERSION: '1.6' pool: vmImage: 'ubuntu-22.04' steps: @@ -161,6 +164,9 @@ jobs: Python3.13_Sklearn1.5: PYTHON_VERSION: '3.13' SKLEARN_VERSION: '1.5' + Python3.13_Sklearn1.6: + PYTHON_VERSION: '3.13' + SKLEARN_VERSION: '1.6' pool: vmImage: 'windows-2022' steps: diff --git a/requirements-test.txt b/requirements-test.txt index e59fdf0606..02a99a8406 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -6,7 +6,7 @@ numpy>=1.19.5 ; python_version <= '3.9' numpy>=1.21.6 ; python_version == '3.10' numpy>=1.23.5 ; python_version == '3.11' numpy>=2.0.0 ; python_version >= '3.12' -scikit-learn==1.5.2 +scikit-learn==1.6.0 pandas==2.1.3 ; python_version < '3.11' pandas==2.2.3 ; python_version >= '3.11' xgboost==2.1.3 From 1939710f2502335593262216e6a8304226cadcc4 Mon Sep 17 00:00:00 2001 From: Alexander Andreev Date: Tue, 10 Dec 2024 08:04:19 -0800 Subject: [PATCH 04/18] Fix pairwise_distances dispatching --- daal4py/sklearn/metrics/_pairwise.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/daal4py/sklearn/metrics/_pairwise.py b/daal4py/sklearn/metrics/_pairwise.py index be5692757a..fa7acba90e 100755 --- a/daal4py/sklearn/metrics/_pairwise.py +++ b/daal4py/sklearn/metrics/_pairwise.py @@ -207,17 +207,12 @@ def pairwise_distances( ) else: - - def pairwise_distances( - X, Y=None, metric="euclidean", *, n_jobs=None, force_all_finite=True, **kwds - ): - return _pairwise_distances( - X, Y, metric, n_jobs=n_jobs, force_all_finite=force_all_finite, **kwds - ) + pairwise_distances = _pairwise_distances pairwise_distances = validate_params( pairwise_distances_parameters, prefer_skip_nested_validation=True, )(pairwise_distances) - +else: + pairwise_distances = _pairwise_distances pairwise_distances.__doc__ = pairwise_distances_original.__doc__ From f7048a2c728c8f112011579b3aa742db00aa412e Mon Sep 17 00:00:00 2001 From: Alexander Andreev Date: Tue, 10 Dec 2024 08:30:43 -0800 Subject: [PATCH 05/18] Fix forbidden usage of sklearn_check_version --- onedal/svm/tests/test_svc.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/onedal/svm/tests/test_svc.py b/onedal/svm/tests/test_svc.py index e6534be3e0..f81b60cb13 100644 --- a/onedal/svm/tests/test_svc.py +++ b/onedal/svm/tests/test_svc.py @@ -16,13 +16,10 @@ from os import environ -from daal4py.sklearn._utils import sklearn_check_version - # sklearn requires manual enabling of Scipy array API support # if `array-api-compat` package is present in environment # TODO: create generic approach to handle this for all tests -if sklearn_check_version("1.6"): - environ["SCIPY_ARRAY_API"] = "1" +environ["SCIPY_ARRAY_API"] = "1" import numpy as np From 19305bd19b2f8b6bbe8b4f0fc1f97eba40a590d1 Mon Sep 17 00:00:00 2001 From: Alexander Andreev Date: Tue, 10 Dec 2024 11:44:14 -0800 Subject: [PATCH 06/18] Fix for pairwise_distances params validation --- daal4py/sklearn/metrics/_pairwise.py | 47 ++++++++++++++-------------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/daal4py/sklearn/metrics/_pairwise.py b/daal4py/sklearn/metrics/_pairwise.py index fa7acba90e..54a0d714a8 100755 --- a/daal4py/sklearn/metrics/_pairwise.py +++ b/daal4py/sklearn/metrics/_pairwise.py @@ -153,21 +153,20 @@ def _pairwise_distances( "Y": ["array-like", "sparse matrix", None], "metric": [StrOptions(set(_VALID_METRICS) | {"precomputed"}), callable], "n_jobs": [Integral, None], - } - if sklearn_check_version("1.6"): - pairwise_distances_parameters["ensure_all_finite"] = [ + "force_all_finite": [ + "boolean", + StrOptions({"allow-nan"}), + Hidden(StrOptions({"deprecated"})), + ], + "ensure_all_finite": [ "boolean", StrOptions({"allow-nan"}), Hidden(None), - ] - if not sklearn_check_version("1.8"): - from sklearn.utils.deprecation import _deprecate_force_all_finite - - pairwise_distances_parameters["force_all_finite"] = [ - "boolean", - StrOptions({"allow-nan"}), - Hidden(StrOptions({"deprecated"})), - ] + ], + } + if sklearn_check_version("1.6"): + if sklearn_check_version("1.8"): + del pairwise_distances_parameters["force_all_finite"] def pairwise_distances( X, @@ -175,18 +174,20 @@ def pairwise_distances( metric="euclidean", *, n_jobs=None, - force_all_finite="deprecated", ensure_all_finite=None, **kwds, ): - force_all_finite = _deprecate_force_all_finite( - force_all_finite, ensure_all_finite - ) return _pairwise_distances( - X, Y, metric, n_jobs=n_jobs, force_all_finite=force_all_finite, **kwds + X, + Y, + metric, + n_jobs=n_jobs, + force_all_finite=ensure_all_finite, + **kwds, ) else: + from sklearn.utils.deprecation import _deprecate_force_all_finite def pairwise_distances( X, @@ -194,19 +195,19 @@ def pairwise_distances( metric="euclidean", *, n_jobs=None, + force_all_finite="deprecated", ensure_all_finite=None, **kwds, ): + force_all_finite = _deprecate_force_all_finite( + force_all_finite, ensure_all_finite + ) return _pairwise_distances( - X, - Y, - metric, - n_jobs=n_jobs, - force_all_finite=ensure_all_finite, - **kwds, + X, Y, metric, n_jobs=n_jobs, force_all_finite=force_all_finite, **kwds ) else: + del pairwise_distances_parameters["ensure_all_finite"] pairwise_distances = _pairwise_distances pairwise_distances = validate_params( From 693a4fb1c2aa2e629f75f73d6efd9e603be79fa9 Mon Sep 17 00:00:00 2001 From: Alexander Andreev Date: Wed, 11 Dec 2024 02:54:12 -0800 Subject: [PATCH 07/18] Fix for pairwise_distances params validation --- daal4py/sklearn/metrics/_pairwise.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/daal4py/sklearn/metrics/_pairwise.py b/daal4py/sklearn/metrics/_pairwise.py index 54a0d714a8..2bf25b98e9 100755 --- a/daal4py/sklearn/metrics/_pairwise.py +++ b/daal4py/sklearn/metrics/_pairwise.py @@ -208,7 +208,24 @@ def pairwise_distances( else: del pairwise_distances_parameters["ensure_all_finite"] - pairwise_distances = _pairwise_distances + + def pairwise_distances( + X, + Y=None, + metric="euclidean", + *, + n_jobs=None, + force_all_finite=None, + **kwds, + ): + return _pairwise_distances( + X, + Y, + metric, + n_jobs=n_jobs, + force_all_finite=force_all_finite, + **kwds, + ) pairwise_distances = validate_params( pairwise_distances_parameters, From 3ebaca94d200d24704760e761a059f631286cc5c Mon Sep 17 00:00:00 2001 From: Alexander Andreev Date: Wed, 11 Dec 2024 03:28:34 -0800 Subject: [PATCH 08/18] Fix for pairwise_distances params validation --- daal4py/sklearn/metrics/_pairwise.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/daal4py/sklearn/metrics/_pairwise.py b/daal4py/sklearn/metrics/_pairwise.py index 2bf25b98e9..dba150c307 100755 --- a/daal4py/sklearn/metrics/_pairwise.py +++ b/daal4py/sklearn/metrics/_pairwise.py @@ -215,7 +215,7 @@ def pairwise_distances( metric="euclidean", *, n_jobs=None, - force_all_finite=None, + force_all_finite=True, **kwds, ): return _pairwise_distances( From cfcbc967a32a83a42807ee634c8d7e40749e5e8c Mon Sep 17 00:00:00 2001 From: Alexander Andreev Date: Wed, 11 Dec 2024 06:42:41 -0800 Subject: [PATCH 09/18] Add SCIPY_ARRAY_API to test_estimators --- tests/test_estimators.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_estimators.py b/tests/test_estimators.py index 15e1923bcd..df05c49639 100644 --- a/tests/test_estimators.py +++ b/tests/test_estimators.py @@ -14,6 +14,14 @@ # limitations under the License. # ============================================================================== +from os import environ + +# sklearn requires manual enabling of Scipy array API support +# if `array-api-compat` package is present in environment +# TODO: create generic approach to handle this for all tests +environ["SCIPY_ARRAY_API"] = "1" + + import unittest import sklearn.utils.estimator_checks From c87c0c323868eca756d7ff274516d06a42d5e4d3 Mon Sep 17 00:00:00 2001 From: Alexander Andreev Date: Wed, 11 Dec 2024 06:43:06 -0800 Subject: [PATCH 10/18] Update input validation in AdaBoost and GBT d4p estimators --- .../sklearn/ensemble/AdaBoostClassifier.py | 12 +++-- daal4py/sklearn/ensemble/GBTDAAL.py | 45 +++++++++++++------ 2 files changed, 39 insertions(+), 18 deletions(-) diff --git a/daal4py/sklearn/ensemble/AdaBoostClassifier.py b/daal4py/sklearn/ensemble/AdaBoostClassifier.py index 1cc9bad41d..14b2a3854d 100644 --- a/daal4py/sklearn/ensemble/AdaBoostClassifier.py +++ b/daal4py/sklearn/ensemble/AdaBoostClassifier.py @@ -25,13 +25,19 @@ from sklearn.utils.validation import check_array, check_is_fitted, check_X_y import daal4py as d4p +from daal4py.sklearn._utils import sklearn_check_version from .._n_jobs_support import control_n_jobs from .._utils import getFPType +if sklearn_check_version("1.6"): + from sklearn.utils.validation import validate_data +else: + validate_data = BaseEstimator._validate_data + @control_n_jobs(decorated_methods=["fit", "predict"]) -class AdaBoostClassifier(BaseEstimator, ClassifierMixin): +class AdaBoostClassifier(ClassifierMixin, BaseEstimator): def __init__( self, split_criterion="gini", @@ -151,9 +157,7 @@ def predict(self, X): check_is_fitted(self) # Input validation - X = check_array(X, dtype=[np.single, np.double]) - if X.shape[1] != self.n_features_in_: - raise ValueError("Shape of input is different from what was seen in `fit`") + X = validate_data(self, X, dtype=[np.single, np.double], reset=False) # Trivial case if self.n_classes_ == 1: diff --git a/daal4py/sklearn/ensemble/GBTDAAL.py b/daal4py/sklearn/ensemble/GBTDAAL.py index b4de6ba9e3..5d8b564cb7 100644 --- a/daal4py/sklearn/ensemble/GBTDAAL.py +++ b/daal4py/sklearn/ensemble/GBTDAAL.py @@ -26,10 +26,16 @@ from sklearn.utils.validation import check_array, check_is_fitted, check_X_y import daal4py as d4p +from daal4py.sklearn._utils import sklearn_check_version from .._n_jobs_support import control_n_jobs from .._utils import getFPType +if sklearn_check_version("1.6"): + from sklearn.utils.validation import validate_data +else: + validate_data = BaseEstimator._validate_data + class GBTDAALBase(BaseEstimator, d4p.mb.GBTDAALBaseModel): def __init__( @@ -128,9 +134,14 @@ def _check_params(self): def _more_tags(self): return {"allow_nan": self.allow_nan_} + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags.input_tags.allow_nan = self.allow_nan_ + return tags + @control_n_jobs(decorated_methods=["fit", "predict"]) -class GBTDAALClassifier(GBTDAALBase, ClassifierMixin): +class GBTDAALClassifier(ClassifierMixin, GBTDAALBase): def fit(self, X, y): # Check the algorithm parameters self._check_params() @@ -196,15 +207,18 @@ def fit(self, X, y): def _predict( self, X, resultsToEvaluate, pred_contribs=False, pred_interactions=False ): - # Input validation - if not self.allow_nan_: - X = check_array(X, dtype=[np.single, np.double]) - else: - X = check_array(X, dtype=[np.single, np.double], force_all_finite="allow-nan") - # Check is fit had been called check_is_fitted(self, ["n_features_in_", "n_classes_"]) + # Input validation + X = validate_data( + self, + X, + dtype=[np.single, np.double], + force_all_finite="allow-nan" if self.allow_nan_ else True, + reset=False, + ) + # Trivial case if self.n_classes_ == 1: return np.full(X.shape[0], self.classes_[0]) @@ -251,7 +265,7 @@ def convert_model(model): @control_n_jobs(decorated_methods=["fit", "predict"]) -class GBTDAALRegressor(GBTDAALBase, RegressorMixin): +class GBTDAALRegressor(RegressorMixin, GBTDAALBase): def fit(self, X, y): # Check the algorithm parameters self._check_params() @@ -297,15 +311,18 @@ def fit(self, X, y): return self def predict(self, X, pred_contribs=False, pred_interactions=False): - # Input validation - if not self.allow_nan_: - X = check_array(X, dtype=[np.single, np.double]) - else: - X = check_array(X, dtype=[np.single, np.double], force_all_finite="allow-nan") - # Check is fit had been called check_is_fitted(self, ["n_features_in_"]) + # Input validation + X = validate_data( + self, + X, + dtype=[np.single, np.double], + force_all_finite="allow-nan" if self.allow_nan_ else True, + reset=False, + ) + fptype = getFPType(X) return self._predict_regression(X, fptype, pred_contribs, pred_interactions) From d2d78e4f555e941d9301b15ae91ecbb091f74eb5 Mon Sep 17 00:00:00 2001 From: Alexander Andreev Date: Wed, 11 Dec 2024 06:43:21 -0800 Subject: [PATCH 11/18] Pin sklearn 1.5 for py3.9 --- requirements-test.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements-test.txt b/requirements-test.txt index 02a99a8406..7a39fc7267 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -6,7 +6,8 @@ numpy>=1.19.5 ; python_version <= '3.9' numpy>=1.21.6 ; python_version == '3.10' numpy>=1.23.5 ; python_version == '3.11' numpy>=2.0.0 ; python_version >= '3.12' -scikit-learn==1.6.0 +scikit-learn==1.5.2 ; python_version <= '3.9' +scikit-learn==1.6.0 ; python_version >= '3.10' pandas==2.1.3 ; python_version < '3.11' pandas==2.2.3 ; python_version >= '3.11' xgboost==2.1.3 From a5f390c999b2c4cff7ae3c7b49afbcff450741c0 Mon Sep 17 00:00:00 2001 From: Alexander Andreev Date: Wed, 11 Dec 2024 10:06:42 -0800 Subject: [PATCH 12/18] Fix knn bf regr spmd example --- examples/sklearnex/knn_bf_regression_spmd.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/examples/sklearnex/knn_bf_regression_spmd.py b/examples/sklearnex/knn_bf_regression_spmd.py index 28ce112290..63381974a2 100644 --- a/examples/sklearnex/knn_bf_regression_spmd.py +++ b/examples/sklearnex/knn_bf_regression_spmd.py @@ -21,7 +21,13 @@ import numpy as np from mpi4py import MPI from numpy.testing import assert_allclose -from sklearn.metrics import mean_squared_error + +from daal4py.sklearn._utils import sklearn_check_version + +if sklearn_check_version("1.4"): + from sklearn.metrics import root_mean_squared_error +else: + from sklearn.metrics import mean_squared_error from sklearnex.spmd.neighbors import KNeighborsRegressor @@ -80,6 +86,11 @@ def generate_X_y(par, coef_seed, data_seed): ) print( "RMSE for entire rank {}: {}\n".format( - rank, mean_squared_error(y_test, dpt.to_numpy(y_predict), squared=False) + rank, + ( + root_mean_squared_error(y_test, dpt.to_numpy(y_predict)) + if sklearn_check_version("1.4") + else mean_squared_error(y_test, dpt.to_numpy(y_predict), squared=False) + ), ) ) From 68df3df5a546bb6475212162c1fcf5540716a8ad Mon Sep 17 00:00:00 2001 From: Alexander Andreev Date: Thu, 12 Dec 2024 04:18:38 -0800 Subject: [PATCH 13/18] Update python-sklearn CI matrix --- .ci/pipeline/ci.yml | 40 ++++++++++++++-------------------------- 1 file changed, 14 insertions(+), 26 deletions(-) diff --git a/.ci/pipeline/ci.yml b/.ci/pipeline/ci.yml index 36877eec4a..877bc1f0d4 100644 --- a/.ci/pipeline/ci.yml +++ b/.ci/pipeline/ci.yml @@ -116,23 +116,17 @@ jobs: timeoutInMinutes: 120 strategy: matrix: - Python3.9_Sklearn1.0: + Python3.9_Sklearn1.2: PYTHON_VERSION: '3.9' - SKLEARN_VERSION: '1.0' - Python3.9_Sklearn1.1: - PYTHON_VERSION: '3.9' - SKLEARN_VERSION: '1.1' - Python3.10_Sklearn1.2: - PYTHON_VERSION: '3.10' SKLEARN_VERSION: '1.2' - Python3.11_Sklearn1.3: - PYTHON_VERSION: '3.11' + Python3.10_Sklearn1.3: + PYTHON_VERSION: '3.10' SKLEARN_VERSION: '1.3' - Python3.12_Sklearn1.4: - PYTHON_VERSION: '3.12' + Python3.11_Sklearn1.4: + PYTHON_VERSION: '3.11' SKLEARN_VERSION: '1.4' - Python3.13_Sklearn1.5: - PYTHON_VERSION: '3.13' + Python3.12_Sklearn1.5: + PYTHON_VERSION: '3.12' SKLEARN_VERSION: '1.5' Python3.13_Sklearn1.6: PYTHON_VERSION: '3.13' @@ -146,23 +140,17 @@ jobs: timeoutInMinutes: 120 strategy: matrix: - Python3.9_Sklearn1.0: + Python3.9_Sklearn1.2: PYTHON_VERSION: '3.9' - SKLEARN_VERSION: '1.0' - Python3.9_Sklearn1.1: - PYTHON_VERSION: '3.9' - SKLEARN_VERSION: '1.1' - Python3.10_Sklearn1.2: - PYTHON_VERSION: '3.10' SKLEARN_VERSION: '1.2' - Python3.11_Sklearn1.3: - PYTHON_VERSION: '3.11' + Python3.10_Sklearn1.3: + PYTHON_VERSION: '3.10' SKLEARN_VERSION: '1.3' - Python3.12_Sklearn1.4: - PYTHON_VERSION: '3.12' + Python3.11_Sklearn1.4: + PYTHON_VERSION: '3.11' SKLEARN_VERSION: '1.4' - Python3.13_Sklearn1.5: - PYTHON_VERSION: '3.13' + Python3.12_Sklearn1.5: + PYTHON_VERSION: '3.12' SKLEARN_VERSION: '1.5' Python3.13_Sklearn1.6: PYTHON_VERSION: '3.13' From 6aa2969fdcbab618691102859d9dd1f632acea85 Mon Sep 17 00:00:00 2001 From: Alexander Andreev Date: Thu, 12 Dec 2024 04:20:42 -0800 Subject: [PATCH 14/18] Apply comments for AdaBoost and GBT estimators --- daal4py/sklearn/ensemble/AdaBoostClassifier.py | 4 ++-- daal4py/sklearn/ensemble/GBTDAAL.py | 17 +++++++++-------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/daal4py/sklearn/ensemble/AdaBoostClassifier.py b/daal4py/sklearn/ensemble/AdaBoostClassifier.py index 14b2a3854d..9ccb592cca 100644 --- a/daal4py/sklearn/ensemble/AdaBoostClassifier.py +++ b/daal4py/sklearn/ensemble/AdaBoostClassifier.py @@ -95,7 +95,7 @@ def fit(self, X, y): ) # Check that X and y have correct shape - X, y = check_X_y(X, y, y_numeric=False, dtype=[np.single, np.double]) + X, y = check_X_y(X, y, y_numeric=False, dtype=[np.float64, np.float32]) check_classification_targets(y) @@ -157,7 +157,7 @@ def predict(self, X): check_is_fitted(self) # Input validation - X = validate_data(self, X, dtype=[np.single, np.double], reset=False) + X = validate_data(self, X, dtype=[np.float64, np.float32], reset=False) # Trivial case if self.n_classes_ == 1: diff --git a/daal4py/sklearn/ensemble/GBTDAAL.py b/daal4py/sklearn/ensemble/GBTDAAL.py index 5d8b564cb7..b6999dda1a 100644 --- a/daal4py/sklearn/ensemble/GBTDAAL.py +++ b/daal4py/sklearn/ensemble/GBTDAAL.py @@ -134,10 +134,11 @@ def _check_params(self): def _more_tags(self): return {"allow_nan": self.allow_nan_} - def __sklearn_tags__(self): - tags = super().__sklearn_tags__() - tags.input_tags.allow_nan = self.allow_nan_ - return tags + if sklearn_check_version("1.6"): + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags.input_tags.allow_nan = self.allow_nan_ + return tags @control_n_jobs(decorated_methods=["fit", "predict"]) @@ -147,7 +148,7 @@ def fit(self, X, y): self._check_params() # Check that X and y have correct shape - X, y = check_X_y(X, y, y_numeric=False, dtype=[np.single, np.double]) + X, y = check_X_y(X, y, y_numeric=False, dtype=[np.float64, np.float32]) check_classification_targets(y) @@ -214,7 +215,7 @@ def _predict( X = validate_data( self, X, - dtype=[np.single, np.double], + dtype=[np.float64, np.float32], force_all_finite="allow-nan" if self.allow_nan_ else True, reset=False, ) @@ -271,7 +272,7 @@ def fit(self, X, y): self._check_params() # Check that X and y have correct shape - X, y = check_X_y(X, y, y_numeric=True, dtype=[np.single, np.double]) + X, y = check_X_y(X, y, y_numeric=True, dtype=[np.float64, np.float32]) # Convert to 2d array y_ = y.reshape((-1, 1)) @@ -318,7 +319,7 @@ def predict(self, X, pred_contribs=False, pred_interactions=False): X = validate_data( self, X, - dtype=[np.single, np.double], + dtype=[np.float64, np.float32], force_all_finite="allow-nan" if self.allow_nan_ else True, reset=False, ) From 119c94ef511de2bc10cebdba9e220c351305607a Mon Sep 17 00:00:00 2001 From: Alexander Andreev Date: Thu, 12 Dec 2024 06:01:20 -0800 Subject: [PATCH 15/18] Add sklearn 1.6 to README badge --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 114a943a4c..6a045cdaaf 100755 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ [![PyPI Version](https://img.shields.io/pypi/v/scikit-learn-intelex)](https://pypi.org/project/scikit-learn-intelex/) [![Conda Version](https://img.shields.io/conda/vn/conda-forge/scikit-learn-intelex)](https://anaconda.org/conda-forge/scikit-learn-intelex) [![python version](https://img.shields.io/badge/python-3.9%20%7C%203.10%20%7C%203.11%20%7C%203.12-blue)](https://img.shields.io/badge/python-3.9%20%7C%203.10%20%7C%203.11%20%7C%203.12-blue) -[![scikit-learn supported versions](https://img.shields.io/badge/sklearn-1.0%20%7C%201.2%20%7C%201.3%20%7C%201.4%20%7C%201.5-blue)](https://img.shields.io/badge/sklearn-1.0%20%7C%201.2%20%7C%201.3%20%7C%201.4%20%7C%201.5-blue) +[![scikit-learn supported versions](https://img.shields.io/badge/sklearn-1.0%20%7C%201.2%20%7C%201.3%20%7C%201.4%20%7C%201.5%20%7C%201.6-blue)](https://img.shields.io/badge/sklearn-1.0%20%7C%201.2%20%7C%201.3%20%7C%201.4%20%7C%201.5%20%7C%201.6-blue) --- From 12874f680f98502000f9380622515e518af0e653 Mon Sep 17 00:00:00 2001 From: Alexander Andreev Date: Thu, 12 Dec 2024 06:01:51 -0800 Subject: [PATCH 16/18] Linting --- daal4py/sklearn/ensemble/GBTDAAL.py | 1 + 1 file changed, 1 insertion(+) diff --git a/daal4py/sklearn/ensemble/GBTDAAL.py b/daal4py/sklearn/ensemble/GBTDAAL.py index b6999dda1a..f8f7a48aaa 100644 --- a/daal4py/sklearn/ensemble/GBTDAAL.py +++ b/daal4py/sklearn/ensemble/GBTDAAL.py @@ -135,6 +135,7 @@ def _more_tags(self): return {"allow_nan": self.allow_nan_} if sklearn_check_version("1.6"): + def __sklearn_tags__(self): tags = super().__sklearn_tags__() tags.input_tags.allow_nan = self.allow_nan_ From 3229f80b834f730e19968dff952215a1856e78b8 Mon Sep 17 00:00:00 2001 From: Alexander Andreev Date: Thu, 12 Dec 2024 08:04:18 -0800 Subject: [PATCH 17/18] Update metric in knn bf regr spmd example --- examples/sklearnex/knn_bf_regression_spmd.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/examples/sklearnex/knn_bf_regression_spmd.py b/examples/sklearnex/knn_bf_regression_spmd.py index 63381974a2..06e70ca013 100644 --- a/examples/sklearnex/knn_bf_regression_spmd.py +++ b/examples/sklearnex/knn_bf_regression_spmd.py @@ -21,13 +21,7 @@ import numpy as np from mpi4py import MPI from numpy.testing import assert_allclose - -from daal4py.sklearn._utils import sklearn_check_version - -if sklearn_check_version("1.4"): - from sklearn.metrics import root_mean_squared_error -else: - from sklearn.metrics import mean_squared_error +from sklearn.metrics import mean_squared_error from sklearnex.spmd.neighbors import KNeighborsRegressor @@ -85,12 +79,8 @@ def generate_X_y(par, coef_seed, data_seed): ) ) print( - "RMSE for entire rank {}: {}\n".format( + "MSE for entire rank {}: {}\n".format( rank, - ( - root_mean_squared_error(y_test, dpt.to_numpy(y_predict)) - if sklearn_check_version("1.4") - else mean_squared_error(y_test, dpt.to_numpy(y_predict), squared=False) - ), + mean_squared_error(y_test, dpt.to_numpy(y_predict)), ) ) From b2ed69cc02cfd86413c46886637f7885bd019fa7 Mon Sep 17 00:00:00 2001 From: Alexander Andreev Date: Thu, 12 Dec 2024 08:12:50 -0800 Subject: [PATCH 18/18] Update CI matrix --- .ci/pipeline/ci.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.ci/pipeline/ci.yml b/.ci/pipeline/ci.yml index 877bc1f0d4..2472054108 100644 --- a/.ci/pipeline/ci.yml +++ b/.ci/pipeline/ci.yml @@ -116,9 +116,9 @@ jobs: timeoutInMinutes: 120 strategy: matrix: - Python3.9_Sklearn1.2: + Python3.9_Sklearn1.0: PYTHON_VERSION: '3.9' - SKLEARN_VERSION: '1.2' + SKLEARN_VERSION: '1.0' Python3.10_Sklearn1.3: PYTHON_VERSION: '3.10' SKLEARN_VERSION: '1.3' @@ -140,9 +140,9 @@ jobs: timeoutInMinutes: 120 strategy: matrix: - Python3.9_Sklearn1.2: + Python3.9_Sklearn1.0: PYTHON_VERSION: '3.9' - SKLEARN_VERSION: '1.2' + SKLEARN_VERSION: '1.0' Python3.10_Sklearn1.3: PYTHON_VERSION: '3.10' SKLEARN_VERSION: '1.3'