Skip to content

Commit

Permalink
Introduce agrona dependency for primitive data structures. Reduce box…
Browse files Browse the repository at this point in the history
…ing through these structures and some miscellaneous clean up.
  • Loading branch information
jkni committed Feb 2, 2024
1 parent 1e4d637 commit 81eb0df
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@
import io.github.jbellis.jvector.graph.GraphIndex;
import io.github.jbellis.jvector.util.Accountable;
import io.github.jbellis.jvector.util.RamUsageEstimator;
import org.agrona.collections.Int2ObjectHashMap;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

public abstract class GraphCache implements Accountable
{
Expand Down Expand Up @@ -65,12 +64,12 @@ public long ramBytesUsed()
private static final class HMGraphCache extends GraphCache
{
// Map is created on construction and never modified
private final Map<Integer, CachedNode> cache;
private final Int2ObjectHashMap<CachedNode> cache;
private long ramBytesUsed = 0;

public HMGraphCache(GraphIndex<float[]> graph, int distance) {
try (var view = graph.getView()) {
HashMap<Integer, CachedNode> tmpCache = new HashMap<>();
var tmpCache = new Int2ObjectHashMap<CachedNode>();
cacheNeighborsOf(tmpCache, view, view.entryNode(), distance);
// Assigning to a final value ensure it is safely published
cache = tmpCache;
Expand All @@ -79,17 +78,17 @@ public HMGraphCache(GraphIndex<float[]> graph, int distance) {
}
}

private void cacheNeighborsOf(HashMap<Integer, CachedNode> tmpCache, GraphIndex.View<float[]> view, int ordinal, int distance) {
private void cacheNeighborsOf(Int2ObjectHashMap<CachedNode> tmpCache, GraphIndex.View<float[]> view, int ordinal, int distance) {
// cache this node
var it = view.getNeighborsIterator(ordinal);
int[] neighbors = new int[it.size()];
int i = 0;
while (it.hasNext()) {
neighbors[i++] = it.next();
neighbors[i++] = it.nextInt();
}
var node = new CachedNode(view.getVector(ordinal), neighbors);
tmpCache.put(ordinal, node);
ramBytesUsed += RamUsageEstimator.HASHTABLE_RAM_BYTES_PER_ENTRY + RamUsageEstimator.sizeOf(node.vector) + RamUsageEstimator.sizeOf(node.neighbors);
ramBytesUsed += 4 + RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.sizeOf(node.vector) + RamUsageEstimator.sizeOf(node.neighbors);

// call recursively on neighbors
if (distance > 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import io.github.jbellis.jvector.util.Accountable;
import io.github.jbellis.jvector.util.Bits;
import org.agrona.collections.Int2IntHashMap;

import java.io.DataOutput;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.IntStream;

Expand Down Expand Up @@ -63,7 +63,7 @@ public OnDiskGraphIndex(ReaderSupplier readerSupplier, long offset)
*/
public static <T> Map<Integer, Integer> getSequentialRenumbering(GraphIndex<T> graph) {
try (var view = graph.getView()) {
Map<Integer, Integer> oldToNewMap = new HashMap<>();
Int2IntHashMap oldToNewMap = new Int2IntHashMap(-1);
int nextOrdinal = 0;
for (int i = 0; i < view.getIdUpperBound(); i++) {
if (graph.containsNode(i)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ public RandomAccessReader get() {
public void close() {
smr.close();
}
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
import io.github.jbellis.jvector.util.Bits;
import io.github.jbellis.jvector.util.DocIdSetIterator;
import io.github.jbellis.jvector.util.FixedBitSet;
import org.agrona.collections.IntHashSet;

import java.util.HashSet;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.IntFunction;

import static java.lang.Math.min;

Expand Down Expand Up @@ -93,7 +93,7 @@ public NodesIterator iterator() {
* For every neighbor X that this node Y connects to, add a reciprocal link from X to Y.
* If overflow is > 1.0, allow the number of neighbors to exceed maxConnections temporarily.
*/
public void backlink(Function<Integer, ConcurrentNeighborSet> neighborhoodOf, float overflow) {
public void backlink(IntFunction<ConcurrentNeighborSet> neighborhoodOf, float overflow) {
NodeArray neighbors = neighborsRef.get();
for (int i = 0; i < neighbors.size(); i++) {
int nbr = neighbors.node[i];
Expand Down Expand Up @@ -281,7 +281,7 @@ static NodeArray mergeNeighbors(NodeArray a1, NodeArray a2) {

// since nodes are only guaranteed to be sorted by score -- ties can appear in any node order --
// we need to remember all the nodes with the current score to avoid adding duplicates
var nodesWithLastScore = new HashSet<>();
var nodesWithLastScore = new IntHashSet();
float lastAddedScore = Float.NaN;

// loop through both source arrays, adding the highest score element to the merged array,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
import io.github.jbellis.jvector.vector.VectorEncoding;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorUtil;
import org.agrona.collections.IntArrayQueue;
import org.agrona.collections.IntHashSet;

import java.io.IOException;
import java.util.ArrayDeque;
import java.util.HashSet;
import java.util.Objects;
import java.util.Random;
import java.util.Set;
Expand Down Expand Up @@ -242,7 +242,7 @@ private void reconnectOrphanedNodes() {
var v1 = vectors.get();
var v2 = vectorsCopy.get())
{
var connectionTargets = new HashSet<Integer>();
var connectionTargets = new IntHashSet();
for (int node = 0; node < graph.getIdUpperBound(); node++) {
if (!connectedNodes.get(node) && graph.containsNode(node)) {
// search for the closest neighbors
Expand Down Expand Up @@ -271,17 +271,17 @@ private void reconnectOrphanedNodes() {
}

private void findConnected(AtomicFixedBitSet connectedNodes, int start) {
var queue = new ArrayDeque<Integer>();
var queue = new IntArrayQueue();
queue.add(start);
try (var view = graph.getView()) {
while (!queue.isEmpty()) {
// DFS should result in less contention across findConnected threads than BFS
int next = queue.pop();
int next = queue.pollInt();
if (connectedNodes.getAndSet(next)) {
continue;
}
for (var it = view.getNeighborsIterator(next); it.hasNext(); ) {
queue.add(it.nextInt());
queue.addInt(it.nextInt());
}
}
} catch (Exception e) {
Expand Down Expand Up @@ -433,7 +433,7 @@ private long removeDeletedNodes() {

// remove deleted nodes from neighbor lists. If neighbor count drops below a minimum,
// add random connections to preserve connectivity
var affectedLiveNodes = new HashSet<Integer>();
var affectedLiveNodes = new IntHashSet();
var R = new Random();
try (var v1 = vectors.get();
var v2 = vectorsCopy.get())
Expand Down Expand Up @@ -556,7 +556,7 @@ private void updateNeighbors(ConcurrentNeighborSet neighbors, NodeArray natural,
neighbors.backlink(graph::getNeighbors, neighborOverflow);
}

private NodeArray toScratchCandidates(SearchResult.NodeScore[] candidates, int count, NodeArray scratch) {
private static NodeArray toScratchCandidates(SearchResult.NodeScore[] candidates, int count, NodeArray scratch) {
scratch.clear();
for (int i = 0; i < count; i++) {
var candidate = candidates[i];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ public int[] nodesCopy() {

public SearchResult.NodeScore[] nodesCopy(NodeSimilarity.ExactScoreFunction sf, float rerankFloor) {
return IntStream.range(0, size())
.mapToObj(i -> heap.get(i + 1))
.mapToLong(i -> heap.get(i + 1))
.filter(m -> decodeScore(m) >= rerankFloor)
.map(m -> new SearchResult.NodeScore(decodeNodeId(m), sf.similarityTo(decodeNodeId(m))))
.mapToObj(m -> new SearchResult.NodeScore(decodeNodeId(m), sf.similarityTo(decodeNodeId(m))))
.toArray(SearchResult.NodeScore[]::new);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ public int prevSetBit(int i) {
}

/** Return the long bits at the given <code>i64</code> index. */
private long longBits(long index, long[] bits, int i64) {
private static long longBits(long index, long[] bits, int i64) {
if ((index & (1L << i64)) == 0) {
return 0L;
} else {
Expand Down
5 changes: 5 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,11 @@
<artifactId>commons-math3</artifactId>
<version>3.6.1</version>
</dependency>
<dependency>
<groupId>org.agrona</groupId>
<artifactId>agrona</artifactId>
<version>1.20.0</version>
</dependency>
</dependencies>
<dependencyManagement>
<dependencies>
Expand Down

0 comments on commit 81eb0df

Please sign in to comment.