Skip to content

Commit

Permalink
implement and use scipy-ckdtree as default (faster than kdtree)
Browse files Browse the repository at this point in the history
  • Loading branch information
naspert committed Mar 21, 2018
1 parent 4a4d597 commit 1309e92
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 9 deletions.
40 changes: 37 additions & 3 deletions pygsp/graphs/nngraphs/nngraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
'manhattan': 1,
'max_dist': np.inf
},
'scipy-ckdtree': {
'euclidean': 2,
'manhattan': 1,
'max_dist': np.inf
},
'scipy-pdist' : {
'euclidean': 'euclidean',
'manhattan': 'cityblock',
Expand All @@ -44,6 +49,13 @@ def _knn_sp_kdtree(X, num_neighbors, dist_type, order=0):
p=_dist_translation['scipy-kdtree'][dist_type])
return NN, D

def _knn_sp_ckdtree(X, num_neighbors, dist_type, order=0):
kdt = sps.cKDTree(X)
D, NN = kdt.query(X, k=(num_neighbors + 1),
p=_dist_translation['scipy-ckdtree'][dist_type])
return NN, D


def _knn_flann(X, num_neighbors, dist_type, order):
# the combination FLANN + max_dist produces incorrect results
# do not allow it
Expand All @@ -66,9 +78,27 @@ def _knn_flann(X, num_neighbors, dist_type, order):
def _radius_sp_kdtree(X, epsilon, dist_type, order=0):
kdt = sps.KDTree(X)
D, NN = kdt.query(X, k=None, distance_upper_bound=epsilon,
p=_dist_translation['scipy-kdtree'][dist_type])
p=_dist_translation['scipy-kdtree'][dist_type])
return NN, D

def _radius_sp_ckdtree(X, epsilon, dist_type, order=0):
N, dim = np.shape(X)
kdt = sps.cKDTree(X)
nn = kdt.query_ball_point(X, r=epsilon,
p=_dist_translation['scipy-ckdtree'][dist_type])
D = []
NN = []
for k in range(N):
x = np.matlib.repmat(X[k, :], len(nn[k]), 1)
d = np.linalg.norm(x - X[nn[k], :],
ord=_dist_translation['scipy-ckdtree'][dist_type],
axis=1)
nidx = d.argsort()
NN.append(np.take(nn[k], nidx))
D.append(np.sort(d))
return NN, D


def _knn_sp_pdist(X, num_neighbors, dist_type, order):
pd = sps.distance.squareform(
sps.distance.pdist(X,
Expand Down Expand Up @@ -142,7 +172,8 @@ class NNGraph(Graph):
is 'knn').
backend : {'scipy-kdtree', 'scipy-pdist', 'flann'}
Type of the backend for graph construction.
- 'scipy-kdtree'(default) will use scipy.spatial.KDTree
- 'scipy-kdtree' will use scipy.spatial.KDTree
- 'scipy-ckdtree'(default) will use scipy.spatial.cKDTree
- 'scipy-pdist' will use scipy.spatial.distance.pdist (slowest but exact)
- 'flann' use Fast Library for Approximate Nearest Neighbors (FLANN)
center : bool, optional
Expand Down Expand Up @@ -183,7 +214,7 @@ class NNGraph(Graph):
"""

def __init__(self, Xin, NNtype='knn', backend='scipy-kdtree', center=True,
def __init__(self, Xin, NNtype='knn', backend='scipy-ckdtree', center=True,
rescale=True, k=10, sigma=0.1, epsilon=0.01, gtype=None,
plotting={}, symmetrize_type='average', dist_type='euclidean',
order=0, **kwargs):
Expand All @@ -197,15 +228,18 @@ def __init__(self, Xin, NNtype='knn', backend='scipy-kdtree', center=True,
self.sigma = sigma
self.epsilon = epsilon
_dist_translation['scipy-kdtree']['minkowski'] = order
_dist_translation['scipy-ckdtree']['minkowski'] = order

self._nn_functions = {
'knn': {
'scipy-kdtree': _knn_sp_kdtree,
'scipy-ckdtree': _knn_sp_ckdtree,
'scipy-pdist': _knn_sp_pdist,
'flann': _knn_flann
},
'radius': {
'scipy-kdtree': _radius_sp_kdtree,
'scipy-ckdtree': _radius_sp_ckdtree,
'scipy-pdist': _radius_sp_pdist,
'flann': _radius_flann
},
Expand Down
13 changes: 7 additions & 6 deletions pygsp/tests/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def test_set_coordinates(self):
def test_nngraph(self):
Xin = np.arange(90).reshape(30, 3)
dist_types = ['euclidean', 'manhattan', 'max_dist', 'minkowski']
backends = ['scipy-kdtree', 'scipy-pdist', 'flann']
backends = ['scipy-kdtree', 'scipy-ckdtree', 'scipy-pdist', 'flann']
order=3 # for minkowski, FLANN only accepts integer orders

for cur_backend in backends:
Expand All @@ -194,9 +194,10 @@ def test_nngraph(self):
NNtype='knn', backend=cur_backend,
dist_type=dist_type)
else:
graphs.NNGraph(Xin, NNtype='radius',
backend=cur_backend,
dist_type=dist_type, order=order)
if cur_backend != 'flann': #pyflann fails on radius query
graphs.NNGraph(Xin, NNtype='radius',
backend=cur_backend,
dist_type=dist_type, order=order)
graphs.NNGraph(Xin, NNtype='knn',
backend=cur_backend,
dist_type=dist_type, order=order)
Expand All @@ -208,9 +209,9 @@ def test_nngraph(self):
dist_type=dist_type)

def test_nngraph_consistency(self):
Xin = np.random.uniform(-5, 5, (60, 3))
Xin = np.arange(90).reshape(30, 3)
dist_types = ['euclidean', 'manhattan', 'max_dist', 'minkowski']
backends = ['scipy-kdtree', 'flann']
backends = ['scipy-kdtree', 'scipy-ckdtree', 'flann']
num_neighbors=4
epsilon=0.1

Expand Down

0 comments on commit 1309e92

Please sign in to comment.