Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Adding IncrementalRidge support into sklearnex #1957

Merged
merged 18 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions deselected_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,11 @@ deselected_tests:
- tests/test_common.py::test_estimators[IncrementalLinearRegression()-check_estimators_pickle(readonly_memmap=True)]
- tests/test_common.py::test_estimators[IncrementalPCA()-check_estimators_pickle]
- tests/test_common.py::test_estimators[IncrementalPCA()-check_estimators_pickle(readonly_memmap=True)]
- tests/test_common.py::test_estimators[IncrementalRidge()-check_estimators_pickle]
- tests/test_common.py::test_estimators[IncrementalRidge()-check_estimators_pickle(readonly_memmap=True)]
# There are not enough data to run onedal backend
- tests/test_common.py::test_estimators[IncrementalLinearRegression()-check_fit2d_1sample]
- tests/test_common.py::test_estimators[IncrementalRidge()-check_fit2d_1sample]

# Deselection of LogisticRegression tests over accuracy comparisons with sample_weights
# and without. Because scikit-learn-intelex does not support sample_weights, it's doing
Expand Down
3 changes: 2 additions & 1 deletion onedal/linear_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
# limitations under the License.
# ===============================================================================

from .incremental_linear_model import IncrementalLinearRegression
from .incremental_linear_model import IncrementalLinearRegression, IncrementalRidge
from .linear_model import LinearRegression, Ridge
from .logistic_regression import LogisticRegression

