Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add neighbors algorithm based on NSW graphs #143

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

LeoSvalov
Copy link

@LeoSvalov LeoSvalov commented May 24, 2022

Good afternoon!

I would like to add the algorithm to do the approximate nearest neighbors search.

The method is based on Navigable small world graphs (NSW graphs) that tends to demonstrate better performance in the high-dimensional data space [1] in comparison with existing Scikit-Learn KDTree and BallTree methods, starting from data dimension D > 50.

The API of the algorithm is very similar to the existing alternatives, despite the fact that NSWGraph also can be utilized in KNearestNeighbors classifier manner, as the base estimator paradigm (fit/predict) is included.

Possible ways to use the method:

from sklearn_extra.neighbors import NSWGraph
from sklearn.datasets import load_iris
import numpy as np
  1. As object to query k-nearest neighbors.
rng = np.random.RandomState(10)
X = rng.random_sample((50, 128))
nswgraph = NSWGraph()
nswgraph.build(X)
X_val = rng.random_sample((5, 128))
dists, inds = nswgraph.query(X_val, k=3)
  1. As neighbors estimator with taking into account the target classes of the data.
X,y = load_iris(return_X_y=True)
estimator = NSWGraph()
estimator.fit(X,y)
y_pred = estimator.predict(X)

References

[1] Malkov, Y., Ponomarenko, A., Logvinov, A., & Krylov, V. (2014).
Approximate nearest neighbor algorithm based on navigable small world graphs.
Information Systems, 45, 61-68.

@ogrisel
Copy link

ogrisel commented Jun 10, 2022

Thanks for the contribution! While I am not yet sure if it would meet a consensus of maintainers to be accepted in the scikit-learn code base, it would surely help to run some benchmarks.

If the speed of your PR can demonstrate to be approximately competitive with alternative implementations, it would surely help convince maintainers that it is worth investing their time to review the PR and accept the long term maintenance burden that will come with a new method.

Ideally the benchmarks could be based on this existing infrastructure:

In particular I would be interested in a comparison with nswlib's implementation and alternative method not based on NSW graphs such as https://github.com/lmcinnes/pynndescent.

@ogrisel
Copy link

ogrisel commented Jun 13, 2022

I just realised that this is not the scikit-learn/scikit-learn repo but the scikit-learn-extra repo as I arrived to this PR from the scikit-learn/scikit-learn#23450 issue from the main scikit-learn issue tracker.

I think it would be great to have an implementation of NSW nearest neighbors in scikit-learn-extra. But before reviewing this PR, I would like to see some performance benchmark results as requested above.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants