From 81eb0dfe244c0b031ab1124060070037146202bd Mon Sep 17 00:00:00 2001 From: Joel Knighton Date: Fri, 2 Feb 2024 11:21:09 -0600 Subject: [PATCH] Introduce agrona dependency for primitive data structures. Reduce boxing through these structures and some miscellaneous clean up. --- .../github/jbellis/jvector/disk/GraphCache.java | 13 ++++++------- .../jbellis/jvector/disk/OnDiskGraphIndex.java | 4 ++-- .../jvector/disk/SimpleMappedReaderSupplier.java | 2 +- .../jvector/graph/ConcurrentNeighborSet.java | 8 ++++---- .../jbellis/jvector/graph/GraphIndexBuilder.java | 16 ++++++++-------- .../github/jbellis/jvector/graph/NodeQueue.java | 4 ++-- .../jbellis/jvector/util/SparseFixedBitSet.java | 2 +- pom.xml | 5 +++++ 8 files changed, 29 insertions(+), 25 deletions(-) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/GraphCache.java b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/GraphCache.java index 484c70e6e..2e059bdfa 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/GraphCache.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/GraphCache.java @@ -19,10 +19,9 @@ import io.github.jbellis.jvector.graph.GraphIndex; import io.github.jbellis.jvector.util.Accountable; import io.github.jbellis.jvector.util.RamUsageEstimator; +import org.agrona.collections.Int2ObjectHashMap; import java.io.IOException; -import java.util.HashMap; -import java.util.Map; public abstract class GraphCache implements Accountable { @@ -65,12 +64,12 @@ public long ramBytesUsed() private static final class HMGraphCache extends GraphCache { // Map is created on construction and never modified - private final Map cache; + private final Int2ObjectHashMap cache; private long ramBytesUsed = 0; public HMGraphCache(GraphIndex graph, int distance) { try (var view = graph.getView()) { - HashMap tmpCache = new HashMap<>(); + var tmpCache = new Int2ObjectHashMap(); cacheNeighborsOf(tmpCache, view, view.entryNode(), distance); // Assigning to a final value ensure it is safely published cache = tmpCache; @@ -79,17 +78,17 @@ public HMGraphCache(GraphIndex graph, int distance) { } } - private void cacheNeighborsOf(HashMap tmpCache, GraphIndex.View view, int ordinal, int distance) { + private void cacheNeighborsOf(Int2ObjectHashMap tmpCache, GraphIndex.View view, int ordinal, int distance) { // cache this node var it = view.getNeighborsIterator(ordinal); int[] neighbors = new int[it.size()]; int i = 0; while (it.hasNext()) { - neighbors[i++] = it.next(); + neighbors[i++] = it.nextInt(); } var node = new CachedNode(view.getVector(ordinal), neighbors); tmpCache.put(ordinal, node); - ramBytesUsed += RamUsageEstimator.HASHTABLE_RAM_BYTES_PER_ENTRY + RamUsageEstimator.sizeOf(node.vector) + RamUsageEstimator.sizeOf(node.neighbors); + ramBytesUsed += 4 + RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.sizeOf(node.vector) + RamUsageEstimator.sizeOf(node.neighbors); // call recursively on neighbors if (distance > 0) { diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/OnDiskGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/OnDiskGraphIndex.java index 89cf0ae43..ad7913e2d 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/OnDiskGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/OnDiskGraphIndex.java @@ -22,13 +22,13 @@ import io.github.jbellis.jvector.graph.RandomAccessVectorValues; import io.github.jbellis.jvector.util.Accountable; import io.github.jbellis.jvector.util.Bits; +import org.agrona.collections.Int2IntHashMap; import java.io.DataOutput; import java.io.IOException; import java.io.UncheckedIOException; import java.util.ArrayList; import java.util.Comparator; -import java.util.HashMap; import java.util.Map; import java.util.stream.IntStream; @@ -63,7 +63,7 @@ public OnDiskGraphIndex(ReaderSupplier readerSupplier, long offset) */ public static Map getSequentialRenumbering(GraphIndex graph) { try (var view = graph.getView()) { - Map oldToNewMap = new HashMap<>(); + Int2IntHashMap oldToNewMap = new Int2IntHashMap(-1); int nextOrdinal = 0; for (int i = 0; i < view.getIdUpperBound(); i++) { if (graph.containsNode(i)) { diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/SimpleMappedReaderSupplier.java b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/SimpleMappedReaderSupplier.java index 7466d49d4..41c2fe608 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/SimpleMappedReaderSupplier.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/SimpleMappedReaderSupplier.java @@ -34,4 +34,4 @@ public RandomAccessReader get() { public void close() { smr.close(); } -}; +} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborSet.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborSet.java index c4955b43f..200f98475 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborSet.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborSet.java @@ -20,11 +20,11 @@ import io.github.jbellis.jvector.util.Bits; import io.github.jbellis.jvector.util.DocIdSetIterator; import io.github.jbellis.jvector.util.FixedBitSet; +import org.agrona.collections.IntHashSet; -import java.util.HashSet; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Function; +import java.util.function.IntFunction; import static java.lang.Math.min; @@ -93,7 +93,7 @@ public NodesIterator iterator() { * For every neighbor X that this node Y connects to, add a reciprocal link from X to Y. * If overflow is > 1.0, allow the number of neighbors to exceed maxConnections temporarily. */ - public void backlink(Function neighborhoodOf, float overflow) { + public void backlink(IntFunction neighborhoodOf, float overflow) { NodeArray neighbors = neighborsRef.get(); for (int i = 0; i < neighbors.size(); i++) { int nbr = neighbors.node[i]; @@ -281,7 +281,7 @@ static NodeArray mergeNeighbors(NodeArray a1, NodeArray a2) { // since nodes are only guaranteed to be sorted by score -- ties can appear in any node order -- // we need to remember all the nodes with the current score to avoid adding duplicates - var nodesWithLastScore = new HashSet<>(); + var nodesWithLastScore = new IntHashSet(); float lastAddedScore = Float.NaN; // loop through both source arrays, adding the highest score element to the merged array, diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java index bc77284da..59bc8c645 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java @@ -25,10 +25,10 @@ import io.github.jbellis.jvector.vector.VectorEncoding; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; import io.github.jbellis.jvector.vector.VectorUtil; +import org.agrona.collections.IntArrayQueue; +import org.agrona.collections.IntHashSet; import java.io.IOException; -import java.util.ArrayDeque; -import java.util.HashSet; import java.util.Objects; import java.util.Random; import java.util.Set; @@ -242,7 +242,7 @@ private void reconnectOrphanedNodes() { var v1 = vectors.get(); var v2 = vectorsCopy.get()) { - var connectionTargets = new HashSet(); + var connectionTargets = new IntHashSet(); for (int node = 0; node < graph.getIdUpperBound(); node++) { if (!connectedNodes.get(node) && graph.containsNode(node)) { // search for the closest neighbors @@ -271,17 +271,17 @@ private void reconnectOrphanedNodes() { } private void findConnected(AtomicFixedBitSet connectedNodes, int start) { - var queue = new ArrayDeque(); + var queue = new IntArrayQueue(); queue.add(start); try (var view = graph.getView()) { while (!queue.isEmpty()) { // DFS should result in less contention across findConnected threads than BFS - int next = queue.pop(); + int next = queue.pollInt(); if (connectedNodes.getAndSet(next)) { continue; } for (var it = view.getNeighborsIterator(next); it.hasNext(); ) { - queue.add(it.nextInt()); + queue.addInt(it.nextInt()); } } } catch (Exception e) { @@ -433,7 +433,7 @@ private long removeDeletedNodes() { // remove deleted nodes from neighbor lists. If neighbor count drops below a minimum, // add random connections to preserve connectivity - var affectedLiveNodes = new HashSet(); + var affectedLiveNodes = new IntHashSet(); var R = new Random(); try (var v1 = vectors.get(); var v2 = vectorsCopy.get()) @@ -556,7 +556,7 @@ private void updateNeighbors(ConcurrentNeighborSet neighbors, NodeArray natural, neighbors.backlink(graph::getNeighbors, neighborOverflow); } - private NodeArray toScratchCandidates(SearchResult.NodeScore[] candidates, int count, NodeArray scratch) { + private static NodeArray toScratchCandidates(SearchResult.NodeScore[] candidates, int count, NodeArray scratch) { scratch.clear(); for (int i = 0; i < count; i++) { var candidate = candidates[i]; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeQueue.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeQueue.java index 1e04d5cf8..bf4a6ad78 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeQueue.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeQueue.java @@ -136,9 +136,9 @@ public int[] nodesCopy() { public SearchResult.NodeScore[] nodesCopy(NodeSimilarity.ExactScoreFunction sf, float rerankFloor) { return IntStream.range(0, size()) - .mapToObj(i -> heap.get(i + 1)) + .mapToLong(i -> heap.get(i + 1)) .filter(m -> decodeScore(m) >= rerankFloor) - .map(m -> new SearchResult.NodeScore(decodeNodeId(m), sf.similarityTo(decodeNodeId(m)))) + .mapToObj(m -> new SearchResult.NodeScore(decodeNodeId(m), sf.similarityTo(decodeNodeId(m)))) .toArray(SearchResult.NodeScore[]::new); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/SparseFixedBitSet.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/SparseFixedBitSet.java index b00af2909..262bde566 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/SparseFixedBitSet.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/SparseFixedBitSet.java @@ -409,7 +409,7 @@ public int prevSetBit(int i) { } /** Return the long bits at the given i64 index. */ - private long longBits(long index, long[] bits, int i64) { + private static long longBits(long index, long[] bits, int i64) { if ((index & (1L << i64)) == 0) { return 0L; } else { diff --git a/pom.xml b/pom.xml index 5f4ba48ab..bf456688e 100644 --- a/pom.xml +++ b/pom.xml @@ -178,6 +178,11 @@ commons-math3 3.6.1 + + org.agrona + agrona + 1.20.0 +