Skip to content

Commit

Permalink
[LTR] Update the feature name from "learn to rank" to "learning to ra…
Browse files Browse the repository at this point in the history
…nk". (elastic#102938)
  • Loading branch information
afoucret authored Dec 5, 2023
1 parent 85311b2 commit be98a46
Show file tree
Hide file tree
Showing 23 changed files with 351 additions and 333 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public class SearchUsageStatsTests extends AbstractWireSerializingTestCase<Searc
"script_score"
);

private static final List<String> RESCORER_TYPES = List.of("query", "learn_to_rank");
private static final List<String> RESCORER_TYPES = List.of("query", "learning_to_rank");

private static final List<String> SECTIONS = List.of(
"highlight",
Expand Down Expand Up @@ -136,14 +136,14 @@ public void testAdd() {
searchUsageStats.add(
new SearchUsageStats(
Map.of("term", 1L, "match", 1L),
Map.of("query", 5L, "learn_to_rank", 2L),
Map.of("query", 5L, "learning_to_rank", 2L),
Map.of("query", 10L, "knn", 1L),
10L
)
);
assertEquals(Map.of("match", 11L, "term", 1L), searchUsageStats.getQueryUsage());
assertEquals(Map.of("query", 20L, "knn", 1L), searchUsageStats.getSectionsUsage());
assertEquals(Map.of("query", 10L, "learn_to_rank", 2L), searchUsageStats.getRescorerUsage());
assertEquals(Map.of("query", 10L, "learning_to_rank", 2L), searchUsageStats.getRescorerUsage());
assertEquals(20L, searchUsageStats.getTotalSearchCount());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
public enum FeatureFlag {
TIME_SERIES_MODE("es.index_mode_feature_flag_registered=true", Version.fromString("8.0.0"), null),
LEARN_TO_RANK("es.learn_to_rank_feature_flag_enabled=true", Version.fromString("8.10.0"), null),
LEARNING_TO_RANK("es.learning_to_rank_feature_flag_enabled=true", Version.fromString("8.12.0"), null),
FAILURE_STORE_ENABLED("es.failure_store_feature_flag_enabled=true", Version.fromString("8.12.0"), null);

public final String systemProperty;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.MlConfigVersion;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.LearnToRankFeatureExtractorBuilder;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.LearningToRankFeatureExtractorBuilder;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.QueryExtractorBuilder;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;

Expand All @@ -30,60 +30,60 @@
import java.util.Set;
import java.util.stream.Collectors;

public class LearnToRankConfig extends RegressionConfig implements Rewriteable<LearnToRankConfig> {
public class LearningToRankConfig extends RegressionConfig implements Rewriteable<LearningToRankConfig> {

public static final ParseField NAME = new ParseField("learn_to_rank");
public static final ParseField NAME = new ParseField("learning_to_rank");
static final TransportVersion MIN_SUPPORTED_TRANSPORT_VERSION = TransportVersion.current();
public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values");
public static final ParseField FEATURE_EXTRACTORS = new ParseField("feature_extractors");
public static final ParseField DEFAULT_PARAMS = new ParseField("default_params");

public static LearnToRankConfig EMPTY_PARAMS = new LearnToRankConfig(null, null, null);
public static LearningToRankConfig EMPTY_PARAMS = new LearningToRankConfig(null, null, null);

private static final ObjectParser<LearnToRankConfig.Builder, Boolean> LENIENT_PARSER = createParser(true);
private static final ObjectParser<LearnToRankConfig.Builder, Boolean> STRICT_PARSER = createParser(false);
private static final ObjectParser<LearningToRankConfig.Builder, Boolean> LENIENT_PARSER = createParser(true);
private static final ObjectParser<LearningToRankConfig.Builder, Boolean> STRICT_PARSER = createParser(false);

private static ObjectParser<LearnToRankConfig.Builder, Boolean> createParser(boolean lenient) {
ObjectParser<LearnToRankConfig.Builder, Boolean> parser = new ObjectParser<>(
private static ObjectParser<LearningToRankConfig.Builder, Boolean> createParser(boolean lenient) {
ObjectParser<LearningToRankConfig.Builder, Boolean> parser = new ObjectParser<>(
NAME.getPreferredName(),
lenient,
LearnToRankConfig.Builder::new
LearningToRankConfig.Builder::new
);
parser.declareInt(Builder::setNumTopFeatureImportanceValues, NUM_TOP_FEATURE_IMPORTANCE_VALUES);
parser.declareNamedObjects(
Builder::setLearnToRankFeatureExtractorBuilders,
(p, c, n) -> p.namedObject(LearnToRankFeatureExtractorBuilder.class, n, lenient),
Builder::setLearningToRankFeatureExtractorBuilders,
(p, c, n) -> p.namedObject(LearningToRankFeatureExtractorBuilder.class, n, lenient),
b -> {},
FEATURE_EXTRACTORS
);
parser.declareObject(Builder::setParamsDefaults, (p, c) -> p.map(), DEFAULT_PARAMS);
return parser;
}

public static LearnToRankConfig fromXContentStrict(XContentParser parser) {
public static LearningToRankConfig fromXContentStrict(XContentParser parser) {
return STRICT_PARSER.apply(parser, null).build();
}

public static LearnToRankConfig fromXContentLenient(XContentParser parser) {
public static LearningToRankConfig fromXContentLenient(XContentParser parser) {
return LENIENT_PARSER.apply(parser, null).build();
}

public static Builder builder(LearnToRankConfig config) {
public static Builder builder(LearningToRankConfig config) {
return new Builder(config);
}

private final List<LearnToRankFeatureExtractorBuilder> featureExtractorBuilders;
private final List<LearningToRankFeatureExtractorBuilder> featureExtractorBuilders;
private final Map<String, Object> paramsDefaults;

public LearnToRankConfig(
public LearningToRankConfig(
Integer numTopFeatureImportanceValues,
List<LearnToRankFeatureExtractorBuilder> featureExtractorBuilders,
List<LearningToRankFeatureExtractorBuilder> featureExtractorBuilders,
Map<String, Object> paramsDefaults
) {
super(DEFAULT_RESULTS_FIELD, numTopFeatureImportanceValues);
if (featureExtractorBuilders != null) {
Set<String> featureNames = featureExtractorBuilders.stream()
.map(LearnToRankFeatureExtractorBuilder::featureName)
.map(LearningToRankFeatureExtractorBuilder::featureName)
.collect(Collectors.toSet());
if (featureNames.size() < featureExtractorBuilders.size()) {
throw new IllegalArgumentException(
Expand All @@ -95,19 +95,19 @@ public LearnToRankConfig(
this.paramsDefaults = Collections.unmodifiableMap(Objects.requireNonNullElse(paramsDefaults, Map.of()));
}

public LearnToRankConfig(StreamInput in) throws IOException {
public LearningToRankConfig(StreamInput in) throws IOException {
super(in);
this.featureExtractorBuilders = in.readNamedWriteableCollectionAsList(LearnToRankFeatureExtractorBuilder.class);
this.featureExtractorBuilders = in.readNamedWriteableCollectionAsList(LearningToRankFeatureExtractorBuilder.class);
this.paramsDefaults = in.readMap();
}

public List<LearnToRankFeatureExtractorBuilder> getFeatureExtractorBuilders() {
public List<LearningToRankFeatureExtractorBuilder> getFeatureExtractorBuilders() {
return featureExtractorBuilders;
}

public List<QueryExtractorBuilder> getQueryFeatureExtractorBuilders() {
List<QueryExtractorBuilder> queryExtractorBuilders = new ArrayList<>();
for (LearnToRankFeatureExtractorBuilder featureExtractorBuilder : featureExtractorBuilders) {
for (LearningToRankFeatureExtractorBuilder featureExtractorBuilder : featureExtractorBuilders) {
if (featureExtractorBuilder instanceof QueryExtractorBuilder queryExtractorBuilder) {
queryExtractorBuilders.add(queryExtractorBuilder);
}
Expand Down Expand Up @@ -189,7 +189,7 @@ public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
if (super.equals(o) == false) return false;
LearnToRankConfig that = (LearnToRankConfig) o;
LearningToRankConfig that = (LearningToRankConfig) o;
return Objects.equals(featureExtractorBuilders, that.featureExtractorBuilders)
&& Objects.equals(paramsDefaults, that.paramsDefaults);
}
Expand Down Expand Up @@ -220,33 +220,33 @@ public TransportVersion getMinimalSupportedTransportVersion() {
}

@Override
public LearnToRankConfig rewrite(QueryRewriteContext ctx) throws IOException {
public LearningToRankConfig rewrite(QueryRewriteContext ctx) throws IOException {
if (this.featureExtractorBuilders.isEmpty()) {
return this;
}
boolean rewritten = false;
List<LearnToRankFeatureExtractorBuilder> rewrittenExtractors = new ArrayList<>(this.featureExtractorBuilders.size());
for (LearnToRankFeatureExtractorBuilder extractorBuilder : this.featureExtractorBuilders) {
LearnToRankFeatureExtractorBuilder rewrittenExtractor = Rewriteable.rewrite(extractorBuilder, ctx);
List<LearningToRankFeatureExtractorBuilder> rewrittenExtractors = new ArrayList<>(this.featureExtractorBuilders.size());
for (LearningToRankFeatureExtractorBuilder extractorBuilder : this.featureExtractorBuilders) {
LearningToRankFeatureExtractorBuilder rewrittenExtractor = Rewriteable.rewrite(extractorBuilder, ctx);
rewrittenExtractors.add(rewrittenExtractor);
rewritten |= (rewrittenExtractor != extractorBuilder);
}
if (rewritten) {
return new LearnToRankConfig(getNumTopFeatureImportanceValues(), rewrittenExtractors, paramsDefaults);
return new LearningToRankConfig(getNumTopFeatureImportanceValues(), rewrittenExtractors, paramsDefaults);
}
return this;
}

public static class Builder {
private Integer numTopFeatureImportanceValues;
private List<LearnToRankFeatureExtractorBuilder> learnToRankFeatureExtractorBuilders;
private List<LearningToRankFeatureExtractorBuilder> learningToRankFeatureExtractorBuilders;
private Map<String, Object> paramsDefaults = Map.of();

Builder() {}

Builder(LearnToRankConfig config) {
Builder(LearningToRankConfig config) {
this.numTopFeatureImportanceValues = config.getNumTopFeatureImportanceValues();
this.learnToRankFeatureExtractorBuilders = config.featureExtractorBuilders;
this.learningToRankFeatureExtractorBuilders = config.featureExtractorBuilders;
this.paramsDefaults = config.getParamsDefaults();
}

Expand All @@ -255,10 +255,10 @@ public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceV
return this;
}

public Builder setLearnToRankFeatureExtractorBuilders(
List<LearnToRankFeatureExtractorBuilder> learnToRankFeatureExtractorBuilders
public Builder setLearningToRankFeatureExtractorBuilders(
List<LearningToRankFeatureExtractorBuilder> learningToRankFeatureExtractorBuilders
) {
this.learnToRankFeatureExtractorBuilders = learnToRankFeatureExtractorBuilders;
this.learningToRankFeatureExtractorBuilders = learningToRankFeatureExtractorBuilders;
return this;
}

Expand All @@ -267,8 +267,8 @@ public Builder setParamsDefaults(Map<String, Object> paramsDefaults) {
return this;
}

public LearnToRankConfig build() {
return new LearnToRankConfig(numTopFeatureImportanceValues, learnToRankFeatureExtractorBuilders, paramsDefaults);
public LearningToRankConfig build() {
return new LearningToRankConfig(numTopFeatureImportanceValues, learningToRankFeatureExtractorBuilders, paramsDefaults);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;

public interface LearnToRankFeatureExtractorBuilder
public interface LearningToRankFeatureExtractorBuilder
extends
NamedXContentObject,
NamedWriteable,
Rewriteable<LearnToRankFeatureExtractorBuilder> {
Rewriteable<LearningToRankFeatureExtractorBuilder> {

ParseField FEATURE_NAME = new ParseField("feature_name");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

public record QueryExtractorBuilder(String featureName, QueryProvider query, float defaultScore)
implements
LearnToRankFeatureExtractorBuilder {
LearningToRankFeatureExtractorBuilder {

public static final ParseField NAME = new ParseField("query_extractor");
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
import org.elasticsearch.plugins.spi.NamedXContentProvider;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearningToRankConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedInferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.LearnToRankFeatureExtractorBuilder;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.LearningToRankFeatureExtractorBuilder;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.QueryExtractorBuilder;

import java.util.ArrayList;
Expand All @@ -32,22 +32,22 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
namedXContent.add(
new NamedXContentRegistry.Entry(
LenientlyParsedInferenceConfig.class,
LearnToRankConfig.NAME,
LearnToRankConfig::fromXContentLenient
LearningToRankConfig.NAME,
LearningToRankConfig::fromXContentLenient
)
);
// Strict Inference Config
namedXContent.add(
new NamedXContentRegistry.Entry(
StrictlyParsedInferenceConfig.class,
LearnToRankConfig.NAME,
LearnToRankConfig::fromXContentStrict
LearningToRankConfig.NAME,
LearningToRankConfig::fromXContentStrict
)
);
// LTR extractors
namedXContent.add(
new NamedXContentRegistry.Entry(
LearnToRankFeatureExtractorBuilder.class,
LearningToRankFeatureExtractorBuilder.class,
QueryExtractorBuilder.NAME,
QueryExtractorBuilder::fromXContent
)
Expand All @@ -59,12 +59,12 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
// Inference config
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceConfig.class, LearnToRankConfig.NAME.getPreferredName(), LearnToRankConfig::new)
new NamedWriteableRegistry.Entry(InferenceConfig.class, LearningToRankConfig.NAME.getPreferredName(), LearningToRankConfig::new)
);
// LTR Extractors
namedWriteables.add(
new NamedWriteableRegistry.Entry(
LearnToRankFeatureExtractorBuilder.class,
LearningToRankFeatureExtractorBuilder.class,
QueryExtractorBuilder.NAME.getPreferredName(),
QueryExtractorBuilder::new
)
Expand Down
Loading

0 comments on commit be98a46

Please sign in to comment.