Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuber21 committed Nov 18, 2024
1 parent a1b2d1d commit b8812fd
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
12 changes: 6 additions & 6 deletions onedal/common/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,22 +99,22 @@ def decorate_functions(
backend_method, backend_type=backend_type, name=formatted_name
)

# assign it to the class under the correct method name
destination = name_overrides.get(method, method)

# Log a message about what's happening
existing = getattr(cls, method, None)
is_abstract = getattr(existing, "__isabstractmethod__", False)
existing = getattr(cls, destination, None)
is_abstract = getattr(destination, "__isabstractmethod__", False)
if existing and not is_abstract:
method_type = getattr(existing, "backend_type", None)
logger.debug(
f"Replaced existing method '{method}'. Old backend {method_type}, new backend {backend_type}"
f"Replaced existing method '{destination}'. Old backend {method_type}, new backend {backend_type}"
)
else:
logger.debug(
f"Assigned method '{cls.__name__}.{method}' from {backend_type} backend"
)

# assign it to the class under the correct method name
destination = name_overrides.get(method, method)

setattr(cls, destination, wrapped_method)

if is_abstract:
Expand Down
16 changes: 8 additions & 8 deletions onedal/neighbors/neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,6 @@ def train(self, *args): ...
@abstractmethod
def infer(self, *args): ...

# direct access to the backend model constructor
@abstractmethod
def model(self): ...

def _validate_data(
self, X, y=None, reset=True, validate_separately=None, **check_params
):
Expand Down Expand Up @@ -187,7 +183,6 @@ def _get_daal_params(self, data, n_neighbors=None):
return params


@bind_default_backend("neighbors.classification", ["train", "infer", "model"])
class NeighborsBase(NeighborsCommonBase, metaclass=ABCMeta):
def __init__(
self,
Expand Down Expand Up @@ -350,7 +345,7 @@ def _kneighbors(self, X=None, n_neighbors=None, return_distance=True, queue=None
type(self._onedal_model) is kdtree_knn_classification_model
or type(self._onedal_model) is bf_knn_classification_model
):
params = self._get_daal_params(X, n_neighbors=n_neighbors)
params = super()._get_daal_params(X, n_neighbors=n_neighbors)
prediction_results = self._onedal_predict(
self._onedal_model, X, params, queue=queue
)
Expand Down Expand Up @@ -412,6 +407,7 @@ def _kneighbors(self, X=None, n_neighbors=None, return_distance=True, queue=None
return neigh_ind


@bind_default_backend("neighbors.classification", ["train", "infer", "model"])
class KNeighborsClassifier(NeighborsBase, ClassifierMixin):
def __init__(
self,
Expand All @@ -434,8 +430,12 @@ def __init__(
)
self.weights = weights

# direct access to the backend model constructor
@abstractmethod
def model(self): ...

def _get_daal_params(self, data):
params = self._get_daal_params(data)
params = super()._get_daal_params(data)
params["resultsToEvaluate"] = "computeClassLabels"
params["resultsToCompute"] = ""
return params
Expand Down Expand Up @@ -590,7 +590,7 @@ def _get_onedal_params(self, X, y=None):
return params

def _get_daal_params(self, data):
params = self._get_daal_params(data)
params = super()._get_daal_params(data)
params["resultsToCompute"] = "computeIndicesOfNeighbors|computeDistances"
params["resultsToEvaluate"] = "none"
return params
Expand Down

0 comments on commit b8812fd

Please sign in to comment.