Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GraphSearcher::resume #185

Merged
merged 5 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion UPGRADING.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
it easy to make apples-to-apples comparisons with Lucene HNSW graphs.) So,
if you were building a graph of M=16 with JVector2, you should build it with M=32
with JVector3.
- `NodeSimilarity.ReRanker` api has changed. The interface is no longer parameterized,
- `NodeSimilarity.ReRanker` renamed to `Reranker`
- `NodeSimilarity.Reranker` api has changed. The interface is no longer parameterized,
and the `similarityTo` method no longer takes a Map parameter (provided by `search` with
the full vectors associated with the nodes returned). This is because we discovered that
(in contrast with the original DiskANN design) it is more performant to read vectors lazily
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ private void reconnectOrphanedNodes() {
var notSelfBits = createNotSelfBits(node);
var value = v1.get().vectorValue(node);
NodeSimilarity.ExactScoreFunction scoreFunction = i1 -> scoreBetween(v2.get().vectorValue(i1), value);
var result = gs.get().searchInternal(scoreFunction, null, beamWidth, 0.0f, graph.entry(), notSelfBits).getNodes();
int ep = graph.entry();
var result = gs.get().searchInternal(scoreFunction, null, beamWidth, 0.0f, 0.0f, ep, notSelfBits).getNodes();
// connect this node to the closest neighbor that hasn't already been used as a connection target
// (since this edge is likely to be the "worst" one in that target's neighborhood, it's likely to be
// overwritten by the next node to need reconnection if we don't enforce uniqueness)
Expand Down Expand Up @@ -325,7 +326,7 @@ public long addGraphNode(int node, RandomAccessVectorValues<T> vectors) {

var bits = new ExcludingBits(node);
// find best "natural" candidates with a beam search
var result = gs.get().searchInternal(scoreFunction, null, beamWidth, 0.0f, ep, bits);
var result = gs.get().searchInternal(scoreFunction, null, beamWidth, 0.0f, 0.0f, ep, bits);

// Update neighbors with these candidates.
// The DiskANN paper calls for using the entire set of visited nodes along the search path as
Expand Down Expand Up @@ -395,7 +396,7 @@ public void improveConnections(int node) {
int ep = graph.entry();
NodeSimilarity.ExactScoreFunction scoreFunction = i -> scoreBetween(vc.get().vectorValue(i), value);
var bits = new ExcludingBits(node);
var result = gs.get().searchInternal(scoreFunction, null, beamWidth, 0.0f, ep, bits);
var result = gs.get().searchInternal(scoreFunction, null, beamWidth, 0.0f, 0.0f, ep, bits);
var natural = toScratchCandidates(result.getNodes(), result.getNodes().length, naturalScratchPooled.get());
updateNeighbors(graph.getNeighbors(node), natural, NodeArray.EMPTY);
}
Expand Down Expand Up @@ -494,7 +495,8 @@ private void addNNDescentConnections(int node) {
{
var value = v1.get().vectorValue(node);
NodeSimilarity.ExactScoreFunction scoreFunction = i -> scoreBetween(v2.get().vectorValue(i), value);
var result = gs.get().searchInternal(scoreFunction, null, beamWidth, 0.0f, graph.entry(), notSelfBits);
int ep = graph.entry();
var result = gs.get().searchInternal(scoreFunction, null, beamWidth, 0.0f, 0.0f, ep, notSelfBits);
var candidates = toScratchCandidates(result.getNodes(), result.getNodes().length, scratch.get());
// We use just the topK results as candidates, which is much less expensive than computing scores for
// the other visited nodes. See comments in addGraphNode.
Expand Down Expand Up @@ -538,7 +540,8 @@ private int approximateMedioid() {

// search for the node closest to the centroid
NodeSimilarity.ExactScoreFunction scoreFunction = i -> scoreBetween(vc.get().vectorValue(i), (T) centroid);
var result = gs.get().searchInternal(scoreFunction, null, beamWidth, 0.0f, graph.entry(), Bits.ALL);
int ep = graph.entry();
var result = gs.get().searchInternal(scoreFunction, null, beamWidth, 0.0f, 0.0f, ep, Bits.ALL);
return result.getNodes()[0].node;
}
}
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

package io.github.jbellis.jvector.graph;

import java.util.Map;

/** Encapsulates comparing node distances. */
public interface NodeSimilarity {
/** for one-off comparisons between nodes */
Expand Down Expand Up @@ -61,7 +59,7 @@ default boolean isExact() {
float similarityTo(int node2);
}

interface ReRanker {
interface Reranker {
float similarityTo(int node2);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import io.github.jbellis.jvector.graph.NodeSimilarity;
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import io.github.jbellis.jvector.graph.SearchResult;
import io.github.jbellis.jvector.pq.BinaryQuantization;
import io.github.jbellis.jvector.pq.CompressedVectors;
import io.github.jbellis.jvector.pq.ProductQuantization;
import io.github.jbellis.jvector.pq.VectorCompressor;
Expand Down Expand Up @@ -166,7 +165,7 @@ private static ResultSummary performQueries(DataSet ds, RandomAccessVectorValues
if (cv != null) {
var view = index.getView();
NodeSimilarity.ApproximateScoreFunction sf = cv.approximateScoreFunctionFor(queryVector, ds.similarityFunction);
NodeSimilarity.ReRanker rr = (j) -> ds.similarityFunction.compare(queryVector, exactVv.vectorValue(j));
NodeSimilarity.Reranker rr = (j) -> ds.similarityFunction.compare(queryVector, exactVv.vectorValue(j));
sr = new GraphSearcher.Builder<>(view)
.build()
.search(sf, rr, efSearch, Bits.ALL);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ String search(String input, SessionContext ctx) {
SearchResult r;
if (ctx.cv != null) {
NodeSimilarity.ApproximateScoreFunction sf = ctx.cv.approximateScoreFunctionFor(queryVector, ctx.similarityFunction);
NodeSimilarity.ReRanker rr = (j) -> ctx.similarityFunction.compare(queryVector, ctx.ravv.vectorValue(j));
NodeSimilarity.Reranker rr = (j) -> ctx.similarityFunction.compare(queryVector, ctx.ravv.vectorValue(j));
r = new GraphSearcher.Builder<>(ctx.index.getView()).build().search(sf, rr, searchEf, Bits.ALL);
} else {
r = GraphSearcher.search(queryVector, topK, ctx.ravv, VectorEncoding.FLOAT32, ctx.similarityFunction, ctx.index, Bits.ALL);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ private static void testRecallInternal(GraphIndex<float[]> graph, RandomAccessVe
}
else {
NodeSimilarity.ApproximateScoreFunction sf = compressedVectors.approximateScoreFunctionFor(queryVector, VectorSimilarityFunction.EUCLIDEAN);
NodeSimilarity.ReRanker rr = (j) -> VectorSimilarityFunction.EUCLIDEAN.compare(queryVector, ravv.vectorValue(j));
NodeSimilarity.Reranker rr = (j) -> VectorSimilarityFunction.EUCLIDEAN.compare(queryVector, ravv.vectorValue(j));
nn = searcher.search(sf, rr, 100, Bits.ALL).getNodes();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
import org.junit.Test;

import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.junit.Assert.*;
import static org.junit.jupiter.api.Assertions.assertThrows;
Expand Down Expand Up @@ -383,6 +385,20 @@ public void testRandom() {
assertTrue("overlap=" + overlap, overlap > 0.9);
}

protected NodeSimilarity.ExactScoreFunction getScoreFunction(T query, RandomAccessVectorValues<T> vectors) {
NodeSimilarity.ExactScoreFunction scoreFunction = i -> {
switch (getVectorEncoding()) {
case BYTE:
return similarityFunction.compare((byte[]) query, (byte[]) vectors.vectorValue(i));
case FLOAT32:
return similarityFunction.compare((float[]) query, (float[]) vectors.vectorValue(i));
default:
throw new RuntimeException("Unsupported vector encoding: " + getVectorEncoding());
}
};
return scoreFunction;
}

private int computeOverlap(int[] a, int[] b) {
Arrays.sort(a);
Arrays.sort(b);
Expand Down Expand Up @@ -526,7 +542,7 @@ static byte[][] createRandomByteVectors(int size, int dimension, Random random)
* Generate a random bitset where before startIndex all bits are set, and after startIndex each
* entry has a 2/3 probability of being set.
*/
private static Bits createRandomAcceptOrds(int startIndex, int length) {
protected static Bits createRandomAcceptOrds(int startIndex, int length) {
FixedBitSet bits = new FixedBitSet(length);
// all bits are set before startIndex
for (int i = 0; i < startIndex; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public void testThreshold() throws IOException {
for (int i = 0; i < 10; i++) {
TestParams tp = createTestParams(vectors);
searcher = new GraphSearcher.Builder<>(onDiskGraph.getView()).build();
NodeSimilarity.ReRanker reranker = (j) -> VectorSimilarityFunction.EUCLIDEAN.compare(tp.q, ravv.vectorValue(j));
NodeSimilarity.Reranker reranker = (j) -> VectorSimilarityFunction.EUCLIDEAN.compare(tp.q, ravv.vectorValue(j));
var asf = cv.approximateScoreFunctionFor(tp.q, VectorSimilarityFunction.EUCLIDEAN);
var result = searcher.search(asf, reranker, vectors.length, tp.th, Bits.ALL);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,18 @@

import com.carrotsearch.randomizedtesting.RandomizedTest;
import io.github.jbellis.jvector.TestUtil;
import io.github.jbellis.jvector.util.Bits;
import io.github.jbellis.jvector.util.FixedBitSet;
import io.github.jbellis.jvector.vector.VectorEncoding;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import org.junit.Before;
import org.junit.Test;

import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
Expand Down Expand Up @@ -112,4 +118,41 @@ public void testSearchWithSkewedAcceptOrds() {
// are closest to the query vector: sum(500,509) = 5045
assertTrue("sum(result docs)=" + sum, sum < 5100);
}

@Test
// build a random graph and check that resuming a search finds the same nodes as an equivalent from-search search
// this test is float-specific because random byte vectors are far more likely to have tied similarities,
// which throws off our assumption that resume picks back up with the same state that the original search
// left off in (because evictedResults from the first search may not end up in the same order in the
// candidates queue)
public void testResume() {
int size = 1000;
int dim = 2;
var vectors = vectorValues(size, dim);
var builder = new GraphIndexBuilder<>(vectors, getVectorEncoding(), similarityFunction, 20, 30, 1.0f, 1.4f);
var graph = builder.build();
Bits acceptOrds = getRandom().nextBoolean() ? Bits.ALL : createRandomAcceptOrds(0, size);

int initialTopK = 10;
int resumeTopK = 15;
var query = randomVector(dim);
var searcher = new GraphSearcher.Builder<>(graph.getView()).build();

var initial = searcher.search(getScoreFunction(query, vectors), null, initialTopK, acceptOrds);
assertEquals(initialTopK, initial.getNodes().length);

var resumed = searcher.resume(resumeTopK);
assertEquals(resumeTopK, resumed.getNodes().length);

var expected = searcher.search(getScoreFunction(query, vectors), null, initialTopK + resumeTopK, acceptOrds);
assertEquals(expected.getVisitedCount(), initial.getVisitedCount() + resumed.getVisitedCount());
assertEquals(expected.getNodes().length, initial.getNodes().length + resumed.getNodes().length);
var initialResumedResults = Stream.concat(Arrays.stream(initial.getNodes()), Arrays.stream(resumed.getNodes()))
.sorted(Comparator.comparingDouble(ns -> -ns.score))
.collect(Collectors.toList());
var expectedResults = List.of(expected.getNodes());
for (int i = 0; i < expectedResults.size(); i++) {
assertEquals(expectedResults.get(i).score, initialResumedResults.get(i).score, 1E-6);
}
}
}