Skip to content

Commit

Permalink
Initial commit for segment-free-vector-search
Browse files Browse the repository at this point in the history
Signed-off-by: Navneet Verma <[email protected]>
  • Loading branch information
navneet1v committed Nov 4, 2024
1 parent 7d87c52 commit cf1b124
Show file tree
Hide file tree
Showing 8 changed files with 302 additions and 2 deletions.
14 changes: 13 additions & 1 deletion src/main/java/org/opensearch/knn/index/KNNSettings.java
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ public class KNNSettings {
public static final String QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES = "knn.quantization.cache.expiry.minutes";
public static final String KNN_FAISS_AVX512_DISABLED = "knn.faiss.avx512.disabled";
public static final String KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED = "index.knn.disk.vector.shard_level_rescoring_disabled";
public static final String USE_NEW_QUERY = "index.knn.use_new_query";

/**
* Default setting values
Expand Down Expand Up @@ -143,6 +144,8 @@ public class KNNSettings {
Dynamic
);

public static final Setting<Boolean> USE_NEW_QUERY_SETTING = Setting.boolSetting(USE_NEW_QUERY, false, IndexScope, Dynamic);

// This setting controls how much memory should be used to transfer vectors from Java to JNI Layer. The default
// 1% of the JVM heap
public static final Setting<ByteSizeValue> KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING = Setting.memorySizeSetting(
Expand Down Expand Up @@ -499,6 +502,10 @@ private Setting<?> getSetting(String key) {
return KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING;
}

if (USE_NEW_QUERY.equals(key)) {
return USE_NEW_QUERY_SETTING;
}

throw new IllegalArgumentException("Cannot find setting by key [" + key + "]");
}

Expand All @@ -522,7 +529,8 @@ public List<Setting<?>> getSettings() {
KNN_FAISS_AVX512_DISABLED_SETTING,
QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING,
QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING,
KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING
KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING,
USE_NEW_QUERY_SETTING
);
return Stream.concat(settings.stream(), Stream.concat(getFeatureFlags().stream(), dynamicCacheSettings.values().stream()))
.collect(Collectors.toList());
Expand Down Expand Up @@ -577,6 +585,10 @@ public static Integer getFilteredExactSearchThreshold(final String indexName) {
.getAsInt(ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE);
}

public static boolean isUseNewQuery(final String indexName) {
return KNNSettings.state().clusterService.state().getMetadata().index(indexName).getSettings().getAsBoolean(USE_NEW_QUERY, false);
}

public static boolean isShardLevelRescoringEnabledForDiskBasedVector(String indexName) {
return KNNSettings.state().clusterService.state()
.getMetadata()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.opensearch.knn.plugin.stats.KNNGraphValue;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.knn.service.VectorEngineServiceWrapper;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -104,6 +105,7 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException {
field.getVectors()
);
final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs);
VectorEngineServiceWrapper.ingestData(knnVectorValuesSupplier.get(), segmentWriteState, fieldInfo);
// Check only after quantization state writer finish writing its state, since it is required
// even if there are no graph files in segment, which will be later used by exact search
if (shouldSkipBuildingVectorDataStructure(totalLiveDocs)) {
Expand Down Expand Up @@ -165,6 +167,7 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState
long time_in_millis = stopWatch.stop().totalTime().millis();
KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.incrementBy(time_in_millis);
log.debug("Merge took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName());

}

/**
Expand Down Expand Up @@ -201,6 +204,7 @@ public void close() throws IOException {
quantizationStateWriter.closeOutput();
}
IOUtils.close(flatVectorsWriter);
VectorEngineServiceWrapper.close(segmentWriteState);
}

/**
Expand Down
11 changes: 11 additions & 0 deletions src/main/java/org/opensearch/knn/index/query/KNNQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.FieldExistsQuery;
Expand All @@ -22,9 +23,13 @@
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.knn.service.OSLuceneDocId;
import org.opensearch.knn.service.VectorEngineService;
import org.opensearch.knn.service.VectorEngineServiceWrapper;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;

Expand Down Expand Up @@ -167,6 +172,12 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo
if (!KNNSettings.isKNNPluginEnabled()) {
throw new IllegalStateException("KNN plugin is disabled. To enable update knn.plugin.enabled to true");
}
if (KNNSettings.isUseNewQuery(indexName)) {
FieldInfo fieldInfo =
searcher.getIndexReader().leaves().get(0).reader().getFieldInfos().fieldInfo(this.getField());
List<OSLuceneDocId> docIdList = VectorEngineServiceWrapper.search(this, fieldInfo);
return new KNNWeightV2(this, docIdList);
}
final Weight filterWeight = getFilterWeight(searcher);
if (filterWeight != null) {
return new KNNWeight(this, boost, filterWeight);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public int docID() {
* @param knnWeight {@link KNNWeight}
* @return {@link KNNScorer}
*/
public static Scorer emptyScorer(KNNWeight knnWeight) {
public static Scorer emptyScorer(Weight knnWeight) {
return new Scorer(knnWeight) {
private final DocIdSetIterator docIdsIter = DocIdSetIterator.empty();

Expand Down
86 changes: 86 additions & 0 deletions src/main/java/org/opensearch/knn/index/query/KNNWeightV2.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SegmentReader;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.opensearch.common.lucene.Lucene;
import org.opensearch.knn.service.OSLuceneDocId;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class KNNWeightV2 extends Weight {

private final KNNQuery knnQuery;
private final List<OSLuceneDocId> osLuceneDocIds;

public KNNWeightV2(KNNQuery knnQuery, List<OSLuceneDocId> osLuceneDocIds) {
super(knnQuery);
this.knnQuery = knnQuery;
this.osLuceneDocIds = osLuceneDocIds;
}

/**
* An explanation of the score computation for the named document.
*
* @param context the readers context to create the {@link Explanation} for.
* @param doc the document's id relative to the given context's reader
* @return an Explanation for the score
* @throws IOException if an {@link IOException} occurs
*/
@Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
return null;
}

/**
* Returns a {@link Scorer} which can iterate in order over all matching documents and assign them
* a score.
*
* <p><b>NOTE:</b> null can be returned if no documents will be scored by this query.
*
* <p><b>NOTE</b>: The returned {@link Scorer} does not have {@link LeafReader#getLiveDocs()}
* applied, they need to be checked on top.
*
* @param context the {@link LeafReaderContext} for which to return the
* {@link Scorer}.
* @return a {@link Scorer} which scores documents in/out-of order.
* @throws IOException if there is a low-level I/O error
*/
@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
SegmentReader segmentReader = Lucene.segmentReader(context.reader());
byte[] segmentId = segmentReader.getSegmentInfo().info.getId();
Map<Integer, Float> docIdsToScoreMap = new HashMap<>();
for (OSLuceneDocId osLuceneDocId : osLuceneDocIds) {
if (Arrays.equals(osLuceneDocId.getSegmentId(), segmentId)) {
docIdsToScoreMap.put(osLuceneDocId.getSegmentDocId(), osLuceneDocId.getScore());
}
}
if (docIdsToScoreMap.isEmpty()) {
return KNNScorer.emptyScorer(this);
}
final int maxDoc = Collections.max(docIdsToScoreMap.keySet()) + 1;
return new KNNScorer(this, ResultUtil.resultMapToDocIds(docIdsToScoreMap, maxDoc), docIdsToScoreMap, 1);
}

/**
* @param ctx
* @return {@code true} if the object can be cached against a given leaf
*/
@Override
public boolean isCacheable(LeafReaderContext ctx) {
return true;
}
}
33 changes: 33 additions & 0 deletions src/main/java/org/opensearch/knn/service/OSLuceneDocId.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.service;

import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.ToString;
import lombok.Value;

@Value
@Builder
@EqualsAndHashCode
@ToString
public class OSLuceneDocId {
@Builder.Default
String opensearchIndexName = "my-index";
byte[] segmentId;
int segmentDocId;
@Builder.Default
float score = 0;

public OSLuceneDocId cloneWithScore(float score) {
return OSLuceneDocId.builder()
.score(score)
.segmentDocId(segmentDocId)
.segmentId(segmentId)
.opensearchIndexName(opensearchIndexName)
.build();
}
}
105 changes: 105 additions & 0 deletions src/main/java/org/opensearch/knn/service/VectorEngineService.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.service;

import lombok.AccessLevel;
import lombok.Builder;
import lombok.NoArgsConstructor;
import lombok.Value;
import lombok.extern.log4j.Log4j2;
import org.opensearch.knn.index.SpaceType;

import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

@NoArgsConstructor(access = AccessLevel.PRIVATE)
@Log4j2
public class VectorEngineService {
private final Map<OSLuceneDocId, Integer> luceneDocIdToVectorEngineDocId = new ConcurrentHashMap<>();
// Keeping the OSLuceneDocIds as list to ensure that if IndexReader is open we can reuse the docIds
private final Map<Integer, List<OSLuceneDocId>> vectorEngineDocIdToLuceneDocId = new ConcurrentHashMap<>();
private final Map<Integer, float[]> vectorEngineDocIdToVector = new ConcurrentHashMap<>();
private final AtomicInteger currentVectorDocId = new AtomicInteger(0);

private static VectorEngineService INSTANCE = null;

public static VectorEngineService getInstance() {
if (INSTANCE == null) {
INSTANCE = new VectorEngineService();
}
return INSTANCE;
}

public void ingestData(final OSLuceneDocId luceneDocId, float[] vector, final SpaceType spaceType) {
log.debug("SpaceType during ingestion is : {}", spaceType);
luceneDocIdToVectorEngineDocId.put(luceneDocId, currentVectorDocId.intValue());
int currentDocId = currentVectorDocId.intValue();
vectorEngineDocIdToLuceneDocId.getOrDefault(currentDocId, Collections.synchronizedList(new LinkedList<>())).add(luceneDocId);
vectorEngineDocIdToVector.put(currentDocId, vector);
currentVectorDocId.incrementAndGet();
}

public List<OSLuceneDocId> search(int k, float[] queryVector, final SpaceType spaceType) {
int finalk = Math.min(k, vectorEngineDocIdToVector.size());
PriorityQueue<VectorScoreDoc> scoreDocsQueue = new PriorityQueue<>((a, b) -> Float.compare(a.score, b.score));
for (int docId : vectorEngineDocIdToVector.keySet()) {
float score = spaceType.getKnnVectorSimilarityFunction()
.getVectorSimilarityFunction()
.compare(queryVector, vectorEngineDocIdToVector.get(docId));
if (scoreDocsQueue.size() < finalk) {
scoreDocsQueue.add(new VectorScoreDoc(score, docId));
} else {
assert scoreDocsQueue.peek() != null;
if (score > scoreDocsQueue.peek().score) {
scoreDocsQueue.poll();
scoreDocsQueue.add(new VectorScoreDoc(score, docId));
}
}
}

return scoreDocsQueue.parallelStream()
.flatMap(s -> vectorEngineDocIdToLuceneDocId.get(s.getDocId()).parallelStream()
.map(osLuceneDocId -> osLuceneDocId.cloneWithScore(s.getScore())))
.collect(Collectors.toList());
}

public void removeOldSegmentKeys(final byte[] segmentId) {
for(OSLuceneDocId osLuceneDocId : luceneDocIdToVectorEngineDocId.keySet()) {
if(Arrays.equals(osLuceneDocId.getSegmentId(), segmentId)) {
int vectorEngineDocId = luceneDocIdToVectorEngineDocId.remove(osLuceneDocId);

List<OSLuceneDocId> osLuceneDocIds = vectorEngineDocIdToLuceneDocId.get(vectorEngineDocId);
for(OSLuceneDocId docId : osLuceneDocIds) {
luceneDocIdToVectorEngineDocId.remove(docId);
}
// if all the keys are removed
if(luceneDocIdToVectorEngineDocId.isEmpty()) {
// remove the node from vector search DS.
vectorEngineDocIdToVector.remove(vectorEngineDocId);
// remove the node from VectorEngine to Lucene DocId Map
vectorEngineDocIdToLuceneDocId.remove(vectorEngineDocId);
// remove the luceneDocId to Vector Engine DocId too.
luceneDocIdToVectorEngineDocId.remove(osLuceneDocId);
}
}
}
}

@Value
@Builder
private static class VectorScoreDoc {
float score;
int docId;
}

}
Loading

0 comments on commit cf1b124

Please sign in to comment.