diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index 2aafa2ece..bd920f41a 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -58,6 +58,7 @@ import org.opensearch.index.mapper.TextFieldMapper; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.DisMaxQueryBuilder; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.index.remote.RemoteStoreEnums; @@ -516,6 +517,104 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBool_thenFail() { releaseResources(directory, w, reader); } + @SneakyThrows + public void testWrappedHybridQuery_whenHybridNestedInDisjunctionQuery_thenFail() { + HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher(); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + MapperService mapperService = mock(MapperService.class); + when(mapperService.hasNested()).thenReturn(false); + + Directory directory = newDirectory(); + IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + int docId1 = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + SearchContext searchContext = mock(SearchContext.class); + + ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + null, + searchContext + ); + + ShardId shardId = new ShardId(dummyIndex, 1); + SearchShardTarget shardTarget = new SearchShardTarget( + randomAlphaOfLength(10), + shardId, + randomAlphaOfLength(10), + OriginalIndices.NONE + ); + when(searchContext.shardTarget()).thenReturn(shardTarget); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + when(searchContext.size()).thenReturn(4); + QuerySearchResult querySearchResult = new QuerySearchResult(); + when(searchContext.queryResult()).thenReturn(querySearchResult); + when(searchContext.numberOfShards()).thenReturn(1); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + IndexShard indexShard = mock(IndexShard.class); + when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); + when(indexShard.getSearchOperationListener()).thenReturn(mock(SearchOperationListener.class)); + when(searchContext.indexShard()).thenReturn(indexShard); + when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); + when(searchContext.getQueryShardContext()).thenReturn(mockQueryShardContext); + IndexMetadata indexMetadata = getIndexMetadata(); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + + LinkedList collectors = new LinkedList<>(); + boolean hasFilterCollector = randomBoolean(); + boolean hasTimeout = randomBoolean(); + + // Create a HybridQueryBuilder + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1)); + hybridQueryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2)); + hybridQueryBuilder.paginationDepth(10); + + // Create a regular term query + TermQueryBuilder termQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2); + + // Create a disjunction query (OR) with the hybrid query and the term query + DisMaxQueryBuilder disjunctionMaxQueryBuilder = QueryBuilders.disMaxQuery().add(hybridQueryBuilder).add(termQuery); + + Query query = disjunctionMaxQueryBuilder.toQuery(mockQueryShardContext); + when(searchContext.query()).thenReturn(query); + + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> hybridQueryPhaseSearcher.searchWith( + searchContext, + contextIndexSearcher, + query, + collectors, + hasFilterCollector, + hasTimeout + ) + ); + + org.hamcrest.MatcherAssert.assertThat( + exception.getMessage(), + containsString("hybrid query must be a top level query and cannot be wrapped into other queries") + ); + + releaseResources(directory, w, reader); + } + @SneakyThrows public void testWrappedHybridQuery_whenHybridWrappedIntoBoolAndIncorrectStructure_thenFail() { HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher();