Skip to content

Commit

Permalink
Merge branch 'intel:main' into enh/functional_array_api
Browse files Browse the repository at this point in the history
  • Loading branch information
samir-nasibli authored Sep 4, 2024
2 parents fe38790 + 138ed18 commit 3dd9521
Show file tree
Hide file tree
Showing 14 changed files with 398 additions and 55 deletions.
58 changes: 58 additions & 0 deletions examples/sklearnex/incremental_covariance_spmd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# ===============================================================================
# Copyright 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===============================================================================

import dpctl
import dpctl.tensor as dpt
import numpy as np
from mpi4py import MPI

from sklearnex.spmd.covariance import IncrementalEmpiricalCovariance


def get_local_data(data, comm):
rank = comm.Get_rank()
num_ranks = comm.Get_size()
local_size = (data.shape[0] + num_ranks - 1) // num_ranks
return data[rank * local_size : (rank + 1) * local_size]


# We create SYCL queue and MPI communicator to perform computation on multiple GPUs

q = dpctl.SyclQueue("gpu")
comm = MPI.COMM_WORLD

num_batches = 2
seed = 77
num_samples, num_features = 3000, 3
drng = np.random.default_rng(seed)
X = drng.random(size=(num_samples, num_features))

# Local data are obtained for each GPU and splitted into batches

X_local = get_local_data(X, comm)
X_split = np.array_split(X_local, num_batches)

cov = IncrementalEmpiricalCovariance()

# Partial fit is called for each batch on each GPU

for i in range(num_batches):
dpt_X = dpt.asarray(X_split[i], usm_type="device", sycl_queue=q)
cov.partial_fit(dpt_X)

# Finalization of results is performed in a lazy way after requesting results like in non-SPMD incremental estimators.

print(f"Computed covariance values on rank {comm.Get_rank()}:\n", cov.covariance_)
21 changes: 0 additions & 21 deletions onedal/basic_statistics/basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,27 +96,6 @@ def fit(self, data, sample_weight=None, queue=None):

return self

def compute(self, data, weights=None, queue=None):
warnings.warn(
"Method `compute` was deprecated in version 2024.7 and will be "
"removed in 2025.0. Use `fit` instead."
)

is_csr = _is_csr(data)

if data is not None:
data = _check_array(data, ensure_2d=False)
if weights is not None:
weights = _check_array(weights, ensure_2d=False)

policy = self._get_policy(queue, data, weights)
data, weights = _convert_to_supported(policy, data, weights)
data_table, weights_table = to_table(data, weights)
dtype = data.dtype
res = self._compute_raw(data_table, weights_table, policy, dtype, is_csr)

return {k: from_table(v).ravel() for k, v in res.items()}

def _compute_raw(
self, data_table, weights_table, policy, dtype=np.float32, is_csr=False
):
Expand Down
16 changes: 0 additions & 16 deletions onedal/basic_statistics/tests/test_basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,19 +296,3 @@ def test_options_csr(queue, option, dtype):
tol = fp32tol if res.dtype == np.float32 else fp64tol

assert_allclose(gtr, res, rtol=tol)


def test_warning():
basicstat = BasicStatistics()
data = np.array([0, 1])

with pytest.warns(
UserWarning,
match="Method `compute` was deprecated in version 2024.7 and will be removed in 2025.0. Use `fit` instead.",
) as warn_record:
basicstat.compute(data)

