Skip to content

Commit

Permalink
[Backport 2.x] ByFieldRerank Processor (ReRankProcessor enhancement) (#…
Browse files Browse the repository at this point in the history
…960) (#962)

* ByFieldRerank Processor (ReRankProcessor enhancement) (#932)

Signed-off-by: Brian Flores <[email protected]>
(cherry picked from commit 858ff28)

Signed-off-by: Martin Gaievski <[email protected]>
Co-authored-by: Brian Flores <[email protected]>
Co-authored-by: Martin Gaievski <[email protected]>
  • Loading branch information
3 people authored Oct 22, 2024
1 parent 5396c6e commit d4d13c6
Show file tree
Hide file tree
Showing 13 changed files with 2,050 additions and 27 deletions.
2 changes: 1 addition & 1 deletion DEVELOPER_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ merged to main, the workflow will create a backport PR to the `2.x` branch.

## Building On Lucene Version Updates
There may be a Lucene version update that can affect your workflow causing errors like
`java.lang.NoClassDefFoundError: org/apache/lucene/codecs/lucene99/Lucene99Codec` or
`java.lang.NoClassDefFoundError: org/apache/lucene/codecs/lucene99/Lucene99Codec` or
`Provider org.opensearch.knn.index.codec.KNN910Codec.KNN910Codec could not be instantiated`. In this case
we can observe there may be an issue with a dependency with [K-NN](https://github.com/opensearch-project/k-NN).
This results in having issues with not being able to do `./gradlew run` or `./gradlew build`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,16 @@ public void testHybridQueryWithRescore_whenIndexWithMultipleShards_E2EFlow() thr
loadModel(modelId);
createPipelineProcessor(modelId, PIPELINE_NAME);
createIndexWithConfiguration(
getIndexNameForTest(),
Files.readString(Path.of(classLoader.getResource("processor/IndexMappings.json").toURI())),
PIPELINE_NAME
getIndexNameForTest(),
Files.readString(Path.of(classLoader.getResource("processor/IndexMappings.json").toURI())),
PIPELINE_NAME
);
addDocument(getIndexNameForTest(), "0", TEST_FIELD, TEXT, null, null);
createSearchPipeline(
SEARCH_PIPELINE_NAME,
DEFAULT_NORMALIZATION_METHOD,
DEFAULT_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.3f, 0.7f }))
SEARCH_PIPELINE_NAME,
DEFAULT_NORMALIZATION_METHOD,
DEFAULT_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.3f, 0.7f }))
);
break;
case MIXED:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

Compatible with OpenSearch 2.18.0

### Features
- Introduces ByFieldRerankProcessor for second level reranking on documents ([#932](https://github.com/opensearch-project/neural-search/pull/932))
### Bug Fixes
- Fixed incorrect document order for nested aggregations in hybrid query ([#956](https://github.com/opensearch-project/neural-search/pull/956))
### Enhancements
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,12 @@
*/
package org.opensearch.neuralsearch.processor.factory;

import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.StringJoiner;

import com.google.common.collect.Sets;
import lombok.AllArgsConstructor;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.rerank.ByFieldRerankProcessor;
import org.opensearch.neuralsearch.processor.rerank.MLOpenSearchRerankProcessor;
import org.opensearch.neuralsearch.processor.rerank.RerankType;
import org.opensearch.neuralsearch.processor.rerank.context.ContextSourceFetcher;
Expand All @@ -22,9 +18,17 @@
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchResponseProcessor;

import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.StringJoiner;

import lombok.AllArgsConstructor;
import static org.opensearch.neuralsearch.processor.rerank.ByFieldRerankProcessor.DEFAULT_KEEP_PREVIOUS_SCORE;
import static org.opensearch.neuralsearch.processor.rerank.ByFieldRerankProcessor.DEFAULT_REMOVE_TARGET_FIELD;
import static org.opensearch.neuralsearch.processor.rerank.RerankProcessor.processorRequiresContext;

/**
* Factory for rerank processors. Must:
Expand All @@ -51,22 +55,55 @@ public SearchResponseProcessor create(
) {
RerankType type = findRerankType(config);
boolean includeQueryContextFetcher = ContextFetcherFactory.shouldIncludeQueryContextFetcher(type);
List<ContextSourceFetcher> contextFetchers = ContextFetcherFactory.createFetchers(
config,
includeQueryContextFetcher,
tag,
clusterService
);

// Currently the createFetchers method requires that you provide a context map, this branch makes sure we can ignore this on
// processors that don't need the context map
List<ContextSourceFetcher> contextFetchers = processorRequiresContext(type)
? ContextFetcherFactory.createFetchers(config, includeQueryContextFetcher, tag, clusterService)
: Collections.emptyList();

Map<String, Object> rerankerConfig = ConfigurationUtils.readMap(RERANK_PROCESSOR_TYPE, tag, config, type.getLabel());

switch (type) {
case ML_OPENSEARCH:
Map<String, Object> rerankerConfig = ConfigurationUtils.readMap(RERANK_PROCESSOR_TYPE, tag, config, type.getLabel());
String modelId = ConfigurationUtils.readStringProperty(
RERANK_PROCESSOR_TYPE,
tag,
rerankerConfig,
MLOpenSearchRerankProcessor.MODEL_ID_FIELD
);
return new MLOpenSearchRerankProcessor(description, tag, ignoreFailure, modelId, contextFetchers, clientAccessor);
case BY_FIELD:
String targetField = ConfigurationUtils.readStringProperty(
RERANK_PROCESSOR_TYPE,
tag,
rerankerConfig,
ByFieldRerankProcessor.TARGET_FIELD
);
boolean removeTargetField = ConfigurationUtils.readBooleanProperty(
RERANK_PROCESSOR_TYPE,
tag,
rerankerConfig,
ByFieldRerankProcessor.REMOVE_TARGET_FIELD,
DEFAULT_REMOVE_TARGET_FIELD
);
boolean keepPreviousScore = ConfigurationUtils.readBooleanProperty(
RERANK_PROCESSOR_TYPE,
tag,
rerankerConfig,
ByFieldRerankProcessor.KEEP_PREVIOUS_SCORE,
DEFAULT_KEEP_PREVIOUS_SCORE
);

return new ByFieldRerankProcessor(
description,
tag,
ignoreFailure,
targetField,
removeTargetField,
keepPreviousScore,
contextFetchers
);
default:
throw new IllegalArgumentException(String.format(Locale.ROOT, "Cannot build reranker type %s", type.getLabel()));
}
Expand Down Expand Up @@ -100,6 +137,7 @@ private static class ContextFetcherFactory {

/**
* Map rerank types to whether they should include the query context source fetcher
*
* @param type the constructing RerankType
* @return does this RerankType depend on the QueryContextSourceFetcher?
*/
Expand All @@ -109,8 +147,8 @@ public static boolean shouldIncludeQueryContextFetcher(RerankType type) {

/**
* Create necessary queryContextFetchers for this processor
* @param config processor config object. Look for "context" field to find fetchers
* @param includeQueryContextFetcher should I include the queryContextFetcher?
* @param config Processor config object. Look for "context" field to find fetchers
* @param includeQueryContextFetcher Should I include the queryContextFetcher?
* @return list of contextFetchers for the processor to use
*/
public static List<ContextSourceFetcher> createFetchers(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor.rerank;

import lombok.extern.log4j.Log4j2;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.neuralsearch.processor.rerank.context.ContextSourceFetcher;
import org.opensearch.neuralsearch.processor.util.ProcessorUtils.SearchHitValidator;
import org.opensearch.search.SearchHit;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;

import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.getScoreFromSourceMap;
import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.getValueFromSource;
import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.mappingExistsInSource;
import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.removeTargetFieldFromSource;
import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.validateRerankCriteria;

/**
* A reranking processor that reorders search results based on the content of a specified field.
* <p>
* The ByFieldRerankProcessor allows for reordering of search results by considering the content of a
* designated target field within each document. This processor will update the <code>_score</code> field with what has been provided
* by {@code target_field}. When {@code keep_previous_score} is enabled a new field is appended called <code>previous_score</code> which was the score prior to reranking.
* <p>
* Key features:
* <ul>
* <li>Reranks search results based on a specified target field</li>
* <li>Optionally removes the target field from the final search results</li>
* <li>Supports nested field structures using dot notation</li>
* </ul>
* <p>
* The processor uses the following configuration parameters:
* <ul>
* <li>{@code target_field}: The field to be used for reranking (required)</li>
* <li>{@code remove_target_field}: Whether to remove the target field from the final results (optional, default: false)</li>
* <li>{@code keep_previous_score}: Whether to append the previous score in a field called <code>previous_score</code> (optional, default: false)</li>
* </ul>
* <p>
* Usage example:
* <pre>
* {
* "rerank": {
* "by_field": {
* "target_field": "document.relevance_score",
* "remove_target_field": true,
* "keep_previous_score": false
* }
* }
* }
* </pre>
* <p>
* This processor is useful in scenarios where additional, document-specific
* information stored in a field can be used to improve the relevance of search results
* beyond the initial scoring.
*/
@Log4j2
public class ByFieldRerankProcessor extends RescoringRerankProcessor {

public static final String TARGET_FIELD = "target_field";
public static final String REMOVE_TARGET_FIELD = "remove_target_field";
public static final String KEEP_PREVIOUS_SCORE = "keep_previous_score";

public static final boolean DEFAULT_REMOVE_TARGET_FIELD = false;
public static final boolean DEFAULT_KEEP_PREVIOUS_SCORE = false;

protected final String targetField;
protected final boolean removeTargetField;
protected final boolean keepPreviousScore;

/**
* Constructor to pass values to the RerankProcessor constructor.
*
* @param description The description of the processor
* @param tag The processor's identifier
* @param ignoreFailure If true, OpenSearch ignores any failure of this processor and
* continues to run the remaining processors in the search pipeline.
* @param targetField The field you want to replace your <code>_score</code> with
* @param removeTargetField A flag to let you delete the target_field for better visualization (i.e. removes a duplicate value)
* @param keepPreviousScore A flag to let you decide to stash your previous <code>_score</code> in a field called <code>previous_score</code> (i.e. for debugging purposes)
* @param contextSourceFetchers Context from some source and puts it in a map for a reranking processor to use <b> (Unused in ByFieldRerankProcessor)</b>
*/
public ByFieldRerankProcessor(
final String description,
final String tag,
final boolean ignoreFailure,
final String targetField,
final boolean removeTargetField,
final boolean keepPreviousScore,
final List<ContextSourceFetcher> contextSourceFetchers
) {
super(RerankType.BY_FIELD, description, tag, ignoreFailure, contextSourceFetchers);
this.targetField = targetField;
this.removeTargetField = removeTargetField;
this.keepPreviousScore = keepPreviousScore;
}

@Override
public void rescoreSearchResponse(
final SearchResponse response,
final Map<String, Object> rerankingContext,
final ActionListener<List<Float>> listener
) {
SearchHit[] searchHits = response.getHits().getHits();

SearchHitValidator searchHitValidator = this::byFieldSearchHitValidator;

if (!validateRerankCriteria(searchHits, searchHitValidator, listener)) {
return;
}

List<Float> scores = new ArrayList<>(searchHits.length);

for (SearchHit hit : searchHits) {
Map<String, Object> sourceAsMap = hit.getSourceAsMap();

float score = getScoreFromSourceMap(sourceAsMap, targetField);
scores.add(score);

if (keepPreviousScore) {
sourceAsMap.put("previous_score", hit.getScore());
}

if (removeTargetField) {
removeTargetFieldFromSource(sourceAsMap, targetField);
}

try {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
BytesReference sourceMapAsBytes = BytesReference.bytes(builder.map(sourceAsMap));
hit.sourceRef(sourceMapAsBytes);
} catch (IOException e) {
log.error(e.getMessage());
listener.onFailure(new RuntimeException(e));
return;
}
}

listener.onResponse(scores);
}

/**
* Implements the behavior of the SearchHit validator {@code SearchHitValidator}
* It checks all the following
* <ul>
* <li>Checks the search hit has a source mapping</li>
* <li>Checks that the mapping exists in the source mapping using the target_field</li>
* <li>Checks that the mapping has a numerical score for it to rerank</li>
* </ul>
* @param hit A search hit to validate
*/
public void byFieldSearchHitValidator(final SearchHit hit) {
if (!hit.hasSource()) {
log.error(String.format(Locale.ROOT, "There is no source field to be able to perform rerank on hit [%d]", hit.docId()));
throw new IllegalArgumentException(
String.format(Locale.ROOT, "There is no source field to be able to perform rerank on hit [%d]", hit.docId())
);
}

Map<String, Object> sourceMap = hit.getSourceAsMap();
if (!mappingExistsInSource(sourceMap, targetField)) {
log.error(String.format(Locale.ROOT, "The field to rerank [%s] is not found at hit [%d]", targetField, hit.docId()));

throw new IllegalArgumentException(String.format(Locale.ROOT, "The field to rerank by is not found at hit [%d]", hit.docId()));
}

Optional<Object> val = getValueFromSource(sourceMap, targetField);

if (!(val.get() instanceof Number)) {
log.error(String.format(Locale.ROOT, "The field mapping to rerank [%s: %s] is not Numerical", targetField, val.orElse(null)));

throw new IllegalArgumentException(
String.format(Locale.ROOT, "The field mapping to rerank by [%s] is not Numerical", val.orElse(null))
);
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public abstract class RerankProcessor implements SearchResponseProcessor {
@Getter
private final boolean ignoreFailure;
protected List<ContextSourceFetcher> contextSourceFetchers;
static final protected List<RerankType> processorsWithNoContext = List.of(RerankType.BY_FIELD);

/**
* Generate the information that this processor needs in order to rerank.
Expand All @@ -48,6 +49,11 @@ public void generateRerankingContext(
final SearchResponse searchResponse,
final ActionListener<Map<String, Object>> listener
) {
// Processors that don't require context, result on a listener infinitely waiting for a response without this check
if (!processorRequiresContext(subType)) {
listener.onResponse(Map.of());
}

Map<String, Object> overallContext = new ConcurrentHashMap<>();
AtomicInteger successfulContexts = new AtomicInteger(contextSourceFetchers.size());
for (ContextSourceFetcher csf : contextSourceFetchers) {
Expand Down Expand Up @@ -102,4 +108,19 @@ public void processResponseAsync(
responseListener.onFailure(e);
}
}

/**
* There are scenarios where ranking occurs without needing context. Currently, these are the processors don't require
* the context mapping
* <ul>
* <li>
* ByFieldRerankProcessor - Uses the search response to get value to rescore by
* </li>
* </ul>
* @param subType The kind of rerank processor
* @return Whether a rerank subtype needs context to perform the rescore search response action.
*/
public static boolean processorRequiresContext(RerankType subType) {
return !processorsWithNoContext.contains(subType);
}
}
Loading

0 comments on commit d4d13c6

Please sign in to comment.