Skip to content

Commit

Permalink
optimize reconnectOrphanedNodes
Browse files Browse the repository at this point in the history
  • Loading branch information
jbellis committed Feb 8, 2024
1 parent 926fc7f commit f648b67
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ public void backlink(IntFunction<ConcurrentNeighborSet> neighborhoodOf, float ov
* for efficiency. This method is threadsafe, but if you call it concurrently with other inserts,
* the limit may end up being exceeded again.
*/
public void cleanup() {
public void enforceDegree() {
neighborsRef.getAndUpdate(this::removeAllNonDiverse);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import java.util.Objects;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ThreadLocalRandom;
Expand Down Expand Up @@ -213,7 +214,7 @@ public void cleanup() {
parallelExecutor.submit(() -> IntStream.range(0, graph.getIdUpperBound()).parallel().forEach(i -> {
var neighbors = graph.getNeighbors(i);
if (neighbors != null) {
neighbors.cleanup();
neighbors.enforceDegree();
}
})).join();

Expand All @@ -226,6 +227,7 @@ public void cleanup() {
}

private void reconnectOrphanedNodes() {
var searchPathNeighbors = new ConcurrentHashMap<Integer, NodeArray>();
// It's possible that reconnecting one node will result in disconnecting another, since we are maintaining
// the maxConnections invariant. In an extreme case, reconnecting node X disconnects Y, and reconnecting
// Y disconnects X again. So we do a best effort of 3 loops.
Expand All @@ -240,34 +242,61 @@ private void reconnectOrphanedNodes() {

// reconnect unreachable nodes
var nReconnected = new AtomicInteger();
var gs = graphSearcher.get();
var v1 = vectors.get();
var v2 = vectorsCopy.get();
var connectionTargets = new IntHashSet();
for (int node = 0; node < graph.getIdUpperBound(); node++) {
if (!connectedNodes.get(node) && graph.containsNode(node)) {
// search for the closest neighbors
var connectionTargets = ConcurrentHashMap.<Integer>newKeySet();
simdExecutor.submit(() -> IntStream.range(0, graph.getIdUpperBound()).parallel().forEach(node -> {
if (connectedNodes.get(node) || !graph.containsNode(node)) {
return;
}
nReconnected.incrementAndGet();

// first, attempt to connect one of our own neighbors to us
var neighbors = graph.getNeighbors(node).getCurrent();
if (connectToClosestNeighbor(node, neighbors, connectionTargets)) {
return;
}

// no unused candidate found -- search for more neighbors and try again
neighbors = searchPathNeighbors.get(node);
if (neighbors == null) {
var gs = graphSearcher.get();
var v1 = vectors.get();
var v2 = vectorsCopy.get();

var notSelfBits = createNotSelfBits(node);
var value = v1.vectorValue(node);
NodeSimilarity.ExactScoreFunction scoreFunction = i1 -> scoreBetween(v2.vectorValue(i1), value);
int ep = graph.entry();
var result = gs.searchInternal(scoreFunction, null, beamWidth, 0.0f, 0.0f, ep, notSelfBits).getNodes();
// connect this node to the closest neighbor that hasn't already been used as a connection target
// (since this edge is likely to be the "worst" one in that target's neighborhood, it's likely to be
// overwritten by the next node to need reconnection if we don't enforce uniqueness)
for (var ns : result) {
if (connectionTargets.add(ns.node)) {
graph.getNeighbors(ns.node).insertNotDiverse(node, ns.score, true);
break;
}
}
nReconnected.incrementAndGet();
var result = gs.searchInternal(scoreFunction, null, beamWidth, 0.0f, 0.0f, ep, notSelfBits);
neighbors = new NodeArray(result.getNodes().length);
toScratchCandidates(result.getNodes(), neighbors);
searchPathNeighbors.put(node, neighbors);
}
}
connectToClosestNeighbor(node, neighbors, connectionTargets);
}));
if (nReconnected.get() == 0) {
break;
}
System.out.println("Pass " + i + " reconnected " + nReconnected.get() + " nodes");
}
}

/**
* Connect `node` to the closest neighbor that is not already a connection target.
* @return true if such a neighbor was found.
*/
private boolean connectToClosestNeighbor(int node, NodeArray neighbors, Set<Integer> connectionTargets) {
// connect this node to the closest neighbor that hasn't already been used as a connection target
// (since this edge is likely to be the "worst" one in that target's neighborhood, it's likely to be
// overwritten by the next node to need reconnection if we don't choose a unique target)
for (int i = 0; i < neighbors.size; i++) {
var neighborNode = neighbors.node[i];
var neighborScore = neighbors.score[i];
if (connectionTargets.add(neighborNode)) {
graph.getNeighbors(neighborNode).insertNotDiverse(neighborNode, neighborScore, true);
return true;
}
}
return false;
}

private void findConnected(AtomicFixedBitSet connectedNodes, int start) {
Expand Down Expand Up @@ -340,7 +369,7 @@ public long addGraphNode(int node, RandomAccessVectorValues<T> vectors) {
// this means that considering additional nodes from the search path, that are by definition
// farther away than the ones in the topK, would not change the result.)
// TODO if we made NeighborArray an interface we could wrap the NodeScore[] directly instead of copying
var natural = toScratchCandidates(result.getNodes(), result.getNodes().length, naturalScratchPooled);
var natural = toScratchCandidates(result.getNodes(), naturalScratchPooled);
var concurrent = getConcurrentCandidates(node, inProgressBefore, concurrentScratchPooled, vectors, vc);
updateNeighbors(newNodeNeighbors, natural, concurrent);

Expand Down Expand Up @@ -399,7 +428,7 @@ public void improveConnections(int node) {
NodeSimilarity.ExactScoreFunction scoreFunction = i -> scoreBetween(vc.vectorValue(i), value);
var bits = new ExcludingBits(node);
var result = gs.searchInternal(scoreFunction, null, beamWidth, 0.0f, 0.0f, ep, bits);
var natural = toScratchCandidates(result.getNodes(), result.getNodes().length, naturalScratchPooled);
var natural = toScratchCandidates(result.getNodes(), naturalScratchPooled);
updateNeighbors(graph.getNeighbors(node), natural, NodeArray.EMPTY);
}

Expand Down Expand Up @@ -495,7 +524,7 @@ private void addNNDescentConnections(int node) {
NodeSimilarity.ExactScoreFunction scoreFunction = i -> scoreBetween(v2.vectorValue(i), value);
int ep = graph.entry();
var result = gs.searchInternal(scoreFunction, null, beamWidth, 0.0f, 0.0f, ep, notSelfBits);
var candidates = toScratchCandidates(result.getNodes(), result.getNodes().length, scratch);
var candidates = toScratchCandidates(result.getNodes(), scratch);
updateNeighbors(graph.getNeighbors(node), candidates, NodeArray.EMPTY);
}

Expand Down Expand Up @@ -541,10 +570,9 @@ private void updateNeighbors(ConcurrentNeighborSet neighbors, NodeArray natural,
neighbors.backlink(graph::getNeighbors, neighborOverflow);
}

private static NodeArray toScratchCandidates(SearchResult.NodeScore[] candidates, int count, NodeArray scratch) {
private static NodeArray toScratchCandidates(SearchResult.NodeScore[] candidates, NodeArray scratch) {
scratch.clear();
for (int i = 0; i < count; i++) {
var candidate = candidates[i];
for (var candidate : candidates) {
scratch.addInOrder(candidate.node, candidate.score);
}
return scratch;
Expand Down

0 comments on commit f648b67

Please sign in to comment.