forked from opensearch-project/k-NN
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial commit for segment-free-vector-search
Signed-off-by: Navneet Verma <[email protected]>
- Loading branch information
Showing
8 changed files
with
302 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
86 changes: 86 additions & 0 deletions
86
src/main/java/org/opensearch/knn/index/query/KNNWeightV2.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.index.query; | ||
|
||
import org.apache.lucene.index.LeafReaderContext; | ||
import org.apache.lucene.index.SegmentReader; | ||
import org.apache.lucene.search.Explanation; | ||
import org.apache.lucene.search.Scorer; | ||
import org.apache.lucene.search.Weight; | ||
import org.opensearch.common.lucene.Lucene; | ||
import org.opensearch.knn.service.OSLuceneDocId; | ||
|
||
import java.io.IOException; | ||
import java.util.Arrays; | ||
import java.util.Collections; | ||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
||
public class KNNWeightV2 extends Weight { | ||
|
||
private final KNNQuery knnQuery; | ||
private final List<OSLuceneDocId> osLuceneDocIds; | ||
|
||
public KNNWeightV2(KNNQuery knnQuery, List<OSLuceneDocId> osLuceneDocIds) { | ||
super(knnQuery); | ||
this.knnQuery = knnQuery; | ||
this.osLuceneDocIds = osLuceneDocIds; | ||
} | ||
|
||
/** | ||
* An explanation of the score computation for the named document. | ||
* | ||
* @param context the readers context to create the {@link Explanation} for. | ||
* @param doc the document's id relative to the given context's reader | ||
* @return an Explanation for the score | ||
* @throws IOException if an {@link IOException} occurs | ||
*/ | ||
@Override | ||
public Explanation explain(LeafReaderContext context, int doc) throws IOException { | ||
return null; | ||
} | ||
|
||
/** | ||
* Returns a {@link Scorer} which can iterate in order over all matching documents and assign them | ||
* a score. | ||
* | ||
* <p><b>NOTE:</b> null can be returned if no documents will be scored by this query. | ||
* | ||
* <p><b>NOTE</b>: The returned {@link Scorer} does not have {@link LeafReader#getLiveDocs()} | ||
* applied, they need to be checked on top. | ||
* | ||
* @param context the {@link LeafReaderContext} for which to return the | ||
* {@link Scorer}. | ||
* @return a {@link Scorer} which scores documents in/out-of order. | ||
* @throws IOException if there is a low-level I/O error | ||
*/ | ||
@Override | ||
public Scorer scorer(LeafReaderContext context) throws IOException { | ||
SegmentReader segmentReader = Lucene.segmentReader(context.reader()); | ||
byte[] segmentId = segmentReader.getSegmentInfo().info.getId(); | ||
Map<Integer, Float> docIdsToScoreMap = new HashMap<>(); | ||
for (OSLuceneDocId osLuceneDocId : osLuceneDocIds) { | ||
if (Arrays.equals(osLuceneDocId.getSegmentId(), segmentId)) { | ||
docIdsToScoreMap.put(osLuceneDocId.getSegmentDocId(), osLuceneDocId.getScore()); | ||
} | ||
} | ||
if (docIdsToScoreMap.isEmpty()) { | ||
return KNNScorer.emptyScorer(this); | ||
} | ||
final int maxDoc = Collections.max(docIdsToScoreMap.keySet()) + 1; | ||
return new KNNScorer(this, ResultUtil.resultMapToDocIds(docIdsToScoreMap, maxDoc), docIdsToScoreMap, 1); | ||
} | ||
|
||
/** | ||
* @param ctx | ||
* @return {@code true} if the object can be cached against a given leaf | ||
*/ | ||
@Override | ||
public boolean isCacheable(LeafReaderContext ctx) { | ||
return true; | ||
} | ||
} |
33 changes: 33 additions & 0 deletions
33
src/main/java/org/opensearch/knn/service/OSLuceneDocId.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.service; | ||
|
||
import lombok.Builder; | ||
import lombok.EqualsAndHashCode; | ||
import lombok.ToString; | ||
import lombok.Value; | ||
|
||
@Value | ||
@Builder | ||
@EqualsAndHashCode | ||
@ToString | ||
public class OSLuceneDocId { | ||
@Builder.Default | ||
String opensearchIndexName = "my-index"; | ||
byte[] segmentId; | ||
int segmentDocId; | ||
@Builder.Default | ||
float score = 0; | ||
|
||
public OSLuceneDocId cloneWithScore(float score) { | ||
return OSLuceneDocId.builder() | ||
.score(score) | ||
.segmentDocId(segmentDocId) | ||
.segmentId(segmentId) | ||
.opensearchIndexName(opensearchIndexName) | ||
.build(); | ||
} | ||
} |
105 changes: 105 additions & 0 deletions
105
src/main/java/org/opensearch/knn/service/VectorEngineService.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.service; | ||
|
||
import lombok.AccessLevel; | ||
import lombok.Builder; | ||
import lombok.NoArgsConstructor; | ||
import lombok.Value; | ||
import lombok.extern.log4j.Log4j2; | ||
import org.opensearch.knn.index.SpaceType; | ||
|
||
import java.util.Arrays; | ||
import java.util.Collections; | ||
import java.util.LinkedList; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.PriorityQueue; | ||
import java.util.concurrent.ConcurrentHashMap; | ||
import java.util.concurrent.atomic.AtomicInteger; | ||
import java.util.stream.Collectors; | ||
|
||
@NoArgsConstructor(access = AccessLevel.PRIVATE) | ||
@Log4j2 | ||
public class VectorEngineService { | ||
private final Map<OSLuceneDocId, Integer> luceneDocIdToVectorEngineDocId = new ConcurrentHashMap<>(); | ||
// Keeping the OSLuceneDocIds as list to ensure that if IndexReader is open we can reuse the docIds | ||
private final Map<Integer, List<OSLuceneDocId>> vectorEngineDocIdToLuceneDocId = new ConcurrentHashMap<>(); | ||
private final Map<Integer, float[]> vectorEngineDocIdToVector = new ConcurrentHashMap<>(); | ||
private final AtomicInteger currentVectorDocId = new AtomicInteger(0); | ||
|
||
private static VectorEngineService INSTANCE = null; | ||
|
||
public static VectorEngineService getInstance() { | ||
if (INSTANCE == null) { | ||
INSTANCE = new VectorEngineService(); | ||
} | ||
return INSTANCE; | ||
} | ||
|
||
public void ingestData(final OSLuceneDocId luceneDocId, float[] vector, final SpaceType spaceType) { | ||
log.debug("SpaceType during ingestion is : {}", spaceType); | ||
luceneDocIdToVectorEngineDocId.put(luceneDocId, currentVectorDocId.intValue()); | ||
int currentDocId = currentVectorDocId.intValue(); | ||
vectorEngineDocIdToLuceneDocId.getOrDefault(currentDocId, Collections.synchronizedList(new LinkedList<>())).add(luceneDocId); | ||
vectorEngineDocIdToVector.put(currentDocId, vector); | ||
currentVectorDocId.incrementAndGet(); | ||
} | ||
|
||
public List<OSLuceneDocId> search(int k, float[] queryVector, final SpaceType spaceType) { | ||
int finalk = Math.min(k, vectorEngineDocIdToVector.size()); | ||
PriorityQueue<VectorScoreDoc> scoreDocsQueue = new PriorityQueue<>((a, b) -> Float.compare(a.score, b.score)); | ||
for (int docId : vectorEngineDocIdToVector.keySet()) { | ||
float score = spaceType.getKnnVectorSimilarityFunction() | ||
.getVectorSimilarityFunction() | ||
.compare(queryVector, vectorEngineDocIdToVector.get(docId)); | ||
if (scoreDocsQueue.size() < finalk) { | ||
scoreDocsQueue.add(new VectorScoreDoc(score, docId)); | ||
} else { | ||
assert scoreDocsQueue.peek() != null; | ||
if (score > scoreDocsQueue.peek().score) { | ||
scoreDocsQueue.poll(); | ||
scoreDocsQueue.add(new VectorScoreDoc(score, docId)); | ||
} | ||
} | ||
} | ||
|
||
return scoreDocsQueue.parallelStream() | ||
.flatMap(s -> vectorEngineDocIdToLuceneDocId.get(s.getDocId()).parallelStream() | ||
.map(osLuceneDocId -> osLuceneDocId.cloneWithScore(s.getScore()))) | ||
.collect(Collectors.toList()); | ||
} | ||
|
||
public void removeOldSegmentKeys(final byte[] segmentId) { | ||
for(OSLuceneDocId osLuceneDocId : luceneDocIdToVectorEngineDocId.keySet()) { | ||
if(Arrays.equals(osLuceneDocId.getSegmentId(), segmentId)) { | ||
int vectorEngineDocId = luceneDocIdToVectorEngineDocId.remove(osLuceneDocId); | ||
|
||
List<OSLuceneDocId> osLuceneDocIds = vectorEngineDocIdToLuceneDocId.get(vectorEngineDocId); | ||
for(OSLuceneDocId docId : osLuceneDocIds) { | ||
luceneDocIdToVectorEngineDocId.remove(docId); | ||
} | ||
// if all the keys are removed | ||
if(luceneDocIdToVectorEngineDocId.isEmpty()) { | ||
// remove the node from vector search DS. | ||
vectorEngineDocIdToVector.remove(vectorEngineDocId); | ||
// remove the node from VectorEngine to Lucene DocId Map | ||
vectorEngineDocIdToLuceneDocId.remove(vectorEngineDocId); | ||
// remove the luceneDocId to Vector Engine DocId too. | ||
luceneDocIdToVectorEngineDocId.remove(osLuceneDocId); | ||
} | ||
} | ||
} | ||
} | ||
|
||
@Value | ||
@Builder | ||
private static class VectorScoreDoc { | ||
float score; | ||
int docId; | ||
} | ||
|
||
} |
Oops, something went wrong.