Skip to content

Commit

Permalink
change incremental algo
Browse files Browse the repository at this point in the history
  • Loading branch information
icfaust committed Dec 3, 2024
1 parent 6579bbd commit 580e697
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 36 deletions.
14 changes: 4 additions & 10 deletions onedal/covariance/incremental_covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ===============================================================================
import numpy as np

from daal4py.sklearn._utils import daal_check_version, get_dtype
from daal4py.sklearn._utils import daal_check_version

from ..datatypes import _convert_to_supported, from_table, to_table
from ..utils import _check_array
from .covariance import BaseEmpiricalCovariance


Expand Down Expand Up @@ -95,27 +93,23 @@ def partial_fit(self, X, y=None, queue=None):
self : object
Returns the instance itself.
"""
X = _check_array(X, dtype=[np.float64, np.float32], ensure_2d=True)

self._queue = queue

policy = self._get_policy(queue, X)

X = _convert_to_supported(policy, X)

X_table = to_table(_convert_to_supported(policy, X))
if not hasattr(self, "_dtype"):
self._dtype = get_dtype(X)
self._dtype = X_table.dtype

params = self._get_onedal_params(self._dtype)
table_X = to_table(X)
self._partial_result = self._get_backend(
"covariance",
None,
"partial_compute",
policy,
params,
self._partial_result,
table_X,
X_table,
)
self._need_to_finalize = True

Expand Down
42 changes: 16 additions & 26 deletions sklearnex/covariance/incremental_covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,14 @@
from onedal.covariance import (
IncrementalEmpiricalCovariance as onedal_IncrementalEmpiricalCovariance,
)
from onedal.utils._array_api import _is_numpy_namespace
from sklearnex import config_context

from .._device_offload import dispatch, wrap_output_data
from .._utils import IntelEstimator, PatchingConditionsChain, register_hyperparameters
from ..metrics import pairwise_distances
from ..utils._array_api import get_namespace

if sklearn_check_version("1.2"):
from sklearn.utils._param_validation import Interval

if sklearn_check_version("1.6"):
from sklearn.utils.validation import validate_data
else:
validate_data = BaseEstimator._validate_data
from ..utils.validation import validate_data


@control_n_jobs(decorated_methods=["partial_fit", "fit", "_onedal_finalize_fit"])
Expand Down Expand Up @@ -152,8 +146,9 @@ def _onedal_finalize_fit(self, queue=None):

if not daal_check_version((2024, "P", 400)) and self.assume_centered:
location = self._onedal_estimator.location_[None, :]
self._onedal_estimator.covariance_ += np.dot(location.T, location)
self._onedal_estimator.location_ = np.zeros_like(np.squeeze(location))
lp, _ = get_namespace(location)
self._onedal_estimator.covariance_ += lp.dot(location.T, location)
self._onedal_estimator.location_ = lp.zeros_like(lp.squeeze(location))
if self.store_precision:
self.precision_ = linalg.pinvh(
self._onedal_estimator.covariance_, check_finite=False
Expand Down Expand Up @@ -187,26 +182,24 @@ def _onedal_partial_fit(self, X, queue=None, check_input=True):

first_pass = not hasattr(self, "n_samples_seen_") or self.n_samples_seen_ == 0

# finite check occurs on onedal side
if check_input:
xp, _ = get_namespace(X)
if sklearn_check_version("1.2"):
self._validate_params()

if sklearn_check_version("1.0"):
X = validate_data(
self,
X,
dtype=[np.float64, np.float32],
dtype=[xp.float64, xp.float32],
reset=first_pass,
copy=self.copy,
force_all_finite=False,
)
else:
X = check_array(
X,
dtype=[np.float64, np.float32],
dtype=[xp.float64, xp.float32],
copy=self.copy,
force_all_finite=False,
)

onedal_params = {
Expand Down Expand Up @@ -239,16 +232,16 @@ def score(self, X_test, y=None):
X = validate_data(
self,
X_test,
dtype=[np.float64, np.float32],
dtype=[xp.float64, xp.float32],
reset=False,
)
else:
X = check_array(
X_test,
dtype=[np.float64, np.float32],
dtype=[xp.float64, xp.float32],
)

if "numpy" not in xp.__name__:
if not _is_numpy_namespace(xp):
location = xp.asarray(location, device=X_test.device)
# depending on the sklearn version, check_array
# and validate_data will return only numpy arrays
Expand Down Expand Up @@ -337,19 +330,16 @@ def _onedal_fit(self, X, queue=None):
if sklearn_check_version("1.2"):
self._validate_params()

# finite check occurs on onedal side
xp, _ = get_namespace(X)
if sklearn_check_version("1.0"):
X = validate_data(
self,
X,
dtype=[np.float64, np.float32],
dtype=[xp.float64, xp.float32],
copy=self.copy,
force_all_finite=False,
)
else:
X = check_array(
X, dtype=[np.float64, np.float32], copy=self.copy, force_all_finite=False
)
X = check_array(X, dtype=[xp.float64, xp.float32], copy=self.copy)
self.n_features_in_ = X.shape[1]

self.batch_size_ = self.batch_size if self.batch_size else 5 * self.n_features_in_
Expand Down Expand Up @@ -378,8 +368,8 @@ def mahalanobis(self, X):
# pairwise_distances will check n_features (via n_feature matching with
# self.location_) , and will check for finiteness via check array
# check_feature_names will match _validate_data functionally
location = self.location_[np.newaxis, :]
if "numpy" not in xp.__name__:
location = self.location_[None, :]
if not _is_numpy_namespace(xp):
# Guarantee that inputs to pairwise_distances match in type and location
location = xp.asarray(location, device=X.device)

Expand Down

0 comments on commit 580e697

Please sign in to comment.