diff --git a/deselected_tests.yaml b/deselected_tests.yaml index f44ae218e5..279854a8a6 100755 --- a/deselected_tests.yaml +++ b/deselected_tests.yaml @@ -380,8 +380,11 @@ deselected_tests: - tests/test_common.py::test_estimators[IncrementalLinearRegression()-check_estimators_pickle(readonly_memmap=True)] - tests/test_common.py::test_estimators[IncrementalPCA()-check_estimators_pickle] - tests/test_common.py::test_estimators[IncrementalPCA()-check_estimators_pickle(readonly_memmap=True)] + - tests/test_common.py::test_estimators[IncrementalRidge()-check_estimators_pickle] + - tests/test_common.py::test_estimators[IncrementalRidge()-check_estimators_pickle(readonly_memmap=True)] # There are not enough data to run onedal backend - tests/test_common.py::test_estimators[IncrementalLinearRegression()-check_fit2d_1sample] + - tests/test_common.py::test_estimators[IncrementalRidge()-check_fit2d_1sample] # Deselection of LogisticRegression tests over accuracy comparisons with sample_weights # and without. Because scikit-learn-intelex does not support sample_weights, it's doing diff --git a/onedal/linear_model/__init__.py b/onedal/linear_model/__init__.py index 998e4a62d7..bdb0d0d6b3 100755 --- a/onedal/linear_model/__init__.py +++ b/onedal/linear_model/__init__.py @@ -14,12 +14,13 @@ # limitations under the License. # =============================================================================== -from .incremental_linear_model import IncrementalLinearRegression +from .incremental_linear_model import IncrementalLinearRegression, IncrementalRidge from .linear_model import LinearRegression, Ridge from .logistic_regression import LogisticRegression __all__ = [ "IncrementalLinearRegression", + "IncrementalRidge", "LinearRegression", "LogisticRegression", "Ridge", diff --git a/onedal/linear_model/incremental_linear_model.py b/onedal/linear_model/incremental_linear_model.py index b8b754e18f..43f9db4159 100644 --- a/onedal/linear_model/incremental_linear_model.py +++ b/onedal/linear_model/incremental_linear_model.py @@ -144,3 +144,113 @@ def finalize_fit(self, queue=None): self.intercept_ = self.intercept_[0] return self + + +class IncrementalRidge(BaseLinearRegression): + """ + Incremental Ridge Regression oneDAL implementation. + + Parameters + ---------- + alpha : float, default=1.0 + Regularization strength; must be a positive float. Regularization + improves the conditioning of the problem and reduces the variance of + the estimates. Larger values specify stronger regularization. + + fit_intercept : bool, default=True + Whether to calculate the intercept for this model. If set + to False, no intercept will be used in calculations + (i.e. data is expected to be centered). + + copy_X : bool, default=True + If True, X will be copied; else, it may be overwritten. + + algorithm : string, default="norm_eq" + Algorithm used for computation on oneDAL side + """ + + def __init__(self, alpha=1.0, fit_intercept=True, copy_X=False, algorithm="norm_eq"): + module = self._get_backend("linear_model", "regression") + super().__init__( + fit_intercept=fit_intercept, alpha=alpha, copy_X=copy_X, algorithm=algorithm + ) + self._partial_result = module.partial_train_result() + + def _reset(self): + module = self._get_backend("linear_model", "regression") + self._partial_result = module.partial_train_result() + + def partial_fit(self, X, y, queue=None): + """ + Computes partial data for ridge regression + from data batch X and saves it to `_partial_result`. + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + Training data batch, where `n_samples` is the number of samples + in the batch, and `n_features` is the number of features. + + y: array-like of shape (n_samples,) or (n_samples, n_targets) in + case of multiple targets + Responses for training data. + + queue : dpctl.SyclQueue + If not None, use this queue for computations. + Returns + ------- + self : object + Returns the instance itself. + """ + module = self._get_backend("linear_model", "regression") + + if not hasattr(self, "_queue"): + self._queue = queue + policy = self._get_policy(queue, X) + + X, y = _convert_to_supported(policy, X, y) + + if not hasattr(self, "_dtype"): + self._dtype = get_dtype(X) + self._params = self._get_onedal_params(self._dtype) + + y = np.asarray(y).astype(dtype=self._dtype) + + X, y = _check_X_y(X, y, dtype=[np.float64, np.float32], accept_2d_y=True) + + self.n_features_in_ = _num_features(X, fallback_1d=True) + X_table, y_table = to_table(X, y) + self._partial_result = module.partial_train( + policy, self._params, self._partial_result, X_table, y_table + ) + + def finalize_fit(self, queue=None): + """ + Finalizes ridge regression computation and obtains coefficients + from the current `_partial_result`. + + Parameters + ---------- + queue : dpctl.SyclQueue + If available, uses provided queue for computations. + + Returns + ------- + self : object + Returns the instance itself. + """ + module = self._get_backend("linear_model", "regression") + if queue is not None: + policy = self._get_policy(queue) + else: + policy = self._get_policy(self._queue) + result = module.finalize_train(policy, self._params, self._partial_result) + + self._onedal_model = result.model + + packed_coefficients = from_table(result.model.packed_coefficients) + self.coef_, self.intercept_ = ( + packed_coefficients[:, 1:].squeeze(), + packed_coefficients[:, 0].squeeze(), + ) + + return self diff --git a/onedal/linear_model/tests/test_incremental_ridge_regression.py b/onedal/linear_model/tests/test_incremental_ridge_regression.py new file mode 100644 index 0000000000..471f46e4f6 --- /dev/null +++ b/onedal/linear_model/tests/test_incremental_ridge_regression.py @@ -0,0 +1,107 @@ +# ============================================================================== +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from daal4py.sklearn._utils import daal_check_version + +if daal_check_version((2024, "P", 600)): + import numpy as np + import pytest + from numpy.testing import assert_allclose, assert_array_equal + from sklearn.datasets import load_diabetes + from sklearn.metrics import mean_squared_error + from sklearn.model_selection import train_test_split + + from onedal.linear_model import IncrementalRidge + from onedal.tests.utils._device_selection import get_queues + + @pytest.mark.parametrize("queue", get_queues()) + @pytest.mark.parametrize("dtype", [np.float32, np.float64]) + def test_diabetes(queue, dtype): + X, y = load_diabetes(return_X_y=True) + X, y = X.astype(dtype), y.astype(dtype) + X_train, X_test, y_train, y_test = train_test_split( + X, y, train_size=0.8, random_state=777 + ) + X_train_split = np.array_split(X_train, 2) + y_train_split = np.array_split(y_train, 2) + model = IncrementalRidge(fit_intercept=True, alpha=0.1) + for i in range(2): + model.partial_fit(X_train_split[i], y_train_split[i], queue=queue) + model.finalize_fit() + y_pred = model.predict(X_test, queue=queue) + assert_allclose(mean_squared_error(y_test, y_pred), 2388.775, rtol=1e-5) + + @pytest.mark.parametrize("queue", get_queues()) + @pytest.mark.parametrize("dtype", [np.float32, np.float64]) + @pytest.mark.skip(reason="pickling not implemented for oneDAL entities") + def test_pickle(queue, dtype): + # TODO Implement pickling for oneDAL entities + X, y = load_diabetes(return_X_y=True) + X, y = X.astype(dtype), y.astype(dtype) + model = IncrementalRidge(fit_intercept=True, alpha=0.5) + model.partial_fit(X, y, queue=queue) + model.finalize_fit() + expected = model.predict(X, queue=queue) + + import pickle + + dump = pickle.dumps(model) + model2 = pickle.loads(dump) + + assert isinstance(model2, model.__class__) + result = model2.predict(X, queue=queue) + + assert_array_equal(expected, result) + + @pytest.mark.parametrize("queue", get_queues()) + @pytest.mark.parametrize("num_blocks", [1, 2, 10]) + @pytest.mark.parametrize("dtype", [np.float32, np.float64]) + def test_no_intercept_results(queue, num_blocks, dtype): + seed = 42 + n_features, n_targets = 19, 7 + n_train_samples, n_test_samples = 3500, 1999 + + gen = np.random.default_rng(seed) + + X = gen.random(size=(n_train_samples, n_features), dtype=dtype) + y = gen.random(size=(n_train_samples, n_targets), dtype=dtype) + X_split = np.array_split(X, num_blocks) + y_split = np.array_split(y, num_blocks) + alpha = 0.5 + + lambda_identity = alpha * np.eye(X.shape[1]) + inverse_term = np.linalg.inv(np.dot(X.T, X) + lambda_identity) + xt_y = np.dot(X.T, y) + coef = np.dot(inverse_term, xt_y) + + model = IncrementalRidge(fit_intercept=False, alpha=alpha) + for i in range(num_blocks): + model.partial_fit(X_split[i], y_split[i], queue=queue) + model.finalize_fit() + + if queue and queue.sycl_device.is_gpu: + tol = 5e-3 if model.coef_.dtype == np.float32 else 1e-5 + else: + tol = 2e-3 if model.coef_.dtype == np.float32 else 1e-5 + assert_allclose(coef, model.coef_.T, rtol=tol) + + Xt = gen.random(size=(n_test_samples, n_features), dtype=dtype) + gtr = Xt @ coef + + res = model.predict(Xt, queue=queue) + + tol = 2e-4 if res.dtype == np.float32 else 1e-7 + assert_allclose(gtr, res, rtol=tol) diff --git a/sklearnex/dispatcher.py b/sklearnex/dispatcher.py index 60c56b4564..44d9880793 100644 --- a/sklearnex/dispatcher.py +++ b/sklearnex/dispatcher.py @@ -155,6 +155,7 @@ def get_patch_map_core(preview=False): from .linear_model import ( IncrementalLinearRegression as IncrementalLinearRegression_sklearnex, ) + from .linear_model import IncrementalRidge as IncrementalRidge_sklearnex from .linear_model import Lasso as Lasso_sklearnex from .linear_model import LinearRegression as LinearRegression_sklearnex from .linear_model import LogisticRegression as LogisticRegression_sklearnex @@ -412,6 +413,19 @@ def get_patch_map_core(preview=False): ] ] + if daal_check_version((2024, "P", 600)): + # IncrementalRidge + mapping["incrementalridge"] = [ + [ + ( + linear_model_module, + "IncrementalRidge", + IncrementalRidge_sklearnex, + ), + None, + ] + ] + # Configs mapping["set_config"] = [ [(base_module, "set_config", set_config_sklearnex), None] diff --git a/sklearnex/linear_model/__init__.py b/sklearnex/linear_model/__init__.py index 7c6ef5201b..2c9defc9e9 100755 --- a/sklearnex/linear_model/__init__.py +++ b/sklearnex/linear_model/__init__.py @@ -16,6 +16,7 @@ from .coordinate_descent import ElasticNet, Lasso from .incremental_linear import IncrementalLinearRegression +from .incremental_ridge import IncrementalRidge from .linear import LinearRegression from .logistic_regression import LogisticRegression from .ridge import Ridge @@ -23,6 +24,7 @@ __all__ = [ "ElasticNet", "IncrementalLinearRegression", + "IncrementalRidge", "Lasso", "LinearRegression", "LogisticRegression", diff --git a/sklearnex/linear_model/incremental_ridge.py b/sklearnex/linear_model/incremental_ridge.py new file mode 100644 index 0000000000..99dc473456 --- /dev/null +++ b/sklearnex/linear_model/incremental_ridge.py @@ -0,0 +1,418 @@ +# =============================================================================== +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =============================================================================== + +import numbers +import warnings + +import numpy as np +from sklearn.base import BaseEstimator, MultiOutputMixin, RegressorMixin +from sklearn.metrics import r2_score +from sklearn.utils import gen_batches +from sklearn.utils.validation import check_is_fitted, check_X_y + +from daal4py.sklearn._n_jobs_support import control_n_jobs +from daal4py.sklearn.utils.validation import sklearn_check_version + +if sklearn_check_version("1.2"): + from sklearn.utils._param_validation import Interval + +from onedal.linear_model import IncrementalRidge as onedal_IncrementalRidge + +from .._device_offload import dispatch, wrap_output_data +from .._utils import PatchingConditionsChain + + +@control_n_jobs( + decorated_methods=["fit", "partial_fit", "predict", "_onedal_finalize_fit"] +) +class IncrementalRidge(MultiOutputMixin, RegressorMixin, BaseEstimator): + """ + Incremental estimator for Ridge Regression. + Allows to train Ridge Regression if data is splitted into batches. + + Parameters + ---------- + fit_intercept : bool, default=True + Whether to calculate the intercept for this model. If set + to False, no intercept will be used in calculations + (i.e. data is expected to be centered). + + alpha : float, default=1.0 + Regularization strength; must be a positive float. Regularization + improves the conditioning of the problem and reduces the variance of + the estimates. Larger values specify stronger regularization. + + copy_X : bool, default=True + If True, X will be copied; else, it may be overwritten. + + n_jobs : int, default=None + The number of jobs to use for the computation. + + batch_size : int, default=None + The number of samples to use for each batch. Only used when calling + ``fit``. If ``batch_size`` is ``None``, then ``batch_size`` + is inferred from the data and set to ``5 * n_features``, to provide a + balance between approximation accuracy and memory consumption. + + Attributes + ---------- + coef_ : array of shape (n_features, ) or (n_targets, n_features) + Estimated coefficients for the ridge regression problem. + If multiple targets are passed during the fit (y 2D), this + is a 2D array of shape (n_targets, n_features), while if only + one target is passed, this is a 1D array of length n_features. + + intercept_ : float or array of shape (n_targets,) + Independent term in the linear model. Set to 0.0 if + `fit_intercept = False`. + + n_features_in_ : int + Number of features seen during :term:`fit`. + + n_samples_seen_ : int + The number of samples processed by the estimator. Will be reset on + new calls to fit, but increments across ``partial_fit`` calls. + It should be not less than `n_features_in_` if `fit_intercept` + is False and not less than `n_features_in_` + 1 if `fit_intercept` + is True to obtain regression coefficients. + + batch_size_ : int + Inferred batch size from ``batch_size``. + """ + + _onedal_incremental_ridge = staticmethod(onedal_IncrementalRidge) + + if sklearn_check_version("1.2"): + _parameter_constraints: dict = { + "fit_intercept": ["boolean"], + "alpha": [Interval(numbers.Real, 0, None, closed="left")], + "copy_X": ["boolean"], + "n_jobs": [Interval(numbers.Integral, -1, None, closed="left"), None], + "batch_size": [Interval(numbers.Integral, 1, None, closed="left"), None], + } + + def __init__( + self, fit_intercept=True, alpha=1.0, copy_X=True, n_jobs=None, batch_size=None + ): + self.fit_intercept = fit_intercept + self.alpha = alpha + self.copy_X = copy_X + self.n_jobs = n_jobs + self.batch_size = batch_size + + def _onedal_supported(self, method_name, *data): + patching_status = PatchingConditionsChain( + f"sklearn.linear_model.{self.__class__.__name__}.{method_name}" + ) + return patching_status + + _onedal_cpu_supported = _onedal_supported + _onedal_gpu_supported = _onedal_supported + + def _onedal_predict(self, X, queue=None): + if sklearn_check_version("1.2"): + self._validate_params() + + if sklearn_check_version("1.0"): + X = self._validate_data(X, accept_sparse=False, reset=False) + + assert hasattr(self, "_onedal_estimator") + if self._need_to_finalize: + self._onedal_finalize_fit() + return self._onedal_estimator.predict(X, queue) + + def _onedal_score(self, X, y, sample_weight=None, queue=None): + return r2_score( + y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight + ) + + def _onedal_partial_fit(self, X, y, check_input=True, queue=None): + first_pass = not hasattr(self, "n_samples_seen_") or self.n_samples_seen_ == 0 + + if sklearn_check_version("1.2"): + self._validate_params() + + if check_input: + if sklearn_check_version("1.0"): + X, y = self._validate_data( + X, + y, + dtype=[np.float64, np.float32], + reset=first_pass, + copy=self.copy_X, + multi_output=True, + force_all_finite=False, + ) + else: + check_X_y(X, y, multi_output=True, y_numeric=True) + + if first_pass: + self.n_samples_seen_ = X.shape[0] + self.n_features_in_ = X.shape[1] + else: + self.n_samples_seen_ += X.shape[0] + onedal_params = { + "fit_intercept": self.fit_intercept, + "alpha": self.alpha, + "copy_X": self.copy_X, + } + if not hasattr(self, "_onedal_estimator"): + self._onedal_estimator = self._onedal_incremental_ridge(**onedal_params) + self._onedal_estimator.partial_fit(X, y, queue) + self._need_to_finalize = True + + def _onedal_finalize_fit(self): + assert hasattr(self, "_onedal_estimator") + is_underdetermined = self.n_samples_seen_ < self.n_features_in_ + int( + self.fit_intercept + ) + if is_underdetermined: + raise ValueError("Not enough samples to finalize") + self._onedal_estimator.finalize_fit() + self._save_attributes() + self._need_to_finalize = False + + def _onedal_fit(self, X, y, queue=None): + if sklearn_check_version("1.2"): + self._validate_params() + + if sklearn_check_version("1.0"): + X, y = self._validate_data( + X, + y, + dtype=[np.float64, np.float32], + copy=self.copy_X, + multi_output=True, + ensure_2d=True, + ) + else: + check_X_y(X, y, multi_output=True, y_numeric=True) + + n_samples, n_features = X.shape + + is_underdetermined = n_samples < n_features + int(self.fit_intercept) + if is_underdetermined: + raise ValueError("Not enough samples to run oneDAL backend") + + if self.batch_size is None: + self.batch_size_ = 5 * n_features + else: + self.batch_size_ = self.batch_size + + self.n_samples_seen_ = 0 + if hasattr(self, "_onedal_estimator"): + self._onedal_estimator._reset() + + for batch in gen_batches(n_samples, self.batch_size_): + X_batch, y_batch = X[batch], y[batch] + self._onedal_partial_fit(X_batch, y_batch, check_input=False, queue=queue) + + if sklearn_check_version("1.2"): + self._validate_params() + + # finite check occurs on onedal side + self.n_features_in_ = n_features + + if n_samples == 1: + warnings.warn( + "Only one sample available. You may want to reshape your data array" + ) + + self._onedal_finalize_fit() + + return self + + def partial_fit(self, X, y, check_input=True): + """ + Incrementally fits the linear model with X and y. All of X and y is + processed as a single batch. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + Training data, where `n_samples` is the number of samples and + `n_features` is the number of features. + + y : array-like of shape (n_samples,) or (n_samples, n_targets) + Target values, where `n_samples` is the number of samples and + `n_targets` is the number of targets. + + Returns + ------- + self : object + Returns the instance itself. + """ + + dispatch( + self, + "partial_fit", + { + "onedal": self.__class__._onedal_partial_fit, + "sklearn": None, + }, + X, + y, + check_input=check_input, + ) + return self + + def fit(self, X, y): + """ + Fit the model with X and y, using minibatches of size batch_size. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + Training data, where `n_samples` is the number of samples and + `n_features` is the number of features. It is necessary for + `n_samples` to be not less than `n_features` if `fit_intercept` + is False and not less than `n_features` + 1 if `fit_intercept` + is True + + y : array-like of shape (n_samples,) or (n_samples, n_targets) + Target values, where `n_samples` is the number of samples and + `n_targets` is the number of targets. + + Returns + ------- + self : object + Returns the instance itself. + """ + + dispatch( + self, + "fit", + { + "onedal": self.__class__._onedal_fit, + "sklearn": None, + }, + X, + y, + ) + return self + + @wrap_output_data + def predict(self, X, y=None): + """ + Predict using the linear model. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Samples. + + Returns + ------- + array, shape (n_samples,) or (n_samples, n_targets) + Returns predicted values. + """ + check_is_fitted( + self, + msg=f"This {self.__class__.__name__} instance is not fitted yet. Call 'fit' with appropriate arguments before using this estimator.", + ) + + return dispatch( + self, + "predict", + { + "onedal": self.__class__._onedal_predict, + "sklearn": None, + }, + X, + ) + + @wrap_output_data + def score(self, X, y, sample_weight=None): + """ + Return the coefficient of determination R^2 of the prediction. + + The coefficient R^2 is defined as (1 - u/v), where u is the residual + sum of squares ((y_true - y_pred) ** 2).sum() and v is the total sum + of squares ((y_true - y_true.mean()) ** 2).sum(). + The best possible score is 1.0 and it can be negative (because the + model can be arbitrarily worse). A constant model that always + predicts the expected value of y, disregarding the input features, + would get a R^2 score of 0.0. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Test samples. + + y : array-like of shape (n_samples,) or (n_samples, n_targets) + True values for X. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. + + Returns + ------- + score : float + R^2 of self.predict(X) wrt. y. + """ + check_is_fitted( + self, + msg=f"This {self.__class__.__name__} instance is not fitted yet. Call 'fit' with appropriate arguments before using this estimator.", + ) + + return dispatch( + self, + "score", + { + "onedal": self.__class__._onedal_score, + "sklearn": None, + }, + X, + y, + sample_weight=sample_weight, + ) + + @property + def coef_(self): + if hasattr(self, "_onedal_estimator") and self._need_to_finalize: + self._onedal_finalize_fit() + + return self._coef + + @coef_.setter + def coef_(self, value): + if hasattr(self, "_onedal_estimator"): + self._onedal_estimator.coef_ = value + # checking if the model is already fitted and if so, deleting the model + if hasattr(self._onedal_estimator, "_onedal_model"): + del self._onedal_estimator._onedal_model + self._coef = value + + @property + def intercept_(self): + if hasattr(self, "_onedal_estimator") and self._need_to_finalize: + self._onedal_finalize_fit() + + return self._intercept + + @intercept_.setter + def intercept_(self, value): + if hasattr(self, "_onedal_estimator"): + self._onedal_estimator.intercept_ = value + # checking if the model is already fitted and if so, deleting the model + if hasattr(self._onedal_estimator, "_onedal_model"): + del self._onedal_estimator._onedal_model + self._intercept = value + + def _save_attributes(self): + self.n_features_in_ = self._onedal_estimator.n_features_in_ + self._coef = self._onedal_estimator.coef_ + self._intercept = self._onedal_estimator.intercept_ diff --git a/sklearnex/linear_model/tests/test_incremental_ridge.py b/sklearnex/linear_model/tests/test_incremental_ridge.py new file mode 100644 index 0000000000..adcd5349ed --- /dev/null +++ b/sklearnex/linear_model/tests/test_incremental_ridge.py @@ -0,0 +1,153 @@ +# =============================================================================== +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =============================================================================== + +from daal4py.sklearn._utils import daal_check_version + +if daal_check_version((2024, "P", 600)): + import numpy as np + import pytest + from numpy.testing import assert_allclose + from sklearn.exceptions import NotFittedError + + from onedal.tests.utils._dataframes_support import ( + _as_numpy, + _convert_to_dataframe, + get_dataframes_and_queues, + ) + from sklearnex.linear_model import IncrementalRidge + + def _compute_ridge_coefficients(X, y, alpha, fit_intercept): + coefficients_manual, intercept_manual = None, None + if fit_intercept: + X_mean = np.mean(X, axis=0) + y_mean = np.mean(y) + X_centered = X - X_mean + y_centered = y - y_mean + + X_with_intercept = np.hstack([np.ones((X.shape[0], 1)), X_centered]) + lambda_identity = alpha * np.eye(X_with_intercept.shape[1]) + inverse_term = np.linalg.inv( + np.dot(X_with_intercept.T, X_with_intercept) + lambda_identity + ) + xt_y = np.dot(X_with_intercept.T, y_centered) + coefficients_manual = np.dot(inverse_term, xt_y) + + intercept_manual = y_mean - np.dot(X_mean, coefficients_manual[1:]) + coefficients_manual = coefficients_manual[1:] + else: + lambda_identity = alpha * np.eye(X.shape[1]) + inverse_term = np.linalg.inv(np.dot(X.T, X) + lambda_identity) + xt_y = np.dot(X.T, y) + coefficients_manual = np.dot(inverse_term, xt_y) + + return coefficients_manual, intercept_manual + + @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues()) + @pytest.mark.parametrize("batch_size", [10, 100, 1000]) + @pytest.mark.parametrize("alpha", [0.1, 0.5, 1.0]) + @pytest.mark.parametrize("fit_intercept", [True, False]) + def test_inc_ridge_fit_coefficients( + dataframe, queue, alpha, batch_size, fit_intercept + ): + sample_size, feature_size = 1000, 50 + X = np.random.rand(sample_size, feature_size) + y = np.random.rand(sample_size) + X_c = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe) + y_c = _convert_to_dataframe(y, sycl_queue=queue, target_df=dataframe) + + inc_ridge = IncrementalRidge( + fit_intercept=fit_intercept, alpha=alpha, batch_size=batch_size + ) + inc_ridge.fit(X_c, y_c) + + coefficients_manual, intercept_manual = _compute_ridge_coefficients( + X, y, alpha, fit_intercept + ) + if fit_intercept: + assert_allclose(inc_ridge.intercept_, intercept_manual, rtol=1e-6, atol=1e-6) + + assert_allclose(inc_ridge.coef_, coefficients_manual, rtol=1e-6, atol=1e-6) + + @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues()) + @pytest.mark.parametrize("batch_size", [2, 5]) + @pytest.mark.parametrize("alpha", [0.1, 0.5, 1.0]) + def test_inc_ridge_partial_fit_coefficients(dataframe, queue, alpha, batch_size): + sample_size, feature_size = 1000, 50 + X = np.random.rand(sample_size, feature_size) + y = np.random.rand(sample_size) + X_split = np.array_split(X, batch_size) + y_split = np.array_split(y, batch_size) + + inc_ridge = IncrementalRidge(fit_intercept=False, alpha=alpha) + + for batch_index in range(len(X_split)): + X_c = _convert_to_dataframe( + X_split[batch_index], sycl_queue=queue, target_df=dataframe + ) + y_c = _convert_to_dataframe( + y_split[batch_index], sycl_queue=queue, target_df=dataframe + ) + inc_ridge.partial_fit(X_c, y_c) + + lambda_identity = alpha * np.eye(X.shape[1]) + inverse_term = np.linalg.inv(np.dot(X.T, X) + lambda_identity) + xt_y = np.dot(X.T, y) + coefficients_manual = np.dot(inverse_term, xt_y) + + assert_allclose(inc_ridge.coef_, coefficients_manual, rtol=1e-6, atol=1e-6) + + def test_inc_ridge_score_before_fit(): + X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]]) + y = np.dot(X, np.array([1, 2])) + 3 + inc_ridge = IncrementalRidge(alpha=0.5) + with pytest.raises(NotFittedError): + inc_ridge.score(X, y) + + def test_inc_ridge_predict_before_fit(): + X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]]) + inc_ridge = IncrementalRidge(alpha=0.5) + with pytest.raises(NotFittedError): + inc_ridge.predict(X) + + def test_inc_ridge_score_after_fit(): + X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]]) + y = np.dot(X, np.array([1, 2])) + 3 + inc_ridge = IncrementalRidge(alpha=0.5) + inc_ridge.fit(X, y) + assert inc_ridge.score(X, y) >= 0.97 + + @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues()) + @pytest.mark.parametrize("fit_intercept", [True, False]) + def test_inc_ridge_predict_after_fit(dataframe, queue, fit_intercept): + sample_size, feature_size = 1000, 50 + X = np.random.rand(sample_size, feature_size) + y = np.random.rand(sample_size) + X_c = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe) + y_c = _convert_to_dataframe(y, sycl_queue=queue, target_df=dataframe) + + inc_ridge = IncrementalRidge(fit_intercept=fit_intercept, alpha=0.5) + inc_ridge.fit(X_c, y_c) + + y_pred = inc_ridge.predict(X_c) + + coefficients_manual, intercept_manual = _compute_ridge_coefficients( + X, y, 0.5, fit_intercept + ) + y_pred_manual = np.dot(X, coefficients_manual) + if fit_intercept: + y_pred_manual += intercept_manual + + assert_allclose(_as_numpy(y_pred), y_pred_manual, rtol=1e-6, atol=1e-6) diff --git a/sklearnex/tests/test_memory_usage.py b/sklearnex/tests/test_memory_usage.py index b072fd7814..778f99d268 100644 --- a/sklearnex/tests/test_memory_usage.py +++ b/sklearnex/tests/test_memory_usage.py @@ -54,6 +54,7 @@ "IncrementalEmpiricalCovariance", # dataframe_f issues "IncrementalLinearRegression", # TODO fix memory leak issue in private CI for data_shape = (1000, 100), data_transform_function = dataframe_f "IncrementalPCA", # TODO fix memory leak issue in private CI for data_shape = (1000, 100), data_transform_function = dataframe_f + "IncrementalRidge", # TODO fix memory leak issue in private CI for data_shape = (1000, 100), data_transform_function = dataframe_f "LogisticRegression(solver='newton-cg')", # memory leak fortran (1000, 100) )