Skip to content

Commit

Permalink
Change naming for base class reference
Browse files Browse the repository at this point in the history
  • Loading branch information
olegkkruglov committed Sep 3, 2024
1 parent 3384c6d commit 7625457
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
10 changes: 5 additions & 5 deletions onedal/spmd/covariance/incremental_covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,19 @@
from daal4py.sklearn._utils import get_dtype

from ...covariance import (
IncrementalEmpiricalCovariance as IncrementalEmpiricalCovariance_base,
IncrementalEmpiricalCovariance as base_IncrementalEmpiricalCovariance,
)
from ...datatypes import _convert_to_supported, to_table
from ...utils import _check_array
from .._base import BaseEstimatorSPMD


class IncrementalEmpiricalCovariance(
BaseEstimatorSPMD, IncrementalEmpiricalCovariance_base
BaseEstimatorSPMD, base_IncrementalEmpiricalCovariance
):
def _reset(self):
self._partial_result = super(
IncrementalEmpiricalCovariance_base, self
base_IncrementalEmpiricalCovariance, self
)._get_backend("covariance", None, "partial_compute_result")

def partial_fit(self, X, y=None, queue=None):
Expand Down Expand Up @@ -60,7 +60,7 @@ def partial_fit(self, X, y=None, queue=None):

self._queue = queue

policy = super(IncrementalEmpiricalCovariance_base, self)._get_policy(queue, X)
policy = super(base_IncrementalEmpiricalCovariance, self)._get_policy(queue, X)

X = _convert_to_supported(policy, X)

Expand All @@ -70,7 +70,7 @@ def partial_fit(self, X, y=None, queue=None):
params = self._get_onedal_params(self._dtype)
table_X = to_table(X)
self._partial_result = super(
IncrementalEmpiricalCovariance_base, self
base_IncrementalEmpiricalCovariance, self
)._get_backend(
"covariance",
None,
Expand Down
4 changes: 2 additions & 2 deletions sklearnex/spmd/covariance/incremental_covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
)

from ...covariance import (
IncrementalEmpiricalCovariance as IncrementalEmpiricalCovariance_base,
IncrementalEmpiricalCovariance as base_IncrementalEmpiricalCovariance,
)


class IncrementalEmpiricalCovariance(IncrementalEmpiricalCovariance_base):
class IncrementalEmpiricalCovariance(base_IncrementalEmpiricalCovariance):
_onedal_incremental_covariance = staticmethod(
onedalSPMD_IncrementalEmpiricalCovariance
)

0 comments on commit 7625457

Please sign in to comment.