if daal_check_version((2025, "P", 0)):
assert len(warn_record) == 0
else:
assert len(warn_record) == 1
4 changes: 3 additions & 1 deletion onedal/covariance/covariance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,11 @@ ONEDAL_PY_INIT_MODULE(covariance) {
using namespace dal::covariance;

auto sub = m.def_submodule("covariance");

#ifdef ONEDAL_DATA_PARALLEL_SPMD
ONEDAL_PY_INSTANTIATE(init_compute_ops, sub, policy_spmd, task::compute);
#else
ONEDAL_PY_INSTANTIATE(init_finalize_compute_ops, sub, policy_spmd, task::compute);
#else
ONEDAL_PY_INSTANTIATE(init_compute_ops, sub, policy_list, task::compute);
ONEDAL_PY_INSTANTIATE(init_partial_compute_ops, sub, policy_list, task::compute);
ONEDAL_PY_INSTANTIATE(init_finalize_compute_ops, sub, policy_list, task::compute);
Expand Down
21 changes: 13 additions & 8 deletions onedal/covariance/incremental_covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
# ===============================================================================
import numpy as np

from daal4py.sklearn._utils import daal_check_version, get_dtype, make2d
from onedal import _backend
from daal4py.sklearn._utils import daal_check_version, get_dtype

from ..datatypes import _convert_to_supported, from_table, to_table
from ..utils import _check_array
Expand Down Expand Up @@ -86,10 +85,11 @@ def partial_fit(self, X, y=None, queue=None):
"""
X = _check_array(X, dtype=[np.float64, np.float32], ensure_2d=True)

if not hasattr(self, "_policy"):
self._policy = self._get_policy(queue, X)
self._queue = queue

X = _convert_to_supported(self._policy, X)
policy = self._get_policy(queue, X)

X = _convert_to_supported(policy, X)

if not hasattr(self, "_dtype"):
self._dtype = get_dtype(X)
Expand All @@ -100,7 +100,7 @@ def partial_fit(self, X, y=None, queue=None):
"covariance",
None,
"partial_compute",
self._policy,
policy,
params,
self._partial_result,
table_X,
Expand All @@ -114,19 +114,24 @@ def finalize_fit(self, queue=None):
Parameters
----------
queue : dpctl.SyclQueue
Not used here, added for API conformance
If not None, use this queue for computations.
Returns
-------
self : object
Returns the instance itself.
"""
params = self._get_onedal_params(self._dtype)
if queue is not None:
policy = self._get_policy(queue)
else:
policy = self._get_policy(self._queue)

result = self._get_backend(
"covariance",
None,
"finalize_compute",
self._policy,
policy,
params,
self._partial_result,
)
Expand Down
3 changes: 2 additions & 1 deletion onedal/spmd/covariance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@
# ==============================================================================

from .covariance import EmpiricalCovariance
from .incremental_covariance import IncrementalEmpiricalCovariance

__all__ = ["EmpiricalCovariance"]
__all__ = ["EmpiricalCovariance", "IncrementalEmpiricalCovariance"]
82 changes: 82 additions & 0 deletions onedal/spmd/covariance/incremental_covariance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# ==============================================================================
# Copyright 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import numpy as np

from daal4py.sklearn._utils import get_dtype

from ...covariance import (
IncrementalEmpiricalCovariance as base_IncrementalEmpiricalCovariance,
)
from ...datatypes import _convert_to_supported, to_table
from ...utils import _check_array
from .._base import BaseEstimatorSPMD


class IncrementalEmpiricalCovariance(
BaseEstimatorSPMD, base_IncrementalEmpiricalCovariance
):
def _reset(self):
self._partial_result = super(
base_IncrementalEmpiricalCovariance, self
)._get_backend("covariance", None, "partial_compute_result")

def partial_fit(self, X, y=None, queue=None):
"""
Computes partial data for the covariance matrix
from data batch X and saves it to `_partial_result`.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Training data batch, where `n_samples` is the number of samples
in the batch, and `n_features` is the number of features.
y : Ignored
Not used, present for API consistency by convention.
queue : dpctl.SyclQueue
If not None, use this queue for computations.
Returns
-------
self : object
Returns the instance itself.
"""
X = _check_array(X, dtype=[np.float64, np.float32], ensure_2d=True)

self._queue = queue

policy = super(base_IncrementalEmpiricalCovariance, self)._get_policy(queue, X)

X = _convert_to_supported(policy, X)

if not hasattr(self, "_dtype"):
self._dtype = get_dtype(X)

params = self._get_onedal_params(self._dtype)
table_X = to_table(X)
self._partial_result = super(
base_IncrementalEmpiricalCovariance, self
)._get_backend(
"covariance",
None,
"partial_compute",
policy,
params,
self._partial_result,
table_X,
)
3 changes: 0 additions & 3 deletions sklearnex/basic_statistics/basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,6 @@ def _onedal_fit(self, X, sample_weight=None, queue=None):
self._onedal_estimator.fit(X, sample_weight, queue)
self._save_attributes()

def compute(self, data, weights=None, queue=None):
return self._onedal_estimator.compute(data, weights, queue)

def fit(self, X, y=None, *, sample_weight=None):
"""Compute statistics with X, using minibatches of size batch_size.
Expand Down
8 changes: 4 additions & 4 deletions sklearnex/covariance/incremental_covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ def _onedal_supported(self, method_name, *data):
)
return patching_status

def _onedal_finalize_fit(self):
def _onedal_finalize_fit(self, queue=None):
assert hasattr(self, "_onedal_estimator")
self._onedal_estimator.finalize_fit()
self._onedal_estimator.finalize_fit(queue=queue)
self._need_to_finalize = False

if not daal_check_version((2024, "P", 400)) and self.assume_centered:
Expand Down Expand Up @@ -192,7 +192,7 @@ def _onedal_partial_fit(self, X, queue=None, check_input=True):
else:
self.n_samples_seen_ += X.shape[0]

self._onedal_estimator.partial_fit(X, queue)
self._onedal_estimator.partial_fit(X, queue=queue)
finally:
self._need_to_finalize = True

Expand Down Expand Up @@ -326,7 +326,7 @@ def _onedal_fit(self, X, queue=None):
X_batch = X[batch]
self._onedal_partial_fit(X_batch, queue=queue, check_input=False)

self._onedal_finalize_fit()
self._onedal_finalize_fit(queue=queue)

return self

Expand Down
11 changes: 11 additions & 0 deletions sklearnex/covariance/tests/test_incremental_covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from sklearn.datasets import load_diabetes
from sklearn.decomposition import PCA

from daal4py.sklearn._utils import daal_check_version
from onedal.tests.utils._dataframes_support import (
_as_numpy,
_convert_to_dataframe,
Expand All @@ -37,6 +38,11 @@
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
@pytest.mark.parametrize("assume_centered", [True, False])
def test_sklearnex_partial_fit_on_gold_data(dataframe, queue, dtype, assume_centered):
is_gpu = queue is not None and queue.sycl_device.is_gpu
if assume_centered and is_gpu and not daal_check_version((2025, "P", 0)):
pytest.skip(
"Due to a bug on oneDAL side, means are not set to zero when assume_centered=True"
)
from sklearnex.covariance import IncrementalEmpiricalCovariance

X = np.array([[0, 1], [0, 1]])
Expand Down Expand Up @@ -143,6 +149,11 @@ def test_sklearnex_partial_fit_on_random_data(
def test_sklearnex_fit_on_random_data(
dataframe, queue, num_batches, row_count, column_count, dtype, assume_centered
):
is_gpu = queue is not None and queue.sycl_device.is_gpu
if assume_centered and is_gpu and not daal_check_version((2025, "P", 0)):
pytest.skip(
"Due to a bug on oneDAL side, means are not set to zero when assume_centered=True"
)
from sklearnex.covariance import IncrementalEmpiricalCovariance

seed = 77
Expand Down
3 changes: 2 additions & 1 deletion sklearnex/spmd/covariance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@
# ==============================================================================

from .covariance import EmpiricalCovariance
from .incremental_covariance import IncrementalEmpiricalCovariance

__all__ = ["EmpiricalCovariance"]
__all__ = ["EmpiricalCovariance", "IncrementalEmpiricalCovariance"]
37 changes: 37 additions & 0 deletions sklearnex/spmd/covariance/incremental_covariance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# ==============================================================================
# Copyright 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from onedal.spmd.covariance import (
IncrementalEmpiricalCovariance as onedalSPMD_IncrementalEmpiricalCovariance,
)

from ...covariance import (
IncrementalEmpiricalCovariance as base_IncrementalEmpiricalCovariance,
)


class IncrementalEmpiricalCovariance(base_IncrementalEmpiricalCovariance):
"""
Incremental distributed estimator for covariance.
Allows to distributely compute empirical covariance estimated by maximum
likelihood method if data are splitted into batches.
API is the same as for `sklearnex.covariance.IncrementalEmpiricalCovariance`
"""

_onedal_incremental_covariance = staticmethod(
onedalSPMD_IncrementalEmpiricalCovariance
)
Loading

0 comments on commit 3dd9521

Please sign in to comment.