Skip to content

Commit

Permalink
Add check to directly use ANN Search when filters match all docs. (op…
Browse files Browse the repository at this point in the history
…ensearch-project#2320)

* Add check to directly use ANN Search when filters match all docs.

Signed-off-by: Wei Wang <[email protected]>

* Fix failed tests and rebase on main branch

Signed-off-by: Wei Wang <[email protected]>

* pass filterbitset as null and add integ tests.

Signed-off-by: Wei Wang <[email protected]>

---------

Signed-off-by: Wei Wang <[email protected]>
Co-authored-by: Wei Wang <[email protected]>
  • Loading branch information
weiwang118 and Wei Wang authored Jan 3, 2025
1 parent c969f1d commit 6f5313f
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
- Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290]
- Optimizes lucene query execution to prevent unnecessary rewrites (#2305)[https://github.com/opensearch-project/k-NN/pull/2305]
- Add check to directly use ANN Search when filters match all docs. (#2320)[https://github.com/opensearch-project/k-NN/pull/2320]
### Bug Fixes
* Fixing the bug when a segment has no vector field present for disk based vector search (#2282)[https://github.com/opensearch-project/k-NN/pull/2282]
* Allow validation for non knn index only after 2.17.0 (#2315)[https://github.com/opensearch-project/k-NN/pull/2315]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ public enum FilterIdsSelectorType {
public static FilterIdsSelector getFilterIdSelector(final BitSet filterIdsBitSet, final int cardinality) throws IOException {
long[] filterIds;
FilterIdsSelector.FilterIdsSelectorType filterType;
if (filterIdsBitSet instanceof FixedBitSet) {
if (filterIdsBitSet == null) {
filterIds = null;
filterType = FilterIdsSelector.FilterIdsSelectorType.BITMAP;
} else if (filterIdsBitSet instanceof FixedBitSet) {
/**
* When filterIds is dense filter, using fixed bitset
*/
Expand Down
10 changes: 9 additions & 1 deletion src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
*/
public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOException {
final BitSet filterBitSet = getFilteredDocsBitSet(context);
final int maxDoc = context.reader().maxDoc();
int cardinality = filterBitSet.cardinality();
// We don't need to go to JNI layer if no documents are found which satisfy the filters
// We should give this condition a deeper look that where it should be placed. For now I feel this is a good
Expand All @@ -145,7 +146,14 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep
Map<Integer, Float> result = doExactSearch(context, new BitSetIterator(filterBitSet, cardinality), cardinality, k);
return new PerLeafResult(filterWeight == null ? null : filterBitSet, result);
}
Map<Integer, Float> docIdsToScoreMap = doANNSearch(context, filterBitSet, cardinality, k);

/*
* If filters match all docs in this segment, then null should be passed as filterBitSet
* so that it will not do a bitset look up in bottom search layer.
*/
final BitSet annFilter = (filterWeight != null && cardinality == maxDoc) ? null : filterBitSet;
final Map<Integer, Float> docIdsToScoreMap = doANNSearch(context, annFilter, cardinality, k);

// See whether we have to perform exact search based on approx search results
// This is required if there are no native engine files or if approximate search returned
// results less than K, though we have more than k filtered docs
Expand Down
92 changes: 88 additions & 4 deletions src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ public void validateANNWithFilterQuery_whenDoingANN_thenSuccess(final boolean is
when(liveDocsBits.length()).thenReturn(1000);

final SegmentReader reader = mockSegmentReader();
when(reader.maxDoc()).thenReturn(filterDocIds.length);
when(reader.maxDoc()).thenReturn(filterDocIds.length + 1);
when(reader.getLiveDocs()).thenReturn(liveDocsBits);

final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
Expand Down Expand Up @@ -758,6 +758,88 @@ public void validateANNWithFilterQuery_whenDoingANN_thenSuccess(final boolean is
assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder()));
}

@SneakyThrows
public void testANNWithFilterQuery_whenFiltersMatchAllDocs_thenSuccess() {
// Given
int k = 3;
final int[] filterDocIds = new int[] { 0, 1, 2, 3, 4, 5 };
FixedBitSet filterBitSet = new FixedBitSet(filterDocIds.length);
for (int docId : filterDocIds) {
filterBitSet.set(docId);
}

jniServiceMockedStatic.when(
() -> JNIService.queryIndex(anyLong(), eq(QUERY_VECTOR), eq(k), eq(HNSW_METHOD_PARAMETERS), any(), eq(null), anyInt(), any())
).thenReturn(getFilteredKNNQueryResults());

final Bits liveDocsBits = mock(Bits.class);
for (int filterDocId : filterDocIds) {
when(liveDocsBits.get(filterDocId)).thenReturn(true);
}
when(liveDocsBits.length()).thenReturn(1000);

final SegmentReader reader = mockSegmentReader();
when(reader.maxDoc()).thenReturn(filterDocIds.length);
when(reader.getLiveDocs()).thenReturn(liveDocsBits);

final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
when(leafReaderContext.reader()).thenReturn(reader);

final KNNQuery query = KNNQuery.builder()
.field(FIELD_NAME)
.queryVector(QUERY_VECTOR)
.k(k)
.indexName(INDEX_NAME)
.filterQuery(FILTER_QUERY)
.methodParameters(HNSW_METHOD_PARAMETERS)
.build();

final Weight filterQueryWeight = mock(Weight.class);
final Scorer filterScorer = mock(Scorer.class);
when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer);
// Just to make sure that we are not hitting the exact search condition
when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(filterDocIds.length + 1));

final float boost = (float) randomDoubleBetween(0, 10, true);
final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight);

final FieldInfos fieldInfos = mock(FieldInfos.class);
final FieldInfo fieldInfo = mock(FieldInfo.class);
final Map<String, String> attributesMap = ImmutableMap.of(
KNN_ENGINE,
KNNEngine.FAISS.getName(),
SPACE_TYPE,
SpaceType.L2.getValue()
);

when(reader.getFieldInfos()).thenReturn(fieldInfos);
when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo);
when(fieldInfo.attributes()).thenReturn(attributesMap);

// When
final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext);

// Then
assertNotNull(knnScorer);
final DocIdSetIterator docIdSetIterator = knnScorer.iterator();
assertNotNull(docIdSetIterator);
assertEquals(FILTERED_DOC_ID_TO_SCORES.size(), docIdSetIterator.cost());

jniServiceMockedStatic.verify(
() -> JNIService.queryIndex(anyLong(), eq(QUERY_VECTOR), eq(k), eq(HNSW_METHOD_PARAMETERS), any(), any(), anyInt(), any()),
times(1)
);

final List<Integer> actualDocIds = new ArrayList<>();
final Map<Integer, Float> translatedScores = getTranslatedScores(SpaceType.L2::scoreTranslation);
for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) {
actualDocIds.add(docId);
assertEquals(translatedScores.get(docId) * boost, knnScorer.score(), 0.01f);
}
assertEquals(docIdSetIterator.cost(), actualDocIds.size());
assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder()));
}