__all__ = [
"IncrementalLinearRegression",
"IncrementalRidge",
"LinearRegression",
"LogisticRegression",
"Ridge",
Expand Down
111 changes: 111 additions & 0 deletions onedal/linear_model/incremental_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,114 @@ def finalize_fit(self, queue=None):
self.intercept_ = self.intercept_[0]

return self


class IncrementalRidge(BaseLinearRegression):
"""
Incremental Ridge Regression oneDAL implementation.

Parameters
----------
alpha : float, default=1.0
Regularization strength; must be a positive float. Regularization
improves the conditioning of the problem and reduces the variance of
the estimates. Larger values specify stronger regularization.

fit_intercept : bool, default=True
Whether to calculate the intercept for this model. If set
to False, no intercept will be used in calculations
(i.e. data is expected to be centered).

copy_X : bool, default=True
If True, X will be copied; else, it may be overwritten.

algorithm : string, default="norm_eq"
Algorithm used for computation on oneDAL side
"""

def __init__(self, alpha=1.0, fit_intercept=True, copy_X=False, algorithm="norm_eq"):
module = self._get_backend("linear_model", "regression")
super().__init__(
fit_intercept=fit_intercept, alpha=alpha, copy_X=copy_X, algorithm=algorithm
)
self._partial_result = module.partial_train_result()

def _reset(self):
module = self._get_backend("linear_model", "regression")
self._partial_result = module.partial_train_result()

def partial_fit(self, X, y, queue=None):
"""
Computes partial data for ridge regression
from data batch X and saves it to `_partial_result`.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Training data batch, where `n_samples` is the number of samples
in the batch, and `n_features` is the number of features.

y: array-like of shape (n_samples,) or (n_samples, n_targets) in
case of multiple targets
Responses for training data.

queue : dpctl.SyclQueue
If not None, use this queue for computations.
Returns
-------
self : object
Returns the instance itself.
"""
module = self._get_backend("linear_model", "regression")

if not hasattr(self, "_queue"):
self._queue = queue
policy = self._get_policy(queue, X)

X, y = _convert_to_supported(policy, X, y)

if not hasattr(self, "_dtype"):
self._dtype = get_dtype(X)
self._params = self._get_onedal_params(self._dtype)

y = np.asarray(y).astype(dtype=self._dtype)
self._y_ndim_1 = y.ndim == 1

X, y = _check_X_y(X, y, dtype=[np.float64, np.float32], accept_2d_y=True)

self.n_features_in_ = _num_features(X, fallback_1d=True)
X_table, y_table = to_table(X, y)
self._partial_result = module.partial_train(
policy, self._params, self._partial_result, X_table, y_table
)

def finalize_fit(self, queue=None):
"""
Finalizes ridge regression computation and obtains coefficients
from the current `_partial_result`.

Parameters
----------
queue : dpctl.SyclQueue
If available, uses provided queue for computations.

Returns
-------
self : object
Returns the instance itself.
"""
module = self._get_backend("linear_model", "regression")
if queue is not None:
policy = self._get_policy(queue)
else:
policy = self._get_policy(self._queue)
result = module.finalize_train(policy, self._params, self._partial_result)

self._onedal_model = result.model

packed_coefficients = from_table(result.model.packed_coefficients)
self.coef_, self.intercept_ = (
packed_coefficients[:, 1:].squeeze(),
packed_coefficients[:, 0].squeeze(),
)
ethanglaser marked this conversation as resolved.
Show resolved Hide resolved

return self
107 changes: 107 additions & 0 deletions onedal/linear_model/tests/test_incremental_ridge_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# ==============================================================================
# Copyright 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from daal4py.sklearn._utils import daal_check_version

if daal_check_version((2024, "P", 600)):
import numpy as np
import pytest
from numpy.testing import assert_allclose, assert_array_equal
from sklearn.datasets import load_diabetes
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split

from onedal.linear_model import IncrementalRidge
from onedal.tests.utils._device_selection import get_queues

@pytest.mark.parametrize("queue", get_queues())
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_diabetes(queue, dtype):
X, y = load_diabetes(return_X_y=True)
X, y = X.astype(dtype), y.astype(dtype)
X_train, X_test, y_train, y_test = train_test_split(
X, y, train_size=0.8, random_state=777
)
X_train_split = np.array_split(X_train, 2)
y_train_split = np.array_split(y_train, 2)
model = IncrementalRidge(fit_intercept=True, alpha=0.1)
for i in range(2):
model.partial_fit(X_train_split[i], y_train_split[i], queue=queue)
model.finalize_fit()
y_pred = model.predict(X_test, queue=queue)
assert_allclose(mean_squared_error(y_test, y_pred), 2388.775, rtol=1e-5)

@pytest.mark.parametrize("queue", get_queues())
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
@pytest.mark.skip(reason="pickling not implemented for oneDAL entities")
def test_pickle(queue, dtype):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just make sure there is a ticket or note in another ticket to remove this skip when serialization is implemented for IncrementalRidge

# TODO Implement pickling for oneDAL entities
X, y = load_diabetes(return_X_y=True)
X, y = X.astype(dtype), y.astype(dtype)
model = IncrementalRidge(fit_intercept=True, alpha=0.5)
model.partial_fit(X, y, queue=queue)
model.finalize_fit()
expected = model.predict(X, queue=queue)

import pickle

dump = pickle.dumps(model)
model2 = pickle.loads(dump)

Check warning on line 62 in onedal/linear_model/tests/test_incremental_ridge_regression.py

View check run for this annotation

codefactor.io / CodeFactor

onedal/linear_model/tests/test_incremental_ridge_regression.py#L62

Pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue. (B301)

assert isinstance(model2, model.__class__)
result = model2.predict(X, queue=queue)

assert_array_equal(expected, result)

@pytest.mark.parametrize("queue", get_queues())
@pytest.mark.parametrize("num_blocks", [1, 2, 10])
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_no_intercept_results(queue, num_blocks, dtype):
seed = 42
n_features, n_targets = 19, 7
n_train_samples, n_test_samples = 3500, 1999

gen = np.random.default_rng(seed)

X = gen.random(size=(n_train_samples, n_features), dtype=dtype)
y = gen.random(size=(n_train_samples, n_targets), dtype=dtype)
X_split = np.array_split(X, num_blocks)
y_split = np.array_split(y, num_blocks)
alpha = 0.5

lambda_identity = alpha * np.eye(X.shape[1])
inverse_term = np.linalg.inv(np.dot(X.T, X) + lambda_identity)
xt_y = np.dot(X.T, y)
coef = np.dot(inverse_term, xt_y)

model = IncrementalRidge(fit_intercept=False, alpha=alpha)
for i in range(num_blocks):
model.partial_fit(X_split[i], y_split[i], queue=queue)
model.finalize_fit()

if queue and queue.sycl_device.is_gpu:
tol = 5e-3 if model.coef_.dtype == np.float32 else 1e-5
else:
tol = 2e-3 if model.coef_.dtype == np.float32 else 1e-5
assert_allclose(coef, model.coef_.T, rtol=tol)

Xt = gen.random(size=(n_test_samples, n_features), dtype=dtype)
gtr = Xt @ coef

res = model.predict(Xt, queue=queue)

tol = 2e-4 if res.dtype == np.float32 else 1e-7
assert_allclose(gtr, res, rtol=tol)
14 changes: 14 additions & 0 deletions sklearnex/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def get_patch_map_core(preview=False):
from .linear_model import (
IncrementalLinearRegression as IncrementalLinearRegression_sklearnex,
)
from .linear_model import IncrementalRidge as IncrementalRidge_sklearnex
from .linear_model import Lasso as Lasso_sklearnex
from .linear_model import LinearRegression as LinearRegression_sklearnex
from .linear_model import LogisticRegression as LogisticRegression_sklearnex
Expand Down Expand Up @@ -412,6 +413,19 @@ def get_patch_map_core(preview=False):
]
]

if daal_check_version((2024, "P", 600)):
ethanglaser marked this conversation as resolved.
Show resolved Hide resolved
# IncrementalRidge
mapping["incrementalridge"] = [
[
(
linear_model_module,
"IncrementalRidge",
IncrementalRidge_sklearnex,
),
None,
]
]

# Configs
mapping["set_config"] = [
[(base_module, "set_config", set_config_sklearnex), None]
Expand Down
2 changes: 2 additions & 0 deletions sklearnex/linear_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@

from .coordinate_descent import ElasticNet, Lasso
from .incremental_linear import IncrementalLinearRegression
from .incremental_ridge import IncrementalRidge
from .linear import LinearRegression
from .logistic_regression import LogisticRegression
from .ridge import Ridge

__all__ = [
"ElasticNet",
"IncrementalLinearRegression",
"IncrementalRidge",
"Lasso",
"LinearRegression",
"LogisticRegression",
Expand Down
Loading