Skip to content

Commit

Permalink
Rename class reference
Browse files Browse the repository at this point in the history
  • Loading branch information
olegkkruglov committed Aug 30, 2024
1 parent 88307b3 commit b10b58a
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 11 deletions.
16 changes: 7 additions & 9 deletions onedal/spmd/basic_statistics/incremental_basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@
from daal4py.sklearn._utils import get_dtype

from ...basic_statistics import (
IncrementalBasicStatistics as IncrementalBasicStatistics_nonSPMD,
IncrementalBasicStatistics as IncrementalBasicStatistics_base,
)
from ...datatypes import _convert_to_supported, to_table
from .._base import BaseEstimatorSPMD


class IncrementalBasicStatistics(BaseEstimatorSPMD, IncrementalBasicStatistics_nonSPMD):
class IncrementalBasicStatistics(BaseEstimatorSPMD, IncrementalBasicStatistics_base):
def _reset(self):
self._partial_result = super(
IncrementalBasicStatistics_nonSPMD, self
)._get_backend("basic_statistics", None, "partial_compute_result")
self._partial_result = super(IncrementalBasicStatistics_base, self)._get_backend(
"basic_statistics", None, "partial_compute_result"
)

def partial_fit(self, X, weights=None, queue=None):
"""
Expand All @@ -50,17 +50,15 @@ def partial_fit(self, X, weights=None, queue=None):
"""
if not hasattr(self, "_queue"):
self._queue = queue
policy = super(IncrementalBasicStatistics_nonSPMD, self)._get_policy(queue, X)
policy = super(IncrementalBasicStatistics_base, self)._get_policy(queue, X)
X, weights = _convert_to_supported(policy, X, weights)

if not hasattr(self, "_onedal_params"):
dtype = get_dtype(X)
self._onedal_params = self._get_onedal_params(False, dtype=dtype)

X_table, weights_table = to_table(X, weights)
self._partial_result = super(
IncrementalBasicStatistics_nonSPMD, self
)._get_backend(
self._partial_result = super(IncrementalBasicStatistics_base, self)._get_backend(
"basic_statistics",
None,
"partial_compute",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
)

from ...basic_statistics import (
IncrementalBasicStatistics as IncrementalBasicStatistics_nonSPMD,
IncrementalBasicStatistics as IncrementalBasicStatistics_base,
)


class IncrementalBasicStatistics(IncrementalBasicStatistics_nonSPMD):
class IncrementalBasicStatistics(IncrementalBasicStatistics_base):
_onedal_incremental_basic_statistics = staticmethod(
onedalSPMD_IncrementalBasicStatistics
)

0 comments on commit b10b58a

Please sign in to comment.