Skip to content

Commit

Permalink
Reduce repetition in Bench. Factor out more concise methods.
Browse files Browse the repository at this point in the history
  • Loading branch information
jkni committed Mar 7, 2024
1 parent 913b9ff commit 99992bf
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ public void close() throws IOException {
graph.close();
}

@Override
public String toString() {
return String.format("CachingADCGraphIndex(graph=%s)", graph);
}

public class CachedView implements ADCView, ApproximateScoreProvider {
private final ADCView view;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ public void close() throws IOException {
graph.close();
}

@Override
public String toString() {
return String.format("CachingGraphIndex(graph=%s)", graph);
}

private class CachedView implements View {
private final View view;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,11 @@ public void close() throws IOException {
readerSupplier.close();
}

@Override
public String toString() {
return String.format("OnDiskADCGraphIndex(size=%d, entryPoint=%d)", size, entryNode);
}

/**
* @param graph the graph to write
* @param vectors the vectors associated with each node
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,11 @@ public void close() throws IOException {
readerSupplier.close();
}

@Override
public String toString() {
return String.format("OnDiskGraphIndex(size=%d, entryPoint=%d)", size, entryNode);
}

/**
* @param graph the graph to write
* @param vectors the vectors associated with each node
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,12 @@ public int hashCode() {
result = 31 * result + Arrays.deepHashCode(compressedVectors);
return result;
}

@Override
public String toString() {
return "BQVectors{" +
"bq=" + bq +
", count=" + compressedVectors.length +
'}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -154,4 +154,12 @@ public long ramBytesUsed() {
long compressedVectorSize = RamUsageEstimator.sizeOf(compressedVectors[0]);
return codebooksSize + (compressedVectorSize * compressedVectors.length);
}

@Override
public String toString() {
return "PQVectors{" +
"pq=" + pq +
", count=" + compressedVectors.length +
'}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,22 @@
import io.github.jbellis.jvector.graph.GraphIndex;
import io.github.jbellis.jvector.graph.GraphIndexBuilder;
import io.github.jbellis.jvector.graph.GraphSearcher;
import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues;
import io.github.jbellis.jvector.graph.NodeSimilarity;
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import io.github.jbellis.jvector.graph.SearchResult;
import io.github.jbellis.jvector.pq.CompressedVectors;
import io.github.jbellis.jvector.pq.PQVectors;
import io.github.jbellis.jvector.pq.ProductQuantization;
import io.github.jbellis.jvector.pq.VectorCompressor;
import io.github.jbellis.jvector.util.Bits;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorizationProvider;
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
import io.github.jbellis.jvector.vector.types.VectorFloat;

import java.io.BufferedOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.IdentityHashMap;
import java.util.List;
Expand All @@ -61,44 +59,39 @@
* Tests GraphIndexes against vectors from various datasets
*/
public class Bench {
private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport();
private static void testRecall(int M,
private static void testIndexParams(int M,
int efConstruction,
List<Function<DataSet, VectorCompressor<?>>> compressionGrid,
List<Integer> efSearchOptions,
DataSet ds,
Path testDirectory) throws IOException
{
var floatVectors = new ListRandomAccessVectorValues(ds.baseVectors, ds.baseVectors.get(0).length());
var topK = ds.groundTruth.get(0).size();

var start = System.nanoTime();
var floatVectors = ds.getBaseRavv();
var builder = new GraphIndexBuilder(floatVectors, ds.similarityFunction, M, efConstruction, 1.2f, 1.2f);
var start = System.nanoTime();
var onHeapGraph = builder.build();
System.out.format("Build M=%d ef=%d in %.2fs with avg degree %.2f and %.2f short edges%n",
M, efConstruction, (System.nanoTime() - start) / 1_000_000_000.0, onHeapGraph.getAverageDegree(), onHeapGraph.getAverageShortEdges());

var graphPath = testDirectory.resolve("graph" + M + efConstruction + ds.name);
var fusedGraphPath = testDirectory.resolve("fusedgraph" + M + efConstruction + ds.name);
try {
try (var outputStream = new DataOutputStream(new BufferedOutputStream(Files.newOutputStream(graphPath)))) {
OnDiskGraphIndex.write(onHeapGraph, floatVectors, outputStream);
}

for (var cf : compressionGrid) {
var compressor = getCompressor(cf, ds);
CompressedVectors cv;
CompressedVectors cv = null;
var fusedCompatible = compressor instanceof ProductQuantization && ((ProductQuantization) compressor).getClusterCount() == 32;
if (compressor == null) {
cv = null;
try (var outputStream = new DataOutputStream(new BufferedOutputStream(Files.newOutputStream(graphPath)))) {
OnDiskGraphIndex.write(onHeapGraph, floatVectors, outputStream);
}
System.out.format("Uncompressed vectors%n");
} else {
start = System.nanoTime();
var quantizedVectors = compressor.encodeAll(ds.baseVectors);
cv = compressor.createCompressedVectors(quantizedVectors);
System.out.format("%s encoded %d vectors [%.2f MB] in %.2fs%n", compressor, ds.baseVectors.size(), (cv.ramBytesUsed() / 1024f / 1024f), (System.nanoTime() - start) / 1_000_000_000.0);
try (var outputStream = new DataOutputStream(new BufferedOutputStream(Files.newOutputStream(graphPath)))) {
OnDiskGraphIndex.write(onHeapGraph, floatVectors, outputStream);
}

if (fusedCompatible) {
try (var outputStream = new DataOutputStream(new BufferedOutputStream(Files.newOutputStream(fusedGraphPath)))) {
OnDiskADCGraphIndex.write(onHeapGraph, floatVectors, (PQVectors) cv, outputStream);
Expand All @@ -108,29 +101,17 @@ private static void testRecall(int M,

try (var onDiskGraph = new CachingGraphIndex(new OnDiskGraphIndex(ReaderSupplierFactory.open(graphPath), 0));
var onDiskFusedGraph = fusedCompatible ? new CachingADCGraphIndex(new OnDiskADCGraphIndex(ReaderSupplierFactory.open(fusedGraphPath), 0)) : null) {
int queryRuns = 2;
for (int overquery : efSearchOptions) {
if (compressor == null) {
// include both in-memory and on-disk search of uncompressed vectors
start = System.nanoTime();
var pqr = performQueries(ds, floatVectors, cv, onHeapGraph, topK, topK * overquery, queryRuns);
var recall = ((double) pqr.topKFound) / (queryRuns * ds.queryVectors.size() * topK);
System.out.format(" Query %s top %d/%d recall %.4f in %.2fs after %,d nodes visited%n",
"(memory)", topK, overquery, recall, (System.nanoTime() - start) / 1_000_000_000.0, pqr.nodesVisited);
}
if (fusedCompatible) {
// include both fused and regular graphs for PQ if clusters == 32
start = System.nanoTime();
var pqr = performQueries(ds, floatVectors, cv, onDiskFusedGraph, topK, topK * overquery, queryRuns);
var recall = ((double) pqr.topKFound) / (queryRuns * ds.queryVectors.size() * topK);
System.out.format(" Query %s top %d/%d recall %.4f in %.2fs after %,d nodes visited%n",
"(fused)", topK, overquery, recall, (System.nanoTime() - start) / 1_000_000_000.0, pqr.nodesVisited);
}
start = System.nanoTime();
var pqr = performQueries(ds, floatVectors, cv, onDiskGraph, topK, topK * overquery, queryRuns);
var recall = ((double) pqr.topKFound) / (queryRuns * ds.queryVectors.size() * topK);
System.out.format(" Query %stop %d/%d recall %.4f in %.2fs after %,d nodes visited%n",
compressor == null ? "(disk) " : "", topK, overquery, recall, (System.nanoTime() - start) / 1_000_000_000.0, pqr.nodesVisited);
List<GraphIndex> graphs = new ArrayList<>();
graphs.add(onDiskGraph);
if (onDiskFusedGraph != null) {
graphs.add(onDiskFusedGraph);
}
if (cv == null) {
graphs.add(onHeapGraph); // if we have no cv, compare on-heap/on-disk with exact searches
}
for (var g : graphs) {
var cs = new ConfiguredSystem(ds, g, cv);
testConfiguration(cs, efSearchOptions);
}
}
}
Expand All @@ -140,6 +121,39 @@ private static void testRecall(int M,
}
}

private static void testConfiguration(ConfiguredSystem cs, List<Integer> efSearchOptions) {
var topK = cs.ds.groundTruth.get(0).size();
System.out.format("Using %s:%n", cs.index);
for (int overquery : efSearchOptions) {
var start = System.nanoTime();
var pqr = performQueries(cs, topK, topK * overquery, 2);
var recall = ((double) pqr.topKFound) / (2 * cs.ds.queryVectors.size() * topK);
System.out.format(" Query top %d/%d recall %.4f in %.2fs after %,d nodes visited%n",
topK, overquery, recall, (System.nanoTime() - start) / 1_000_000_000.0, pqr.nodesVisited);

}
}

static class ConfiguredSystem {
DataSet ds;
GraphIndex index;
CompressedVectors cv;

ConfiguredSystem(DataSet ds, GraphIndex index, CompressedVectors cv) {
this.ds = ds;
this.index = index;
this.cv = cv;
}

public NodeSimilarity.ApproximateScoreFunction approximateScoreFunctionFor(VectorFloat<?> queryVector, GraphIndex.View view) {
if (index instanceof CachingADCGraphIndex) {
return ((CachingADCGraphIndex.CachedView) view).approximateScoreFunctionFor(queryVector, ds.similarityFunction);
} else {
return cv.approximateScoreFunctionFor(queryVector, ds.similarityFunction);
}
}
}

// avoid recomputing the compressor repeatedly (this is a relatively small memory footprint)
private static final Map<Function<DataSet, VectorCompressor<?>>, VectorCompressor<?>> cachedCompressors = new IdentityHashMap<>();
private static VectorCompressor<?> getCompressor(Function<DataSet, VectorCompressor<?>> cf, DataSet ds) {
Expand Down Expand Up @@ -178,34 +192,29 @@ private static long topKCorrect(int topK, SearchResult.NodeScore[] nn, Set<Integ
return topKCorrect(topK, a, gt);
}

private static ResultSummary performQueries(DataSet ds, RandomAccessVectorValues exactVv, CompressedVectors cv, GraphIndex index, int topK, int efSearch, int queryRuns) {
private static ResultSummary performQueries(ConfiguredSystem cs, int topK, int efSearch, int queryRuns) {
assert efSearch >= topK;
LongAdder topKfound = new LongAdder();
LongAdder nodesVisited = new LongAdder();
for (int k = 0; k < queryRuns; k++) {
IntStream.range(0, ds.queryVectors.size()).parallel().forEach(i -> {
var queryVector = ds.queryVectors.get(i);
IntStream.range(0, cs.ds.queryVectors.size()).parallel().forEach(i -> {
var queryVector = cs.ds.queryVectors.get(i);
SearchResult sr;
if (cv != null) {
try (var view = index.getView()) {
NodeSimilarity.ApproximateScoreFunction sf;
if (index instanceof CachingADCGraphIndex) {
sf = ((CachingADCGraphIndex.CachedView) view).approximateScoreFunctionFor(queryVector, ds.similarityFunction);
} else {
sf = cv.approximateScoreFunctionFor(queryVector, ds.similarityFunction);
}
var rr = NodeSimilarity.Reranker.from(queryVector, ds.similarityFunction, view);
if (cs.cv != null) {
try (var view = cs.index.getView()) {
NodeSimilarity.ApproximateScoreFunction sf = cs.approximateScoreFunctionFor(queryVector, view);
var rr = NodeSimilarity.Reranker.from(queryVector, cs.ds.similarityFunction, view);
sr = new GraphSearcher.Builder(view)
.build()
.search(sf, rr, efSearch, Bits.ALL);
} catch (Exception e) {
throw new RuntimeException(e);
}
} else {
sr = GraphSearcher.search(queryVector, efSearch, exactVv, ds.similarityFunction, index, Bits.ALL);
sr = GraphSearcher.search(queryVector, efSearch, cs.ds.getBaseRavv(), cs.ds.similarityFunction, cs.index, Bits.ALL);
}

var gt = ds.groundTruth.get(i);
var gt = cs.ds.groundTruth.get(i);
var n = topKCorrect(topK, sr.getNodes(), gt);
topKfound.add(n);
nodesVisited.add(sr.getVisitedCount());
Expand Down Expand Up @@ -247,8 +256,7 @@ public static void main(String[] args) throws IOException {
for (var nwDatasetName : nwFiles) {
if (pattern.matcher(nwDatasetName).find()) {
var mfd = DownloadHelper.maybeDownloadFvecs(nwDatasetName);
gridSearch(mfd.load(), compressionGrid, mGrid, efConstructionGrid, efSearchGrid);
cachedCompressors.clear();
gridSearchIndexParams(mfd.load(), compressionGrid, mGrid, efConstructionGrid, efSearchGrid);
}
}

Expand All @@ -267,8 +275,7 @@ public static void main(String[] args) throws IOException {
for (var f : hdf5Files) {
if (pattern.matcher(f).find()) {
DownloadHelper.maybeDownloadHdf5(f);
gridSearch(Hdf5Loader.load(f), compressionGrid, mGrid, efConstructionGrid, efSearchGrid);
cachedCompressors.clear();
gridSearchIndexParams(Hdf5Loader.load(f), compressionGrid, mGrid, efConstructionGrid, efSearchGrid);
}
}

Expand All @@ -277,12 +284,11 @@ public static void main(String[] args) throws IOException {
compressionGrid = Arrays.asList(null,
ds -> ProductQuantization.compute(ds.getBaseRavv(), ds.getDimension(), true));
var grid2d = DataSetCreator.create2DGrid(4_000_000, 10_000, 100);
gridSearch(grid2d, compressionGrid, mGrid, efConstructionGrid, efSearchGrid);
cachedCompressors.clear();
gridSearchIndexParams(grid2d, compressionGrid, mGrid, efConstructionGrid, efSearchGrid);
}
}

private static void gridSearch(DataSet ds,
private static void gridSearchIndexParams(DataSet ds,
List<Function<DataSet, VectorCompressor<?>>> compressionGrid,
List<Integer> mGrid,
List<Integer> efConstructionGrid,
Expand All @@ -292,11 +298,12 @@ private static void gridSearch(DataSet ds,
try {
for (int M : mGrid) {
for (int efC : efConstructionGrid) {
testRecall(M, efC, compressionGrid, efSearchFactor, ds, testDirectory);
testIndexParams(M, efC, compressionGrid, efSearchFactor, ds, testDirectory);
}
}
} finally {
Files.delete(testDirectory);
cachedCompressors.clear();
}
}
}

0 comments on commit 99992bf

Please sign in to comment.