Skip to content

Commit

Permalink
add new assert_all_finite
Browse files Browse the repository at this point in the history
  • Loading branch information
icfaust committed Dec 3, 2024
1 parent 2ab20da commit 334c57b
Show file tree
Hide file tree
Showing 7 changed files with 447 additions and 130 deletions.
45 changes: 7 additions & 38 deletions sklearnex/tests/test_memory_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,14 @@
get_dataframes_and_queues,
)
from onedal.tests.utils._device_selection import get_queues, is_dpctl_device_available
from onedal.utils._array_api import _get_sycl_namespace
from onedal.utils._dpep_helpers import dpctl_available, dpnp_available
from sklearnex import config_context
from sklearnex.tests.utils import PATCHED_FUNCTIONS, PATCHED_MODELS, SPECIAL_INSTANCES
from sklearnex.tests.utils import (
PATCHED_FUNCTIONS,
PATCHED_MODELS,
SPECIAL_INSTANCES,
DummyEstimator,
)
from sklearnex.utils._array_api import get_namespace

if dpctl_available:
Expand Down Expand Up @@ -131,41 +135,6 @@ def gen_functions(functions):
ORDER_DICT = {"F": np.asfortranarray, "C": np.ascontiguousarray}


if _is_dpc_backend:

from sklearn.utils.validation import check_is_fitted

from onedal.datatypes import from_table, to_table

class DummyEstimatorWithTableConversions(BaseEstimator):

def fit(self, X, y=None):
sua_iface, xp, _ = _get_sycl_namespace(X)
X_table = to_table(X)
y_table = to_table(y)
# The presence of the fitted attributes (ending with a trailing
# underscore) is required for the correct check. The cleanup of
# the memory will occur at the estimator instance deletion.
self.x_attr_ = from_table(
X_table, sua_iface=sua_iface, sycl_queue=X.sycl_queue, xp=xp
)
self.y_attr_ = from_table(
y_table, sua_iface=sua_iface, sycl_queue=X.sycl_queue, xp=xp
)
return self

def predict(self, X):
# Checks if the estimator is fitted by verifying the presence of
# fitted attributes (ending with a trailing underscore).
check_is_fitted(self)
sua_iface, xp, _ = _get_sycl_namespace(X)
X_table = to_table(X)
returned_X = from_table(
X_table, sua_iface=sua_iface, sycl_queue=X.sycl_queue, xp=xp
)
return returned_X


def gen_clsf_data(n_samples, n_features, dtype=None):
data, label = make_classification(
n_classes=2, n_samples=n_samples, n_features=n_features, random_state=777
Expand Down Expand Up @@ -369,7 +338,7 @@ def test_table_conversions_memory_leaks(dataframe, queue, order, data_shape, dty
pytest.skip("SYCL device memory leak check requires the level zero sysman")

_kfold_function_template(
DummyEstimatorWithTableConversions,
DummyEstimator,
dataframe,
data_shape,
queue,
Expand Down
2 changes: 2 additions & 0 deletions sklearnex/tests/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
SPECIAL_INSTANCES,
UNPATCHED_FUNCTIONS,
UNPATCHED_MODELS,
DummyEstimator,
_get_processor_info,
call_method,
gen_dataset,
Expand All @@ -39,6 +40,7 @@
"gen_models_info",
"gen_dataset",
"sklearn_clone_dict",
"DummyEstimator",
]

_IS_INTEL = "GenuineIntel" in _get_processor_info()
41 changes: 41 additions & 0 deletions sklearnex/tests/utils/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@
)
from sklearn.datasets import load_diabetes, load_iris
from sklearn.neighbors._base import KNeighborsMixin
from sklearn.utils.validation import check_is_fitted

from onedal.datatypes import from_table, to_table
from onedal.tests.utils._dataframes_support import _convert_to_dataframe
from onedal.utils._array_api import _get_sycl_namespace
from sklearnex import get_patch_map, patch_sklearn, sklearn_is_patched, unpatch_sklearn
from sklearnex.basic_statistics import BasicStatistics, IncrementalBasicStatistics
from sklearnex.linear_model import LogisticRegression
Expand Down Expand Up @@ -369,3 +372,41 @@ def _get_processor_info():
)

return proc


class DummyEstimator(BaseEstimator):

def fit(self, X, y=None):
sua_iface, xp, _ = _get_sycl_namespace(X)
X_table = to_table(X)
y_table = to_table(y)
# The presence of the fitted attributes (ending with a trailing
# underscore) is required for the correct check. The cleanup of
# the memory will occur at the estimator instance deletion.
if sua_iface:
self.x_attr_ = from_table(
X_table, sua_iface=sua_iface, sycl_queue=X.sycl_queue, xp=xp
)
self.y_attr_ = from_table(
y_table, sua_iface=sua_iface, sycl_queue=X.sycl_queue, xp=xp
)
else:
self.x_attr = from_table(X_table)
self.y_attr = from_table(y_table)

return self

def predict(self, X):
# Checks if the estimator is fitted by verifying the presence of
# fitted attributes (ending with a trailing underscore).
check_is_fitted(self)
sua_iface, xp, _ = _get_sycl_namespace(X)
X_table = to_table(X)
if sua_iface:
returned_X = from_table(
X_table, sua_iface=sua_iface, sycl_queue=X.sycl_queue, xp=xp
)
else:
returned_X = from_table(X_table)

return returned_X
4 changes: 2 additions & 2 deletions sklearnex/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@
# limitations under the License.
# ===============================================================================

from .validation import _assert_all_finite
from .validation import assert_all_finite

__all__ = ["_assert_all_finite"]
__all__ = ["assert_all_finite"]
89 changes: 0 additions & 89 deletions sklearnex/utils/tests/test_finite.py

This file was deleted.

Loading

0 comments on commit 334c57b

Please sign in to comment.