diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index b81a54124..f3f4d4455 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -90,6 +90,7 @@ public class KNNSettings { public static final String QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES = "knn.quantization.cache.expiry.minutes"; public static final String KNN_FAISS_AVX512_DISABLED = "knn.faiss.avx512.disabled"; public static final String KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED = "index.knn.disk.vector.shard_level_rescoring_disabled"; + public static final String USE_NEW_QUERY = "index.knn.use_new_query"; /** * Default setting values @@ -143,6 +144,8 @@ public class KNNSettings { Dynamic ); + public static final Setting USE_NEW_QUERY_SETTING = Setting.boolSetting(USE_NEW_QUERY, false, IndexScope, Dynamic); + // This setting controls how much memory should be used to transfer vectors from Java to JNI Layer. The default // 1% of the JVM heap public static final Setting KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING = Setting.memorySizeSetting( @@ -499,6 +502,10 @@ private Setting getSetting(String key) { return KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING; } + if (USE_NEW_QUERY.equals(key)) { + return USE_NEW_QUERY_SETTING; + } + throw new IllegalArgumentException("Cannot find setting by key [" + key + "]"); } @@ -522,7 +529,8 @@ public List> getSettings() { KNN_FAISS_AVX512_DISABLED_SETTING, QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING, - KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING + KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING, + USE_NEW_QUERY_SETTING ); return Stream.concat(settings.stream(), Stream.concat(getFeatureFlags().stream(), dynamicCacheSettings.values().stream())) .collect(Collectors.toList()); @@ -577,6 +585,10 @@ public static Integer getFilteredExactSearchThreshold(final String indexName) { .getAsInt(ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE); } + public static boolean isUseNewQuery(final String indexName) { + return KNNSettings.state().clusterService.state().getMetadata().index(indexName).getSettings().getAsBoolean(USE_NEW_QUERY, false); + } + public static boolean isShardLevelRescoringEnabledForDiskBasedVector(String indexName) { return KNNSettings.state().clusterService.state() .getMetadata() diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index 7c8636577..54e4b6330 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -32,6 +32,7 @@ import org.opensearch.knn.plugin.stats.KNNGraphValue; import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.service.VectorEngineServiceWrapper; import java.io.IOException; import java.util.ArrayList; @@ -104,6 +105,7 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { field.getVectors() ); final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs); + VectorEngineServiceWrapper.ingestData(knnVectorValuesSupplier.get(), segmentWriteState, fieldInfo); // Check only after quantization state writer finish writing its state, since it is required // even if there are no graph files in segment, which will be later used by exact search if (shouldSkipBuildingVectorDataStructure(totalLiveDocs)) { @@ -165,6 +167,7 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState long time_in_millis = stopWatch.stop().totalTime().millis(); KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.incrementBy(time_in_millis); log.debug("Merge took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName()); + } /** @@ -201,6 +204,7 @@ public void close() throws IOException { quantizationStateWriter.closeOutput(); } IOUtils.close(flatVectorsWriter); + VectorEngineServiceWrapper.close(segmentWriteState); } /** diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java index f0974f7e9..088db0cc2 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java @@ -10,6 +10,7 @@ import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.Setter; +import org.apache.lucene.index.FieldInfo; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.FieldExistsQuery; @@ -22,9 +23,13 @@ import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.query.rescore.RescoreContext; +import org.opensearch.knn.service.OSLuceneDocId; +import org.opensearch.knn.service.VectorEngineService; +import org.opensearch.knn.service.VectorEngineServiceWrapper; import java.io.IOException; import java.util.Arrays; +import java.util.List; import java.util.Map; import java.util.Objects; @@ -167,6 +172,12 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo if (!KNNSettings.isKNNPluginEnabled()) { throw new IllegalStateException("KNN plugin is disabled. To enable update knn.plugin.enabled to true"); } + if (KNNSettings.isUseNewQuery(indexName)) { + FieldInfo fieldInfo = + searcher.getIndexReader().leaves().get(0).reader().getFieldInfos().fieldInfo(this.getField()); + List docIdList = VectorEngineServiceWrapper.search(this, fieldInfo); + return new KNNWeightV2(this, docIdList); + } final Weight filterWeight = getFilterWeight(searcher); if (filterWeight != null) { return new KNNWeight(this, boost, filterWeight); diff --git a/src/main/java/org/opensearch/knn/index/query/KNNScorer.java b/src/main/java/org/opensearch/knn/index/query/KNNScorer.java index 99962d307..1d1e26016 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNScorer.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNScorer.java @@ -63,7 +63,7 @@ public int docID() { * @param knnWeight {@link KNNWeight} * @return {@link KNNScorer} */ - public static Scorer emptyScorer(KNNWeight knnWeight) { + public static Scorer emptyScorer(Weight knnWeight) { return new Scorer(knnWeight) { private final DocIdSetIterator docIdsIter = DocIdSetIterator.empty(); diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeightV2.java b/src/main/java/org/opensearch/knn/index/query/KNNWeightV2.java new file mode 100644 index 000000000..ad1954456 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeightV2.java @@ -0,0 +1,86 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.SegmentReader; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; +import org.opensearch.common.lucene.Lucene; +import org.opensearch.knn.service.OSLuceneDocId; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class KNNWeightV2 extends Weight { + + private final KNNQuery knnQuery; + private final List osLuceneDocIds; + + public KNNWeightV2(KNNQuery knnQuery, List osLuceneDocIds) { + super(knnQuery); + this.knnQuery = knnQuery; + this.osLuceneDocIds = osLuceneDocIds; + } + + /** + * An explanation of the score computation for the named document. + * + * @param context the readers context to create the {@link Explanation} for. + * @param doc the document's id relative to the given context's reader + * @return an Explanation for the score + * @throws IOException if an {@link IOException} occurs + */ + @Override + public Explanation explain(LeafReaderContext context, int doc) throws IOException { + return null; + } + + /** + * Returns a {@link Scorer} which can iterate in order over all matching documents and assign them + * a score. + * + *

NOTE: null can be returned if no documents will be scored by this query. + * + *

NOTE: The returned {@link Scorer} does not have {@link LeafReader#getLiveDocs()} + * applied, they need to be checked on top. + * + * @param context the {@link LeafReaderContext} for which to return the + * {@link Scorer}. + * @return a {@link Scorer} which scores documents in/out-of order. + * @throws IOException if there is a low-level I/O error + */ + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + SegmentReader segmentReader = Lucene.segmentReader(context.reader()); + byte[] segmentId = segmentReader.getSegmentInfo().info.getId(); + Map docIdsToScoreMap = new HashMap<>(); + for (OSLuceneDocId osLuceneDocId : osLuceneDocIds) { + if (Arrays.equals(osLuceneDocId.getSegmentId(), segmentId)) { + docIdsToScoreMap.put(osLuceneDocId.getSegmentDocId(), osLuceneDocId.getScore()); + } + } + if (docIdsToScoreMap.isEmpty()) { + return KNNScorer.emptyScorer(this); + } + final int maxDoc = Collections.max(docIdsToScoreMap.keySet()) + 1; + return new KNNScorer(this, ResultUtil.resultMapToDocIds(docIdsToScoreMap, maxDoc), docIdsToScoreMap, 1); + } + + /** + * @param ctx + * @return {@code true} if the object can be cached against a given leaf + */ + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } +} diff --git a/src/main/java/org/opensearch/knn/service/OSLuceneDocId.java b/src/main/java/org/opensearch/knn/service/OSLuceneDocId.java new file mode 100644 index 000000000..8ee273851 --- /dev/null +++ b/src/main/java/org/opensearch/knn/service/OSLuceneDocId.java @@ -0,0 +1,33 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.service; + +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.ToString; +import lombok.Value; + +@Value +@Builder +@EqualsAndHashCode +@ToString +public class OSLuceneDocId { + @Builder.Default + String opensearchIndexName = "my-index"; + byte[] segmentId; + int segmentDocId; + @Builder.Default + float score = 0; + + public OSLuceneDocId cloneWithScore(float score) { + return OSLuceneDocId.builder() + .score(score) + .segmentDocId(segmentDocId) + .segmentId(segmentId) + .opensearchIndexName(opensearchIndexName) + .build(); + } +} diff --git a/src/main/java/org/opensearch/knn/service/VectorEngineService.java b/src/main/java/org/opensearch/knn/service/VectorEngineService.java new file mode 100644 index 000000000..48a7b7f2a --- /dev/null +++ b/src/main/java/org/opensearch/knn/service/VectorEngineService.java @@ -0,0 +1,105 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.service; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.NoArgsConstructor; +import lombok.Value; +import lombok.extern.log4j.Log4j2; +import org.opensearch.knn.index.SpaceType; + +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; + +@NoArgsConstructor(access = AccessLevel.PRIVATE) +@Log4j2 +public class VectorEngineService { + private final Map luceneDocIdToVectorEngineDocId = new ConcurrentHashMap<>(); + // Keeping the OSLuceneDocIds as list to ensure that if IndexReader is open we can reuse the docIds + private final Map> vectorEngineDocIdToLuceneDocId = new ConcurrentHashMap<>(); + private final Map vectorEngineDocIdToVector = new ConcurrentHashMap<>(); + private final AtomicInteger currentVectorDocId = new AtomicInteger(0); + + private static VectorEngineService INSTANCE = null; + + public static VectorEngineService getInstance() { + if (INSTANCE == null) { + INSTANCE = new VectorEngineService(); + } + return INSTANCE; + } + + public void ingestData(final OSLuceneDocId luceneDocId, float[] vector, final SpaceType spaceType) { + log.debug("SpaceType during ingestion is : {}", spaceType); + luceneDocIdToVectorEngineDocId.put(luceneDocId, currentVectorDocId.intValue()); + int currentDocId = currentVectorDocId.intValue(); + vectorEngineDocIdToLuceneDocId.getOrDefault(currentDocId, Collections.synchronizedList(new LinkedList<>())).add(luceneDocId); + vectorEngineDocIdToVector.put(currentDocId, vector); + currentVectorDocId.incrementAndGet(); + } + + public List search(int k, float[] queryVector, final SpaceType spaceType) { + int finalk = Math.min(k, vectorEngineDocIdToVector.size()); + PriorityQueue scoreDocsQueue = new PriorityQueue<>((a, b) -> Float.compare(a.score, b.score)); + for (int docId : vectorEngineDocIdToVector.keySet()) { + float score = spaceType.getKnnVectorSimilarityFunction() + .getVectorSimilarityFunction() + .compare(queryVector, vectorEngineDocIdToVector.get(docId)); + if (scoreDocsQueue.size() < finalk) { + scoreDocsQueue.add(new VectorScoreDoc(score, docId)); + } else { + assert scoreDocsQueue.peek() != null; + if (score > scoreDocsQueue.peek().score) { + scoreDocsQueue.poll(); + scoreDocsQueue.add(new VectorScoreDoc(score, docId)); + } + } + } + + return scoreDocsQueue.parallelStream() + .flatMap(s -> vectorEngineDocIdToLuceneDocId.get(s.getDocId()).parallelStream() + .map(osLuceneDocId -> osLuceneDocId.cloneWithScore(s.getScore()))) + .collect(Collectors.toList()); + } + + public void removeOldSegmentKeys(final byte[] segmentId) { + for(OSLuceneDocId osLuceneDocId : luceneDocIdToVectorEngineDocId.keySet()) { + if(Arrays.equals(osLuceneDocId.getSegmentId(), segmentId)) { + int vectorEngineDocId = luceneDocIdToVectorEngineDocId.remove(osLuceneDocId); + + List osLuceneDocIds = vectorEngineDocIdToLuceneDocId.get(vectorEngineDocId); + for(OSLuceneDocId docId : osLuceneDocIds) { + luceneDocIdToVectorEngineDocId.remove(docId); + } + // if all the keys are removed + if(luceneDocIdToVectorEngineDocId.isEmpty()) { + // remove the node from vector search DS. + vectorEngineDocIdToVector.remove(vectorEngineDocId); + // remove the node from VectorEngine to Lucene DocId Map + vectorEngineDocIdToLuceneDocId.remove(vectorEngineDocId); + // remove the luceneDocId to Vector Engine DocId too. + luceneDocIdToVectorEngineDocId.remove(osLuceneDocId); + } + } + } + } + + @Value + @Builder + private static class VectorScoreDoc { + float score; + int docId; + } + +} diff --git a/src/main/java/org/opensearch/knn/service/VectorEngineServiceWrapper.java b/src/main/java/org/opensearch/knn/service/VectorEngineServiceWrapper.java new file mode 100644 index 000000000..ba4f4ae93 --- /dev/null +++ b/src/main/java/org/opensearch/knn/service/VectorEngineServiceWrapper.java @@ -0,0 +1,49 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.service; + +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.search.DocIdSetIterator; +import org.opensearch.knn.common.FieldInfoExtractor; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.query.KNNQuery; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; + +import java.io.IOException; +import java.util.List; + +@Log4j2 +public class VectorEngineServiceWrapper { + + public static void ingestData(final KNNVectorValues knnVectorValues, + final SegmentWriteState segmentWriteState, final FieldInfo fieldInfo) + throws IOException { + byte[] segmentId = segmentWriteState.segmentInfo.getId(); + VectorEngineService vectorEngineService = VectorEngineService.getInstance(); + SpaceType spaceType = FieldInfoExtractor.getSpaceType(null, fieldInfo); + + for (int docId = knnVectorValues.nextDoc(); docId != DocIdSetIterator.NO_MORE_DOCS; docId = knnVectorValues.nextDoc()) { + log.debug("Adding DocId: {}", docId); + // we need to + vectorEngineService.ingestData( + OSLuceneDocId.builder().segmentDocId(docId).segmentId(segmentId).build(), + (float[]) knnVectorValues.getVector(), spaceType + ); + } + } + + public static List search(final KNNQuery knnQuery, final FieldInfo fieldInfo) { + VectorEngineService vectorEngineService = VectorEngineService.getInstance(); + SpaceType spaceType = FieldInfoExtractor.getSpaceType(null, fieldInfo); + return vectorEngineService.search(knnQuery.getK(), knnQuery.getQueryVector(), spaceType); + } + + public static void close(final SegmentWriteState segmentWriteState) { + VectorEngineService.getInstance().removeOldSegmentKeys(segmentWriteState.segmentInfo.getId()); + } +}