Skip to content

Commit

Permalink
[Enhancement] Upgrade IncrementalLinearRegression for underdetermined…
Browse files Browse the repository at this point in the history
… systems (#2175)

* Update incremental_linear.py

* Update incremental_linear.py

* Update incremental_linear.py

* formatting

* add import

* Update incremental_linear.py

* Update deselected_tests.yaml

* Update incremental_linear.py

* formatting

* Update incremental_linear.py
  • Loading branch information
icfaust authored Dec 13, 2024
1 parent c631127 commit c1229d9
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
1 change: 0 additions & 1 deletion deselected_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,6 @@ deselected_tests:
- tests/test_common.py::test_estimators[IncrementalRidge()-check_estimators_pickle]
- tests/test_common.py::test_estimators[IncrementalRidge()-check_estimators_pickle(readonly_memmap=True)]
# There are not enough data to run onedal backend
- tests/test_common.py::test_estimators[IncrementalLinearRegression()-check_fit2d_1sample]
- tests/test_common.py::test_estimators[IncrementalRidge()-check_fit2d_1sample]

# Deselection of LogisticRegression tests over accuracy comparisons with sample_weights
Expand Down
24 changes: 15 additions & 9 deletions sklearnex/linear_model/incremental_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from sklearn.utils.validation import check_is_fitted

from daal4py.sklearn._n_jobs_support import control_n_jobs
from daal4py.sklearn._utils import sklearn_check_version
from daal4py.sklearn._utils import daal_check_version, sklearn_check_version
from onedal.linear_model import (
IncrementalLinearRegression as onedal_IncrementalLinearRegression,
)
Expand Down Expand Up @@ -221,13 +221,21 @@ def _onedal_partial_fit(self, X, y, check_input=True, queue=None):
self._onedal_estimator.partial_fit(X, y, queue=queue)
self._need_to_finalize = True

if daal_check_version((2025, "P", 200)):

def _onedal_validate_underdetermined(self, n_samples, n_features):
pass

else:

def _onedal_validate_underdetermined(self, n_samples, n_features):
is_underdetermined = n_samples < n_features + int(self.fit_intercept)
if is_underdetermined:
raise ValueError("Not enough samples for oneDAL")

def _onedal_finalize_fit(self, queue=None):
assert hasattr(self, "_onedal_estimator")
is_underdetermined = self.n_samples_seen_ < self.n_features_in_ + int(
self.fit_intercept
)
if is_underdetermined:
raise ValueError("Not enough samples to finalize")
self._onedal_validate_underdetermined(self.n_samples_seen_, self.n_features_in_)
self._onedal_estimator.finalize_fit(queue=queue)
self._need_to_finalize = False

Expand Down Expand Up @@ -260,9 +268,7 @@ def _onedal_fit(self, X, y, queue=None):

n_samples, n_features = X.shape

is_underdetermined = n_samples < n_features + int(self.fit_intercept)
if is_underdetermined:
raise ValueError("Not enough samples to run oneDAL backend")
self._onedal_validate_underdetermined(n_samples, n_features)

if self.batch_size is None:
self.batch_size_ = 5 * n_features
Expand Down

0 comments on commit c1229d9

Please sign in to comment.