private SegmentReader mockSegmentReader() {
Path path = mock(Path.class);

Expand Down Expand Up @@ -815,7 +897,7 @@ public void validateANNWithFilterQuery_whenExactSearch_thenSuccess(final boolean
when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer);
// scorer will return 2 documents
when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(1));
when(reader.maxDoc()).thenReturn(1);
when(reader.maxDoc()).thenReturn(2);
final Bits liveDocsBits = mock(Bits.class);
when(reader.getLiveDocs()).thenReturn(liveDocsBits);
when(liveDocsBits.get(filterDocId)).thenReturn(true);
Expand Down Expand Up @@ -891,6 +973,7 @@ public void testRadialSearch_whenNoEngineFiles_thenPerformExactSearch() {
final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
final SegmentReader reader = mock(SegmentReader.class);
when(leafReaderContext.reader()).thenReturn(reader);
when(reader.maxDoc()).thenReturn(1);

final FSDirectory directory = mock(FSDirectory.class);
when(reader.directory()).thenReturn(directory);
Expand Down Expand Up @@ -968,7 +1051,7 @@ public void testANNWithFilterQuery_whenExactSearchAndThresholdComputations_thenS
when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer);
// scorer will return 2 documents
when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(1));
when(reader.maxDoc()).thenReturn(1);
when(reader.maxDoc()).thenReturn(2);
final Bits liveDocsBits = mock(Bits.class);
when(reader.getLiveDocs()).thenReturn(liveDocsBits);
when(liveDocsBits.get(filterDocId)).thenReturn(true);
Expand Down Expand Up @@ -1168,6 +1251,7 @@ public void testANNWithFilterQuery_whenEmptyFilterIds_thenReturnEarly() {
final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
final SegmentReader reader = mock(SegmentReader.class);
when(leafReaderContext.reader()).thenReturn(reader);
when(reader.maxDoc()).thenReturn(1);

final Weight filterQueryWeight = mock(Weight.class);
final Scorer filterScorer = mock(Scorer.class);
Expand Down Expand Up @@ -1202,7 +1286,7 @@ public void testANNWithParentsFilter_whenExactSearch_thenSuccess() {
// We will have 0, 1 for filteredIds and 2 will be the parent id for both of them
final Scorer filterScorer = mock(Scorer.class);
when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(2));
when(reader.maxDoc()).thenReturn(2);
when(reader.maxDoc()).thenReturn(3);

// Query vector is {1.8f, 2.4f}, therefore, second vector {1.9f, 2.5f} should be returned in a result
final List<float[]> vectors = Arrays.asList(new float[] { 0.1f, 0.3f }, new float[] { 1.9f, 2.5f });
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.integ;

import com.google.common.collect.ImmutableMap;
import lombok.SneakyThrows;
import lombok.extern.log4j.Log4j2;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.opensearch.client.Response;
import org.opensearch.common.settings.Settings;
import org.opensearch.knn.KNNJsonQueryBuilder;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.knn.index.KNNSettings;
import java.util.List;

import static org.opensearch.knn.common.KNNConstants.FAISS_NAME;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;

@Log4j2
public class FilteredSearchANNSearchIT extends KNNRestTestCase {
@SneakyThrows
public void testFilteredSearchWithFaissHnsw_whenFiltersMatchAllDocs_thenReturnCorrectResults() {
String filterFieldName = "color";
final int expectResultSize = randomIntBetween(1, 3);
final String filterValue = "red";
createKnnIndex(INDEX_NAME, getKNNDefaultIndexSettings(), createKnnIndexMapping(FIELD_NAME, 3, METHOD_HNSW, FAISS_NAME));

// ingest 4 vector docs into the index with the same field {"color": "red"}
for (int i = 0; i < 4; i++) {
addKnnDocWithAttributes(String.valueOf(i), new float[] { i, i, i }, ImmutableMap.of(filterFieldName, filterValue));
}

refreshIndex(INDEX_NAME);
forceMergeKnnIndex(INDEX_NAME);

updateIndexSettings(INDEX_NAME, Settings.builder().put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, 0));

Float[] queryVector = { 3f, 3f, 3f };
// All docs in one segment will match the filters value
String query = KNNJsonQueryBuilder.builder()
.fieldName(FIELD_NAME)
.vector(queryVector)
.k(expectResultSize)
.filterFieldName(filterFieldName)
.filterValue(filterValue)
.build()
.getQueryString();
Response response = searchKNNIndex(INDEX_NAME, query, expectResultSize);
String entity = EntityUtils.toString(response.getEntity());
List<String> docIds = parseIds(entity);
assertEquals(expectResultSize, docIds.size());
assertEquals(expectResultSize, parseTotalSearchHits(entity));
}
}

0 comments on commit 6f5313f

Please sign in to comment.