Skip to content

Commit

Permalink
Fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
olegkkruglov committed Jul 16, 2024
1 parent ee4af1b commit 2a3fcd5
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 8 deletions.
6 changes: 4 additions & 2 deletions examples/sklearnex/incremental_covariance_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@

from sklearnex.spmd.covariance import IncrementalEmpiricalCovariance


def get_local_data(data, comm):
rank = comm.Get_rank()
num_ranks = comm.Get_size()
local_size = (data.shape[0] + num_ranks - 1) // num_ranks
return data[rank * local_size: (rank + 1) * local_size]
return data[rank * local_size : (rank + 1) * local_size]


q = dpctl.SyclQueue("gpu")
comm = MPI.COMM_WORLD
Expand All @@ -45,4 +47,4 @@ def get_local_data(data, comm):
dpt_X = dpt.asarray(X_split[i], usm_type="device", sycl_queue=q)
cov.partial_fit(dpt_X)

print(f"Computed covariance values on rank {comm.Get_rank()}:\n", cov.covariance_)
print(f"Computed covariance values on rank {comm.Get_rank()}:\n", cov.covariance_)
11 changes: 7 additions & 4 deletions onedal/spmd/covariance/incremental_covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,20 @@
# limitations under the License.
# ==============================================================================

from ...covariance import IncrementalEmpiricalCovariance as IncrementalEmpiricalCovariance_Batch

from ..._device_offload import support_usm_ndarray
from ...covariance import (
IncrementalEmpiricalCovariance as IncrementalEmpiricalCovariance_Batch,
)
from .._base import BaseEstimatorSPMD


class IncrementalEmpiricalCovariance(BaseEstimatorSPMD, IncrementalEmpiricalCovariance_Batch):
class IncrementalEmpiricalCovariance(
BaseEstimatorSPMD, IncrementalEmpiricalCovariance_Batch
):
@support_usm_ndarray()
def partial_fit(self, X, y=None, queue=None):
return super().partial_fit(X, queue=queue)

@support_usm_ndarray()
def finalize_fit(self, queue=None):
return super().finalize_fit(queue=queue)
return super().finalize_fit(queue=queue)
10 changes: 8 additions & 2 deletions sklearnex/spmd/covariance/incremental_covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,14 @@
# limitations under the License.
# ==============================================================================

from ...covariance import IncrementalEmpiricalCovariance as IncrementalEmpiricalCovariance_batch
from onedal.spmd.covariance import IncrementalEmpiricalCovariance as onedal_IncrementalEmpiricalCovariance
from onedal.spmd.covariance import (
IncrementalEmpiricalCovariance as onedal_IncrementalEmpiricalCovariance,
)

from ...covariance import (
IncrementalEmpiricalCovariance as IncrementalEmpiricalCovariance_batch,
)


class IncrementalEmpiricalCovariance(IncrementalEmpiricalCovariance_batch):
_onedal_incremental_covariance = staticmethod(onedal_IncrementalEmpiricalCovariance)

0 comments on commit 2a3fcd5

Please sign in to comment.