From 8a09904df707b2e686af0c62808d74a83d40296c Mon Sep 17 00:00:00 2001 From: olegkkruglov <102592747+olegkkruglov@users.noreply.github.com> Date: Mon, 2 Dec 2024 13:51:03 +0100 Subject: [PATCH] ENH: Added serialization for IncrementalPCA (#1926) --- deselected_tests.yaml | 2 - onedal/decomposition/incremental_pca.py | 65 ++++++++------ onedal/decomposition/pca.cpp | 34 +++++++- .../tests/test_incremental_pca.py | 87 +++++++++++++++++++ onedal/spmd/decomposition/incremental_pca.py | 2 + .../preview/decomposition/incremental_pca.py | 13 ++- .../tests/test_incremental_pca.py | 70 +++++++++++++++ 7 files changed, 242 insertions(+), 31 deletions(-) diff --git a/deselected_tests.yaml b/deselected_tests.yaml index 50e22c7041..9e731e914f 100755 --- a/deselected_tests.yaml +++ b/deselected_tests.yaml @@ -282,8 +282,6 @@ deselected_tests: # partial result serialization - tests/test_common.py::test_estimators[IncrementalLinearRegression()-check_estimators_pickle] - tests/test_common.py::test_estimators[IncrementalLinearRegression()-check_estimators_pickle(readonly_memmap=True)] - - tests/test_common.py::test_estimators[IncrementalPCA()-check_estimators_pickle] - - tests/test_common.py::test_estimators[IncrementalPCA()-check_estimators_pickle(readonly_memmap=True)] - tests/test_common.py::test_estimators[IncrementalRidge()-check_estimators_pickle] - tests/test_common.py::test_estimators[IncrementalRidge()-check_estimators_pickle(readonly_memmap=True)] # There are not enough data to run onedal backend diff --git a/onedal/decomposition/incremental_pca.py b/onedal/decomposition/incremental_pca.py index 7199c1e1c2..bd917db306 100644 --- a/onedal/decomposition/incremental_pca.py +++ b/onedal/decomposition/incremental_pca.py @@ -99,11 +99,22 @@ def __init__( self._reset() def _reset(self): - self._partial_result = self._get_backend( - "decomposition", "dim_reduction", "partial_train_result" - ) + self._need_to_finalize = False + module = self._get_backend("decomposition", "dim_reduction") if hasattr(self, "components_"): del self.components_ + self._partial_result = module.partial_train_result() + + def __getstate__(self): + # Since finalize_fit can't be dispatched without directly provided queue + # and the dispatching policy can't be serialized, the computation is finalized + # here and the policy is not saved in serialized data. + + self.finalize_fit() + data = self.__dict__.copy() + data.pop("_queue", None) + + return data def partial_fit(self, X, queue): """Incremental fit with X. All of X is processed as a single batch. @@ -160,6 +171,7 @@ def partial_fit(self, X, queue): self._partial_result, X_table, ) + self._need_to_finalize = True return self def finalize_fit(self, queue=None): @@ -177,28 +189,27 @@ def finalize_fit(self, queue=None): self : object Returns the instance itself. """ - if queue is not None: - policy = self._get_policy(queue) - else: - policy = self._get_policy(self._queue) - result = self._get_backend( - "decomposition", - "dim_reduction", - "finalize_train", - policy, - self._params, - self._partial_result, - ) - self.mean_ = from_table(result.means).ravel() - self.var_ = from_table(result.variances).ravel() - self.components_ = from_table(result.eigenvectors) - self.singular_values_ = np.nan_to_num(from_table(result.singular_values).ravel()) - self.explained_variance_ = np.maximum(from_table(result.eigenvalues).ravel(), 0) - self.explained_variance_ratio_ = from_table( - result.explained_variances_ratio - ).ravel() - self.noise_variance_ = self._compute_noise_variance( - self.n_components_, min(self.n_samples_seen_, self.n_features_in_) - ) - + if self._need_to_finalize: + module = self._get_backend("decomposition", "dim_reduction") + if queue is not None: + policy = self._get_policy(queue) + else: + policy = self._get_policy(self._queue) + result = module.finalize_train(policy, self._params, self._partial_result) + self.mean_ = from_table(result.means).ravel() + self.var_ = from_table(result.variances).ravel() + self.components_ = from_table(result.eigenvectors) + self.singular_values_ = np.nan_to_num( + from_table(result.singular_values).ravel() + ) + self.explained_variance_ = np.maximum( + from_table(result.eigenvalues).ravel(), 0 + ) + self.explained_variance_ratio_ = from_table( + result.explained_variances_ratio + ).ravel() + self.noise_variance_ = self._compute_noise_variance( + self.n_components_, min(self.n_samples_seen_, self.n_features_in_) + ) + self._need_to_finalize = False return self diff --git a/onedal/decomposition/pca.cpp b/onedal/decomposition/pca.cpp index cc19ed396a..c69161ccbc 100644 --- a/onedal/decomposition/pca.cpp +++ b/onedal/decomposition/pca.cpp @@ -15,6 +15,8 @@ *******************************************************************************/ #include "oneapi/dal/algo/pca.hpp" #include "onedal/common.hpp" +#define NO_IMPORT_ARRAY // import_array called in table.cpp +#include "onedal/datatypes/data_conversion.hpp" namespace py = pybind11; @@ -123,7 +125,37 @@ void init_partial_train_result(py::module_& m) { .DEF_ONEDAL_PY_PROPERTY(partial_n_rows, result_t) .DEF_ONEDAL_PY_PROPERTY(partial_crossproduct, result_t) .DEF_ONEDAL_PY_PROPERTY(partial_sum, result_t) - .DEF_ONEDAL_PY_PROPERTY(auxiliary_table, result_t); + .DEF_ONEDAL_PY_PROPERTY(auxiliary_table, result_t) + .def_property_readonly("auxiliary_table_count", &result_t::get_auxiliary_table_count) + .def(py::pickle( + [](const result_t& res) { + py::list auxiliary; + int auxiliary_size = res.get_auxiliary_table_count(); + for (int i = 0; i < auxiliary_size; i++) { + auto aux_table = res.get_auxiliary_table(i); + auxiliary.append(py::cast(convert_to_pyobject(aux_table))); + } + return py::make_tuple( + py::cast(convert_to_pyobject(res.get_partial_n_rows())), + py::cast(convert_to_pyobject(res.get_partial_crossproduct())), + py::cast(convert_to_pyobject(res.get_partial_sum())), + auxiliary + ); + }, + [](py::tuple t) { + if (t.size() != 4) + throw std::runtime_error("Invalid state!"); + result_t res; + if (py::cast(t[0].attr("size")) != 0) res.set_partial_n_rows(convert_to_table(t[0].ptr())); + if (py::cast(t[1].attr("size")) != 0) res.set_partial_crossproduct(convert_to_table(t[1].ptr())); + if (py::cast(t[2].attr("size")) != 0) res.set_partial_sum(convert_to_table(t[2].ptr())); + py::list aux_list = t[3].cast(); + for (int i = 0; i < aux_list.size(); i++) { + res.set_auxiliary_table(convert_to_table(aux_list[i].ptr())); + } + return res; + } + )); } template diff --git a/onedal/decomposition/tests/test_incremental_pca.py b/onedal/decomposition/tests/test_incremental_pca.py index f2054c210b..f427f74fd4 100644 --- a/onedal/decomposition/tests/test_incremental_pca.py +++ b/onedal/decomposition/tests/test_incremental_pca.py @@ -19,6 +19,7 @@ from numpy.testing import assert_allclose from daal4py.sklearn._utils import daal_check_version +from onedal.datatypes import from_table from onedal.decomposition import IncrementalPCA from onedal.tests.utils._device_selection import get_queues @@ -196,3 +197,89 @@ def test_on_random_data( whiten and queue is not None and queue.sycl_device.device_type.name == "gpu" ): assert_allclose(transformed_data, expected_transformed_data, atol=tol) + + +@pytest.mark.parametrize("queue", get_queues()) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_incremental_estimator_pickle(queue, dtype): + import pickle + + from onedal.decomposition import IncrementalPCA + + incpca = IncrementalPCA() + + # Check that estimator can be serialized without any data. + dump = pickle.dumps(incpca) + incpca_loaded = pickle.loads(dump) + seed = 77 + gen = np.random.default_rng(seed) + X = gen.uniform(low=-0.3, high=+0.7, size=(10, 10)) + X = X.astype(dtype) + X_split = np.array_split(X, 2) + incpca.partial_fit(X_split[0], queue=queue) + incpca_loaded.partial_fit(X_split[0], queue=queue) + assert incpca._need_to_finalize == True + assert incpca_loaded._need_to_finalize == True + + # Check that estimator can be serialized after partial_fit call. + dump = pickle.dumps(incpca) + incpca_loaded = pickle.loads(dump) + assert incpca._need_to_finalize == False + # Finalize is called during serialization to make sure partial results are finalized correctly. + assert incpca_loaded._need_to_finalize == False + + partial_n_rows = from_table(incpca._partial_result.partial_n_rows) + partial_n_rows_loaded = from_table(incpca_loaded._partial_result.partial_n_rows) + assert_allclose(partial_n_rows, partial_n_rows_loaded) + + partial_crossproduct = from_table(incpca._partial_result.partial_crossproduct) + partial_crossproduct_loaded = from_table( + incpca_loaded._partial_result.partial_crossproduct + ) + assert_allclose(partial_crossproduct, partial_crossproduct_loaded) + + partial_sum = from_table(incpca._partial_result.partial_sum) + partial_sum_loaded = from_table(incpca_loaded._partial_result.partial_sum) + assert_allclose(partial_sum, partial_sum_loaded) + + auxiliary_table_count = incpca._partial_result.auxiliary_table_count + auxiliary_table_count_loaded = incpca_loaded._partial_result.auxiliary_table_count + assert_allclose(auxiliary_table_count, auxiliary_table_count_loaded) + + for i in range(auxiliary_table_count): + aux_table = incpca._partial_result.get_auxiliary_table(i) + aux_table_loaded = incpca_loaded._partial_result.get_auxiliary_table(i) + assert_allclose(from_table(aux_table), from_table(aux_table_loaded)) + + incpca.partial_fit(X_split[1], queue=queue) + incpca_loaded.partial_fit(X_split[1], queue=queue) + assert incpca._need_to_finalize == True + assert incpca_loaded._need_to_finalize == True + + dump = pickle.dumps(incpca_loaded) + incpca_loaded = pickle.loads(dump) + + assert incpca._need_to_finalize == True + assert incpca_loaded._need_to_finalize == False + + incpca.finalize_fit() + incpca_loaded.finalize_fit() + + # Check that finalized estimator can be serialized. + dump = pickle.dumps(incpca_loaded) + incpca_loaded = pickle.loads(dump) + + assert_allclose(incpca.singular_values_, incpca_loaded.singular_values_, atol=1e-6) + assert_allclose(incpca.n_samples_seen_, incpca_loaded.n_samples_seen_, atol=1e-6) + assert_allclose(incpca.n_features_in_, incpca_loaded.n_features_in_, atol=1e-6) + assert_allclose(incpca.mean_, incpca_loaded.mean_, atol=1e-6) + assert_allclose(incpca.var_, incpca_loaded.var_, atol=1e-6) + assert_allclose( + incpca.explained_variance_, incpca_loaded.explained_variance_, atol=1e-6 + ) + assert_allclose(incpca.components_, incpca_loaded.components_, atol=1e-6) + assert_allclose( + incpca.explained_variance_ratio_, + incpca_loaded.explained_variance_ratio_, + atol=1e-6, + ) diff --git a/onedal/spmd/decomposition/incremental_pca.py b/onedal/spmd/decomposition/incremental_pca.py index 6f82a1ac37..a77c4af9db 100644 --- a/onedal/spmd/decomposition/incremental_pca.py +++ b/onedal/spmd/decomposition/incremental_pca.py @@ -31,6 +31,7 @@ class IncrementalPCA(BaseEstimatorSPMD, base_IncrementalPCA): """ def _reset(self): + self._need_to_finalize = False self._partial_result = super(base_IncrementalPCA, self)._get_backend( "decomposition", "dim_reduction", "partial_train_result" ) @@ -92,6 +93,7 @@ def partial_fit(self, X, y=None, queue=None): self._partial_result, X_table, ) + self._need_to_finalize = True return self def _create_model(self): diff --git a/sklearnex/preview/decomposition/incremental_pca.py b/sklearnex/preview/decomposition/incremental_pca.py index aa8d7e78f1..fdf13e0817 100644 --- a/sklearnex/preview/decomposition/incremental_pca.py +++ b/sklearnex/preview/decomposition/incremental_pca.py @@ -226,7 +226,18 @@ def fit_transform(self, X, y=None, **fit_params): X, ) - __doc__ = _sklearn_IncrementalPCA.__doc__ + __doc__ = ( + _sklearn_IncrementalPCA.__doc__ + + """ + + Note + ---- + Serializing instances of this class will trigger a forced finalization of calculations. + Since finalize_fit can't be dispatched without directly provided queue + and the dispatching policy can't be serialized, the computation is finalized + during serialization call and the policy is not saved in serialized data. + """ + ) fit.__doc__ = _sklearn_IncrementalPCA.fit.__doc__ fit_transform.__doc__ = _sklearn_IncrementalPCA.fit_transform.__doc__ transform.__doc__ = _sklearn_IncrementalPCA.transform.__doc__ diff --git a/sklearnex/preview/decomposition/tests/test_incremental_pca.py b/sklearnex/preview/decomposition/tests/test_incremental_pca.py index 67929bfac8..c4c47c8adb 100644 --- a/sklearnex/preview/decomposition/tests/test_incremental_pca.py +++ b/sklearnex/preview/decomposition/tests/test_incremental_pca.py @@ -264,3 +264,73 @@ def test_sklearnex_partial_fit_on_random_data( X_df = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe) transformed_data = incpca.transform(X_df) check_pca(incpca, dtype, whiten, X, transformed_data) + + +@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues()) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_sklearnex_incremental_estimatior_pickle(dataframe, queue, dtype): + import pickle + + incpca = IncrementalPCA() + + # Check that estimator can be serialized without any data. + dump = pickle.dumps(incpca) + incpca_loaded = pickle.loads(dump) + + seed = 77 + gen = np.random.default_rng(seed) + X = gen.uniform(low=-0.3, high=+0.7, size=(10, 10)) + X = X.astype(dtype) + X_split = np.array_split(X, 2) + X_split_df = _convert_to_dataframe(X_split[0], sycl_queue=queue, target_df=dataframe) + incpca.partial_fit(X_split_df) + incpca_loaded.partial_fit(X_split_df) + dump = pickle.dumps(incpca_loaded) + incpca_loaded = pickle.loads(dump) + assert incpca.batch_size == incpca_loaded.batch_size + assert incpca.n_features_in_ == incpca_loaded.n_features_in_ + assert incpca.n_samples_seen_ == incpca_loaded.n_samples_seen_ + if hasattr(incpca, "_parameter_constraints"): + assert incpca._parameter_constraints == incpca_loaded._parameter_constraints + assert incpca.n_jobs == incpca_loaded.n_jobs + X_split_df = _convert_to_dataframe(X_split[1], sycl_queue=queue, target_df=dataframe) + incpca.partial_fit(X_split_df) + incpca_loaded.partial_fit(X_split_df) + + # Check that estimator can be serialized after partial_fit call. + dump = pickle.dumps(incpca) + incpca_loaded = pickle.loads(dump) + + assert_allclose(incpca.singular_values_, incpca_loaded.singular_values_, atol=1e-6) + assert_allclose(incpca.n_samples_seen_, incpca_loaded.n_samples_seen_, atol=1e-6) + assert_allclose(incpca.n_features_in_, incpca_loaded.n_features_in_, atol=1e-6) + assert_allclose(incpca.mean_, incpca_loaded.mean_, atol=1e-6) + assert_allclose(incpca.var_, incpca_loaded.var_, atol=1e-6) + assert_allclose( + incpca.explained_variance_, incpca_loaded.explained_variance_, atol=1e-6 + ) + assert_allclose(incpca.components_, incpca_loaded.components_, atol=1e-6) + assert_allclose( + incpca.explained_variance_ratio_, + incpca_loaded.explained_variance_ratio_, + atol=1e-6, + ) + + # Check that finalized estimator can be serialized. + dump = pickle.dumps(incpca_loaded) + incpca_loaded = pickle.loads(dump) + + assert_allclose(incpca.singular_values_, incpca_loaded.singular_values_, atol=1e-6) + assert_allclose(incpca.n_samples_seen_, incpca_loaded.n_samples_seen_, atol=1e-6) + assert_allclose(incpca.n_features_in_, incpca_loaded.n_features_in_, atol=1e-6) + assert_allclose(incpca.mean_, incpca_loaded.mean_, atol=1e-6) + assert_allclose(incpca.var_, incpca_loaded.var_, atol=1e-6) + assert_allclose( + incpca.explained_variance_, incpca_loaded.explained_variance_, atol=1e-6 + ) + assert_allclose(incpca.components_, incpca_loaded.components_, atol=1e-6) + assert_allclose( + incpca.explained_variance_ratio_, + incpca_loaded.explained_variance_ratio_, + atol=1e-6, + )