diff --git a/cupyx/scipy/spatial/distance.py b/cupyx/scipy/spatial/distance.py index 8d12090daf2..425f8b09274 100644 --- a/cupyx/scipy/spatial/distance.py +++ b/cupyx/scipy/spatial/distance.py @@ -7,7 +7,7 @@ cuvs_available = True except ImportError: try: - #cuVS distance primitives were previously in pylibraft + # cuVS distance primitives were previously in pylibraft from pylibraft.distance import pairwise_distance pylibraft_available = True except ImportError: @@ -110,12 +110,14 @@ def __init__(self, canonical_name=None, aka=None, _METRICS_NAMES = list(_METRICS.keys()) + def check_soft_dependencies(): if not cuvs_available: if not pylibraft_available: raise RuntimeError('cuVS >= 24.12 or pylibraft < ' '24.12 should be installed to use this feature') + def minkowski(u, v, p): """Compute the Minkowski distance between two 1-D arrays. @@ -496,7 +498,6 @@ def cdist(XA, XB, metric='euclidean', out=None, **kwargs): """ check_soft_dependencies() - if pylibraft_available or \ (cuvs_available and XA.dtype not in ['float32', 'float64']): XA = cupy.asarray(XA, dtype='float32')