Skip to content

Commit

Permalink
Updating pylibraft pairwise_distance to cuvs (while still being backw…
Browse files Browse the repository at this point in the history
…ards compatible with pylibraft when appropriate)
  • Loading branch information
cjnolet committed Dec 19, 2024
1 parent 087edbb commit 16c96fa
Showing 1 changed file with 48 additions and 36 deletions.
84 changes: 48 additions & 36 deletions cupyx/scipy/spatial/distance.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
import cupy

cuvs_available = False
try:
from pylibraft.distance import pairwise_distance
pylibraft_available = True
except ModuleNotFoundError:
pylibraft_available = False
from cuvs.distance import pairwise_distance
cuvs_available = True
except ImportError:
try:
#cuVS distance primitives were previously in pylibraft
from pylibraft.distance import pairwise_distance
cuvs_available = True
except ImportError:
cuvs_available = False


def _convert_to_type(X, out_type):
Expand Down Expand Up @@ -116,8 +123,8 @@ def minkowski(u, v, p):
Returns:
minkowski (double): The Minkowski distance between vectors `u` and `v`.
"""
if not pylibraft_available:
raise RuntimeError('pylibraft is not installed')
if not cuvs_available:
raise RuntimeError('cuVS is not installed')
u = cupy.asarray(u)
v = cupy.asarray(v)
output_arr = cupy.zeros((1,), dtype=u.dtype)
Expand All @@ -141,8 +148,8 @@ def canberra(u, v):
Returns:
canberra (double): The Canberra distance between vectors `u` and `v`.
"""
if not pylibraft_available:
raise RuntimeError('pylibraft is not installed')
if not cuvs_available:
raise RuntimeError('cuVS is not installed')
u = cupy.asarray(u)
v = cupy.asarray(v)
output_arr = cupy.zeros((1,), dtype=u.dtype)
Expand All @@ -166,8 +173,8 @@ def chebyshev(u, v):
Returns:
chebyshev (double): The Chebyshev distance between vectors `u` and `v`.
"""
if not pylibraft_available:
raise RuntimeError('pylibraft is not installed')
if not cuvs_available:
raise RuntimeError('cuVS is not installed')
u = cupy.asarray(u)
v = cupy.asarray(v)
output_arr = cupy.zeros((1,), dtype=u.dtype)
Expand All @@ -192,8 +199,8 @@ def cityblock(u, v):
cityblock (double): The City Block distance between
vectors `u` and `v`.
"""
if not pylibraft_available:
raise RuntimeError('pylibraft is not installed')
if not cuvs_available:
raise RuntimeError('cuVS is not installed')
u = cupy.asarray(u)
v = cupy.asarray(v)
output_arr = cupy.zeros((1,), dtype=u.dtype)
Expand Down Expand Up @@ -222,8 +229,8 @@ def correlation(u, v):
correlation (double): The correlation distance between
vectors `u` and `v`.
"""
if not pylibraft_available:
raise RuntimeError('pylibraft is not installed')
if not cuvs_available:
raise RuntimeError('cuVS is not installed')
u = cupy.asarray(u)
v = cupy.asarray(v)
output_arr = cupy.zeros((1,), dtype=u.dtype)
Expand All @@ -249,8 +256,8 @@ def cosine(u, v):
Returns:
cosine (double): The Cosine distance between vectors `u` and `v`.
"""
if not pylibraft_available:
raise RuntimeError('pylibraft is not installed')
if not cuvs_available:
raise RuntimeError('cuVS is not installed')
u = cupy.asarray(u)
v = cupy.asarray(v)
output_arr = cupy.zeros((1,), dtype=u.dtype)
Expand Down Expand Up @@ -278,8 +285,8 @@ def hamming(u, v):
Returns:
hamming (double): The Hamming distance between vectors `u` and `v`.
"""
if not pylibraft_available:
raise RuntimeError('pylibraft is not installed')
if not cuvs_available:
raise RuntimeError('cuVS is not installed')
u = cupy.asarray(u)
v = cupy.asarray(v)
output_arr = cupy.zeros((1,), dtype=u.dtype)
Expand All @@ -303,8 +310,8 @@ def euclidean(u, v):
Returns:
euclidean (double): The Euclidean distance between vectors `u` and `v`.
"""
if not pylibraft_available:
raise RuntimeError('pylibraft is not installed')
if not cuvs_available:
raise RuntimeError('cuVS is not installed')
u = cupy.asarray(u)
v = cupy.asarray(v)
output_arr = cupy.zeros((1,), dtype=u.dtype)
Expand Down Expand Up @@ -332,8 +339,8 @@ def jensenshannon(u, v):
jensenshannon (double): The Jensen-Shannon distance between
vectors `u` and `v`.
"""
if not pylibraft_available:
raise RuntimeError('pylibraft is not installed')
if not cuvs_available:
raise RuntimeError('cuVS is not installed')
u = cupy.asarray(u)
v = cupy.asarray(v)
output_arr = cupy.zeros((1,), dtype=u.dtype)
Expand Down Expand Up @@ -361,8 +368,8 @@ def russellrao(u, v):
Returns:
hamming (double): The Hamming distance between vectors `u` and `v`.
"""
if not pylibraft_available:
raise RuntimeError('pylibraft is not installed')
if not cuvs_available:
raise RuntimeError('cuVS is not installed')
u = cupy.asarray(u)
v = cupy.asarray(v)
output_arr = cupy.zeros((1,), dtype=u.dtype)
Expand All @@ -387,8 +394,8 @@ def sqeuclidean(u, v):
sqeuclidean (double): The squared Euclidean distance between
vectors `u` and `v`.
"""
if not pylibraft_available:
raise RuntimeError('pylibraft is not installed.')
if not cuvs_available:
raise RuntimeError('cuVS is not installed.')
u = cupy.asarray(u)
v = cupy.asarray(v)
output_arr = cupy.zeros((1,), dtype=u.dtype)
Expand All @@ -414,8 +421,8 @@ def hellinger(u, v):
hellinger (double): The Hellinger distance between
vectors `u` and `v`.
"""
if not pylibraft_available:
raise RuntimeError('pylibraft is not installed')
if not cuvs_available:
raise RuntimeError('cuVS is not installed')
u = cupy.asarray(u)
v = cupy.asarray(v)
output_arr = cupy.zeros((1,), dtype=u.dtype)
Expand All @@ -440,8 +447,8 @@ def kl_divergence(u, v):
kl_divergence (double): The Kullback-Leibler divergence between
vectors `u` and `v`.
"""
if not pylibraft_available:
raise RuntimeError('pylibraft is not installed')
if not cuvs_available:
raise RuntimeError('cuVS is not installed')
u = cupy.asarray(u)
v = cupy.asarray(v)
output_arr = cupy.zeros((1,), dtype=u.dtype)
Expand Down Expand Up @@ -479,10 +486,14 @@ def cdist(XA, XB, metric='euclidean', out=None, **kwargs):
``dist(u=XA[i], v=XB[j])`` is computed and stored in the
:math:`ij` th entry.
"""
if not pylibraft_available:
raise RuntimeError('pylibraft is not installed')
XA = cupy.asarray(XA, dtype='float32')
XB = cupy.asarray(XB, dtype='float32')
if not cuvs_available:
raise RuntimeError('cuVS is not installed')

if XA.dtype not in ['float16', 'float32', 'float64']:
raise ValueError('XA must be float16, float32, or float64')

if XB.dtype not in ['float16', 'float32', 'float64']:
raise ValueError('XB must be float16, float32, or float64')

s = XA.shape
sB = XB.shape
Expand All @@ -501,8 +512,9 @@ def cdist(XA, XB, metric='euclidean', out=None, **kwargs):
p = kwargs["p"] if "p" in kwargs else 2.0

if out is not None:
if out.dtype != 'float32':
out = out.astype('float32', copy=False)
if out.dtype not in ['float16', 'float32', 'float64']:
raise ValueError('out must be float16, float32, or float64')

if out.shape != (mA, mB):
cupy.resize(out, (mA, mB))
out[:] = 0.0
Expand Down

0 comments on commit 16c96fa

Please sign in to comment.