From 88307b38bc85a92e82c33e7c9fa5fe54cddf822b Mon Sep 17 00:00:00 2001 From: "Kruglov, Oleg" Date: Thu, 29 Aug 2024 15:15:00 -0700 Subject: [PATCH] Remove generate_weights --- .../test_incremental_basic_statistics_spmd.py | 3 +-- sklearnex/tests/_utils_spmd.py | 17 ++++++++--------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py b/sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py index e190fb2e49..63060e4e9b 100644 --- a/sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +++ b/sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py @@ -25,7 +25,6 @@ ) from sklearnex.tests._utils_spmd import ( _generate_statistic_data, - _generate_weights, _get_local_tensor, _mpi_libs_and_gpu_available, ) @@ -272,7 +271,7 @@ def test_incremental_basic_statistics_partial_fit_spmd_synthetic( if weighted: # Create weights array containing the weight for each sample in the data - weights = _generate_weights(n_samples, dtype=dtype) + weights = _generate_statistic_data(n_samples, dtype=dtype) local_weights = _get_local_tensor(weights) split_local_weights = np.array_split(local_weights, num_blocks) split_weights = np.array_split(weights, num_blocks) diff --git a/sklearnex/tests/_utils_spmd.py b/sklearnex/tests/_utils_spmd.py index 7fd6cb4e4e..4bdd4d4fd5 100644 --- a/sklearnex/tests/_utils_spmd.py +++ b/sklearnex/tests/_utils_spmd.py @@ -89,10 +89,16 @@ def _generate_classification_data( return X_train, X_test, y_train, y_test -def _generate_statistic_data(n_samples, n_features, dtype=np.float64, random_state=42): +def _generate_statistic_data( + n_samples, n_features=None, dtype=np.float64, random_state=42 +): # Generates statistical data gen = np.random.default_rng(random_state) - data = gen.uniform(low=-0.3, high=+0.7, size=(n_samples, n_features)).astype(dtype) + data = gen.uniform( + low=-0.3, + high=+0.7, + size=(n_samples, n_features) if n_features is not None else (n_samples,), + ).astype(dtype) return data @@ -111,13 +117,6 @@ def _generate_clustering_data( return X_train, X_test -def _generate_weights(n_samples, dtype=np.float64, random_state=42): - # Generates weights - gen = np.random.default_rng(random_state) - weights = gen.uniform(low=-0.3, high=+0.7, size=(n_samples)).astype(dtype) - return weights - - def _spmd_assert_allclose(spmd_result, batch_result, **kwargs): """Calls assert_allclose on spmd and batch results.