Skip to content

Commit

Permalink
Remove usage of NotSupportedByEngineError and let sklearn unit tests …
Browse files Browse the repository at this point in the history
…fallback on other engines
  • Loading branch information
fcharras committed Feb 14, 2023
1 parent 0b81200 commit 14cfec1
Showing 1 changed file with 13 additions and 28 deletions.
41 changes: 13 additions & 28 deletions sklearn_numba_dpex/kmeans/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import sklearn
import sklearn.utils.validation as sklearn_validation
from sklearn.cluster._kmeans import KMeansCythonEngine
from sklearn.exceptions import NotSupportedByEngineError
from sklearn.utils import check_array, check_random_state
from sklearn.utils.validation import _is_arraylike_not_scalar

Expand Down Expand Up @@ -113,18 +112,11 @@ def __init__(self, estimator):
self._is_in_testing_mode = _is_in_testing_mode == "1"

def accepts(self, X, y, sample_weight):

if (algorithm := self.estimator.algorithm) not in ("lloyd", "auto", "full"):
if self._is_in_testing_mode:
raise NotSupportedByEngineError(
"The sklearn_nunmba_dpex engine for KMeans only support the Lloyd"
f" algorithm, {algorithm} is not supported."
)
else:
return False
if self.estimator.algorithm not in ("lloyd", "auto", "full"):
return False

if sp.issparse(X):
return self._is_in_testing_mode
return False

return True

Expand Down Expand Up @@ -308,23 +300,16 @@ def _validate_data(self, X, reset=True):
accepted_dtypes = [np.float32]

with _validate_with_array_api(device):
try:
X = self.estimator._validate_data(
X,
accept_sparse=False,
dtype=accepted_dtypes,
order=self.order,
copy=False,
reset=reset,
force_all_finite=True,
estimator=self.estimator,
)
return X
except TypeError as type_error:
if "A sparse matrix was passed, but dense data is required" in str(
type_error
):
raise NotSupportedByEngineError from type_error
return self.estimator._validate_data(
X,
accept_sparse=False,
dtype=accepted_dtypes,
order=self.order,
copy=False,
reset=reset,
force_all_finite=True,
estimator=self.estimator,
)

def _check_sample_weight(self, sample_weight, X):
"""Adapted from sklearn.utils.validation._check_sample_weight to be compatible
Expand Down

0 comments on commit 14cfec1

Please sign in to comment.