diff --git a/src/antlr/Lexer.g b/src/antlr/Lexer.g index 3da08245c3d6..dd050c1478e7 100644 --- a/src/antlr/Lexer.g +++ b/src/antlr/Lexer.g @@ -227,6 +227,7 @@ K_DROPPED: D R O P P E D; K_COLUMN: C O L U M N; K_RECORD: R E C O R D; K_ANN: A N N; +K_BM25: B M '2' '5'; // Case-insensitive alpha characters fragment A: ('a'|'A'); diff --git a/src/antlr/Parser.g b/src/antlr/Parser.g index c0544620d51d..6b76c9d74bed 100644 --- a/src/antlr/Parser.g +++ b/src/antlr/Parser.g @@ -459,14 +459,18 @@ customIndexExpression [WhereClause.Builder clause] ; orderByClause[List orderings] - @init{ + @init { Ordering.Direction direction = Ordering.Direction.ASC; + Ordering.Raw.Expression expr = null; } - : c=cident (K_ANN K_OF t=term)? (K_ASC | K_DESC { direction = Ordering.Direction.DESC; })? + : c=cident + ( K_ANN K_OF t=term { expr = new Ordering.Raw.Ann(c, t); } + | K_BM25 K_OF t=term { expr = new Ordering.Raw.Bm25(c, t); } + )? + (K_ASC | K_DESC { direction = Ordering.Direction.DESC; })? { - Ordering.Raw.Expression expr = (t == null) - ? new Ordering.Raw.SingleColumn(c) - : new Ordering.Raw.Ann(c, t); + if (expr == null) + expr = new Ordering.Raw.SingleColumn(c); orderings.add(new Ordering.Raw(expr, direction)); } ; @@ -1969,6 +1973,7 @@ basic_unreserved_keyword returns [String str] | K_COLUMN | K_RECORD | K_ANN + | K_BM25 | K_OFFSET ) { $str = $k.text; } ; diff --git a/src/java/org/apache/cassandra/cql3/GeoDistanceRelation.java b/src/java/org/apache/cassandra/cql3/GeoDistanceRelation.java index 3d5fb2eeeef7..640e7e600686 100644 --- a/src/java/org/apache/cassandra/cql3/GeoDistanceRelation.java +++ b/src/java/org/apache/cassandra/cql3/GeoDistanceRelation.java @@ -141,6 +141,12 @@ protected Restriction newAnnRestriction(TableMetadata table, VariableSpecificati throw invalidRequest("%s cannot be used with the GEO_DISTANCE function", operator()); } + @Override + protected Restriction newBm25Restriction(TableMetadata table, VariableSpecifications boundNames) + { + throw invalidRequest("%s cannot be used with the GEO_DISTANCE function", operator()); + } + @Override protected Restriction newAnalyzerMatchesRestriction(TableMetadata table, VariableSpecifications boundNames) { diff --git a/src/java/org/apache/cassandra/cql3/MultiColumnRelation.java b/src/java/org/apache/cassandra/cql3/MultiColumnRelation.java index fcd505fcf5df..f56d76e2ced9 100644 --- a/src/java/org/apache/cassandra/cql3/MultiColumnRelation.java +++ b/src/java/org/apache/cassandra/cql3/MultiColumnRelation.java @@ -250,6 +250,12 @@ protected Restriction newAnnRestriction(TableMetadata table, VariableSpecificati throw invalidRequest("%s cannot be used for multi-column relations", operator()); } + @Override + protected Restriction newBm25Restriction(TableMetadata table, VariableSpecifications boundNames) + { + throw invalidRequest("%s cannot be used for multi-column relations", operator()); + } + @Override protected Restriction newAnalyzerMatchesRestriction(TableMetadata table, VariableSpecifications boundNames) { diff --git a/src/java/org/apache/cassandra/cql3/Operator.java b/src/java/org/apache/cassandra/cql3/Operator.java index 9908c64c39ea..d393fc1aa416 100644 --- a/src/java/org/apache/cassandra/cql3/Operator.java +++ b/src/java/org/apache/cassandra/cql3/Operator.java @@ -323,7 +323,7 @@ public boolean isSatisfiedBy(AbstractType type, @Nullable Index.Analyzer indexAnalyzer, @Nullable Index.Analyzer queryAnalyzer) { - return true; + throw new UnsupportedOperationException(); } }, NOT_IN(16) @@ -523,6 +523,7 @@ private boolean hasToken(AbstractType type, List tokens, ByteBuff return false; } }, + /** * An operator that performs a distance bounded approximate nearest neighbor search against a vector column such * that all result vectors are within a given distance of the query vector. The notable difference between this @@ -584,6 +585,24 @@ public boolean isSatisfiedBy(AbstractType type, { throw new UnsupportedOperationException(); } + }, + BM25(104) + { + @Override + public String toString() + { + return "BM25"; + } + + @Override + public boolean isSatisfiedBy(AbstractType type, + ByteBuffer leftOperand, + ByteBuffer rightOperand, + @Nullable Index.Analyzer indexAnalyzer, + @Nullable Index.Analyzer queryAnalyzer) + { + throw new UnsupportedOperationException(); + } }; /** diff --git a/src/java/org/apache/cassandra/cql3/Ordering.java b/src/java/org/apache/cassandra/cql3/Ordering.java index 81aa94a076cd..2dd817818a76 100644 --- a/src/java/org/apache/cassandra/cql3/Ordering.java +++ b/src/java/org/apache/cassandra/cql3/Ordering.java @@ -20,6 +20,7 @@ import org.apache.cassandra.cql3.restrictions.SingleColumnRestriction; import org.apache.cassandra.cql3.restrictions.SingleRestriction; +import org.apache.cassandra.cql3.statements.SelectStatement; import org.apache.cassandra.schema.ColumnMetadata; import org.apache.cassandra.schema.TableMetadata; @@ -48,6 +49,11 @@ public interface Expression SingleRestriction toRestriction(); ColumnMetadata getColumn(); + + default boolean isScored() + { + return false; + } } /** @@ -118,6 +124,54 @@ public ColumnMetadata getColumn() { return column; } + + @Override + public boolean isScored() + { + return SelectStatement.ANN_USE_SYNTHETIC_SCORE; + } + } + + /** + * An expression used in BM25 ordering. + * ORDER BY column BM25 OF value + */ + public static class Bm25 implements Expression + { + final ColumnMetadata column; + final Term queryValue; + final Direction direction; + + public Bm25(ColumnMetadata column, Term queryValue, Direction direction) + { + this.column = column; + this.queryValue = queryValue; + this.direction = direction; + } + + @Override + public boolean hasNonClusteredOrdering() + { + return true; + } + + @Override + public SingleRestriction toRestriction() + { + return new SingleColumnRestriction.Bm25Restriction(column, queryValue); + } + + @Override + public ColumnMetadata getColumn() + { + return column; + } + + @Override + public boolean isScored() + { + return true; + } } public enum Direction @@ -190,6 +244,27 @@ public Ordering.Expression bind(TableMetadata table, VariableSpecifications boun return new Ordering.Ann(column, value, direction); } } + + public static class Bm25 implements Expression + { + final ColumnIdentifier columnId; + final Term.Raw queryValue; + + Bm25(ColumnIdentifier column, Term.Raw queryValue) + { + this.columnId = column; + this.queryValue = queryValue; + } + + @Override + public Ordering.Expression bind(TableMetadata table, VariableSpecifications boundNames, Direction direction) + { + ColumnMetadata column = table.getExistingColumn(columnId); + Term value = queryValue.prepare(table.keyspace, column); + value.collectMarkerSpecification(boundNames); + return new Ordering.Bm25(column, value, direction); + } + } } } diff --git a/src/java/org/apache/cassandra/cql3/Relation.java b/src/java/org/apache/cassandra/cql3/Relation.java index 5cca2d257323..42cf3c9c8287 100644 --- a/src/java/org/apache/cassandra/cql3/Relation.java +++ b/src/java/org/apache/cassandra/cql3/Relation.java @@ -202,6 +202,8 @@ public final Restriction toRestriction(TableMetadata table, VariableSpecificatio return newLikeRestriction(table, boundNames, relationType); case ANN: return newAnnRestriction(table, boundNames); + case BM25: + return newBm25Restriction(table, boundNames); case ANALYZER_MATCHES: return newAnalyzerMatchesRestriction(table, boundNames); default: throw invalidRequest("Unsupported \"!=\" relation: %s", this); @@ -296,6 +298,11 @@ protected abstract Restriction newSliceRestriction(TableMetadata table, */ protected abstract Restriction newAnnRestriction(TableMetadata table, VariableSpecifications boundNames); + /** + * Creates a new BM25 restriction instance. + */ + protected abstract Restriction newBm25Restriction(TableMetadata table, VariableSpecifications boundNames); + /** * Creates a new Analyzer Matches restriction instance. */ diff --git a/src/java/org/apache/cassandra/cql3/SingleColumnRelation.java b/src/java/org/apache/cassandra/cql3/SingleColumnRelation.java index ec66ad70b529..2b97b99c3562 100644 --- a/src/java/org/apache/cassandra/cql3/SingleColumnRelation.java +++ b/src/java/org/apache/cassandra/cql3/SingleColumnRelation.java @@ -21,8 +21,11 @@ import java.util.ArrayList; import java.util.List; import java.util.Objects; +import java.util.stream.Collectors; import org.apache.cassandra.db.marshal.VectorType; +import org.apache.cassandra.index.IndexRegistry; +import org.apache.cassandra.index.sai.analyzer.AnalyzerEqOperatorSupport; import org.apache.cassandra.schema.ColumnMetadata; import org.apache.cassandra.schema.TableMetadata; import org.apache.cassandra.cql3.Term.Raw; @@ -33,6 +36,7 @@ import org.apache.cassandra.db.marshal.ListType; import org.apache.cassandra.db.marshal.MapType; import org.apache.cassandra.exceptions.InvalidRequestException; +import org.apache.cassandra.service.ClientWarn; import static org.apache.cassandra.cql3.statements.RequestValidations.checkFalse; import static org.apache.cassandra.cql3.statements.RequestValidations.checkTrue; @@ -191,7 +195,29 @@ protected Restriction newEQRestriction(TableMetadata table, VariableSpecificatio if (mapKey == null) { Term term = toTerm(toReceivers(columnDef), value, table.keyspace, boundNames); - return new SingleColumnRestriction.EQRestriction(columnDef, term); + // Leave the restriction as EQ if no analyzed index in backwards compatibility mode is present + var ebi = IndexRegistry.obtain(table).getEqBehavior(columnDef); + if (ebi.behavior == IndexRegistry.EqBehavior.EQ) + return new SingleColumnRestriction.EQRestriction(columnDef, term); + + // the index is configured to transform EQ into MATCH for backwards compatibility + var matchIndexName = ebi.matchIndex.getIndexMetadata() == null ? "Unknown" : ebi.matchIndex.getIndexMetadata().name; + if (ebi.behavior == IndexRegistry.EqBehavior.MATCH) + { + ClientWarn.instance.warn(String.format(AnalyzerEqOperatorSupport.EQ_RESTRICTION_ON_ANALYZED_WARNING, + columnDef.toString(), + matchIndexName), + columnDef); + return new SingleColumnRestriction.AnalyzerMatchesRestriction(columnDef, term); + } + + // multiple indexes support EQ, this is unsupported + assert ebi.behavior == IndexRegistry.EqBehavior.AMBIGUOUS; + var eqIndexName = ebi.eqIndex.getIndexMetadata() == null ? "Unknown" : ebi.eqIndex.getIndexMetadata().name; + throw invalidRequest(AnalyzerEqOperatorSupport.EQ_AMBIGUOUS_ERROR, + columnDef.toString(), + matchIndexName, + eqIndexName); } List receivers = toReceivers(columnDef); Term entryKey = toTerm(Collections.singletonList(receivers.get(0)), mapKey, table.keyspace, boundNames); @@ -333,6 +359,14 @@ protected Restriction newAnnRestriction(TableMetadata table, VariableSpecificati return new SingleColumnRestriction.AnnRestriction(columnDef, term); } + @Override + protected Restriction newBm25Restriction(TableMetadata table, VariableSpecifications boundNames) + { + ColumnMetadata columnDef = table.getExistingColumn(entity); + Term term = toTerm(toReceivers(columnDef), value, table.keyspace, boundNames); + return new SingleColumnRestriction.Bm25Restriction(columnDef, term); + } + @Override protected Restriction newAnalyzerMatchesRestriction(TableMetadata table, VariableSpecifications boundNames) { diff --git a/src/java/org/apache/cassandra/cql3/TokenRelation.java b/src/java/org/apache/cassandra/cql3/TokenRelation.java index a3ca586eee76..ca849dc82a30 100644 --- a/src/java/org/apache/cassandra/cql3/TokenRelation.java +++ b/src/java/org/apache/cassandra/cql3/TokenRelation.java @@ -138,7 +138,13 @@ protected Restriction newLikeRestriction(TableMetadata table, VariableSpecificat @Override protected Restriction newAnnRestriction(TableMetadata table, VariableSpecifications boundNames) { - throw invalidRequest("%s cannot be used for toekn relations", operator()); + throw invalidRequest("%s cannot be used for token relations", operator()); + } + + @Override + protected Restriction newBm25Restriction(TableMetadata table, VariableSpecifications boundNames) + { + throw invalidRequest("%s cannot be used for token relations", operator()); } @Override diff --git a/src/java/org/apache/cassandra/cql3/restrictions/ClusteringColumnRestrictions.java b/src/java/org/apache/cassandra/cql3/restrictions/ClusteringColumnRestrictions.java index 8ccd7fcaf37a..b2c08a065045 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/ClusteringColumnRestrictions.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/ClusteringColumnRestrictions.java @@ -191,7 +191,7 @@ public ClusteringColumnRestrictions.Builder addRestriction(Restriction restricti SingleRestriction lastRestriction = restrictions.lastRestriction(); ColumnMetadata lastRestrictionStart = lastRestriction.getFirstColumn(); ColumnMetadata newRestrictionStart = newRestriction.getFirstColumn(); - restrictions.addRestriction(newRestriction, isDisjunction, indexRegistry); + restrictions.addRestriction(newRestriction, isDisjunction); checkFalse(lastRestriction.isSlice() && newRestrictionStart.position() > lastRestrictionStart.position(), "Clustering column \"%s\" cannot be restricted (preceding column \"%s\" is restricted by a non-EQ relation)", @@ -205,7 +205,7 @@ public ClusteringColumnRestrictions.Builder addRestriction(Restriction restricti } else { - restrictions.addRestriction(newRestriction, isDisjunction, indexRegistry); + restrictions.addRestriction(newRestriction, isDisjunction); } return this; diff --git a/src/java/org/apache/cassandra/cql3/restrictions/PartitionKeySingleRestrictionSet.java b/src/java/org/apache/cassandra/cql3/restrictions/PartitionKeySingleRestrictionSet.java index 2536f94f482d..ccb144b4b174 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/PartitionKeySingleRestrictionSet.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/PartitionKeySingleRestrictionSet.java @@ -191,7 +191,7 @@ public PartitionKeyRestrictions build(IndexRegistry indexRegistry, boolean isDis if (restriction.isOnToken()) return buildWithTokens(restrictionSet, i, indexRegistry); - restrictionSet.addRestriction((SingleRestriction) restriction, isDisjunction, indexRegistry); + restrictionSet.addRestriction((SingleRestriction) restriction, isDisjunction); } return buildPartitionKeyRestrictions(restrictionSet); diff --git a/src/java/org/apache/cassandra/cql3/restrictions/RestrictionSet.java b/src/java/org/apache/cassandra/cql3/restrictions/RestrictionSet.java index 4cde1c434f33..dacb018e6186 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/RestrictionSet.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/RestrictionSet.java @@ -398,45 +398,36 @@ private Builder() { } - public void addRestriction(SingleRestriction restriction, boolean isDisjunction, IndexRegistry indexRegistry) + public void addRestriction(SingleRestriction restriction, boolean isDisjunction) { List columnDefs = restriction.getColumnDefs(); if (isDisjunction) { // If this restriction is part of a disjunction query then we don't want - // to merge the restrictions (if that is possible), we just add the - // restriction to the set of restrictions for the column. + // to merge the restrictions, we just add the new restriction addRestrictionForColumns(columnDefs, restriction, null); } else { - // In some special cases such as EQ in analyzed index we need to skip merging the restriction, - // so we can send multiple EQ restrictions to the index. - if (restriction.skipMerge(indexRegistry)) - { - addRestrictionForColumns(columnDefs, restriction, null); - return; - } - - // If this restriction isn't part of a disjunction then we need to get - // the set of existing restrictions for the column and merge them with the - // new restriction + // ANDed together restrictions against the same columns should be merged. Set existingRestrictions = getRestrictions(newRestrictions, columnDefs); - SingleRestriction merged = restriction; - Set replacedRestrictions = new HashSet<>(); - - for (SingleRestriction existing : existingRestrictions) + // merge the new restriction into an existing one. note that there is only ever a single + // restriction (per column), UNLESS one is ORDER BY BM25 and the other is MATCH. + for (var existing : existingRestrictions) { - if (!existing.skipMerge(indexRegistry)) + // shouldMerge exists for the BM25/MATCH case + if (existing.shouldMerge(restriction)) { - merged = existing.mergeWith(merged); - replacedRestrictions.add(existing); + var merged = existing.mergeWith(restriction); + addRestrictionForColumns(merged.getColumnDefs(), merged, Set.of(existing)); + return; } } - addRestrictionForColumns(merged.getColumnDefs(), merged, replacedRestrictions); + // no existing restrictions that we should merge the new one with, add a new one + addRestrictionForColumns(columnDefs, restriction, null); } } diff --git a/src/java/org/apache/cassandra/cql3/restrictions/SingleColumnRestriction.java b/src/java/org/apache/cassandra/cql3/restrictions/SingleColumnRestriction.java index 46c2afcc1d8f..cd44a6f09000 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/SingleColumnRestriction.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/SingleColumnRestriction.java @@ -24,15 +24,19 @@ import java.util.List; import java.util.Map; -import org.apache.cassandra.db.filter.RowFilter; -import org.apache.cassandra.db.filter.ANNOptions; -import org.apache.cassandra.schema.ColumnMetadata; -import org.apache.cassandra.cql3.*; +import org.apache.cassandra.cql3.MarkerOrTerms; +import org.apache.cassandra.cql3.Operator; +import org.apache.cassandra.cql3.QueryOptions; +import org.apache.cassandra.cql3.Term; +import org.apache.cassandra.cql3.Terms; import org.apache.cassandra.cql3.functions.Function; import org.apache.cassandra.cql3.statements.Bound; import org.apache.cassandra.db.MultiClusteringBuilder; +import org.apache.cassandra.db.filter.ANNOptions; +import org.apache.cassandra.db.filter.RowFilter; import org.apache.cassandra.index.Index; import org.apache.cassandra.index.IndexRegistry; +import org.apache.cassandra.schema.ColumnMetadata; import org.apache.cassandra.serializers.ListSerializer; import org.apache.cassandra.transport.ProtocolVersion; import org.apache.cassandra.utils.ByteBufferUtil; @@ -76,12 +80,17 @@ public ColumnMetadata getLastColumn() @Override public boolean hasSupportingIndex(IndexRegistry indexRegistry) + { + return findSupportingIndex(indexRegistry) != null; + } + + public Index findSupportingIndex(IndexRegistry indexRegistry) { for (Index index : indexRegistry.listIndexes()) if (isSupportedBy(index)) - return true; + return index; - return false; + return null; } @Override @@ -190,24 +199,6 @@ public String toString() return String.format("EQ(%s)", term); } - @Override - public boolean skipMerge(IndexRegistry indexRegistry) - { - // We should skip merging this EQ if there is an analyzed index for this column that supports EQ, - // so there can be multiple EQs for the same column. - - if (indexRegistry == null) - return false; - - for (Index index : indexRegistry.listIndexes()) - { - if (index.supportsExpression(columnDef, Operator.ANALYZER_MATCHES) && - index.supportsExpression(columnDef, Operator.EQ)) - return true; - } - return false; - } - @Override public SingleRestriction doMergeWith(SingleRestriction otherRestriction) { @@ -1188,6 +1179,86 @@ public boolean isBoundedAnn() } } + public static final class Bm25Restriction extends SingleColumnRestriction + { + private final Term value; + + public Bm25Restriction(ColumnMetadata columnDef, Term value) + { + super(columnDef); + this.value = value; + } + + public ByteBuffer value(QueryOptions options) + { + return value.bindAndGet(options); + } + + @Override + public void addFunctionsTo(List functions) + { + value.addFunctionsTo(functions); + } + + @Override + MultiColumnRestriction toMultiColumnRestriction() + { + throw new UnsupportedOperationException(); + } + + @Override + public void addToRowFilter(RowFilter.Builder filter, IndexRegistry indexRegistry, QueryOptions options, ANNOptions annOptions) + { + var index = findSupportingIndex(indexRegistry); + var valueBytes = value.bindAndGet(options); + var terms = index.getQueryAnalyzer().get().analyze(valueBytes); + if (terms.isEmpty()) + throw invalidRequest("BM25 query must contain at least one term (perhaps your analyzer is discarding tokens you didn't expect)"); + filter.add(columnDef, Operator.BM25, valueBytes); + } + + @Override + public MultiClusteringBuilder appendTo(MultiClusteringBuilder builder, QueryOptions options) + { + throw new UnsupportedOperationException(); + } + + @Override + public String toString() + { + return String.format("BM25(%s)", value); + } + + @Override + public SingleRestriction doMergeWith(SingleRestriction otherRestriction) + { + throw invalidRequest("%s cannot be restricted by both BM25 and %s", columnDef.name, otherRestriction.toString()); + } + + @Override + protected boolean isSupportedBy(Index index) + { + return index.supportsExpression(columnDef, Operator.BM25); + } + + @Override + public boolean isIndexBasedOrdering() + { + return true; + } + + @Override + public boolean shouldMerge(SingleRestriction other) + { + // we don't want to merge MATCH restrictions with ORDER BY BM25 + // so shouldMerge = false for that scenario, and true for others + // (because even though we can't meaningfully merge with others, we want doMergeWith to be called to throw) + // + // (Note that because ORDER BY is processed before WHERE, we only need this check in the BM25 class) + return !other.isAnalyzerMatches(); + } + } + /** * A Bounded ANN Restriction is one that uses a similarity score as the limiting factor for ANN instead of a number * of results. @@ -1335,10 +1406,12 @@ public String toString() @Override public SingleRestriction doMergeWith(SingleRestriction otherRestriction) { - if (!(otherRestriction.isAnalyzerMatches())) + if (!otherRestriction.isAnalyzerMatches()) throw invalidRequest(CANNOT_BE_MERGED_ERROR, columnDef.name); - List otherValues = ((AnalyzerMatchesRestriction) otherRestriction).getValues(); + List otherValues = otherRestriction instanceof AnalyzerMatchesRestriction + ? ((AnalyzerMatchesRestriction) otherRestriction).getValues() + : List.of(((EQRestriction) otherRestriction).term); List newValues = new ArrayList<>(values.size() + otherValues.size()); newValues.addAll(values); newValues.addAll(otherValues); diff --git a/src/java/org/apache/cassandra/cql3/restrictions/SingleRestriction.java b/src/java/org/apache/cassandra/cql3/restrictions/SingleRestriction.java index 595451f812de..bdd80badc0ae 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/SingleRestriction.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/SingleRestriction.java @@ -20,7 +20,6 @@ import org.apache.cassandra.cql3.QueryOptions; import org.apache.cassandra.cql3.statements.Bound; import org.apache.cassandra.db.MultiClusteringBuilder; -import org.apache.cassandra.index.IndexRegistry; /** * A single restriction/clause on one or multiple column. @@ -97,17 +96,6 @@ public default boolean isInclusive(Bound b) return true; } - /** - * Checks if this restriction shouldn't be merged with other restrictions. - * - * @param indexRegistry the index registry - * @return {@code true} if this shouldn't be merged with other restrictions - */ - default boolean skipMerge(IndexRegistry indexRegistry) - { - return false; - } - /** * Merges this restriction with the specified one. * @@ -141,4 +129,16 @@ public default MultiClusteringBuilder appendBoundTo(MultiClusteringBuilder build { return appendTo(builder, options); } + + /** + * @return true if the other restriction should be merged with this one. + * This is NOT for preventing illegal combinations of restrictions, e.g. + * a=1 AND a=2; that is handled by mergeWith. Instead, this is for the case + * where we want two completely different semantics against the same column. + * Currently the only such case is BM25 with MATCH. + */ + default boolean shouldMerge(SingleRestriction other) + { + return true; + } } diff --git a/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java b/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java index b20d658867e8..ac49d7d6f694 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java @@ -83,6 +83,7 @@ public class StatementRestrictions "Restriction on partition key column %s must not be nested under OR operator"; public static final String GEO_DISTANCE_REQUIRES_INDEX_MESSAGE = "GEO_DISTANCE requires the vector column to be indexed"; + public static final String BM25_ORDERING_REQUIRES_ANALYZED_INDEX_MESSAGE = "BM25 ordering on column %s requires an analyzed index"; public static final String NON_CLUSTER_ORDERING_REQUIRES_INDEX_MESSAGE = "Ordering on non-clustering column %s requires the column to be indexed."; public static final String NON_CLUSTER_ORDERING_REQUIRES_ALL_RESTRICTED_NON_PARTITION_KEY_COLUMNS_INDEXED_MESSAGE = "Ordering on non-clustering column requires each restricted column to be indexed except for fully-specified partition keys"; @@ -445,7 +446,7 @@ else if (def.isClusteringColumn() && nestingLevel == 0) } else { - nonPrimaryKeyRestrictionSet.addRestriction((SingleRestriction) restriction, element.isDisjunction(), indexRegistry); + nonPrimaryKeyRestrictionSet.addRestriction((SingleRestriction) restriction, element.isDisjunction()); } } } @@ -685,7 +686,8 @@ else if (indexOrderings.size() == 1) if (orderings.size() > 1) throw new InvalidRequestException("Cannot combine clustering column ordering with non-clustering column ordering"); Ordering ordering = indexOrderings.get(0); - if (ordering.direction != Ordering.Direction.ASC && ordering.expression instanceof Ordering.Ann) + // TODO remove the instanceof with SelectStatement.ANN_USE_SYNTHETIC_SCORE. + if (ordering.direction != Ordering.Direction.ASC && (ordering.expression.isScored() || ordering.expression instanceof Ordering.Ann)) throw new InvalidRequestException("Descending ANN ordering is not supported"); if (!ENABLE_SAI_GENERAL_ORDER_BY && ordering.expression instanceof Ordering.SingleColumn) throw new InvalidRequestException("SAI based ORDER BY on non-vector column is not supported"); @@ -698,10 +700,14 @@ else if (indexOrderings.size() == 1) throw new InvalidRequestException(String.format("SAI based ordering on column %s of type %s is not supported", restriction.getFirstColumn(), restriction.getFirstColumn().type.asCQL3Type())); - throw new InvalidRequestException(String.format(NON_CLUSTER_ORDERING_REQUIRES_INDEX_MESSAGE, - restriction.getFirstColumn())); + if (ordering.expression instanceof Ordering.Bm25) + throw new InvalidRequestException(String.format(BM25_ORDERING_REQUIRES_ANALYZED_INDEX_MESSAGE, + restriction.getFirstColumn())); + else + throw new InvalidRequestException(String.format(NON_CLUSTER_ORDERING_REQUIRES_INDEX_MESSAGE, + restriction.getFirstColumn())); } - receiver.addRestriction(restriction, false, indexRegistry); + receiver.addRestriction(restriction, false); } } diff --git a/src/java/org/apache/cassandra/cql3/selection/ColumnFilterFactory.java b/src/java/org/apache/cassandra/cql3/selection/ColumnFilterFactory.java index 63fa0520101e..00225cca4108 100644 --- a/src/java/org/apache/cassandra/cql3/selection/ColumnFilterFactory.java +++ b/src/java/org/apache/cassandra/cql3/selection/ColumnFilterFactory.java @@ -38,9 +38,21 @@ abstract class ColumnFilterFactory */ abstract ColumnFilter newInstance(List selectors); - public static ColumnFilterFactory wildcard(TableMetadata table) + public static ColumnFilterFactory wildcard(TableMetadata table, Set orderingColumns) { - return new PrecomputedColumnFilter(ColumnFilter.all(table)); + ColumnFilter cf; + if (orderingColumns.isEmpty()) + { + cf = ColumnFilter.all(table); + } + else + { + ColumnFilter.Builder builder = ColumnFilter.selectionBuilder(); + builder.addAll(table.regularAndStaticColumns()); + builder.addAll(orderingColumns); + cf = builder.build(); + } + return new PrecomputedColumnFilter(cf); } public static ColumnFilterFactory fromColumns(TableMetadata table, diff --git a/src/java/org/apache/cassandra/cql3/selection/Selection.java b/src/java/org/apache/cassandra/cql3/selection/Selection.java index 02aae61dd5ff..12d8aa014e19 100644 --- a/src/java/org/apache/cassandra/cql3/selection/Selection.java +++ b/src/java/org/apache/cassandra/cql3/selection/Selection.java @@ -43,10 +43,24 @@ public abstract class Selection private static final Predicate STATIC_COLUMN_FILTER = (column) -> column.isStatic(); private final TableMetadata table; + + // Full list of columns needed for processing the query, including selected columns, ordering columns, + // and columns needed for restrictions. Wildcard columns are fully materialized here. + // + // This also includes synthetic columns, because unlike all the other not-physical-columns selectables, they are + // computed on the replica instead of the coordinator and so, like physical columns, they need to be sent back + // as part of the result. private final List columns; + + // maps ColumnSpecifications (columns, function calls, aliases) to the columns backing them private final SelectionColumnMapping columnMapping; + + // metadata matching the ColumnSpcifications protected final ResultSet.ResultMetadata metadata; + + // creates a ColumnFilter that breaks columns into `queried` and `fetched` protected final ColumnFilterFactory columnFilterFactory; + protected final boolean isJson; // Columns used to order the result set for JSON queries with post ordering. @@ -126,10 +140,15 @@ public ResultSet.ResultMetadata getResultMetadata() } public static Selection wildcard(TableMetadata table, boolean isJson, boolean returnStaticContentOnPartitionWithNoRows) + { + return wildcard(table, Collections.emptySet(), isJson, returnStaticContentOnPartitionWithNoRows); + } + + public static Selection wildcard(TableMetadata table, Set orderingColumns, boolean isJson, boolean returnStaticContentOnPartitionWithNoRows) { List all = new ArrayList<>(table.columns().size()); Iterators.addAll(all, table.allColumnsInSelectOrder()); - return new SimpleSelection(table, all, Collections.emptySet(), true, isJson, returnStaticContentOnPartitionWithNoRows); + return new SimpleSelection(table, all, orderingColumns, true, isJson, returnStaticContentOnPartitionWithNoRows); } public static Selection wildcardWithGroupBy(TableMetadata table, @@ -400,7 +419,7 @@ public SimpleSelection(TableMetadata table, selectedColumns, orderingColumns, SelectionColumnMapping.simpleMapping(selectedColumns), - isWildcard ? ColumnFilterFactory.wildcard(table) + isWildcard ? ColumnFilterFactory.wildcard(table, orderingColumns) : ColumnFilterFactory.fromColumns(table, selectedColumns, orderingColumns, Collections.emptySet(), returnStaticContentOnPartitionWithNoRows), isWildcard, isJson); diff --git a/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java b/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java index 58b75aacab5a..63fd7abd2b70 100644 --- a/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java +++ b/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java @@ -40,6 +40,7 @@ import org.apache.cassandra.cql3.restrictions.ExternalRestriction; import org.apache.cassandra.cql3.restrictions.Restrictions; import org.apache.cassandra.cql3.selection.SortedRowsBuilder; +import org.apache.cassandra.db.marshal.FloatType; import org.apache.cassandra.guardrails.Guardrails; import org.apache.cassandra.sensors.SensorsCustomParams; import org.apache.cassandra.schema.ColumnMetadata; @@ -105,6 +106,12 @@ */ public class SelectStatement implements CQLStatement.SingleKeyspaceCqlStatement { + // TODO remove this when we no longer need to downgrade to replicas that don't know about synthetic columns, + // and the related code in + // - StatementRestrictions.addOrderingRestrictions + // - StorageAttachedIndexSearcher.PrimaryKeyIterator constructor + public static final boolean ANN_USE_SYNTHETIC_SCORE = Boolean.parseBoolean(System.getProperty("cassandra.sai.ann_use_synthetic_score", "false")); + private static final Logger logger = LoggerFactory.getLogger(SelectStatement.class); private static final NoSpamLogger noSpamLogger = NoSpamLogger.getLogger(SelectStatement.logger, 1, TimeUnit.MINUTES); public static final String TOPK_CONSISTENCY_LEVEL_ERROR = "Top-K queries can only be run with consistency level ONE/LOCAL_ONE. Consistency level %s was used."; @@ -1100,12 +1107,16 @@ void processPartition(RowIterator partition, QueryOptions options, ResultSetBuil case CLUSTERING: result.add(row.clustering().bufferAt(def.position())); break; + case SYNTHETIC: + // treat as REGULAR case REGULAR: result.add(row.getColumnData(def), nowInSec); break; case STATIC: result.add(staticRow.getColumnData(def), nowInSec); break; + default: + throw new AssertionError(); } } } @@ -1191,6 +1202,9 @@ public SelectStatement prepare(boolean forView, UnaryOperator keyspaceMa List selectables = RawSelector.toSelectables(selectClause, table); boolean containsOnlyStaticColumns = selectOnlyStaticColumns(table, selectables); + // Besides actual restrictions (where clauses), prepareRestrictions will include pseudo-restrictions + // on indexed columns to allow pushing ORDER BY into the index; see StatementRestrictions::addOrderingRestrictions. + // Therefore, we don't want to convert an ANN Ordering column into a +score column until after that. List orderings = getOrderings(table); StatementRestrictions restrictions = prepareRestrictions( table, bindVariables, orderings, containsOnlyStaticColumns, forView); @@ -1198,6 +1212,11 @@ public SelectStatement prepare(boolean forView, UnaryOperator keyspaceMa // If we order post-query, the sorted column needs to be in the ResultSet for sorting, // even if we don't ultimately ship them to the client (CASSANDRA-4911). Map orderingColumns = getOrderingColumns(orderings); + // +score column for ANN/BM25 + var scoreOrdering = getScoreOrdering(orderings); + assert scoreOrdering == null || orderingColumns.isEmpty() : "can't have both scored ordering and column ordering"; + if (scoreOrdering != null) + orderingColumns = scoreOrdering; Set resultSetOrderingColumns = getResultSetOrdering(restrictions, orderingColumns); Selection selection = prepareSelection(table, @@ -1226,9 +1245,9 @@ public SelectStatement prepare(boolean forView, UnaryOperator keyspaceMa if (!orderingColumns.isEmpty()) { assert !forView; - verifyOrderingIsAllowed(restrictions, orderingColumns); + verifyOrderingIsAllowed(table, restrictions, orderingColumns); orderingComparator = getOrderingComparator(selection, restrictions, orderingColumns); - isReversed = isReversed(table, orderingColumns, restrictions); + isReversed = isReversed(table, orderingColumns); if (isReversed && orderingComparator != null) orderingComparator = orderingComparator.reverse(); } @@ -1252,6 +1271,21 @@ public SelectStatement prepare(boolean forView, UnaryOperator keyspaceMa options); } + private Map getScoreOrdering(List orderings) + { + if (orderings.isEmpty()) + return null; + + var expr = orderings.get(0).expression; + if (!expr.isScored()) + return null; + + // Create synthetic score column + ColumnMetadata sourceColumn = expr.getColumn(); + var cm = ColumnMetadata.syntheticColumn(sourceColumn.ksName, sourceColumn.cfName, ColumnMetadata.SYNTHETIC_SCORE_ID, FloatType.instance); + return Map.of(cm, orderings.get(0)); + } + private Set getResultSetOrdering(StatementRestrictions restrictions, Map orderingColumns) { if (restrictions.keyIsInRelation() || orderingColumns.values().stream().anyMatch(o -> o.expression.hasNonClusteredOrdering())) @@ -1269,8 +1303,9 @@ private Selection prepareSelection(TableMetadata table, if (selectables.isEmpty()) // wildcard query { - return hasGroupBy ? Selection.wildcardWithGroupBy(table, boundNames, parameters.isJson, restrictions.returnStaticContentOnPartitionWithNoRows()) - : Selection.wildcard(table, parameters.isJson, restrictions.returnStaticContentOnPartitionWithNoRows()); + return hasGroupBy + ? Selection.wildcardWithGroupBy(table, boundNames, parameters.isJson, restrictions.returnStaticContentOnPartitionWithNoRows()) + : Selection.wildcard(table, parameters.isJson, restrictions.returnStaticContentOnPartitionWithNoRows()); } return Selection.fromSelectors(table, @@ -1312,13 +1347,14 @@ private Map getOrderingColumns(List ordering if (orderings.isEmpty()) return Collections.emptyMap(); - Map orderingColumns = new LinkedHashMap<>(); - for (Ordering ordering : orderings) - { - ColumnMetadata column = ordering.expression.getColumn(); - orderingColumns.put(column, ordering); - } - return orderingColumns; + return orderings.stream() + .filter(ordering -> !ordering.expression.isScored()) + .collect(Collectors.toMap(ordering -> ordering.expression.getColumn(), + ordering -> ordering, + (a, b) -> { + throw new IllegalStateException("Duplicate keys"); + }, + LinkedHashMap::new)); } private List getOrderings(TableMetadata table) @@ -1365,12 +1401,28 @@ private Term prepareLimit(VariableSpecifications boundNames, Term.Raw limit, return prepLimit; } - private static void verifyOrderingIsAllowed(StatementRestrictions restrictions, Map orderingColumns) throws InvalidRequestException + private static void verifyOrderingIsAllowed(TableMetadata table, StatementRestrictions restrictions, Map orderingColumns) throws InvalidRequestException { if (orderingColumns.values().stream().anyMatch(o -> o.expression.hasNonClusteredOrdering())) return; + checkFalse(restrictions.usesSecondaryIndexing(), "ORDER BY with 2ndary indexes is not supported."); checkFalse(restrictions.isKeyRange(), "ORDER BY is only supported when the partition key is restricted by an EQ or an IN."); + + // check that clustering columns are valid + int i = 0; + for (var entry : orderingColumns.entrySet()) + { + ColumnMetadata def = entry.getKey(); + checkTrue(def.isClusteringColumn(), + "Order by is currently only supported on indexed columns and the clustered columns of the PRIMARY KEY, got %s", def.name); + while (i != def.position()) + { + checkTrue(restrictions.isColumnRestrictedByEq(table.clusteringColumns().get(i++)), + "Ordering by clustered columns must follow the declared order in the PRIMARY KEY"); + } + i++; + } } private static void validateDistinctSelection(TableMetadata metadata, @@ -1485,35 +1537,30 @@ private ColumnComparator> getOrderingComparator(Selection selec : new CompositeComparator(sorters, idToSort); } - private boolean isReversed(TableMetadata table, Map orderingColumns, StatementRestrictions restrictions) throws InvalidRequestException + private boolean isReversed(TableMetadata table, Map orderingColumns) throws InvalidRequestException { - // Nonclustered ordering handles descending logic in a different way + // Nonclustered ordering handles descending logic through ScoreOrderedResultRetriever and TKP if (orderingColumns.values().stream().anyMatch(o -> o.expression.hasNonClusteredOrdering())) return false; - Boolean[] reversedMap = new Boolean[table.clusteringColumns().size()]; - int i = 0; + Boolean[] clusteredMap = new Boolean[table.clusteringColumns().size()]; for (var entry : orderingColumns.entrySet()) { ColumnMetadata def = entry.getKey(); Ordering ordering = entry.getValue(); - boolean reversed = ordering.direction == Ordering.Direction.DESC; - - // VSTODO move this to verifyOrderingIsAllowed? - checkTrue(def.isClusteringColumn(), - "Order by is currently only supported on the clustered columns of the PRIMARY KEY, got %s", def.name); - while (i != def.position()) - { - checkTrue(restrictions.isColumnRestrictedByEq(table.clusteringColumns().get(i++)), - "Order by currently only supports the ordering of columns following their declared order in the PRIMARY KEY"); - } - i++; - reversedMap[def.position()] = (reversed != def.isReversedType()); + // We defined ANN OF to be ASC ordering, as in, "order by near-ness". But since score goes from + // 0 (worst) to 1 (closest), we need to reverse the ordering for the comparator when we're sorting + // by synthetic +score column. + boolean cqlReversed = ordering.direction == Ordering.Direction.DESC; + if (def.position() == ColumnMetadata.NO_POSITION) + return ordering.expression.isScored() || cqlReversed; + else + clusteredMap[def.position()] = (cqlReversed != def.isReversedType()); } - // Check that all boolean in reversedMap, if set, agrees + // Check that all boolean in clusteredMap, if set, agrees Boolean isReversed = null; - for (Boolean b : reversedMap) + for (Boolean b : clusteredMap) { // Column on which order is specified can be in any order if (b == null) @@ -1658,7 +1705,14 @@ public int compare(T o1, T o2) { return wrapped.compare(o2, o1); } + + @Override + public boolean indexOrdering() + { + return wrapped.indexOrdering(); + } } + /** * Used in orderResults(...) method when single 'ORDER BY' condition where given */ diff --git a/src/java/org/apache/cassandra/db/Columns.java b/src/java/org/apache/cassandra/db/Columns.java index 9a904de12253..45b3da97a596 100644 --- a/src/java/org/apache/cassandra/db/Columns.java +++ b/src/java/org/apache/cassandra/db/Columns.java @@ -28,6 +28,7 @@ import net.nicoulaj.compilecommand.annotations.DontInline; import org.apache.cassandra.cql3.ColumnIdentifier; +import org.apache.cassandra.db.marshal.AbstractType; import org.apache.cassandra.db.marshal.SetType; import org.apache.cassandra.db.marshal.UTF8Type; import org.apache.cassandra.db.rows.ColumnData; @@ -36,6 +37,7 @@ import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.schema.ColumnMetadata; import org.apache.cassandra.schema.TableMetadata; +import org.apache.cassandra.serializers.AbstractTypeSerializer; import org.apache.cassandra.utils.ByteBufferUtil; import org.apache.cassandra.utils.ObjectSizes; import org.apache.cassandra.utils.SearchIterator; @@ -459,37 +461,107 @@ public String toString() public static class Serializer { + AbstractTypeSerializer typeSerializer = new AbstractTypeSerializer(); + public void serialize(Columns columns, DataOutputPlus out) throws IOException { - out.writeUnsignedVInt(columns.size()); + int regularCount = 0; + int syntheticCount = 0; + + // Count regular and synthetic columns + for (ColumnMetadata column : columns) + { + if (column.isSynthetic()) + syntheticCount++; + else + regularCount++; + } + + // Jam the two counts into a single value to avoid massive backwards compatibility issues + long packedCount = getPackedCount(syntheticCount, regularCount); + out.writeUnsignedVInt(packedCount); + + // First pass - write synthetic columns with their full metadata for (ColumnMetadata column : columns) - ByteBufferUtil.writeWithVIntLength(column.name.bytes, out); + { + if (column.isSynthetic()) + { + ByteBufferUtil.writeWithVIntLength(column.name.bytes, out); + typeSerializer.serialize(column.type, out); + } + } + + // Second pass - write regular columns + for (ColumnMetadata column : columns) + { + if (!column.isSynthetic()) + ByteBufferUtil.writeWithVIntLength(column.name.bytes, out); + } + } + + private static long getPackedCount(int syntheticCount, int regularCount) + { + // Left shift of 20 gives us over 1M regular columns, and up to 4 synthetic columns + // before overflowing to a 4th byte. + return ((long) syntheticCount << 20) | regularCount; } public long serializedSize(Columns columns) { - long size = TypeSizes.sizeofUnsignedVInt(columns.size()); + int regularCount = 0; + int syntheticCount = 0; + long size = 0; + + // Count and calculate sizes for (ColumnMetadata column : columns) - size += ByteBufferUtil.serializedSizeWithVIntLength(column.name.bytes); - return size; + { + if (column.isSynthetic()) + { + syntheticCount++; + size += ByteBufferUtil.serializedSizeWithVIntLength(column.name.bytes); + size += typeSerializer.serializedSize(column.type); + } + else + { + regularCount++; + size += ByteBufferUtil.serializedSizeWithVIntLength(column.name.bytes); + } + } + + return TypeSizes.sizeofUnsignedVInt(getPackedCount(syntheticCount, regularCount)) + + size; } public Columns deserialize(DataInputPlus in, TableMetadata metadata) throws IOException { - int length = (int)in.readUnsignedVInt(); try (BTree.FastBuilder builder = BTree.fastBuilder()) { - for (int i = 0; i < length; i++) + long packedCount = in.readUnsignedVInt() ; + int regularCount = (int) (packedCount & 0xFFFFF); + int syntheticCount = (int) (packedCount >> 20); + + // First pass - synthetic columns + for (int i = 0; i < syntheticCount; i++) + { + ByteBuffer name = ByteBufferUtil.readWithVIntLength(in); + AbstractType type = typeSerializer.deserialize(in); + + if (!name.equals(ColumnMetadata.SYNTHETIC_SCORE_ID.bytes)) + throw new IllegalStateException("Unknown synthetic column " + UTF8Type.instance.getString(name)); + + ColumnMetadata column = ColumnMetadata.syntheticColumn(metadata.keyspace, metadata.name, ColumnMetadata.SYNTHETIC_SCORE_ID, type); + builder.add(column); + } + + // Second pass - regular columns + for (int i = 0; i < regularCount; i++) { ByteBuffer name = ByteBufferUtil.readWithVIntLength(in); ColumnMetadata column = metadata.getColumn(name); if (column == null) { - // If we don't find the definition, it could be we have data for a dropped column, and we shouldn't - // fail deserialization because of that. So we grab a "fake" ColumnMetadata that ensure proper - // deserialization. The column will be ignore later on anyway. + // If we don't find the definition, it could be we have data for a dropped column column = metadata.getDroppedColumn(name); - if (column == null) throw new RuntimeException("Unknown column " + UTF8Type.instance.getString(name) + " during deserialization of " + metadata.keyspace + '.' + metadata.name); } diff --git a/src/java/org/apache/cassandra/db/ReadCommand.java b/src/java/org/apache/cassandra/db/ReadCommand.java index b110a3ada2de..7fa00f0436e5 100644 --- a/src/java/org/apache/cassandra/db/ReadCommand.java +++ b/src/java/org/apache/cassandra/db/ReadCommand.java @@ -74,6 +74,7 @@ import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessageFlag; import org.apache.cassandra.net.Verb; +import org.apache.cassandra.schema.ColumnMetadata; import org.apache.cassandra.schema.IndexMetadata; import org.apache.cassandra.schema.Schema; import org.apache.cassandra.schema.SchemaConstants; @@ -384,8 +385,6 @@ static Index.QueryPlan findIndexQueryPlan(TableMetadata table, RowFilter rowFilt @Override public void maybeValidateIndexes() { - IndexRegistry.obtain(metadata()).validate(rowFilter()); - if (null != indexQueryPlan) indexQueryPlan.validate(this); } @@ -415,9 +414,9 @@ public UnfilteredPartitionIterator executeLocally(ReadExecutionController execut } Context context = Context.from(this); - UnfilteredPartitionIterator iterator = (null == searcher) ? Transformation.apply(queryStorage(cfs, executionController), new TrackingRowIterator(context)) - : Transformation.apply(searchStorage(searcher, executionController), new TrackingRowIterator(context)); - + var storageTarget = (null == searcher) ? queryStorage(cfs, executionController) + : searchStorage(searcher, executionController); + UnfilteredPartitionIterator iterator = Transformation.apply(storageTarget, new TrackingRowIterator(context)); iterator = RTBoundValidator.validate(iterator, Stage.MERGED, false); try @@ -1054,6 +1053,19 @@ public ReadCommand deserialize(DataInputPlus in, int version) throws IOException TableMetadata metadata = schema.getExistingTableMetadata(TableId.deserialize(in)); int nowInSec = in.readInt(); ColumnFilter columnFilter = ColumnFilter.serializer.deserialize(in, version, metadata); + + // add synthetic columns to the tablemetadata so we can serialize them in our response + var tmb = metadata.unbuild(); + for (var it = columnFilter.fetchedColumns().regulars.simpleColumns(); it.hasNext(); ) + { + var c = it.next(); + // synthetic columns sort first, so when we hit the first non-synthetic, we're done + if (!c.isSynthetic()) + break; + tmb.addColumn(ColumnMetadata.syntheticColumn(c.ksName, c.cfName, c.name, c.type)); + } + metadata = tmb.build(); + RowFilter rowFilter = RowFilter.serializer.deserialize(in, version, metadata); DataLimits limits = DataLimits.serializer.deserialize(in, version, metadata.comparator); Index.QueryPlan indexQueryPlan = null; diff --git a/src/java/org/apache/cassandra/db/RegularAndStaticColumns.java b/src/java/org/apache/cassandra/db/RegularAndStaticColumns.java index b6da183d013f..55533eda0e97 100644 --- a/src/java/org/apache/cassandra/db/RegularAndStaticColumns.java +++ b/src/java/org/apache/cassandra/db/RegularAndStaticColumns.java @@ -163,7 +163,7 @@ public Builder add(ColumnMetadata c) } else { - assert c.isRegular(); + assert c.isRegular() || c.isSynthetic(); if (regularColumns == null) regularColumns = BTree.builder(naturalOrder()); regularColumns.add(c); @@ -197,7 +197,7 @@ public Builder addAll(RegularAndStaticColumns columns) public RegularAndStaticColumns build() { - return new RegularAndStaticColumns(staticColumns == null ? Columns.NONE : Columns.from(staticColumns), + return new RegularAndStaticColumns(staticColumns == null ? Columns.NONE : Columns.from(staticColumns), regularColumns == null ? Columns.NONE : Columns.from(regularColumns)); } } diff --git a/src/java/org/apache/cassandra/db/filter/ColumnFilter.java b/src/java/org/apache/cassandra/db/filter/ColumnFilter.java index d9a1b9d4e51a..644e6d661a61 100644 --- a/src/java/org/apache/cassandra/db/filter/ColumnFilter.java +++ b/src/java/org/apache/cassandra/db/filter/ColumnFilter.java @@ -75,6 +75,9 @@ public abstract class ColumnFilter public static final Serializer serializer = new Serializer(); + // TODO remove this with ANN_USE_SYNTHETIC_SCORE + public abstract boolean fetchesExplicitly(ColumnMetadata column); + /** * The fetching strategy for the different queries. */ @@ -103,7 +106,8 @@ boolean fetchesAllColumns(boolean isStatic) @Override RegularAndStaticColumns getFetchedColumns(TableMetadata metadata, RegularAndStaticColumns queried) { - return metadata.regularAndStaticColumns(); + var merged = queried.regulars.mergeTo(metadata.regularColumns()); + return new RegularAndStaticColumns(metadata.staticColumns(), merged); } }, @@ -124,7 +128,8 @@ boolean fetchesAllColumns(boolean isStatic) @Override RegularAndStaticColumns getFetchedColumns(TableMetadata metadata, RegularAndStaticColumns queried) { - return new RegularAndStaticColumns(queried.statics, metadata.regularColumns()); + var merged = queried.regulars.mergeTo(metadata.regularColumns()); + return new RegularAndStaticColumns(queried.statics, merged); } }, @@ -295,14 +300,16 @@ public static ColumnFilter selection(TableMetadata metadata, } /** - * The columns that needs to be fetched internally for this filter. + * The columns that needs to be fetched internally. See FetchingStrategy for why this is + * always a superset of the queried columns. * * @return the columns to fetch for this filter. */ public abstract RegularAndStaticColumns fetchedColumns(); /** - * The columns actually queried by the user. + * The columns needed to process the query, including selected columns, ordering columns, + * restriction (predicate) columns, and synthetic columns. *

* Note that this is in general not all the columns that are fetched internally (see {@link #fetchedColumns}). */ @@ -619,9 +626,7 @@ private SortedSetMultimap buildSubSelectio */ public static class WildCardColumnFilter extends ColumnFilter { - /** - * The queried and fetched columns. - */ + // for wildcards, there is no distinction between fetched and queried because queried is already "everything" private final RegularAndStaticColumns fetchedAndQueried; /** @@ -667,6 +672,12 @@ public boolean fetches(ColumnMetadata column) return true; } + @Override + public boolean fetchesExplicitly(ColumnMetadata column) + { + return false; + } + @Override public boolean fetchedColumnIsQueried(ColumnMetadata column) { @@ -739,14 +750,9 @@ public static class SelectionColumnFilter extends ColumnFilter { public final FetchingStrategy fetchingStrategy; - /** - * The selected columns - */ + // Materializes the columns required to implement queriedColumns() and fetchedColumns(), + // see the comments to superclass's methods private final RegularAndStaticColumns queried; - - /** - * The columns that need to be fetched to be able - */ private final RegularAndStaticColumns fetched; private final SortedSetMultimap subSelections; // can be null @@ -820,6 +826,12 @@ public boolean fetches(ColumnMetadata column) return fetchingStrategy.fetchesAllColumns(column.isStatic()) || fetched.contains(column); } + @Override + public boolean fetchesExplicitly(ColumnMetadata column) + { + return fetched.contains(column); + } + /** * Whether the provided complex cell (identified by its column and path), which is assumed to be _fetched_ by * this filter, is also _queried_ by the user. diff --git a/src/java/org/apache/cassandra/db/filter/RowFilter.java b/src/java/org/apache/cassandra/db/filter/RowFilter.java index 96ad240b539c..4d35566fa41f 100644 --- a/src/java/org/apache/cassandra/db/filter/RowFilter.java +++ b/src/java/org/apache/cassandra/db/filter/RowFilter.java @@ -1239,6 +1239,7 @@ public boolean isSatisfiedBy(TableMetadata metadata, DecoratedKey partitionKey, case LIKE_MATCHES: case ANALYZER_MATCHES: case ANN: + case BM25: { assert !column.isComplex() : "Only CONTAINS and CONTAINS_KEY are supported for 'complex' types"; ByteBuffer foundValue = getValue(metadata, partitionKey, row); diff --git a/src/java/org/apache/cassandra/db/monitoring/MonitorableImpl.java b/src/java/org/apache/cassandra/db/monitoring/MonitorableImpl.java index a6e7947b23f1..30376eb385f3 100644 --- a/src/java/org/apache/cassandra/db/monitoring/MonitorableImpl.java +++ b/src/java/org/apache/cassandra/db/monitoring/MonitorableImpl.java @@ -18,6 +18,8 @@ package org.apache.cassandra.db.monitoring; +import org.apache.cassandra.index.sai.QueryContext; + import static org.apache.cassandra.utils.MonotonicClock.approxTime; public abstract class MonitorableImpl implements Monitorable @@ -123,6 +125,9 @@ public boolean complete() private void check() { + if (QueryContext.DISABLE_TIMEOUT) + return; + if (approxCreationTimeNanos < 0 || state != MonitoringState.IN_PROGRESS) return; diff --git a/src/java/org/apache/cassandra/db/partitions/ParallelCommandProcessor.java b/src/java/org/apache/cassandra/db/partitions/ParallelCommandProcessor.java deleted file mode 100644 index 19d8523bd5c9..000000000000 --- a/src/java/org/apache/cassandra/db/partitions/ParallelCommandProcessor.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.db.partitions; - -import java.util.List; - -import org.apache.cassandra.db.SinglePartitionReadCommand; -import org.apache.cassandra.db.rows.UnfilteredRowIterator; -import org.apache.cassandra.index.sai.utils.PrimaryKey; -import org.apache.cassandra.utils.Pair; - -/** - * An "processor" over a number of unfiltered partitions (i.e. partitions containing deletion information). - * - * Unlike {@link UnfilteredPartitionIterator}, this is designed to be used concurrently. - * - * Unlike UnfilteredPartitionIterator which requires single-threaded - * while (partitions.hasNext()) - * { - * var part = partitions.next(); - * ... - * part.close(); - * } - * - * this one allows concurrency, like - * var commands = partitions.getUninitializedCommands(); - * commands.parallelStream().forEach(tuple -> { - * var iter = partitions.commandToIterator(tuple.left(), tuple.right()); - * } - */ -public interface ParallelCommandProcessor -{ - /** - * Single-threaded call to get all commands and corresponding keys. - * - * @return the list of partition read commands. - */ - List> getUninitializedCommands(); - - /** - * Get an iterator for a given command and key. - * This method can be called concurrently for reulst of getUninitializedCommands(). - * - * @param command - * @param key - * @return - */ - UnfilteredRowIterator commandToIterator(PrimaryKey key, SinglePartitionReadCommand command); -} diff --git a/src/java/org/apache/cassandra/index/IndexRegistry.java b/src/java/org/apache/cassandra/index/IndexRegistry.java index 639d9aec350b..94bdfe0792fc 100644 --- a/src/java/org/apache/cassandra/index/IndexRegistry.java +++ b/src/java/org/apache/cassandra/index/IndexRegistry.java @@ -21,6 +21,7 @@ package org.apache.cassandra.index; import java.util.Collection; +import java.util.HashSet; import java.util.Collections; import java.util.Optional; import java.util.Set; @@ -103,12 +104,6 @@ public Optional getBestIndexFor(RowFilter.Expression expression) public void validate(PartitionUpdate update) { } - - @Override - public void validate(RowFilter filter) - { - // no-op since it's an empty registry - } }; /** @@ -296,12 +291,6 @@ public Optional getBestIndexFor(RowFilter.Expression expression) public void validate(PartitionUpdate update) { } - - @Override - public void validate(RowFilter filter) - { - // no-op since it's an empty registry - } }; default void registerIndex(Index index) @@ -354,8 +343,6 @@ default Optional getAnalyzerFor(ColumnMetadata column, */ void validate(PartitionUpdate update); - void validate(RowFilter filter); - /** * Returns the {@code IndexRegistry} associated to the specified table. * @@ -369,4 +356,74 @@ public static IndexRegistry obtain(TableMetadata table) return table.isVirtual() ? EMPTY : Keyspace.openAndGetStore(table).indexManager; } + + enum EqBehavior + { + EQ, + MATCH, + AMBIGUOUS + } + + class EqBehaviorIndexes + { + public EqBehavior behavior; + public final Index eqIndex; + public final Index matchIndex; + + private EqBehaviorIndexes(Index eqIndex, Index matchIndex, EqBehavior behavior) + { + this.eqIndex = eqIndex; + this.matchIndex = matchIndex; + this.behavior = behavior; + } + + public static EqBehaviorIndexes eq(Index eqIndex) + { + return new EqBehaviorIndexes(eqIndex, null, EqBehavior.EQ); + } + + public static EqBehaviorIndexes match(Index eqAndMatchIndex) + { + return new EqBehaviorIndexes(eqAndMatchIndex, eqAndMatchIndex, EqBehavior.MATCH); + } + + public static EqBehaviorIndexes ambiguous(Index firstEqIndex, Index secondEqIndex) + { + return new EqBehaviorIndexes(firstEqIndex, secondEqIndex, EqBehavior.AMBIGUOUS); + } + } + + /** + * @return + * - AMBIGUOUS if an index supports EQ and a different one supports both EQ and ANALYZER_MATCHES + * - MATCHES if an index supports both EQ and ANALYZER_MATCHES + * - otherwise EQ + */ + default EqBehaviorIndexes getEqBehavior(ColumnMetadata cm) + { + Index eqOnlyIndex = null; + Index bothIndex = null; + + for (Index index : listIndexes()) + { + boolean supportsEq = index.supportsExpression(cm, Operator.EQ); + boolean supportsMatches = index.supportsExpression(cm, Operator.ANALYZER_MATCHES); + + if (supportsEq && supportsMatches) + bothIndex = index; + else if (supportsEq) + eqOnlyIndex = index; + } + + // If we have one index supporting only EQ and another supporting both, return AMBIGUOUS + if (eqOnlyIndex != null && bothIndex != null) + return EqBehaviorIndexes.ambiguous(eqOnlyIndex, bothIndex); + + // If we have an index supporting both EQ and MATCHES, return MATCHES + if (bothIndex != null) + return EqBehaviorIndexes.match(bothIndex); + + // Otherwise return EQ + return EqBehaviorIndexes.eq(eqOnlyIndex == null ? bothIndex : eqOnlyIndex); + } } diff --git a/src/java/org/apache/cassandra/index/RowFilterValidator.java b/src/java/org/apache/cassandra/index/RowFilterValidator.java deleted file mode 100644 index fb70fbfc1452..000000000000 --- a/src/java/org/apache/cassandra/index/RowFilterValidator.java +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Copyright DataStax, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.index; - -import java.util.HashSet; -import java.util.Set; -import java.util.StringJoiner; - -import org.apache.cassandra.cql3.Operator; -import org.apache.cassandra.db.filter.RowFilter; -import org.apache.cassandra.index.sai.analyzer.AnalyzerEqOperatorSupport; -import org.apache.cassandra.schema.ColumnMetadata; -import org.apache.cassandra.service.ClientWarn; - -/** - * Class for validating the index-related aspects of a {@link RowFilter}, without considering what index is actually used. - *

- * It will emit a client warning when a query has EQ restrictions on columns having an analyzed index. - */ -class RowFilterValidator -{ - private final Iterable allIndexes; - - private Set columns; - private Set indexes; - - private RowFilterValidator(Iterable allIndexes) - { - this.allIndexes = allIndexes; - } - - private void addEqRestriction(ColumnMetadata column) - { - for (Index index : allIndexes) - { - if (index.supportsExpression(column, Operator.EQ) && - index.supportsExpression(column, Operator.ANALYZER_MATCHES)) - { - if (columns == null) - columns = new HashSet<>(); - columns.add(column); - - if (indexes == null) - indexes = new HashSet<>(); - indexes.add(index); - } - } - } - - private void validate() - { - if (columns == null || indexes == null) - return; - - StringJoiner columnNames = new StringJoiner(", "); - StringJoiner indexNames = new StringJoiner(", "); - columns.forEach(column -> columnNames.add(column.name.toString())); - indexes.forEach(index -> indexNames.add(index.getIndexMetadata().name)); - - ClientWarn.instance.warn(String.format(AnalyzerEqOperatorSupport.EQ_RESTRICTION_ON_ANALYZED_WARNING, columnNames, indexNames)); - } - - /** - * Emits a client warning if the filter contains EQ restrictions on columns having an analyzed index. - * - * @param filter the filter to validate - * @param indexes the existing indexes - */ - public static void validate(RowFilter filter, Iterable indexes) - { - RowFilterValidator validator = new RowFilterValidator(indexes); - validate(filter.root(), validator); - validator.validate(); - } - - private static void validate(RowFilter.FilterElement element, RowFilterValidator validator) - { - for (RowFilter.Expression expression : element.expressions()) - { - if (expression.operator() == Operator.EQ) - validator.addEqRestriction(expression.column()); - } - - for (RowFilter.FilterElement child : element.children()) - { - validate(child, validator); - } - } -} diff --git a/src/java/org/apache/cassandra/index/SecondaryIndexManager.java b/src/java/org/apache/cassandra/index/SecondaryIndexManager.java index b1930e741804..d4c75ed2a4f1 100644 --- a/src/java/org/apache/cassandra/index/SecondaryIndexManager.java +++ b/src/java/org/apache/cassandra/index/SecondaryIndexManager.java @@ -1277,12 +1277,6 @@ public void validate(PartitionUpdate update) throws InvalidRequestException index.validate(update); } - @Override - public void validate(RowFilter filter) - { - RowFilterValidator.validate(filter, indexes.values()); - } - /* * IndexRegistry methods */ diff --git a/src/java/org/apache/cassandra/index/sai/IndexContext.java b/src/java/org/apache/cassandra/index/sai/IndexContext.java index 7820400a6c25..1de515e16267 100644 --- a/src/java/org/apache/cassandra/index/sai/IndexContext.java +++ b/src/java/org/apache/cassandra/index/sai/IndexContext.java @@ -717,8 +717,8 @@ public boolean supports(Operator op) { if (op.isLike() || op == Operator.LIKE) return false; // Analyzed columns store the indexed result, so we are unable to compute raw equality. - // The only supported operator is ANALYZER_MATCHES. - if (op == Operator.ANALYZER_MATCHES) return isAnalyzed; + // The only supported operators are ANALYZER_MATCHES and BM25. + if (op == Operator.ANALYZER_MATCHES || op == Operator.BM25) return isAnalyzed; // If the column is analyzed and the operator is EQ, we need to check if the analyzer supports it. if (op == Operator.EQ && isAnalyzed && !analyzerFactory.supportsEquals()) @@ -742,7 +742,6 @@ public boolean supports(Operator op) || column.type instanceof IntegerType); // Currently truncates to 20 bytes Expression.Op operator = Expression.Op.valueOf(op); - if (isNonFrozenCollection()) { if (indexType == IndexTarget.Type.KEYS) @@ -754,17 +753,12 @@ public boolean supports(Operator op) return indexType == IndexTarget.Type.KEYS_AND_VALUES && (operator == Expression.Op.EQ || operator == Expression.Op.NOT_EQ || operator == Expression.Op.RANGE); } - if (indexType == IndexTarget.Type.FULL) return operator == Expression.Op.EQ; - AbstractType validator = getValidator(); - if (operator == Expression.Op.IN) return true; - if (operator != Expression.Op.EQ && EQ_ONLY_TYPES.contains(validator)) return false; - // RANGE only applicable to non-literal indexes return (operator != null) && !(TypeUtil.isLiteral(validator) && operator == Expression.Op.RANGE); } diff --git a/src/java/org/apache/cassandra/index/sai/QueryContext.java b/src/java/org/apache/cassandra/index/sai/QueryContext.java index a0b1719049ec..5f1f7cda8498 100644 --- a/src/java/org/apache/cassandra/index/sai/QueryContext.java +++ b/src/java/org/apache/cassandra/index/sai/QueryContext.java @@ -36,7 +36,7 @@ @NotThreadSafe public class QueryContext { - private static final boolean DISABLE_TIMEOUT = Boolean.getBoolean("cassandra.sai.test.disable.timeout"); + public static final boolean DISABLE_TIMEOUT = Boolean.getBoolean("cassandra.sai.test.disable.timeout"); protected final long queryStartTimeNanos; diff --git a/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java b/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java index 28cb42ab8dc8..12557cd66a2d 100644 --- a/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java +++ b/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java @@ -299,20 +299,26 @@ public static Map validateOptions(Map options, T throw new InvalidRequestException("Failed to retrieve target column for: " + targetColumn); } - // In order to support different index target on non-frozen map, ie. KEYS, VALUE, ENTRIES, we need to put index - // name as part of index file name instead of column name. We only need to check that the target is different - // between indexes. This will only allow indexes in the same column with a different IndexTarget.Type. - // - // Note that: "metadata.indexes" already includes current index - if (metadata.indexes.stream().filter(index -> index.getIndexClassName().equals(StorageAttachedIndex.class.getName())) - .map(index -> TargetParser.parse(metadata, index.options.get(IndexTarget.TARGET_OPTION_NAME))) - .filter(Objects::nonNull).filter(t -> t.equals(target)).count() > 1) - { - throw new InvalidRequestException("Cannot create more than one storage-attached index on the same column: " + target.left); - } + // Check for duplicate indexes considering both target and analyzer configuration + boolean isAnalyzed = AbstractAnalyzer.isAnalyzed(options); + long duplicateCount = metadata.indexes.stream() + .filter(index -> index.getIndexClassName().equals(StorageAttachedIndex.class.getName())) + .filter(index -> { + // Indexes on the same column with different target (KEYS, VALUES, ENTRIES) + // are allowed on non-frozen Maps + var existingTarget = TargetParser.parse(metadata, index.options.get(IndexTarget.TARGET_OPTION_NAME)); + if (existingTarget == null || !existingTarget.equals(target)) + return false; + // Also allow different indexes if one is analyzed and the other isn't + return isAnalyzed == AbstractAnalyzer.isAnalyzed(index.options); + }) + .count(); + // >1 because "metadata.indexes" already includes current index + if (duplicateCount > 1) + throw new InvalidRequestException(String.format("Cannot create duplicate storage-attached index on column: %s", target.left)); // Analyzer is not supported against PK columns - if (AbstractAnalyzer.isAnalyzed(options)) + if (isAnalyzed) { for (ColumnMetadata column : metadata.primaryKeyColumns()) { diff --git a/src/java/org/apache/cassandra/index/sai/analyzer/AnalyzerEqOperatorSupport.java b/src/java/org/apache/cassandra/index/sai/analyzer/AnalyzerEqOperatorSupport.java index 116bc7f62832..30408c9b986f 100644 --- a/src/java/org/apache/cassandra/index/sai/analyzer/AnalyzerEqOperatorSupport.java +++ b/src/java/org/apache/cassandra/index/sai/analyzer/AnalyzerEqOperatorSupport.java @@ -49,7 +49,7 @@ public class AnalyzerEqOperatorSupport OPTION, Arrays.toString(Value.values())); public static final String EQ_RESTRICTION_ON_ANALYZED_WARNING = - String.format("Columns [%%s] are restricted by '=' and have analyzed indexes [%%s] able to process those restrictions. " + + String.format("Column [%%s] is restricted by '=' and has an analyzed index [%%s] able to process those restrictions. " + "Analyzed indexes might process '=' restrictions in a way that is inconsistent with non-indexed queries. " + "While '=' is still supported on analyzed indexes for backwards compatibility, " + "it is recommended to use the ':' operator instead to prevent the ambiguity. " + @@ -58,6 +58,13 @@ public class AnalyzerEqOperatorSupport "please use '%s':'%s' in the index options.", OPTION, Value.UNSUPPORTED.toString().toLowerCase()); + public static final String EQ_AMBIGUOUS_ERROR = + String.format("Column [%%s] equality predicate is ambiguous. It has both an analyzed index [%%s] configured with '%s':'%s', " + + "and an un-analyzed index [%%s]. " + + "To avoid ambiguity, drop the analyzed index and recreate it with option '%s':'%s'.", + OPTION, Value.MATCH.toString().toLowerCase(), OPTION, Value.UNSUPPORTED.toString().toLowerCase()); + + public static final String LWT_CONDITION_ON_ANALYZED_WARNING = "Index analyzers not applied to LWT conditions on columns [%s]."; diff --git a/src/java/org/apache/cassandra/index/sai/disk/MemtableTermsIterator.java b/src/java/org/apache/cassandra/index/sai/disk/MemtableTermsIterator.java index 9f0253e61ed8..615f3cfb7a4e 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/MemtableTermsIterator.java +++ b/src/java/org/apache/cassandra/index/sai/disk/MemtableTermsIterator.java @@ -19,11 +19,11 @@ import java.nio.ByteBuffer; import java.util.Iterator; +import java.util.List; import com.google.common.base.Preconditions; -import com.carrotsearch.hppc.IntArrayList; -import com.carrotsearch.hppc.cursors.IntCursor; +import org.apache.cassandra.index.sai.memory.RowMapping; import org.apache.cassandra.utils.Pair; import org.apache.cassandra.utils.bytecomparable.ByteComparable; @@ -34,16 +34,16 @@ public class MemtableTermsIterator implements TermsIterator { private final ByteBuffer minTerm; private final ByteBuffer maxTerm; - private final Iterator> iterator; + private final Iterator>> iterator; - private Pair current; + private Pair> current; private int maxSSTableRowId = -1; private int minSSTableRowId = Integer.MAX_VALUE; public MemtableTermsIterator(ByteBuffer minTerm, ByteBuffer maxTerm, - Iterator> iterator) + Iterator>> iterator) { Preconditions.checkArgument(iterator != null); this.minTerm = minTerm; @@ -69,22 +69,24 @@ public void close() {} @Override public PostingList postings() { - final IntArrayList list = current.right; + var list = current.right; assert list.size() > 0; - final int minSegmentRowID = list.get(0); - final int maxSegmentRowID = list.get(list.size() - 1); + final int minSegmentRowID = list.get(0).rowId; + final int maxSegmentRowID = list.get(list.size() - 1).rowId; // Because we are working with postings from the memtable, there is only one segment, so segment row ids // and sstable row ids are the same. minSSTableRowId = Math.min(minSSTableRowId, minSegmentRowID); maxSSTableRowId = Math.max(maxSSTableRowId, maxSegmentRowID); - final Iterator it = list.iterator(); + var it = list.iterator(); return new PostingList() { + int frequency; + @Override public int nextPosting() { @@ -93,7 +95,9 @@ public int nextPosting() return END_OF_STREAM; } - return it.next().value; + var rowIdWithFrequency = it.next(); + frequency = rowIdWithFrequency.frequency; + return rowIdWithFrequency.rowId; } @Override @@ -102,6 +106,12 @@ public int size() return list.size(); } + @Override + public int frequency() + { + return frequency; + } + @Override public int advance(int targetRowID) { diff --git a/src/java/org/apache/cassandra/index/sai/disk/PostingList.java b/src/java/org/apache/cassandra/index/sai/disk/PostingList.java index b7dffb972599..4959c6f0be6a 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/PostingList.java +++ b/src/java/org/apache/cassandra/index/sai/disk/PostingList.java @@ -42,6 +42,14 @@ default void close() throws IOException {} */ int nextPosting() throws IOException; + /** + * @return the number of occurrences of the term in the current row (the one most recently returned by nextPosting). + */ + default int frequency() + { + return 1; + } + int size(); /** diff --git a/src/java/org/apache/cassandra/index/sai/disk/PrimaryKeyWithSource.java b/src/java/org/apache/cassandra/index/sai/disk/PrimaryKeyWithSource.java index 7f9f7e89f51e..d303bd4d27c7 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/PrimaryKeyWithSource.java +++ b/src/java/org/apache/cassandra/index/sai/disk/PrimaryKeyWithSource.java @@ -18,6 +18,7 @@ package org.apache.cassandra.index.sai.disk; +import io.github.jbellis.jvector.util.RamUsageEstimator; import org.apache.cassandra.db.Clustering; import org.apache.cassandra.db.DecoratedKey; import org.apache.cassandra.dht.Token; @@ -127,4 +128,14 @@ public String toString() { return String.format("%s (source sstable: %s, %s)", primaryKey, sourceSstableId, sourceRowId); } + + @Override + public long ramBytesUsed() + { + // Object header + 3 references (primaryKey, sourceSstableId) + long value + return RamUsageEstimator.NUM_BYTES_OBJECT_HEADER + + 2L * RamUsageEstimator.NUM_BYTES_OBJECT_REF + + Long.BYTES + + primaryKey.ramBytesUsed(); + } } diff --git a/src/java/org/apache/cassandra/index/sai/disk/RAMPostingSlices.java b/src/java/org/apache/cassandra/index/sai/disk/RAMPostingSlices.java index 5879e5f1abdf..1313f81569ae 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/RAMPostingSlices.java +++ b/src/java/org/apache/cassandra/index/sai/disk/RAMPostingSlices.java @@ -25,22 +25,33 @@ import org.apache.lucene.util.mutable.MutableValueInt; /** - * Encodes postings as variable integers into slices + * Encodes postings as variable integers into "slices" of byte blocks for efficient memory usage. */ class RAMPostingSlices { static final int DEFAULT_TERM_DICT_SIZE = 1024; + /** Pool of byte blocks storing the actual posting data */ private final ByteBlockPool postingsPool; + /** true if we're also writing term frequencies for an analyzed index */ + private final boolean includeFrequencies; + + /** The starting positions of postings for each term. Term id = index in array. */ private int[] postingStarts = new int[DEFAULT_TERM_DICT_SIZE]; + /** The current write positions for each term's postings. Term id = index in array. */ private int[] postingUptos = new int[DEFAULT_TERM_DICT_SIZE]; + /** The number of postings for each term. Term id = index in array. */ private int[] sizes = new int[DEFAULT_TERM_DICT_SIZE]; - RAMPostingSlices(Counter memoryUsage) + RAMPostingSlices(Counter memoryUsage, boolean includeFrequencies) { postingsPool = new ByteBlockPool(new ByteBlockPool.DirectTrackingAllocator(memoryUsage)); + this.includeFrequencies = includeFrequencies; } + /** + * Creates and returns a PostingList for the given term ID. + */ PostingList postingList(int termID, final ByteSliceReader reader, long maxSegmentRowID) { initReader(reader, termID); @@ -49,20 +60,35 @@ PostingList postingList(int termID, final ByteSliceReader reader, long maxSegmen return new PostingList() { + int frequency = Integer.MIN_VALUE; + @Override public int nextPosting() throws IOException { if (reader.eof()) { + frequency = Integer.MIN_VALUE; return PostingList.END_OF_STREAM; } else { lastSegmentRowId.value += reader.readVInt(); + if (includeFrequencies) + frequency = reader.readVInt(); return lastSegmentRowId.value; } } + @Override + public int frequency() + { + if (!includeFrequencies) + return 1; + if (frequency <= 0) + throw new IllegalStateException("frequency() called before nextPosting()"); + return frequency; + } + @Override public int size() { @@ -77,12 +103,20 @@ public int advance(int targetRowID) }; } + /** + * Initializes a ByteSliceReader for reading postings for a specific term. + */ void initReader(ByteSliceReader reader, int termID) { final int upto = postingUptos[termID]; reader.init(postingsPool, postingStarts[termID], upto); } + /** + * Creates a new slice for storing postings for a given term ID. + * Grows the internal arrays if necessary and allocates a new block + * if the current block cannot accommodate a new slice. + */ void createNewSlice(int termID) { if (termID >= postingStarts.length - 1) @@ -103,7 +137,27 @@ void createNewSlice(int termID) postingUptos[termID] = upto + postingsPool.byteOffset; } - void writeVInt(int termID, int i) + void writePosting(int termID, int deltaRowId, int frequency) + { + assert termID >= 0 : termID; + assert deltaRowId >= 0 : deltaRowId; + writeVInt(termID, deltaRowId); + + if (includeFrequencies) + { + assert frequency > 0 : frequency; + writeVInt(termID, frequency); + } + + sizes[termID]++; + } + + /** + * Writes a variable-length integer to the posting list for a given term. + * The integer is encoded using a variable-length encoding scheme where each + * byte uses 7 bits for the value and 1 bit to indicate if more bytes follow. + */ + private void writeVInt(int termID, int i) { while ((i & ~0x7F) != 0) { @@ -111,9 +165,12 @@ void writeVInt(int termID, int i) i >>>= 7; } writeByte(termID, (byte) i); - sizes[termID]++; } + /** + * Writes a single byte to the posting list for a given term. + * If the current slice is full, it automatically allocates a new slice. + */ private void writeByte(int termID, byte b) { int upto = postingUptos[termID]; diff --git a/src/java/org/apache/cassandra/index/sai/disk/RAMStringIndexer.java b/src/java/org/apache/cassandra/index/sai/disk/RAMStringIndexer.java index 8b6f07d43597..40c7d8fa37a8 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/RAMStringIndexer.java +++ b/src/java/org/apache/cassandra/index/sai/disk/RAMStringIndexer.java @@ -18,10 +18,12 @@ package org.apache.cassandra.index.sai.disk; import java.nio.ByteBuffer; +import java.util.List; import java.util.NoSuchElementException; import com.google.common.annotations.VisibleForTesting; +import org.agrona.collections.Int2IntHashMap; import org.apache.cassandra.utils.bytecomparable.ByteComparable; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.ByteBlockPool; @@ -43,11 +45,14 @@ public class RAMStringIndexer private final Counter termsBytesUsed; private final Counter slicesBytesUsed; - private int rowCount = 0; private int[] lastSegmentRowID = new int[RAMPostingSlices.DEFAULT_TERM_DICT_SIZE]; - public RAMStringIndexer() + private final boolean writeFrequencies; + private final Int2IntHashMap docLengths = new Int2IntHashMap(Integer.MIN_VALUE); + + public RAMStringIndexer(boolean writeFrequencies) { + this.writeFrequencies = writeFrequencies; termsBytesUsed = Counter.newCounter(); slicesBytesUsed = Counter.newCounter(); @@ -55,7 +60,7 @@ public RAMStringIndexer() termsHash = new BytesRefHash(termsPool); - slices = new RAMPostingSlices(slicesBytesUsed); + slices = new RAMPostingSlices(slicesBytesUsed, writeFrequencies); } public long estimatedBytesUsed() @@ -75,7 +80,12 @@ public boolean requiresFlush() public boolean isEmpty() { - return rowCount == 0; + return docLengths.isEmpty(); + } + + public Int2IntHashMap getDocLengths() + { + return docLengths; } /** @@ -140,36 +150,58 @@ private ByteComparable asByteComparable(byte[] bytes, int offset, int length) }; } - public long add(BytesRef term, int segmentRowId) + /** + * @return bytes allocated. may be zero if the (term, row) pair is a duplicate + */ + public long addAll(List terms, int segmentRowId) { long startBytes = estimatedBytesUsed(); - int termID = termsHash.add(term); - - if (termID >= 0) - { - // firs time seeing this term, create the term's first slice ! - slices.createNewSlice(termID); - } - else - { - termID = (-termID) - 1; - } + Int2IntHashMap frequencies = new Int2IntHashMap(Integer.MIN_VALUE); + Int2IntHashMap deltas = new Int2IntHashMap(Integer.MIN_VALUE); - if (termID >= lastSegmentRowID.length - 1) + for (BytesRef term : terms) { - lastSegmentRowID = ArrayUtil.grow(lastSegmentRowID, termID + 1); - } + int termID = termsHash.add(term); + boolean firstOccurrence = termID >= 0; - int delta = segmentRowId - lastSegmentRowID[termID]; - - lastSegmentRowID[termID] = segmentRowId; + if (firstOccurrence) + { + // first time seeing this term in any row, create the term's first slice ! + slices.createNewSlice(termID); + // grow the termID -> last segment array if necessary + if (termID >= lastSegmentRowID.length - 1) + lastSegmentRowID = ArrayUtil.grow(lastSegmentRowID, termID + 1); + if (writeFrequencies) + frequencies.put(termID, 1); + } + else + { + termID = (-termID) - 1; + // compaction should call this method only with increasing segmentRowIds + assert segmentRowId >= lastSegmentRowID[termID]; + // increment frequency + if (writeFrequencies) + frequencies.put(termID, frequencies.getOrDefault(termID, 0) + 1); + // Skip computing a delta if we've already seen this term in this row + if (segmentRowId == lastSegmentRowID[termID]) + continue; + } - slices.writeVInt(termID, delta); + // Compute the delta from the last time this term was seen, to this row + int delta = segmentRowId - lastSegmentRowID[termID]; + // sanity check that we're advancing the row id, i.e. no duplicate entries. + assert firstOccurrence || delta > 0; + deltas.put(termID, delta); + lastSegmentRowID[termID] = segmentRowId; + } - long allocatedBytes = estimatedBytesUsed() - startBytes; + // add the postings now that we know the frequencies + deltas.forEachInt((termID, delta) -> { + slices.writePosting(termID, delta, frequencies.get(termID)); + }); - rowCount++; + docLengths.put(segmentRowId, terms.size()); - return allocatedBytes; + return estimatedBytesUsed() - startBytes; } } diff --git a/src/java/org/apache/cassandra/index/sai/disk/format/IndexComponentType.java b/src/java/org/apache/cassandra/index/sai/disk/format/IndexComponentType.java index 165f484eae66..bf38d6b8fbd8 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/format/IndexComponentType.java +++ b/src/java/org/apache/cassandra/index/sai/disk/format/IndexComponentType.java @@ -111,7 +111,12 @@ public enum IndexComponentType * * V1 V2 */ - GROUP_COMPLETION_MARKER("GroupComplete"); + GROUP_COMPLETION_MARKER("GroupComplete"), + + /** + * Stores document length information for BM25 scoring + */ + DOC_LENGTHS("DocLengths"); public final String representation; diff --git a/src/java/org/apache/cassandra/index/sai/disk/format/Version.java b/src/java/org/apache/cassandra/index/sai/disk/format/Version.java index 6411072486b8..77757eec4851 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/format/Version.java +++ b/src/java/org/apache/cassandra/index/sai/disk/format/Version.java @@ -34,6 +34,7 @@ import org.apache.cassandra.index.sai.disk.v4.V4OnDiskFormat; import org.apache.cassandra.index.sai.disk.v5.V5OnDiskFormat; import org.apache.cassandra.index.sai.disk.v6.V6OnDiskFormat; +import org.apache.cassandra.index.sai.disk.v7.V7OnDiskFormat; import org.apache.cassandra.index.sai.utils.TypeUtil; import org.apache.cassandra.utils.bytecomparable.ByteComparable; @@ -43,7 +44,7 @@ * Format version of indexing component, denoted as [major][minor]. Same forward-compatibility rules apply as to * {@link org.apache.cassandra.io.sstable.format.Version}. */ -public class Version +public class Version implements Comparable { // 6.8 formats public static final Version AA = new Version("aa", V1OnDiskFormat.instance, Version::aaFileNameFormat); @@ -53,15 +54,17 @@ public class Version public static final Version CA = new Version("ca", V3OnDiskFormat.instance, (c, i, g) -> stargazerFileNameFormat(c, i, g, "ca")); // NOTE: use DB to prevent collisions with upstream file formats // Encode trie entries using their AbstractType to ensure trie entries are sorted for range queries and are prefix free. - public static final Version DB = new Version("db", V4OnDiskFormat.instance, (c, i, g) -> stargazerFileNameFormat(c, i, g,"db")); + public static final Version DB = new Version("db", V4OnDiskFormat.instance, (c, i, g) -> stargazerFileNameFormat(c, i, g, "db")); // revamps vector postings lists to cause fewer reads from disk public static final Version DC = new Version("dc", V5OnDiskFormat.instance, (c, i, g) -> stargazerFileNameFormat(c, i, g, "dc")); // histograms in index metadata public static final Version EB = new Version("eb", V6OnDiskFormat.instance, (c, i, g) -> stargazerFileNameFormat(c, i, g, "eb")); + // term frequencies index component + public static final Version EC = new Version("ec", V7OnDiskFormat.instance, (c, i, g) -> stargazerFileNameFormat(c, i, g, "ec")); // These are in reverse-chronological order so that the latest version is first. Version matching tests // are more likely to match the latest version so we want to test that one first. - public static final List ALL = Lists.newArrayList(EB, DC, DB, CA, BA, AA); + public static final List ALL = Lists.newArrayList(EC, EB, DC, DB, CA, BA, AA); public static final Version EARLIEST = AA; public static final Version VECTOR_EARLIEST = BA; @@ -87,7 +90,8 @@ public static Version parse(String input) { checkArgument(input != null); checkArgument(input.length() == 2); - for (var v : ALL) { + for (var v : ALL) + { if (input.equals(v.version)) return v; } @@ -110,7 +114,7 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - Version other = (Version)o; + Version other = (Version) o; return Objects.equal(version, other.version); } @@ -143,6 +147,11 @@ public boolean useImmutableComponentFiles() return CassandraRelevantProperties.IMMUTABLE_SAI_COMPONENTS.getBoolean() && onOrAfter(Version.CA); } + @Override + public int compareTo(Version other) + { + return this.version.compareTo(other.version); + } public interface FileNameFormatter { @@ -152,7 +161,7 @@ public interface FileNameFormatter */ default String format(IndexComponentType indexComponentType, IndexContext indexContext, int generation) { - return format(indexComponentType, indexContext == null ? null : indexContext.getIndexName(), generation); + return format(indexComponentType, indexContext == null ? null : indexContext.getIndexName(), generation); } /** @@ -160,8 +169,8 @@ default String format(IndexComponentType indexComponentType, IndexContext indexC * filename is returned (so the suffix of the full filename), not a full path. * * @param indexComponentType the type of the index component. - * @param indexName the name of the index, or {@code null} for a per-sstable component. - * @param generation the generation of the build of the component. + * @param indexName the name of the index, or {@code null} for a per-sstable component. + * @param generation the generation of the build of the component. */ String format(IndexComponentType indexComponentType, @Nullable String indexName, int generation); } @@ -191,7 +200,6 @@ else if (componentStr.startsWith("SAI" + SAI_SEPARATOR)) return tryParseStargazerFileName(componentStr); else return Optional.empty(); - } public static class ParsedFileName @@ -222,7 +230,7 @@ private static String aaFileNameFormat(IndexComponentType indexComponentType, @N StringBuilder stringBuilder = new StringBuilder(); stringBuilder.append(indexName == null ? String.format(VERSION_AA_PER_SSTABLE_FORMAT, indexComponentType.representation) - : String.format(VERSION_AA_PER_INDEX_FORMAT, indexName, indexComponentType.representation)); + : String.format(VERSION_AA_PER_INDEX_FORMAT, indexName, indexComponentType.representation)); return stringBuilder.toString(); } diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/DocLengthsReader.java b/src/java/org/apache/cassandra/index/sai/disk/v1/DocLengthsReader.java new file mode 100644 index 000000000000..45769f9337d8 --- /dev/null +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/DocLengthsReader.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.index.sai.disk.v1; + +import java.io.Closeable; +import java.io.IOException; + +import org.apache.cassandra.index.sai.disk.io.IndexInputReader; +import org.apache.cassandra.index.sai.utils.IndexFileUtils; +import org.apache.cassandra.index.sai.utils.SAICodecUtils; +import org.apache.cassandra.io.util.FileHandle; +import org.apache.cassandra.io.util.FileUtils; +import org.apache.lucene.codecs.CodecUtil; + +public class DocLengthsReader implements Closeable +{ + private final FileHandle fileHandle; + private final IndexInputReader input; + private final SegmentMetadata.ComponentMetadata componentMetadata; + + public DocLengthsReader(FileHandle fileHandle, SegmentMetadata.ComponentMetadata componentMetadata) + { + this.fileHandle = fileHandle; + this.input = IndexFileUtils.instance.openInput(fileHandle); + this.componentMetadata = componentMetadata; + } + + public int get(int rowID) throws IOException + { + // Account for header size in offset calculation + long position = componentMetadata.offset + (long) rowID * Integer.BYTES; + if (position >= componentMetadata.offset + componentMetadata.length) + return 0; + input.seek(position); + return input.readInt(); + } + + @Override + public void close() throws IOException + { + FileUtils.close(fileHandle, input); + } +} + diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java index 52988a685cd5..16d14bdd8ac3 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java @@ -68,9 +68,7 @@ public abstract class IndexSearcher implements Closeable, SegmentOrdering protected final SegmentMetadata metadata; protected final IndexContext indexContext; - private static final SSTableReadsListener NOOP_LISTENER = new SSTableReadsListener() {}; - - private final ColumnFilter columnFilter; + protected final ColumnFilter columnFilter; protected IndexSearcher(PrimaryKeyMap.Factory primaryKeyMapFactory, PerIndexFiles perIndexFiles, @@ -90,30 +88,36 @@ protected IndexSearcher(PrimaryKeyMap.Factory primaryKeyMapFactory, public abstract long indexFileCacheSize(); /** - * Search on-disk index synchronously + * Search on-disk index synchronously. Used for WHERE clause predicates, including BOUNDED_ANN. * * @param expression to filter on disk index * @param keyRange key range specific in read command, used by ANN index * @param queryContext to track per sstable cache and per query metrics * @param defer create the iterator in a deferred state - * @param limit the num of rows to returned, used by ANN index * @return {@link KeyRangeIterator} that matches given expression */ - public abstract KeyRangeIterator search(Expression expression, AbstractBounds keyRange, QueryContext queryContext, boolean defer, int limit) throws IOException; + public abstract KeyRangeIterator search(Expression expression, AbstractBounds keyRange, QueryContext queryContext, boolean defer) throws IOException; /** - * Order the on-disk index synchronously and produce an iterator in score order + * Order the rows by the given Orderer. Used for ORDER BY clause when + * (1) the WHERE predicate is either a partition restriction or a range restriction on the index, + * (2) there is no WHERE predicate, or + * (3) the planner determines it is better to post-filter the ordered results by the predicate. * * @param orderer the object containing the ordering logic * @param slice optional predicate to get a slice of the index * @param keyRange key range specific in read command, used by ANN index * @param queryContext to track per sstable cache and per query metrics - * @param limit the num of rows to returned, used by ANN index + * @param limit the initial num of rows to returned, used by ANN index. More rows may be requested if filtering throws away more than expected! * @return an iterator of {@link PrimaryKeyWithSortKey} in score order */ public abstract CloseableIterator orderBy(Orderer orderer, Expression slice, AbstractBounds keyRange, QueryContext queryContext, int limit) throws IOException; - + /** + * Order the rows by the given Orderer. Used for ORDER BY clause when the WHERE predicates + * have been applied first, yielding a list of primary keys. Again, `limit` is a planner hint for ANN to determine + * the initial number of results returned, not a maximum. + */ @Override public CloseableIterator orderResultsBy(SSTableReader reader, QueryContext context, List keys, Orderer orderer, int limit) throws IOException { @@ -124,7 +128,7 @@ public CloseableIterator orderResultsBy(SSTableReader rea { var slices = Slices.with(indexContext.comparator(), Slice.make(key.clustering())); // TODO if we end up needing to read the row still, is it better to store offset and use reader.unfilteredAt? - try (var iter = reader.iterator(key.partitionKey(), slices, columnFilter, false, NOOP_LISTENER)) + try (var iter = reader.iterator(key.partitionKey(), slices, columnFilter, false, SSTableReadsListener.NOOP_LISTENER)) { if (iter.hasNext()) { diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java index 15aaa349e1c8..08a63818c709 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java @@ -19,15 +19,27 @@ package org.apache.cassandra.index.sai.disk.v1; import java.io.IOException; +import java.io.UncheckedIOException; import java.lang.invoke.MethodHandles; +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; import com.google.common.base.MoreObjects; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.cassandra.cql3.Operator; import org.apache.cassandra.db.PartitionPosition; +import org.apache.cassandra.db.Slice; +import org.apache.cassandra.db.Slices; +import org.apache.cassandra.db.rows.Cell; +import org.apache.cassandra.db.rows.Row; import org.apache.cassandra.dht.AbstractBounds; +import org.apache.cassandra.exceptions.InvalidRequestException; import org.apache.cassandra.index.sai.IndexContext; import org.apache.cassandra.index.sai.QueryContext; import org.apache.cassandra.index.sai.SSTableContext; @@ -35,24 +47,31 @@ import org.apache.cassandra.index.sai.disk.TermsIterator; import org.apache.cassandra.index.sai.disk.format.IndexComponentType; import org.apache.cassandra.index.sai.disk.format.Version; +import org.apache.cassandra.index.sai.disk.v1.postings.IntersectingPostingList; import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; import org.apache.cassandra.index.sai.metrics.MulticastQueryEventListeners; import org.apache.cassandra.index.sai.metrics.QueryEventListener; import org.apache.cassandra.index.sai.plan.Expression; import org.apache.cassandra.index.sai.plan.Orderer; +import org.apache.cassandra.index.sai.utils.BM25Utils; +import org.apache.cassandra.index.sai.utils.BM25Utils.DocTF; +import org.apache.cassandra.index.sai.utils.PrimaryKey; import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey; import org.apache.cassandra.index.sai.utils.RowIdWithByteComparable; import org.apache.cassandra.index.sai.utils.SAICodecUtils; -import org.apache.cassandra.index.sai.utils.SegmentOrdering; +import org.apache.cassandra.io.sstable.format.SSTableReader; +import org.apache.cassandra.io.sstable.format.SSTableReadsListener; import org.apache.cassandra.io.util.FileUtils; import org.apache.cassandra.utils.AbstractIterator; import org.apache.cassandra.utils.CloseableIterator; import org.apache.cassandra.utils.bytecomparable.ByteComparable; +import static org.apache.cassandra.index.sai.disk.PostingList.END_OF_STREAM; + /** * Executes {@link Expression}s against the trie-based terms dictionary for an individual index segment. */ -public class InvertedIndexSearcher extends IndexSearcher implements SegmentOrdering +public class InvertedIndexSearcher extends IndexSearcher { private static final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); @@ -60,6 +79,9 @@ public class InvertedIndexSearcher extends IndexSearcher implements SegmentOrder private final QueryEventListener.TrieIndexEventListener perColumnEventListener; private final Version version; private final boolean filterRangeResults; + private final SSTableReader sstable; + private final DocLengthsReader docLengthsReader; + private final long segmentRowIdOffset; protected InvertedIndexSearcher(SSTableContext sstableContext, PerIndexFiles perIndexFiles, @@ -69,6 +91,7 @@ protected InvertedIndexSearcher(SSTableContext sstableContext, boolean filterRangeResults) throws IOException { super(sstableContext.primaryKeyMapFactory(), perIndexFiles, segmentMetadata, indexContext); + this.sstable = sstableContext.sstable; long root = metadata.getIndexRoot(IndexComponentType.TERMS_DATA); assert root >= 0; @@ -76,6 +99,9 @@ protected InvertedIndexSearcher(SSTableContext sstableContext, this.version = version; this.filterRangeResults = filterRangeResults; perColumnEventListener = (QueryEventListener.TrieIndexEventListener)indexContext.getColumnQueryMetrics(); + var docLengthsMeta = segmentMetadata.componentMetadatas.getOptional(IndexComponentType.DOC_LENGTHS); + this.segmentRowIdOffset = segmentMetadata.segmentRowIdOffset; + this.docLengthsReader = docLengthsMeta == null ? null : new DocLengthsReader(indexFiles.docLengths(), docLengthsMeta); Map map = metadata.componentMetadatas.get(IndexComponentType.TERMS_DATA).attributes; String footerPointerString = map.get(SAICodecUtils.FOOTER_POINTER); @@ -100,7 +126,7 @@ public long indexFileCacheSize() } @SuppressWarnings("resource") - public KeyRangeIterator search(Expression exp, AbstractBounds keyRange, QueryContext context, boolean defer, int limit) throws IOException + public KeyRangeIterator search(Expression exp, AbstractBounds keyRange, QueryContext context, boolean defer) throws IOException { PostingList postingList = searchPosting(exp, context); return toPrimaryKeyIterator(postingList, context); @@ -129,11 +155,117 @@ else if (exp.getOp() == Expression.Op.RANGE) throw new IllegalArgumentException(indexContext.logMessage("Unsupported expression: " + exp)); } + private Cell readColumn(SSTableReader sstable, PrimaryKey primaryKey) + { + var dk = primaryKey.partitionKey(); + var slices = Slices.with(indexContext.comparator(), Slice.make(primaryKey.clustering())); + try (var rowIterator = sstable.iterator(dk, slices, columnFilter, false, SSTableReadsListener.NOOP_LISTENER)) + { + var unfiltered = rowIterator.next(); + assert unfiltered.isRow() : unfiltered; + Row row = (Row) unfiltered; + return row.getCell(indexContext.getDefinition()); + } + } + @Override public CloseableIterator orderBy(Orderer orderer, Expression slice, AbstractBounds keyRange, QueryContext queryContext, int limit) throws IOException { - var iter = new RowIdWithTermsIterator(reader.allTerms(orderer.isAscending())); - return toMetaSortedIterator(iter, queryContext); + if (!orderer.isBM25()) + { + var iter = new RowIdWithTermsIterator(reader.allTerms(orderer.isAscending())); + return toMetaSortedIterator(iter, queryContext); + } + if (docLengthsReader == null) + throw new InvalidRequestException(indexContext.getIndexName() + " does not support BM25 scoring until it is rebuilt"); + + // find documents that match each term + var queryTerms = orderer.getQueryTerms(); + var postingLists = queryTerms.stream() + .collect(Collectors.toMap(Function.identity(), term -> + { + var encodedTerm = version.onDiskFormat().encodeForTrie(term, indexContext.getValidator()); + var listener = MulticastQueryEventListeners.of(queryContext, perColumnEventListener); + var postings = reader.exactMatch(encodedTerm, listener, queryContext); + return postings == null ? PostingList.EMPTY : postings; + })); + // extract the match count for each + var documentFrequencies = postingLists.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> (long) e.getValue().size())); + + var pkm = primaryKeyMapFactory.newPerSSTablePrimaryKeyMap(); + var merged = IntersectingPostingList.intersect(postingLists); + + // Wrap the iterator with resource management + var it = new AbstractIterator() { // Anonymous class extends AbstractIterator + private boolean closed; + + @Override + protected DocTF computeNext() + { + try + { + int rowId = merged.nextPosting(); + if (rowId == PostingList.END_OF_STREAM) + return endOfData(); + int docLength = docLengthsReader.get(rowId); // segment-local rowid + var pk = pkm.primaryKeyFromRowId(segmentRowIdOffset + rowId); // sstable-global rowid + return new DocTF(pk, docLength, merged.frequencies()); + } + catch (IOException e) + { + throw new UncheckedIOException(e); + } + } + + @Override + public void close() + { + if (closed) return; + closed = true; + FileUtils.closeQuietly(pkm, merged); + } + }; + return bm25Internal(it, queryTerms, documentFrequencies); + } + + private CloseableIterator bm25Internal(CloseableIterator keyIterator, + List queryTerms, + Map documentFrequencies) + { + var totalRows = sstable.getTotalRows(); + // since doc frequencies can be an estimate from the index histogram, which does not have bounded error, + // cap frequencies to total rows so that the IDF term doesn't turn negative + var cappedFrequencies = documentFrequencies.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> Math.min(e.getValue(), totalRows))); + var docStats = new BM25Utils.DocStats(cappedFrequencies, totalRows); + return BM25Utils.computeScores(keyIterator, + queryTerms, + docStats, + indexContext, + sstable.descriptor.id); + } + + @Override + public CloseableIterator orderResultsBy(SSTableReader reader, QueryContext queryContext, List keys, Orderer orderer, int limit) throws IOException + { + if (!orderer.isBM25()) + return super.orderResultsBy(reader, queryContext, keys, orderer, limit); + if (docLengthsReader == null) + throw new InvalidRequestException(indexContext.getIndexName() + " does not support BM25 scoring until it is rebuilt"); + + var queryTerms = orderer.getQueryTerms(); + // compute documentFrequencies from either histogram or an index search + var documentFrequencies = new HashMap(); + // any index new enough to support BM25 should also support histograms + assert metadata.version.onDiskFormat().indexFeatureSet().hasTermsHistogram(); + for (ByteBuffer term : queryTerms) + { + long matches = metadata.estimateNumRowsMatching(new Expression(indexContext).add(Operator.ANALYZER_MATCHES, term)); + documentFrequencies.put(term, matches); + } + var analyzer = indexContext.getAnalyzerFactory().create(); + var it = keys.stream().map(pk -> DocTF.createFromDocument(pk, readColumn(sstable, pk), analyzer, queryTerms)).iterator(); + return bm25Internal(CloseableIterator.wrap(it), queryTerms, documentFrequencies); } @Override @@ -147,7 +279,7 @@ public String toString() @Override public void close() { - reader.close(); + FileUtils.closeQuietly(reader, docLengthsReader); } /** @@ -172,7 +304,7 @@ protected RowIdWithByteComparable computeNext() while (true) { long nextPosting = currentPostingList.nextPosting(); - if (nextPosting != PostingList.END_OF_STREAM) + if (nextPosting != END_OF_STREAM) return new RowIdWithByteComparable(Math.toIntExact(nextPosting), currentTerm); if (!source.hasNext()) diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcher.java index 8a2aa354bd47..5911f2351014 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcher.java @@ -84,7 +84,7 @@ public long indexFileCacheSize() } @Override - public KeyRangeIterator search(Expression exp, AbstractBounds keyRange, QueryContext context, boolean defer, int limit) throws IOException + public KeyRangeIterator search(Expression exp, AbstractBounds keyRange, QueryContext context, boolean defer) throws IOException { PostingList postingList = searchPosting(exp, context); return toPrimaryKeyIterator(postingList, context); diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/MemtableIndexWriter.java b/src/java/org/apache/cassandra/index/sai/disk/v1/MemtableIndexWriter.java index 102699ca0288..8683409896ad 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/MemtableIndexWriter.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/MemtableIndexWriter.java @@ -19,15 +19,15 @@ import java.io.IOException; import java.lang.invoke.MethodHandles; +import java.util.Arrays; import java.util.Collections; -import java.util.Iterator; import java.util.concurrent.TimeUnit; import com.google.common.base.Stopwatch; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import com.carrotsearch.hppc.IntArrayList; +import org.agrona.collections.Int2IntHashMap; import org.apache.cassandra.db.DecoratedKey; import org.apache.cassandra.db.marshal.AbstractType; import org.apache.cassandra.db.rows.Row; @@ -36,17 +36,17 @@ import org.apache.cassandra.index.sai.disk.PerIndexWriter; import org.apache.cassandra.index.sai.disk.format.IndexComponents; import org.apache.cassandra.index.sai.disk.format.Version; -import org.apache.cassandra.index.sai.disk.vector.VectorMemtableIndex; import org.apache.cassandra.index.sai.disk.v1.kdtree.ImmutableOneDimPointValues; import org.apache.cassandra.index.sai.disk.v1.kdtree.NumericIndexWriter; import org.apache.cassandra.index.sai.disk.v1.trie.InvertedIndexWriter; +import org.apache.cassandra.index.sai.disk.vector.VectorMemtableIndex; import org.apache.cassandra.index.sai.memory.MemtableIndex; import org.apache.cassandra.index.sai.memory.RowMapping; +import org.apache.cassandra.index.sai.memory.TrieMemoryIndex; +import org.apache.cassandra.index.sai.memory.TrieMemtableIndex; import org.apache.cassandra.index.sai.utils.PrimaryKey; import org.apache.cassandra.index.sai.utils.TypeUtil; import org.apache.cassandra.utils.ByteBufferUtil; -import org.apache.cassandra.utils.Pair; -import org.apache.cassandra.utils.bytecomparable.ByteComparable; /** * Column index writer that flushes indexed data directly from the corresponding Memtable index, without buffering index @@ -126,7 +126,7 @@ public void complete(Stopwatch stopwatch) throws IOException } else { - final Iterator> iterator = rowMapping.merge(memtableIndex); + var iterator = rowMapping.merge(memtableIndex); try (MemtableTermsIterator terms = new MemtableTermsIterator(memtableIndex.getMinTerm(), memtableIndex.getMaxTerm(), iterator)) { long cellCount = flush(minKey, maxKey, indexContext().getValidator(), terms, rowMapping.maxSegmentRowId); @@ -151,9 +151,21 @@ private long flush(DecoratedKey minKey, DecoratedKey maxKey, AbstractType ter SegmentMetadata.ComponentMetadataMap indexMetas; if (TypeUtil.isLiteral(termComparator)) { - try (InvertedIndexWriter writer = new InvertedIndexWriter(perIndexComponents)) + try (InvertedIndexWriter writer = new InvertedIndexWriter(perIndexComponents, writeFrequencies())) { - indexMetas = writer.writeAll(metadataBuilder.intercept(terms)); + // Convert PrimaryKey->length map to rowId->length using RowMapping + var docLengths = new Int2IntHashMap(Integer.MIN_VALUE); + Arrays.stream(((TrieMemtableIndex) memtableIndex).getRangeIndexes()) + .map(TrieMemoryIndex.class::cast) + .forEach(trieMemoryIndex -> + trieMemoryIndex.getDocLengths().forEach((pk, length) -> { + int rowId = rowMapping.get(pk); + if (rowId >= 0) + docLengths.put(rowId, (int) length); + }) + ); + + indexMetas = writer.writeAll(metadataBuilder.intercept(terms), docLengths); numRows = writer.getPostingsCount(); } } @@ -194,6 +206,11 @@ private long flush(DecoratedKey minKey, DecoratedKey maxKey, AbstractType ter return numRows; } + private boolean writeFrequencies() + { + return indexContext().isAnalyzed() && Version.latest().onOrAfter(Version.EC); + } + private void flushVectorIndex(DecoratedKey minKey, DecoratedKey maxKey, long startTime, Stopwatch stopwatch) throws IOException { var vectorIndex = (VectorMemtableIndex) memtableIndex; diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/PartitionAwarePrimaryKeyFactory.java b/src/java/org/apache/cassandra/index/sai/disk/v1/PartitionAwarePrimaryKeyFactory.java index 3b49e22a84fd..c1b47916229b 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/PartitionAwarePrimaryKeyFactory.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/PartitionAwarePrimaryKeyFactory.java @@ -21,6 +21,7 @@ import java.util.Objects; import java.util.function.Supplier; +import io.github.jbellis.jvector.util.RamUsageEstimator; import org.apache.cassandra.db.Clustering; import org.apache.cassandra.db.DecoratedKey; import org.apache.cassandra.dht.Token; @@ -136,6 +137,19 @@ private ByteSource asComparableBytes(int terminator, ByteComparable.Version vers return ByteSource.withTerminator(terminator, tokenComparable, keyComparable, null); } + @Override + public long ramBytesUsed() + { + // Compute shallow size: object header + 4 references (3 declared + 1 implicit outer reference) + long shallowSize = RamUsageEstimator.NUM_BYTES_OBJECT_HEADER + 4L * RamUsageEstimator.NUM_BYTES_OBJECT_REF; + long preHashedDecoratedKeySize = partitionKey == null + ? 0 + : RamUsageEstimator.NUM_BYTES_OBJECT_HEADER + + 2L * RamUsageEstimator.NUM_BYTES_OBJECT_REF // token and key references + + 2L * Long.BYTES; + return shallowSize + token.getHeapSize() + preHashedDecoratedKeySize; + } + @Override public int compareTo(PrimaryKey o) { diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/PerIndexFiles.java b/src/java/org/apache/cassandra/index/sai/disk/v1/PerIndexFiles.java index e4a1ff6bca9e..b82c23cc589d 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/PerIndexFiles.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/PerIndexFiles.java @@ -28,6 +28,7 @@ import org.apache.cassandra.index.sai.disk.format.IndexComponents; import org.apache.cassandra.index.sai.disk.format.IndexComponentType; +import org.apache.cassandra.index.sai.disk.format.Version; import org.apache.cassandra.io.util.FileHandle; import org.apache.cassandra.io.util.FileUtils; @@ -104,6 +105,12 @@ public FileHandle pq() return getFile(IndexComponentType.PQ).sharedCopy(); } + /** It is the caller's responsibility to close the returned file handle. */ + public FileHandle docLengths() + { + return getFile(IndexComponentType.DOC_LENGTHS).sharedCopy(); + } + public FileHandle getFile(IndexComponentType indexComponentType) { FileHandle file = files.get(indexComponentType); diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/SSTableIndexWriter.java b/src/java/org/apache/cassandra/index/sai/disk/v1/SSTableIndexWriter.java index bb85ad583cba..a4812fb120ac 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/SSTableIndexWriter.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/SSTableIndexWriter.java @@ -241,7 +241,7 @@ else if (shouldFlush(sstableRowId)) if (term.remaining() == 0 && TypeUtil.skipsEmptyValue(indexContext.getValidator())) return; - long allocated = currentBuilder.addAll(term, type, key, sstableRowId); + long allocated = currentBuilder.analyzeAndAdd(term, type, key, sstableRowId); limiter.increment(allocated); } diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/Segment.java b/src/java/org/apache/cassandra/index/sai/disk/v1/Segment.java index 87fe0a998e5f..3d52c944d726 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/Segment.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/Segment.java @@ -144,7 +144,7 @@ public long indexFileCacheSize() */ public KeyRangeIterator search(Expression expression, AbstractBounds keyRange, QueryContext context, boolean defer, int limit) throws IOException { - return index.search(expression, keyRange, context, defer, limit); + return index.search(expression, keyRange, context, defer); } /** diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentBuilder.java b/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentBuilder.java index e1cbfc318180..095b43d3926f 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentBuilder.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentBuilder.java @@ -29,6 +29,7 @@ import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.LongAdder; +import java.util.stream.Collectors; import javax.annotation.concurrent.NotThreadSafe; import com.google.common.annotations.VisibleForTesting; @@ -42,6 +43,7 @@ import org.apache.cassandra.index.sai.IndexContext; import org.apache.cassandra.index.sai.analyzer.AbstractAnalyzer; import org.apache.cassandra.index.sai.analyzer.ByteLimitedMaterializer; +import org.apache.cassandra.index.sai.analyzer.NoOpAnalyzer; import org.apache.cassandra.index.sai.disk.PostingList; import org.apache.cassandra.index.sai.disk.RAMStringIndexer; import org.apache.cassandra.index.sai.disk.TermsIterator; @@ -154,9 +156,11 @@ public boolean isEmpty() return kdTreeRamBuffer.numRows() == 0; } - protected long addInternal(ByteBuffer term, int segmentRowId) + @Override + protected long addInternal(List terms, int segmentRowId) { - TypeUtil.toComparableBytes(term, termComparator, buffer); + assert terms.size() == 1; + TypeUtil.toComparableBytes(terms.get(0), termComparator, buffer); return kdTreeRamBuffer.addPackedValue(segmentRowId, new BytesRef(buffer)); } @@ -192,10 +196,14 @@ public static class RAMStringSegmentBuilder extends SegmentBuilder { super(components, rowIdOffset, limiter); this.byteComparableVersion = components.byteComparableVersionFor(IndexComponentType.TERMS_DATA); - ramIndexer = new RAMStringIndexer(); + ramIndexer = new RAMStringIndexer(writeFrequencies()); totalBytesAllocated = ramIndexer.estimatedBytesUsed(); totalBytesAllocatedConcurrent.add(totalBytesAllocated); + } + private boolean writeFrequencies() + { + return !(analyzer instanceof NoOpAnalyzer) && Version.latest().onOrAfter(Version.EC); } public boolean isEmpty() @@ -203,21 +211,26 @@ public boolean isEmpty() return ramIndexer.isEmpty(); } - protected long addInternal(ByteBuffer term, int segmentRowId) + @Override + protected long addInternal(List terms, int segmentRowId) { - var encodedTerm = components.onDiskFormat().encodeForTrie(term, termComparator); - var bytes = ByteSourceInverse.readBytes(encodedTerm.asComparableBytes(byteComparableVersion)); - var bytesRef = new BytesRef(bytes); - return ramIndexer.add(bytesRef, segmentRowId); + var bytesRefs = terms.stream() + .map(term -> components.onDiskFormat().encodeForTrie(term, termComparator)) + .map(encodedTerm -> ByteSourceInverse.readBytes(encodedTerm.asComparableBytes(byteComparableVersion))) + .map(BytesRef::new) + .collect(Collectors.toList()); + // ramIndexer is responsible for merging duplicate (term, row) pairs + return ramIndexer.addAll(bytesRefs, segmentRowId); } @Override protected void flushInternal(SegmentMetadataBuilder metadataBuilder) throws IOException { - try (InvertedIndexWriter writer = new InvertedIndexWriter(components)) + try (InvertedIndexWriter writer = new InvertedIndexWriter(components, writeFrequencies())) { TermsIterator termsWithPostings = ramIndexer.getTermsWithPostings(minTerm, maxTerm, byteComparableVersion); - var metadataMap = writer.writeAll(metadataBuilder.intercept(termsWithPostings)); + var docLengths = ramIndexer.getDocLengths(); + var metadataMap = writer.writeAll(metadataBuilder.intercept(termsWithPostings), docLengths); metadataBuilder.setComponentsMetadata(metadataMap); } } @@ -261,21 +274,23 @@ public boolean isEmpty() } @Override - protected long addInternal(ByteBuffer term, int segmentRowId) + protected long addInternal(List terms, int segmentRowId) { throw new UnsupportedOperationException(); } @Override - protected long addInternalAsync(ByteBuffer term, int segmentRowId) + protected long addInternalAsync(List terms, int segmentRowId) { + assert terms.size() == 1; + // CompactionGraph splits adding a node into two parts: // (1) maybeAddVector, which must be done serially because it writes to disk incrementally // (2) addGraphNode, which may be done asynchronously CompactionGraph.InsertionResult result; try { - result = graphIndex.maybeAddVector(term, segmentRowId); + result = graphIndex.maybeAddVector(terms.get(0), segmentRowId); } catch (IOException e) { @@ -367,19 +382,20 @@ public boolean isEmpty() } @Override - protected long addInternal(ByteBuffer term, int segmentRowId) + protected long addInternal(List terms, int segmentRowId) { - return graphIndex.add(term, segmentRowId); + assert terms.size() == 1; + return graphIndex.add(terms.get(0), segmentRowId); } @Override - protected long addInternalAsync(ByteBuffer term, int segmentRowId) + protected long addInternalAsync(List terms, int segmentRowId) { updatesInFlight.incrementAndGet(); compactionExecutor.submit(() -> { try { - long bytesAdded = addInternal(term, segmentRowId); + long bytesAdded = addInternal(terms, segmentRowId); totalBytesAllocatedConcurrent.add(bytesAdded); termSizeReservoir.update(bytesAdded); } @@ -454,23 +470,22 @@ public SegmentMetadata flush() throws IOException return metadataBuilder.build(); } - public long addAll(ByteBuffer term, AbstractType type, PrimaryKey key, long sstableRowId) + public long analyzeAndAdd(ByteBuffer rawTerm, AbstractType type, PrimaryKey key, long sstableRowId) { long totalSize = 0; if (TypeUtil.isLiteral(type)) { - List tokens = ByteLimitedMaterializer.materializeTokens(analyzer, term, components.context(), key); - for (ByteBuffer tokenTerm : tokens) - totalSize += add(tokenTerm, key, sstableRowId); + var terms = ByteLimitedMaterializer.materializeTokens(analyzer, rawTerm, components.context(), key); + totalSize += add(terms, key, sstableRowId); } else { - totalSize += add(term, key, sstableRowId); + totalSize += add(List.of(rawTerm), key, sstableRowId); } return totalSize; } - private long add(ByteBuffer term, PrimaryKey key, long sstableRowId) + private long add(List terms, PrimaryKey key, long sstableRowId) { assert !flushed : "Cannot add to flushed segment."; assert sstableRowId >= maxSSTableRowId; @@ -481,9 +496,12 @@ private long add(ByteBuffer term, PrimaryKey key, long sstableRowId) minKey = minKey == null ? key : minKey; maxKey = key; - // Note that the min and max terms are not encoded. - minTerm = TypeUtil.min(term, minTerm, termComparator, Version.latest()); - maxTerm = TypeUtil.max(term, maxTerm, termComparator, Version.latest()); + // Update term boundaries for all terms in this row + for (ByteBuffer term : terms) + { + minTerm = TypeUtil.min(term, minTerm, termComparator, Version.latest()); + maxTerm = TypeUtil.max(term, maxTerm, termComparator, Version.latest()); + } rowCount++; @@ -495,15 +513,23 @@ private long add(ByteBuffer term, PrimaryKey key, long sstableRowId) maxSegmentRowId = Math.max(maxSegmentRowId, segmentRowId); - long bytesAllocated = supportsAsyncAdd() - ? addInternalAsync(term, segmentRowId) - : addInternal(term, segmentRowId); - totalBytesAllocated += bytesAllocated; + long bytesAllocated; + if (supportsAsyncAdd()) + { + // only vector indexing is done async and there can only be one term + assert terms.size() == 1; + bytesAllocated = addInternalAsync(terms, segmentRowId); + } + else + { + bytesAllocated = addInternal(terms, segmentRowId); + } + totalBytesAllocated += bytesAllocated; return bytesAllocated; } - protected long addInternalAsync(ByteBuffer term, int segmentRowId) + protected long addInternalAsync(List terms, int segmentRowId) { throw new UnsupportedOperationException(); } @@ -566,7 +592,7 @@ long release(IndexContext indexContext) public abstract boolean isEmpty(); - protected abstract long addInternal(ByteBuffer term, int segmentRowId); + protected abstract long addInternal(List terms, int segmentRowId); protected abstract void flushInternal(SegmentMetadataBuilder metadataBuilder) throws IOException; diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentMetadata.java b/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentMetadata.java index 00361d6f27ab..b8f4cd769873 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentMetadata.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentMetadata.java @@ -376,6 +376,11 @@ public ComponentMetadata get(IndexComponentType indexComponentType) return metas.get(indexComponentType); } + public ComponentMetadata getOptional(IndexComponentType indexComponentType) + { + return metas.get(indexComponentType); + } + public Map> asMap() { Map> metaAttributes = new HashMap<>(); diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/TermsReader.java b/src/java/org/apache/cassandra/index/sai/disk/v1/TermsReader.java index a75871ec28cd..f05511b195a5 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/TermsReader.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/TermsReader.java @@ -228,7 +228,7 @@ public PostingsReader getPostingReader(long offset) throws IOException { PostingsReader.BlocksSummary header = new PostingsReader.BlocksSummary(postingsSummaryInput, offset); - return new PostingsReader(postingsInput, header, listener.postingListEventListener()); + return new PostingsReader(postingsInput, header, readFrequencies(), listener.postingListEventListener()); } } @@ -363,11 +363,17 @@ private PostingsReader currentReader(IndexInput postingsInput, PostingsReader.InputCloser.NOOP); return new PostingsReader(postingsInput, blocksSummary, + readFrequencies(), listener.postingListEventListener(), PostingsReader.InputCloser.NOOP); } } + private boolean readFrequencies() + { + return indexContext.isAnalyzed() && version.onOrAfter(Version.EC); + } + private class TermsScanner implements TermsIterator { private final TrieTermsDictionaryReader termsDictionaryReader; @@ -400,7 +406,7 @@ public PostingList postings() throws IOException { assert entry != null; var blockSummary = new PostingsReader.BlocksSummary(postingsSummaryInput, entry.right, PostingsReader.InputCloser.NOOP); - return new ScanningPostingsReader(postingsInput, blockSummary); + return new ScanningPostingsReader(postingsInput, blockSummary, readFrequencies()); } @Override @@ -461,7 +467,7 @@ public PostingList postings() throws IOException { assert entry != null; var blockSummary = new PostingsReader.BlocksSummary(postingsSummaryInput, entry.right, PostingsReader.InputCloser.NOOP); - return new ScanningPostingsReader(postingsInput, blockSummary); + return new ScanningPostingsReader(postingsInput, blockSummary, readFrequencies()); } @Override diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/V1OnDiskFormat.java b/src/java/org/apache/cassandra/index/sai/disk/v1/V1OnDiskFormat.java index 5dbd732f4676..8afe5d62b4c1 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/V1OnDiskFormat.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/V1OnDiskFormat.java @@ -83,10 +83,11 @@ public class V1OnDiskFormat implements OnDiskFormat IndexComponentType.META, IndexComponentType.TERMS_DATA, IndexComponentType.POSTING_LISTS); - private static final Set NUMERIC_COMPONENTS = EnumSet.of(IndexComponentType.COLUMN_COMPLETION_MARKER, - IndexComponentType.META, - IndexComponentType.KD_TREE, - IndexComponentType.KD_TREE_POSTING_LISTS); + + public static final Set NUMERIC_COMPONENTS = EnumSet.of(IndexComponentType.COLUMN_COMPLETION_MARKER, + IndexComponentType.META, + IndexComponentType.KD_TREE, + IndexComponentType.KD_TREE_POSTING_LISTS); /** * Global limit on heap consumed by all index segment building that occurs outside the context of Memtable flush. diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java new file mode 100644 index 000000000000..295796c9066f --- /dev/null +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.disk.v1.postings; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import javax.annotation.concurrent.NotThreadSafe; + +import org.apache.cassandra.index.sai.disk.PostingList; +import org.apache.cassandra.io.util.FileUtils; + +/** + * Performs intersection operations on multiple PostingLists, returning only postings + * that appear in all inputs. + */ +@NotThreadSafe +public class IntersectingPostingList implements PostingList +{ + private final Map postingsByTerm; + private final List postingLists; // so we can access by ordinal in intersection code + private final int size; + + private IntersectingPostingList(Map postingsByTerm) + { + if (postingsByTerm.isEmpty()) + throw new AssertionError(); + this.postingsByTerm = postingsByTerm; + this.postingLists = new ArrayList<>(postingsByTerm.values()); + this.size = postingLists.stream() + .mapToInt(PostingList::size) + .min() + .orElse(0); + } + + /** + * @return the intersection of the provided term-posting list mappings + */ + public static IntersectingPostingList intersect(Map postingsByTerm) + { + // TODO optimize cases where + // - we have a single postinglist + // - any posting list is empty (intersection also empty) + return new IntersectingPostingList(postingsByTerm); + } + + @Override + public int nextPosting() throws IOException + { + return findNextIntersection(Integer.MIN_VALUE, false); + } + + @Override + public int advance(int targetRowID) throws IOException + { + assert targetRowID >= 0 : targetRowID; + return findNextIntersection(targetRowID, true); + } + + @Override + public int frequency() + { + // call frequencies() instead + throw new UnsupportedOperationException(); + } + + public Map frequencies() + { + Map result = new HashMap<>(); + for (Map.Entry entry : postingsByTerm.entrySet()) + result.put(entry.getKey(), entry.getValue().frequency()); + return result; + } + + private int findNextIntersection(int targetRowID, boolean isAdvance) throws IOException + { + int maxRowId = targetRowID; + int maxRowIdIndex = -1; + + // Scan through all posting lists looking for a common row ID + for (int i = 0; i < postingLists.size(); i++) + { + // don't advance the sublist in which we found our current max + if (i == maxRowIdIndex) + continue; + + // Advance this sublist to the current max, special casing the first one as needed + PostingList list = postingLists.get(i); + int rowId = (isAdvance || maxRowIdIndex >= 0) + ? list.advance(maxRowId) + : list.nextPosting(); + if (rowId == END_OF_STREAM) + return END_OF_STREAM; + + // Update maxRowId + index if we find a larger value, or this was the first sublist evaluated + if (rowId > maxRowId || maxRowIdIndex < 0) + { + maxRowId = rowId; + maxRowIdIndex = i; + i = -1; // restart the scan with new maxRowId + } + } + + // Once we complete a full scan without finding a larger rowId, we've found an intersection + return maxRowId; + } + + @Override + public int size() + { + return size; + } + + @Override + public void close() + { + for (PostingList list : postingLists) + FileUtils.closeQuietly(list); + } +} + + diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/PostingsReader.java b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/PostingsReader.java index ed3556c4ffe5..0ee2ab63b132 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/PostingsReader.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/PostingsReader.java @@ -51,7 +51,7 @@ public class PostingsReader implements OrdinalPostingList protected final IndexInput input; protected final InputCloser runOnClose; - private final int blockSize; + private final int blockEntries; private final int numPostings; private final LongArray blockOffsets; private final LongArray blockMaxValues; @@ -69,6 +69,8 @@ public class PostingsReader implements OrdinalPostingList private long currentPosition; private LongValues currentFORValues; private int postingsDecoded = 0; + private int currentFrequency = Integer.MIN_VALUE; + private final boolean readFrequencies; @VisibleForTesting public PostingsReader(IndexInput input, long summaryOffset, QueryEventListener.PostingListEventListener listener) throws IOException @@ -78,7 +80,7 @@ public PostingsReader(IndexInput input, long summaryOffset, QueryEventListener.P public PostingsReader(IndexInput input, BlocksSummary summary, QueryEventListener.PostingListEventListener listener) throws IOException { - this(input, summary, listener, () -> { + this(input, summary, false, listener, () -> { try { input.close(); @@ -90,14 +92,29 @@ public PostingsReader(IndexInput input, BlocksSummary summary, QueryEventListene }); } - public PostingsReader(IndexInput input, BlocksSummary summary, QueryEventListener.PostingListEventListener listener, InputCloser runOnClose) throws IOException + public PostingsReader(IndexInput input, BlocksSummary summary, boolean readFrequencies, QueryEventListener.PostingListEventListener listener) throws IOException + { + this(input, summary, readFrequencies, listener, () -> { + try + { + input.close(); + } + finally + { + summary.close(); + } + }); + } + + public PostingsReader(IndexInput input, BlocksSummary summary, boolean readFrequencies, QueryEventListener.PostingListEventListener listener, InputCloser runOnClose) throws IOException { assert input instanceof IndexInputReader; logger.trace("Opening postings reader for {}", input); + this.readFrequencies = readFrequencies; this.input = input; this.seekingInput = new SeekingRandomAccessInput(input); this.blockOffsets = summary.offsets; - this.blockSize = summary.blockSize; + this.blockEntries = summary.blockEntries; this.numPostings = summary.numPostings; this.blockMaxValues = summary.maxValues; this.listener = listener; @@ -122,7 +139,7 @@ public interface InputCloser public static class BlocksSummary { - final int blockSize; + final int blockEntries; final int numPostings; final LongArray offsets; final LongArray maxValues; @@ -139,7 +156,7 @@ public BlocksSummary(IndexInput input, long offset, InputCloser runOnClose) thro this.runOnClose = runOnClose; input.seek(offset); - this.blockSize = input.readVInt(); + this.blockEntries = input.readVInt(); // This is the count of row ids in a single posting list. For now, a segment cannot have more than // Integer.MAX_VALUE row ids, so it is safe to use an int here. this.numPostings = input.readVInt(); @@ -323,10 +340,10 @@ private void lastPosInBlock(int block) // blockMaxValues is integer only actualSegmentRowId = Math.toIntExact(blockMaxValues.get(block)); //upper bound, since we might've advanced to the last block, but upper bound is enough - totalPostingsRead += (blockSize - blockIdx) + (block - postingsBlockIdx + 1) * blockSize; + totalPostingsRead += (blockEntries - blockIdx) + (block - postingsBlockIdx + 1) * blockEntries; postingsBlockIdx = block + 1; - blockIdx = blockSize; + blockIdx = blockEntries; } @Override @@ -341,9 +358,9 @@ public int nextPosting() throws IOException } @VisibleForTesting - int getBlockSize() + int getBlockEntries() { - return blockSize; + return blockEntries; } private int peekNext() throws IOException @@ -352,27 +369,28 @@ private int peekNext() throws IOException { return END_OF_STREAM; } - if (blockIdx == blockSize) + if (blockIdx == blockEntries) { reBuffer(); } - return actualSegmentRowId + nextRowID(); + return actualSegmentRowId + nextRowDelta(); } - private int nextRowID() + private int nextRowDelta() { - // currentFORValues is null when the all the values in the block are the same if (currentFORValues == null) { + currentFrequency = Integer.MIN_VALUE; return 0; } - else - { - final long id = currentFORValues.get(blockIdx); - postingsDecoded++; - return Math.toIntExact(id); - } + + long offset = readFrequencies ? 2L * blockIdx : blockIdx; + long id = currentFORValues.get(offset); + if (readFrequencies) + currentFrequency = Math.toIntExact(currentFORValues.get(offset + 1)); + postingsDecoded++; + return Math.toIntExact(id); } private void advanceOnePosition(int nextRowID) @@ -420,4 +438,9 @@ else if (bitsPerValue > 64) } currentFORValues = LuceneCompat.directReaderGetInstance(seekingInput, bitsPerValue, currentPosition); } + + @Override + public int frequency() { + return currentFrequency; + } } diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/PostingsWriter.java b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/PostingsWriter.java index ef027e1e2a39..5bf0decfb8e4 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/PostingsWriter.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/PostingsWriter.java @@ -27,7 +27,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.agrona.collections.IntArrayList; import org.agrona.collections.LongArrayList; import org.apache.cassandra.index.sai.disk.PostingList; import org.apache.cassandra.index.sai.disk.format.IndexComponentType; @@ -42,6 +41,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static java.lang.Math.max; +import static java.lang.Math.min; /** @@ -65,7 +65,7 @@ *

* Each posting list ends with a meta section and a skip table, that are written right after all postings blocks. Skip * interval is the same as block size, and each skip entry points to the end of each block. Skip table consist of - * block offsets and last values of each block, compressed as two FoR blocks. + * block offsets and maximum rowids of each block, compressed as two FoR blocks. *

* * Visual representation of the disk format: @@ -80,7 +80,7 @@ * | LIST SIZE | SKIP TABLE | * +---------------+------------+ * | BLOCKS POS.| - * | MAX VALUES | + * | MAX ROWIDS | * +------------+ * * @@ -91,13 +91,14 @@ public class PostingsWriter implements Closeable protected static final Logger logger = LoggerFactory.getLogger(PostingsWriter.class); // import static org.apache.lucene.codecs.lucene50.Lucene50PostingsFormat.BLOCK_SIZE; - private final static int BLOCK_SIZE = 128; + private final static int BLOCK_ENTRIES = 128; private static final String POSTINGS_MUST_BE_SORTED_ERROR_MSG = "Postings must be sorted ascending, got [%s] after [%s]"; private final IndexOutput dataOutput; - private final int blockSize; + private final int blockEntries; private final long[] deltaBuffer; + private final int[] freqBuffer; // frequency is capped at 255 private final LongArrayList blockOffsets = new LongArrayList(); private final LongArrayList blockMaxIDs = new LongArrayList(); private final ResettableByteBuffersIndexOutput inMemoryOutput; @@ -106,35 +107,43 @@ public class PostingsWriter implements Closeable private int bufferUpto; private long lastSegmentRowId; - private long maxDelta; // This number is the count of row ids written to the postings for this segment. Because a segment row id can be in // multiple postings list for the segment, this number could exceed Integer.MAX_VALUE, so we use a long. private long totalPostings; + private final boolean writeFrequencies; public PostingsWriter(IndexComponents.ForWrite components) throws IOException { - this(components, BLOCK_SIZE); + this(components, BLOCK_ENTRIES); } + public PostingsWriter(IndexComponents.ForWrite components, boolean writeFrequencies) throws IOException + { + this(components.addOrGet(IndexComponentType.POSTING_LISTS).openOutput(true), BLOCK_ENTRIES, writeFrequencies); + } + + public PostingsWriter(IndexOutput dataOutput) throws IOException { - this(dataOutput, BLOCK_SIZE); + this(dataOutput, BLOCK_ENTRIES, false); } @VisibleForTesting - PostingsWriter(IndexComponents.ForWrite components, int blockSize) throws IOException + PostingsWriter(IndexComponents.ForWrite components, int blockEntries) throws IOException { - this(components.addOrGet(IndexComponentType.POSTING_LISTS).openOutput(true), blockSize); + this(components.addOrGet(IndexComponentType.POSTING_LISTS).openOutput(true), blockEntries, false); } - private PostingsWriter(IndexOutput dataOutput, int blockSize) throws IOException + private PostingsWriter(IndexOutput dataOutput, int blockEntries, boolean writeFrequencies) throws IOException { assert dataOutput instanceof IndexOutputWriter; logger.debug("Creating postings writer for output {}", dataOutput); - this.blockSize = blockSize; + this.writeFrequencies = writeFrequencies; + this.blockEntries = blockEntries; this.dataOutput = dataOutput; startOffset = dataOutput.getFilePointer(); - deltaBuffer = new long[blockSize]; + deltaBuffer = new long[blockEntries]; + freqBuffer = new int[blockEntries]; inMemoryOutput = LuceneCompat.getResettableByteBuffersIndexOutput(dataOutput.order(), 1024, "blockOffsets"); SAICodecUtils.writeHeader(dataOutput); } @@ -191,7 +200,7 @@ public long write(PostingList postings) throws IOException int size = 0; while ((segmentRowId = postings.nextPosting()) != PostingList.END_OF_STREAM) { - writePosting(segmentRowId); + writePosting(segmentRowId, postings.frequency()); size++; totalPostings++; } @@ -210,19 +219,19 @@ public long getTotalPostings() return totalPostings; } - private void writePosting(long segmentRowId) throws IOException - { + private void writePosting(long segmentRowId, int freq) throws IOException { if (!(segmentRowId >= lastSegmentRowId || lastSegmentRowId == 0)) throw new IllegalArgumentException(String.format(POSTINGS_MUST_BE_SORTED_ERROR_MSG, segmentRowId, lastSegmentRowId)); + assert freq > 0; final long delta = segmentRowId - lastSegmentRowId; - maxDelta = max(maxDelta, delta); - deltaBuffer[bufferUpto++] = delta; + deltaBuffer[bufferUpto] = delta; + freqBuffer[bufferUpto] = min(freq, 255); + bufferUpto++; - if (bufferUpto == blockSize) - { + if (bufferUpto == blockEntries) { addBlockToSkipTable(segmentRowId); - writePostingsBlock(maxDelta, bufferUpto); + writePostingsBlock(bufferUpto); resetBlockCounters(); } lastSegmentRowId = segmentRowId; @@ -234,7 +243,7 @@ private void finish() throws IOException { addBlockToSkipTable(lastSegmentRowId); - writePostingsBlock(maxDelta, bufferUpto); + writePostingsBlock(bufferUpto); } } @@ -242,7 +251,6 @@ private void resetBlockCounters() { bufferUpto = 0; lastSegmentRowId = 0; - maxDelta = 0; } private void addBlockToSkipTable(long maxSegmentRowID) @@ -253,7 +261,7 @@ private void addBlockToSkipTable(long maxSegmentRowID) private void writeSummary(int exactSize) throws IOException { - dataOutput.writeVInt(blockSize); + dataOutput.writeVInt(blockEntries); dataOutput.writeVInt(exactSize); writeSkipTable(); } @@ -272,19 +280,29 @@ private void writeSkipTable() throws IOException writeSortedFoRBlock(blockMaxIDs, dataOutput); } - private void writePostingsBlock(long maxValue, int blockSize) throws IOException - { + private void writePostingsBlock(int entries) throws IOException { + // Find max value to determine bits needed + long maxValue = 0; + for (int i = 0; i < entries; i++) { + maxValue = max(maxValue, deltaBuffer[i]); + if (writeFrequencies) + maxValue = max(maxValue, freqBuffer[i]); + } + + // Use the maximum bits needed for either value type final int bitsPerValue = maxValue == 0 ? 0 : LuceneCompat.directWriterUnsignedBitsRequired(dataOutput.order(), maxValue); - - assert bitsPerValue < Byte.MAX_VALUE; - + dataOutput.writeByte((byte) bitsPerValue); - if (bitsPerValue > 0) - { - final DirectWriterAdapter writer = LuceneCompat.directWriterGetInstance(dataOutput.order(), dataOutput, blockSize, bitsPerValue); - for (int i = 0; i < blockSize; ++i) - { + if (bitsPerValue > 0) { + // Write interleaved [delta][freq] pairs + final DirectWriterAdapter writer = LuceneCompat.directWriterGetInstance(dataOutput.order(), + dataOutput, + writeFrequencies ? entries * 2L : entries, + bitsPerValue); + for (int i = 0; i < entries; ++i) { writer.add(deltaBuffer[i]); + if (writeFrequencies) + writer.add(freqBuffer[i]); } writer.finish(); } @@ -292,9 +310,9 @@ private void writePostingsBlock(long maxValue, int blockSize) throws IOException private void writeSortedFoRBlock(LongArrayList values, IndexOutput output) throws IOException { + assert !values.isEmpty(); final long maxValue = values.getLong(values.size() - 1); - assert values.size() > 0; final int bitsPerValue = maxValue == 0 ? 0 : LuceneCompat.directWriterUnsignedBitsRequired(output.order(), maxValue); output.writeByte((byte) bitsPerValue); if (bitsPerValue > 0) @@ -307,22 +325,4 @@ private void writeSortedFoRBlock(LongArrayList values, IndexOutput output) throw writer.finish(); } } - - private void writeSortedFoRBlock(IntArrayList values, IndexOutput output) throws IOException - { - final int maxValue = values.getInt(values.size() - 1); - - assert values.size() > 0; - final int bitsPerValue = maxValue == 0 ? 0 : LuceneCompat.directWriterUnsignedBitsRequired(output.order(), maxValue); - output.writeByte((byte) bitsPerValue); - if (bitsPerValue > 0) - { - final DirectWriterAdapter writer = LuceneCompat.directWriterGetInstance(output.order(), output, values.size(), bitsPerValue); - for (int i = 0; i < values.size(); ++i) - { - writer.add(values.getInt(i)); - } - writer.finish(); - } - } } diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/VectorPostingList.java b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/ReorderingPostingList.java similarity index 83% rename from src/java/org/apache/cassandra/index/sai/disk/v1/postings/VectorPostingList.java rename to src/java/org/apache/cassandra/index/sai/disk/v1/postings/ReorderingPostingList.java index 52155bf6ed59..f43c7c4e8dce 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/VectorPostingList.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/ReorderingPostingList.java @@ -19,31 +19,29 @@ package org.apache.cassandra.index.sai.disk.v1.postings; import java.io.IOException; +import java.util.function.ToIntFunction; import org.apache.cassandra.index.sai.disk.PostingList; -import org.apache.cassandra.index.sai.utils.RowIdWithMeta; import org.apache.cassandra.utils.CloseableIterator; import org.apache.lucene.util.LongHeap; /** * A posting list for ANN search results. Transforms results from similarity order to rowId order. */ -public class VectorPostingList implements PostingList +public class ReorderingPostingList implements PostingList { private final LongHeap segmentRowIds; private final int size; - public VectorPostingList(CloseableIterator source) + public ReorderingPostingList(CloseableIterator source, ToIntFunction rowIdTransformer) { - // TODO find int specific data structure? segmentRowIds = new LongHeap(32); int n = 0; - // Once the source is consumed, we have to close it. try (source) { while (source.hasNext()) { - segmentRowIds.push(source.next().getSegmentRowId()); + segmentRowIds.push(rowIdTransformer.applyAsInt(source.next())); n++; } } diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/ScanningPostingsReader.java b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/ScanningPostingsReader.java index 2fd7c2336ab2..5c9da6169e08 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/ScanningPostingsReader.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/ScanningPostingsReader.java @@ -31,9 +31,9 @@ */ public class ScanningPostingsReader extends PostingsReader { - public ScanningPostingsReader(IndexInput input, BlocksSummary summary) throws IOException + public ScanningPostingsReader(IndexInput input, BlocksSummary summary, boolean readFrequencies) throws IOException { - super(input, summary, QueryEventListener.PostingListEventListener.NO_OP, InputCloser.NOOP); + super(input, summary, readFrequencies, QueryEventListener.PostingListEventListener.NO_OP, InputCloser.NOOP); } @Override diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/trie/DocLengthsWriter.java b/src/java/org/apache/cassandra/index/sai/disk/v1/trie/DocLengthsWriter.java new file mode 100644 index 000000000000..90dc6b5c6ff9 --- /dev/null +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/trie/DocLengthsWriter.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.index.sai.disk.v1.trie; + +import java.io.Closeable; +import java.io.IOException; + +import org.agrona.collections.Int2IntHashMap; +import org.apache.cassandra.index.sai.disk.format.IndexComponents; +import org.apache.cassandra.index.sai.disk.format.IndexComponentType; +import org.apache.cassandra.index.sai.disk.io.IndexOutputWriter; +import org.apache.cassandra.index.sai.utils.SAICodecUtils; + +/** + * Writes document length information to disk for use in text scoring + */ +public class DocLengthsWriter implements Closeable +{ + private final IndexOutputWriter output; + + public DocLengthsWriter(IndexComponents.ForWrite components) throws IOException + { + this.output = components.addOrGet(IndexComponentType.DOC_LENGTHS).openOutput(true); + SAICodecUtils.writeHeader(output); + } + + public void writeDocLengths(Int2IntHashMap lengths) throws IOException + { + // Calculate max row ID from doc lengths map + int maxRowId = -1; + for (var keyIterator = lengths.keySet().iterator(); keyIterator.hasNext(); ) + { + int key = keyIterator.nextValue(); + if (key > maxRowId) + maxRowId = key; + } + + // write out the doc lengths in row order + for (int rowId = 0; rowId <= maxRowId; rowId++) + { + final int length = lengths.get(rowId); + output.writeInt(length == lengths.missingValue() ? 0 : length); + } + + SAICodecUtils.writeFooter(output); + } + + public long getFilePointer() + { + return output.getFilePointer(); + } + + @Override + public void close() throws IOException + { + output.close(); + } +} diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/trie/InvertedIndexWriter.java b/src/java/org/apache/cassandra/index/sai/disk/v1/trie/InvertedIndexWriter.java index e6adb61ce863..a61df65f2239 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/trie/InvertedIndexWriter.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/trie/InvertedIndexWriter.java @@ -25,10 +25,12 @@ import org.apache.commons.lang3.mutable.MutableLong; +import org.agrona.collections.Int2IntHashMap; import org.apache.cassandra.index.sai.disk.PostingList; import org.apache.cassandra.index.sai.disk.TermsIterator; import org.apache.cassandra.index.sai.disk.format.IndexComponents; import org.apache.cassandra.index.sai.disk.format.IndexComponentType; +import org.apache.cassandra.index.sai.disk.format.Version; import org.apache.cassandra.index.sai.disk.v1.SegmentMetadata; import org.apache.cassandra.index.sai.disk.v1.postings.PostingsWriter; import org.apache.cassandra.index.sai.utils.SAICodecUtils; @@ -42,12 +44,19 @@ public class InvertedIndexWriter implements Closeable { private final TrieTermsDictionaryWriter termsDictionaryWriter; private final PostingsWriter postingsWriter; + private final DocLengthsWriter docLengthsWriter; private long postingsAdded; public InvertedIndexWriter(IndexComponents.ForWrite components) throws IOException + { + this(components, false); + } + + public InvertedIndexWriter(IndexComponents.ForWrite components, boolean writeFrequencies) throws IOException { this.termsDictionaryWriter = new TrieTermsDictionaryWriter(components); - this.postingsWriter = new PostingsWriter(components); + this.postingsWriter = new PostingsWriter(components, writeFrequencies); + this.docLengthsWriter = Version.latest().onOrAfter(Version.EC) ? new DocLengthsWriter(components) : null; } /** @@ -58,7 +67,7 @@ public InvertedIndexWriter(IndexComponents.ForWrite components) throws IOExcepti * @return metadata describing the location of this inverted index in the overall SSTable * terms and postings component files */ - public SegmentMetadata.ComponentMetadataMap writeAll(TermsIterator terms) throws IOException + public SegmentMetadata.ComponentMetadataMap writeAll(TermsIterator terms, Int2IntHashMap docLengths) throws IOException { // Terms and postings writers are opened in append mode with pointers at the end of their respective files. long termsOffset = termsDictionaryWriter.getStartOffset(); @@ -91,6 +100,15 @@ public SegmentMetadata.ComponentMetadataMap writeAll(TermsIterator terms) throws components.put(IndexComponentType.POSTING_LISTS, -1, postingsOffset, postingsLength); components.put(IndexComponentType.TERMS_DATA, termsRoot, termsOffset, termsLength, map); + // Write doc lengths + if (docLengthsWriter != null) + { + long docLengthsOffset = docLengthsWriter.getFilePointer(); + docLengthsWriter.writeDocLengths(docLengths); + long docLengthsLength = docLengthsWriter.getFilePointer() - docLengthsOffset; + components.put(IndexComponentType.DOC_LENGTHS, -1, docLengthsOffset, docLengthsLength); + } + return components; } @@ -99,6 +117,8 @@ public void close() throws IOException { postingsWriter.close(); termsDictionaryWriter.close(); + if (docLengthsWriter != null) + docLengthsWriter.close(); } /** diff --git a/src/java/org/apache/cassandra/index/sai/disk/v2/RowAwarePrimaryKeyFactory.java b/src/java/org/apache/cassandra/index/sai/disk/v2/RowAwarePrimaryKeyFactory.java index e51a6229de58..cba1e1124ce3 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v2/RowAwarePrimaryKeyFactory.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v2/RowAwarePrimaryKeyFactory.java @@ -23,6 +23,7 @@ import java.util.function.Supplier; import java.util.stream.Collectors; +import io.github.jbellis.jvector.util.RamUsageEstimator; import org.apache.cassandra.db.Clustering; import org.apache.cassandra.db.ClusteringComparator; import org.apache.cassandra.db.DecoratedKey; @@ -212,5 +213,22 @@ public String toString() .map(ByteBufferUtil::bytesToHex) .collect(Collectors.toList()))); } + + @Override + public long ramBytesUsed() + { + // Object header + 4 references (token, partitionKey, clustering, primaryKeySupplier) + implicit outer reference + long size = RamUsageEstimator.NUM_BYTES_OBJECT_HEADER + + 5L * RamUsageEstimator.NUM_BYTES_OBJECT_REF; + + if (token != null) + size += token.getHeapSize(); + if (partitionKey != null) + size += RamUsageEstimator.NUM_BYTES_OBJECT_HEADER + + 2L * RamUsageEstimator.NUM_BYTES_OBJECT_REF + // token and key references + 2L * Long.BYTES; + // We don't count clustering size here as it's managed elsewhere + return size; + } } } diff --git a/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java index 80870dab36bf..7da1097e64ae 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java @@ -49,7 +49,7 @@ import org.apache.cassandra.index.sai.disk.v1.IndexSearcher; import org.apache.cassandra.index.sai.disk.v1.PerIndexFiles; import org.apache.cassandra.index.sai.disk.v1.SegmentMetadata; -import org.apache.cassandra.index.sai.disk.v1.postings.VectorPostingList; +import org.apache.cassandra.index.sai.disk.v1.postings.ReorderingPostingList; import org.apache.cassandra.index.sai.disk.v5.V5VectorPostingsWriter; import org.apache.cassandra.index.sai.disk.vector.BruteForceRowIdIterator; import org.apache.cassandra.index.sai.disk.vector.CassandraDiskAnn; @@ -64,8 +64,8 @@ import org.apache.cassandra.index.sai.utils.PrimaryKey; import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey; import org.apache.cassandra.index.sai.utils.RangeUtil; +import org.apache.cassandra.index.sai.utils.RowIdWithMeta; import org.apache.cassandra.index.sai.utils.RowIdWithScore; -import org.apache.cassandra.index.sai.utils.SegmentOrdering; import org.apache.cassandra.io.sstable.format.SSTableReader; import org.apache.cassandra.metrics.LinearFit; import org.apache.cassandra.metrics.PairedSlidingWindowReservoir; @@ -81,7 +81,7 @@ /** * Executes ann search against the graph for an individual index segment. */ -public class V2VectorIndexSearcher extends IndexSearcher implements SegmentOrdering +public class V2VectorIndexSearcher extends IndexSearcher { private static final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); private static final VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport(); @@ -133,13 +133,13 @@ public ProductQuantization getPQ() } @Override - public KeyRangeIterator search(Expression exp, AbstractBounds keyRange, QueryContext context, boolean defer, int limit) throws IOException + public KeyRangeIterator search(Expression exp, AbstractBounds keyRange, QueryContext context, boolean defer) throws IOException { - PostingList results = searchPosting(context, exp, keyRange, limit); + PostingList results = searchPosting(context, exp, keyRange); return toPrimaryKeyIterator(results, context); } - private PostingList searchPosting(QueryContext context, Expression exp, AbstractBounds keyRange, int limit) throws IOException + private PostingList searchPosting(QueryContext context, Expression exp, AbstractBounds keyRange) throws IOException { if (logger.isTraceEnabled()) logger.trace(indexContext.logMessage("Searching on expression '{}'..."), exp); @@ -151,7 +151,7 @@ private PostingList searchPosting(QueryContext context, Expression exp, Abstract // this is a thresholded query, so pass graph.size() as top k to get all results satisfying the threshold var result = searchInternal(keyRange, context, queryVector, graph.size(), graph.size(), exp.getEuclideanSearchThreshold()); - return new VectorPostingList(result); + return new ReorderingPostingList(result, RowIdWithMeta::getSegmentRowId); } @Override @@ -160,11 +160,11 @@ public CloseableIterator orderBy(Orderer orderer, Express if (logger.isTraceEnabled()) logger.trace(indexContext.logMessage("Searching on expression '{}'..."), orderer); - if (orderer.vector == null) + if (!orderer.isANN()) throw new IllegalArgumentException(indexContext.logMessage("Unsupported expression during ANN index query: " + orderer)); int rerankK = indexContext.getIndexWriterConfig().getSourceModel().rerankKFor(limit, graph.getCompression()); - var queryVector = vts.createFloatVector(orderer.vector); + var queryVector = vts.createFloatVector(orderer.getVectorTerm()); var result = searchInternal(keyRange, context, queryVector, limit, rerankK, 0); return toMetaSortedIterator(result, context); @@ -485,14 +485,14 @@ public CloseableIterator orderResultsBy(SSTableReader rea if (cost.shouldUseBruteForce()) { // brute force using the in-memory compressed vectors to cut down the number of results returned - var queryVector = vts.createFloatVector(orderer.vector); + var queryVector = vts.createFloatVector(orderer.getVectorTerm()); return toMetaSortedIterator(this.orderByBruteForce(queryVector, segmentOrdinalPairs, limit, rerankK), context); } // Create bits from the mapping var bits = bitSetForSearch(); segmentOrdinalPairs.forEachRightInt(bits::set); // else ask the index to perform a search limited to the bits we created - var queryVector = vts.createFloatVector(orderer.vector); + var queryVector = vts.createFloatVector(orderer.getVectorTerm()); var results = graph.search(queryVector, limit, rerankK, 0, bits, context, cost::updateStatistics); return toMetaSortedIterator(results, context); } diff --git a/src/java/org/apache/cassandra/index/sai/disk/v4/V4InvertedIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v4/V4InvertedIndexSearcher.java index ac11b238844e..f819b6e1ee6f 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v4/V4InvertedIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v4/V4InvertedIndexSearcher.java @@ -37,6 +37,6 @@ class V4InvertedIndexSearcher extends InvertedIndexSearcher SegmentMetadata segmentMetadata, IndexContext indexContext) throws IOException { - super(sstableContext, perIndexFiles, segmentMetadata, indexContext, Version.DB, false); + super(sstableContext, perIndexFiles, segmentMetadata, indexContext, segmentMetadata.version, false); } } diff --git a/src/java/org/apache/cassandra/index/sai/disk/v7/V7OnDiskFormat.java b/src/java/org/apache/cassandra/index/sai/disk/v7/V7OnDiskFormat.java new file mode 100644 index 000000000000..5457685cf1b0 --- /dev/null +++ b/src/java/org/apache/cassandra/index/sai/disk/v7/V7OnDiskFormat.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.disk.v7; + +import java.util.EnumSet; +import java.util.Set; + +import org.apache.cassandra.db.marshal.AbstractType; +import org.apache.cassandra.index.sai.disk.format.IndexComponentType; +import org.apache.cassandra.index.sai.disk.v1.V1OnDiskFormat; +import org.apache.cassandra.index.sai.disk.v3.V3OnDiskFormat; +import org.apache.cassandra.index.sai.disk.v6.V6OnDiskFormat; +import org.apache.cassandra.index.sai.utils.TypeUtil; + +public class V7OnDiskFormat extends V6OnDiskFormat +{ + public static final V7OnDiskFormat instance = new V7OnDiskFormat(); + + private static final Set LITERAL_COMPONENTS = EnumSet.of(IndexComponentType.COLUMN_COMPLETION_MARKER, + IndexComponentType.META, + IndexComponentType.TERMS_DATA, + IndexComponentType.POSTING_LISTS, + IndexComponentType.DOC_LENGTHS); + + @Override + public Set perIndexComponentTypes(AbstractType validator) + { + if (validator.isVector()) + return V3OnDiskFormat.VECTOR_COMPONENTS_V3; + if (TypeUtil.isLiteral(validator)) + return LITERAL_COMPONENTS; + return V1OnDiskFormat.NUMERIC_COMPONENTS; + } +} diff --git a/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java b/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java index 25231748149f..a6a52c462f3b 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java +++ b/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java @@ -54,6 +54,7 @@ import org.apache.cassandra.index.sai.disk.format.IndexComponents; import org.apache.cassandra.index.sai.disk.v1.SegmentMetadata; import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; +import org.apache.cassandra.index.sai.memory.MemoryIndex; import org.apache.cassandra.index.sai.memory.MemtableIndex; import org.apache.cassandra.index.sai.plan.Expression; import org.apache.cassandra.index.sai.plan.Orderer; @@ -222,7 +223,7 @@ public List> orderBy(QueryContext conte assert slice == null : "ANN does not support index slicing"; assert orderer.isANN() : "Only ANN is supported for vector search, received " + orderer.operator; - var qv = vts.createFloatVector(orderer.vector); + var qv = vts.createFloatVector(orderer.getVectorTerm()); return List.of(searchInternal(context, qv, keyRange, limit, 0)); } @@ -310,7 +311,7 @@ public CloseableIterator orderResultsBy(QueryContext cont relevantOrdinals.size(), keys.size(), maxBruteForceRows, graph.size(), limit); // convert the expression value to query vector - var qv = vts.createFloatVector(orderer.vector); + var qv = vts.createFloatVector(orderer.getVectorTerm()); // brute force path if (keysInGraph.size() <= maxBruteForceRows) { @@ -422,7 +423,7 @@ public static int ensureSaneEstimate(int rawEstimate, int rerankK, int graphSize } @Override - public Iterator>> iterator(DecoratedKey min, DecoratedKey max) + public Iterator>> iterator(DecoratedKey min, DecoratedKey max) { // This method is only used when merging an in-memory index with a RowMapping. This is done a different // way with the graph using the writeData method below. diff --git a/src/java/org/apache/cassandra/index/sai/memory/MemoryIndex.java b/src/java/org/apache/cassandra/index/sai/memory/MemoryIndex.java index 83c43e212d78..b50a58805430 100644 --- a/src/java/org/apache/cassandra/index/sai/memory/MemoryIndex.java +++ b/src/java/org/apache/cassandra/index/sai/memory/MemoryIndex.java @@ -20,6 +20,7 @@ import java.nio.ByteBuffer; import java.util.Iterator; +import java.util.List; import java.util.function.LongConsumer; import org.apache.cassandra.db.Clustering; @@ -30,8 +31,8 @@ import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; import org.apache.cassandra.index.sai.plan.Expression; import org.apache.cassandra.index.sai.plan.Orderer; +import org.apache.cassandra.index.sai.utils.PrimaryKey; import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey; -import org.apache.cassandra.index.sai.utils.PrimaryKeys; import org.apache.cassandra.utils.CloseableIterator; import org.apache.cassandra.utils.Pair; import org.apache.cassandra.utils.bytecomparable.ByteComparable; @@ -64,5 +65,17 @@ public abstract void add(DecoratedKey key, /** * Iterate all Term->PrimaryKeys mappings in sorted order */ - public abstract Iterator> iterator(); + public abstract Iterator>> iterator(); + + public static class PkWithFrequency + { + public final PrimaryKey pk; + public final int frequency; + + public PkWithFrequency(PrimaryKey pk, int frequency) + { + this.pk = pk; + this.frequency = frequency; + } + } } diff --git a/src/java/org/apache/cassandra/index/sai/memory/MemtableIndex.java b/src/java/org/apache/cassandra/index/sai/memory/MemtableIndex.java index 22bdc284bf3f..f340b06a5e19 100644 --- a/src/java/org/apache/cassandra/index/sai/memory/MemtableIndex.java +++ b/src/java/org/apache/cassandra/index/sai/memory/MemtableIndex.java @@ -20,6 +20,7 @@ import java.nio.ByteBuffer; import java.util.Iterator; +import java.util.List; import javax.annotation.Nullable; import org.apache.cassandra.db.Clustering; @@ -32,7 +33,6 @@ import org.apache.cassandra.index.sai.disk.vector.VectorMemtableIndex; import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; import org.apache.cassandra.index.sai.plan.Expression; -import org.apache.cassandra.index.sai.utils.PrimaryKey; import org.apache.cassandra.index.sai.utils.MemtableOrdering; import org.apache.cassandra.utils.Pair; import org.apache.cassandra.utils.bytecomparable.ByteComparable; @@ -79,7 +79,7 @@ default void update(DecoratedKey key, Clustering clustering, ByteBuffer oldValue long estimateMatchingRowsCount(Expression expression, AbstractBounds keyRange); - Iterator>> iterator(DecoratedKey min, DecoratedKey max); + Iterator>> iterator(DecoratedKey min, DecoratedKey max); static MemtableIndex createIndex(IndexContext indexContext, Memtable mt) { diff --git a/src/java/org/apache/cassandra/index/sai/memory/RowMapping.java b/src/java/org/apache/cassandra/index/sai/memory/RowMapping.java index a435a6e47169..e0cdc610a63c 100644 --- a/src/java/org/apache/cassandra/index/sai/memory/RowMapping.java +++ b/src/java/org/apache/cassandra/index/sai/memory/RowMapping.java @@ -17,10 +17,11 @@ */ package org.apache.cassandra.index.sai.memory; +import java.util.ArrayList; import java.util.Collections; import java.util.Iterator; +import java.util.List; -import com.carrotsearch.hppc.IntArrayList; import org.apache.cassandra.db.compaction.OperationType; import org.apache.cassandra.db.rows.RangeTombstoneMarker; import org.apache.cassandra.db.rows.Row; @@ -44,7 +45,7 @@ public class RowMapping public static final RowMapping DUMMY = new RowMapping() { @Override - public Iterator> merge(MemtableIndex index) { return Collections.emptyIterator(); } + public Iterator>> merge(MemtableIndex index) { return Collections.emptyIterator(); } @Override public void complete() {} @@ -89,6 +90,16 @@ public static RowMapping create(OperationType opType) return DUMMY; } + public static class RowIdWithFrequency { + public final int rowId; + public final int frequency; + + public RowIdWithFrequency(int rowId, int frequency) { + this.rowId = rowId; + this.frequency = frequency; + } + } + /** * Merge IndexMemtable(index term to PrimaryKeys mappings) with row mapping of a sstable * (PrimaryKey to RowId mappings). @@ -97,33 +108,32 @@ public static RowMapping create(OperationType opType) * * @return iterator of index term to postings mapping exists in the sstable */ - public Iterator> merge(MemtableIndex index) + public Iterator>> merge(MemtableIndex index) { assert complete : "RowMapping is not built."; - Iterator>> iterator = index.iterator(minKey.partitionKey(), maxKey.partitionKey()); - return new AbstractGuavaIterator>() + var it = index.iterator(minKey.partitionKey(), maxKey.partitionKey()); + return new AbstractGuavaIterator<>() { @Override - protected Pair computeNext() + protected Pair> computeNext() { - while (iterator.hasNext()) + while (it.hasNext()) { - Pair> pair = iterator.next(); + var pair = it.next(); - IntArrayList postings = null; - Iterator primaryKeys = pair.right; + List postings = null; + var primaryKeysWithFreq = pair.right; - while (primaryKeys.hasNext()) + for (var pkWithFreq : primaryKeysWithFreq) { - PrimaryKey primaryKey = primaryKeys.next(); - ByteComparable byteComparable = v -> primaryKey.asComparableBytes(v); + ByteComparable byteComparable = pkWithFreq.pk::asComparableBytes; Integer segmentRowId = rowMapping.get(byteComparable); if (segmentRowId != null) { - postings = postings == null ? new IntArrayList() : postings; - postings.add(segmentRowId); + postings = postings == null ? new ArrayList<>() : postings; + postings.add(new RowIdWithFrequency(segmentRowId, pkWithFreq.frequency)); } } if (postings != null && !postings.isEmpty()) diff --git a/src/java/org/apache/cassandra/index/sai/memory/TrieMemoryIndex.java b/src/java/org/apache/cassandra/index/sai/memory/TrieMemoryIndex.java index 57517d9192d0..593f11acccca 100644 --- a/src/java/org/apache/cassandra/index/sai/memory/TrieMemoryIndex.java +++ b/src/java/org/apache/cassandra/index/sai/memory/TrieMemoryIndex.java @@ -28,13 +28,15 @@ import java.math.BigDecimal; import java.nio.ByteBuffer; import java.util.ArrayList; - import java.util.Collection; - +import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.SortedSet; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.LongAdder; import java.util.function.LongConsumer; import javax.annotation.Nullable; @@ -43,6 +45,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import io.github.jbellis.jvector.util.Accountable; +import io.github.jbellis.jvector.util.RamUsageEstimator; import io.netty.util.concurrent.FastThreadLocal; import org.apache.cassandra.db.Clustering; import org.apache.cassandra.db.DecoratedKey; @@ -57,6 +61,7 @@ import org.apache.cassandra.dht.AbstractBounds; import org.apache.cassandra.index.sai.IndexContext; import org.apache.cassandra.index.sai.analyzer.AbstractAnalyzer; +import org.apache.cassandra.index.sai.analyzer.NoOpAnalyzer; import org.apache.cassandra.index.sai.disk.format.Version; import org.apache.cassandra.index.sai.disk.v6.TermsDistribution; import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; @@ -83,6 +88,8 @@ public class TrieMemoryIndex extends MemoryIndex private final InMemoryTrie data; private final PrimaryKeysReducer primaryKeysReducer; + private final Map termFrequencies; + private final Map docLengths = new HashMap<>(); private final Memtable memtable; private AbstractBounds keyBounds; @@ -111,6 +118,48 @@ public TrieMemoryIndex(IndexContext indexContext, Memtable memtable, AbstractBou this.data = InMemoryTrie.longLived(TypeUtil.BYTE_COMPARABLE_VERSION, TrieMemtable.BUFFER_TYPE, indexContext.columnFamilyStore().readOrdering()); this.primaryKeysReducer = new PrimaryKeysReducer(); this.memtable = memtable; + termFrequencies = new ConcurrentHashMap<>(); + } + + public synchronized Map getDocLengths() + { + return docLengths; + } + + private static class PkWithTerm implements Accountable + { + private final PrimaryKey pk; + private final ByteComparable term; + + private PkWithTerm(PrimaryKey pk, ByteComparable term) + { + this.pk = pk; + this.term = term; + } + + @Override + public long ramBytesUsed() + { + return RamUsageEstimator.NUM_BYTES_OBJECT_HEADER + + 2L * RamUsageEstimator.NUM_BYTES_OBJECT_REF + + pk.ramBytesUsed() + + ByteComparable.length(term, TypeUtil.BYTE_COMPARABLE_VERSION); + } + + @Override + public int hashCode() + { + return Objects.hash(pk, ByteComparable.length(term, TypeUtil.BYTE_COMPARABLE_VERSION)); + } + + @Override + public boolean equals(Object o) + { + if (o == null || getClass() != o.getClass()) return false; + PkWithTerm that = (PkWithTerm) o; + return Objects.equals(pk, that.pk) + && ByteComparable.compare(term, that.term, TypeUtil.BYTE_COMPARABLE_VERSION) == 0; + } } public synchronized void add(DecoratedKey key, @@ -129,12 +178,33 @@ public synchronized void add(DecoratedKey key, final long initialSizeOffHeap = data.usedSizeOffHeap(); final long reducerHeapSize = primaryKeysReducer.heapAllocations(); + if (docLengths.containsKey(primaryKey) && !(analyzer instanceof NoOpAnalyzer)) + { + AtomicLong heapReclaimed = new AtomicLong(); + // we're overwriting an existing cell, clear out the old term counts + for (Map.Entry entry : data.entrySet()) + { + var termInTrie = entry.getKey(); + entry.getValue().forEach(pkInTrie -> { + if (pkInTrie.equals(primaryKey)) + { + var t = new PkWithTerm(pkInTrie, termInTrie); + termFrequencies.remove(t); + heapReclaimed.addAndGet(RamUsageEstimator.HASHTABLE_RAM_BYTES_PER_ENTRY + t.ramBytesUsed() + Integer.BYTES); + } + }); + } + } + + int tokenCount = 0; while (analyzer.hasNext()) { final ByteBuffer term = analyzer.next(); if (!indexContext.validateMaxTermSize(key, term)) continue; + tokenCount++; + // Note that this term is already encoded once by the TypeUtil.encode call above. setMinMaxTerm(term.duplicate()); @@ -142,7 +212,25 @@ public synchronized void add(DecoratedKey key, try { - data.putSingleton(encodedTerm, primaryKey, primaryKeysReducer, term.limit() <= MAX_RECURSIVE_KEY_LENGTH); + data.putSingleton(encodedTerm, primaryKey, (existing, update) -> { + // First do the normal primary keys reduction + PrimaryKeys result = primaryKeysReducer.apply(existing, update); + if (analyzer instanceof NoOpAnalyzer) + return result; + + // Then update term frequency + var pkbc = new PkWithTerm(update, encodedTerm); + termFrequencies.compute(pkbc, (k, oldValue) -> { + if (oldValue == null) { + // New key added, track heap allocation + onHeapAllocationsTracker.accept(RamUsageEstimator.HASHTABLE_RAM_BYTES_PER_ENTRY + k.ramBytesUsed() + Integer.BYTES); + return 1; + } + return oldValue + 1; + }); + + return result; + }, term.limit() <= MAX_RECURSIVE_KEY_LENGTH); } catch (TrieSpaceExhaustedException e) { @@ -150,6 +238,14 @@ public synchronized void add(DecoratedKey key, } } + docLengths.put(primaryKey, tokenCount); + // heap used for term frequencies and doc lengths + long heapUsed = RamUsageEstimator.HASHTABLE_RAM_BYTES_PER_ENTRY + + primaryKey.ramBytesUsed() + + Integer.BYTES; + onHeapAllocationsTracker.accept(heapUsed); + + // memory used by the trie onHeapAllocationsTracker.accept((data.usedSizeOnHeap() - initialSizeOnHeap) + (primaryKeysReducer.heapAllocations() - reducerHeapSize)); offHeapAllocationsTracker.accept(data.usedSizeOffHeap() - initialSizeOffHeap); @@ -161,10 +257,10 @@ public synchronized void add(DecoratedKey key, } @Override - public Iterator> iterator() + public Iterator>> iterator() { Iterator> iterator = data.entrySet().iterator(); - return new Iterator>() + return new Iterator<>() { @Override public boolean hasNext() @@ -173,10 +269,17 @@ public boolean hasNext() } @Override - public Pair next() + public Pair> next() { Map.Entry entry = iterator.next(); - return Pair.create(entry.getKey(), entry.getValue()); + var pairs = new ArrayList(entry.getValue().size()); + for (PrimaryKey pk : entry.getValue().keys()) + { + var frequencyRaw = termFrequencies.get(new PkWithTerm(pk, entry.getKey())); + int frequency = frequencyRaw == null ? 1 : frequencyRaw; + pairs.add(new PkWithFrequency(pk, frequency)); + } + return Pair.create(entry.getKey(), pairs); } }; } @@ -418,7 +521,6 @@ private ByteComparable asByteComparable(ByteBuffer input) return Version.latest().onDiskFormat().encodeForTrie(input, indexContext.getValidator()); } - class PrimaryKeysReducer implements InMemoryTrie.UpsertTransformer { private final LongAdder heapAllocations = new LongAdder(); @@ -662,6 +764,13 @@ public void close() throws IOException } } + /** + * Iterator that provides ordered access to all indexed terms and their associated primary keys + * in the TrieMemoryIndex. For each term in the index, yields PrimaryKeyWithSortKey objects that + * combine a primary key with its associated term. + *

+ * A more verbose name could be KeysMatchingTermsByTermIterator. + */ private class AllTermsIterator extends AbstractIterator { private final Iterator> iterator; diff --git a/src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java b/src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java index f1cb9f5f47e7..dd440a4c42aa 100644 --- a/src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java +++ b/src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; +import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Objects; @@ -30,28 +31,37 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import com.google.common.collect.Iterators; import com.google.common.util.concurrent.Runnables; +import org.apache.cassandra.cql3.Operator; import org.apache.cassandra.db.Clustering; +import org.apache.cassandra.db.DataRange; import org.apache.cassandra.db.DecoratedKey; import org.apache.cassandra.db.PartitionPosition; +import org.apache.cassandra.db.RegularAndStaticColumns; +import org.apache.cassandra.db.filter.ColumnFilter; import org.apache.cassandra.db.marshal.AbstractType; import org.apache.cassandra.db.memtable.Memtable; import org.apache.cassandra.db.memtable.ShardBoundaries; import org.apache.cassandra.db.memtable.TrieMemtable; +import org.apache.cassandra.db.rows.Row; import org.apache.cassandra.dht.AbstractBounds; +import org.apache.cassandra.dht.Bounds; import org.apache.cassandra.index.sai.IndexContext; import org.apache.cassandra.index.sai.QueryContext; import org.apache.cassandra.index.sai.disk.format.Version; -import org.apache.cassandra.index.sai.iterators.KeyRangeLazyIterator; import org.apache.cassandra.index.sai.iterators.KeyRangeConcatIterator; +import org.apache.cassandra.index.sai.iterators.KeyRangeIntersectionIterator; import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; +import org.apache.cassandra.index.sai.iterators.KeyRangeLazyIterator; +import org.apache.cassandra.index.sai.memory.MemoryIndex.PkWithFrequency; import org.apache.cassandra.index.sai.plan.Expression; import org.apache.cassandra.index.sai.plan.Orderer; +import org.apache.cassandra.index.sai.utils.BM25Utils; import org.apache.cassandra.index.sai.utils.PrimaryKey; import org.apache.cassandra.index.sai.utils.PrimaryKeyWithByteComparable; import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey; -import org.apache.cassandra.index.sai.utils.PrimaryKeys; import org.apache.cassandra.index.sai.utils.TypeUtil; import org.apache.cassandra.sensors.Context; import org.apache.cassandra.sensors.RequestSensors; @@ -63,7 +73,6 @@ import org.apache.cassandra.utils.Reducer; import org.apache.cassandra.utils.SortingIterator; import org.apache.cassandra.utils.bytecomparable.ByteComparable; -import org.apache.cassandra.utils.bytecomparable.ByteSource; import org.apache.cassandra.utils.concurrent.OpOrder; public class TrieMemtableIndex implements MemtableIndex @@ -236,15 +245,46 @@ public List> orderBy(QueryContext query int startShard = boundaries.getShardForToken(keyRange.left.getToken()); int endShard = keyRange.right.isMinimum() ? boundaries.shardCount() - 1 : boundaries.getShardForToken(keyRange.right.getToken()); - var iterators = new ArrayList>(endShard - startShard + 1); - - for (int shard = startShard; shard <= endShard; ++shard) + if (!orderer.isBM25()) { - assert rangeIndexes[shard] != null; - iterators.add(rangeIndexes[shard].orderBy(orderer, slice)); + var iterators = new ArrayList>(endShard - startShard + 1); + for (int shard = startShard; shard <= endShard; ++shard) + { + assert rangeIndexes[shard] != null; + iterators.add(rangeIndexes[shard].orderBy(orderer, slice)); + } + return iterators; } - return iterators; + // BM25 + var queryTerms = orderer.getQueryTerms(); + + // Intersect iterators to find documents containing all terms + var termIterators = keyIteratorsPerTerm(queryContext, keyRange, queryTerms); + var intersectedIterator = KeyRangeIntersectionIterator.builder(termIterators).build(); + + // Compute BM25 scores + var docStats = computeDocumentFrequencies(queryContext, queryTerms); + var analyzer = indexContext.getAnalyzerFactory().create(); + var it = Iterators.transform(intersectedIterator, pk -> BM25Utils.DocTF.createFromDocument(pk, getCellForKey(pk), analyzer, queryTerms)); + return List.of(BM25Utils.computeScores(CloseableIterator.wrap(it), + queryTerms, + docStats, + indexContext, + memtable)); + } + + private List keyIteratorsPerTerm(QueryContext queryContext, AbstractBounds keyRange, List queryTerms) + { + List termIterators = new ArrayList<>(queryTerms.size()); + for (ByteBuffer term : queryTerms) + { + Expression expr = new Expression(indexContext); + expr.add(Operator.ANALYZER_MATCHES, term); + KeyRangeIterator iterator = search(queryContext, expr, keyRange, Integer.MAX_VALUE); + termIterators.add(iterator); + } + return termIterators; } @Override @@ -256,38 +296,105 @@ public long estimateMatchingRowsCount(Expression expression, AbstractBounds orderResultsBy(QueryContext context, List keys, Orderer orderer, int limit) + public CloseableIterator orderResultsBy(QueryContext queryContext, List keys, Orderer orderer, int limit) { if (keys.isEmpty()) return CloseableIterator.emptyIterator(); - return SortingIterator.createCloseable( - orderer.getComparator(), - keys, - key -> + + if (!orderer.isBM25()) + { + return SortingIterator.createCloseable( + orderer.getComparator(), + keys, + key -> + { + var partition = memtable.getPartition(key.partitionKey()); + if (partition == null) + return null; + var row = partition.getRow(key.clustering()); + if (row == null) + return null; + var cell = row.getCell(indexContext.getDefinition()); + if (cell == null) + return null; + + // We do two kinds of encoding... it'd be great to make this more straight forward, but this is what + // we have for now. I leave it to the reader to inspect the two methods to see the nuanced differences. + var encoding = encode(TypeUtil.encode(cell.buffer(), validator)); + return new PrimaryKeyWithByteComparable(indexContext, memtable, key, encoding); + }, + Runnables.doNothing() + ); + } + + // BM25 + var analyzer = indexContext.getAnalyzerFactory().create(); + var queryTerms = orderer.getQueryTerms(); + var docStats = computeDocumentFrequencies(queryContext, queryTerms); + var it = keys.stream().map(pk -> BM25Utils.DocTF.createFromDocument(pk, getCellForKey(pk), analyzer, queryTerms)).iterator(); + return BM25Utils.computeScores(CloseableIterator.wrap(it), + queryTerms, + docStats, + indexContext, + memtable); + } + + /** + * Count document frequencies for each term using brute force + */ + private BM25Utils.DocStats computeDocumentFrequencies(QueryContext queryContext, List queryTerms) + { + var termIterators = keyIteratorsPerTerm(queryContext, Bounds.unbounded(indexContext.getPartitioner()), queryTerms); + var documentFrequencies = new HashMap(); + for (int i = 0; i < queryTerms.size(); i++) + { + // KeyRangeIterator.getMaxKeys is not accurate enough, we have to count them + long keys = 0; + for (var it = termIterators.get(i); it.hasNext(); it.next()) + keys++; + documentFrequencies.put(queryTerms.get(i), keys); + } + long docCount = 0; + + // count all documents in the queried column + try (var it = memtable.makePartitionIterator(ColumnFilter.selection(RegularAndStaticColumns.of(indexContext.getDefinition())), + DataRange.allData(memtable.metadata().partitioner))) + { + while (it.hasNext()) { - var partition = memtable.getPartition(key.partitionKey()); - if (partition == null) - return null; - var row = partition.getRow(key.clustering()); - if (row == null) - return null; - var cell = row.getCell(indexContext.getDefinition()); - if (cell == null) - return null; - - // We do two kinds of encoding... it'd be great to make this more straight forward, but this is what - // we have for now. I leave it to the reader to inspect the two methods to see the nuanced differences. - var encoding = encode(TypeUtil.encode(cell.buffer(), validator)); - return new PrimaryKeyWithByteComparable(indexContext, memtable, key, encoding); - }, - Runnables.doNothing() - ); + var partitions = it.next(); + while (partitions.hasNext()) + { + var unfiltered = partitions.next(); + if (!unfiltered.isRow()) + continue; + var row = (Row) unfiltered; + var cell = row.getCell(indexContext.getDefinition()); + if (cell == null) + continue; + + docCount++; + } + } + } + return new BM25Utils.DocStats(documentFrequencies, docCount); + } + + @Nullable + private org.apache.cassandra.db.rows.Cell getCellForKey(PrimaryKey key) + { + var partition = memtable.getPartition(key.partitionKey()); + if (partition == null) + return null; + var row = partition.getRow(key.clustering()); + if (row == null) + return null; + return row.getCell(indexContext.getDefinition()); } private ByteComparable encode(ByteBuffer input) { - return indexContext.isLiteral() ? v -> ByteSource.preencoded(input) - : v -> TypeUtil.asComparableBytes(input, indexContext.getValidator(), v); + return Version.latest().onDiskFormat().encodeForTrie(input, indexContext.getValidator()); } /** @@ -302,26 +409,34 @@ private ByteComparable encode(ByteBuffer input) * @return iterator of indexed term to primary keys mapping in sorted by indexed term and primary key. */ @Override - public Iterator>> iterator(DecoratedKey min, DecoratedKey max) + public Iterator>> iterator(DecoratedKey min, DecoratedKey max) { int minSubrange = min == null ? 0 : boundaries.getShardForKey(min); int maxSubrange = max == null ? rangeIndexes.length - 1 : boundaries.getShardForKey(max); - List>> rangeIterators = new ArrayList<>(maxSubrange - minSubrange + 1); + List>>> rangeIterators = new ArrayList<>(maxSubrange - minSubrange + 1); for (int i = minSubrange; i <= maxSubrange; i++) rangeIterators.add(rangeIndexes[i].iterator()); - return MergeIterator.get(rangeIterators, (o1, o2) -> ByteComparable.compare(o1.left, o2.left, TypeUtil.BYTE_COMPARABLE_VERSION), + return MergeIterator.get(rangeIterators, + (o1, o2) -> ByteComparable.compare(o1.left, o2.left, TypeUtil.BYTE_COMPARABLE_VERSION), new PrimaryKeysMergeReducer(rangeIterators.size())); } - // The PrimaryKeysMergeReducer receives the range iterators from each of the range indexes selected based on the - // min and max keys passed to the iterator method. It doesn't strictly do any reduction because the terms in each - // range index are unique. It will receive at most one range index entry per selected range index before getReduced - // is called. - private static class PrimaryKeysMergeReducer extends Reducer, Pair>> + /** + * Used to merge sorted primary keys from multiple TrieMemoryIndex shards for a given indexed term. + * For each term that appears in multiple shards, the reducer: + * 1. Receives exactly one call to reduce() per shard containing that term + * 2. Merges all the primary keys for that term via getReduced() + * 3. Resets state via onKeyChange() before processing the next term + *

+ * While this follows the Reducer pattern, its "reduction" operation is a simple merge since each term + * appears at most once per shard, and each key will only be found in a given shard, so there are no values to aggregate; + * we simply combine and sort the primary keys from each shard that contains the term. + */ + private static class PrimaryKeysMergeReducer extends Reducer>, Pair>> { - private final Pair[] rangeIndexEntriesToMerge; + private final Pair>[] rangeIndexEntriesToMerge; private final Comparator comparator; private ByteComparable term; @@ -337,7 +452,7 @@ private static class PrimaryKeysMergeReducer extends Reducer termPair) + public void reduce(int index, Pair> termPair) { Preconditions.checkArgument(rangeIndexEntriesToMerge[index] == null, "Terms should be unique in the memory index"); @@ -348,17 +463,17 @@ public void reduce(int index, Pair termPair) @Override // Return a merger of the term keys for the term. - public Pair> getReduced() + public Pair> getReduced() { Preconditions.checkArgument(term != null, "The term must exist in the memory index"); - List> keyIterators = new ArrayList<>(rangeIndexEntriesToMerge.length); - for (Pair p : rangeIndexEntriesToMerge) - if (p != null && p.right != null && !p.right.isEmpty()) - keyIterators.add(p.right.iterator()); + var merged = new ArrayList(); + for (var p : rangeIndexEntriesToMerge) + if (p != null && p.right != null) + merged.addAll(p.right); - Iterator primaryKeys = MergeIterator.get(keyIterators, comparator, Reducer.getIdentity()); - return Pair.create(term, primaryKeys); + merged.sort((o1, o2) -> comparator.compare(o1.pk, o2.pk)); + return Pair.create(term, merged); } @Override diff --git a/src/java/org/apache/cassandra/index/sai/plan/Expression.java b/src/java/org/apache/cassandra/index/sai/plan/Expression.java index a1bd9acddc4b..aac8e240829b 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/Expression.java +++ b/src/java/org/apache/cassandra/index/sai/plan/Expression.java @@ -101,6 +101,7 @@ public static Op valueOf(Operator operator) return IN; case ANN: + case BM25: case ORDER_BY_ASC: case ORDER_BY_DESC: return ORDER_BY; @@ -250,6 +251,7 @@ public Expression add(Operator op, ByteBuffer value) boundedAnnEuclideanDistanceThreshold = GeoUtil.amplifiedEuclideanSimilarityThreshold(lower.value.vector, searchRadiusMeters); break; case ANN: + case BM25: case ORDER_BY_ASC: case ORDER_BY_DESC: // If we alread have an operation on the column, we don't need to set the ORDER_BY op because diff --git a/src/java/org/apache/cassandra/index/sai/plan/Operation.java b/src/java/org/apache/cassandra/index/sai/plan/Operation.java index 7722ada87463..8968d4390f5a 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/Operation.java +++ b/src/java/org/apache/cassandra/index/sai/plan/Operation.java @@ -93,7 +93,7 @@ protected static ListMultimap analyzeGroup(QueryCont analyzer.reset(e.getIndexValue()); // EQ/LIKE_*/NOT_EQ can have multiple expressions e.g. text = "Hello World", - // becomes text = "Hello" OR text = "World" because "space" is always interpreted as a split point (by analyzer), + // becomes text = "Hello" AND text = "World" because "space" is always interpreted as a split point (by analyzer), // CONTAINS/CONTAINS_KEY are always treated as multiple expressions since they currently only targetting // collections, NOT_EQ is made an independent expression only in case of pre-existing multiple EQ expressions, or // if there is no EQ operations and NOT_EQ is met or a single NOT_EQ expression present, @@ -102,6 +102,7 @@ protected static ListMultimap analyzeGroup(QueryCont boolean isMultiExpression = columnIsMultiExpression.getOrDefault(e.column(), Boolean.FALSE); switch (e.operator()) { + // case BM25: leave it at the default of `false` case EQ: // EQ operator will always be a multiple expression because it is being used by map entries isMultiExpression = indexContext.isNonFrozenCollection(); diff --git a/src/java/org/apache/cassandra/index/sai/plan/Orderer.java b/src/java/org/apache/cassandra/index/sai/plan/Orderer.java index e23c83ed91e2..af202bae5f5b 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/Orderer.java +++ b/src/java/org/apache/cassandra/index/sai/plan/Orderer.java @@ -19,9 +19,12 @@ package org.apache.cassandra.index.sai.plan; import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; import java.util.EnumSet; +import java.util.HashSet; +import java.util.List; import java.util.stream.Collectors; import javax.annotation.Nullable; @@ -41,12 +44,15 @@ public class Orderer { // The list of operators that are valid for order by clauses. static final EnumSet ORDER_BY_OPERATORS = EnumSet.of(Operator.ANN, + Operator.BM25, Operator.ORDER_BY_ASC, Operator.ORDER_BY_DESC); public final IndexContext context; public final Operator operator; - public final float[] vector; + public final ByteBuffer term; + private float[] vector; + private List queryTerms; /** * Create an orderer for the given index context, operator, and term. @@ -59,7 +65,7 @@ public Orderer(IndexContext context, Operator operator, ByteBuffer term) this.context = context; assert ORDER_BY_OPERATORS.contains(operator) : "Invalid operator for order by clause " + operator; this.operator = operator; - this.vector = context.getValidator().isVector() ? TypeUtil.decomposeVector(context.getValidator(), term) : null; + this.term = term; } public String getIndexName() @@ -75,8 +81,8 @@ public boolean isAscending() public Comparator getComparator() { - // ANN's PrimaryKeyWithSortKey is always descending, so we use the natural order for the priority queue - return isAscending() || isANN() ? Comparator.naturalOrder() : Comparator.reverseOrder(); + // ANN/BM25's PrimaryKeyWithSortKey is always descending, so we use the natural order for the priority queue + return (isAscending() || isANN() || isBM25()) ? Comparator.naturalOrder() : Comparator.reverseOrder(); } public boolean isLiteral() @@ -89,6 +95,11 @@ public boolean isANN() return operator == Operator.ANN; } + public boolean isBM25() + { + return operator == Operator.BM25; + } + @Nullable public static Orderer from(SecondaryIndexManager indexManager, RowFilter filter) { @@ -110,8 +121,38 @@ public static boolean isFilterExpressionOrderer(RowFilter.Expression expression) public String toString() { String direction = isAscending() ? "ASC" : "DESC"; - return isANN() - ? context.getColumnName() + " ANN OF " + Arrays.toString(vector) + ' ' + direction - : context.getColumnName() + ' ' + direction; + if (isANN()) + return context.getColumnName() + " ANN OF " + Arrays.toString(getVectorTerm()) + ' ' + direction; + if (isBM25()) + return context.getColumnName() + " BM25 OF " + TypeUtil.getString(term, context.getValidator()) + ' ' + direction; + return context.getColumnName() + ' ' + direction; + } + + public float[] getVectorTerm() + { + if (vector == null) + vector = TypeUtil.decomposeVector(context.getValidator(), term); + return vector; + } + + public List getQueryTerms() + { + if (queryTerms != null) + return queryTerms; + + var queryAnalyzer = context.getQueryAnalyzerFactory().create(); + // Split query into terms + var uniqueTerms = new HashSet(); + queryAnalyzer.reset(term); + try + { + queryAnalyzer.forEachRemaining(uniqueTerms::add); + } + finally + { + queryAnalyzer.end(); + } + queryTerms = new ArrayList<>(uniqueTerms); + return queryTerms; } } diff --git a/src/java/org/apache/cassandra/index/sai/plan/Plan.java b/src/java/org/apache/cassandra/index/sai/plan/Plan.java index e40906c73d52..d82b92905403 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/Plan.java +++ b/src/java/org/apache/cassandra/index/sai/plan/Plan.java @@ -1257,10 +1257,12 @@ protected double estimateSelectivity() @Override protected KeysIterationCost estimateCost() { - return ordering.isANN() - ? estimateAnnSortCost() - : estimateGlobalSortCost(); - + if (ordering.isANN()) + return estimateAnnSortCost(); + else if (ordering.isBM25()) + return estimateBm25SortCost(); + else + return estimateGlobalSortCost(); } private KeysIterationCost estimateAnnSortCost() @@ -1277,6 +1279,21 @@ private KeysIterationCost estimateAnnSortCost() return new KeysIterationCost(expectedKeys, initCost, searchCost); } + private KeysIterationCost estimateBm25SortCost() + { + double expectedKeys = access.expectedAccessCount(source.expectedKeys()); + + int termCount = ordering.getQueryTerms().size(); + // all of the cost for BM25 is up front since the index doesn't give us the information we need + // to return results in order, in isolation. The big cost is reading the indexed cells out of + // the sstables. + // VSTODO if we had stats on cell size _per column_ we could usefully include ROW_BYTE_COST + double initCost = source.fullCost() + + source.expectedKeys() * (hrs(ROW_CELL_COST) + ROW_CELL_COST) + + termCount * BM25_SCORE_COST; + return new KeysIterationCost(expectedKeys, initCost, 0); + } + private KeysIterationCost estimateGlobalSortCost() { return new KeysIterationCost(source.expectedKeys(), @@ -1310,19 +1327,51 @@ protected KeysSort withAccess(Access access) } /** - * Returns all keys in ANN order. - * Contrary to {@link KeysSort}, there is no input node here and the output is generated lazily. + * Base class for index scans that return results in a computed order (ANN, BM25) + * rather than the natural index order. */ - final static class AnnIndexScan extends Leaf + abstract static class ScoredIndexScan extends Leaf { final Orderer ordering; - AnnIndexScan(Factory factory, int id, Access access, Orderer ordering) + protected ScoredIndexScan(Factory factory, int id, Access access, Orderer ordering) { super(factory, id, access); this.ordering = ordering; } + @Nullable + @Override + protected Orderer ordering() + { + return ordering; + } + + @Override + protected double estimateSelectivity() + { + return 1.0; + } + + @Override + protected Iterator execute(Executor executor) + { + int softLimit = max(1, round((float) access.expectedAccessCount(factory.tableMetrics.rows))); + return executor.getTopKRows((Expression) null, softLimit); + } + } + + /** + * Returns all keys in ANN order. + * Contrary to {@link KeysSort}, there is no input node here and the output is generated lazily. + */ + final static class AnnIndexScan extends ScoredIndexScan + { + protected AnnIndexScan(Factory factory, int id, Access access, Orderer ordering) + { + super(factory, id, access, ordering); + } + @Override protected KeysIterationCost estimateCost() { @@ -1335,32 +1384,53 @@ protected KeysIterationCost estimateCost() return new KeysIterationCost(expectedKeys, initCost, searchCost); } - @Nullable @Override - protected Orderer ordering() + protected KeysIteration withAccess(Access access) { - return ordering; + return Objects.equals(access, this.access) + ? this + : new AnnIndexScan(factory, id, access, ordering); } + @Nullable @Override - protected Iterator execute(Executor executor) + protected IndexContext getIndexContext() { - int softLimit = max(1, round((float) access.expectedAccessCount(factory.tableMetrics.rows))); - return executor.getTopKRows((Expression) null, softLimit); + return ordering.context; } + } + /** + * Returns all keys in BM25 order. + * Like AnnIndexScan, this generates results lazily without an input node. + */ + final static class Bm25IndexScan extends ScoredIndexScan + { + protected Bm25IndexScan(Factory factory, int id, Access access, Orderer ordering) + { + super(factory, id, access, ordering); + } + + @Nonnull @Override - protected KeysIteration withAccess(Access access) + protected KeysIterationCost estimateCost() { - return Objects.equals(access, this.access) - ? this - : new AnnIndexScan(factory, id, access, ordering); + double expectedKeys = access.expectedAccessCount(factory.tableMetrics.rows); + int expectedKeysInt = Math.max(1, (int) Math.ceil(expectedKeys)); + + int termCount = ordering.getQueryTerms().size(); + double initCost = expectedKeysInt * (hrs(ROW_CELL_COST) + ROW_CELL_COST) + + termCount * BM25_SCORE_COST; + + return new KeysIterationCost(expectedKeys, initCost, 0); } @Override - protected double estimateSelectivity() + protected KeysIteration withAccess(Access access) { - return 1.0; + return Objects.equals(access, this.access) + ? this + : new Bm25IndexScan(factory, id, access, ordering); } @Override @@ -1684,6 +1754,8 @@ private KeysIteration indexScan(Expression predicate, long matchingKeysCount, Or if (ordering != null) if (ordering.isANN()) return new AnnIndexScan(this, id, defaultAccess, ordering); + else if (ordering.isBM25()) + return new Bm25IndexScan(this, id, defaultAccess, ordering); else if (ordering.isLiteral()) return new LiteralIndexScan(this, id, predicate, matchingKeysCount, defaultAccess, ordering); else @@ -1938,6 +2010,9 @@ public static class CostCoefficients /** Additional cost added to row fetch cost per each serialized byte of the row */ public final static double ROW_BYTE_COST = 0.005; + + /** Cost to perform BM25 scoring, per query term */ + public final static double BM25_SCORE_COST = 0.5; } /** Convenience builder for building intersection and union nodes */ diff --git a/src/java/org/apache/cassandra/index/sai/plan/QueryController.java b/src/java/org/apache/cassandra/index/sai/plan/QueryController.java index 15b0965fe82e..992b5f1334ab 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/QueryController.java +++ b/src/java/org/apache/cassandra/index/sai/plan/QueryController.java @@ -195,6 +195,11 @@ public TableMetadata metadata() return command.metadata(); } + public ReadCommand command() + { + return command; + } + RowFilter.FilterElement filterOperation() { // NOTE: we cannot remove the order by filter expression here yet because it is used in the FilterTree class @@ -883,6 +888,7 @@ private long estimateMatchingRowCount(Expression predicate) switch (predicate.getOp()) { case EQ: + case MATCH: case CONTAINS_KEY: case CONTAINS_VALUE: case NOT_EQ: diff --git a/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexQueryPlan.java b/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexQueryPlan.java index 22e564365bea..39a66cbad9de 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexQueryPlan.java +++ b/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexQueryPlan.java @@ -216,7 +216,7 @@ public Function postProcessor(ReadCommand return partitions -> partitions; // in case of top-k query, filter out rows that are not actually global top-K - return partitions -> new TopKProcessor(command).filter(partitions); + return partitions -> new TopKProcessor(command).reorder(partitions); } /** diff --git a/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java index 2f0cabd05c2c..f5ebe85a1ed9 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java @@ -30,12 +30,10 @@ import java.util.Queue; import java.util.function.Supplier; import java.util.stream.Collectors; - import javax.annotation.Nonnull; import javax.annotation.Nullable; import com.google.common.base.Preconditions; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -45,8 +43,13 @@ import org.apache.cassandra.db.PartitionPosition; import org.apache.cassandra.db.ReadCommand; import org.apache.cassandra.db.ReadExecutionController; +import org.apache.cassandra.db.filter.ColumnFilter; +import org.apache.cassandra.db.marshal.FloatType; import org.apache.cassandra.db.partitions.UnfilteredPartitionIterator; import org.apache.cassandra.db.rows.AbstractUnfilteredRowIterator; +import org.apache.cassandra.db.rows.BTreeRow; +import org.apache.cassandra.db.rows.BufferCell; +import org.apache.cassandra.db.rows.ColumnData; import org.apache.cassandra.db.rows.Row; import org.apache.cassandra.db.rows.Unfiltered; import org.apache.cassandra.db.rows.UnfilteredRowIterator; @@ -60,13 +63,16 @@ import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; import org.apache.cassandra.index.sai.metrics.TableQueryMetrics; import org.apache.cassandra.index.sai.utils.PrimaryKey; -import org.apache.cassandra.index.sai.utils.RangeUtil; +import org.apache.cassandra.index.sai.utils.PrimaryKeyWithScore; import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey; +import org.apache.cassandra.index.sai.utils.RangeUtil; import org.apache.cassandra.io.util.FileUtils; +import org.apache.cassandra.schema.ColumnMetadata; import org.apache.cassandra.schema.TableMetadata; import org.apache.cassandra.utils.AbstractIterator; import org.apache.cassandra.utils.CloseableIterator; import org.apache.cassandra.utils.FBUtilities; +import org.apache.cassandra.utils.btree.BTree; public class StorageAttachedIndexSearcher implements Index.Searcher { @@ -109,19 +115,17 @@ public UnfilteredPartitionIterator search(ReadExecutionController executionContr // Can't check for `command.isTopK()` because the planner could optimize sorting out Orderer ordering = plan.ordering(); - if (ordering != null) - { - assert !(keysIterator instanceof KeyRangeIterator); - var scoredKeysIterator = (CloseableIterator) keysIterator; - var result = new ScoreOrderedResultRetriever(scoredKeysIterator, filterTree, controller, - executionController, queryContext, command.limits().count()); - return new TopKProcessor(command).filter(result); - } - else + if (ordering == null) { assert keysIterator instanceof KeyRangeIterator; return new ResultRetriever((KeyRangeIterator) keysIterator, filterTree, controller, executionController, queryContext); } + + assert !(keysIterator instanceof KeyRangeIterator); + var scoredKeysIterator = (CloseableIterator) keysIterator; + var result = new ScoreOrderedResultRetriever(scoredKeysIterator, filterTree, controller, + executionController, queryContext, command.limits().count()); + return new TopKProcessor(command).filter(result); } catch (QueryView.Builder.MissingIndexException e) { @@ -521,48 +525,50 @@ public UnfilteredRowIterator computeNext() */ private void fillPendingRows() { + // Group PKs by source sstable/memtable + var groupedKeys = new HashMap>(); // We always want to get at least 1. int rowsToRetrieve = Math.max(1, softLimit - returnedRowCount); - var keys = new HashMap>(); // We want to get the first unique `rowsToRetrieve` keys to materialize // Don't pass the priority queue here because it is more efficient to add keys in bulk - fillKeys(keys, rowsToRetrieve, null); + fillKeys(groupedKeys, rowsToRetrieve, null); // Sort the primary keys by PrK order, just in case that helps with cache and disk efficiency - var primaryKeyPriorityQueue = new PriorityQueue<>(keys.keySet()); + var primaryKeyPriorityQueue = new PriorityQueue<>(groupedKeys.keySet()); - while (!keys.isEmpty()) + // drain groupedKeys into pendingRows + while (!groupedKeys.isEmpty()) { - var primaryKey = primaryKeyPriorityQueue.poll(); - var primaryKeyWithSortKeys = keys.remove(primaryKey); - var partitionIterator = readAndValidatePartition(primaryKey, primaryKeyWithSortKeys); + var pk = primaryKeyPriorityQueue.poll(); + var sourceKeys = groupedKeys.remove(pk); + var partitionIterator = readAndValidatePartition(pk, sourceKeys); if (partitionIterator != null) pendingRows.add(partitionIterator); else // The current primaryKey did not produce a partition iterator. We know the caller will need // `rowsToRetrieve` rows, so we get the next unique key and add it to the queue. - fillKeys(keys, 1, primaryKeyPriorityQueue); + fillKeys(groupedKeys, 1, primaryKeyPriorityQueue); } } /** - * Fills the keys map with the next `count` unique primary keys that are in the keys produced by calling + * Fills the `groupedKeys` Map with the next `count` unique primary keys that are in the keys produced by calling * {@link #nextSelectedKeyInRange()}. We map PrimaryKey to List because the same * primary key can be in the result set multiple times, but with different source tables. - * @param keys the map to fill + * @param groupedKeys the map to fill * @param count the number of unique PrimaryKeys to consume from the iterator * @param primaryKeyPriorityQueue the priority queue to add new keys to. If the queue is null, we do not add * keys to the queue. */ - private void fillKeys(Map> keys, int count, PriorityQueue primaryKeyPriorityQueue) + private void fillKeys(Map> groupedKeys, int count, PriorityQueue primaryKeyPriorityQueue) { - int initialSize = keys.size(); - while (keys.size() - initialSize < count) + int initialSize = groupedKeys.size(); + while (groupedKeys.size() - initialSize < count) { var primaryKeyWithSortKey = nextSelectedKeyInRange(); if (primaryKeyWithSortKey == null) return; var nextPrimaryKey = primaryKeyWithSortKey.primaryKey(); - var accumulator = keys.computeIfAbsent(nextPrimaryKey, k -> new ArrayList<>()); + var accumulator = groupedKeys.computeIfAbsent(nextPrimaryKey, k -> new ArrayList<>()); if (primaryKeyPriorityQueue != null && accumulator.isEmpty()) primaryKeyPriorityQueue.add(nextPrimaryKey); accumulator.add(primaryKeyWithSortKey); @@ -602,15 +608,29 @@ private boolean isInRange(DecoratedKey key) return null; } - public UnfilteredRowIterator readAndValidatePartition(PrimaryKey key, List primaryKeys) + /** + * Reads and validates a partition for a given primary key against its sources. + *

+ * @param pk The primary key of the partition to read and validate + * @param sourceKeys A list of PrimaryKeyWithSortKey objects associated with the primary key. + * Multiple sort keys can exist for the same primary key when data comes from different + * sstables or memtables. + * + * @return An UnfilteredRowIterator containing the validated partition data, or null if: + * - The key has already been processed + * - The partition does not pass index filters + * - The partition contains no valid rows + * - The row data does not match the index metadata for any of the provided primary keys + */ + public UnfilteredRowIterator readAndValidatePartition(PrimaryKey pk, List sourceKeys) { // If we've already processed the key, we can skip it. Because the score ordered iterator does not // deduplicate rows, we could see dupes if a row is in the ordering index multiple times. This happens // in the case of dupes and of overwrites. - if (processedKeys.contains(key)) + if (processedKeys.contains(pk)) return null; - try (UnfilteredRowIterator partition = controller.getPartition(key, view, executionController)) + try (UnfilteredRowIterator partition = controller.getPartition(pk, view, executionController)) { queryContext.addPartitionsRead(1); queryContext.checkpoint(); @@ -619,7 +639,7 @@ public UnfilteredRowIterator readAndValidatePartition(PrimaryKey key, List primaryKeysWithScore, ReadCommand command) { super(partition.metadata(), partition.partitionKey(), @@ -662,7 +694,47 @@ public PrimaryKeyIterator(UnfilteredRowIterator partition, Row staticRow, Unfilt partition.isReverseOrder(), partition.stats()); - row = content; + assert !primaryKeysWithScore.isEmpty(); + var isScoredRow = primaryKeysWithScore.get(0) instanceof PrimaryKeyWithScore; + if (!content.isRow() || !isScoredRow) + { + this.row = content; + return; + } + + // When +score is added on the coordinator side, it's represented as a PrecomputedColumnFilter + // even in a 'SELECT *' because WCF is not capable of representing synthetic columns. + // This can be simplified when we remove ANN_USE_SYNTHETIC_SCORE + var tm = metadata(); + var scoreColumn = ColumnMetadata.syntheticColumn(tm.keyspace, + tm.name, + ColumnMetadata.SYNTHETIC_SCORE_ID, + FloatType.instance); + var isScoreFetched = command.columnFilter().fetchesExplicitly(scoreColumn); + if (!isScoreFetched) + { + this.row = content; + return; + } + + // Clone the original Row + Row originalRow = (Row) content; + ArrayList columnData = new ArrayList<>(originalRow.columnCount() + 1); + columnData.addAll(originalRow.columnData()); + + // inject +score as a new column + var pkWithScore = (PrimaryKeyWithScore) primaryKeysWithScore.get(0); + columnData.add(BufferCell.live(scoreColumn, + FBUtilities.nowInSeconds(), + FloatType.instance.decompose(pkWithScore.indexScore))); + + this.row = BTreeRow.create(originalRow.clustering(), + originalRow.primaryKeyLivenessInfo(), + originalRow.deletion(), + BTree.builder(ColumnData.comparator) + .auto(true) + .addAll(columnData) + .build()); } @Override @@ -674,18 +746,6 @@ protected Unfiltered computeNext() return row; } } - - @Override - public TableMetadata metadata() - { - return controller.metadata(); - } - - public void close() - { - FileUtils.closeQuietly(scoredPrimaryKeyIterator); - controller.finish(); - } } private static UnfilteredRowIterator applyIndexFilter(UnfilteredRowIterator partition, FilterTree tree, QueryContext queryContext) diff --git a/src/java/org/apache/cassandra/index/sai/plan/TopKProcessor.java b/src/java/org/apache/cassandra/index/sai/plan/TopKProcessor.java index a2e4b315f74f..5760109328b7 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/TopKProcessor.java +++ b/src/java/org/apache/cassandra/index/sai/plan/TopKProcessor.java @@ -27,8 +27,6 @@ import java.util.SortedSet; import java.util.TreeMap; import java.util.TreeSet; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionException; import javax.annotation.Nullable; import org.apache.commons.lang3.tuple.Triple; @@ -38,28 +36,25 @@ import io.github.jbellis.jvector.vector.VectorizationProvider; import io.github.jbellis.jvector.vector.types.VectorFloat; import io.github.jbellis.jvector.vector.types.VectorTypeSupport; -import org.apache.cassandra.concurrent.ImmediateExecutor; -import org.apache.cassandra.concurrent.LocalAwareExecutorService; -import org.apache.cassandra.concurrent.SharedExecutorPool; -import org.apache.cassandra.config.CassandraRelevantProperties; import org.apache.cassandra.cql3.Operator; +import org.apache.cassandra.cql3.statements.SelectStatement; import org.apache.cassandra.db.ColumnFamilyStore; import org.apache.cassandra.db.DecoratedKey; import org.apache.cassandra.db.Keyspace; import org.apache.cassandra.db.ReadCommand; import org.apache.cassandra.db.filter.RowFilter; -import org.apache.cassandra.db.partitions.BasePartitionIterator; -import org.apache.cassandra.db.partitions.ParallelCommandProcessor; +import org.apache.cassandra.db.marshal.FloatType; import org.apache.cassandra.db.partitions.PartitionIterator; import org.apache.cassandra.db.partitions.UnfilteredPartitionIterator; import org.apache.cassandra.db.rows.BaseRowIterator; +import org.apache.cassandra.db.rows.Cell; import org.apache.cassandra.db.rows.Row; import org.apache.cassandra.db.rows.Unfiltered; import org.apache.cassandra.index.Index; import org.apache.cassandra.index.SecondaryIndexManager; import org.apache.cassandra.index.sai.IndexContext; import org.apache.cassandra.index.sai.StorageAttachedIndex; -import org.apache.cassandra.index.sai.utils.AbortedOperationException; +import org.apache.cassandra.index.sai.plan.StorageAttachedIndexSearcher.ScoreOrderedResultRetriever; import org.apache.cassandra.index.sai.utils.InMemoryPartitionIterator; import org.apache.cassandra.index.sai.utils.InMemoryUnfilteredPartitionIterator; import org.apache.cassandra.index.sai.utils.PartitionInfo; @@ -72,33 +67,30 @@ import static org.apache.cassandra.cql3.statements.RequestValidations.invalidRequest; /** - * Processor applied to SAI based ORDER BY queries. This class could likely be refactored into either two filter - * methods depending on where the processing is happening or into two classes. - *

- * This processor performs the following steps on a replica: - * - collect LIMIT rows from partition iterator, making sure that all are valid. - * - return rows in Primary Key order - *

- * This processor performs the following steps on a coordinator: - * - consume all rows from the provided partition iterator and sort them according to the specified order. - * For vectors, that is similarit score and for all others, that is the ordering defined by their - * {@link org.apache.cassandra.db.marshal.AbstractType}. If there are multiple vector indexes, - * the final score is the sum of all vector index scores. - * - remove rows with the lowest scores from PQ if PQ size exceeds limit - * - return rows from PQ in primary key order to caller + * Processor applied to SAI based ORDER BY queries. + * + * * On a replica: + * * - filter(ScoreOrderedResultRetriever) is used to collect up to the top-K rows. + * * - We store any tombstones as well, to avoid losing them during coordinator reconciliation. + * * - The result is returned in PK order so that coordinator can merge from multiple replicas. + * + * On a coordinator: + * - reorder(PartitionIterator) is used to consume all rows from the provided partitions, + * compute the order based on either a column ordering or a similarity score, and keep top-K. + * - The result is returned in score/sortkey order. */ public class TopKProcessor { public static final String INDEX_MAY_HAVE_BEEN_DROPPED = "An index may have been dropped. Ordering on non-clustering " + "column requires the column to be indexed"; protected static final Logger logger = LoggerFactory.getLogger(TopKProcessor.class); - private static final LocalAwareExecutorService PARALLEL_EXECUTOR = getExecutor(); private static final VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport(); private final ReadCommand command; private final IndexContext indexContext; private final RowFilter.Expression expression; private final VectorFloat queryVector; + private final ColumnMetadata scoreColumn; private final int limit; @@ -106,52 +98,29 @@ public TopKProcessor(ReadCommand command) { this.command = command; - Pair annIndexAndExpression = findTopKIndexContext(); + Pair indexAndExpression = findTopKIndexContext(); // this can happen in case an index was dropped after the query was initiated - if (annIndexAndExpression == null) + if (indexAndExpression == null) throw invalidRequest(INDEX_MAY_HAVE_BEEN_DROPPED); - this.indexContext = annIndexAndExpression.left; - this.expression = annIndexAndExpression.right; - if (expression.operator() == Operator.ANN) + this.indexContext = indexAndExpression.left; + this.expression = indexAndExpression.right; + if (expression.operator() == Operator.ANN && !SelectStatement.ANN_USE_SYNTHETIC_SCORE) this.queryVector = vts.createFloatVector(TypeUtil.decomposeVector(indexContext, expression.getIndexValue().duplicate())); else this.queryVector = null; this.limit = command.limits().count(); - } - - /** - * Executor to use for parallel index reads. - * Defined by -Dcassandra.index_read.parallele=true/false, true by default. - *

- * INDEX_READ uses 2 * cpus threads by default but can be overridden with -Dcassandra.index_read.parallel_thread_num= - * - * @return stage to use, default INDEX_READ - */ - private static LocalAwareExecutorService getExecutor() - { - boolean isParallel = CassandraRelevantProperties.USE_PARALLEL_INDEX_READ.getBoolean(); - - if (isParallel) - { - int numThreads = CassandraRelevantProperties.PARALLEL_INDEX_READ_NUM_THREADS.isPresent() - ? CassandraRelevantProperties.PARALLEL_INDEX_READ_NUM_THREADS.getInt() - : FBUtilities.getAvailableProcessors() * 2; - return SharedExecutorPool.SHARED.newExecutor(numThreads, maximumPoolSize -> {}, "request", "IndexParallelRead"); - } - else - return ImmediateExecutor.INSTANCE; + this.scoreColumn = ColumnMetadata.syntheticColumn(indexContext.getKeyspace(), indexContext.getTable(), ColumnMetadata.SYNTHETIC_SCORE_ID, FloatType.instance); } /** * Sort the specified filtered rows according to the {@code ORDER BY} clause and keep the first {@link #limit} rows. - * This is meant to be used on the coordinator-side to sort the rows collected from the replicas. - * Caller must close the supplied iterator. + * Called on the coordinator side. * - * @param partitions the partitions collected by the coordinator - * @return the provided rows, sorted by the requested {@code ORDER BY} chriteria and trimmed to {@link #limit} rows + * @param partitions the partitions collected by the coordinator. It will be closed as a side-effect. + * @return the provided rows, sorted and trimmed to {@link #limit} rows */ - public PartitionIterator filter(PartitionIterator partitions) + public PartitionIterator reorder(PartitionIterator partitions) { // We consume the partitions iterator and create a new one. Use a try-with-resources block to ensure the // original iterator is closed. We do not expect exceptions here, but if they happen, we want to make sure the @@ -159,17 +128,35 @@ public PartitionIterator filter(PartitionIterator partitions) try (partitions) { Comparator> comparator = comparator() - .thenComparing(Triple::getLeft, Comparator.comparing(p -> p.key)) - .thenComparing(Triple::getMiddle, command.metadata().comparator); + .thenComparing(Triple::getLeft, Comparator.comparing(pi -> pi.key)) + .thenComparing(Triple::getMiddle, command.metadata().comparator); TopKSelector> topK = new TopKSelector<>(comparator, limit); + while (partitions.hasNext()) + { + try (BaseRowIterator partitionRowIterator = partitions.next()) + { + if (expression.operator() == Operator.ANN || expression.operator() == Operator.BM25) + { + PartitionResults pr = processScoredPartition(partitionRowIterator); + topK.addAll(pr.rows); + } + else + { + while (partitionRowIterator.hasNext()) + { + Row row = (Row) partitionRowIterator.next(); + ByteBuffer value = row.getCell(expression.column()).buffer(); + topK.add(Triple.of(PartitionInfo.create(partitionRowIterator), row, value)); + } + } + } + } - processPartitions(partitions, topK, null); - + // Convert the topK results to a PartitionIterator List> sortedRows = new ArrayList<>(topK.size()); for (Triple triple : topK.getShared()) sortedRows.add(Pair.create(triple.getLeft(), triple.getMiddle())); - return InMemoryPartitionIterator.create(command, sortedRows); } } @@ -186,149 +173,53 @@ public PartitionIterator filter(PartitionIterator partitions) *

* All tombstones will be kept. Caller must close the supplied iterator. * - * @param partitions the partitions collected in the replica side of a query + * @param partitions the partitions collected in the replica side of a query. It will be closed as a side-effect. * @return the provided rows, sorted by the requested {@code ORDER BY} chriteria, trimmed to {@link #limit} rows, * and the sorted again by primary key. */ - public UnfilteredPartitionIterator filter(UnfilteredPartitionIterator partitions) + public UnfilteredPartitionIterator filter(ScoreOrderedResultRetriever partitions) { - // We consume the partitions iterator and create a new one. Use a try-with-resources block to ensure the - // original iterator is closed. We do not expect exceptions here, but if they happen, we want to make sure the - // original iterator is closed to prevent leaking resources, which could compound the effect of an exception. try (partitions) { - TopKSelector> topK = new TopKSelector<>(comparator(), limit); - - TreeMap> unfilteredByPartition = new TreeMap<>(Comparator.comparing(p -> p.key)); - - processPartitions(partitions, topK, unfilteredByPartition); - - // Reorder the rows by primary key. - for (var triple : topK.getUnsortedShared()) - addUnfiltered(unfilteredByPartition, triple.getLeft(), triple.getMiddle()); - - return new InMemoryUnfilteredPartitionIterator(command, unfilteredByPartition); - } - } - - private Comparator> comparator() - { - Comparator> comparator; - if (queryVector != null) - { - comparator = Comparator.comparing((Triple t) -> (Float) t.getRight()).reversed(); - } - else - { - comparator = Comparator.comparing(t -> (ByteBuffer) t.getRight(), indexContext.getValidator()); - if (expression.operator() == Operator.ORDER_BY_DESC) - comparator = comparator.reversed(); - } - return comparator; - } - - private , P extends BasePartitionIterator> - void processPartitions(P partitions, - TopKSelector> topK, - @Nullable TreeMap> unfilteredByPartition) - { - if (PARALLEL_EXECUTOR != ImmediateExecutor.INSTANCE && partitions instanceof ParallelCommandProcessor) - { - ParallelCommandProcessor pIter = (ParallelCommandProcessor) partitions; - var commands = pIter.getUninitializedCommands(); - List> results = new ArrayList<>(commands.size()); + TreeMap> unfilteredByPartition = new TreeMap<>(Comparator.comparing(pi -> pi.key)); - int count = commands.size(); - for (var command: commands) { - CompletableFuture future = new CompletableFuture<>(); - results.add(future); - - // run last command immediately, others in parallel (if possible) - count--; - var executor = count == 0 ? ImmediateExecutor.INSTANCE : PARALLEL_EXECUTOR; - - executor.maybeExecuteImmediately(() -> { - try (var partitionRowIterator = pIter.commandToIterator(command.left(), command.right())) - { - future.complete(partitionRowIterator == null ? null : processPartition(partitionRowIterator)); - } - catch (Throwable t) - { - future.completeExceptionally(t); - } - }); - } - - for (CompletableFuture triplesFuture: results) - { - PartitionResults pr; - try - { - pr = triplesFuture.join(); - } - catch (CompletionException t) - { - if (t.getCause() instanceof AbortedOperationException) - throw (AbortedOperationException) t.getCause(); - throw t; - } - if (pr == null) - continue; - topK.addAll(pr.rows); - if (unfilteredByPartition != null) - { - for (var uf : pr.tombstones) - addUnfiltered(unfilteredByPartition, pr.partitionInfo, uf); - } - } - } - else if (partitions instanceof StorageAttachedIndexSearcher.ScoreOrderedResultRetriever) - { - // FilteredPartitions does not implement ParallelizablePartitionIterator. - // Realistically, this won't benefit from parallelizm as these are coming from in-memory/memtable data. int rowsMatched = 0; - // Check rowsMatched first to prevent fetching one more partition than needed. + // Because each “partition” from ScoreOrderedResultRetriever is actually a single row + // or tombstone, we can simply read them until we have enough. while (rowsMatched < limit && partitions.hasNext()) { - // Must close to move to the next partition, otherwise hasNext() fails - try (var partitionRowIterator = partitions.next()) + try (BaseRowIterator partitionRowIterator = partitions.next()) { rowsMatched += processSingleRowPartition(unfilteredByPartition, partitionRowIterator); } } + + return new InMemoryUnfilteredPartitionIterator(command, unfilteredByPartition); } - else + } + + /** + * Constructs a comparator for triple (PartitionInfo, Row, X) used for top-K ranking. + * For ANN/BM25 we compare descending by X (float score). For ORDER_BY_ASC or DESC, + * we compare ascending/descending by the row’s relevant ByteBuffer data. + */ + private Comparator> comparator() + { + if (expression.operator() == Operator.ANN || expression.operator() == Operator.BM25) { - // FilteredPartitions does not implement ParallelizablePartitionIterator. - // Realistically, this won't benefit from parallelizm as these are coming from in-memory/memtable data. - while (partitions.hasNext()) - { - // have to close to move to the next partition, otherwise hasNext() fails - try (var partitionRowIterator = partitions.next()) - { - if (queryVector != null) - { - PartitionResults pr = processPartition(partitionRowIterator); - topK.addAll(pr.rows); - if (unfilteredByPartition != null) - { - for (var uf : pr.tombstones) - addUnfiltered(unfilteredByPartition, pr.partitionInfo, uf); - } - } - else - { - while (partitionRowIterator.hasNext()) - { - Row row = (Row) partitionRowIterator.next(); - topK.add(Triple.of(PartitionInfo.create(partitionRowIterator), row, row.getCell(expression.column()).buffer())); - } - } - } - } + // For similarity, higher is better, so reversed + return Comparator.comparing((Triple t) -> (Float) t.getRight()).reversed(); } + + Comparator> comparator = Comparator.comparing(t -> (ByteBuffer) t.getRight(), indexContext.getValidator()); + if (expression.operator() == Operator.ORDER_BY_DESC) + comparator = comparator.reversed(); + return comparator; } + /** + * Simple holder for partial results of a single partition (score-based path). + */ private class PartitionResults { final PartitionInfo partitionInfo; @@ -352,9 +243,9 @@ void addRow(Triple triple) } /** - * Processes a single partition, calculating scores for rows and extracting tombstones. + * Processes all rows in a single partition to compute scores (for ANN or BM25) */ - private PartitionResults processPartition(BaseRowIterator partitionRowIterator) + private PartitionResults processScoredPartition(BaseRowIterator partitionRowIterator) { // Compute key and static row score once per partition DecoratedKey key = partitionRowIterator.partitionKey(); @@ -400,15 +291,14 @@ private int processSingleRowPartition(TreeMap return unfiltered.isRangeTombstoneMarker() ? 0 : 1; } - private void addUnfiltered(SortedMap> unfilteredByPartition, PartitionInfo partitionInfo, Unfiltered unfiltered) + private void addUnfiltered(SortedMap> unfilteredByPartition, + PartitionInfo partitionInfo, + Unfiltered unfiltered) { var map = unfilteredByPartition.computeIfAbsent(partitionInfo, k -> new TreeSet<>(command.metadata().comparator)); map.add(unfiltered); } - /** - * Sum the scores from different vector indexes for the row - */ private float getScoreForRow(DecoratedKey key, Row row) { ColumnMetadata column = indexContext.getDefinition(); @@ -422,6 +312,15 @@ private float getScoreForRow(DecoratedKey key, Row row) if ((column.isClusteringColumn() || column.isRegular()) && row.isStatic()) return 0; + // If we have a synthetic score column, use it + var scoreData = row.getColumnData(scoreColumn); + if (scoreData != null) + { + var cell = (Cell) scoreData; + return FloatType.instance.compose(cell.buffer()); + } + + // TODO remove this once we enable ANN_USE_SYNTHETIC_SCORE ByteBuffer value = indexContext.getValueOf(key, row, FBUtilities.nowInSeconds()); if (value != null) { @@ -431,28 +330,30 @@ private float getScoreForRow(DecoratedKey key, Row row) return 0; } - private Pair findTopKIndexContext() { ColumnFamilyStore cfs = Keyspace.openAndGetStore(command.metadata()); for (RowFilter.Expression expression : command.rowFilter().expressions()) { - StorageAttachedIndex sai = findVectorIndexFor(cfs.indexManager, expression); + StorageAttachedIndex sai = findOrderingIndexFor(cfs.indexManager, expression); if (sai != null) - { return Pair.create(sai.getIndexContext(), expression); - } } return null; } @Nullable - private StorageAttachedIndex findVectorIndexFor(SecondaryIndexManager sim, RowFilter.Expression e) + private StorageAttachedIndex findOrderingIndexFor(SecondaryIndexManager sim, RowFilter.Expression e) { - if (e.operator() != Operator.ANN && e.operator() != Operator.ORDER_BY_ASC && e.operator() != Operator.ORDER_BY_DESC) + if (e.operator() != Operator.ANN + && e.operator() != Operator.BM25 + && e.operator() != Operator.ORDER_BY_ASC + && e.operator() != Operator.ORDER_BY_DESC) + { return null; + } Optional index = sim.getBestIndexFor(e); return (StorageAttachedIndex) index.filter(i -> i instanceof StorageAttachedIndex).orElse(null); diff --git a/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java b/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java new file mode 100644 index 000000000000..ae101c6f0373 --- /dev/null +++ b/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.utils; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +import org.apache.cassandra.db.memtable.Memtable; +import org.apache.cassandra.db.rows.Cell; +import org.apache.cassandra.index.sai.IndexContext; +import org.apache.cassandra.index.sai.analyzer.AbstractAnalyzer; +import org.apache.cassandra.io.sstable.SSTableId; +import org.apache.cassandra.io.util.FileUtils; +import org.apache.cassandra.utils.CloseableIterator; + +public class BM25Utils +{ + private static final float K1 = 1.2f; // BM25 term frequency saturation parameter + private static final float B = 0.75f; // BM25 length normalization parameter + + /** + * Term frequencies across all documents. Each document is only counted once. + */ + public static class DocStats + { + // Map of term -> count of docs containing that term + private final Map frequencies; + // total number of docs in the index + private final long docCount; + + public DocStats(Map frequencies, long docCount) + { + this.frequencies = frequencies; + this.docCount = docCount; + } + } + + /** + * Term frequencies within a single document. All instances of a term are counted. + */ + public static class DocTF + { + private final PrimaryKey pk; + private final Map frequencies; + private final int termCount; + + public DocTF(PrimaryKey pk, int termCount, Map frequencies) + { + this.pk = pk; + this.frequencies = frequencies; + this.termCount = termCount; + } + + public int getTermFrequency(ByteBuffer term) + { + return frequencies.getOrDefault(term, 0); + } + + public static DocTF createFromDocument(PrimaryKey pk, + Cell cell, + AbstractAnalyzer docAnalyzer, + Collection queryTerms) + { + int count = 0; + Map frequencies = new HashMap<>(); + + docAnalyzer.reset(cell.buffer()); + try + { + while (docAnalyzer.hasNext()) + { + ByteBuffer term = docAnalyzer.next(); + count++; + if (queryTerms.contains(term)) + frequencies.merge(term, 1, Integer::sum); + } + } + finally + { + docAnalyzer.end(); + } + + return new DocTF(pk, count, frequencies); + } + } + + public static CloseableIterator computeScores(CloseableIterator docIterator, + List queryTerms, + DocStats docStats, + IndexContext indexContext, + Object source) + { + // data structures for document stats and frequencies + ArrayList documents = new ArrayList<>(); + double totalTermCount = 0; + + // Compute TF within each document + while (docIterator.hasNext()) + { + var tf = docIterator.next(); + documents.add(tf); + totalTermCount += tf.termCount; + } + if (documents.isEmpty()) + return CloseableIterator.emptyIterator(); + + // Calculate average document length + double avgDocLength = totalTermCount / documents.size(); + + // Calculate BM25 scores + var scoredDocs = new ArrayList(documents.size()); + for (var doc : documents) + { + double score = 0.0; + for (var queryTerm : queryTerms) + { + int tf = doc.getTermFrequency(queryTerm); + Long df = docStats.frequencies.get(queryTerm); + // we shouldn't have more hits for a term than we counted total documents + assert df <= docStats.docCount : String.format("df=%d, totalDocs=%d", df, docStats.docCount); + + double normalizedTf = tf / (tf + K1 * (1 - B + B * doc.termCount / avgDocLength)); + double idf = Math.log(1 + (docStats.docCount - df + 0.5) / (df + 0.5)); + double deltaScore = normalizedTf * idf; + assert deltaScore >= 0 : String.format("BM25 score for tf=%d, df=%d, tc=%d, totalDocs=%d is %f", + tf, df, doc.termCount, docStats.docCount, deltaScore); + score += deltaScore; + } + if (source instanceof Memtable) + scoredDocs.add(new PrimaryKeyWithScore(indexContext, (Memtable) source, doc.pk, (float) score)); + else if (source instanceof SSTableId) + scoredDocs.add(new PrimaryKeyWithScore(indexContext, (SSTableId) source, doc.pk, (float) score)); + else + throw new IllegalArgumentException("Invalid source " + source.getClass()); + } + + // sort by score (PKWS implements Comparator correctly for us) + Collections.sort(scoredDocs); + + return new CloseableIterator<>() + { + private final Iterator iterator = scoredDocs.iterator(); + + @Override + public boolean hasNext() + { + return iterator.hasNext(); + } + + @Override + public PrimaryKeyWithSortKey next() + { + return iterator.next(); + } + + @Override + public void close() + { + FileUtils.closeQuietly(docIterator); + } + }; + } +} diff --git a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKey.java b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKey.java index 516e26caaeea..0bcd16c0a7f3 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKey.java +++ b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKey.java @@ -19,6 +19,7 @@ import java.util.function.Supplier; +import io.github.jbellis.jvector.util.Accountable; import org.apache.cassandra.db.Clustering; import org.apache.cassandra.db.ClusteringComparator; import org.apache.cassandra.db.DecoratedKey; @@ -38,7 +39,7 @@ * For the V2 on-disk format the {@link DecoratedKey} and {@link Clustering} are supported. * */ -public interface PrimaryKey extends Comparable +public interface PrimaryKey extends Comparable, Accountable { /** * A factory for creating {@link PrimaryKey} instances diff --git a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithByteComparable.java b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithByteComparable.java index 837c032b7952..bb931c942fbf 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithByteComparable.java +++ b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithByteComparable.java @@ -21,7 +21,10 @@ import java.nio.ByteBuffer; import java.util.Arrays; +import io.github.jbellis.jvector.util.RamUsageEstimator; +import org.apache.cassandra.db.memtable.Memtable; import org.apache.cassandra.index.sai.IndexContext; +import org.apache.cassandra.io.sstable.SSTableId; import org.apache.cassandra.utils.bytecomparable.ByteComparable; import org.apache.cassandra.utils.bytecomparable.ByteSource; import org.apache.cassandra.utils.bytecomparable.ByteSourceInverse; @@ -34,7 +37,13 @@ public class PrimaryKeyWithByteComparable extends PrimaryKeyWithSortKey { private final ByteComparable byteComparable; - public PrimaryKeyWithByteComparable(IndexContext context, Object sourceTable, PrimaryKey primaryKey, ByteComparable byteComparable) + public PrimaryKeyWithByteComparable(IndexContext context, Memtable sourceTable, PrimaryKey primaryKey, ByteComparable byteComparable) + { + super(context, sourceTable, primaryKey); + this.byteComparable = byteComparable; + } + + public PrimaryKeyWithByteComparable(IndexContext context, SSTableId sourceTable, PrimaryKey primaryKey, ByteComparable byteComparable) { super(context, sourceTable, primaryKey); this.byteComparable = byteComparable; @@ -65,4 +74,10 @@ public int compareTo(PrimaryKey o) return ByteComparable.compare(byteComparable, ((PrimaryKeyWithByteComparable) o).byteComparable, TypeUtil.BYTE_COMPARABLE_VERSION); } + + @Override + public long ramBytesUsed() + { + return super.ramBytesUsed() + RamUsageEstimator.NUM_BYTES_OBJECT_REF; + } } diff --git a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithScore.java b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithScore.java index 8b7b4acba9c6..b88c210d65f2 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithScore.java +++ b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithScore.java @@ -20,7 +20,9 @@ import java.nio.ByteBuffer; +import org.apache.cassandra.db.memtable.Memtable; import org.apache.cassandra.index.sai.IndexContext; +import org.apache.cassandra.io.sstable.SSTableId; /** * A {@link PrimaryKey} that includes a score from a source index. @@ -28,9 +30,15 @@ */ public class PrimaryKeyWithScore extends PrimaryKeyWithSortKey { - private final float indexScore; + public final float indexScore; - public PrimaryKeyWithScore(IndexContext context, Object source, PrimaryKey primaryKey, float indexScore) + public PrimaryKeyWithScore(IndexContext context, Memtable source, PrimaryKey primaryKey, float indexScore) + { + super(context, source, primaryKey); + this.indexScore = indexScore; + } + + public PrimaryKeyWithScore(IndexContext context, SSTableId source, PrimaryKey primaryKey, float indexScore) { super(context, source, primaryKey); this.indexScore = indexScore; @@ -53,4 +61,11 @@ public int compareTo(PrimaryKey o) // Descending order return Float.compare(((PrimaryKeyWithScore) o).indexScore, indexScore); } + + @Override + public long ramBytesUsed() + { + // Include super class fields plus float value + return super.ramBytesUsed() + Float.BYTES; + } } diff --git a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithSortKey.java b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithSortKey.java index 2e79b0402124..b3a6fb4338e5 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithSortKey.java +++ b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithSortKey.java @@ -20,11 +20,14 @@ import java.nio.ByteBuffer; +import io.github.jbellis.jvector.util.RamUsageEstimator; import org.apache.cassandra.db.Clustering; import org.apache.cassandra.db.DecoratedKey; +import org.apache.cassandra.db.memtable.Memtable; import org.apache.cassandra.db.rows.Row; import org.apache.cassandra.dht.Token; import org.apache.cassandra.index.sai.IndexContext; +import org.apache.cassandra.io.sstable.SSTableId; import org.apache.cassandra.utils.bytecomparable.ByteComparable; import org.apache.cassandra.utils.bytecomparable.ByteSource; @@ -41,7 +44,14 @@ public abstract class PrimaryKeyWithSortKey implements PrimaryKey // Either a Memtable reference or an SSTableId reference private final Object sourceTable; - protected PrimaryKeyWithSortKey(IndexContext context, Object sourceTable, PrimaryKey primaryKey) + protected PrimaryKeyWithSortKey(IndexContext context, Memtable sourceTable, PrimaryKey primaryKey) + { + this.context = context; + this.sourceTable = sourceTable; + this.primaryKey = primaryKey; + } + + protected PrimaryKeyWithSortKey(IndexContext context, SSTableId sourceTable, PrimaryKey primaryKey) { this.context = context; this.sourceTable = sourceTable; @@ -137,4 +147,13 @@ public ByteSource asComparableBytesMaxPrefix(ByteComparable.Version version) return primaryKey.asComparableBytesMaxPrefix(version); } + @Override + public long ramBytesUsed() + { + // Object header + 3 references (context, primaryKey, sourceTable) + return RamUsageEstimator.NUM_BYTES_OBJECT_HEADER + + 3L * RamUsageEstimator.NUM_BYTES_OBJECT_REF + + primaryKey.ramBytesUsed(); + } + } diff --git a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeys.java b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeys.java index 173dc045965f..75298ac217da 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeys.java +++ b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeys.java @@ -41,6 +41,7 @@ public class PrimaryKeys implements Iterable * Adds the specified {@link PrimaryKey}. * * @param key a primary key + * @return the bytes allocated for the key (0 if it already existed in the set) */ public long add(PrimaryKey key) { diff --git a/src/java/org/apache/cassandra/index/sai/utils/RowIdWithScore.java b/src/java/org/apache/cassandra/index/sai/utils/RowIdWithScore.java index b4b1f129567f..c6ed4708e0c8 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/RowIdWithScore.java +++ b/src/java/org/apache/cassandra/index/sai/utils/RowIdWithScore.java @@ -26,7 +26,7 @@ */ public class RowIdWithScore extends RowIdWithMeta { - private final float score; + public final float score; public RowIdWithScore(int segmentRowId, float score) { diff --git a/src/java/org/apache/cassandra/index/sai/utils/RowWithSourceTable.java b/src/java/org/apache/cassandra/index/sai/utils/RowWithSourceTable.java index 3064c701fc56..d93198a94d1a 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/RowWithSourceTable.java +++ b/src/java/org/apache/cassandra/index/sai/utils/RowWithSourceTable.java @@ -375,4 +375,13 @@ private Row maybeWrapRow(Row r) return this; return new RowWithSourceTable(r, source); } + + @Override + public String toString() + { + return "RowWithSourceTable{" + + row + + ", source=" + source + + '}'; + } } diff --git a/src/java/org/apache/cassandra/schema/ColumnMetadata.java b/src/java/org/apache/cassandra/schema/ColumnMetadata.java index 38784e72acb5..008813564e57 100644 --- a/src/java/org/apache/cassandra/schema/ColumnMetadata.java +++ b/src/java/org/apache/cassandra/schema/ColumnMetadata.java @@ -71,9 +71,9 @@ public enum ClusteringOrder /** * The type of CQL3 column this definition represents. - * There is 4 main type of CQL3 columns: those parts of the partition key, - * those parts of the clustering columns and amongst the others, regular and - * static ones. + * There are 5 types of columns: those parts of the partition key, + * those parts of the clustering columns and amongst the others, regular, + * static, and synthetic ones. * * IMPORTANT: this enum is serialized as toString() and deserialized by calling * Kind.valueOf(), so do not override toString() or rename existing values. @@ -81,18 +81,22 @@ public enum ClusteringOrder public enum Kind { // NOTE: if adding a new type, must modify comparisonOrder + SYNTHETIC, PARTITION_KEY, CLUSTERING, REGULAR, STATIC; + // it is not possible to add new Kinds after Synthetic without invasive changes to BTreeRow, which + // assumes that complex regulr/static columns are the last ones public boolean isPrimaryKeyKind() { return this == PARTITION_KEY || this == CLUSTERING; } - } + public static final ColumnIdentifier SYNTHETIC_SCORE_ID = ColumnIdentifier.getInterned("+:!score", true); + /** * Whether this is a dropped column. */ @@ -121,10 +125,17 @@ public boolean isPrimaryKeyKind() */ private final long comparisonOrder; + /** + * Bit layout (from most to least significant): + * - Bits 61-63: Kind ordinal (3 bits, supporting up to 8 Kind values) + * - Bit 60: isComplex flag + * - Bits 48-59: position (12 bits, see assert) + * - Bits 0-47: name.prefixComparison (shifted right by 16) + */ private static long comparisonOrder(Kind kind, boolean isComplex, long position, ColumnIdentifier name) { assert position >= 0 && position < 1 << 12; - return (((long) kind.ordinal()) << 61) + return (((long) kind.ordinal()) << 61) | (isComplex ? 1L << 60 : 0) | (position << 48) | (name.prefixComparison >>> 16); @@ -170,6 +181,14 @@ public static ColumnMetadata staticColumn(String keyspace, String table, String return new ColumnMetadata(keyspace, table, ColumnIdentifier.getInterned(name, true), type, NO_POSITION, Kind.STATIC); } + /** + * Creates a new synthetic column metadata instance. + */ + public static ColumnMetadata syntheticColumn(String keyspace, String table, ColumnIdentifier id, AbstractType type) + { + return new ColumnMetadata(keyspace, table, id, type, NO_POSITION, Kind.SYNTHETIC); + } + /** * Rebuild the metadata for a dropped column from its recorded data. * @@ -225,6 +244,7 @@ public ColumnMetadata(String ksName, this.kind = kind; this.position = position; this.cellPathComparator = makeCellPathComparator(kind, type); + assert kind != Kind.SYNTHETIC || cellPathComparator == null; this.cellComparator = cellPathComparator == null ? ColumnData.comparator : new Comparator>() { @Override @@ -461,7 +481,7 @@ public int compareTo(ColumnMetadata other) return 0; if (comparisonOrder != other.comparisonOrder) - return Long.compare(comparisonOrder, other.comparisonOrder); + return Long.compareUnsigned(comparisonOrder, other.comparisonOrder); return this.name.compareTo(other.name); } @@ -593,6 +613,11 @@ public boolean isCounterColumn() return type.isCounter(); } + public boolean isSynthetic() + { + return kind == Kind.SYNTHETIC; + } + public Selector.Factory newSelectorFactory(TableMetadata table, AbstractType expectedType, List defs, VariableSpecifications boundNames) throws InvalidRequestException { return SimpleSelector.newFactory(this, addAndGetIndex(this, defs)); diff --git a/src/java/org/apache/cassandra/schema/TableMetadata.java b/src/java/org/apache/cassandra/schema/TableMetadata.java index ba5e1db8d84d..8885b85fdd43 100644 --- a/src/java/org/apache/cassandra/schema/TableMetadata.java +++ b/src/java/org/apache/cassandra/schema/TableMetadata.java @@ -1122,8 +1122,7 @@ public Builder addStaticColumn(ColumnIdentifier name, AbstractType type) public Builder addColumn(ColumnMetadata column) { - if (columns.containsKey(column.name.bytes)) - throw new IllegalArgumentException(); + assert !columns.containsKey(column.name.bytes) : column.name + " is already present"; switch (column.kind) { diff --git a/src/java/org/apache/cassandra/service/ClientWarn.java b/src/java/org/apache/cassandra/service/ClientWarn.java index 5a6a878681e1..38570a06d2b8 100644 --- a/src/java/org/apache/cassandra/service/ClientWarn.java +++ b/src/java/org/apache/cassandra/service/ClientWarn.java @@ -18,7 +18,9 @@ package org.apache.cassandra.service; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; +import java.util.Set; import io.netty.util.concurrent.FastThreadLocal; import org.apache.cassandra.concurrent.ExecutorLocal; @@ -45,10 +47,18 @@ public void set(State value) } public void warn(String text) + { + warn(text, null); + } + + /** + * Issue the given warning if this is the first time `key` is seen. + */ + public void warn(String text, Object key) { State state = warnLocal.get(); if (state != null) - state.add(text); + state.add(text, key); } public void captureWarnings() @@ -72,11 +82,16 @@ public void resetWarnings() public static class State { private final List warnings = new ArrayList<>(); + private final Set keysAdded = new HashSet<>(); - private void add(String warning) + private void add(String warning, Object key) { if (warnings.size() < FBUtilities.MAX_UNSIGNED_SHORT) + { + if (key != null && !keysAdded.add(key)) + return; warnings.add(maybeTruncate(warning)); + } } private static String maybeTruncate(String warning) diff --git a/test/burn/org/apache/cassandra/index/sai/LongBM25Test.java b/test/burn/org/apache/cassandra/index/sai/LongBM25Test.java new file mode 100644 index 000000000000..49b5b5118540 --- /dev/null +++ b/test/burn/org/apache/cassandra/index/sai/LongBM25Test.java @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import org.junit.Before; +import org.junit.Test; +import org.slf4j.Logger; + +import org.apache.cassandra.db.memtable.TrieMemtable; + +public class LongBM25Test extends SAITester +{ + private static final Logger logger = org.slf4j.LoggerFactory.getLogger(LongBM25Test.class); + + private static final List documentLines = new ArrayList<>(); + + static + { + try + { + var cl = LongBM25Test.class.getClassLoader(); + var resourceDir = cl.getResource("bm25"); + if (resourceDir == null) + throw new RuntimeException("Could not find resource directory test/resources/bm25/"); + + var dirPath = java.nio.file.Paths.get(resourceDir.toURI()); + try (var files = java.nio.file.Files.list(dirPath)) + { + files.forEach(file -> { + try (var lines = java.nio.file.Files.lines(file)) + { + lines.map(String::trim) + .filter(line -> !line.isEmpty()) + .forEach(documentLines::add); + } + catch (IOException e) + { + throw new RuntimeException("Failed to read file: " + file, e); + } + }); + } + if (documentLines.isEmpty()) + { + throw new RuntimeException("No document lines loaded from test/resources/bm25/"); + } + } + catch (IOException | URISyntaxException e) + { + throw new RuntimeException("Failed to load test documents", e); + } + } + + KeySet keysInserted = new KeySet(); + private final int threadCount = 12; + + @Before + public void setup() throws Throwable + { + // we don't get loaded until after TM, so we can't affect the very first memtable, + // but this will affect all subsequent ones + TrieMemtable.SHARD_COUNT = 4 * threadCount; + } + + @FunctionalInterface + private interface Op + { + void run(int i) throws Throwable; + } + + public void testConcurrentOps(Op op) throws ExecutionException, InterruptedException + { + createTable("CREATE TABLE %s (key int primary key, value text)"); + // Create analyzed index following BM25Test pattern + createIndex("CREATE CUSTOM INDEX ON %s(value) " + + "USING 'org.apache.cassandra.index.sai.StorageAttachedIndex' " + + "WITH OPTIONS = {" + + "'index_analyzer': '{" + + "\"tokenizer\" : {\"name\" : \"standard\"}, " + + "\"filters\" : [{\"name\" : \"porterstem\"}]" + + "}'}" + ); + + AtomicInteger counter = new AtomicInteger(); + long start = System.currentTimeMillis(); + var fjp = new ForkJoinPool(threadCount); + var keys = IntStream.range(0, 10_000_000).boxed().collect(Collectors.toList()); + Collections.shuffle(keys); + var task = fjp.submit(() -> keys.stream().parallel().forEach(i -> + { + wrappedOp(op, i); + if (counter.incrementAndGet() % 10_000 == 0) + { + var elapsed = System.currentTimeMillis() - start; + logger.info("{} ops in {}ms = {} ops/s", counter.get(), elapsed, counter.get() * 1000.0 / elapsed); + } + if (ThreadLocalRandom.current().nextDouble() < 0.001) + flush(); + })); + fjp.shutdown(); + task.get(); // re-throw + } + + private static void wrappedOp(Op op, Integer i) + { + try + { + op.run(i); + } + catch (Throwable e) + { + throw new RuntimeException(e); + } + } + + private static String randomDocument() + { + var R = ThreadLocalRandom.current(); + int numLines = R.nextInt(5, 51); // 5 to 50 lines inclusive + var selectedLines = new ArrayList(); + + for (int i = 0; i < numLines; i++) + { + selectedLines.add(randomQuery(R)); + } + + return String.join("\n", selectedLines); + } + + private static String randomLine(ThreadLocalRandom R) + { + return documentLines.get(R.nextInt(documentLines.size())); + } + + @Test + public void testConcurrentReadsWritesDeletes() throws ExecutionException, InterruptedException + { + testConcurrentOps(i -> { + var R = ThreadLocalRandom.current(); + if (R.nextDouble() < 0.2 || keysInserted.isEmpty()) + { + var doc = randomDocument(); + execute("INSERT INTO %s (key, value) VALUES (?, ?)", i, doc); + keysInserted.add(i); + } + else if (R.nextDouble() < 0.1) + { + var key = keysInserted.getRandom(); + execute("DELETE FROM %s WHERE key = ?", key); + } + else + { + var line = randomQuery(R); + execute("SELECT * FROM %s ORDER BY value BM25 OF ? LIMIT ?", line, R.nextInt(1, 100)); + } + }); + } + + private static String randomQuery(ThreadLocalRandom R) + { + while (true) + { + var line = randomLine(R); + if (line.chars().anyMatch(Character::isAlphabetic)) + return line; + } + } + + @Test + public void testConcurrentReadsWrites() throws ExecutionException, InterruptedException + { + testConcurrentOps(i -> { + var R = ThreadLocalRandom.current(); + if (R.nextDouble() < 0.1 || keysInserted.isEmpty()) + { + var doc = randomDocument(); + execute("INSERT INTO %s (key, value) VALUES (?, ?)", i, doc); + keysInserted.add(i); + } + else + { + var line = randomQuery(R); + execute("SELECT * FROM %s ORDER BY value BM25 OF ? LIMIT ?", line, R.nextInt(1, 100)); + } + }); + } + + @Test + public void testConcurrentWrites() throws ExecutionException, InterruptedException + { + testConcurrentOps(i -> { + var doc = randomDocument(); + execute("INSERT INTO %s (key, value) VALUES (?, ?)", i, doc); + }); + } + + private static class KeySet + { + private final Map keys = new ConcurrentHashMap<>(); + private final AtomicInteger ordinal = new AtomicInteger(); + + public void add(int key) + { + var i = ordinal.getAndIncrement(); + keys.put(i, key); + } + + public int getRandom() + { + if (isEmpty()) + throw new IllegalStateException(); + var i = ThreadLocalRandom.current().nextInt(ordinal.get()); + // in case there is race with add(key), retry another random + return keys.containsKey(i) ? keys.get(i) : getRandom(); + } + + public boolean isEmpty() + { + return keys.isEmpty(); + } + } +} diff --git a/test/distributed/org/apache/cassandra/distributed/test/sai/BM25DistributedTest.java b/test/distributed/org/apache/cassandra/distributed/test/sai/BM25DistributedTest.java new file mode 100644 index 000000000000..95c7587ea044 --- /dev/null +++ b/test/distributed/org/apache/cassandra/distributed/test/sai/BM25DistributedTest.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.distributed.test.sai; + +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.apache.cassandra.distributed.Cluster; +import org.apache.cassandra.distributed.api.ConsistencyLevel; +import org.apache.cassandra.distributed.test.TestBaseImpl; +import org.apache.cassandra.index.sai.disk.format.Version; + +import static org.apache.cassandra.distributed.api.Feature.GOSSIP; +import static org.apache.cassandra.distributed.api.Feature.NETWORK; +import static org.assertj.core.api.Assertions.assertThat; + +public class BM25DistributedTest extends TestBaseImpl +{ + private static final String CREATE_KEYSPACE = "CREATE KEYSPACE %%s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': %d}"; + private static final String CREATE_TABLE = "CREATE TABLE %s (k int PRIMARY KEY, v text)"; + private static final String CREATE_INDEX = "CREATE CUSTOM INDEX ON %s(v) USING 'StorageAttachedIndex' WITH OPTIONS = {'index_analyzer': '{\"tokenizer\" : {\"name\" : \"standard\"}, \"filters\" : [{\"name\" : \"porterstem\"}]}'}"; + + // To get consistent results from BM25 we need to know which docs are evaluated, the easiest way + // to do that is to put all the docs on every replica + private static final int NUM_NODES = 3; + private static final int RF = 3; + + private static Cluster cluster; + private static String table; + + private static final AtomicInteger seq = new AtomicInteger(); + + @BeforeClass + public static void setupCluster() throws Exception + { + cluster = Cluster.build(NUM_NODES) + .withTokenCount(1) + .withDataDirCount(1) + .withConfig(config -> config.with(GOSSIP).with(NETWORK)) + .start(); + + cluster.schemaChange(withKeyspace(String.format(CREATE_KEYSPACE, RF))); + cluster.forEach(i -> i.runOnInstance(() -> org.apache.cassandra.index.sai.SAIUtil.setLatestVersion(Version.EC))); + } + + @AfterClass + public static void closeCluster() + { + if (cluster != null) + cluster.close(); + } + + @Before + public void before() + { + table = "table_" + seq.getAndIncrement(); + cluster.schemaChange(formatQuery(CREATE_TABLE)); + cluster.schemaChange(formatQuery(CREATE_INDEX)); + SAIUtil.waitForIndexQueryable(cluster, KEYSPACE); + } + + @Test + public void testTermFrequencyOrdering() + { + // Insert documents with varying frequencies of the term "apple" + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + execute("INSERT INTO %s (k, v) VALUES (2, 'apple apple')"); + execute("INSERT INTO %s (k, v) VALUES (3, 'apple apple apple')"); + + // Query memtable index + assertBM25Ordering(); + + // Flush and query on-disk index + cluster.forEach(n -> n.flush(KEYSPACE)); + assertBM25Ordering(); + } + + private void assertBM25Ordering() + { + Object[][] result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertThat(result).hasNumberOfRows(3); + + // Results should be ordered by term frequency (highest to lowest) + assertThat((Integer) result[0][0]).isEqualTo(3); // 3 occurrences + assertThat((Integer) result[1][0]).isEqualTo(2); // 2 occurrences + assertThat((Integer) result[2][0]).isEqualTo(1); // 1 occurrence + } + + private static Object[][] execute(String query) + { + return execute(query, ConsistencyLevel.QUORUM); + } + + private static Object[][] execute(String query, ConsistencyLevel consistencyLevel) + { + return cluster.coordinator(1).execute(formatQuery(query), consistencyLevel); + } + + private static String formatQuery(String query) + { + return String.format(query, KEYSPACE + '.' + table); + } +} diff --git a/test/unit/org/apache/cassandra/cql3/CQLTester.java b/test/unit/org/apache/cassandra/cql3/CQLTester.java index 38b937041caf..aec7238c9348 100644 --- a/test/unit/org/apache/cassandra/cql3/CQLTester.java +++ b/test/unit/org/apache/cassandra/cql3/CQLTester.java @@ -768,6 +768,11 @@ protected String currentIndex() return indexes.get(indexes.size() - 1); } + protected String getIndex(int i) + { + return indexes.get(i); + } + protected Collection currentTables() { if (tables == null || tables.isEmpty()) diff --git a/test/unit/org/apache/cassandra/cql3/validation/operations/SelectOrderByTest.java b/test/unit/org/apache/cassandra/cql3/validation/operations/SelectOrderByTest.java index dac86c76e5d8..34f7d606ce55 100644 --- a/test/unit/org/apache/cassandra/cql3/validation/operations/SelectOrderByTest.java +++ b/test/unit/org/apache/cassandra/cql3/validation/operations/SelectOrderByTest.java @@ -659,7 +659,7 @@ public void testAllowSkippingEqualityAndSingleValueInRestrictedClusteringColumns assertInvalidMessage("Cannot combine clustering column ordering with non-clustering column ordering", "SELECT * FROM %s WHERE a=? ORDER BY b ASC, c ASC, d ASC", 0); - String errorMsg = "Order by currently only supports the ordering of columns following their declared order in the PRIMARY KEY"; + String errorMsg = "Ordering by clustered columns must follow the declared order in the PRIMARY KEY"; assertRows(execute("SELECT * FROM %s WHERE a=? AND b=? ORDER BY c", 0, 0), row(0, 0, 0, 0), diff --git a/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java b/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java new file mode 100644 index 000000000000..3a40e075114a --- /dev/null +++ b/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java @@ -0,0 +1,510 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.cql; + +import org.junit.Before; +import org.junit.Test; + +import org.apache.cassandra.index.sai.SAITester; +import org.apache.cassandra.index.sai.SAIUtil; +import org.apache.cassandra.index.sai.disk.format.Version; +import org.apache.cassandra.index.sai.disk.v1.SegmentBuilder; +import org.apache.cassandra.index.sai.plan.QueryController; + +import static org.apache.cassandra.index.sai.analyzer.AnalyzerEqOperatorSupport.EQ_AMBIGUOUS_ERROR; +import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; + +public class BM25Test extends SAITester +{ + @Before + public void setup() throws Throwable + { + SAIUtil.setLatestVersion(Version.EC); + } + + @Test + public void testTwoIndexes() + { + // create un-analyzed index + createTable("CREATE TABLE %s (k int PRIMARY KEY, v text)"); + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'"); + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + + // BM25 should fail with only an equality index + assertInvalidMessage("BM25 ordering on column v requires an analyzed index", + "SELECT k FROM %s WHERE v : 'apple' ORDER BY v BM25 OF 'apple' LIMIT 3"); + + // create analyzed index + analyzeIndex(); + // BM25 query should work now + var result = execute("SELECT k FROM %s WHERE v : 'apple' ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertRows(result, row(1)); + } + + @Test + public void testTwoIndexesAmbiguousPredicate() throws Throwable + { + createTable("CREATE TABLE %s (k int PRIMARY KEY, v text)"); + + // Create analyzed and un-analyzed indexes + analyzeIndex(); + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'"); + + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + execute("INSERT INTO %s (k, v) VALUES (2, 'apple juice')"); + execute("INSERT INTO %s (k, v) VALUES (3, 'orange juice')"); + + // equality predicate is ambiguous (both analyzed and un-analyzed indexes could support it) so it should + // be rejected + beforeAndAfterFlush(() -> { + // Single predicate + assertInvalidMessage(String.format(EQ_AMBIGUOUS_ERROR, "v", getIndex(0), getIndex(1)), + "SELECT k FROM %s WHERE v = 'apple'"); + + // AND + assertInvalidMessage(String.format(EQ_AMBIGUOUS_ERROR, "v", getIndex(0), getIndex(1)), + "SELECT k FROM %s WHERE v = 'apple' AND v : 'juice'"); + + // OR + assertInvalidMessage(String.format(EQ_AMBIGUOUS_ERROR, "v", getIndex(0), getIndex(1)), + "SELECT k FROM %s WHERE v = 'apple' OR v : 'juice'"); + }); + } + + @Test + public void testTwoIndexesWithEqualsUnsupported() throws Throwable + { + createTable("CREATE TABLE %s (k int PRIMARY KEY, v text)"); + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'"); + // analyzed index with equals_behavior:unsupported option + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex' " + + "WITH OPTIONS = { 'equals_behaviour_when_analyzed': 'unsupported', " + + "'index_analyzer':'{\"tokenizer\":{\"name\":\"standard\"},\"filters\":[{\"name\":\"porterstem\"}]}' }"); + + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + execute("INSERT INTO %s (k, v) VALUES (2, 'apple juice')"); + + beforeAndAfterFlush(() -> { + // combining two EQ predicates is not allowed + assertInvalid("SELECT k FROM %s WHERE v = 'apple' AND v = 'juice'"); + + // combining EQ and MATCH predicates is also not allowed (when we're not converting EQ to MATCH) + assertInvalid("SELECT k FROM %s WHERE v = 'apple' AND v : 'apple'"); + + // combining two MATCH predicates is fine + assertRows(execute("SELECT k FROM %s WHERE v : 'apple' AND v : 'juice'"), + row(2)); + + // = operator should use un-analyzed index since equals is unsupported in analyzed index + assertRows(execute("SELECT k FROM %s WHERE v = 'apple'"), + row(1)); + + // : operator should use analyzed index + assertRows(execute("SELECT k FROM %s WHERE v : 'apple'"), + row(1), row(2)); + }); + } + + @Test + public void testComplexQueriesWithMultipleIndexes() throws Throwable + { + createTable("CREATE TABLE %s (k int PRIMARY KEY, v1 text, v2 text, v3 int)"); + + // Create mix of analyzed, unanalyzed, and non-text indexes + createIndex("CREATE CUSTOM INDEX ON %s(v1) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'"); + createIndex("CREATE CUSTOM INDEX ON %s(v2) " + + "USING 'org.apache.cassandra.index.sai.StorageAttachedIndex' " + + "WITH OPTIONS = {" + + "'index_analyzer': '{" + + "\"tokenizer\" : {\"name\" : \"standard\"}, " + + "\"filters\" : [{\"name\" : \"porterstem\"}]" + + "}'" + + "}"); + createIndex("CREATE CUSTOM INDEX ON %s(v3) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'"); + + execute("INSERT INTO %s (k, v1, v2, v3) VALUES (1, 'apple', 'orange juice', 5)"); + execute("INSERT INTO %s (k, v1, v2, v3) VALUES (2, 'apple juice', 'apple', 10)"); + execute("INSERT INTO %s (k, v1, v2, v3) VALUES (3, 'banana', 'grape juice', 5)"); + + beforeAndAfterFlush(() -> { + // Complex query mixing different types of indexes and operators + assertRows(execute("SELECT k FROM %s WHERE v1 = 'apple' AND v2 : 'juice' AND v3 = 5"), + row(1)); + + // Mix of AND and OR conditions across different index types + assertRows(execute("SELECT k FROM %s WHERE v3 = 5 AND (v1 = 'apple' OR v2 : 'apple')"), + row(1)); + + // Multi-term analyzed query + assertRows(execute("SELECT k FROM %s WHERE v2 : 'orange juice'"), + row(1)); + + // Range query with text match + assertRows(execute("SELECT k FROM %s WHERE v3 >= 5 AND v2 : 'juice'"), + row(1), row(3)); + }); + } + + @Test + public void testMatchingAllowed() throws Throwable + { + // match operator should be allowed with BM25 on the same column + // (seems obvious but exercises a corner case in the internal RestrictionSet processing) + createSimpleTable(); + + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + + beforeAndAfterFlush(() -> + { + var result = execute("SELECT k FROM %s WHERE v : 'apple' ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertRows(result, row(1)); + }); + } + + @Test + public void testUnknownQueryTerm() throws Throwable + { + createSimpleTable(); + + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + + beforeAndAfterFlush(() -> + { + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'orange' LIMIT 1"); + assertEmpty(result); + }); + } + + @Test + public void testDuplicateQueryTerm() throws Throwable + { + createSimpleTable(); + + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + + beforeAndAfterFlush(() -> + { + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'apple apple' LIMIT 1"); + assertRows(result, row(1)); + }); + } + + @Test + public void testEmptyQuery() throws Throwable + { + createSimpleTable(); + + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + + beforeAndAfterFlush(() -> + { + assertInvalidMessage("BM25 query must contain at least one term (perhaps your analyzer is discarding tokens you didn't expect)", + "SELECT k FROM %s ORDER BY v BM25 OF '+' LIMIT 1"); + }); + } + + @Test + public void testTermFrequencyOrdering() throws Throwable + { + createSimpleTable(); + + // Insert documents with varying frequencies of the term "apple" + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + execute("INSERT INTO %s (k, v) VALUES (2, 'apple apple')"); + execute("INSERT INTO %s (k, v) VALUES (3, 'apple apple apple')"); + + beforeAndAfterFlush(() -> + { + // Results should be ordered by term frequency (highest to lowest) + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertRows(result, + row(3), // 3 occurrences + row(2), // 2 occurrences + row(1)); // 1 occurrence + }); + } + + @Test + public void testTermFrequenciesWithOverwrites() throws Throwable + { + createSimpleTable(); + + // Insert documents with varying frequencies of the term "apple", but overwrite the first term + // This exercises the code that is supposed to reset frequency counts for overwrites + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + execute("INSERT INTO %s (k, v) VALUES (2, 'apple apple')"); + execute("INSERT INTO %s (k, v) VALUES (3, 'apple apple apple')"); + + beforeAndAfterFlush(() -> + { + // Results should be ordered by term frequency (highest to lowest) + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertRows(result, + row(3), // 3 occurrences + row(2), // 2 occurrences + row(1)); // 1 occurrence + }); + } + + @Test + public void testDocumentLength() throws Throwable + { + createSimpleTable(); + // Create documents with same term frequency but different lengths + execute("INSERT INTO %s (k, v) VALUES (1, 'test test')"); + execute("INSERT INTO %s (k, v) VALUES (2, 'test test other words here to make it longer')"); + execute("INSERT INTO %s (k, v) VALUES (3, 'test test extremely long document with many additional words to significantly increase the document length while maintaining the same term frequency for our target term')"); + + beforeAndAfterFlush(() -> + { + // Documents with same term frequency should be ordered by length (shorter first) + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'test' LIMIT 3"); + assertRows(result, + row(1), + row(2), + row(3)); + }); + } + + @Test + public void testMultiTermQueryScoring() throws Throwable + { + createSimpleTable(); + // Two terms, but "apple" appears in fewer documents + execute("INSERT INTO %s (k, v) VALUES (1, 'apple banana')"); + execute("INSERT INTO %s (k, v) VALUES (2, 'apple apple banana')"); + execute("INSERT INTO %s (k, v) VALUES (3, 'apple banana banana')"); + execute("INSERT INTO %s (k, v) VALUES (4, 'apple apple banana banana')"); + execute("INSERT INTO %s (k, v) VALUES (5, 'banana banana')"); + + beforeAndAfterFlush(() -> + { + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'apple banana' LIMIT 4"); + assertRows(result, + row(2), // Highest frequency of most important term + row(4), // More mentions of both terms + row(1), // One of each term + row(3)); // Low frequency of most important term + }); + } + + @Test + public void testIrrelevantRowsScoring() throws Throwable + { + createSimpleTable(); + // Insert pizza reviews with varying relevance to "crispy crust" + execute("INSERT INTO %s (k, v) VALUES (1, 'The pizza had a crispy crust and was delicious')"); // Basic mention + execute("INSERT INTO %s (k, v) VALUES (2, 'Very crispy crispy crust, perfectly cooked')"); // Emphasized crispy + execute("INSERT INTO %s (k, v) VALUES (3, 'The crust crust crust was okay, nothing special')"); // Only crust mentions + execute("INSERT INTO %s (k, v) VALUES (4, 'Super crispy crispy crust crust, best pizza ever!')"); // Most mentions of both + execute("INSERT INTO %s (k, v) VALUES (5, 'The toppings were good but the pizza was soggy')"); // Irrelevant review + + beforeAndAfterFlush(this::assertIrrelevantRowsCorrect); + } + + private void assertIrrelevantRowsCorrect() + { + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'crispy crust' LIMIT 5"); + assertRows(result, + row(4), // Highest frequency of both terms + row(2), // High frequency of 'crispy', one 'crust' + row(1)); // One mention of each term + // Rows 4 and 5 do not contain all terms + } + + @Test + public void testIrrelevantRowsWithCompaction() + { + // same dataset as testIrrelevantRowsScoring, but split across two sstables + createSimpleTable(); + disableCompaction(); + + execute("INSERT INTO %s (k, v) VALUES (1, 'The pizza had a crispy crust and was delicious')"); // Basic mention + execute("INSERT INTO %s (k, v) VALUES (2, 'Very crispy crispy crust, perfectly cooked')"); // Emphasized crispy + flush(); + + execute("INSERT INTO %s (k, v) VALUES (3, 'The crust crust crust was okay, nothing special')"); // Only crust mentions + execute("INSERT INTO %s (k, v) VALUES (4, 'Super crispy crispy crust crust, best pizza ever!')"); // Most mentions of both + execute("INSERT INTO %s (k, v) VALUES (5, 'The toppings were good but the pizza was soggy')"); // Irrelevant review + flush(); + + assertIrrelevantRowsCorrect(); + + compact(); + assertIrrelevantRowsCorrect(); + + // Force segmentation and requery + SegmentBuilder.updateLastValidSegmentRowId(2); + compact(); + assertIrrelevantRowsCorrect(); + } + + private void createSimpleTable() + { + createTable("CREATE TABLE %s (k int PRIMARY KEY, v text)"); + analyzeIndex(); + } + + private String analyzeIndex() + { + return createIndex("CREATE CUSTOM INDEX ON %s(v) " + + "USING 'org.apache.cassandra.index.sai.StorageAttachedIndex' " + + "WITH OPTIONS = {" + + "'index_analyzer': '{" + + "\"tokenizer\" : {\"name\" : \"standard\"}, " + + "\"filters\" : [{\"name\" : \"porterstem\"}]" + + "}'}" + ); + } + + @Test + public void testWithPredicate() throws Throwable + { + createTable("CREATE TABLE %s (k int PRIMARY KEY, p int, v text)"); + analyzeIndex(); + execute("CREATE CUSTOM INDEX ON %s(p) USING 'StorageAttachedIndex'"); + + // Insert documents with varying frequencies of the term "apple" + execute("INSERT INTO %s (k, p, v) VALUES (1, 5, 'apple')"); + execute("INSERT INTO %s (k, p, v) VALUES (2, 5, 'apple apple')"); + execute("INSERT INTO %s (k, p, v) VALUES (3, 5, 'apple apple apple')"); + execute("INSERT INTO %s (k, p, v) VALUES (4, 6, 'apple apple apple')"); + execute("INSERT INTO %s (k, p, v) VALUES (5, 7, 'apple apple apple')"); + + beforeAndAfterFlush(() -> + { + // Results should be ordered by term frequency (highest to lowest) + var result = execute("SELECT k FROM %s WHERE p = 5 ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertRows(result, + row(3), // 3 occurrences + row(2), // 2 occurrences + row(1)); // 1 occurrence + }); + } + + @Test + public void testWidePartition() throws Throwable + { + createTable("CREATE TABLE %s (k1 int, k2 int, v text, PRIMARY KEY (k1, k2))"); + analyzeIndex(); + + // Insert documents with varying frequencies of the term "apple" + execute("INSERT INTO %s (k1, k2, v) VALUES (0, 1, 'apple')"); + execute("INSERT INTO %s (k1, k2, v) VALUES (0, 2, 'apple apple')"); + execute("INSERT INTO %s (k1, k2, v) VALUES (0, 3, 'apple apple apple')"); + + beforeAndAfterFlush(() -> + { + // Results should be ordered by term frequency (highest to lowest) + var result = execute("SELECT k2 FROM %s ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertRows(result, + row(3), // 3 occurrences + row(2), // 2 occurrences + row(1)); // 1 occurrence + }); + } + + @Test + public void testWidePartitionWithPkPredicate() throws Throwable + { + createTable("CREATE TABLE %s (k1 int, k2 int, v text, PRIMARY KEY (k1, k2))"); + analyzeIndex(); + + // Insert documents with varying frequencies of the term "apple" + execute("INSERT INTO %s (k1, k2, v) VALUES (0, 1, 'apple')"); + execute("INSERT INTO %s (k1, k2, v) VALUES (0, 2, 'apple apple')"); + execute("INSERT INTO %s (k1, k2, v) VALUES (0, 3, 'apple apple apple')"); + execute("INSERT INTO %s (k1, k2, v) VALUES (1, 3, 'apple apple apple')"); + execute("INSERT INTO %s (k1, k2, v) VALUES (2, 3, 'apple apple apple')"); + + beforeAndAfterFlush(() -> + { + // Results should be ordered by term frequency (highest to lowest) + var result = execute("SELECT k2 FROM %s WHERE k1 = 0 ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertRows(result, + row(3), // 3 occurrences + row(2), // 2 occurrences + row(1)); // 1 occurrence + }); + } + + @Test + public void testWidePartitionWithPredicate() throws Throwable + { + createTable("CREATE TABLE %s (k1 int, k2 int, p int, v text, PRIMARY KEY (k1, k2))"); + analyzeIndex(); + execute("CREATE CUSTOM INDEX ON %s(p) USING 'StorageAttachedIndex'"); + + // Insert documents with varying frequencies of the term "apple" + execute("INSERT INTO %s (k1, k2, p, v) VALUES (0, 1, 5, 'apple')"); + execute("INSERT INTO %s (k1, k2, p, v) VALUES (0, 2, 5, 'apple apple')"); + execute("INSERT INTO %s (k1, k2, p, v) VALUES (0, 3, 5, 'apple apple apple')"); + execute("INSERT INTO %s (k1, k2, p, v) VALUES (0, 4, 6, 'apple apple apple')"); + execute("INSERT INTO %s (k1, k2, p, v) VALUES (0, 5, 7, 'apple apple apple')"); + + beforeAndAfterFlush(() -> + { + // Results should be ordered by term frequency (highest to lowest) + var result = execute("SELECT k2 FROM %s WHERE p = 5 ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertRows(result, + row(3), // 3 occurrences + row(2), // 2 occurrences + row(1)); // 1 occurrence + }); + } + + @Test + public void testWithPredicateSearchThenOrder() throws Throwable + { + QueryController.QUERY_OPT_LEVEL = 0; + testWithPredicate(); + } + + @Test + public void testWidePartitionWithPredicateOrderThenSearch() throws Throwable + { + QueryController.QUERY_OPT_LEVEL = 1; + testWidePartitionWithPredicate(); + } + + @Test + public void testQueryWithNulls() throws Throwable + { + createSimpleTable(); + + execute("INSERT INTO %s (k, v) VALUES (0, null)"); + execute("INSERT INTO %s (k, v) VALUES (1, 'test document')"); + beforeAndAfterFlush(() -> + { + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'test' LIMIT 1"); + assertRows(result, row(1)); + }); + } + + @Test + public void testQueryEmptyTable() + { + createSimpleTable(); + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'test' LIMIT 1"); + assertThat(result).hasSize(0); + } +} diff --git a/test/unit/org/apache/cassandra/index/sai/cql/MultipleColumnIndexTest.java b/test/unit/org/apache/cassandra/index/sai/cql/MultipleColumnIndexTest.java index 5364717a8906..8d952dc1f4f0 100644 --- a/test/unit/org/apache/cassandra/index/sai/cql/MultipleColumnIndexTest.java +++ b/test/unit/org/apache/cassandra/index/sai/cql/MultipleColumnIndexTest.java @@ -40,12 +40,30 @@ public void canCreateMultipleMapIndexesOnSameColumn() throws Throwable } @Test - public void cannotHaveMultipleLiteralIndexesWithDifferentOptions() throws Throwable + public void canHaveAnalyzedAndUnanalyzedIndexesOnSameColumn() throws Throwable { - createTable("CREATE TABLE %s (pk int, ck int, value text, PRIMARY KEY(pk, ck))"); + createTable("CREATE TABLE %s (pk int, value text, PRIMARY KEY(pk))"); createIndex("CREATE CUSTOM INDEX ON %s(value) USING 'StorageAttachedIndex' WITH OPTIONS = { 'case_sensitive' : true }"); - assertThatThrownBy(() -> createIndex("CREATE CUSTOM INDEX ON %s(value) USING 'StorageAttachedIndex' WITH OPTIONS = { 'case_sensitive' : false }")) - .isInstanceOf(InvalidRequestException.class); + createIndex("CREATE CUSTOM INDEX ON %s(value) USING 'StorageAttachedIndex' WITH OPTIONS = { 'case_sensitive' : false, 'equals_behaviour_when_analyzed': 'unsupported' }"); + + execute("INSERT INTO %s (pk, value) VALUES (?, ?)", 1, "a"); + execute("INSERT INTO %s (pk, value) VALUES (?, ?)", 2, "A"); + beforeAndAfterFlush(() -> { + assertRows(execute("SELECT pk FROM %s WHERE value = 'a'"), + row(1)); + assertRows(execute("SELECT pk FROM %s WHERE value : 'a'"), + row(1), + row(2)); + }); + } + + @Test + public void cannotHaveMultipleAnalyzingIndexesOnSameColumn() throws Throwable + { + createTable("CREATE TABLE %s (pk int, ck int, value text, PRIMARY KEY(pk, ck))"); + createIndex("CREATE CUSTOM INDEX ON %s(value) USING 'StorageAttachedIndex' WITH OPTIONS = { 'case_sensitive' : false }"); + assertThatThrownBy(() -> createIndex("CREATE CUSTOM INDEX ON %s(value) USING 'StorageAttachedIndex' WITH OPTIONS = { 'normalize' : true }")) + .isInstanceOf(InvalidRequestException.class); } @Test diff --git a/test/unit/org/apache/cassandra/index/sai/cql/NativeIndexDDLTest.java b/test/unit/org/apache/cassandra/index/sai/cql/NativeIndexDDLTest.java index ebda5a426430..b7b5a078b2c0 100644 --- a/test/unit/org/apache/cassandra/index/sai/cql/NativeIndexDDLTest.java +++ b/test/unit/org/apache/cassandra/index/sai/cql/NativeIndexDDLTest.java @@ -581,7 +581,7 @@ public void shouldFailCreationMultipleIndexesOnSimpleColumn() // different name, different option, same target. assertThatThrownBy(() -> executeNet("CREATE CUSTOM INDEX ON %s(v1) USING 'StorageAttachedIndex' WITH OPTIONS = { 'case_sensitive' : true }")) .isInstanceOf(InvalidQueryException.class) - .hasMessageContaining("Cannot create more than one storage-attached index on the same column: v1" ); + .hasMessageContaining("Cannot create duplicate storage-attached index on column: v1" ); ResultSet rows = executeNet("SELECT id FROM %s WHERE v1 = '1'"); assertEquals(1, rows.all().size()); diff --git a/test/unit/org/apache/cassandra/index/sai/disk/RAMPostingSlicesTest.java b/test/unit/org/apache/cassandra/index/sai/disk/RAMPostingSlicesTest.java index 8d49798d906b..d6ca88d935ee 100644 --- a/test/unit/org/apache/cassandra/index/sai/disk/RAMPostingSlicesTest.java +++ b/test/unit/org/apache/cassandra/index/sai/disk/RAMPostingSlicesTest.java @@ -19,6 +19,8 @@ import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; import org.junit.Test; @@ -31,7 +33,7 @@ public class RAMPostingSlicesTest extends SaiRandomizedTest @Test public void testRAMPostingSlices() throws Exception { - RAMPostingSlices slices = new RAMPostingSlices(Counter.newCounter()); + RAMPostingSlices slices = new RAMPostingSlices(Counter.newCounter(), false); int[] segmentRowIdUpto = new int[1024]; Arrays.fill(segmentRowIdUpto, -1); @@ -56,7 +58,7 @@ public void testRAMPostingSlices() throws Exception bitSets[termID].set(segmentRowIdUpto[termID]); - slices.writeVInt(termID, segmentRowIdUpto[termID]); + slices.writePosting(termID, segmentRowIdUpto[termID], 1); } for (int termID = 0; termID < segmentRowIdUpto.length; termID++) @@ -74,4 +76,36 @@ public void testRAMPostingSlices() throws Exception assertEquals(segmentRowId, segmentRowIdUpto[termID]); } } + + @Test + public void testRAMPostingSlicesWithFrequencies() throws Exception { + RAMPostingSlices slices = new RAMPostingSlices(Counter.newCounter(), true); + + // Test with just 3 terms and known frequencies + for (int termId = 0; termId < 3; termId++) { + slices.createNewSlice(termId); + + // Write a sequence of rows with different frequencies for each term + slices.writePosting(termId, 5, 1); // first posting at row 5 + slices.writePosting(termId, 3, 2); // next at row 8 (delta=3) + slices.writePosting(termId, 2, 3); // next at row 10 (delta=2) + } + + // Verify each term's postings + for (int termId = 0; termId < 3; termId++) { + ByteSliceReader reader = new ByteSliceReader(); + PostingList postings = slices.postingList(termId, reader, 10); + + assertEquals(5, postings.nextPosting()); + assertEquals(1, postings.frequency()); + + assertEquals(8, postings.nextPosting()); + assertEquals(2, postings.frequency()); + + assertEquals(10, postings.nextPosting()); + assertEquals(3, postings.frequency()); + + assertEquals(PostingList.END_OF_STREAM, postings.nextPosting()); + } + } } diff --git a/test/unit/org/apache/cassandra/index/sai/disk/RAMStringIndexerTest.java b/test/unit/org/apache/cassandra/index/sai/disk/RAMStringIndexerTest.java index 9ee1605512cf..15d47013f946 100644 --- a/test/unit/org/apache/cassandra/index/sai/disk/RAMStringIndexerTest.java +++ b/test/unit/org/apache/cassandra/index/sai/disk/RAMStringIndexerTest.java @@ -21,11 +21,15 @@ import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.junit.Assert; import org.junit.Test; +import org.apache.cassandra.index.sai.disk.format.IndexComponent; +import org.apache.cassandra.index.sai.disk.v1.trie.InvertedIndexWriter; import org.apache.cassandra.index.sai.utils.SaiRandomizedTest; import org.apache.cassandra.index.sai.utils.TypeUtil; import org.apache.cassandra.utils.ByteBufferUtil; @@ -40,13 +44,13 @@ public class RAMStringIndexerTest extends SaiRandomizedTest @Test public void test() throws Exception { - RAMStringIndexer indexer = new RAMStringIndexer(); + RAMStringIndexer indexer = new RAMStringIndexer(false); - indexer.add(new BytesRef("0"), 100); - indexer.add(new BytesRef("2"), 102); - indexer.add(new BytesRef("0"), 200); - indexer.add(new BytesRef("2"), 202); - indexer.add(new BytesRef("2"), 302); + indexer.addAll(List.of(new BytesRef("0")), 100); + indexer.addAll(List.of(new BytesRef("2")), 102); + indexer.addAll(List.of(new BytesRef("0")), 200); + indexer.addAll(List.of(new BytesRef("2")), 202); + indexer.addAll(List.of(new BytesRef("2")), 302); List> matches = new ArrayList<>(); matches.add(Arrays.asList(100L, 200L)); @@ -75,10 +79,48 @@ public void test() throws Exception } } + @Test + public void testWithFrequencies() throws Exception + { + RAMStringIndexer indexer = new RAMStringIndexer(true); + + // Add same term twice in same row to increment frequency + indexer.addAll(List.of(new BytesRef("A"), new BytesRef("A")), 100); + indexer.addAll(List.of(new BytesRef("B")), 102); + indexer.addAll(List.of(new BytesRef("A"), new BytesRef("A"), new BytesRef("A")), 200); + indexer.addAll(List.of(new BytesRef("B"), new BytesRef("B")), 202); + indexer.addAll(List.of(new BytesRef("B")), 302); + + // Expected results: rowID -> frequency + List> matches = Arrays.asList(Map.of(100L, 2, 200L, 3), + Map.of(102L, 1, 202L, 2, 302L, 1)); + + try (TermsIterator terms = indexer.getTermsWithPostings(ByteBufferUtil.bytes("A"), ByteBufferUtil.bytes("B"), TypeUtil.BYTE_COMPARABLE_VERSION)) + { + int ord = 0; + while (terms.hasNext()) + { + terms.next(); + try (PostingList postings = terms.postings()) + { + Map results = new HashMap<>(); + long segmentRowId; + while ((segmentRowId = postings.nextPosting()) != PostingList.END_OF_STREAM) + { + results.put(segmentRowId, postings.frequency()); + } + assertEquals(matches.get(ord++), results); + } + } + assertArrayEquals("A".getBytes(), terms.getMinTerm().array()); + assertArrayEquals("B".getBytes(), terms.getMaxTerm().array()); + } + } + @Test public void testLargeSegment() throws IOException { - final RAMStringIndexer indexer = new RAMStringIndexer(); + final RAMStringIndexer indexer = new RAMStringIndexer(false); final int numTerms = between(1 << 10, 1 << 13); final int numPostings = between(1 << 5, 1 << 10); @@ -87,7 +129,7 @@ public void testLargeSegment() throws IOException final BytesRef term = new BytesRef(String.format("%04d", id)); for (int posting = 0; posting < numPostings; ++posting) { - indexer.add(term, posting); + indexer.addAll(List.of(term), posting); } } @@ -124,14 +166,14 @@ public void testRequiresFlush() { RAMStringIndexer.MAX_BLOCK_BYTE_POOL_SIZE = 1024 * 1024 * 100; // primary behavior we're testing is that exceptions aren't thrown due to overflowing backing structures - RAMStringIndexer indexer = new RAMStringIndexer(); + RAMStringIndexer indexer = new RAMStringIndexer(false); Assert.assertFalse(indexer.requiresFlush()); for (int i = 0; i < Integer.MAX_VALUE; i++) { if (indexer.requiresFlush()) break; - indexer.add(new BytesRef(String.format("%5000d", i)), i); + indexer.addAll(List.of(new BytesRef(String.format("%5000d", i))), i); } // If we don't require a flush before MAX_VALUE, the implementation of RAMStringIndexer has sufficiently // changed to warrant changes to the test. diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexBuilder.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexBuilder.java index d5699c195305..db2f61370754 100644 --- a/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexBuilder.java +++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexBuilder.java @@ -19,15 +19,18 @@ import java.nio.ByteBuffer; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.function.IntSupplier; import java.util.function.Supplier; +import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; import com.carrotsearch.hppc.IntArrayList; import org.apache.cassandra.db.marshal.UTF8Type; import org.apache.cassandra.index.sai.disk.format.Version; +import org.apache.cassandra.index.sai.memory.RowMapping; import org.apache.cassandra.utils.Pair; import org.apache.cassandra.utils.bytecomparable.ByteComparable; @@ -83,10 +86,15 @@ public static class TermsEnum this.byteComparableBytes = byteComparableBytes; this.postings = postings; } + } - public Pair toPair() - { - return Pair.create(byteComparableBytes, postings); - } + /** + * Adds default frequency of 1 to postings + */ + static Pair> toTermWithFrequency(TermsEnum te) + { + return Pair.create(te.byteComparableBytes, Arrays.stream(te.postings.toArray()).boxed() + .map(p -> new RowMapping.RowIdWithFrequency(p, 1)) + .collect(toList())); } } diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcherTest.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcherTest.java index 47565eb28739..a15b11f6abf3 100644 --- a/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcherTest.java +++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcherTest.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.stream.Collectors; +import org.agrona.collections.Int2IntHashMap; import org.junit.BeforeClass; import org.junit.Test; @@ -101,12 +102,12 @@ private void doTestEqQueriesAgainstStringIndex(Version version) throws Exception final int numTerms = randomIntBetween(64, 512), numPostings = randomIntBetween(256, 1024); final List termsEnum = buildTermsEnum(version, numTerms, numPostings); - try (IndexSearcher searcher = buildIndexAndOpenSearcher(numTerms, numPostings, termsEnum)) + try (IndexSearcher searcher = buildIndexAndOpenSearcher(numTerms, termsEnum)) { for (int t = 0; t < numTerms; ++t) { try (KeyRangeIterator results = searcher.search(new Expression(indexContext) - .add(Operator.EQ, termsEnum.get(t).originalTermBytes), null, new QueryContext(), false, LIMIT)) + .add(Operator.EQ, termsEnum.get(t).originalTermBytes), null, new QueryContext(), false)) { assertTrue(results.hasNext()); @@ -121,7 +122,7 @@ private void doTestEqQueriesAgainstStringIndex(Version version) throws Exception } try (KeyRangeIterator results = searcher.search(new Expression(indexContext) - .add(Operator.EQ, termsEnum.get(t).originalTermBytes), null, new QueryContext(), false, LIMIT)) + .add(Operator.EQ, termsEnum.get(t).originalTermBytes), null, new QueryContext(), false)) { assertTrue(results.hasNext()); @@ -143,12 +144,12 @@ private void doTestEqQueriesAgainstStringIndex(Version version) throws Exception // try searching for terms that weren't indexed final String tooLongTerm = randomSimpleString(10, 12); KeyRangeIterator results = searcher.search(new Expression(indexContext) - .add(Operator.EQ, UTF8Type.instance.decompose(tooLongTerm)), null, new QueryContext(), false, LIMIT); + .add(Operator.EQ, UTF8Type.instance.decompose(tooLongTerm)), null, new QueryContext(), false); assertFalse(results.hasNext()); final String tooShortTerm = randomSimpleString(1, 2); results = searcher.search(new Expression(indexContext) - .add(Operator.EQ, UTF8Type.instance.decompose(tooShortTerm)), null, new QueryContext(), false, LIMIT); + .add(Operator.EQ, UTF8Type.instance.decompose(tooShortTerm)), null, new QueryContext(), false); assertFalse(results.hasNext()); } } @@ -159,10 +160,10 @@ public void testUnsupportedOperator() throws Exception final int numTerms = randomIntBetween(5, 15), numPostings = randomIntBetween(5, 20); final List termsEnum = buildTermsEnum(version, numTerms, numPostings); - try (IndexSearcher searcher = buildIndexAndOpenSearcher(numTerms, numPostings, termsEnum)) + try (IndexSearcher searcher = buildIndexAndOpenSearcher(numTerms, termsEnum)) { searcher.search(new Expression(indexContext) - .add(Operator.NEQ, UTF8Type.instance.decompose("a")), null, new QueryContext(), false, LIMIT); + .add(Operator.NEQ, UTF8Type.instance.decompose("a")), null, new QueryContext(), false); fail("Expect IllegalArgumentException thrown, but didn't"); } @@ -172,9 +173,8 @@ public void testUnsupportedOperator() throws Exception } } - private IndexSearcher buildIndexAndOpenSearcher(int terms, int postings, List termsEnum) throws IOException + private IndexSearcher buildIndexAndOpenSearcher(int terms, List termsEnum) throws IOException { - final int size = terms * postings; final IndexDescriptor indexDescriptor = newIndexDescriptor(); final String index = newIndex(); final IndexContext indexContext = SAITester.createIndexContext(index, UTF8Type.instance); @@ -189,9 +189,10 @@ private IndexSearcher buildIndexAndOpenSearcher(int terms, int postings, List buildTermsEnum(Version version, int return InvertedIndexBuilder.buildStringTermsEnum(version, terms, postings, () -> randomSimpleString(3, 5), () -> nextInt(0, Integer.MAX_VALUE)); } + private Int2IntHashMap createMockDocLengths(List termsEnum) + { + Int2IntHashMap docLengths = new Int2IntHashMap(Integer.MIN_VALUE); + for (InvertedIndexBuilder.TermsEnum term : termsEnum) + { + for (var cursor : term.postings) + docLengths.put(cursor.value, 1); + } + return docLengths; + } + private ByteBuffer wrap(ByteComparable bc) { return ByteBuffer.wrap(ByteSourceInverse.readBytes(bc.asComparableBytes(TypeUtil.BYTE_COMPARABLE_VERSION))); diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcherTest.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcherTest.java index 8f34b9a808aa..65242e319a37 100644 --- a/test/unit/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcherTest.java +++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcherTest.java @@ -151,7 +151,7 @@ public void testUnsupportedOperator() throws Exception {{ operation = Op.NOT_EQ; lower = upper = new Bound(ShortType.instance.decompose((short) 0), Int32Type.instance, true); - }}, null, new QueryContext(), false, LIMIT); + }}, null, new QueryContext(), false); fail("Expect IllegalArgumentException thrown, but didn't"); } @@ -169,7 +169,7 @@ private void testEqQueries(final IndexSearcher indexSearcher, {{ operation = Op.EQ; lower = upper = new Bound(rawType.decompose(rawValueProducer.apply(EQ_TEST_LOWER_BOUND_INCLUSIVE)), encodedType, true); - }}, null, new QueryContext(), false, LIMIT)) + }}, null, new QueryContext(), false)) { assertTrue(results.hasNext()); @@ -180,7 +180,7 @@ private void testEqQueries(final IndexSearcher indexSearcher, {{ operation = Op.EQ; lower = upper = new Bound(rawType.decompose(rawValueProducer.apply(EQ_TEST_UPPER_BOUND_EXCLUSIVE)), encodedType, true); - }}, null, new QueryContext(), false, LIMIT)) + }}, null, new QueryContext(), false)) { assertFalse(results.hasNext()); indexSearcher.close(); @@ -206,7 +206,7 @@ private void testRangeQueries(final IndexSearcher indexSearch lower = new Bound(rawType.decompose(rawValueProducer.apply((short)2)), encodedType, false); upper = new Bound(rawType.decompose(rawValueProducer.apply((short)7)), encodedType, true); - }}, null, new QueryContext(), false, LIMIT)) + }}, null, new QueryContext(), false)) { assertTrue(results.hasNext()); @@ -218,7 +218,7 @@ private void testRangeQueries(final IndexSearcher indexSearch {{ operation = Op.RANGE; lower = new Bound(rawType.decompose(rawValueProducer.apply(RANGE_TEST_UPPER_BOUND_EXCLUSIVE)), encodedType, true); - }}, null, new QueryContext(), false, LIMIT)) + }}, null, new QueryContext(), false)) { assertFalse(results.hasNext()); } @@ -227,7 +227,7 @@ private void testRangeQueries(final IndexSearcher indexSearch {{ operation = Op.RANGE; upper = new Bound(rawType.decompose(rawValueProducer.apply(RANGE_TEST_LOWER_BOUND_INCLUSIVE)), encodedType, false); - }}, null, new QueryContext(), false, LIMIT)) + }}, null, new QueryContext(), false)) { assertFalse(results.hasNext()); indexSearcher.close(); diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/TermsReaderTest.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/TermsReaderTest.java index bdfbf529e86b..97277da2490a 100644 --- a/test/unit/org/apache/cassandra/index/sai/disk/v1/TermsReaderTest.java +++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/TermsReaderTest.java @@ -18,42 +18,49 @@ package org.apache.cassandra.index.sai.disk.v1; import java.io.IOException; +import java.util.ArrayList; import java.util.Collection; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import org.junit.Test; -import com.carrotsearch.hppc.IntArrayList; import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import org.agrona.collections.Int2IntHashMap; import org.apache.cassandra.config.DatabaseDescriptor; import org.apache.cassandra.db.marshal.UTF8Type; import org.apache.cassandra.index.sai.IndexContext; import org.apache.cassandra.index.sai.QueryContext; import org.apache.cassandra.index.sai.SAITester; +import org.apache.cassandra.index.sai.SAIUtil; import org.apache.cassandra.index.sai.disk.MemtableTermsIterator; import org.apache.cassandra.index.sai.disk.PostingList; +import org.apache.cassandra.index.sai.disk.RAMStringIndexer; import org.apache.cassandra.index.sai.disk.TermsIterator; import org.apache.cassandra.index.sai.disk.format.IndexComponentType; import org.apache.cassandra.index.sai.disk.format.IndexComponents; import org.apache.cassandra.index.sai.disk.format.IndexDescriptor; import org.apache.cassandra.index.sai.disk.format.Version; import org.apache.cassandra.index.sai.disk.v1.trie.InvertedIndexWriter; +import org.apache.cassandra.index.sai.memory.RowMapping; import org.apache.cassandra.index.sai.metrics.QueryEventListener; import org.apache.cassandra.index.sai.utils.SAICodecUtils; import org.apache.cassandra.index.sai.utils.SaiRandomizedTest; import org.apache.cassandra.index.sai.utils.TypeUtil; import org.apache.cassandra.io.util.FileHandle; +import org.apache.cassandra.utils.ByteBufferUtil; import org.apache.cassandra.utils.Pair; import org.apache.cassandra.utils.bytecomparable.ByteComparable; import org.apache.cassandra.utils.bytecomparable.ByteSourceInverse; +import org.apache.lucene.util.BytesRef; import static org.apache.cassandra.index.sai.disk.v1.InvertedIndexBuilder.buildStringTermsEnum; import static org.apache.cassandra.index.sai.metrics.QueryEventListeners.NO_OP_TRIE_LISTENER; public class TermsReaderTest extends SaiRandomizedTest { - public static final ByteComparable.Version VERSION = TypeUtil.BYTE_COMPARABLE_VERSION; @ParametersFactory() @@ -102,8 +109,11 @@ private void doTestTermsIteration(Version version) throws IOException IndexComponents.ForWrite components = indexDescriptor.newPerIndexComponentsForWrite(indexContext); try (InvertedIndexWriter writer = new InvertedIndexWriter(components)) { - var iter = termsEnum.stream().map(InvertedIndexBuilder.TermsEnum::toPair).iterator(); - indexMetas = writer.writeAll(new MemtableTermsIterator(null, null, iter)); + var iter = termsEnum.stream() + .map(InvertedIndexBuilder::toTermWithFrequency) + .iterator(); + Int2IntHashMap docLengths = createMockDocLengths(termsEnum); + indexMetas = writer.writeAll(new MemtableTermsIterator(null, null, iter), docLengths); } FileHandle termsData = components.get(IndexComponentType.TERMS_DATA).createFileHandle(); @@ -142,8 +152,11 @@ private void testTermQueries(Version version, int numTerms, int numPostings) thr IndexComponents.ForWrite components = indexDescriptor.newPerIndexComponentsForWrite(indexContext); try (InvertedIndexWriter writer = new InvertedIndexWriter(components)) { - var iter = termsEnum.stream().map(InvertedIndexBuilder.TermsEnum::toPair).iterator(); - indexMetas = writer.writeAll(new MemtableTermsIterator(null, null, iter)); + var iter = termsEnum.stream() + .map(InvertedIndexBuilder::toTermWithFrequency) + .iterator(); + Int2IntHashMap docLengths = createMockDocLengths(termsEnum); + indexMetas = writer.writeAll(new MemtableTermsIterator(null, null, iter), docLengths); } FileHandle termsData = components.get(IndexComponentType.TERMS_DATA).createFileHandle(); @@ -159,22 +172,24 @@ private void testTermQueries(Version version, int numTerms, int numPostings) thr termsFooterPointer, version)) { - var iter = termsEnum.stream().map(InvertedIndexBuilder.TermsEnum::toPair).collect(Collectors.toList()); - for (Pair pair : iter) + var iter = termsEnum.stream() + .map(InvertedIndexBuilder::toTermWithFrequency) + .collect(Collectors.toList()); + for (Pair> pair : iter) { final byte[] bytes = ByteSourceInverse.readBytes(pair.left.asComparableBytes(VERSION)); try (PostingList actualPostingList = reader.exactMatch(ByteComparable.preencoded(VERSION, bytes), (QueryEventListener.TrieIndexEventListener)NO_OP_TRIE_LISTENER, new QueryContext())) { - final IntArrayList expectedPostingList = pair.right; + final List expectedPostingList = pair.right; assertNotNull(actualPostingList); assertEquals(expectedPostingList.size(), actualPostingList.size()); for (int i = 0; i < expectedPostingList.size(); ++i) { - final long expectedRowID = expectedPostingList.get(i); + final long expectedRowID = expectedPostingList.get(i).rowId; long result = actualPostingList.nextPosting(); assertEquals(String.format("row %d mismatch of %d in enum %d", i, expectedPostingList.size(), termsEnum.indexOf(pair)), expectedRowID, result); } @@ -188,18 +203,18 @@ private void testTermQueries(Version version, int numTerms, int numPostings) thr (QueryEventListener.TrieIndexEventListener)NO_OP_TRIE_LISTENER, new QueryContext())) { - final IntArrayList expectedPostingList = pair.right; + final List expectedPostingList = pair.right; // test skipping to the last block final int idxToSkip = numPostings - 2; // tokens are equal to their corresponding row IDs - final int tokenToSkip = expectedPostingList.get(idxToSkip); + final int tokenToSkip = expectedPostingList.get(idxToSkip).rowId; long advanceResult = actualPostingList.advance(tokenToSkip); assertEquals(tokenToSkip, advanceResult); for (int i = idxToSkip + 1; i < expectedPostingList.size(); ++i) { - final long expectedRowID = expectedPostingList.get(i); + final long expectedRowID = expectedPostingList.get(i).rowId; long result = actualPostingList.nextPosting(); assertEquals(expectedRowID, result); } @@ -211,6 +226,17 @@ private void testTermQueries(Version version, int numTerms, int numPostings) thr } } + private Int2IntHashMap createMockDocLengths(List termsEnum) + { + Int2IntHashMap docLengths = new Int2IntHashMap(Integer.MIN_VALUE); + for (InvertedIndexBuilder.TermsEnum term : termsEnum) + { + for (var cursor : term.postings) + docLengths.put(cursor.value, 1); + } + return docLengths; + } + private List buildTermsEnum(Version version, int terms, int postings) { return buildStringTermsEnum(version, terms, postings, () -> randomSimpleString(4, 10), () -> nextInt(0, Integer.MAX_VALUE)); diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/kdtree/ImmutableOneDimPointValuesTest.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/kdtree/ImmutableOneDimPointValuesTest.java index 9f086249f331..ba8f33cd9951 100644 --- a/test/unit/org/apache/cassandra/index/sai/disk/v1/kdtree/ImmutableOneDimPointValuesTest.java +++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/kdtree/ImmutableOneDimPointValuesTest.java @@ -19,12 +19,14 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.List; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; -import com.carrotsearch.hppc.IntArrayList; +import org.apache.cassandra.index.sai.memory.RowMapping; import org.apache.cassandra.db.marshal.Int32Type; import org.apache.cassandra.index.sai.disk.MemtableTermsIterator; import org.apache.cassandra.index.sai.disk.TermsIterator; @@ -111,20 +113,22 @@ private TermsIterator buildDescTermEnum(int from, int to) final ByteBuffer minTerm = Int32Type.instance.decompose(from); final ByteBuffer maxTerm = Int32Type.instance.decompose(to); - final AbstractGuavaIterator> iterator = new AbstractGuavaIterator>() + final AbstractGuavaIterator>> iterator = new AbstractGuavaIterator<>() { private int currentTerm = from; @Override - protected Pair computeNext() + protected Pair> computeNext() { if (currentTerm <= to) { return endOfData(); } final ByteBuffer term = Int32Type.instance.decompose(currentTerm++); - IntArrayList postings = new IntArrayList(); - postings.add(0, 1, 2); + List postings = Arrays.asList( + new RowMapping.RowIdWithFrequency(0, 1), + new RowMapping.RowIdWithFrequency(1, 1), + new RowMapping.RowIdWithFrequency(2, 1)); return Pair.create(v -> ByteSource.preencoded(term), postings); } }; diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/kdtree/KDTreeIndexBuilder.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/kdtree/KDTreeIndexBuilder.java index fdbd1f9d91f8..3a867e41d90f 100644 --- a/test/unit/org/apache/cassandra/index/sai/disk/v1/kdtree/KDTreeIndexBuilder.java +++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/kdtree/KDTreeIndexBuilder.java @@ -21,7 +21,9 @@ import java.math.BigDecimal; import java.math.BigInteger; import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.Iterator; +import java.util.List; import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -31,6 +33,7 @@ import org.junit.Assert; import com.carrotsearch.hppc.IntArrayList; +import org.apache.cassandra.index.sai.memory.RowMapping; import org.apache.cassandra.db.marshal.AbstractType; import org.apache.cassandra.db.marshal.DecimalType; import org.apache.cassandra.db.marshal.Int32Type; @@ -142,7 +145,22 @@ public KDTreeIndexBuilder(IndexDescriptor indexDescriptor, KDTreeIndexSearcher flushAndOpen() throws IOException { - final TermsIterator termEnum = new MemtableTermsIterator(null, null, terms); + // Wrap postings with RowIdWithFrequency using default frequency of 1 + final TermsIterator termEnum = new MemtableTermsIterator(null, null, new AbstractGuavaIterator<>() + { + @Override + protected Pair> computeNext() + { + if (!terms.hasNext()) + return endOfData(); + + Pair pair = terms.next(); + List postings = new ArrayList<>(pair.right.size()); + for (int i = 0; i < pair.right.size(); i++) + postings.add(new RowMapping.RowIdWithFrequency(pair.right.get(i), 1)); + return Pair.create(pair.left, postings); + } + }); final ImmutableOneDimPointValues pointValues = ImmutableOneDimPointValues.fromTermEnum(termEnum, type); final SegmentMetadata metadata; diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/kdtree/NumericIndexWriterTest.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/kdtree/NumericIndexWriterTest.java index dcc1e7ef056d..27bff7fa40cb 100644 --- a/test/unit/org/apache/cassandra/index/sai/disk/v1/kdtree/NumericIndexWriterTest.java +++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/kdtree/NumericIndexWriterTest.java @@ -18,6 +18,8 @@ package org.apache.cassandra.index.sai.disk.v1.kdtree; import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; import org.junit.Before; import org.junit.Test; @@ -35,6 +37,7 @@ import org.apache.cassandra.index.sai.disk.format.IndexDescriptor; import org.apache.cassandra.index.sai.disk.v1.IndexWriterConfig; import org.apache.cassandra.index.sai.disk.v1.SegmentMetadata; +import org.apache.cassandra.index.sai.memory.RowMapping; import org.apache.cassandra.index.sai.metrics.QueryEventListener; import org.apache.cassandra.index.sai.metrics.QueryEventListeners; import org.apache.cassandra.index.sai.utils.SaiRandomizedTest; @@ -190,21 +193,21 @@ private TermsIterator buildTermEnum(int startTermInclusive, int endTermExclusive final ByteBuffer minTerm = Int32Type.instance.decompose(startTermInclusive); final ByteBuffer maxTerm = Int32Type.instance.decompose(endTermExclusive); - final AbstractGuavaIterator> iterator = new AbstractGuavaIterator>() + final AbstractGuavaIterator>> iterator = new AbstractGuavaIterator<>() { private int currentTerm = startTermInclusive; private int currentRowId = 0; @Override - protected Pair computeNext() + protected Pair> computeNext() { if (currentTerm >= endTermExclusive) { return endOfData(); } final ByteBuffer term = Int32Type.instance.decompose(currentTerm++); - final IntArrayList postings = new IntArrayList(); - postings.add(currentRowId++); + final List postings = new ArrayList<>(); + postings.add(new RowMapping.RowIdWithFrequency(currentRowId++, 1)); final ByteSource encoded = Int32Type.instance.asComparableBytes(term, TypeUtil.BYTE_COMPARABLE_VERSION); return Pair.create(v -> encoded, postings); } diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingListTest.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingListTest.java new file mode 100644 index 000000000000..76285a4f0bc4 --- /dev/null +++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingListTest.java @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.disk.v1.postings; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.IntStream; + +import org.junit.Test; + +import org.apache.cassandra.index.sai.disk.PostingList; +import org.apache.cassandra.index.sai.postings.IntArrayPostingList; +import org.apache.cassandra.index.sai.utils.SaiRandomizedTest; +import org.apache.cassandra.utils.ByteBufferUtil; + +public class IntersectingPostingListTest extends SaiRandomizedTest +{ + private Map createPostingMap(PostingList... lists) + { + Map map = new HashMap<>(); + for (int i = 0; i < lists.length; i++) + { + map.put(ByteBufferUtil.bytes(String.valueOf((char) ('A' + i))), lists[i]); + } + return map; + } + + @Test + public void shouldIntersectOverlappingPostingLists() throws IOException + { + var map = createPostingMap(new IntArrayPostingList(new int[]{ 1, 4, 6, 8 }), + new IntArrayPostingList(new int[]{ 2, 4, 6, 9 }), + new IntArrayPostingList(new int[]{ 4, 6, 7 })); + + final PostingList intersected = IntersectingPostingList.intersect(map); + assertPostingListEquals(new IntArrayPostingList(new int[]{ 4, 6 }), intersected); + } + + @Test + public void shouldIntersectDisjointPostingLists() throws IOException + { + var map = createPostingMap(new IntArrayPostingList(new int[]{ 1, 3, 5 }), + new IntArrayPostingList(new int[]{ 2, 4, 6 })); + + final PostingList intersected = IntersectingPostingList.intersect(map); + assertPostingListEquals(new IntArrayPostingList(new int[]{}), intersected); + } + + @Test + public void shouldIntersectSinglePostingList() throws IOException + { + var map = createPostingMap(new IntArrayPostingList(new int[]{ 1, 4, 6 })); + + final PostingList intersected = IntersectingPostingList.intersect(map); + assertPostingListEquals(new IntArrayPostingList(new int[]{ 1, 4, 6 }), intersected); + } + + @Test + public void shouldIntersectIdenticalPostingLists() throws IOException + { + var map = createPostingMap(new IntArrayPostingList(new int[]{ 1, 2, 3 }), + new IntArrayPostingList(new int[]{ 1, 2, 3 })); + + final PostingList intersected = IntersectingPostingList.intersect(map); + assertPostingListEquals(new IntArrayPostingList(new int[]{ 1, 2, 3 }), intersected); + } + + @Test + public void shouldAdvanceAllIntersectedLists() throws IOException + { + var map = createPostingMap(new IntArrayPostingList(new int[]{ 1, 3, 5, 7, 9 }), + new IntArrayPostingList(new int[]{ 2, 3, 5, 7, 8 }), + new IntArrayPostingList(new int[]{ 3, 5, 7, 10 })); + + final PostingList intersected = IntersectingPostingList.intersect(map); + final PostingList expected = new IntArrayPostingList(new int[]{ 3, 5, 7 }); + + assertEquals(expected.advance(5), intersected.advance(5)); + assertPostingListEquals(expected, intersected); + } + + @Test + public void shouldHandleEmptyList() throws IOException + { + var map = createPostingMap(new IntArrayPostingList(new int[]{}), + new IntArrayPostingList(new int[]{ 1, 2, 3 })); + + final PostingList intersected = IntersectingPostingList.intersect(map); + assertEquals(PostingList.END_OF_STREAM, intersected.advance(1)); + } + + @Test + public void shouldInterleaveNextAndAdvance() throws IOException + { + var map = createPostingMap(new IntArrayPostingList(new int[]{ 1, 3, 5, 7, 9 }), + new IntArrayPostingList(new int[]{ 1, 3, 5, 7, 9 }), + new IntArrayPostingList(new int[]{ 1, 3, 5, 7, 9 })); + + final PostingList intersected = IntersectingPostingList.intersect(map); + + assertEquals(1, intersected.nextPosting()); + assertEquals(5, intersected.advance(5)); + assertEquals(7, intersected.nextPosting()); + assertEquals(9, intersected.advance(9)); + } + + @Test + public void shouldInterleaveNextAndAdvanceOnRandom() throws IOException + { + for (int i = 0; i < 1000; ++i) + { + testAdvancingOnRandom(); + } + } + + private void testAdvancingOnRandom() throws IOException + { + final int postingsCount = nextInt(1, 50_000); + final int postingListCount = nextInt(2, 10); + + final AtomicInteger rowId = new AtomicInteger(); + final int[] commonPostings = IntStream.generate(() -> rowId.addAndGet(nextInt(1, 10))) + .limit(postingsCount / 4) + .toArray(); + + var splitPostingLists = new ArrayList(); + for (int i = 0; i < postingListCount; i++) + { + final int[] uniquePostings = IntStream.generate(() -> rowId.addAndGet(nextInt(1, 10))) + .limit(postingsCount) + .toArray(); + int[] combined = IntStream.concat(IntStream.of(commonPostings), + IntStream.of(uniquePostings)) + .distinct() + .sorted() + .toArray(); + splitPostingLists.add(new IntArrayPostingList(combined)); + } + + final PostingList intersected = IntersectingPostingList.intersect(createPostingMap(splitPostingLists.toArray(new PostingList[0]))); + final PostingList expected = new IntArrayPostingList(commonPostings); + + final List actions = new ArrayList<>(); + for (int idx = 0; idx < commonPostings.length; idx++) + { + if (nextInt(0, 8) == 0) + { + actions.add((postingList) -> { + try + { + return postingList.nextPosting(); + } + catch (IOException e) + { + fail(e.getMessage()); + throw new RuntimeException(e); + } + }); + } + else + { + final int skips = nextInt(0, 5); + idx = Math.min(idx + skips, commonPostings.length - 1); + final int rowID = commonPostings[idx]; + actions.add((postingList) -> { + try + { + return postingList.advance(rowID); + } + catch (IOException e) + { + fail(e.getMessage()); + throw new RuntimeException(e); + } + }); + } + } + + for (PostingListAdvance action : actions) + { + assertEquals(action.advance(expected), action.advance(intersected)); + } + } + + private interface PostingListAdvance + { + long advance(PostingList list) throws IOException; + } +} diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/PostingsTest.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/PostingsTest.java index 184bfd3af8f0..451789e872c9 100644 --- a/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/PostingsTest.java +++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/PostingsTest.java @@ -328,7 +328,7 @@ private PostingsReader openReader(IndexComponent.ForRead postingLists, long fp, private PostingsReader.BlocksSummary assertBlockSummary(int blockSize, PostingList expected, IndexInput input) throws IOException { final PostingsReader.BlocksSummary summary = new PostingsReader.BlocksSummary(input, input.getFilePointer()); - assertEquals(blockSize, summary.blockSize); + assertEquals(blockSize, summary.blockEntries); assertEquals(expected.size(), summary.numPostings); assertTrue(summary.offsets.length() > 0); assertEquals(summary.offsets.length(), summary.maxValues.length()); diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/VectorPostingListTest.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/ReorderingPostingListTest.java similarity index 91% rename from test/unit/org/apache/cassandra/index/sai/disk/v1/postings/VectorPostingListTest.java rename to test/unit/org/apache/cassandra/index/sai/disk/v1/postings/ReorderingPostingListTest.java index 3135db3c7748..10b2b6a33f47 100644 --- a/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/VectorPostingListTest.java +++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/ReorderingPostingListTest.java @@ -30,14 +30,14 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -public class VectorPostingListTest +public class ReorderingPostingListTest { @Test public void ensureEmptySourceBehavesCorrectly() throws Throwable { var source = new TestIterator(CloseableIterator.emptyIterator()); - try (var postingList = new VectorPostingList(source)) + try (var postingList = new ReorderingPostingList(source, RowIdWithScore::getSegmentRowId)) { // Even an empty source should be closed assertTrue(source.isClosed); @@ -55,7 +55,7 @@ public void ensureIteratorIsConsumedClosedAndReordered() throws Throwable new RowIdWithScore(4, 4), }).iterator()); - try (var postingList = new VectorPostingList(source)) + try (var postingList = new ReorderingPostingList(source, RowIdWithScore::getSegmentRowId)) { // The posting list is eagerly consumed, so it should be closed before // we close postingList @@ -80,7 +80,7 @@ public void ensureAdvanceWorksCorrectly() throws Throwable new RowIdWithScore(2, 2), }).iterator()); - try (var postingList = new VectorPostingList(source)) + try (var postingList = new ReorderingPostingList(source, RowIdWithScore::getSegmentRowId)) { assertEquals(3, postingList.advance(3)); assertEquals(PostingList.END_OF_STREAM, postingList.advance(4)); diff --git a/test/unit/org/apache/cassandra/index/sai/memory/TrieMemoryIndexTest.java b/test/unit/org/apache/cassandra/index/sai/memory/TrieMemoryIndexTest.java index 4119b3afb872..6199266d6da0 100644 --- a/test/unit/org/apache/cassandra/index/sai/memory/TrieMemoryIndexTest.java +++ b/test/unit/org/apache/cassandra/index/sai/memory/TrieMemoryIndexTest.java @@ -82,17 +82,16 @@ public void iteratorShouldReturnAllValuesNumeric() index.add(makeKey(table, Integer.toString(row)), Clustering.EMPTY, Int32Type.instance.decompose(row / 10), allocatedBytes -> {}, allocatesBytes -> {}); } - Iterator> iterator = index.iterator(); + var iterator = index.iterator(); int valueCount = 0; while(iterator.hasNext()) { - Pair pair = iterator.next(); + var pair = iterator.next(); int value = ByteSourceInverse.getSignedInt(pair.left.asComparableBytes(TypeUtil.BYTE_COMPARABLE_VERSION)); int idCount = 0; - Iterator primaryKeyIterator = pair.right.iterator(); - while (primaryKeyIterator.hasNext()) + for (var pkf : pair.right) { - PrimaryKey primaryKey = primaryKeyIterator.next(); + PrimaryKey primaryKey = pkf.pk; int id = Int32Type.instance.compose(primaryKey.partitionKey().getKey()); assertEquals(id/10, value); idCount++; @@ -113,17 +112,16 @@ public void iteratorShouldReturnAllValuesString() index.add(makeKey(table, Integer.toString(row)), Clustering.EMPTY, UTF8Type.instance.decompose(Integer.toString(row / 10)), allocatedBytes -> {}, allocatesBytes -> {}); } - Iterator> iterator = index.iterator(); + var iterator = index.iterator(); int valueCount = 0; while(iterator.hasNext()) { - Pair pair = iterator.next(); + var pair = iterator.next(); String value = new String(ByteSourceInverse.readBytes(pair.left.asPeekableBytes(TypeUtil.BYTE_COMPARABLE_VERSION)), StandardCharsets.UTF_8); int idCount = 0; - Iterator primaryKeyIterator = pair.right.iterator(); - while (primaryKeyIterator.hasNext()) + for (var pkf : pair.right) { - PrimaryKey primaryKey = primaryKeyIterator.next(); + PrimaryKey primaryKey = pkf.pk; String id = UTF8Type.instance.compose(primaryKey.partitionKey().getKey()); assertEquals(Integer.toString(Integer.parseInt(id) / 10), value); idCount++; @@ -149,11 +147,11 @@ private void shouldAcceptPrefixValuesForType(AbstractType type, IntFunction {}, allocatesBytes -> {}); } - final Iterator> iterator = index.iterator(); + final var iterator = index.iterator(); int i = 0; while (iterator.hasNext()) { - Pair pair = iterator.next(); + var pair = iterator.next(); assertEquals(1, pair.right.size()); final int rowId = i; diff --git a/test/unit/org/apache/cassandra/index/sai/memory/TrieMemtableIndexTestBase.java b/test/unit/org/apache/cassandra/index/sai/memory/TrieMemtableIndexTestBase.java index 0f1089250348..21b13059de61 100644 --- a/test/unit/org/apache/cassandra/index/sai/memory/TrieMemtableIndexTestBase.java +++ b/test/unit/org/apache/cassandra/index/sai/memory/TrieMemtableIndexTestBase.java @@ -225,11 +225,11 @@ public void indexIteratorTest() DecoratedKey minimum = temp1.compareTo(temp2) <= 0 ? temp1 : temp2; DecoratedKey maximum = temp1.compareTo(temp2) <= 0 ? temp2 : temp1; - Iterator>> iterator = memtableIndex.iterator(minimum, maximum); + var iterator = memtableIndex.iterator(minimum, maximum); while (iterator.hasNext()) { - Pair> termPair = iterator.next(); + var termPair = iterator.next(); int term = termFromComparable(termPair.left); // The iterator will return keys outside the range of min/max, so we need to filter here to // get the correct keys @@ -239,9 +239,9 @@ public void indexIteratorTest() .sorted() .collect(Collectors.toList()); List termPks = new ArrayList<>(); - while (termPair.right.hasNext()) + for (var pkWithFreq : termPair.right) { - DecoratedKey pk = termPair.right.next().partitionKey(); + DecoratedKey pk = pkWithFreq.pk.partitionKey(); if (pk.compareTo(minimum) >= 0 && pk.compareTo(maximum) <= 0) termPks.add(pk); } diff --git a/test/unit/org/apache/cassandra/index/sai/memory/VectorMemtableIndexTest.java b/test/unit/org/apache/cassandra/index/sai/memory/VectorMemtableIndexTest.java index 75e8c83f30be..2197c00fe231 100644 --- a/test/unit/org/apache/cassandra/index/sai/memory/VectorMemtableIndexTest.java +++ b/test/unit/org/apache/cassandra/index/sai/memory/VectorMemtableIndexTest.java @@ -148,7 +148,7 @@ private void validate(List keys) { IntStream.range(0, 1_000).parallel().forEach(i -> { - var orderer = generateRandomOrderer(); + var orderer = randomVectorOrderer(); AbstractBounds keyRange = generateRandomBounds(keys); // compute keys in range of the bounds Set keysInRange = keys.stream().filter(keyRange::contains) @@ -197,7 +197,7 @@ public void indexIteratorTest() // VSTODO } - private Orderer generateRandomOrderer() + private Orderer randomVectorOrderer() { return new Orderer(indexContext, Operator.ANN, randomVectorSerialized()); } diff --git a/test/unit/org/apache/cassandra/index/sai/metrics/IndexMetricsTest.java b/test/unit/org/apache/cassandra/index/sai/metrics/IndexMetricsTest.java index cc65bd20c293..c59f10c8e42f 100644 --- a/test/unit/org/apache/cassandra/index/sai/metrics/IndexMetricsTest.java +++ b/test/unit/org/apache/cassandra/index/sai/metrics/IndexMetricsTest.java @@ -30,6 +30,7 @@ public class IndexMetricsTest extends AbstractMetricsTest { + private static final String TABLE = "table_name"; private static final String INDEX = "table_name_index"; @@ -155,7 +156,7 @@ public void testQueriesCount() int rowCount = 10; for (int i = 0; i < rowCount; i++) - execute("INSERT INTO %s (id1, v1, v2, v3) VALUES (?, ?, '0', [?, 0.0])", Integer.toString(i), i, i); + execute("INSERT INTO %s (id1, v1, v2, v3) VALUES (?, ?, '0', ?)", Integer.toString(i), i, vector(i, i)); assertIndexQueryCount(indexV1, 0L); @@ -180,7 +181,7 @@ public void testQueriesCount() assertIndexQueryCount(indexV1, 4L); assertIndexQueryCount(indexV2, 2L); - String indexV3 = createIndex("CREATE CUSTOM INDEX ON %s (v3) USING 'StorageAttachedIndex'"); + String indexV3 = createIndex("CREATE CUSTOM INDEX ON %s (v3) USING 'StorageAttachedIndex' WITH OPTIONS = {'similarity_function': 'euclidean'}"); assertIndexQueryCount(indexV3, 0L); executeNet("SELECT id1 FROM %s WHERE v2 = '2' ORDER BY v3 ANN OF [5,0] LIMIT 10"); assertIndexQueryCount(indexV1, 4L);