Skip to content

Commit

Permalink
ENH: Added serialization for IncrementalBasicStatistics (#2180)
Browse files Browse the repository at this point in the history
  • Loading branch information
olegkkruglov authored Dec 2, 2024
1 parent e634529 commit 8ee38f2
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 17 deletions.
30 changes: 29 additions & 1 deletion onedal/basic_statistics/basic_statistics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
#include "onedal/common.hpp"
#include "onedal/version.hpp"

#define NO_IMPORT_ARRAY // import_array called in table.cpp
#include "onedal/datatypes/data_conversion.hpp"

#include <string>
#include <regex>
#include <map>
Expand Down Expand Up @@ -204,7 +207,32 @@ void init_partial_compute_result(py::module_& m) {
.DEF_ONEDAL_PY_PROPERTY(partial_max, result_t)
.DEF_ONEDAL_PY_PROPERTY(partial_sum, result_t)
.DEF_ONEDAL_PY_PROPERTY(partial_sum_squares, result_t)
.DEF_ONEDAL_PY_PROPERTY(partial_sum_squares_centered, result_t);
.DEF_ONEDAL_PY_PROPERTY(partial_sum_squares_centered, result_t)
.def(py::pickle(
[](const result_t& res) {
return py::make_tuple(
py::cast<py::object>(convert_to_pyobject(res.get_partial_n_rows())),
py::cast<py::object>(convert_to_pyobject(res.get_partial_min())),
py::cast<py::object>(convert_to_pyobject(res.get_partial_max())),
py::cast<py::object>(convert_to_pyobject(res.get_partial_sum())),
py::cast<py::object>(convert_to_pyobject(res.get_partial_sum_squares())),
py::cast<py::object>(convert_to_pyobject(res.get_partial_sum_squares_centered()))
);
},
[](py::tuple t) {
if (t.size() != 6)
throw std::runtime_error("Invalid state!");
result_t res;
if (py::cast<int>(t[0].attr("size")) != 0) res.set_partial_n_rows(convert_to_table(t[0].ptr()));
if (py::cast<int>(t[1].attr("size")) != 0) res.set_partial_min(convert_to_table(t[1].ptr()));
if (py::cast<int>(t[2].attr("size")) != 0) res.set_partial_max(convert_to_table(t[2].ptr()));
if (py::cast<int>(t[2].attr("size")) != 0) res.set_partial_sum(convert_to_table(t[3].ptr()));
if (py::cast<int>(t[2].attr("size")) != 0) res.set_partial_sum_squares(convert_to_table(t[4].ptr()));
if (py::cast<int>(t[2].attr("size")) != 0) res.set_partial_sum_squares_centered(convert_to_table(t[5].ptr()));

return res;
}
));
}

ONEDAL_PY_DECLARE_INSTANTIATOR(init_compute_result);
Expand Down
48 changes: 32 additions & 16 deletions onedal/basic_statistics/incremental_basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,21 @@ def __init__(self, result_options="all"):
self._reset()

def _reset(self):
self._need_to_finalize = False
self._partial_result = self._get_backend(
"basic_statistics", None, "partial_compute_result"
)

def __getstate__(self):
# Since finalize_fit can't be dispatched without directly provided queue
# and the dispatching policy can't be serialized, the computation is finalized
# here and the policy is not saved in serialized data.
self.finalize_fit()
data = self.__dict__.copy()
data.pop("_queue", None)

return data

def partial_fit(self, X, weights=None, queue=None):
"""
Computes partial data for basic statistics
Expand Down Expand Up @@ -124,6 +135,9 @@ def partial_fit(self, X, weights=None, queue=None):
weights_table,
)

self._need_to_finalize = True
return self

def finalize_fit(self, queue=None):
"""
Finalizes basic statistics computation and obtains result
Expand All @@ -139,22 +153,24 @@ def finalize_fit(self, queue=None):
self : object
Returns the instance itself.
"""
if self._need_to_finalize:
if queue is not None:
policy = self._get_policy(queue)
else:
policy = self._get_policy(self._queue)

result = self._get_backend(
"basic_statistics",
None,
"finalize_compute",
policy,
self._onedal_params,
self._partial_result,
)
options = self._get_result_options(self.options).split("|")
for opt in options:
setattr(self, opt, from_table(getattr(result, opt)).ravel())

if queue is not None:
policy = self._get_policy(queue)
else:
policy = self._get_policy(self._queue)

result = self._get_backend(
"basic_statistics",
None,
"finalize_compute",
policy,
self._onedal_params,
self._partial_result,
)
options = self._get_result_options(self.options).split("|")
for opt in options:
setattr(self, opt, from_table(getattr(result, opt)).ravel())
self._need_to_finalize = False

return self
88 changes: 88 additions & 0 deletions onedal/basic_statistics/tests/test_incremental_basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from onedal.basic_statistics import IncrementalBasicStatistics
from onedal.basic_statistics.tests.utils import options_and_tests
from onedal.datatypes import from_table
from onedal.tests.utils._device_selection import get_queues


Expand Down Expand Up @@ -189,3 +190,90 @@ def test_all_option_on_random_data(
gtr = function(data)
tol = fp32tol if res.dtype == np.float32 else fp64tol
assert_allclose(gtr, res, atol=tol)


@pytest.mark.parametrize("queue", get_queues())
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_incremental_estimator_pickle(queue, dtype):
import pickle

from onedal.basic_statistics import IncrementalBasicStatistics

incbs = IncrementalBasicStatistics()

# Check that estimator can be serialized without any data.
dump = pickle.dumps(incbs)
incbs_loaded = pickle.loads(dump)
seed = 77
gen = np.random.default_rng(seed)
X = gen.uniform(low=-0.3, high=+0.7, size=(10, 10))
X = X.astype(dtype)
X_split = np.array_split(X, 2)
incbs.partial_fit(X_split[0], queue=queue)
incbs_loaded.partial_fit(X_split[0], queue=queue)

assert incbs._need_to_finalize == True
assert incbs_loaded._need_to_finalize == True

# Check that estimator can be serialized after partial_fit call.
dump = pickle.dumps(incbs)
incbs_loaded = pickle.loads(dump)
assert incbs._need_to_finalize == False
# Finalize is called during serialization to make sure partial results are finalized correctly.
assert incbs_loaded._need_to_finalize == False

partial_n_rows = from_table(incbs._partial_result.partial_n_rows)
partial_n_rows_loaded = from_table(incbs_loaded._partial_result.partial_n_rows)
assert_allclose(partial_n_rows, partial_n_rows_loaded)

partial_min = from_table(incbs._partial_result.partial_min)
partial_min_loaded = from_table(incbs_loaded._partial_result.partial_min)
assert_allclose(partial_min, partial_min_loaded)

partial_max = from_table(incbs._partial_result.partial_max)
partial_max_loaded = from_table(incbs_loaded._partial_result.partial_max)
assert_allclose(partial_max, partial_max_loaded)

partial_sum = from_table(incbs._partial_result.partial_sum)
partial_sum_loaded = from_table(incbs_loaded._partial_result.partial_sum)
assert_allclose(partial_sum, partial_sum_loaded)

partial_sum_squares = from_table(incbs._partial_result.partial_sum_squares)
partial_sum_squares_loaded = from_table(
incbs_loaded._partial_result.partial_sum_squares
)
assert_allclose(partial_sum_squares, partial_sum_squares_loaded)

partial_sum_squares_centered = from_table(
incbs._partial_result.partial_sum_squares_centered
)
partial_sum_squares_centered_loaded = from_table(
incbs_loaded._partial_result.partial_sum_squares_centered
)
assert_allclose(partial_sum_squares_centered, partial_sum_squares_centered_loaded)

incbs.partial_fit(X_split[1], queue=queue)
incbs_loaded.partial_fit(X_split[1], queue=queue)
assert incbs._need_to_finalize == True
assert incbs_loaded._need_to_finalize == True

dump = pickle.dumps(incbs_loaded)
incbs_loaded = pickle.loads(dump)

assert incbs._need_to_finalize == True
assert incbs_loaded._need_to_finalize == False

incbs.finalize_fit()
incbs_loaded.finalize_fit()

# Check that finalized estimator can be serialized.
dump = pickle.dumps(incbs_loaded)
incbs_loaded = pickle.loads(dump)

for result_option in options_and_tests:
_, tols = options_and_tests[result_option]
fp32tol, fp64tol = tols
res = getattr(incbs, result_option)
res_loaded = getattr(incbs_loaded, result_option)
tol = fp32tol if res.dtype == np.float32 else fp64tol
assert_allclose(res, res_loaded, atol=tol)
4 changes: 4 additions & 0 deletions onedal/spmd/basic_statistics/incremental_basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

class IncrementalBasicStatistics(BaseEstimatorSPMD, base_IncrementalBasicStatistics):
def _reset(self):
self._need_to_finalize = False
self._partial_result = super(base_IncrementalBasicStatistics, self)._get_backend(
"basic_statistics", None, "partial_compute_result"
)
Expand Down Expand Up @@ -67,3 +68,6 @@ def partial_fit(self, X, weights=None, queue=None):
X_table,
weights_table,
)

self._need_to_finalize = True
return self
7 changes: 7 additions & 0 deletions sklearnex/basic_statistics/incremental_basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ class IncrementalBasicStatistics(IntelEstimator, BaseEstimator):
----
Attribute exists only if corresponding result option has been provided.
Note
----
Serializing instances of this class will trigger a forced finalization of calculations.
Since finalize_fit can't be dispatched without directly provided queue
and the dispatching policy can't be serialized, the computation is finalized
during serialization call and the policy is not saved in serialized data.
Note
----
Names of attributes without the trailing underscore are
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,3 +397,59 @@ def test_warning():
assert len(warn_record) == 0, i
else:
assert len(warn_record) == 1, i


@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_sklearnex_incremental_estimatior_pickle(dataframe, queue, dtype):
import pickle

from sklearnex.basic_statistics import IncrementalBasicStatistics

incbs = IncrementalBasicStatistics()

# Check that estimator can be serialized without any data.
dump = pickle.dumps(incbs)
incbs_loaded = pickle.loads(dump)
seed = 77
gen = np.random.default_rng(seed)
X = gen.uniform(low=-0.3, high=+0.7, size=(10, 10))
X = X.astype(dtype)
X_split = np.array_split(X, 2)
X_split_df = _convert_to_dataframe(X_split[0], sycl_queue=queue, target_df=dataframe)
incbs.partial_fit(X_split_df)
incbs_loaded.partial_fit(X_split_df)

# Check that estimator can be serialized after partial_fit call.
dump = pickle.dumps(incbs_loaded)
incbs_loaded = pickle.loads(dump)

X_split_df = _convert_to_dataframe(X_split[1], sycl_queue=queue, target_df=dataframe)
incbs.partial_fit(X_split_df)
incbs_loaded.partial_fit(X_split_df)
dump = pickle.dumps(incbs)
incbs_loaded = pickle.loads(dump)
assert incbs.batch_size == incbs_loaded.batch_size
assert incbs.n_features_in_ == incbs_loaded.n_features_in_
assert incbs.n_samples_seen_ == incbs_loaded.n_samples_seen_
if hasattr(incbs, "_parameter_constraints"):
assert incbs._parameter_constraints == incbs_loaded._parameter_constraints
assert incbs.n_jobs == incbs_loaded.n_jobs
for result_option in options_and_tests:
_, tols = options_and_tests[result_option]
fp32tol, fp64tol = tols
res = getattr(incbs, result_option)
res_loaded = getattr(incbs_loaded, result_option)
tol = fp32tol if res.dtype == np.float32 else fp64tol
assert_allclose(res, res_loaded, atol=tol)

# Check that finalized estimator can be serialized.
dump = pickle.dumps(incbs_loaded)
incbs_loaded = pickle.loads(dump)
for result_option in options_and_tests:
_, tols = options_and_tests[result_option]
fp32tol, fp64tol = tols
res = getattr(incbs, result_option)
res_loaded = getattr(incbs_loaded, result_option)
tol = fp32tol if res.dtype == np.float32 else fp64tol
assert_allclose(res, res_loaded, atol=tol)

0 comments on commit 8ee38f2

Please sign in to comment.