-
Notifications
You must be signed in to change notification settings - Fork 179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[experiment] ENH: using only raw inputs for onedal backend #2153
Draft
samir-nasibli
wants to merge
65
commits into
uxlfoundation:main
Choose a base branch
from
samir-nasibli:enh/raw_inputs
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 20 commits
Commits
Show all changes
65 commits
Select commit
Hold shift + click to select a range
daed528
ENH: using only raw inputs for onedal backend
samir-nasibli 1be2ffb
minor fix
samir-nasibli a23b677
lin
samir-nasibli 664e140
fix usw_raw_input True/False with dpctl tensor on device
ahuber21 518dceb
Add hacks to kmeans
ahuber21 df9d930
Basic statistics online
samir-nasibli 2954913
Merge branch 'enh/raw_inputs' of https://github.com/samir-nasibli/sci…
samir-nasibli 3ef345c
Covariance support
ethanglaser f1c9233
Merge branch 'enh/raw_inputs' of https://github.com/samir-nasibli/sci…
ethanglaser 66d7b2d
DBSCAN support
samir-nasibli c5d26a4
Merge branch 'enh/raw_inputs' of https://github.com/samir-nasibli/sci…
samir-nasibli 1350c10
minor fix for dbscan
samir-nasibli 8aaaa70
minor fix for DBSCAN
samir-nasibli f0d92ae
Apply raw input for batch linear and logistic regression
Alexsandruss 3b58beb
Apply linters
Alexsandruss d7f2c3c
fix for DBSCAN
samir-nasibli 1aca420
support for Random Forest
samir-nasibli 362930a
PCA support (batch)
ethanglaser bc37391
Merge branch 'enh/raw_inputs' of https://github.com/samir-nasibli/sci…
ethanglaser 102dcae
minor fix for dbscan and rf
samir-nasibli 6edab5b
fully fixed DBSCAN
samir-nasibli e153a28
Add Incremental Linear Regression
Alexsandruss 37d32c9
Linting
Alexsandruss 71c5135
add modification to knn
ahuber21 db9f021
minor update for RF
samir-nasibli bc353da
fix for RandomForestClassifier
samir-nasibli e873205
minor for RF
samir-nasibli fe3222a
Update online algos
olegkkruglov 5b3ad17
Merge branch 'enh/raw_inputs' of https://github.com/samir-nasibli/sci…
samir-nasibli eaaab32
fix for RF regressor
samir-nasibli a7f0c2d
fix workaround for knn
ahuber21 d9a2966
kmeans predict support
ethanglaser 3562c69
Merge remote-tracking branch 'origin/main' into enh/raw_inputs
ahuber21 42c3614
fix merge errors
ahuber21 53bcc7b
fix some tests
ahuber21 9964c5a
fixup
ahuber21 84afb62
undo more changes that broke tests
ahuber21 cf5b736
format
ahuber21 92393b9
restore original behavior when running without raw inputs
ahuber21 13471e5
restore original behavior when running without raw inputs
ahuber21 a8f3f19
align code
ahuber21 2b07c00
restore original from_table
ahuber21 6104736
add use_raw_input tests for incremental covariance
ahuber21 df03233
Add basic statistics testing
ahuber21 8a166b7
add incremental basic statistics
ahuber21 fb5f5fa
add dbscan
ahuber21 7072041
Merge remote-tracking branch 'origin/main' into dev/ahuber/raw-inputs…
ahuber21 91384ed
add kmeans
ahuber21 6dec57d
add covariance
ahuber21 529a7b8
align get_config() import and use_raw_input retrieval
ahuber21 9f78cbd
add incremental_pca
ahuber21 658ccc1
add pca
ahuber21 5e74a54
add incremental linear
ahuber21 dfbf223
add linear_model
ahuber21 c4094fb
Merge branch 'dev/ahuber/raw-inputs-dispatching' into enh/raw_inputs
ahuber21 bb5206f
raw inputs updates for functional forest predict
ethanglaser 8211a23
fixes for logreg predict_proba, knnreg, inc cov, inc pca
ethanglaser e3425bf
dbscan + inc linreg changes
ethanglaser 0630bc1
Merge 'upstream/main' into enh/raw_inputs
ethanglaser 52ba18a
black
ethanglaser 90b7175
temporary for CI
ethanglaser f4d18cd
isorted
ethanglaser d84a559
tuple indices safeguarding
ethanglaser 2daeeb7
incremental bs fit fixes
ethanglaser fb3d0bc
dbscan CI fixes
ethanglaser File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -32,10 +32,12 @@ | |||||
from sklearn.metrics.pairwise import euclidean_distances | ||||||
from sklearn.utils import check_random_state | ||||||
|
||||||
from .._config import _get_config | ||||||
from ..common._base import BaseEstimator as onedal_BaseEstimator | ||||||
from ..common._mixin import ClusterMixin, TransformerMixin | ||||||
from ..datatypes import _convert_to_supported, from_table, to_table | ||||||
from ..utils import _check_array, _is_arraylike_not_scalar, _is_csr | ||||||
from ..utils._array_api import _get_sycl_namespace | ||||||
|
||||||
|
||||||
class _BaseKMeans(onedal_BaseEstimator, TransformerMixin, ClusterMixin, ABC): | ||||||
|
@@ -80,7 +82,7 @@ def _get_kmeans_init(self, cluster_count, seed, algorithm): | |||||
def _get_basic_statistics_backend(self, result_options): | ||||||
return BasicStatistics(result_options) | ||||||
|
||||||
def _tolerance(self, X_table, rtol, is_csr, policy, dtype): | ||||||
def _tolerance(self, X_table, rtol, is_csr, policy, dtype, sua_iface): | ||||||
"""Compute absolute tolerance from the relative tolerance""" | ||||||
if rtol == 0.0: | ||||||
return rtol | ||||||
|
@@ -94,7 +96,7 @@ def _tolerance(self, X_table, rtol, is_csr, policy, dtype): | |||||
return mean_var * rtol | ||||||
|
||||||
def _check_params_vs_input( | ||||||
self, X_table, is_csr, policy, default_n_init=10, dtype=np.float32 | ||||||
self, X_table, is_csr, policy, default_n_init=10, dtype=np.float32, sua_iface=None | ||||||
): | ||||||
# n_clusters | ||||||
if X_table.shape[0] < self.n_clusters: | ||||||
|
@@ -103,7 +105,7 @@ def _check_params_vs_input( | |||||
) | ||||||
|
||||||
# tol | ||||||
self._tol = self._tolerance(X_table, self.tol, is_csr, policy, dtype) | ||||||
self._tol = self._tolerance(X_table, self.tol, is_csr, policy, dtype, sua_iface) | ||||||
|
||||||
# n-init | ||||||
# TODO(1.4): Remove | ||||||
|
@@ -261,18 +263,33 @@ def _fit_backend( | |||||
) | ||||||
|
||||||
def _fit(self, X, module, queue=None): | ||||||
policy = self._get_policy(queue, X) | ||||||
is_csr = _is_csr(X) | ||||||
X = _check_array( | ||||||
X, dtype=[np.float64, np.float32], accept_sparse="csr", force_all_finite=False | ||||||
) | ||||||
|
||||||
use_raw_input = _get_config().get("use_raw_input") is True | ||||||
if use_raw_input and _get_sycl_namespace(X)[0] is not None: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
queue = X.sycl_queue | ||||||
|
||||||
if not use_raw_input: | ||||||
X = _check_array( | ||||||
X, | ||||||
dtype=[np.float64, np.float32], | ||||||
accept_sparse="csr", | ||||||
force_all_finite=False, | ||||||
) | ||||||
|
||||||
policy = self._get_policy(queue, X) | ||||||
|
||||||
X = _convert_to_supported(policy, X) | ||||||
dtype = get_dtype(X) | ||||||
X_table = to_table(X) | ||||||
sua_iface = _get_sycl_namespace(X)[0] | ||||||
X_table = to_table(X, sua_iface=sua_iface) | ||||||
|
||||||
self._check_params_vs_input(X_table, is_csr, policy, dtype=dtype) | ||||||
self._check_params_vs_input( | ||||||
X_table, is_csr, policy, dtype=dtype, sua_iface=sua_iface | ||||||
) | ||||||
|
||||||
params = self._get_onedal_params(is_csr, dtype) | ||||||
# not used? | ||||||
# params = self._get_onedal_params(is_csr, dtype) | ||||||
|
||||||
self.n_features_in_ = X_table.column_count | ||||||
|
||||||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When running sklearnex example incremental_basic_statistics_dpctl.py leads to AttributeError: 'NoneType' object has no attribute 'ravel'