diff --git a/src/antlr/Lexer.g b/src/antlr/Lexer.g
index 3da08245c3d6..dd050c1478e7 100644
--- a/src/antlr/Lexer.g
+++ b/src/antlr/Lexer.g
@@ -227,6 +227,7 @@ K_DROPPED: D R O P P E D;
K_COLUMN: C O L U M N;
K_RECORD: R E C O R D;
K_ANN: A N N;
+K_BM25: B M '2' '5';
// Case-insensitive alpha characters
fragment A: ('a'|'A');
diff --git a/src/antlr/Parser.g b/src/antlr/Parser.g
index c0544620d51d..6b76c9d74bed 100644
--- a/src/antlr/Parser.g
+++ b/src/antlr/Parser.g
@@ -459,14 +459,18 @@ customIndexExpression [WhereClause.Builder clause]
;
orderByClause[List orderings]
- @init{
+ @init {
Ordering.Direction direction = Ordering.Direction.ASC;
+ Ordering.Raw.Expression expr = null;
}
- : c=cident (K_ANN K_OF t=term)? (K_ASC | K_DESC { direction = Ordering.Direction.DESC; })?
+ : c=cident
+ ( K_ANN K_OF t=term { expr = new Ordering.Raw.Ann(c, t); }
+ | K_BM25 K_OF t=term { expr = new Ordering.Raw.Bm25(c, t); }
+ )?
+ (K_ASC | K_DESC { direction = Ordering.Direction.DESC; })?
{
- Ordering.Raw.Expression expr = (t == null)
- ? new Ordering.Raw.SingleColumn(c)
- : new Ordering.Raw.Ann(c, t);
+ if (expr == null)
+ expr = new Ordering.Raw.SingleColumn(c);
orderings.add(new Ordering.Raw(expr, direction));
}
;
@@ -1969,6 +1973,7 @@ basic_unreserved_keyword returns [String str]
| K_COLUMN
| K_RECORD
| K_ANN
+ | K_BM25
| K_OFFSET
) { $str = $k.text; }
;
diff --git a/src/java/org/apache/cassandra/cql3/GeoDistanceRelation.java b/src/java/org/apache/cassandra/cql3/GeoDistanceRelation.java
index 3d5fb2eeeef7..640e7e600686 100644
--- a/src/java/org/apache/cassandra/cql3/GeoDistanceRelation.java
+++ b/src/java/org/apache/cassandra/cql3/GeoDistanceRelation.java
@@ -141,6 +141,12 @@ protected Restriction newAnnRestriction(TableMetadata table, VariableSpecificati
throw invalidRequest("%s cannot be used with the GEO_DISTANCE function", operator());
}
+ @Override
+ protected Restriction newBm25Restriction(TableMetadata table, VariableSpecifications boundNames)
+ {
+ throw invalidRequest("%s cannot be used with the GEO_DISTANCE function", operator());
+ }
+
@Override
protected Restriction newAnalyzerMatchesRestriction(TableMetadata table, VariableSpecifications boundNames)
{
diff --git a/src/java/org/apache/cassandra/cql3/MultiColumnRelation.java b/src/java/org/apache/cassandra/cql3/MultiColumnRelation.java
index fcd505fcf5df..f56d76e2ced9 100644
--- a/src/java/org/apache/cassandra/cql3/MultiColumnRelation.java
+++ b/src/java/org/apache/cassandra/cql3/MultiColumnRelation.java
@@ -250,6 +250,12 @@ protected Restriction newAnnRestriction(TableMetadata table, VariableSpecificati
throw invalidRequest("%s cannot be used for multi-column relations", operator());
}
+ @Override
+ protected Restriction newBm25Restriction(TableMetadata table, VariableSpecifications boundNames)
+ {
+ throw invalidRequest("%s cannot be used for multi-column relations", operator());
+ }
+
@Override
protected Restriction newAnalyzerMatchesRestriction(TableMetadata table, VariableSpecifications boundNames)
{
diff --git a/src/java/org/apache/cassandra/cql3/Operator.java b/src/java/org/apache/cassandra/cql3/Operator.java
index 9908c64c39ea..d393fc1aa416 100644
--- a/src/java/org/apache/cassandra/cql3/Operator.java
+++ b/src/java/org/apache/cassandra/cql3/Operator.java
@@ -323,7 +323,7 @@ public boolean isSatisfiedBy(AbstractType> type,
@Nullable Index.Analyzer indexAnalyzer,
@Nullable Index.Analyzer queryAnalyzer)
{
- return true;
+ throw new UnsupportedOperationException();
}
},
NOT_IN(16)
@@ -523,6 +523,7 @@ private boolean hasToken(AbstractType> type, List tokens, ByteBuff
return false;
}
},
+
/**
* An operator that performs a distance bounded approximate nearest neighbor search against a vector column such
* that all result vectors are within a given distance of the query vector. The notable difference between this
@@ -584,6 +585,24 @@ public boolean isSatisfiedBy(AbstractType> type,
{
throw new UnsupportedOperationException();
}
+ },
+ BM25(104)
+ {
+ @Override
+ public String toString()
+ {
+ return "BM25";
+ }
+
+ @Override
+ public boolean isSatisfiedBy(AbstractType> type,
+ ByteBuffer leftOperand,
+ ByteBuffer rightOperand,
+ @Nullable Index.Analyzer indexAnalyzer,
+ @Nullable Index.Analyzer queryAnalyzer)
+ {
+ throw new UnsupportedOperationException();
+ }
};
/**
diff --git a/src/java/org/apache/cassandra/cql3/Ordering.java b/src/java/org/apache/cassandra/cql3/Ordering.java
index 81aa94a076cd..2dd817818a76 100644
--- a/src/java/org/apache/cassandra/cql3/Ordering.java
+++ b/src/java/org/apache/cassandra/cql3/Ordering.java
@@ -20,6 +20,7 @@
import org.apache.cassandra.cql3.restrictions.SingleColumnRestriction;
import org.apache.cassandra.cql3.restrictions.SingleRestriction;
+import org.apache.cassandra.cql3.statements.SelectStatement;
import org.apache.cassandra.schema.ColumnMetadata;
import org.apache.cassandra.schema.TableMetadata;
@@ -48,6 +49,11 @@ public interface Expression
SingleRestriction toRestriction();
ColumnMetadata getColumn();
+
+ default boolean isScored()
+ {
+ return false;
+ }
}
/**
@@ -118,6 +124,54 @@ public ColumnMetadata getColumn()
{
return column;
}
+
+ @Override
+ public boolean isScored()
+ {
+ return SelectStatement.ANN_USE_SYNTHETIC_SCORE;
+ }
+ }
+
+ /**
+ * An expression used in BM25 ordering.
+ * ORDER BY column BM25 OF value
+ */
+ public static class Bm25 implements Expression
+ {
+ final ColumnMetadata column;
+ final Term queryValue;
+ final Direction direction;
+
+ public Bm25(ColumnMetadata column, Term queryValue, Direction direction)
+ {
+ this.column = column;
+ this.queryValue = queryValue;
+ this.direction = direction;
+ }
+
+ @Override
+ public boolean hasNonClusteredOrdering()
+ {
+ return true;
+ }
+
+ @Override
+ public SingleRestriction toRestriction()
+ {
+ return new SingleColumnRestriction.Bm25Restriction(column, queryValue);
+ }
+
+ @Override
+ public ColumnMetadata getColumn()
+ {
+ return column;
+ }
+
+ @Override
+ public boolean isScored()
+ {
+ return true;
+ }
}
public enum Direction
@@ -190,6 +244,27 @@ public Ordering.Expression bind(TableMetadata table, VariableSpecifications boun
return new Ordering.Ann(column, value, direction);
}
}
+
+ public static class Bm25 implements Expression
+ {
+ final ColumnIdentifier columnId;
+ final Term.Raw queryValue;
+
+ Bm25(ColumnIdentifier column, Term.Raw queryValue)
+ {
+ this.columnId = column;
+ this.queryValue = queryValue;
+ }
+
+ @Override
+ public Ordering.Expression bind(TableMetadata table, VariableSpecifications boundNames, Direction direction)
+ {
+ ColumnMetadata column = table.getExistingColumn(columnId);
+ Term value = queryValue.prepare(table.keyspace, column);
+ value.collectMarkerSpecification(boundNames);
+ return new Ordering.Bm25(column, value, direction);
+ }
+ }
}
}
diff --git a/src/java/org/apache/cassandra/cql3/Relation.java b/src/java/org/apache/cassandra/cql3/Relation.java
index 5cca2d257323..42cf3c9c8287 100644
--- a/src/java/org/apache/cassandra/cql3/Relation.java
+++ b/src/java/org/apache/cassandra/cql3/Relation.java
@@ -202,6 +202,8 @@ public final Restriction toRestriction(TableMetadata table, VariableSpecificatio
return newLikeRestriction(table, boundNames, relationType);
case ANN:
return newAnnRestriction(table, boundNames);
+ case BM25:
+ return newBm25Restriction(table, boundNames);
case ANALYZER_MATCHES:
return newAnalyzerMatchesRestriction(table, boundNames);
default: throw invalidRequest("Unsupported \"!=\" relation: %s", this);
@@ -296,6 +298,11 @@ protected abstract Restriction newSliceRestriction(TableMetadata table,
*/
protected abstract Restriction newAnnRestriction(TableMetadata table, VariableSpecifications boundNames);
+ /**
+ * Creates a new BM25 restriction instance.
+ */
+ protected abstract Restriction newBm25Restriction(TableMetadata table, VariableSpecifications boundNames);
+
/**
* Creates a new Analyzer Matches restriction instance.
*/
diff --git a/src/java/org/apache/cassandra/cql3/SingleColumnRelation.java b/src/java/org/apache/cassandra/cql3/SingleColumnRelation.java
index ec66ad70b529..2b97b99c3562 100644
--- a/src/java/org/apache/cassandra/cql3/SingleColumnRelation.java
+++ b/src/java/org/apache/cassandra/cql3/SingleColumnRelation.java
@@ -21,8 +21,11 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
+import java.util.stream.Collectors;
import org.apache.cassandra.db.marshal.VectorType;
+import org.apache.cassandra.index.IndexRegistry;
+import org.apache.cassandra.index.sai.analyzer.AnalyzerEqOperatorSupport;
import org.apache.cassandra.schema.ColumnMetadata;
import org.apache.cassandra.schema.TableMetadata;
import org.apache.cassandra.cql3.Term.Raw;
@@ -33,6 +36,7 @@
import org.apache.cassandra.db.marshal.ListType;
import org.apache.cassandra.db.marshal.MapType;
import org.apache.cassandra.exceptions.InvalidRequestException;
+import org.apache.cassandra.service.ClientWarn;
import static org.apache.cassandra.cql3.statements.RequestValidations.checkFalse;
import static org.apache.cassandra.cql3.statements.RequestValidations.checkTrue;
@@ -191,7 +195,29 @@ protected Restriction newEQRestriction(TableMetadata table, VariableSpecificatio
if (mapKey == null)
{
Term term = toTerm(toReceivers(columnDef), value, table.keyspace, boundNames);
- return new SingleColumnRestriction.EQRestriction(columnDef, term);
+ // Leave the restriction as EQ if no analyzed index in backwards compatibility mode is present
+ var ebi = IndexRegistry.obtain(table).getEqBehavior(columnDef);
+ if (ebi.behavior == IndexRegistry.EqBehavior.EQ)
+ return new SingleColumnRestriction.EQRestriction(columnDef, term);
+
+ // the index is configured to transform EQ into MATCH for backwards compatibility
+ var matchIndexName = ebi.matchIndex.getIndexMetadata() == null ? "Unknown" : ebi.matchIndex.getIndexMetadata().name;
+ if (ebi.behavior == IndexRegistry.EqBehavior.MATCH)
+ {
+ ClientWarn.instance.warn(String.format(AnalyzerEqOperatorSupport.EQ_RESTRICTION_ON_ANALYZED_WARNING,
+ columnDef.toString(),
+ matchIndexName),
+ columnDef);
+ return new SingleColumnRestriction.AnalyzerMatchesRestriction(columnDef, term);
+ }
+
+ // multiple indexes support EQ, this is unsupported
+ assert ebi.behavior == IndexRegistry.EqBehavior.AMBIGUOUS;
+ var eqIndexName = ebi.eqIndex.getIndexMetadata() == null ? "Unknown" : ebi.eqIndex.getIndexMetadata().name;
+ throw invalidRequest(AnalyzerEqOperatorSupport.EQ_AMBIGUOUS_ERROR,
+ columnDef.toString(),
+ matchIndexName,
+ eqIndexName);
}
List extends ColumnSpecification> receivers = toReceivers(columnDef);
Term entryKey = toTerm(Collections.singletonList(receivers.get(0)), mapKey, table.keyspace, boundNames);
@@ -333,6 +359,14 @@ protected Restriction newAnnRestriction(TableMetadata table, VariableSpecificati
return new SingleColumnRestriction.AnnRestriction(columnDef, term);
}
+ @Override
+ protected Restriction newBm25Restriction(TableMetadata table, VariableSpecifications boundNames)
+ {
+ ColumnMetadata columnDef = table.getExistingColumn(entity);
+ Term term = toTerm(toReceivers(columnDef), value, table.keyspace, boundNames);
+ return new SingleColumnRestriction.Bm25Restriction(columnDef, term);
+ }
+
@Override
protected Restriction newAnalyzerMatchesRestriction(TableMetadata table, VariableSpecifications boundNames)
{
diff --git a/src/java/org/apache/cassandra/cql3/TokenRelation.java b/src/java/org/apache/cassandra/cql3/TokenRelation.java
index a3ca586eee76..ca849dc82a30 100644
--- a/src/java/org/apache/cassandra/cql3/TokenRelation.java
+++ b/src/java/org/apache/cassandra/cql3/TokenRelation.java
@@ -138,7 +138,13 @@ protected Restriction newLikeRestriction(TableMetadata table, VariableSpecificat
@Override
protected Restriction newAnnRestriction(TableMetadata table, VariableSpecifications boundNames)
{
- throw invalidRequest("%s cannot be used for toekn relations", operator());
+ throw invalidRequest("%s cannot be used for token relations", operator());
+ }
+
+ @Override
+ protected Restriction newBm25Restriction(TableMetadata table, VariableSpecifications boundNames)
+ {
+ throw invalidRequest("%s cannot be used for token relations", operator());
}
@Override
diff --git a/src/java/org/apache/cassandra/cql3/restrictions/ClusteringColumnRestrictions.java b/src/java/org/apache/cassandra/cql3/restrictions/ClusteringColumnRestrictions.java
index 8ccd7fcaf37a..b2c08a065045 100644
--- a/src/java/org/apache/cassandra/cql3/restrictions/ClusteringColumnRestrictions.java
+++ b/src/java/org/apache/cassandra/cql3/restrictions/ClusteringColumnRestrictions.java
@@ -191,7 +191,7 @@ public ClusteringColumnRestrictions.Builder addRestriction(Restriction restricti
SingleRestriction lastRestriction = restrictions.lastRestriction();
ColumnMetadata lastRestrictionStart = lastRestriction.getFirstColumn();
ColumnMetadata newRestrictionStart = newRestriction.getFirstColumn();
- restrictions.addRestriction(newRestriction, isDisjunction, indexRegistry);
+ restrictions.addRestriction(newRestriction, isDisjunction);
checkFalse(lastRestriction.isSlice() && newRestrictionStart.position() > lastRestrictionStart.position(),
"Clustering column \"%s\" cannot be restricted (preceding column \"%s\" is restricted by a non-EQ relation)",
@@ -205,7 +205,7 @@ public ClusteringColumnRestrictions.Builder addRestriction(Restriction restricti
}
else
{
- restrictions.addRestriction(newRestriction, isDisjunction, indexRegistry);
+ restrictions.addRestriction(newRestriction, isDisjunction);
}
return this;
diff --git a/src/java/org/apache/cassandra/cql3/restrictions/PartitionKeySingleRestrictionSet.java b/src/java/org/apache/cassandra/cql3/restrictions/PartitionKeySingleRestrictionSet.java
index 2536f94f482d..ccb144b4b174 100644
--- a/src/java/org/apache/cassandra/cql3/restrictions/PartitionKeySingleRestrictionSet.java
+++ b/src/java/org/apache/cassandra/cql3/restrictions/PartitionKeySingleRestrictionSet.java
@@ -191,7 +191,7 @@ public PartitionKeyRestrictions build(IndexRegistry indexRegistry, boolean isDis
if (restriction.isOnToken())
return buildWithTokens(restrictionSet, i, indexRegistry);
- restrictionSet.addRestriction((SingleRestriction) restriction, isDisjunction, indexRegistry);
+ restrictionSet.addRestriction((SingleRestriction) restriction, isDisjunction);
}
return buildPartitionKeyRestrictions(restrictionSet);
diff --git a/src/java/org/apache/cassandra/cql3/restrictions/RestrictionSet.java b/src/java/org/apache/cassandra/cql3/restrictions/RestrictionSet.java
index 4cde1c434f33..dacb018e6186 100644
--- a/src/java/org/apache/cassandra/cql3/restrictions/RestrictionSet.java
+++ b/src/java/org/apache/cassandra/cql3/restrictions/RestrictionSet.java
@@ -398,45 +398,36 @@ private Builder()
{
}
- public void addRestriction(SingleRestriction restriction, boolean isDisjunction, IndexRegistry indexRegistry)
+ public void addRestriction(SingleRestriction restriction, boolean isDisjunction)
{
List columnDefs = restriction.getColumnDefs();
if (isDisjunction)
{
// If this restriction is part of a disjunction query then we don't want
- // to merge the restrictions (if that is possible), we just add the
- // restriction to the set of restrictions for the column.
+ // to merge the restrictions, we just add the new restriction
addRestrictionForColumns(columnDefs, restriction, null);
}
else
{
- // In some special cases such as EQ in analyzed index we need to skip merging the restriction,
- // so we can send multiple EQ restrictions to the index.
- if (restriction.skipMerge(indexRegistry))
- {
- addRestrictionForColumns(columnDefs, restriction, null);
- return;
- }
-
- // If this restriction isn't part of a disjunction then we need to get
- // the set of existing restrictions for the column and merge them with the
- // new restriction
+ // ANDed together restrictions against the same columns should be merged.
Set existingRestrictions = getRestrictions(newRestrictions, columnDefs);
- SingleRestriction merged = restriction;
- Set replacedRestrictions = new HashSet<>();
-
- for (SingleRestriction existing : existingRestrictions)
+ // merge the new restriction into an existing one. note that there is only ever a single
+ // restriction (per column), UNLESS one is ORDER BY BM25 and the other is MATCH.
+ for (var existing : existingRestrictions)
{
- if (!existing.skipMerge(indexRegistry))
+ // shouldMerge exists for the BM25/MATCH case
+ if (existing.shouldMerge(restriction))
{
- merged = existing.mergeWith(merged);
- replacedRestrictions.add(existing);
+ var merged = existing.mergeWith(restriction);
+ addRestrictionForColumns(merged.getColumnDefs(), merged, Set.of(existing));
+ return;
}
}
- addRestrictionForColumns(merged.getColumnDefs(), merged, replacedRestrictions);
+ // no existing restrictions that we should merge the new one with, add a new one
+ addRestrictionForColumns(columnDefs, restriction, null);
}
}
diff --git a/src/java/org/apache/cassandra/cql3/restrictions/SingleColumnRestriction.java b/src/java/org/apache/cassandra/cql3/restrictions/SingleColumnRestriction.java
index 46c2afcc1d8f..cd44a6f09000 100644
--- a/src/java/org/apache/cassandra/cql3/restrictions/SingleColumnRestriction.java
+++ b/src/java/org/apache/cassandra/cql3/restrictions/SingleColumnRestriction.java
@@ -24,15 +24,19 @@
import java.util.List;
import java.util.Map;
-import org.apache.cassandra.db.filter.RowFilter;
-import org.apache.cassandra.db.filter.ANNOptions;
-import org.apache.cassandra.schema.ColumnMetadata;
-import org.apache.cassandra.cql3.*;
+import org.apache.cassandra.cql3.MarkerOrTerms;
+import org.apache.cassandra.cql3.Operator;
+import org.apache.cassandra.cql3.QueryOptions;
+import org.apache.cassandra.cql3.Term;
+import org.apache.cassandra.cql3.Terms;
import org.apache.cassandra.cql3.functions.Function;
import org.apache.cassandra.cql3.statements.Bound;
import org.apache.cassandra.db.MultiClusteringBuilder;
+import org.apache.cassandra.db.filter.ANNOptions;
+import org.apache.cassandra.db.filter.RowFilter;
import org.apache.cassandra.index.Index;
import org.apache.cassandra.index.IndexRegistry;
+import org.apache.cassandra.schema.ColumnMetadata;
import org.apache.cassandra.serializers.ListSerializer;
import org.apache.cassandra.transport.ProtocolVersion;
import org.apache.cassandra.utils.ByteBufferUtil;
@@ -76,12 +80,17 @@ public ColumnMetadata getLastColumn()
@Override
public boolean hasSupportingIndex(IndexRegistry indexRegistry)
+ {
+ return findSupportingIndex(indexRegistry) != null;
+ }
+
+ public Index findSupportingIndex(IndexRegistry indexRegistry)
{
for (Index index : indexRegistry.listIndexes())
if (isSupportedBy(index))
- return true;
+ return index;
- return false;
+ return null;
}
@Override
@@ -190,24 +199,6 @@ public String toString()
return String.format("EQ(%s)", term);
}
- @Override
- public boolean skipMerge(IndexRegistry indexRegistry)
- {
- // We should skip merging this EQ if there is an analyzed index for this column that supports EQ,
- // so there can be multiple EQs for the same column.
-
- if (indexRegistry == null)
- return false;
-
- for (Index index : indexRegistry.listIndexes())
- {
- if (index.supportsExpression(columnDef, Operator.ANALYZER_MATCHES) &&
- index.supportsExpression(columnDef, Operator.EQ))
- return true;
- }
- return false;
- }
-
@Override
public SingleRestriction doMergeWith(SingleRestriction otherRestriction)
{
@@ -1188,6 +1179,86 @@ public boolean isBoundedAnn()
}
}
+ public static final class Bm25Restriction extends SingleColumnRestriction
+ {
+ private final Term value;
+
+ public Bm25Restriction(ColumnMetadata columnDef, Term value)
+ {
+ super(columnDef);
+ this.value = value;
+ }
+
+ public ByteBuffer value(QueryOptions options)
+ {
+ return value.bindAndGet(options);
+ }
+
+ @Override
+ public void addFunctionsTo(List functions)
+ {
+ value.addFunctionsTo(functions);
+ }
+
+ @Override
+ MultiColumnRestriction toMultiColumnRestriction()
+ {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void addToRowFilter(RowFilter.Builder filter, IndexRegistry indexRegistry, QueryOptions options, ANNOptions annOptions)
+ {
+ var index = findSupportingIndex(indexRegistry);
+ var valueBytes = value.bindAndGet(options);
+ var terms = index.getQueryAnalyzer().get().analyze(valueBytes);
+ if (terms.isEmpty())
+ throw invalidRequest("BM25 query must contain at least one term (perhaps your analyzer is discarding tokens you didn't expect)");
+ filter.add(columnDef, Operator.BM25, valueBytes);
+ }
+
+ @Override
+ public MultiClusteringBuilder appendTo(MultiClusteringBuilder builder, QueryOptions options)
+ {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public String toString()
+ {
+ return String.format("BM25(%s)", value);
+ }
+
+ @Override
+ public SingleRestriction doMergeWith(SingleRestriction otherRestriction)
+ {
+ throw invalidRequest("%s cannot be restricted by both BM25 and %s", columnDef.name, otherRestriction.toString());
+ }
+
+ @Override
+ protected boolean isSupportedBy(Index index)
+ {
+ return index.supportsExpression(columnDef, Operator.BM25);
+ }
+
+ @Override
+ public boolean isIndexBasedOrdering()
+ {
+ return true;
+ }
+
+ @Override
+ public boolean shouldMerge(SingleRestriction other)
+ {
+ // we don't want to merge MATCH restrictions with ORDER BY BM25
+ // so shouldMerge = false for that scenario, and true for others
+ // (because even though we can't meaningfully merge with others, we want doMergeWith to be called to throw)
+ //
+ // (Note that because ORDER BY is processed before WHERE, we only need this check in the BM25 class)
+ return !other.isAnalyzerMatches();
+ }
+ }
+
/**
* A Bounded ANN Restriction is one that uses a similarity score as the limiting factor for ANN instead of a number
* of results.
@@ -1335,10 +1406,12 @@ public String toString()
@Override
public SingleRestriction doMergeWith(SingleRestriction otherRestriction)
{
- if (!(otherRestriction.isAnalyzerMatches()))
+ if (!otherRestriction.isAnalyzerMatches())
throw invalidRequest(CANNOT_BE_MERGED_ERROR, columnDef.name);
- List otherValues = ((AnalyzerMatchesRestriction) otherRestriction).getValues();
+ List otherValues = otherRestriction instanceof AnalyzerMatchesRestriction
+ ? ((AnalyzerMatchesRestriction) otherRestriction).getValues()
+ : List.of(((EQRestriction) otherRestriction).term);
List newValues = new ArrayList<>(values.size() + otherValues.size());
newValues.addAll(values);
newValues.addAll(otherValues);
diff --git a/src/java/org/apache/cassandra/cql3/restrictions/SingleRestriction.java b/src/java/org/apache/cassandra/cql3/restrictions/SingleRestriction.java
index 595451f812de..bdd80badc0ae 100644
--- a/src/java/org/apache/cassandra/cql3/restrictions/SingleRestriction.java
+++ b/src/java/org/apache/cassandra/cql3/restrictions/SingleRestriction.java
@@ -20,7 +20,6 @@
import org.apache.cassandra.cql3.QueryOptions;
import org.apache.cassandra.cql3.statements.Bound;
import org.apache.cassandra.db.MultiClusteringBuilder;
-import org.apache.cassandra.index.IndexRegistry;
/**
* A single restriction/clause on one or multiple column.
@@ -97,17 +96,6 @@ public default boolean isInclusive(Bound b)
return true;
}
- /**
- * Checks if this restriction shouldn't be merged with other restrictions.
- *
- * @param indexRegistry the index registry
- * @return {@code true} if this shouldn't be merged with other restrictions
- */
- default boolean skipMerge(IndexRegistry indexRegistry)
- {
- return false;
- }
-
/**
* Merges this restriction with the specified one.
*
@@ -141,4 +129,16 @@ public default MultiClusteringBuilder appendBoundTo(MultiClusteringBuilder build
{
return appendTo(builder, options);
}
+
+ /**
+ * @return true if the other restriction should be merged with this one.
+ * This is NOT for preventing illegal combinations of restrictions, e.g.
+ * a=1 AND a=2; that is handled by mergeWith. Instead, this is for the case
+ * where we want two completely different semantics against the same column.
+ * Currently the only such case is BM25 with MATCH.
+ */
+ default boolean shouldMerge(SingleRestriction other)
+ {
+ return true;
+ }
}
diff --git a/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java b/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java
index b20d658867e8..ac49d7d6f694 100644
--- a/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java
+++ b/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java
@@ -83,6 +83,7 @@ public class StatementRestrictions
"Restriction on partition key column %s must not be nested under OR operator";
public static final String GEO_DISTANCE_REQUIRES_INDEX_MESSAGE = "GEO_DISTANCE requires the vector column to be indexed";
+ public static final String BM25_ORDERING_REQUIRES_ANALYZED_INDEX_MESSAGE = "BM25 ordering on column %s requires an analyzed index";
public static final String NON_CLUSTER_ORDERING_REQUIRES_INDEX_MESSAGE = "Ordering on non-clustering column %s requires the column to be indexed.";
public static final String NON_CLUSTER_ORDERING_REQUIRES_ALL_RESTRICTED_NON_PARTITION_KEY_COLUMNS_INDEXED_MESSAGE =
"Ordering on non-clustering column requires each restricted column to be indexed except for fully-specified partition keys";
@@ -445,7 +446,7 @@ else if (def.isClusteringColumn() && nestingLevel == 0)
}
else
{
- nonPrimaryKeyRestrictionSet.addRestriction((SingleRestriction) restriction, element.isDisjunction(), indexRegistry);
+ nonPrimaryKeyRestrictionSet.addRestriction((SingleRestriction) restriction, element.isDisjunction());
}
}
}
@@ -685,7 +686,8 @@ else if (indexOrderings.size() == 1)
if (orderings.size() > 1)
throw new InvalidRequestException("Cannot combine clustering column ordering with non-clustering column ordering");
Ordering ordering = indexOrderings.get(0);
- if (ordering.direction != Ordering.Direction.ASC && ordering.expression instanceof Ordering.Ann)
+ // TODO remove the instanceof with SelectStatement.ANN_USE_SYNTHETIC_SCORE.
+ if (ordering.direction != Ordering.Direction.ASC && (ordering.expression.isScored() || ordering.expression instanceof Ordering.Ann))
throw new InvalidRequestException("Descending ANN ordering is not supported");
if (!ENABLE_SAI_GENERAL_ORDER_BY && ordering.expression instanceof Ordering.SingleColumn)
throw new InvalidRequestException("SAI based ORDER BY on non-vector column is not supported");
@@ -698,10 +700,14 @@ else if (indexOrderings.size() == 1)
throw new InvalidRequestException(String.format("SAI based ordering on column %s of type %s is not supported",
restriction.getFirstColumn(),
restriction.getFirstColumn().type.asCQL3Type()));
- throw new InvalidRequestException(String.format(NON_CLUSTER_ORDERING_REQUIRES_INDEX_MESSAGE,
- restriction.getFirstColumn()));
+ if (ordering.expression instanceof Ordering.Bm25)
+ throw new InvalidRequestException(String.format(BM25_ORDERING_REQUIRES_ANALYZED_INDEX_MESSAGE,
+ restriction.getFirstColumn()));
+ else
+ throw new InvalidRequestException(String.format(NON_CLUSTER_ORDERING_REQUIRES_INDEX_MESSAGE,
+ restriction.getFirstColumn()));
}
- receiver.addRestriction(restriction, false, indexRegistry);
+ receiver.addRestriction(restriction, false);
}
}
diff --git a/src/java/org/apache/cassandra/cql3/selection/ColumnFilterFactory.java b/src/java/org/apache/cassandra/cql3/selection/ColumnFilterFactory.java
index 63fa0520101e..00225cca4108 100644
--- a/src/java/org/apache/cassandra/cql3/selection/ColumnFilterFactory.java
+++ b/src/java/org/apache/cassandra/cql3/selection/ColumnFilterFactory.java
@@ -38,9 +38,21 @@ abstract class ColumnFilterFactory
*/
abstract ColumnFilter newInstance(List selectors);
- public static ColumnFilterFactory wildcard(TableMetadata table)
+ public static ColumnFilterFactory wildcard(TableMetadata table, Set orderingColumns)
{
- return new PrecomputedColumnFilter(ColumnFilter.all(table));
+ ColumnFilter cf;
+ if (orderingColumns.isEmpty())
+ {
+ cf = ColumnFilter.all(table);
+ }
+ else
+ {
+ ColumnFilter.Builder builder = ColumnFilter.selectionBuilder();
+ builder.addAll(table.regularAndStaticColumns());
+ builder.addAll(orderingColumns);
+ cf = builder.build();
+ }
+ return new PrecomputedColumnFilter(cf);
}
public static ColumnFilterFactory fromColumns(TableMetadata table,
diff --git a/src/java/org/apache/cassandra/cql3/selection/Selection.java b/src/java/org/apache/cassandra/cql3/selection/Selection.java
index 02aae61dd5ff..12d8aa014e19 100644
--- a/src/java/org/apache/cassandra/cql3/selection/Selection.java
+++ b/src/java/org/apache/cassandra/cql3/selection/Selection.java
@@ -43,10 +43,24 @@ public abstract class Selection
private static final Predicate STATIC_COLUMN_FILTER = (column) -> column.isStatic();
private final TableMetadata table;
+
+ // Full list of columns needed for processing the query, including selected columns, ordering columns,
+ // and columns needed for restrictions. Wildcard columns are fully materialized here.
+ //
+ // This also includes synthetic columns, because unlike all the other not-physical-columns selectables, they are
+ // computed on the replica instead of the coordinator and so, like physical columns, they need to be sent back
+ // as part of the result.
private final List columns;
+
+ // maps ColumnSpecifications (columns, function calls, aliases) to the columns backing them
private final SelectionColumnMapping columnMapping;
+
+ // metadata matching the ColumnSpcifications
protected final ResultSet.ResultMetadata metadata;
+
+ // creates a ColumnFilter that breaks columns into `queried` and `fetched`
protected final ColumnFilterFactory columnFilterFactory;
+
protected final boolean isJson;
// Columns used to order the result set for JSON queries with post ordering.
@@ -126,10 +140,15 @@ public ResultSet.ResultMetadata getResultMetadata()
}
public static Selection wildcard(TableMetadata table, boolean isJson, boolean returnStaticContentOnPartitionWithNoRows)
+ {
+ return wildcard(table, Collections.emptySet(), isJson, returnStaticContentOnPartitionWithNoRows);
+ }
+
+ public static Selection wildcard(TableMetadata table, Set orderingColumns, boolean isJson, boolean returnStaticContentOnPartitionWithNoRows)
{
List all = new ArrayList<>(table.columns().size());
Iterators.addAll(all, table.allColumnsInSelectOrder());
- return new SimpleSelection(table, all, Collections.emptySet(), true, isJson, returnStaticContentOnPartitionWithNoRows);
+ return new SimpleSelection(table, all, orderingColumns, true, isJson, returnStaticContentOnPartitionWithNoRows);
}
public static Selection wildcardWithGroupBy(TableMetadata table,
@@ -400,7 +419,7 @@ public SimpleSelection(TableMetadata table,
selectedColumns,
orderingColumns,
SelectionColumnMapping.simpleMapping(selectedColumns),
- isWildcard ? ColumnFilterFactory.wildcard(table)
+ isWildcard ? ColumnFilterFactory.wildcard(table, orderingColumns)
: ColumnFilterFactory.fromColumns(table, selectedColumns, orderingColumns, Collections.emptySet(), returnStaticContentOnPartitionWithNoRows),
isWildcard,
isJson);
diff --git a/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java b/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java
index 58b75aacab5a..63fd7abd2b70 100644
--- a/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java
+++ b/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java
@@ -40,6 +40,7 @@
import org.apache.cassandra.cql3.restrictions.ExternalRestriction;
import org.apache.cassandra.cql3.restrictions.Restrictions;
import org.apache.cassandra.cql3.selection.SortedRowsBuilder;
+import org.apache.cassandra.db.marshal.FloatType;
import org.apache.cassandra.guardrails.Guardrails;
import org.apache.cassandra.sensors.SensorsCustomParams;
import org.apache.cassandra.schema.ColumnMetadata;
@@ -105,6 +106,12 @@
*/
public class SelectStatement implements CQLStatement.SingleKeyspaceCqlStatement
{
+ // TODO remove this when we no longer need to downgrade to replicas that don't know about synthetic columns,
+ // and the related code in
+ // - StatementRestrictions.addOrderingRestrictions
+ // - StorageAttachedIndexSearcher.PrimaryKeyIterator constructor
+ public static final boolean ANN_USE_SYNTHETIC_SCORE = Boolean.parseBoolean(System.getProperty("cassandra.sai.ann_use_synthetic_score", "false"));
+
private static final Logger logger = LoggerFactory.getLogger(SelectStatement.class);
private static final NoSpamLogger noSpamLogger = NoSpamLogger.getLogger(SelectStatement.logger, 1, TimeUnit.MINUTES);
public static final String TOPK_CONSISTENCY_LEVEL_ERROR = "Top-K queries can only be run with consistency level ONE/LOCAL_ONE. Consistency level %s was used.";
@@ -1100,12 +1107,16 @@ void processPartition(RowIterator partition, QueryOptions options, ResultSetBuil
case CLUSTERING:
result.add(row.clustering().bufferAt(def.position()));
break;
+ case SYNTHETIC:
+ // treat as REGULAR
case REGULAR:
result.add(row.getColumnData(def), nowInSec);
break;
case STATIC:
result.add(staticRow.getColumnData(def), nowInSec);
break;
+ default:
+ throw new AssertionError();
}
}
}
@@ -1191,6 +1202,9 @@ public SelectStatement prepare(boolean forView, UnaryOperator keyspaceMa
List selectables = RawSelector.toSelectables(selectClause, table);
boolean containsOnlyStaticColumns = selectOnlyStaticColumns(table, selectables);
+ // Besides actual restrictions (where clauses), prepareRestrictions will include pseudo-restrictions
+ // on indexed columns to allow pushing ORDER BY into the index; see StatementRestrictions::addOrderingRestrictions.
+ // Therefore, we don't want to convert an ANN Ordering column into a +score column until after that.
List orderings = getOrderings(table);
StatementRestrictions restrictions = prepareRestrictions(
table, bindVariables, orderings, containsOnlyStaticColumns, forView);
@@ -1198,6 +1212,11 @@ public SelectStatement prepare(boolean forView, UnaryOperator keyspaceMa
// If we order post-query, the sorted column needs to be in the ResultSet for sorting,
// even if we don't ultimately ship them to the client (CASSANDRA-4911).
Map orderingColumns = getOrderingColumns(orderings);
+ // +score column for ANN/BM25
+ var scoreOrdering = getScoreOrdering(orderings);
+ assert scoreOrdering == null || orderingColumns.isEmpty() : "can't have both scored ordering and column ordering";
+ if (scoreOrdering != null)
+ orderingColumns = scoreOrdering;
Set resultSetOrderingColumns = getResultSetOrdering(restrictions, orderingColumns);
Selection selection = prepareSelection(table,
@@ -1226,9 +1245,9 @@ public SelectStatement prepare(boolean forView, UnaryOperator keyspaceMa
if (!orderingColumns.isEmpty())
{
assert !forView;
- verifyOrderingIsAllowed(restrictions, orderingColumns);
+ verifyOrderingIsAllowed(table, restrictions, orderingColumns);
orderingComparator = getOrderingComparator(selection, restrictions, orderingColumns);
- isReversed = isReversed(table, orderingColumns, restrictions);
+ isReversed = isReversed(table, orderingColumns);
if (isReversed && orderingComparator != null)
orderingComparator = orderingComparator.reverse();
}
@@ -1252,6 +1271,21 @@ public SelectStatement prepare(boolean forView, UnaryOperator keyspaceMa
options);
}
+ private Map getScoreOrdering(List orderings)
+ {
+ if (orderings.isEmpty())
+ return null;
+
+ var expr = orderings.get(0).expression;
+ if (!expr.isScored())
+ return null;
+
+ // Create synthetic score column
+ ColumnMetadata sourceColumn = expr.getColumn();
+ var cm = ColumnMetadata.syntheticColumn(sourceColumn.ksName, sourceColumn.cfName, ColumnMetadata.SYNTHETIC_SCORE_ID, FloatType.instance);
+ return Map.of(cm, orderings.get(0));
+ }
+
private Set getResultSetOrdering(StatementRestrictions restrictions, Map orderingColumns)
{
if (restrictions.keyIsInRelation() || orderingColumns.values().stream().anyMatch(o -> o.expression.hasNonClusteredOrdering()))
@@ -1269,8 +1303,9 @@ private Selection prepareSelection(TableMetadata table,
if (selectables.isEmpty()) // wildcard query
{
- return hasGroupBy ? Selection.wildcardWithGroupBy(table, boundNames, parameters.isJson, restrictions.returnStaticContentOnPartitionWithNoRows())
- : Selection.wildcard(table, parameters.isJson, restrictions.returnStaticContentOnPartitionWithNoRows());
+ return hasGroupBy
+ ? Selection.wildcardWithGroupBy(table, boundNames, parameters.isJson, restrictions.returnStaticContentOnPartitionWithNoRows())
+ : Selection.wildcard(table, parameters.isJson, restrictions.returnStaticContentOnPartitionWithNoRows());
}
return Selection.fromSelectors(table,
@@ -1312,13 +1347,14 @@ private Map getOrderingColumns(List ordering
if (orderings.isEmpty())
return Collections.emptyMap();
- Map orderingColumns = new LinkedHashMap<>();
- for (Ordering ordering : orderings)
- {
- ColumnMetadata column = ordering.expression.getColumn();
- orderingColumns.put(column, ordering);
- }
- return orderingColumns;
+ return orderings.stream()
+ .filter(ordering -> !ordering.expression.isScored())
+ .collect(Collectors.toMap(ordering -> ordering.expression.getColumn(),
+ ordering -> ordering,
+ (a, b) -> {
+ throw new IllegalStateException("Duplicate keys");
+ },
+ LinkedHashMap::new));
}
private List getOrderings(TableMetadata table)
@@ -1365,12 +1401,28 @@ private Term prepareLimit(VariableSpecifications boundNames, Term.Raw limit,
return prepLimit;
}
- private static void verifyOrderingIsAllowed(StatementRestrictions restrictions, Map orderingColumns) throws InvalidRequestException
+ private static void verifyOrderingIsAllowed(TableMetadata table, StatementRestrictions restrictions, Map orderingColumns) throws InvalidRequestException
{
if (orderingColumns.values().stream().anyMatch(o -> o.expression.hasNonClusteredOrdering()))
return;
+
checkFalse(restrictions.usesSecondaryIndexing(), "ORDER BY with 2ndary indexes is not supported.");
checkFalse(restrictions.isKeyRange(), "ORDER BY is only supported when the partition key is restricted by an EQ or an IN.");
+
+ // check that clustering columns are valid
+ int i = 0;
+ for (var entry : orderingColumns.entrySet())
+ {
+ ColumnMetadata def = entry.getKey();
+ checkTrue(def.isClusteringColumn(),
+ "Order by is currently only supported on indexed columns and the clustered columns of the PRIMARY KEY, got %s", def.name);
+ while (i != def.position())
+ {
+ checkTrue(restrictions.isColumnRestrictedByEq(table.clusteringColumns().get(i++)),
+ "Ordering by clustered columns must follow the declared order in the PRIMARY KEY");
+ }
+ i++;
+ }
}
private static void validateDistinctSelection(TableMetadata metadata,
@@ -1485,35 +1537,30 @@ private ColumnComparator> getOrderingComparator(Selection selec
: new CompositeComparator(sorters, idToSort);
}
- private boolean isReversed(TableMetadata table, Map orderingColumns, StatementRestrictions restrictions) throws InvalidRequestException
+ private boolean isReversed(TableMetadata table, Map orderingColumns) throws InvalidRequestException
{
- // Nonclustered ordering handles descending logic in a different way
+ // Nonclustered ordering handles descending logic through ScoreOrderedResultRetriever and TKP
if (orderingColumns.values().stream().anyMatch(o -> o.expression.hasNonClusteredOrdering()))
return false;
- Boolean[] reversedMap = new Boolean[table.clusteringColumns().size()];
- int i = 0;
+ Boolean[] clusteredMap = new Boolean[table.clusteringColumns().size()];
for (var entry : orderingColumns.entrySet())
{
ColumnMetadata def = entry.getKey();
Ordering ordering = entry.getValue();
- boolean reversed = ordering.direction == Ordering.Direction.DESC;
-
- // VSTODO move this to verifyOrderingIsAllowed?
- checkTrue(def.isClusteringColumn(),
- "Order by is currently only supported on the clustered columns of the PRIMARY KEY, got %s", def.name);
- while (i != def.position())
- {
- checkTrue(restrictions.isColumnRestrictedByEq(table.clusteringColumns().get(i++)),
- "Order by currently only supports the ordering of columns following their declared order in the PRIMARY KEY");
- }
- i++;
- reversedMap[def.position()] = (reversed != def.isReversedType());
+ // We defined ANN OF to be ASC ordering, as in, "order by near-ness". But since score goes from
+ // 0 (worst) to 1 (closest), we need to reverse the ordering for the comparator when we're sorting
+ // by synthetic +score column.
+ boolean cqlReversed = ordering.direction == Ordering.Direction.DESC;
+ if (def.position() == ColumnMetadata.NO_POSITION)
+ return ordering.expression.isScored() || cqlReversed;
+ else
+ clusteredMap[def.position()] = (cqlReversed != def.isReversedType());
}
- // Check that all boolean in reversedMap, if set, agrees
+ // Check that all boolean in clusteredMap, if set, agrees
Boolean isReversed = null;
- for (Boolean b : reversedMap)
+ for (Boolean b : clusteredMap)
{
// Column on which order is specified can be in any order
if (b == null)
@@ -1658,7 +1705,14 @@ public int compare(T o1, T o2)
{
return wrapped.compare(o2, o1);
}
+
+ @Override
+ public boolean indexOrdering()
+ {
+ return wrapped.indexOrdering();
+ }
}
+
/**
* Used in orderResults(...) method when single 'ORDER BY' condition where given
*/
diff --git a/src/java/org/apache/cassandra/db/Columns.java b/src/java/org/apache/cassandra/db/Columns.java
index 9a904de12253..45b3da97a596 100644
--- a/src/java/org/apache/cassandra/db/Columns.java
+++ b/src/java/org/apache/cassandra/db/Columns.java
@@ -28,6 +28,7 @@
import net.nicoulaj.compilecommand.annotations.DontInline;
import org.apache.cassandra.cql3.ColumnIdentifier;
+import org.apache.cassandra.db.marshal.AbstractType;
import org.apache.cassandra.db.marshal.SetType;
import org.apache.cassandra.db.marshal.UTF8Type;
import org.apache.cassandra.db.rows.ColumnData;
@@ -36,6 +37,7 @@
import org.apache.cassandra.io.util.DataOutputPlus;
import org.apache.cassandra.schema.ColumnMetadata;
import org.apache.cassandra.schema.TableMetadata;
+import org.apache.cassandra.serializers.AbstractTypeSerializer;
import org.apache.cassandra.utils.ByteBufferUtil;
import org.apache.cassandra.utils.ObjectSizes;
import org.apache.cassandra.utils.SearchIterator;
@@ -459,37 +461,107 @@ public String toString()
public static class Serializer
{
+ AbstractTypeSerializer typeSerializer = new AbstractTypeSerializer();
+
public void serialize(Columns columns, DataOutputPlus out) throws IOException
{
- out.writeUnsignedVInt(columns.size());
+ int regularCount = 0;
+ int syntheticCount = 0;
+
+ // Count regular and synthetic columns
+ for (ColumnMetadata column : columns)
+ {
+ if (column.isSynthetic())
+ syntheticCount++;
+ else
+ regularCount++;
+ }
+
+ // Jam the two counts into a single value to avoid massive backwards compatibility issues
+ long packedCount = getPackedCount(syntheticCount, regularCount);
+ out.writeUnsignedVInt(packedCount);
+
+ // First pass - write synthetic columns with their full metadata
for (ColumnMetadata column : columns)
- ByteBufferUtil.writeWithVIntLength(column.name.bytes, out);
+ {
+ if (column.isSynthetic())
+ {
+ ByteBufferUtil.writeWithVIntLength(column.name.bytes, out);
+ typeSerializer.serialize(column.type, out);
+ }
+ }
+
+ // Second pass - write regular columns
+ for (ColumnMetadata column : columns)
+ {
+ if (!column.isSynthetic())
+ ByteBufferUtil.writeWithVIntLength(column.name.bytes, out);
+ }
+ }
+
+ private static long getPackedCount(int syntheticCount, int regularCount)
+ {
+ // Left shift of 20 gives us over 1M regular columns, and up to 4 synthetic columns
+ // before overflowing to a 4th byte.
+ return ((long) syntheticCount << 20) | regularCount;
}
public long serializedSize(Columns columns)
{
- long size = TypeSizes.sizeofUnsignedVInt(columns.size());
+ int regularCount = 0;
+ int syntheticCount = 0;
+ long size = 0;
+
+ // Count and calculate sizes
for (ColumnMetadata column : columns)
- size += ByteBufferUtil.serializedSizeWithVIntLength(column.name.bytes);
- return size;
+ {
+ if (column.isSynthetic())
+ {
+ syntheticCount++;
+ size += ByteBufferUtil.serializedSizeWithVIntLength(column.name.bytes);
+ size += typeSerializer.serializedSize(column.type);
+ }
+ else
+ {
+ regularCount++;
+ size += ByteBufferUtil.serializedSizeWithVIntLength(column.name.bytes);
+ }
+ }
+
+ return TypeSizes.sizeofUnsignedVInt(getPackedCount(syntheticCount, regularCount)) +
+ size;
}
public Columns deserialize(DataInputPlus in, TableMetadata metadata) throws IOException
{
- int length = (int)in.readUnsignedVInt();
try (BTree.FastBuilder builder = BTree.fastBuilder())
{
- for (int i = 0; i < length; i++)
+ long packedCount = in.readUnsignedVInt() ;
+ int regularCount = (int) (packedCount & 0xFFFFF);
+ int syntheticCount = (int) (packedCount >> 20);
+
+ // First pass - synthetic columns
+ for (int i = 0; i < syntheticCount; i++)
+ {
+ ByteBuffer name = ByteBufferUtil.readWithVIntLength(in);
+ AbstractType> type = typeSerializer.deserialize(in);
+
+ if (!name.equals(ColumnMetadata.SYNTHETIC_SCORE_ID.bytes))
+ throw new IllegalStateException("Unknown synthetic column " + UTF8Type.instance.getString(name));
+
+ ColumnMetadata column = ColumnMetadata.syntheticColumn(metadata.keyspace, metadata.name, ColumnMetadata.SYNTHETIC_SCORE_ID, type);
+ builder.add(column);
+ }
+
+ // Second pass - regular columns
+ for (int i = 0; i < regularCount; i++)
{
ByteBuffer name = ByteBufferUtil.readWithVIntLength(in);
ColumnMetadata column = metadata.getColumn(name);
if (column == null)
{
- // If we don't find the definition, it could be we have data for a dropped column, and we shouldn't
- // fail deserialization because of that. So we grab a "fake" ColumnMetadata that ensure proper
- // deserialization. The column will be ignore later on anyway.
+ // If we don't find the definition, it could be we have data for a dropped column
column = metadata.getDroppedColumn(name);
-
if (column == null)
throw new RuntimeException("Unknown column " + UTF8Type.instance.getString(name) + " during deserialization of " + metadata.keyspace + '.' + metadata.name);
}
diff --git a/src/java/org/apache/cassandra/db/ReadCommand.java b/src/java/org/apache/cassandra/db/ReadCommand.java
index b110a3ada2de..7fa00f0436e5 100644
--- a/src/java/org/apache/cassandra/db/ReadCommand.java
+++ b/src/java/org/apache/cassandra/db/ReadCommand.java
@@ -74,6 +74,7 @@
import org.apache.cassandra.net.Message;
import org.apache.cassandra.net.MessageFlag;
import org.apache.cassandra.net.Verb;
+import org.apache.cassandra.schema.ColumnMetadata;
import org.apache.cassandra.schema.IndexMetadata;
import org.apache.cassandra.schema.Schema;
import org.apache.cassandra.schema.SchemaConstants;
@@ -384,8 +385,6 @@ static Index.QueryPlan findIndexQueryPlan(TableMetadata table, RowFilter rowFilt
@Override
public void maybeValidateIndexes()
{
- IndexRegistry.obtain(metadata()).validate(rowFilter());
-
if (null != indexQueryPlan)
indexQueryPlan.validate(this);
}
@@ -415,9 +414,9 @@ public UnfilteredPartitionIterator executeLocally(ReadExecutionController execut
}
Context context = Context.from(this);
- UnfilteredPartitionIterator iterator = (null == searcher) ? Transformation.apply(queryStorage(cfs, executionController), new TrackingRowIterator(context))
- : Transformation.apply(searchStorage(searcher, executionController), new TrackingRowIterator(context));
-
+ var storageTarget = (null == searcher) ? queryStorage(cfs, executionController)
+ : searchStorage(searcher, executionController);
+ UnfilteredPartitionIterator iterator = Transformation.apply(storageTarget, new TrackingRowIterator(context));
iterator = RTBoundValidator.validate(iterator, Stage.MERGED, false);
try
@@ -1054,6 +1053,19 @@ public ReadCommand deserialize(DataInputPlus in, int version) throws IOException
TableMetadata metadata = schema.getExistingTableMetadata(TableId.deserialize(in));
int nowInSec = in.readInt();
ColumnFilter columnFilter = ColumnFilter.serializer.deserialize(in, version, metadata);
+
+ // add synthetic columns to the tablemetadata so we can serialize them in our response
+ var tmb = metadata.unbuild();
+ for (var it = columnFilter.fetchedColumns().regulars.simpleColumns(); it.hasNext(); )
+ {
+ var c = it.next();
+ // synthetic columns sort first, so when we hit the first non-synthetic, we're done
+ if (!c.isSynthetic())
+ break;
+ tmb.addColumn(ColumnMetadata.syntheticColumn(c.ksName, c.cfName, c.name, c.type));
+ }
+ metadata = tmb.build();
+
RowFilter rowFilter = RowFilter.serializer.deserialize(in, version, metadata);
DataLimits limits = DataLimits.serializer.deserialize(in, version, metadata.comparator);
Index.QueryPlan indexQueryPlan = null;
diff --git a/src/java/org/apache/cassandra/db/RegularAndStaticColumns.java b/src/java/org/apache/cassandra/db/RegularAndStaticColumns.java
index b6da183d013f..55533eda0e97 100644
--- a/src/java/org/apache/cassandra/db/RegularAndStaticColumns.java
+++ b/src/java/org/apache/cassandra/db/RegularAndStaticColumns.java
@@ -163,7 +163,7 @@ public Builder add(ColumnMetadata c)
}
else
{
- assert c.isRegular();
+ assert c.isRegular() || c.isSynthetic();
if (regularColumns == null)
regularColumns = BTree.builder(naturalOrder());
regularColumns.add(c);
@@ -197,7 +197,7 @@ public Builder addAll(RegularAndStaticColumns columns)
public RegularAndStaticColumns build()
{
- return new RegularAndStaticColumns(staticColumns == null ? Columns.NONE : Columns.from(staticColumns),
+ return new RegularAndStaticColumns(staticColumns == null ? Columns.NONE : Columns.from(staticColumns),
regularColumns == null ? Columns.NONE : Columns.from(regularColumns));
}
}
diff --git a/src/java/org/apache/cassandra/db/filter/ColumnFilter.java b/src/java/org/apache/cassandra/db/filter/ColumnFilter.java
index d9a1b9d4e51a..644e6d661a61 100644
--- a/src/java/org/apache/cassandra/db/filter/ColumnFilter.java
+++ b/src/java/org/apache/cassandra/db/filter/ColumnFilter.java
@@ -75,6 +75,9 @@ public abstract class ColumnFilter
public static final Serializer serializer = new Serializer();
+ // TODO remove this with ANN_USE_SYNTHETIC_SCORE
+ public abstract boolean fetchesExplicitly(ColumnMetadata column);
+
/**
* The fetching strategy for the different queries.
*/
@@ -103,7 +106,8 @@ boolean fetchesAllColumns(boolean isStatic)
@Override
RegularAndStaticColumns getFetchedColumns(TableMetadata metadata, RegularAndStaticColumns queried)
{
- return metadata.regularAndStaticColumns();
+ var merged = queried.regulars.mergeTo(metadata.regularColumns());
+ return new RegularAndStaticColumns(metadata.staticColumns(), merged);
}
},
@@ -124,7 +128,8 @@ boolean fetchesAllColumns(boolean isStatic)
@Override
RegularAndStaticColumns getFetchedColumns(TableMetadata metadata, RegularAndStaticColumns queried)
{
- return new RegularAndStaticColumns(queried.statics, metadata.regularColumns());
+ var merged = queried.regulars.mergeTo(metadata.regularColumns());
+ return new RegularAndStaticColumns(queried.statics, merged);
}
},
@@ -295,14 +300,16 @@ public static ColumnFilter selection(TableMetadata metadata,
}
/**
- * The columns that needs to be fetched internally for this filter.
+ * The columns that needs to be fetched internally. See FetchingStrategy for why this is
+ * always a superset of the queried columns.
*
* @return the columns to fetch for this filter.
*/
public abstract RegularAndStaticColumns fetchedColumns();
/**
- * The columns actually queried by the user.
+ * The columns needed to process the query, including selected columns, ordering columns,
+ * restriction (predicate) columns, and synthetic columns.
*
* Note that this is in general not all the columns that are fetched internally (see {@link #fetchedColumns}).
*/
@@ -619,9 +626,7 @@ private SortedSetMultimap buildSubSelectio
*/
public static class WildCardColumnFilter extends ColumnFilter
{
- /**
- * The queried and fetched columns.
- */
+ // for wildcards, there is no distinction between fetched and queried because queried is already "everything"
private final RegularAndStaticColumns fetchedAndQueried;
/**
@@ -667,6 +672,12 @@ public boolean fetches(ColumnMetadata column)
return true;
}
+ @Override
+ public boolean fetchesExplicitly(ColumnMetadata column)
+ {
+ return false;
+ }
+
@Override
public boolean fetchedColumnIsQueried(ColumnMetadata column)
{
@@ -739,14 +750,9 @@ public static class SelectionColumnFilter extends ColumnFilter
{
public final FetchingStrategy fetchingStrategy;
- /**
- * The selected columns
- */
+ // Materializes the columns required to implement queriedColumns() and fetchedColumns(),
+ // see the comments to superclass's methods
private final RegularAndStaticColumns queried;
-
- /**
- * The columns that need to be fetched to be able
- */
private final RegularAndStaticColumns fetched;
private final SortedSetMultimap subSelections; // can be null
@@ -820,6 +826,12 @@ public boolean fetches(ColumnMetadata column)
return fetchingStrategy.fetchesAllColumns(column.isStatic()) || fetched.contains(column);
}
+ @Override
+ public boolean fetchesExplicitly(ColumnMetadata column)
+ {
+ return fetched.contains(column);
+ }
+
/**
* Whether the provided complex cell (identified by its column and path), which is assumed to be _fetched_ by
* this filter, is also _queried_ by the user.
diff --git a/src/java/org/apache/cassandra/db/filter/RowFilter.java b/src/java/org/apache/cassandra/db/filter/RowFilter.java
index 96ad240b539c..4d35566fa41f 100644
--- a/src/java/org/apache/cassandra/db/filter/RowFilter.java
+++ b/src/java/org/apache/cassandra/db/filter/RowFilter.java
@@ -1239,6 +1239,7 @@ public boolean isSatisfiedBy(TableMetadata metadata, DecoratedKey partitionKey,
case LIKE_MATCHES:
case ANALYZER_MATCHES:
case ANN:
+ case BM25:
{
assert !column.isComplex() : "Only CONTAINS and CONTAINS_KEY are supported for 'complex' types";
ByteBuffer foundValue = getValue(metadata, partitionKey, row);
diff --git a/src/java/org/apache/cassandra/db/monitoring/MonitorableImpl.java b/src/java/org/apache/cassandra/db/monitoring/MonitorableImpl.java
index a6e7947b23f1..30376eb385f3 100644
--- a/src/java/org/apache/cassandra/db/monitoring/MonitorableImpl.java
+++ b/src/java/org/apache/cassandra/db/monitoring/MonitorableImpl.java
@@ -18,6 +18,8 @@
package org.apache.cassandra.db.monitoring;
+import org.apache.cassandra.index.sai.QueryContext;
+
import static org.apache.cassandra.utils.MonotonicClock.approxTime;
public abstract class MonitorableImpl implements Monitorable
@@ -123,6 +125,9 @@ public boolean complete()
private void check()
{
+ if (QueryContext.DISABLE_TIMEOUT)
+ return;
+
if (approxCreationTimeNanos < 0 || state != MonitoringState.IN_PROGRESS)
return;
diff --git a/src/java/org/apache/cassandra/db/partitions/ParallelCommandProcessor.java b/src/java/org/apache/cassandra/db/partitions/ParallelCommandProcessor.java
deleted file mode 100644
index 19d8523bd5c9..000000000000
--- a/src/java/org/apache/cassandra/db/partitions/ParallelCommandProcessor.java
+++ /dev/null
@@ -1,65 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.cassandra.db.partitions;
-
-import java.util.List;
-
-import org.apache.cassandra.db.SinglePartitionReadCommand;
-import org.apache.cassandra.db.rows.UnfilteredRowIterator;
-import org.apache.cassandra.index.sai.utils.PrimaryKey;
-import org.apache.cassandra.utils.Pair;
-
-/**
- * An "processor" over a number of unfiltered partitions (i.e. partitions containing deletion information).
- *
- * Unlike {@link UnfilteredPartitionIterator}, this is designed to be used concurrently.
- *
- * Unlike UnfilteredPartitionIterator which requires single-threaded
- * while (partitions.hasNext())
- * {
- * var part = partitions.next();
- * ...
- * part.close();
- * }
- *
- * this one allows concurrency, like
- * var commands = partitions.getUninitializedCommands();
- * commands.parallelStream().forEach(tuple -> {
- * var iter = partitions.commandToIterator(tuple.left(), tuple.right());
- * }
- */
-public interface ParallelCommandProcessor
-{
- /**
- * Single-threaded call to get all commands and corresponding keys.
- *
- * @return the list of partition read commands.
- */
- List> getUninitializedCommands();
-
- /**
- * Get an iterator for a given command and key.
- * This method can be called concurrently for reulst of getUninitializedCommands().
- *
- * @param command
- * @param key
- * @return
- */
- UnfilteredRowIterator commandToIterator(PrimaryKey key, SinglePartitionReadCommand command);
-}
diff --git a/src/java/org/apache/cassandra/index/IndexRegistry.java b/src/java/org/apache/cassandra/index/IndexRegistry.java
index 639d9aec350b..94bdfe0792fc 100644
--- a/src/java/org/apache/cassandra/index/IndexRegistry.java
+++ b/src/java/org/apache/cassandra/index/IndexRegistry.java
@@ -21,6 +21,7 @@
package org.apache.cassandra.index;
import java.util.Collection;
+import java.util.HashSet;
import java.util.Collections;
import java.util.Optional;
import java.util.Set;
@@ -103,12 +104,6 @@ public Optional getBestIndexFor(RowFilter.Expression expression)
public void validate(PartitionUpdate update)
{
}
-
- @Override
- public void validate(RowFilter filter)
- {
- // no-op since it's an empty registry
- }
};
/**
@@ -296,12 +291,6 @@ public Optional getBestIndexFor(RowFilter.Expression expression)
public void validate(PartitionUpdate update)
{
}
-
- @Override
- public void validate(RowFilter filter)
- {
- // no-op since it's an empty registry
- }
};
default void registerIndex(Index index)
@@ -354,8 +343,6 @@ default Optional getAnalyzerFor(ColumnMetadata column,
*/
void validate(PartitionUpdate update);
- void validate(RowFilter filter);
-
/**
* Returns the {@code IndexRegistry} associated to the specified table.
*
@@ -369,4 +356,74 @@ public static IndexRegistry obtain(TableMetadata table)
return table.isVirtual() ? EMPTY : Keyspace.openAndGetStore(table).indexManager;
}
+
+ enum EqBehavior
+ {
+ EQ,
+ MATCH,
+ AMBIGUOUS
+ }
+
+ class EqBehaviorIndexes
+ {
+ public EqBehavior behavior;
+ public final Index eqIndex;
+ public final Index matchIndex;
+
+ private EqBehaviorIndexes(Index eqIndex, Index matchIndex, EqBehavior behavior)
+ {
+ this.eqIndex = eqIndex;
+ this.matchIndex = matchIndex;
+ this.behavior = behavior;
+ }
+
+ public static EqBehaviorIndexes eq(Index eqIndex)
+ {
+ return new EqBehaviorIndexes(eqIndex, null, EqBehavior.EQ);
+ }
+
+ public static EqBehaviorIndexes match(Index eqAndMatchIndex)
+ {
+ return new EqBehaviorIndexes(eqAndMatchIndex, eqAndMatchIndex, EqBehavior.MATCH);
+ }
+
+ public static EqBehaviorIndexes ambiguous(Index firstEqIndex, Index secondEqIndex)
+ {
+ return new EqBehaviorIndexes(firstEqIndex, secondEqIndex, EqBehavior.AMBIGUOUS);
+ }
+ }
+
+ /**
+ * @return
+ * - AMBIGUOUS if an index supports EQ and a different one supports both EQ and ANALYZER_MATCHES
+ * - MATCHES if an index supports both EQ and ANALYZER_MATCHES
+ * - otherwise EQ
+ */
+ default EqBehaviorIndexes getEqBehavior(ColumnMetadata cm)
+ {
+ Index eqOnlyIndex = null;
+ Index bothIndex = null;
+
+ for (Index index : listIndexes())
+ {
+ boolean supportsEq = index.supportsExpression(cm, Operator.EQ);
+ boolean supportsMatches = index.supportsExpression(cm, Operator.ANALYZER_MATCHES);
+
+ if (supportsEq && supportsMatches)
+ bothIndex = index;
+ else if (supportsEq)
+ eqOnlyIndex = index;
+ }
+
+ // If we have one index supporting only EQ and another supporting both, return AMBIGUOUS
+ if (eqOnlyIndex != null && bothIndex != null)
+ return EqBehaviorIndexes.ambiguous(eqOnlyIndex, bothIndex);
+
+ // If we have an index supporting both EQ and MATCHES, return MATCHES
+ if (bothIndex != null)
+ return EqBehaviorIndexes.match(bothIndex);
+
+ // Otherwise return EQ
+ return EqBehaviorIndexes.eq(eqOnlyIndex == null ? bothIndex : eqOnlyIndex);
+ }
}
diff --git a/src/java/org/apache/cassandra/index/RowFilterValidator.java b/src/java/org/apache/cassandra/index/RowFilterValidator.java
deleted file mode 100644
index fb70fbfc1452..000000000000
--- a/src/java/org/apache/cassandra/index/RowFilterValidator.java
+++ /dev/null
@@ -1,103 +0,0 @@
-/*
- * Copyright DataStax, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.cassandra.index;
-
-import java.util.HashSet;
-import java.util.Set;
-import java.util.StringJoiner;
-
-import org.apache.cassandra.cql3.Operator;
-import org.apache.cassandra.db.filter.RowFilter;
-import org.apache.cassandra.index.sai.analyzer.AnalyzerEqOperatorSupport;
-import org.apache.cassandra.schema.ColumnMetadata;
-import org.apache.cassandra.service.ClientWarn;
-
-/**
- * Class for validating the index-related aspects of a {@link RowFilter}, without considering what index is actually used.
- *
- * It will emit a client warning when a query has EQ restrictions on columns having an analyzed index.
- */
-class RowFilterValidator
-{
- private final Iterable allIndexes;
-
- private Set columns;
- private Set indexes;
-
- private RowFilterValidator(Iterable allIndexes)
- {
- this.allIndexes = allIndexes;
- }
-
- private void addEqRestriction(ColumnMetadata column)
- {
- for (Index index : allIndexes)
- {
- if (index.supportsExpression(column, Operator.EQ) &&
- index.supportsExpression(column, Operator.ANALYZER_MATCHES))
- {
- if (columns == null)
- columns = new HashSet<>();
- columns.add(column);
-
- if (indexes == null)
- indexes = new HashSet<>();
- indexes.add(index);
- }
- }
- }
-
- private void validate()
- {
- if (columns == null || indexes == null)
- return;
-
- StringJoiner columnNames = new StringJoiner(", ");
- StringJoiner indexNames = new StringJoiner(", ");
- columns.forEach(column -> columnNames.add(column.name.toString()));
- indexes.forEach(index -> indexNames.add(index.getIndexMetadata().name));
-
- ClientWarn.instance.warn(String.format(AnalyzerEqOperatorSupport.EQ_RESTRICTION_ON_ANALYZED_WARNING, columnNames, indexNames));
- }
-
- /**
- * Emits a client warning if the filter contains EQ restrictions on columns having an analyzed index.
- *
- * @param filter the filter to validate
- * @param indexes the existing indexes
- */
- public static void validate(RowFilter filter, Iterable indexes)
- {
- RowFilterValidator validator = new RowFilterValidator(indexes);
- validate(filter.root(), validator);
- validator.validate();
- }
-
- private static void validate(RowFilter.FilterElement element, RowFilterValidator validator)
- {
- for (RowFilter.Expression expression : element.expressions())
- {
- if (expression.operator() == Operator.EQ)
- validator.addEqRestriction(expression.column());
- }
-
- for (RowFilter.FilterElement child : element.children())
- {
- validate(child, validator);
- }
- }
-}
diff --git a/src/java/org/apache/cassandra/index/SecondaryIndexManager.java b/src/java/org/apache/cassandra/index/SecondaryIndexManager.java
index b1930e741804..d4c75ed2a4f1 100644
--- a/src/java/org/apache/cassandra/index/SecondaryIndexManager.java
+++ b/src/java/org/apache/cassandra/index/SecondaryIndexManager.java
@@ -1277,12 +1277,6 @@ public void validate(PartitionUpdate update) throws InvalidRequestException
index.validate(update);
}
- @Override
- public void validate(RowFilter filter)
- {
- RowFilterValidator.validate(filter, indexes.values());
- }
-
/*
* IndexRegistry methods
*/
diff --git a/src/java/org/apache/cassandra/index/sai/IndexContext.java b/src/java/org/apache/cassandra/index/sai/IndexContext.java
index 7820400a6c25..1de515e16267 100644
--- a/src/java/org/apache/cassandra/index/sai/IndexContext.java
+++ b/src/java/org/apache/cassandra/index/sai/IndexContext.java
@@ -717,8 +717,8 @@ public boolean supports(Operator op)
{
if (op.isLike() || op == Operator.LIKE) return false;
// Analyzed columns store the indexed result, so we are unable to compute raw equality.
- // The only supported operator is ANALYZER_MATCHES.
- if (op == Operator.ANALYZER_MATCHES) return isAnalyzed;
+ // The only supported operators are ANALYZER_MATCHES and BM25.
+ if (op == Operator.ANALYZER_MATCHES || op == Operator.BM25) return isAnalyzed;
// If the column is analyzed and the operator is EQ, we need to check if the analyzer supports it.
if (op == Operator.EQ && isAnalyzed && !analyzerFactory.supportsEquals())
@@ -742,7 +742,6 @@ public boolean supports(Operator op)
|| column.type instanceof IntegerType); // Currently truncates to 20 bytes
Expression.Op operator = Expression.Op.valueOf(op);
-
if (isNonFrozenCollection())
{
if (indexType == IndexTarget.Type.KEYS)
@@ -754,17 +753,12 @@ public boolean supports(Operator op)
return indexType == IndexTarget.Type.KEYS_AND_VALUES &&
(operator == Expression.Op.EQ || operator == Expression.Op.NOT_EQ || operator == Expression.Op.RANGE);
}
-
if (indexType == IndexTarget.Type.FULL)
return operator == Expression.Op.EQ;
-
AbstractType> validator = getValidator();
-
if (operator == Expression.Op.IN)
return true;
-
if (operator != Expression.Op.EQ && EQ_ONLY_TYPES.contains(validator)) return false;
-
// RANGE only applicable to non-literal indexes
return (operator != null) && !(TypeUtil.isLiteral(validator) && operator == Expression.Op.RANGE);
}
diff --git a/src/java/org/apache/cassandra/index/sai/QueryContext.java b/src/java/org/apache/cassandra/index/sai/QueryContext.java
index a0b1719049ec..5f1f7cda8498 100644
--- a/src/java/org/apache/cassandra/index/sai/QueryContext.java
+++ b/src/java/org/apache/cassandra/index/sai/QueryContext.java
@@ -36,7 +36,7 @@
@NotThreadSafe
public class QueryContext
{
- private static final boolean DISABLE_TIMEOUT = Boolean.getBoolean("cassandra.sai.test.disable.timeout");
+ public static final boolean DISABLE_TIMEOUT = Boolean.getBoolean("cassandra.sai.test.disable.timeout");
protected final long queryStartTimeNanos;
diff --git a/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java b/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java
index 28cb42ab8dc8..12557cd66a2d 100644
--- a/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java
+++ b/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java
@@ -299,20 +299,26 @@ public static Map validateOptions(Map options, T
throw new InvalidRequestException("Failed to retrieve target column for: " + targetColumn);
}
- // In order to support different index target on non-frozen map, ie. KEYS, VALUE, ENTRIES, we need to put index
- // name as part of index file name instead of column name. We only need to check that the target is different
- // between indexes. This will only allow indexes in the same column with a different IndexTarget.Type.
- //
- // Note that: "metadata.indexes" already includes current index
- if (metadata.indexes.stream().filter(index -> index.getIndexClassName().equals(StorageAttachedIndex.class.getName()))
- .map(index -> TargetParser.parse(metadata, index.options.get(IndexTarget.TARGET_OPTION_NAME)))
- .filter(Objects::nonNull).filter(t -> t.equals(target)).count() > 1)
- {
- throw new InvalidRequestException("Cannot create more than one storage-attached index on the same column: " + target.left);
- }
+ // Check for duplicate indexes considering both target and analyzer configuration
+ boolean isAnalyzed = AbstractAnalyzer.isAnalyzed(options);
+ long duplicateCount = metadata.indexes.stream()
+ .filter(index -> index.getIndexClassName().equals(StorageAttachedIndex.class.getName()))
+ .filter(index -> {
+ // Indexes on the same column with different target (KEYS, VALUES, ENTRIES)
+ // are allowed on non-frozen Maps
+ var existingTarget = TargetParser.parse(metadata, index.options.get(IndexTarget.TARGET_OPTION_NAME));
+ if (existingTarget == null || !existingTarget.equals(target))
+ return false;
+ // Also allow different indexes if one is analyzed and the other isn't
+ return isAnalyzed == AbstractAnalyzer.isAnalyzed(index.options);
+ })
+ .count();
+ // >1 because "metadata.indexes" already includes current index
+ if (duplicateCount > 1)
+ throw new InvalidRequestException(String.format("Cannot create duplicate storage-attached index on column: %s", target.left));
// Analyzer is not supported against PK columns
- if (AbstractAnalyzer.isAnalyzed(options))
+ if (isAnalyzed)
{
for (ColumnMetadata column : metadata.primaryKeyColumns())
{
diff --git a/src/java/org/apache/cassandra/index/sai/analyzer/AnalyzerEqOperatorSupport.java b/src/java/org/apache/cassandra/index/sai/analyzer/AnalyzerEqOperatorSupport.java
index 116bc7f62832..30408c9b986f 100644
--- a/src/java/org/apache/cassandra/index/sai/analyzer/AnalyzerEqOperatorSupport.java
+++ b/src/java/org/apache/cassandra/index/sai/analyzer/AnalyzerEqOperatorSupport.java
@@ -49,7 +49,7 @@ public class AnalyzerEqOperatorSupport
OPTION, Arrays.toString(Value.values()));
public static final String EQ_RESTRICTION_ON_ANALYZED_WARNING =
- String.format("Columns [%%s] are restricted by '=' and have analyzed indexes [%%s] able to process those restrictions. " +
+ String.format("Column [%%s] is restricted by '=' and has an analyzed index [%%s] able to process those restrictions. " +
"Analyzed indexes might process '=' restrictions in a way that is inconsistent with non-indexed queries. " +
"While '=' is still supported on analyzed indexes for backwards compatibility, " +
"it is recommended to use the ':' operator instead to prevent the ambiguity. " +
@@ -58,6 +58,13 @@ public class AnalyzerEqOperatorSupport
"please use '%s':'%s' in the index options.",
OPTION, Value.UNSUPPORTED.toString().toLowerCase());
+ public static final String EQ_AMBIGUOUS_ERROR =
+ String.format("Column [%%s] equality predicate is ambiguous. It has both an analyzed index [%%s] configured with '%s':'%s', " +
+ "and an un-analyzed index [%%s]. " +
+ "To avoid ambiguity, drop the analyzed index and recreate it with option '%s':'%s'.",
+ OPTION, Value.MATCH.toString().toLowerCase(), OPTION, Value.UNSUPPORTED.toString().toLowerCase());
+
+
public static final String LWT_CONDITION_ON_ANALYZED_WARNING =
"Index analyzers not applied to LWT conditions on columns [%s].";
diff --git a/src/java/org/apache/cassandra/index/sai/disk/MemtableTermsIterator.java b/src/java/org/apache/cassandra/index/sai/disk/MemtableTermsIterator.java
index 9f0253e61ed8..615f3cfb7a4e 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/MemtableTermsIterator.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/MemtableTermsIterator.java
@@ -19,11 +19,11 @@
import java.nio.ByteBuffer;
import java.util.Iterator;
+import java.util.List;
import com.google.common.base.Preconditions;
-import com.carrotsearch.hppc.IntArrayList;
-import com.carrotsearch.hppc.cursors.IntCursor;
+import org.apache.cassandra.index.sai.memory.RowMapping;
import org.apache.cassandra.utils.Pair;
import org.apache.cassandra.utils.bytecomparable.ByteComparable;
@@ -34,16 +34,16 @@ public class MemtableTermsIterator implements TermsIterator
{
private final ByteBuffer minTerm;
private final ByteBuffer maxTerm;
- private final Iterator> iterator;
+ private final Iterator>> iterator;
- private Pair current;
+ private Pair> current;
private int maxSSTableRowId = -1;
private int minSSTableRowId = Integer.MAX_VALUE;
public MemtableTermsIterator(ByteBuffer minTerm,
ByteBuffer maxTerm,
- Iterator> iterator)
+ Iterator>> iterator)
{
Preconditions.checkArgument(iterator != null);
this.minTerm = minTerm;
@@ -69,22 +69,24 @@ public void close() {}
@Override
public PostingList postings()
{
- final IntArrayList list = current.right;
+ var list = current.right;
assert list.size() > 0;
- final int minSegmentRowID = list.get(0);
- final int maxSegmentRowID = list.get(list.size() - 1);
+ final int minSegmentRowID = list.get(0).rowId;
+ final int maxSegmentRowID = list.get(list.size() - 1).rowId;
// Because we are working with postings from the memtable, there is only one segment, so segment row ids
// and sstable row ids are the same.
minSSTableRowId = Math.min(minSSTableRowId, minSegmentRowID);
maxSSTableRowId = Math.max(maxSSTableRowId, maxSegmentRowID);
- final Iterator it = list.iterator();
+ var it = list.iterator();
return new PostingList()
{
+ int frequency;
+
@Override
public int nextPosting()
{
@@ -93,7 +95,9 @@ public int nextPosting()
return END_OF_STREAM;
}
- return it.next().value;
+ var rowIdWithFrequency = it.next();
+ frequency = rowIdWithFrequency.frequency;
+ return rowIdWithFrequency.rowId;
}
@Override
@@ -102,6 +106,12 @@ public int size()
return list.size();
}
+ @Override
+ public int frequency()
+ {
+ return frequency;
+ }
+
@Override
public int advance(int targetRowID)
{
diff --git a/src/java/org/apache/cassandra/index/sai/disk/PostingList.java b/src/java/org/apache/cassandra/index/sai/disk/PostingList.java
index b7dffb972599..4959c6f0be6a 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/PostingList.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/PostingList.java
@@ -42,6 +42,14 @@ default void close() throws IOException {}
*/
int nextPosting() throws IOException;
+ /**
+ * @return the number of occurrences of the term in the current row (the one most recently returned by nextPosting).
+ */
+ default int frequency()
+ {
+ return 1;
+ }
+
int size();
/**
diff --git a/src/java/org/apache/cassandra/index/sai/disk/PrimaryKeyWithSource.java b/src/java/org/apache/cassandra/index/sai/disk/PrimaryKeyWithSource.java
index 7f9f7e89f51e..d303bd4d27c7 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/PrimaryKeyWithSource.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/PrimaryKeyWithSource.java
@@ -18,6 +18,7 @@
package org.apache.cassandra.index.sai.disk;
+import io.github.jbellis.jvector.util.RamUsageEstimator;
import org.apache.cassandra.db.Clustering;
import org.apache.cassandra.db.DecoratedKey;
import org.apache.cassandra.dht.Token;
@@ -127,4 +128,14 @@ public String toString()
{
return String.format("%s (source sstable: %s, %s)", primaryKey, sourceSstableId, sourceRowId);
}
+
+ @Override
+ public long ramBytesUsed()
+ {
+ // Object header + 3 references (primaryKey, sourceSstableId) + long value
+ return RamUsageEstimator.NUM_BYTES_OBJECT_HEADER +
+ 2L * RamUsageEstimator.NUM_BYTES_OBJECT_REF +
+ Long.BYTES +
+ primaryKey.ramBytesUsed();
+ }
}
diff --git a/src/java/org/apache/cassandra/index/sai/disk/RAMPostingSlices.java b/src/java/org/apache/cassandra/index/sai/disk/RAMPostingSlices.java
index 5879e5f1abdf..1313f81569ae 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/RAMPostingSlices.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/RAMPostingSlices.java
@@ -25,22 +25,33 @@
import org.apache.lucene.util.mutable.MutableValueInt;
/**
- * Encodes postings as variable integers into slices
+ * Encodes postings as variable integers into "slices" of byte blocks for efficient memory usage.
*/
class RAMPostingSlices
{
static final int DEFAULT_TERM_DICT_SIZE = 1024;
+ /** Pool of byte blocks storing the actual posting data */
private final ByteBlockPool postingsPool;
+ /** true if we're also writing term frequencies for an analyzed index */
+ private final boolean includeFrequencies;
+
+ /** The starting positions of postings for each term. Term id = index in array. */
private int[] postingStarts = new int[DEFAULT_TERM_DICT_SIZE];
+ /** The current write positions for each term's postings. Term id = index in array. */
private int[] postingUptos = new int[DEFAULT_TERM_DICT_SIZE];
+ /** The number of postings for each term. Term id = index in array. */
private int[] sizes = new int[DEFAULT_TERM_DICT_SIZE];
- RAMPostingSlices(Counter memoryUsage)
+ RAMPostingSlices(Counter memoryUsage, boolean includeFrequencies)
{
postingsPool = new ByteBlockPool(new ByteBlockPool.DirectTrackingAllocator(memoryUsage));
+ this.includeFrequencies = includeFrequencies;
}
+ /**
+ * Creates and returns a PostingList for the given term ID.
+ */
PostingList postingList(int termID, final ByteSliceReader reader, long maxSegmentRowID)
{
initReader(reader, termID);
@@ -49,20 +60,35 @@ PostingList postingList(int termID, final ByteSliceReader reader, long maxSegmen
return new PostingList()
{
+ int frequency = Integer.MIN_VALUE;
+
@Override
public int nextPosting() throws IOException
{
if (reader.eof())
{
+ frequency = Integer.MIN_VALUE;
return PostingList.END_OF_STREAM;
}
else
{
lastSegmentRowId.value += reader.readVInt();
+ if (includeFrequencies)
+ frequency = reader.readVInt();
return lastSegmentRowId.value;
}
}
+ @Override
+ public int frequency()
+ {
+ if (!includeFrequencies)
+ return 1;
+ if (frequency <= 0)
+ throw new IllegalStateException("frequency() called before nextPosting()");
+ return frequency;
+ }
+
@Override
public int size()
{
@@ -77,12 +103,20 @@ public int advance(int targetRowID)
};
}
+ /**
+ * Initializes a ByteSliceReader for reading postings for a specific term.
+ */
void initReader(ByteSliceReader reader, int termID)
{
final int upto = postingUptos[termID];
reader.init(postingsPool, postingStarts[termID], upto);
}
+ /**
+ * Creates a new slice for storing postings for a given term ID.
+ * Grows the internal arrays if necessary and allocates a new block
+ * if the current block cannot accommodate a new slice.
+ */
void createNewSlice(int termID)
{
if (termID >= postingStarts.length - 1)
@@ -103,7 +137,27 @@ void createNewSlice(int termID)
postingUptos[termID] = upto + postingsPool.byteOffset;
}
- void writeVInt(int termID, int i)
+ void writePosting(int termID, int deltaRowId, int frequency)
+ {
+ assert termID >= 0 : termID;
+ assert deltaRowId >= 0 : deltaRowId;
+ writeVInt(termID, deltaRowId);
+
+ if (includeFrequencies)
+ {
+ assert frequency > 0 : frequency;
+ writeVInt(termID, frequency);
+ }
+
+ sizes[termID]++;
+ }
+
+ /**
+ * Writes a variable-length integer to the posting list for a given term.
+ * The integer is encoded using a variable-length encoding scheme where each
+ * byte uses 7 bits for the value and 1 bit to indicate if more bytes follow.
+ */
+ private void writeVInt(int termID, int i)
{
while ((i & ~0x7F) != 0)
{
@@ -111,9 +165,12 @@ void writeVInt(int termID, int i)
i >>>= 7;
}
writeByte(termID, (byte) i);
- sizes[termID]++;
}
+ /**
+ * Writes a single byte to the posting list for a given term.
+ * If the current slice is full, it automatically allocates a new slice.
+ */
private void writeByte(int termID, byte b)
{
int upto = postingUptos[termID];
diff --git a/src/java/org/apache/cassandra/index/sai/disk/RAMStringIndexer.java b/src/java/org/apache/cassandra/index/sai/disk/RAMStringIndexer.java
index 8b6f07d43597..40c7d8fa37a8 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/RAMStringIndexer.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/RAMStringIndexer.java
@@ -18,10 +18,12 @@
package org.apache.cassandra.index.sai.disk;
import java.nio.ByteBuffer;
+import java.util.List;
import java.util.NoSuchElementException;
import com.google.common.annotations.VisibleForTesting;
+import org.agrona.collections.Int2IntHashMap;
import org.apache.cassandra.utils.bytecomparable.ByteComparable;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.ByteBlockPool;
@@ -43,11 +45,14 @@ public class RAMStringIndexer
private final Counter termsBytesUsed;
private final Counter slicesBytesUsed;
- private int rowCount = 0;
private int[] lastSegmentRowID = new int[RAMPostingSlices.DEFAULT_TERM_DICT_SIZE];
- public RAMStringIndexer()
+ private final boolean writeFrequencies;
+ private final Int2IntHashMap docLengths = new Int2IntHashMap(Integer.MIN_VALUE);
+
+ public RAMStringIndexer(boolean writeFrequencies)
{
+ this.writeFrequencies = writeFrequencies;
termsBytesUsed = Counter.newCounter();
slicesBytesUsed = Counter.newCounter();
@@ -55,7 +60,7 @@ public RAMStringIndexer()
termsHash = new BytesRefHash(termsPool);
- slices = new RAMPostingSlices(slicesBytesUsed);
+ slices = new RAMPostingSlices(slicesBytesUsed, writeFrequencies);
}
public long estimatedBytesUsed()
@@ -75,7 +80,12 @@ public boolean requiresFlush()
public boolean isEmpty()
{
- return rowCount == 0;
+ return docLengths.isEmpty();
+ }
+
+ public Int2IntHashMap getDocLengths()
+ {
+ return docLengths;
}
/**
@@ -140,36 +150,58 @@ private ByteComparable asByteComparable(byte[] bytes, int offset, int length)
};
}
- public long add(BytesRef term, int segmentRowId)
+ /**
+ * @return bytes allocated. may be zero if the (term, row) pair is a duplicate
+ */
+ public long addAll(List terms, int segmentRowId)
{
long startBytes = estimatedBytesUsed();
- int termID = termsHash.add(term);
-
- if (termID >= 0)
- {
- // firs time seeing this term, create the term's first slice !
- slices.createNewSlice(termID);
- }
- else
- {
- termID = (-termID) - 1;
- }
+ Int2IntHashMap frequencies = new Int2IntHashMap(Integer.MIN_VALUE);
+ Int2IntHashMap deltas = new Int2IntHashMap(Integer.MIN_VALUE);
- if (termID >= lastSegmentRowID.length - 1)
+ for (BytesRef term : terms)
{
- lastSegmentRowID = ArrayUtil.grow(lastSegmentRowID, termID + 1);
- }
+ int termID = termsHash.add(term);
+ boolean firstOccurrence = termID >= 0;
- int delta = segmentRowId - lastSegmentRowID[termID];
-
- lastSegmentRowID[termID] = segmentRowId;
+ if (firstOccurrence)
+ {
+ // first time seeing this term in any row, create the term's first slice !
+ slices.createNewSlice(termID);
+ // grow the termID -> last segment array if necessary
+ if (termID >= lastSegmentRowID.length - 1)
+ lastSegmentRowID = ArrayUtil.grow(lastSegmentRowID, termID + 1);
+ if (writeFrequencies)
+ frequencies.put(termID, 1);
+ }
+ else
+ {
+ termID = (-termID) - 1;
+ // compaction should call this method only with increasing segmentRowIds
+ assert segmentRowId >= lastSegmentRowID[termID];
+ // increment frequency
+ if (writeFrequencies)
+ frequencies.put(termID, frequencies.getOrDefault(termID, 0) + 1);
+ // Skip computing a delta if we've already seen this term in this row
+ if (segmentRowId == lastSegmentRowID[termID])
+ continue;
+ }
- slices.writeVInt(termID, delta);
+ // Compute the delta from the last time this term was seen, to this row
+ int delta = segmentRowId - lastSegmentRowID[termID];
+ // sanity check that we're advancing the row id, i.e. no duplicate entries.
+ assert firstOccurrence || delta > 0;
+ deltas.put(termID, delta);
+ lastSegmentRowID[termID] = segmentRowId;
+ }
- long allocatedBytes = estimatedBytesUsed() - startBytes;
+ // add the postings now that we know the frequencies
+ deltas.forEachInt((termID, delta) -> {
+ slices.writePosting(termID, delta, frequencies.get(termID));
+ });
- rowCount++;
+ docLengths.put(segmentRowId, terms.size());
- return allocatedBytes;
+ return estimatedBytesUsed() - startBytes;
}
}
diff --git a/src/java/org/apache/cassandra/index/sai/disk/format/IndexComponentType.java b/src/java/org/apache/cassandra/index/sai/disk/format/IndexComponentType.java
index 165f484eae66..bf38d6b8fbd8 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/format/IndexComponentType.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/format/IndexComponentType.java
@@ -111,7 +111,12 @@ public enum IndexComponentType
*
* V1 V2
*/
- GROUP_COMPLETION_MARKER("GroupComplete");
+ GROUP_COMPLETION_MARKER("GroupComplete"),
+
+ /**
+ * Stores document length information for BM25 scoring
+ */
+ DOC_LENGTHS("DocLengths");
public final String representation;
diff --git a/src/java/org/apache/cassandra/index/sai/disk/format/Version.java b/src/java/org/apache/cassandra/index/sai/disk/format/Version.java
index 6411072486b8..77757eec4851 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/format/Version.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/format/Version.java
@@ -34,6 +34,7 @@
import org.apache.cassandra.index.sai.disk.v4.V4OnDiskFormat;
import org.apache.cassandra.index.sai.disk.v5.V5OnDiskFormat;
import org.apache.cassandra.index.sai.disk.v6.V6OnDiskFormat;
+import org.apache.cassandra.index.sai.disk.v7.V7OnDiskFormat;
import org.apache.cassandra.index.sai.utils.TypeUtil;
import org.apache.cassandra.utils.bytecomparable.ByteComparable;
@@ -43,7 +44,7 @@
* Format version of indexing component, denoted as [major][minor]. Same forward-compatibility rules apply as to
* {@link org.apache.cassandra.io.sstable.format.Version}.
*/
-public class Version
+public class Version implements Comparable
{
// 6.8 formats
public static final Version AA = new Version("aa", V1OnDiskFormat.instance, Version::aaFileNameFormat);
@@ -53,15 +54,17 @@ public class Version
public static final Version CA = new Version("ca", V3OnDiskFormat.instance, (c, i, g) -> stargazerFileNameFormat(c, i, g, "ca"));
// NOTE: use DB to prevent collisions with upstream file formats
// Encode trie entries using their AbstractType to ensure trie entries are sorted for range queries and are prefix free.
- public static final Version DB = new Version("db", V4OnDiskFormat.instance, (c, i, g) -> stargazerFileNameFormat(c, i, g,"db"));
+ public static final Version DB = new Version("db", V4OnDiskFormat.instance, (c, i, g) -> stargazerFileNameFormat(c, i, g, "db"));
// revamps vector postings lists to cause fewer reads from disk
public static final Version DC = new Version("dc", V5OnDiskFormat.instance, (c, i, g) -> stargazerFileNameFormat(c, i, g, "dc"));
// histograms in index metadata
public static final Version EB = new Version("eb", V6OnDiskFormat.instance, (c, i, g) -> stargazerFileNameFormat(c, i, g, "eb"));
+ // term frequencies index component
+ public static final Version EC = new Version("ec", V7OnDiskFormat.instance, (c, i, g) -> stargazerFileNameFormat(c, i, g, "ec"));
// These are in reverse-chronological order so that the latest version is first. Version matching tests
// are more likely to match the latest version so we want to test that one first.
- public static final List ALL = Lists.newArrayList(EB, DC, DB, CA, BA, AA);
+ public static final List ALL = Lists.newArrayList(EC, EB, DC, DB, CA, BA, AA);
public static final Version EARLIEST = AA;
public static final Version VECTOR_EARLIEST = BA;
@@ -87,7 +90,8 @@ public static Version parse(String input)
{
checkArgument(input != null);
checkArgument(input.length() == 2);
- for (var v : ALL) {
+ for (var v : ALL)
+ {
if (input.equals(v.version))
return v;
}
@@ -110,7 +114,7 @@ public boolean equals(Object o)
{
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
- Version other = (Version)o;
+ Version other = (Version) o;
return Objects.equal(version, other.version);
}
@@ -143,6 +147,11 @@ public boolean useImmutableComponentFiles()
return CassandraRelevantProperties.IMMUTABLE_SAI_COMPONENTS.getBoolean() && onOrAfter(Version.CA);
}
+ @Override
+ public int compareTo(Version other)
+ {
+ return this.version.compareTo(other.version);
+ }
public interface FileNameFormatter
{
@@ -152,7 +161,7 @@ public interface FileNameFormatter
*/
default String format(IndexComponentType indexComponentType, IndexContext indexContext, int generation)
{
- return format(indexComponentType, indexContext == null ? null : indexContext.getIndexName(), generation);
+ return format(indexComponentType, indexContext == null ? null : indexContext.getIndexName(), generation);
}
/**
@@ -160,8 +169,8 @@ default String format(IndexComponentType indexComponentType, IndexContext indexC
* filename is returned (so the suffix of the full filename), not a full path.
*
* @param indexComponentType the type of the index component.
- * @param indexName the name of the index, or {@code null} for a per-sstable component.
- * @param generation the generation of the build of the component.
+ * @param indexName the name of the index, or {@code null} for a per-sstable component.
+ * @param generation the generation of the build of the component.
*/
String format(IndexComponentType indexComponentType, @Nullable String indexName, int generation);
}
@@ -191,7 +200,6 @@ else if (componentStr.startsWith("SAI" + SAI_SEPARATOR))
return tryParseStargazerFileName(componentStr);
else
return Optional.empty();
-
}
public static class ParsedFileName
@@ -222,7 +230,7 @@ private static String aaFileNameFormat(IndexComponentType indexComponentType, @N
StringBuilder stringBuilder = new StringBuilder();
stringBuilder.append(indexName == null ? String.format(VERSION_AA_PER_SSTABLE_FORMAT, indexComponentType.representation)
- : String.format(VERSION_AA_PER_INDEX_FORMAT, indexName, indexComponentType.representation));
+ : String.format(VERSION_AA_PER_INDEX_FORMAT, indexName, indexComponentType.representation));
return stringBuilder.toString();
}
diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/DocLengthsReader.java b/src/java/org/apache/cassandra/index/sai/disk/v1/DocLengthsReader.java
new file mode 100644
index 000000000000..45769f9337d8
--- /dev/null
+++ b/src/java/org/apache/cassandra/index/sai/disk/v1/DocLengthsReader.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.cassandra.index.sai.disk.v1;
+
+import java.io.Closeable;
+import java.io.IOException;
+
+import org.apache.cassandra.index.sai.disk.io.IndexInputReader;
+import org.apache.cassandra.index.sai.utils.IndexFileUtils;
+import org.apache.cassandra.index.sai.utils.SAICodecUtils;
+import org.apache.cassandra.io.util.FileHandle;
+import org.apache.cassandra.io.util.FileUtils;
+import org.apache.lucene.codecs.CodecUtil;
+
+public class DocLengthsReader implements Closeable
+{
+ private final FileHandle fileHandle;
+ private final IndexInputReader input;
+ private final SegmentMetadata.ComponentMetadata componentMetadata;
+
+ public DocLengthsReader(FileHandle fileHandle, SegmentMetadata.ComponentMetadata componentMetadata)
+ {
+ this.fileHandle = fileHandle;
+ this.input = IndexFileUtils.instance.openInput(fileHandle);
+ this.componentMetadata = componentMetadata;
+ }
+
+ public int get(int rowID) throws IOException
+ {
+ // Account for header size in offset calculation
+ long position = componentMetadata.offset + (long) rowID * Integer.BYTES;
+ if (position >= componentMetadata.offset + componentMetadata.length)
+ return 0;
+ input.seek(position);
+ return input.readInt();
+ }
+
+ @Override
+ public void close() throws IOException
+ {
+ FileUtils.close(fileHandle, input);
+ }
+}
+
diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java
index 52988a685cd5..16d14bdd8ac3 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java
@@ -68,9 +68,7 @@ public abstract class IndexSearcher implements Closeable, SegmentOrdering
protected final SegmentMetadata metadata;
protected final IndexContext indexContext;
- private static final SSTableReadsListener NOOP_LISTENER = new SSTableReadsListener() {};
-
- private final ColumnFilter columnFilter;
+ protected final ColumnFilter columnFilter;
protected IndexSearcher(PrimaryKeyMap.Factory primaryKeyMapFactory,
PerIndexFiles perIndexFiles,
@@ -90,30 +88,36 @@ protected IndexSearcher(PrimaryKeyMap.Factory primaryKeyMapFactory,
public abstract long indexFileCacheSize();
/**
- * Search on-disk index synchronously
+ * Search on-disk index synchronously. Used for WHERE clause predicates, including BOUNDED_ANN.
*
* @param expression to filter on disk index
* @param keyRange key range specific in read command, used by ANN index
* @param queryContext to track per sstable cache and per query metrics
* @param defer create the iterator in a deferred state
- * @param limit the num of rows to returned, used by ANN index
* @return {@link KeyRangeIterator} that matches given expression
*/
- public abstract KeyRangeIterator search(Expression expression, AbstractBounds keyRange, QueryContext queryContext, boolean defer, int limit) throws IOException;
+ public abstract KeyRangeIterator search(Expression expression, AbstractBounds keyRange, QueryContext queryContext, boolean defer) throws IOException;
/**
- * Order the on-disk index synchronously and produce an iterator in score order
+ * Order the rows by the given Orderer. Used for ORDER BY clause when
+ * (1) the WHERE predicate is either a partition restriction or a range restriction on the index,
+ * (2) there is no WHERE predicate, or
+ * (3) the planner determines it is better to post-filter the ordered results by the predicate.
*
* @param orderer the object containing the ordering logic
* @param slice optional predicate to get a slice of the index
* @param keyRange key range specific in read command, used by ANN index
* @param queryContext to track per sstable cache and per query metrics
- * @param limit the num of rows to returned, used by ANN index
+ * @param limit the initial num of rows to returned, used by ANN index. More rows may be requested if filtering throws away more than expected!
* @return an iterator of {@link PrimaryKeyWithSortKey} in score order
*/
public abstract CloseableIterator orderBy(Orderer orderer, Expression slice, AbstractBounds keyRange, QueryContext queryContext, int limit) throws IOException;
-
+ /**
+ * Order the rows by the given Orderer. Used for ORDER BY clause when the WHERE predicates
+ * have been applied first, yielding a list of primary keys. Again, `limit` is a planner hint for ANN to determine
+ * the initial number of results returned, not a maximum.
+ */
@Override
public CloseableIterator orderResultsBy(SSTableReader reader, QueryContext context, List keys, Orderer orderer, int limit) throws IOException
{
@@ -124,7 +128,7 @@ public CloseableIterator orderResultsBy(SSTableReader rea
{
var slices = Slices.with(indexContext.comparator(), Slice.make(key.clustering()));
// TODO if we end up needing to read the row still, is it better to store offset and use reader.unfilteredAt?
- try (var iter = reader.iterator(key.partitionKey(), slices, columnFilter, false, NOOP_LISTENER))
+ try (var iter = reader.iterator(key.partitionKey(), slices, columnFilter, false, SSTableReadsListener.NOOP_LISTENER))
{
if (iter.hasNext())
{
diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java
index 15aaa349e1c8..08a63818c709 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java
@@ -19,15 +19,27 @@
package org.apache.cassandra.index.sai.disk.v1;
import java.io.IOException;
+import java.io.UncheckedIOException;
import java.lang.invoke.MethodHandles;
+import java.nio.ByteBuffer;
+import java.util.HashMap;
+import java.util.List;
import java.util.Map;
+import java.util.function.Function;
+import java.util.stream.Collectors;
import com.google.common.base.MoreObjects;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.apache.cassandra.cql3.Operator;
import org.apache.cassandra.db.PartitionPosition;
+import org.apache.cassandra.db.Slice;
+import org.apache.cassandra.db.Slices;
+import org.apache.cassandra.db.rows.Cell;
+import org.apache.cassandra.db.rows.Row;
import org.apache.cassandra.dht.AbstractBounds;
+import org.apache.cassandra.exceptions.InvalidRequestException;
import org.apache.cassandra.index.sai.IndexContext;
import org.apache.cassandra.index.sai.QueryContext;
import org.apache.cassandra.index.sai.SSTableContext;
@@ -35,24 +47,31 @@
import org.apache.cassandra.index.sai.disk.TermsIterator;
import org.apache.cassandra.index.sai.disk.format.IndexComponentType;
import org.apache.cassandra.index.sai.disk.format.Version;
+import org.apache.cassandra.index.sai.disk.v1.postings.IntersectingPostingList;
import org.apache.cassandra.index.sai.iterators.KeyRangeIterator;
import org.apache.cassandra.index.sai.metrics.MulticastQueryEventListeners;
import org.apache.cassandra.index.sai.metrics.QueryEventListener;
import org.apache.cassandra.index.sai.plan.Expression;
import org.apache.cassandra.index.sai.plan.Orderer;
+import org.apache.cassandra.index.sai.utils.BM25Utils;
+import org.apache.cassandra.index.sai.utils.BM25Utils.DocTF;
+import org.apache.cassandra.index.sai.utils.PrimaryKey;
import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey;
import org.apache.cassandra.index.sai.utils.RowIdWithByteComparable;
import org.apache.cassandra.index.sai.utils.SAICodecUtils;
-import org.apache.cassandra.index.sai.utils.SegmentOrdering;
+import org.apache.cassandra.io.sstable.format.SSTableReader;
+import org.apache.cassandra.io.sstable.format.SSTableReadsListener;
import org.apache.cassandra.io.util.FileUtils;
import org.apache.cassandra.utils.AbstractIterator;
import org.apache.cassandra.utils.CloseableIterator;
import org.apache.cassandra.utils.bytecomparable.ByteComparable;
+import static org.apache.cassandra.index.sai.disk.PostingList.END_OF_STREAM;
+
/**
* Executes {@link Expression}s against the trie-based terms dictionary for an individual index segment.
*/
-public class InvertedIndexSearcher extends IndexSearcher implements SegmentOrdering
+public class InvertedIndexSearcher extends IndexSearcher
{
private static final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
@@ -60,6 +79,9 @@ public class InvertedIndexSearcher extends IndexSearcher implements SegmentOrder
private final QueryEventListener.TrieIndexEventListener perColumnEventListener;
private final Version version;
private final boolean filterRangeResults;
+ private final SSTableReader sstable;
+ private final DocLengthsReader docLengthsReader;
+ private final long segmentRowIdOffset;
protected InvertedIndexSearcher(SSTableContext sstableContext,
PerIndexFiles perIndexFiles,
@@ -69,6 +91,7 @@ protected InvertedIndexSearcher(SSTableContext sstableContext,
boolean filterRangeResults) throws IOException
{
super(sstableContext.primaryKeyMapFactory(), perIndexFiles, segmentMetadata, indexContext);
+ this.sstable = sstableContext.sstable;
long root = metadata.getIndexRoot(IndexComponentType.TERMS_DATA);
assert root >= 0;
@@ -76,6 +99,9 @@ protected InvertedIndexSearcher(SSTableContext sstableContext,
this.version = version;
this.filterRangeResults = filterRangeResults;
perColumnEventListener = (QueryEventListener.TrieIndexEventListener)indexContext.getColumnQueryMetrics();
+ var docLengthsMeta = segmentMetadata.componentMetadatas.getOptional(IndexComponentType.DOC_LENGTHS);
+ this.segmentRowIdOffset = segmentMetadata.segmentRowIdOffset;
+ this.docLengthsReader = docLengthsMeta == null ? null : new DocLengthsReader(indexFiles.docLengths(), docLengthsMeta);
Map map = metadata.componentMetadatas.get(IndexComponentType.TERMS_DATA).attributes;
String footerPointerString = map.get(SAICodecUtils.FOOTER_POINTER);
@@ -100,7 +126,7 @@ public long indexFileCacheSize()
}
@SuppressWarnings("resource")
- public KeyRangeIterator search(Expression exp, AbstractBounds keyRange, QueryContext context, boolean defer, int limit) throws IOException
+ public KeyRangeIterator search(Expression exp, AbstractBounds keyRange, QueryContext context, boolean defer) throws IOException
{
PostingList postingList = searchPosting(exp, context);
return toPrimaryKeyIterator(postingList, context);
@@ -129,11 +155,117 @@ else if (exp.getOp() == Expression.Op.RANGE)
throw new IllegalArgumentException(indexContext.logMessage("Unsupported expression: " + exp));
}
+ private Cell> readColumn(SSTableReader sstable, PrimaryKey primaryKey)
+ {
+ var dk = primaryKey.partitionKey();
+ var slices = Slices.with(indexContext.comparator(), Slice.make(primaryKey.clustering()));
+ try (var rowIterator = sstable.iterator(dk, slices, columnFilter, false, SSTableReadsListener.NOOP_LISTENER))
+ {
+ var unfiltered = rowIterator.next();
+ assert unfiltered.isRow() : unfiltered;
+ Row row = (Row) unfiltered;
+ return row.getCell(indexContext.getDefinition());
+ }
+ }
+
@Override
public CloseableIterator orderBy(Orderer orderer, Expression slice, AbstractBounds keyRange, QueryContext queryContext, int limit) throws IOException
{
- var iter = new RowIdWithTermsIterator(reader.allTerms(orderer.isAscending()));
- return toMetaSortedIterator(iter, queryContext);
+ if (!orderer.isBM25())
+ {
+ var iter = new RowIdWithTermsIterator(reader.allTerms(orderer.isAscending()));
+ return toMetaSortedIterator(iter, queryContext);
+ }
+ if (docLengthsReader == null)
+ throw new InvalidRequestException(indexContext.getIndexName() + " does not support BM25 scoring until it is rebuilt");
+
+ // find documents that match each term
+ var queryTerms = orderer.getQueryTerms();
+ var postingLists = queryTerms.stream()
+ .collect(Collectors.toMap(Function.identity(), term ->
+ {
+ var encodedTerm = version.onDiskFormat().encodeForTrie(term, indexContext.getValidator());
+ var listener = MulticastQueryEventListeners.of(queryContext, perColumnEventListener);
+ var postings = reader.exactMatch(encodedTerm, listener, queryContext);
+ return postings == null ? PostingList.EMPTY : postings;
+ }));
+ // extract the match count for each
+ var documentFrequencies = postingLists.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> (long) e.getValue().size()));
+
+ var pkm = primaryKeyMapFactory.newPerSSTablePrimaryKeyMap();
+ var merged = IntersectingPostingList.intersect(postingLists);
+
+ // Wrap the iterator with resource management
+ var it = new AbstractIterator() { // Anonymous class extends AbstractIterator
+ private boolean closed;
+
+ @Override
+ protected DocTF computeNext()
+ {
+ try
+ {
+ int rowId = merged.nextPosting();
+ if (rowId == PostingList.END_OF_STREAM)
+ return endOfData();
+ int docLength = docLengthsReader.get(rowId); // segment-local rowid
+ var pk = pkm.primaryKeyFromRowId(segmentRowIdOffset + rowId); // sstable-global rowid
+ return new DocTF(pk, docLength, merged.frequencies());
+ }
+ catch (IOException e)
+ {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ @Override
+ public void close()
+ {
+ if (closed) return;
+ closed = true;
+ FileUtils.closeQuietly(pkm, merged);
+ }
+ };
+ return bm25Internal(it, queryTerms, documentFrequencies);
+ }
+
+ private CloseableIterator bm25Internal(CloseableIterator keyIterator,
+ List queryTerms,
+ Map documentFrequencies)
+ {
+ var totalRows = sstable.getTotalRows();
+ // since doc frequencies can be an estimate from the index histogram, which does not have bounded error,
+ // cap frequencies to total rows so that the IDF term doesn't turn negative
+ var cappedFrequencies = documentFrequencies.entrySet().stream()
+ .collect(Collectors.toMap(Map.Entry::getKey, e -> Math.min(e.getValue(), totalRows)));
+ var docStats = new BM25Utils.DocStats(cappedFrequencies, totalRows);
+ return BM25Utils.computeScores(keyIterator,
+ queryTerms,
+ docStats,
+ indexContext,
+ sstable.descriptor.id);
+ }
+
+ @Override
+ public CloseableIterator orderResultsBy(SSTableReader reader, QueryContext queryContext, List keys, Orderer orderer, int limit) throws IOException
+ {
+ if (!orderer.isBM25())
+ return super.orderResultsBy(reader, queryContext, keys, orderer, limit);
+ if (docLengthsReader == null)
+ throw new InvalidRequestException(indexContext.getIndexName() + " does not support BM25 scoring until it is rebuilt");
+
+ var queryTerms = orderer.getQueryTerms();
+ // compute documentFrequencies from either histogram or an index search
+ var documentFrequencies = new HashMap();
+ // any index new enough to support BM25 should also support histograms
+ assert metadata.version.onDiskFormat().indexFeatureSet().hasTermsHistogram();
+ for (ByteBuffer term : queryTerms)
+ {
+ long matches = metadata.estimateNumRowsMatching(new Expression(indexContext).add(Operator.ANALYZER_MATCHES, term));
+ documentFrequencies.put(term, matches);
+ }
+ var analyzer = indexContext.getAnalyzerFactory().create();
+ var it = keys.stream().map(pk -> DocTF.createFromDocument(pk, readColumn(sstable, pk), analyzer, queryTerms)).iterator();
+ return bm25Internal(CloseableIterator.wrap(it), queryTerms, documentFrequencies);
}
@Override
@@ -147,7 +279,7 @@ public String toString()
@Override
public void close()
{
- reader.close();
+ FileUtils.closeQuietly(reader, docLengthsReader);
}
/**
@@ -172,7 +304,7 @@ protected RowIdWithByteComparable computeNext()
while (true)
{
long nextPosting = currentPostingList.nextPosting();
- if (nextPosting != PostingList.END_OF_STREAM)
+ if (nextPosting != END_OF_STREAM)
return new RowIdWithByteComparable(Math.toIntExact(nextPosting), currentTerm);
if (!source.hasNext())
diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcher.java
index 8a2aa354bd47..5911f2351014 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcher.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcher.java
@@ -84,7 +84,7 @@ public long indexFileCacheSize()
}
@Override
- public KeyRangeIterator search(Expression exp, AbstractBounds keyRange, QueryContext context, boolean defer, int limit) throws IOException
+ public KeyRangeIterator search(Expression exp, AbstractBounds keyRange, QueryContext context, boolean defer) throws IOException
{
PostingList postingList = searchPosting(exp, context);
return toPrimaryKeyIterator(postingList, context);
diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/MemtableIndexWriter.java b/src/java/org/apache/cassandra/index/sai/disk/v1/MemtableIndexWriter.java
index 102699ca0288..8683409896ad 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/v1/MemtableIndexWriter.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/v1/MemtableIndexWriter.java
@@ -19,15 +19,15 @@
import java.io.IOException;
import java.lang.invoke.MethodHandles;
+import java.util.Arrays;
import java.util.Collections;
-import java.util.Iterator;
import java.util.concurrent.TimeUnit;
import com.google.common.base.Stopwatch;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import com.carrotsearch.hppc.IntArrayList;
+import org.agrona.collections.Int2IntHashMap;
import org.apache.cassandra.db.DecoratedKey;
import org.apache.cassandra.db.marshal.AbstractType;
import org.apache.cassandra.db.rows.Row;
@@ -36,17 +36,17 @@
import org.apache.cassandra.index.sai.disk.PerIndexWriter;
import org.apache.cassandra.index.sai.disk.format.IndexComponents;
import org.apache.cassandra.index.sai.disk.format.Version;
-import org.apache.cassandra.index.sai.disk.vector.VectorMemtableIndex;
import org.apache.cassandra.index.sai.disk.v1.kdtree.ImmutableOneDimPointValues;
import org.apache.cassandra.index.sai.disk.v1.kdtree.NumericIndexWriter;
import org.apache.cassandra.index.sai.disk.v1.trie.InvertedIndexWriter;
+import org.apache.cassandra.index.sai.disk.vector.VectorMemtableIndex;
import org.apache.cassandra.index.sai.memory.MemtableIndex;
import org.apache.cassandra.index.sai.memory.RowMapping;
+import org.apache.cassandra.index.sai.memory.TrieMemoryIndex;
+import org.apache.cassandra.index.sai.memory.TrieMemtableIndex;
import org.apache.cassandra.index.sai.utils.PrimaryKey;
import org.apache.cassandra.index.sai.utils.TypeUtil;
import org.apache.cassandra.utils.ByteBufferUtil;
-import org.apache.cassandra.utils.Pair;
-import org.apache.cassandra.utils.bytecomparable.ByteComparable;
/**
* Column index writer that flushes indexed data directly from the corresponding Memtable index, without buffering index
@@ -126,7 +126,7 @@ public void complete(Stopwatch stopwatch) throws IOException
}
else
{
- final Iterator> iterator = rowMapping.merge(memtableIndex);
+ var iterator = rowMapping.merge(memtableIndex);
try (MemtableTermsIterator terms = new MemtableTermsIterator(memtableIndex.getMinTerm(), memtableIndex.getMaxTerm(), iterator))
{
long cellCount = flush(minKey, maxKey, indexContext().getValidator(), terms, rowMapping.maxSegmentRowId);
@@ -151,9 +151,21 @@ private long flush(DecoratedKey minKey, DecoratedKey maxKey, AbstractType> ter
SegmentMetadata.ComponentMetadataMap indexMetas;
if (TypeUtil.isLiteral(termComparator))
{
- try (InvertedIndexWriter writer = new InvertedIndexWriter(perIndexComponents))
+ try (InvertedIndexWriter writer = new InvertedIndexWriter(perIndexComponents, writeFrequencies()))
{
- indexMetas = writer.writeAll(metadataBuilder.intercept(terms));
+ // Convert PrimaryKey->length map to rowId->length using RowMapping
+ var docLengths = new Int2IntHashMap(Integer.MIN_VALUE);
+ Arrays.stream(((TrieMemtableIndex) memtableIndex).getRangeIndexes())
+ .map(TrieMemoryIndex.class::cast)
+ .forEach(trieMemoryIndex ->
+ trieMemoryIndex.getDocLengths().forEach((pk, length) -> {
+ int rowId = rowMapping.get(pk);
+ if (rowId >= 0)
+ docLengths.put(rowId, (int) length);
+ })
+ );
+
+ indexMetas = writer.writeAll(metadataBuilder.intercept(terms), docLengths);
numRows = writer.getPostingsCount();
}
}
@@ -194,6 +206,11 @@ private long flush(DecoratedKey minKey, DecoratedKey maxKey, AbstractType> ter
return numRows;
}
+ private boolean writeFrequencies()
+ {
+ return indexContext().isAnalyzed() && Version.latest().onOrAfter(Version.EC);
+ }
+
private void flushVectorIndex(DecoratedKey minKey, DecoratedKey maxKey, long startTime, Stopwatch stopwatch) throws IOException
{
var vectorIndex = (VectorMemtableIndex) memtableIndex;
diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/PartitionAwarePrimaryKeyFactory.java b/src/java/org/apache/cassandra/index/sai/disk/v1/PartitionAwarePrimaryKeyFactory.java
index 3b49e22a84fd..c1b47916229b 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/v1/PartitionAwarePrimaryKeyFactory.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/v1/PartitionAwarePrimaryKeyFactory.java
@@ -21,6 +21,7 @@
import java.util.Objects;
import java.util.function.Supplier;
+import io.github.jbellis.jvector.util.RamUsageEstimator;
import org.apache.cassandra.db.Clustering;
import org.apache.cassandra.db.DecoratedKey;
import org.apache.cassandra.dht.Token;
@@ -136,6 +137,19 @@ private ByteSource asComparableBytes(int terminator, ByteComparable.Version vers
return ByteSource.withTerminator(terminator, tokenComparable, keyComparable, null);
}
+ @Override
+ public long ramBytesUsed()
+ {
+ // Compute shallow size: object header + 4 references (3 declared + 1 implicit outer reference)
+ long shallowSize = RamUsageEstimator.NUM_BYTES_OBJECT_HEADER + 4L * RamUsageEstimator.NUM_BYTES_OBJECT_REF;
+ long preHashedDecoratedKeySize = partitionKey == null
+ ? 0
+ : RamUsageEstimator.NUM_BYTES_OBJECT_HEADER
+ + 2L * RamUsageEstimator.NUM_BYTES_OBJECT_REF // token and key references
+ + 2L * Long.BYTES;
+ return shallowSize + token.getHeapSize() + preHashedDecoratedKeySize;
+ }
+
@Override
public int compareTo(PrimaryKey o)
{
diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/PerIndexFiles.java b/src/java/org/apache/cassandra/index/sai/disk/v1/PerIndexFiles.java
index e4a1ff6bca9e..b82c23cc589d 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/v1/PerIndexFiles.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/v1/PerIndexFiles.java
@@ -28,6 +28,7 @@
import org.apache.cassandra.index.sai.disk.format.IndexComponents;
import org.apache.cassandra.index.sai.disk.format.IndexComponentType;
+import org.apache.cassandra.index.sai.disk.format.Version;
import org.apache.cassandra.io.util.FileHandle;
import org.apache.cassandra.io.util.FileUtils;
@@ -104,6 +105,12 @@ public FileHandle pq()
return getFile(IndexComponentType.PQ).sharedCopy();
}
+ /** It is the caller's responsibility to close the returned file handle. */
+ public FileHandle docLengths()
+ {
+ return getFile(IndexComponentType.DOC_LENGTHS).sharedCopy();
+ }
+
public FileHandle getFile(IndexComponentType indexComponentType)
{
FileHandle file = files.get(indexComponentType);
diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/SSTableIndexWriter.java b/src/java/org/apache/cassandra/index/sai/disk/v1/SSTableIndexWriter.java
index bb85ad583cba..a4812fb120ac 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/v1/SSTableIndexWriter.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/v1/SSTableIndexWriter.java
@@ -241,7 +241,7 @@ else if (shouldFlush(sstableRowId))
if (term.remaining() == 0 && TypeUtil.skipsEmptyValue(indexContext.getValidator()))
return;
- long allocated = currentBuilder.addAll(term, type, key, sstableRowId);
+ long allocated = currentBuilder.analyzeAndAdd(term, type, key, sstableRowId);
limiter.increment(allocated);
}
diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/Segment.java b/src/java/org/apache/cassandra/index/sai/disk/v1/Segment.java
index 87fe0a998e5f..3d52c944d726 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/v1/Segment.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/v1/Segment.java
@@ -144,7 +144,7 @@ public long indexFileCacheSize()
*/
public KeyRangeIterator search(Expression expression, AbstractBounds keyRange, QueryContext context, boolean defer, int limit) throws IOException
{
- return index.search(expression, keyRange, context, defer, limit);
+ return index.search(expression, keyRange, context, defer);
}
/**
diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentBuilder.java b/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentBuilder.java
index e1cbfc318180..095b43d3926f 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentBuilder.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentBuilder.java
@@ -29,6 +29,7 @@
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.atomic.LongAdder;
+import java.util.stream.Collectors;
import javax.annotation.concurrent.NotThreadSafe;
import com.google.common.annotations.VisibleForTesting;
@@ -42,6 +43,7 @@
import org.apache.cassandra.index.sai.IndexContext;
import org.apache.cassandra.index.sai.analyzer.AbstractAnalyzer;
import org.apache.cassandra.index.sai.analyzer.ByteLimitedMaterializer;
+import org.apache.cassandra.index.sai.analyzer.NoOpAnalyzer;
import org.apache.cassandra.index.sai.disk.PostingList;
import org.apache.cassandra.index.sai.disk.RAMStringIndexer;
import org.apache.cassandra.index.sai.disk.TermsIterator;
@@ -154,9 +156,11 @@ public boolean isEmpty()
return kdTreeRamBuffer.numRows() == 0;
}
- protected long addInternal(ByteBuffer term, int segmentRowId)
+ @Override
+ protected long addInternal(List terms, int segmentRowId)
{
- TypeUtil.toComparableBytes(term, termComparator, buffer);
+ assert terms.size() == 1;
+ TypeUtil.toComparableBytes(terms.get(0), termComparator, buffer);
return kdTreeRamBuffer.addPackedValue(segmentRowId, new BytesRef(buffer));
}
@@ -192,10 +196,14 @@ public static class RAMStringSegmentBuilder extends SegmentBuilder
{
super(components, rowIdOffset, limiter);
this.byteComparableVersion = components.byteComparableVersionFor(IndexComponentType.TERMS_DATA);
- ramIndexer = new RAMStringIndexer();
+ ramIndexer = new RAMStringIndexer(writeFrequencies());
totalBytesAllocated = ramIndexer.estimatedBytesUsed();
totalBytesAllocatedConcurrent.add(totalBytesAllocated);
+ }
+ private boolean writeFrequencies()
+ {
+ return !(analyzer instanceof NoOpAnalyzer) && Version.latest().onOrAfter(Version.EC);
}
public boolean isEmpty()
@@ -203,21 +211,26 @@ public boolean isEmpty()
return ramIndexer.isEmpty();
}
- protected long addInternal(ByteBuffer term, int segmentRowId)
+ @Override
+ protected long addInternal(List terms, int segmentRowId)
{
- var encodedTerm = components.onDiskFormat().encodeForTrie(term, termComparator);
- var bytes = ByteSourceInverse.readBytes(encodedTerm.asComparableBytes(byteComparableVersion));
- var bytesRef = new BytesRef(bytes);
- return ramIndexer.add(bytesRef, segmentRowId);
+ var bytesRefs = terms.stream()
+ .map(term -> components.onDiskFormat().encodeForTrie(term, termComparator))
+ .map(encodedTerm -> ByteSourceInverse.readBytes(encodedTerm.asComparableBytes(byteComparableVersion)))
+ .map(BytesRef::new)
+ .collect(Collectors.toList());
+ // ramIndexer is responsible for merging duplicate (term, row) pairs
+ return ramIndexer.addAll(bytesRefs, segmentRowId);
}
@Override
protected void flushInternal(SegmentMetadataBuilder metadataBuilder) throws IOException
{
- try (InvertedIndexWriter writer = new InvertedIndexWriter(components))
+ try (InvertedIndexWriter writer = new InvertedIndexWriter(components, writeFrequencies()))
{
TermsIterator termsWithPostings = ramIndexer.getTermsWithPostings(minTerm, maxTerm, byteComparableVersion);
- var metadataMap = writer.writeAll(metadataBuilder.intercept(termsWithPostings));
+ var docLengths = ramIndexer.getDocLengths();
+ var metadataMap = writer.writeAll(metadataBuilder.intercept(termsWithPostings), docLengths);
metadataBuilder.setComponentsMetadata(metadataMap);
}
}
@@ -261,21 +274,23 @@ public boolean isEmpty()
}
@Override
- protected long addInternal(ByteBuffer term, int segmentRowId)
+ protected long addInternal(List terms, int segmentRowId)
{
throw new UnsupportedOperationException();
}
@Override
- protected long addInternalAsync(ByteBuffer term, int segmentRowId)
+ protected long addInternalAsync(List terms, int segmentRowId)
{
+ assert terms.size() == 1;
+
// CompactionGraph splits adding a node into two parts:
// (1) maybeAddVector, which must be done serially because it writes to disk incrementally
// (2) addGraphNode, which may be done asynchronously
CompactionGraph.InsertionResult result;
try
{
- result = graphIndex.maybeAddVector(term, segmentRowId);
+ result = graphIndex.maybeAddVector(terms.get(0), segmentRowId);
}
catch (IOException e)
{
@@ -367,19 +382,20 @@ public boolean isEmpty()
}
@Override
- protected long addInternal(ByteBuffer term, int segmentRowId)
+ protected long addInternal(List terms, int segmentRowId)
{
- return graphIndex.add(term, segmentRowId);
+ assert terms.size() == 1;
+ return graphIndex.add(terms.get(0), segmentRowId);
}
@Override
- protected long addInternalAsync(ByteBuffer term, int segmentRowId)
+ protected long addInternalAsync(List terms, int segmentRowId)
{
updatesInFlight.incrementAndGet();
compactionExecutor.submit(() -> {
try
{
- long bytesAdded = addInternal(term, segmentRowId);
+ long bytesAdded = addInternal(terms, segmentRowId);
totalBytesAllocatedConcurrent.add(bytesAdded);
termSizeReservoir.update(bytesAdded);
}
@@ -454,23 +470,22 @@ public SegmentMetadata flush() throws IOException
return metadataBuilder.build();
}
- public long addAll(ByteBuffer term, AbstractType> type, PrimaryKey key, long sstableRowId)
+ public long analyzeAndAdd(ByteBuffer rawTerm, AbstractType> type, PrimaryKey key, long sstableRowId)
{
long totalSize = 0;
if (TypeUtil.isLiteral(type))
{
- List tokens = ByteLimitedMaterializer.materializeTokens(analyzer, term, components.context(), key);
- for (ByteBuffer tokenTerm : tokens)
- totalSize += add(tokenTerm, key, sstableRowId);
+ var terms = ByteLimitedMaterializer.materializeTokens(analyzer, rawTerm, components.context(), key);
+ totalSize += add(terms, key, sstableRowId);
}
else
{
- totalSize += add(term, key, sstableRowId);
+ totalSize += add(List.of(rawTerm), key, sstableRowId);
}
return totalSize;
}
- private long add(ByteBuffer term, PrimaryKey key, long sstableRowId)
+ private long add(List terms, PrimaryKey key, long sstableRowId)
{
assert !flushed : "Cannot add to flushed segment.";
assert sstableRowId >= maxSSTableRowId;
@@ -481,9 +496,12 @@ private long add(ByteBuffer term, PrimaryKey key, long sstableRowId)
minKey = minKey == null ? key : minKey;
maxKey = key;
- // Note that the min and max terms are not encoded.
- minTerm = TypeUtil.min(term, minTerm, termComparator, Version.latest());
- maxTerm = TypeUtil.max(term, maxTerm, termComparator, Version.latest());
+ // Update term boundaries for all terms in this row
+ for (ByteBuffer term : terms)
+ {
+ minTerm = TypeUtil.min(term, minTerm, termComparator, Version.latest());
+ maxTerm = TypeUtil.max(term, maxTerm, termComparator, Version.latest());
+ }
rowCount++;
@@ -495,15 +513,23 @@ private long add(ByteBuffer term, PrimaryKey key, long sstableRowId)
maxSegmentRowId = Math.max(maxSegmentRowId, segmentRowId);
- long bytesAllocated = supportsAsyncAdd()
- ? addInternalAsync(term, segmentRowId)
- : addInternal(term, segmentRowId);
- totalBytesAllocated += bytesAllocated;
+ long bytesAllocated;
+ if (supportsAsyncAdd())
+ {
+ // only vector indexing is done async and there can only be one term
+ assert terms.size() == 1;
+ bytesAllocated = addInternalAsync(terms, segmentRowId);
+ }
+ else
+ {
+ bytesAllocated = addInternal(terms, segmentRowId);
+ }
+ totalBytesAllocated += bytesAllocated;
return bytesAllocated;
}
- protected long addInternalAsync(ByteBuffer term, int segmentRowId)
+ protected long addInternalAsync(List terms, int segmentRowId)
{
throw new UnsupportedOperationException();
}
@@ -566,7 +592,7 @@ long release(IndexContext indexContext)
public abstract boolean isEmpty();
- protected abstract long addInternal(ByteBuffer term, int segmentRowId);
+ protected abstract long addInternal(List terms, int segmentRowId);
protected abstract void flushInternal(SegmentMetadataBuilder metadataBuilder) throws IOException;
diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentMetadata.java b/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentMetadata.java
index 00361d6f27ab..b8f4cd769873 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentMetadata.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentMetadata.java
@@ -376,6 +376,11 @@ public ComponentMetadata get(IndexComponentType indexComponentType)
return metas.get(indexComponentType);
}
+ public ComponentMetadata getOptional(IndexComponentType indexComponentType)
+ {
+ return metas.get(indexComponentType);
+ }
+
public Map> asMap()
{
Map> metaAttributes = new HashMap<>();
diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/TermsReader.java b/src/java/org/apache/cassandra/index/sai/disk/v1/TermsReader.java
index a75871ec28cd..f05511b195a5 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/v1/TermsReader.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/v1/TermsReader.java
@@ -228,7 +228,7 @@ public PostingsReader getPostingReader(long offset) throws IOException
{
PostingsReader.BlocksSummary header = new PostingsReader.BlocksSummary(postingsSummaryInput, offset);
- return new PostingsReader(postingsInput, header, listener.postingListEventListener());
+ return new PostingsReader(postingsInput, header, readFrequencies(), listener.postingListEventListener());
}
}
@@ -363,11 +363,17 @@ private PostingsReader currentReader(IndexInput postingsInput,
PostingsReader.InputCloser.NOOP);
return new PostingsReader(postingsInput,
blocksSummary,
+ readFrequencies(),
listener.postingListEventListener(),
PostingsReader.InputCloser.NOOP);
}
}
+ private boolean readFrequencies()
+ {
+ return indexContext.isAnalyzed() && version.onOrAfter(Version.EC);
+ }
+
private class TermsScanner implements TermsIterator
{
private final TrieTermsDictionaryReader termsDictionaryReader;
@@ -400,7 +406,7 @@ public PostingList postings() throws IOException
{
assert entry != null;
var blockSummary = new PostingsReader.BlocksSummary(postingsSummaryInput, entry.right, PostingsReader.InputCloser.NOOP);
- return new ScanningPostingsReader(postingsInput, blockSummary);
+ return new ScanningPostingsReader(postingsInput, blockSummary, readFrequencies());
}
@Override
@@ -461,7 +467,7 @@ public PostingList postings() throws IOException
{
assert entry != null;
var blockSummary = new PostingsReader.BlocksSummary(postingsSummaryInput, entry.right, PostingsReader.InputCloser.NOOP);
- return new ScanningPostingsReader(postingsInput, blockSummary);
+ return new ScanningPostingsReader(postingsInput, blockSummary, readFrequencies());
}
@Override
diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/V1OnDiskFormat.java b/src/java/org/apache/cassandra/index/sai/disk/v1/V1OnDiskFormat.java
index 5dbd732f4676..8afe5d62b4c1 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/v1/V1OnDiskFormat.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/v1/V1OnDiskFormat.java
@@ -83,10 +83,11 @@ public class V1OnDiskFormat implements OnDiskFormat
IndexComponentType.META,
IndexComponentType.TERMS_DATA,
IndexComponentType.POSTING_LISTS);
- private static final Set NUMERIC_COMPONENTS = EnumSet.of(IndexComponentType.COLUMN_COMPLETION_MARKER,
- IndexComponentType.META,
- IndexComponentType.KD_TREE,
- IndexComponentType.KD_TREE_POSTING_LISTS);
+
+ public static final Set NUMERIC_COMPONENTS = EnumSet.of(IndexComponentType.COLUMN_COMPLETION_MARKER,
+ IndexComponentType.META,
+ IndexComponentType.KD_TREE,
+ IndexComponentType.KD_TREE_POSTING_LISTS);
/**
* Global limit on heap consumed by all index segment building that occurs outside the context of Memtable flush.
diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java
new file mode 100644
index 000000000000..295796c9066f
--- /dev/null
+++ b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.index.sai.disk.v1.postings;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import javax.annotation.concurrent.NotThreadSafe;
+
+import org.apache.cassandra.index.sai.disk.PostingList;
+import org.apache.cassandra.io.util.FileUtils;
+
+/**
+ * Performs intersection operations on multiple PostingLists, returning only postings
+ * that appear in all inputs.
+ */
+@NotThreadSafe
+public class IntersectingPostingList implements PostingList
+{
+ private final Map postingsByTerm;
+ private final List postingLists; // so we can access by ordinal in intersection code
+ private final int size;
+
+ private IntersectingPostingList(Map postingsByTerm)
+ {
+ if (postingsByTerm.isEmpty())
+ throw new AssertionError();
+ this.postingsByTerm = postingsByTerm;
+ this.postingLists = new ArrayList<>(postingsByTerm.values());
+ this.size = postingLists.stream()
+ .mapToInt(PostingList::size)
+ .min()
+ .orElse(0);
+ }
+
+ /**
+ * @return the intersection of the provided term-posting list mappings
+ */
+ public static IntersectingPostingList intersect(Map postingsByTerm)
+ {
+ // TODO optimize cases where
+ // - we have a single postinglist
+ // - any posting list is empty (intersection also empty)
+ return new IntersectingPostingList(postingsByTerm);
+ }
+
+ @Override
+ public int nextPosting() throws IOException
+ {
+ return findNextIntersection(Integer.MIN_VALUE, false);
+ }
+
+ @Override
+ public int advance(int targetRowID) throws IOException
+ {
+ assert targetRowID >= 0 : targetRowID;
+ return findNextIntersection(targetRowID, true);
+ }
+
+ @Override
+ public int frequency()
+ {
+ // call frequencies() instead
+ throw new UnsupportedOperationException();
+ }
+
+ public Map frequencies()
+ {
+ Map result = new HashMap<>();
+ for (Map.Entry entry : postingsByTerm.entrySet())
+ result.put(entry.getKey(), entry.getValue().frequency());
+ return result;
+ }
+
+ private int findNextIntersection(int targetRowID, boolean isAdvance) throws IOException
+ {
+ int maxRowId = targetRowID;
+ int maxRowIdIndex = -1;
+
+ // Scan through all posting lists looking for a common row ID
+ for (int i = 0; i < postingLists.size(); i++)
+ {
+ // don't advance the sublist in which we found our current max
+ if (i == maxRowIdIndex)
+ continue;
+
+ // Advance this sublist to the current max, special casing the first one as needed
+ PostingList list = postingLists.get(i);
+ int rowId = (isAdvance || maxRowIdIndex >= 0)
+ ? list.advance(maxRowId)
+ : list.nextPosting();
+ if (rowId == END_OF_STREAM)
+ return END_OF_STREAM;
+
+ // Update maxRowId + index if we find a larger value, or this was the first sublist evaluated
+ if (rowId > maxRowId || maxRowIdIndex < 0)
+ {
+ maxRowId = rowId;
+ maxRowIdIndex = i;
+ i = -1; // restart the scan with new maxRowId
+ }
+ }
+
+ // Once we complete a full scan without finding a larger rowId, we've found an intersection
+ return maxRowId;
+ }
+
+ @Override
+ public int size()
+ {
+ return size;
+ }
+
+ @Override
+ public void close()
+ {
+ for (PostingList list : postingLists)
+ FileUtils.closeQuietly(list);
+ }
+}
+
+
diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/PostingsReader.java b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/PostingsReader.java
index ed3556c4ffe5..0ee2ab63b132 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/PostingsReader.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/PostingsReader.java
@@ -51,7 +51,7 @@ public class PostingsReader implements OrdinalPostingList
protected final IndexInput input;
protected final InputCloser runOnClose;
- private final int blockSize;
+ private final int blockEntries;
private final int numPostings;
private final LongArray blockOffsets;
private final LongArray blockMaxValues;
@@ -69,6 +69,8 @@ public class PostingsReader implements OrdinalPostingList
private long currentPosition;
private LongValues currentFORValues;
private int postingsDecoded = 0;
+ private int currentFrequency = Integer.MIN_VALUE;
+ private final boolean readFrequencies;
@VisibleForTesting
public PostingsReader(IndexInput input, long summaryOffset, QueryEventListener.PostingListEventListener listener) throws IOException
@@ -78,7 +80,7 @@ public PostingsReader(IndexInput input, long summaryOffset, QueryEventListener.P
public PostingsReader(IndexInput input, BlocksSummary summary, QueryEventListener.PostingListEventListener listener) throws IOException
{
- this(input, summary, listener, () -> {
+ this(input, summary, false, listener, () -> {
try
{
input.close();
@@ -90,14 +92,29 @@ public PostingsReader(IndexInput input, BlocksSummary summary, QueryEventListene
});
}
- public PostingsReader(IndexInput input, BlocksSummary summary, QueryEventListener.PostingListEventListener listener, InputCloser runOnClose) throws IOException
+ public PostingsReader(IndexInput input, BlocksSummary summary, boolean readFrequencies, QueryEventListener.PostingListEventListener listener) throws IOException
+ {
+ this(input, summary, readFrequencies, listener, () -> {
+ try
+ {
+ input.close();
+ }
+ finally
+ {
+ summary.close();
+ }
+ });
+ }
+
+ public PostingsReader(IndexInput input, BlocksSummary summary, boolean readFrequencies, QueryEventListener.PostingListEventListener listener, InputCloser runOnClose) throws IOException
{
assert input instanceof IndexInputReader;
logger.trace("Opening postings reader for {}", input);
+ this.readFrequencies = readFrequencies;
this.input = input;
this.seekingInput = new SeekingRandomAccessInput(input);
this.blockOffsets = summary.offsets;
- this.blockSize = summary.blockSize;
+ this.blockEntries = summary.blockEntries;
this.numPostings = summary.numPostings;
this.blockMaxValues = summary.maxValues;
this.listener = listener;
@@ -122,7 +139,7 @@ public interface InputCloser
public static class BlocksSummary
{
- final int blockSize;
+ final int blockEntries;
final int numPostings;
final LongArray offsets;
final LongArray maxValues;
@@ -139,7 +156,7 @@ public BlocksSummary(IndexInput input, long offset, InputCloser runOnClose) thro
this.runOnClose = runOnClose;
input.seek(offset);
- this.blockSize = input.readVInt();
+ this.blockEntries = input.readVInt();
// This is the count of row ids in a single posting list. For now, a segment cannot have more than
// Integer.MAX_VALUE row ids, so it is safe to use an int here.
this.numPostings = input.readVInt();
@@ -323,10 +340,10 @@ private void lastPosInBlock(int block)
// blockMaxValues is integer only
actualSegmentRowId = Math.toIntExact(blockMaxValues.get(block));
//upper bound, since we might've advanced to the last block, but upper bound is enough
- totalPostingsRead += (blockSize - blockIdx) + (block - postingsBlockIdx + 1) * blockSize;
+ totalPostingsRead += (blockEntries - blockIdx) + (block - postingsBlockIdx + 1) * blockEntries;
postingsBlockIdx = block + 1;
- blockIdx = blockSize;
+ blockIdx = blockEntries;
}
@Override
@@ -341,9 +358,9 @@ public int nextPosting() throws IOException
}
@VisibleForTesting
- int getBlockSize()
+ int getBlockEntries()
{
- return blockSize;
+ return blockEntries;
}
private int peekNext() throws IOException
@@ -352,27 +369,28 @@ private int peekNext() throws IOException
{
return END_OF_STREAM;
}
- if (blockIdx == blockSize)
+ if (blockIdx == blockEntries)
{
reBuffer();
}
- return actualSegmentRowId + nextRowID();
+ return actualSegmentRowId + nextRowDelta();
}
- private int nextRowID()
+ private int nextRowDelta()
{
- // currentFORValues is null when the all the values in the block are the same
if (currentFORValues == null)
{
+ currentFrequency = Integer.MIN_VALUE;
return 0;
}
- else
- {
- final long id = currentFORValues.get(blockIdx);
- postingsDecoded++;
- return Math.toIntExact(id);
- }
+
+ long offset = readFrequencies ? 2L * blockIdx : blockIdx;
+ long id = currentFORValues.get(offset);
+ if (readFrequencies)
+ currentFrequency = Math.toIntExact(currentFORValues.get(offset + 1));
+ postingsDecoded++;
+ return Math.toIntExact(id);
}
private void advanceOnePosition(int nextRowID)
@@ -420,4 +438,9 @@ else if (bitsPerValue > 64)
}
currentFORValues = LuceneCompat.directReaderGetInstance(seekingInput, bitsPerValue, currentPosition);
}
+
+ @Override
+ public int frequency() {
+ return currentFrequency;
+ }
}
diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/PostingsWriter.java b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/PostingsWriter.java
index ef027e1e2a39..5bf0decfb8e4 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/PostingsWriter.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/PostingsWriter.java
@@ -27,7 +27,6 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import org.agrona.collections.IntArrayList;
import org.agrona.collections.LongArrayList;
import org.apache.cassandra.index.sai.disk.PostingList;
import org.apache.cassandra.index.sai.disk.format.IndexComponentType;
@@ -42,6 +41,7 @@
import static com.google.common.base.Preconditions.checkArgument;
import static java.lang.Math.max;
+import static java.lang.Math.min;
/**
@@ -65,7 +65,7 @@
*
* Each posting list ends with a meta section and a skip table, that are written right after all postings blocks. Skip
* interval is the same as block size, and each skip entry points to the end of each block. Skip table consist of
- * block offsets and last values of each block, compressed as two FoR blocks.
+ * block offsets and maximum rowids of each block, compressed as two FoR blocks.
*
*
* Visual representation of the disk format:
@@ -80,7 +80,7 @@
* | LIST SIZE | SKIP TABLE |
* +---------------+------------+
* | BLOCKS POS.|
- * | MAX VALUES |
+ * | MAX ROWIDS |
* +------------+
*
*
@@ -91,13 +91,14 @@ public class PostingsWriter implements Closeable
protected static final Logger logger = LoggerFactory.getLogger(PostingsWriter.class);
// import static org.apache.lucene.codecs.lucene50.Lucene50PostingsFormat.BLOCK_SIZE;
- private final static int BLOCK_SIZE = 128;
+ private final static int BLOCK_ENTRIES = 128;
private static final String POSTINGS_MUST_BE_SORTED_ERROR_MSG = "Postings must be sorted ascending, got [%s] after [%s]";
private final IndexOutput dataOutput;
- private final int blockSize;
+ private final int blockEntries;
private final long[] deltaBuffer;
+ private final int[] freqBuffer; // frequency is capped at 255
private final LongArrayList blockOffsets = new LongArrayList();
private final LongArrayList blockMaxIDs = new LongArrayList();
private final ResettableByteBuffersIndexOutput inMemoryOutput;
@@ -106,35 +107,43 @@ public class PostingsWriter implements Closeable
private int bufferUpto;
private long lastSegmentRowId;
- private long maxDelta;
// This number is the count of row ids written to the postings for this segment. Because a segment row id can be in
// multiple postings list for the segment, this number could exceed Integer.MAX_VALUE, so we use a long.
private long totalPostings;
+ private final boolean writeFrequencies;
public PostingsWriter(IndexComponents.ForWrite components) throws IOException
{
- this(components, BLOCK_SIZE);
+ this(components, BLOCK_ENTRIES);
}
+ public PostingsWriter(IndexComponents.ForWrite components, boolean writeFrequencies) throws IOException
+ {
+ this(components.addOrGet(IndexComponentType.POSTING_LISTS).openOutput(true), BLOCK_ENTRIES, writeFrequencies);
+ }
+
+
public PostingsWriter(IndexOutput dataOutput) throws IOException
{
- this(dataOutput, BLOCK_SIZE);
+ this(dataOutput, BLOCK_ENTRIES, false);
}
@VisibleForTesting
- PostingsWriter(IndexComponents.ForWrite components, int blockSize) throws IOException
+ PostingsWriter(IndexComponents.ForWrite components, int blockEntries) throws IOException
{
- this(components.addOrGet(IndexComponentType.POSTING_LISTS).openOutput(true), blockSize);
+ this(components.addOrGet(IndexComponentType.POSTING_LISTS).openOutput(true), blockEntries, false);
}
- private PostingsWriter(IndexOutput dataOutput, int blockSize) throws IOException
+ private PostingsWriter(IndexOutput dataOutput, int blockEntries, boolean writeFrequencies) throws IOException
{
assert dataOutput instanceof IndexOutputWriter;
logger.debug("Creating postings writer for output {}", dataOutput);
- this.blockSize = blockSize;
+ this.writeFrequencies = writeFrequencies;
+ this.blockEntries = blockEntries;
this.dataOutput = dataOutput;
startOffset = dataOutput.getFilePointer();
- deltaBuffer = new long[blockSize];
+ deltaBuffer = new long[blockEntries];
+ freqBuffer = new int[blockEntries];
inMemoryOutput = LuceneCompat.getResettableByteBuffersIndexOutput(dataOutput.order(), 1024, "blockOffsets");
SAICodecUtils.writeHeader(dataOutput);
}
@@ -191,7 +200,7 @@ public long write(PostingList postings) throws IOException
int size = 0;
while ((segmentRowId = postings.nextPosting()) != PostingList.END_OF_STREAM)
{
- writePosting(segmentRowId);
+ writePosting(segmentRowId, postings.frequency());
size++;
totalPostings++;
}
@@ -210,19 +219,19 @@ public long getTotalPostings()
return totalPostings;
}
- private void writePosting(long segmentRowId) throws IOException
- {
+ private void writePosting(long segmentRowId, int freq) throws IOException {
if (!(segmentRowId >= lastSegmentRowId || lastSegmentRowId == 0))
throw new IllegalArgumentException(String.format(POSTINGS_MUST_BE_SORTED_ERROR_MSG, segmentRowId, lastSegmentRowId));
+ assert freq > 0;
final long delta = segmentRowId - lastSegmentRowId;
- maxDelta = max(maxDelta, delta);
- deltaBuffer[bufferUpto++] = delta;
+ deltaBuffer[bufferUpto] = delta;
+ freqBuffer[bufferUpto] = min(freq, 255);
+ bufferUpto++;
- if (bufferUpto == blockSize)
- {
+ if (bufferUpto == blockEntries) {
addBlockToSkipTable(segmentRowId);
- writePostingsBlock(maxDelta, bufferUpto);
+ writePostingsBlock(bufferUpto);
resetBlockCounters();
}
lastSegmentRowId = segmentRowId;
@@ -234,7 +243,7 @@ private void finish() throws IOException
{
addBlockToSkipTable(lastSegmentRowId);
- writePostingsBlock(maxDelta, bufferUpto);
+ writePostingsBlock(bufferUpto);
}
}
@@ -242,7 +251,6 @@ private void resetBlockCounters()
{
bufferUpto = 0;
lastSegmentRowId = 0;
- maxDelta = 0;
}
private void addBlockToSkipTable(long maxSegmentRowID)
@@ -253,7 +261,7 @@ private void addBlockToSkipTable(long maxSegmentRowID)
private void writeSummary(int exactSize) throws IOException
{
- dataOutput.writeVInt(blockSize);
+ dataOutput.writeVInt(blockEntries);
dataOutput.writeVInt(exactSize);
writeSkipTable();
}
@@ -272,19 +280,29 @@ private void writeSkipTable() throws IOException
writeSortedFoRBlock(blockMaxIDs, dataOutput);
}
- private void writePostingsBlock(long maxValue, int blockSize) throws IOException
- {
+ private void writePostingsBlock(int entries) throws IOException {
+ // Find max value to determine bits needed
+ long maxValue = 0;
+ for (int i = 0; i < entries; i++) {
+ maxValue = max(maxValue, deltaBuffer[i]);
+ if (writeFrequencies)
+ maxValue = max(maxValue, freqBuffer[i]);
+ }
+
+ // Use the maximum bits needed for either value type
final int bitsPerValue = maxValue == 0 ? 0 : LuceneCompat.directWriterUnsignedBitsRequired(dataOutput.order(), maxValue);
-
- assert bitsPerValue < Byte.MAX_VALUE;
-
+
dataOutput.writeByte((byte) bitsPerValue);
- if (bitsPerValue > 0)
- {
- final DirectWriterAdapter writer = LuceneCompat.directWriterGetInstance(dataOutput.order(), dataOutput, blockSize, bitsPerValue);
- for (int i = 0; i < blockSize; ++i)
- {
+ if (bitsPerValue > 0) {
+ // Write interleaved [delta][freq] pairs
+ final DirectWriterAdapter writer = LuceneCompat.directWriterGetInstance(dataOutput.order(),
+ dataOutput,
+ writeFrequencies ? entries * 2L : entries,
+ bitsPerValue);
+ for (int i = 0; i < entries; ++i) {
writer.add(deltaBuffer[i]);
+ if (writeFrequencies)
+ writer.add(freqBuffer[i]);
}
writer.finish();
}
@@ -292,9 +310,9 @@ private void writePostingsBlock(long maxValue, int blockSize) throws IOException
private void writeSortedFoRBlock(LongArrayList values, IndexOutput output) throws IOException
{
+ assert !values.isEmpty();
final long maxValue = values.getLong(values.size() - 1);
- assert values.size() > 0;
final int bitsPerValue = maxValue == 0 ? 0 : LuceneCompat.directWriterUnsignedBitsRequired(output.order(), maxValue);
output.writeByte((byte) bitsPerValue);
if (bitsPerValue > 0)
@@ -307,22 +325,4 @@ private void writeSortedFoRBlock(LongArrayList values, IndexOutput output) throw
writer.finish();
}
}
-
- private void writeSortedFoRBlock(IntArrayList values, IndexOutput output) throws IOException
- {
- final int maxValue = values.getInt(values.size() - 1);
-
- assert values.size() > 0;
- final int bitsPerValue = maxValue == 0 ? 0 : LuceneCompat.directWriterUnsignedBitsRequired(output.order(), maxValue);
- output.writeByte((byte) bitsPerValue);
- if (bitsPerValue > 0)
- {
- final DirectWriterAdapter writer = LuceneCompat.directWriterGetInstance(output.order(), output, values.size(), bitsPerValue);
- for (int i = 0; i < values.size(); ++i)
- {
- writer.add(values.getInt(i));
- }
- writer.finish();
- }
- }
}
diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/VectorPostingList.java b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/ReorderingPostingList.java
similarity index 83%
rename from src/java/org/apache/cassandra/index/sai/disk/v1/postings/VectorPostingList.java
rename to src/java/org/apache/cassandra/index/sai/disk/v1/postings/ReorderingPostingList.java
index 52155bf6ed59..f43c7c4e8dce 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/VectorPostingList.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/ReorderingPostingList.java
@@ -19,31 +19,29 @@
package org.apache.cassandra.index.sai.disk.v1.postings;
import java.io.IOException;
+import java.util.function.ToIntFunction;
import org.apache.cassandra.index.sai.disk.PostingList;
-import org.apache.cassandra.index.sai.utils.RowIdWithMeta;
import org.apache.cassandra.utils.CloseableIterator;
import org.apache.lucene.util.LongHeap;
/**
* A posting list for ANN search results. Transforms results from similarity order to rowId order.
*/
-public class VectorPostingList implements PostingList
+public class ReorderingPostingList implements PostingList
{
private final LongHeap segmentRowIds;
private final int size;
- public VectorPostingList(CloseableIterator extends RowIdWithMeta> source)
+ public ReorderingPostingList(CloseableIterator source, ToIntFunction rowIdTransformer)
{
- // TODO find int specific data structure?
segmentRowIds = new LongHeap(32);
int n = 0;
- // Once the source is consumed, we have to close it.
try (source)
{
while (source.hasNext())
{
- segmentRowIds.push(source.next().getSegmentRowId());
+ segmentRowIds.push(rowIdTransformer.applyAsInt(source.next()));
n++;
}
}
diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/ScanningPostingsReader.java b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/ScanningPostingsReader.java
index 2fd7c2336ab2..5c9da6169e08 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/ScanningPostingsReader.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/ScanningPostingsReader.java
@@ -31,9 +31,9 @@
*/
public class ScanningPostingsReader extends PostingsReader
{
- public ScanningPostingsReader(IndexInput input, BlocksSummary summary) throws IOException
+ public ScanningPostingsReader(IndexInput input, BlocksSummary summary, boolean readFrequencies) throws IOException
{
- super(input, summary, QueryEventListener.PostingListEventListener.NO_OP, InputCloser.NOOP);
+ super(input, summary, readFrequencies, QueryEventListener.PostingListEventListener.NO_OP, InputCloser.NOOP);
}
@Override
diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/trie/DocLengthsWriter.java b/src/java/org/apache/cassandra/index/sai/disk/v1/trie/DocLengthsWriter.java
new file mode 100644
index 000000000000..90dc6b5c6ff9
--- /dev/null
+++ b/src/java/org/apache/cassandra/index/sai/disk/v1/trie/DocLengthsWriter.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.cassandra.index.sai.disk.v1.trie;
+
+import java.io.Closeable;
+import java.io.IOException;
+
+import org.agrona.collections.Int2IntHashMap;
+import org.apache.cassandra.index.sai.disk.format.IndexComponents;
+import org.apache.cassandra.index.sai.disk.format.IndexComponentType;
+import org.apache.cassandra.index.sai.disk.io.IndexOutputWriter;
+import org.apache.cassandra.index.sai.utils.SAICodecUtils;
+
+/**
+ * Writes document length information to disk for use in text scoring
+ */
+public class DocLengthsWriter implements Closeable
+{
+ private final IndexOutputWriter output;
+
+ public DocLengthsWriter(IndexComponents.ForWrite components) throws IOException
+ {
+ this.output = components.addOrGet(IndexComponentType.DOC_LENGTHS).openOutput(true);
+ SAICodecUtils.writeHeader(output);
+ }
+
+ public void writeDocLengths(Int2IntHashMap lengths) throws IOException
+ {
+ // Calculate max row ID from doc lengths map
+ int maxRowId = -1;
+ for (var keyIterator = lengths.keySet().iterator(); keyIterator.hasNext(); )
+ {
+ int key = keyIterator.nextValue();
+ if (key > maxRowId)
+ maxRowId = key;
+ }
+
+ // write out the doc lengths in row order
+ for (int rowId = 0; rowId <= maxRowId; rowId++)
+ {
+ final int length = lengths.get(rowId);
+ output.writeInt(length == lengths.missingValue() ? 0 : length);
+ }
+
+ SAICodecUtils.writeFooter(output);
+ }
+
+ public long getFilePointer()
+ {
+ return output.getFilePointer();
+ }
+
+ @Override
+ public void close() throws IOException
+ {
+ output.close();
+ }
+}
diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/trie/InvertedIndexWriter.java b/src/java/org/apache/cassandra/index/sai/disk/v1/trie/InvertedIndexWriter.java
index e6adb61ce863..a61df65f2239 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/v1/trie/InvertedIndexWriter.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/v1/trie/InvertedIndexWriter.java
@@ -25,10 +25,12 @@
import org.apache.commons.lang3.mutable.MutableLong;
+import org.agrona.collections.Int2IntHashMap;
import org.apache.cassandra.index.sai.disk.PostingList;
import org.apache.cassandra.index.sai.disk.TermsIterator;
import org.apache.cassandra.index.sai.disk.format.IndexComponents;
import org.apache.cassandra.index.sai.disk.format.IndexComponentType;
+import org.apache.cassandra.index.sai.disk.format.Version;
import org.apache.cassandra.index.sai.disk.v1.SegmentMetadata;
import org.apache.cassandra.index.sai.disk.v1.postings.PostingsWriter;
import org.apache.cassandra.index.sai.utils.SAICodecUtils;
@@ -42,12 +44,19 @@ public class InvertedIndexWriter implements Closeable
{
private final TrieTermsDictionaryWriter termsDictionaryWriter;
private final PostingsWriter postingsWriter;
+ private final DocLengthsWriter docLengthsWriter;
private long postingsAdded;
public InvertedIndexWriter(IndexComponents.ForWrite components) throws IOException
+ {
+ this(components, false);
+ }
+
+ public InvertedIndexWriter(IndexComponents.ForWrite components, boolean writeFrequencies) throws IOException
{
this.termsDictionaryWriter = new TrieTermsDictionaryWriter(components);
- this.postingsWriter = new PostingsWriter(components);
+ this.postingsWriter = new PostingsWriter(components, writeFrequencies);
+ this.docLengthsWriter = Version.latest().onOrAfter(Version.EC) ? new DocLengthsWriter(components) : null;
}
/**
@@ -58,7 +67,7 @@ public InvertedIndexWriter(IndexComponents.ForWrite components) throws IOExcepti
* @return metadata describing the location of this inverted index in the overall SSTable
* terms and postings component files
*/
- public SegmentMetadata.ComponentMetadataMap writeAll(TermsIterator terms) throws IOException
+ public SegmentMetadata.ComponentMetadataMap writeAll(TermsIterator terms, Int2IntHashMap docLengths) throws IOException
{
// Terms and postings writers are opened in append mode with pointers at the end of their respective files.
long termsOffset = termsDictionaryWriter.getStartOffset();
@@ -91,6 +100,15 @@ public SegmentMetadata.ComponentMetadataMap writeAll(TermsIterator terms) throws
components.put(IndexComponentType.POSTING_LISTS, -1, postingsOffset, postingsLength);
components.put(IndexComponentType.TERMS_DATA, termsRoot, termsOffset, termsLength, map);
+ // Write doc lengths
+ if (docLengthsWriter != null)
+ {
+ long docLengthsOffset = docLengthsWriter.getFilePointer();
+ docLengthsWriter.writeDocLengths(docLengths);
+ long docLengthsLength = docLengthsWriter.getFilePointer() - docLengthsOffset;
+ components.put(IndexComponentType.DOC_LENGTHS, -1, docLengthsOffset, docLengthsLength);
+ }
+
return components;
}
@@ -99,6 +117,8 @@ public void close() throws IOException
{
postingsWriter.close();
termsDictionaryWriter.close();
+ if (docLengthsWriter != null)
+ docLengthsWriter.close();
}
/**
diff --git a/src/java/org/apache/cassandra/index/sai/disk/v2/RowAwarePrimaryKeyFactory.java b/src/java/org/apache/cassandra/index/sai/disk/v2/RowAwarePrimaryKeyFactory.java
index e51a6229de58..cba1e1124ce3 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/v2/RowAwarePrimaryKeyFactory.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/v2/RowAwarePrimaryKeyFactory.java
@@ -23,6 +23,7 @@
import java.util.function.Supplier;
import java.util.stream.Collectors;
+import io.github.jbellis.jvector.util.RamUsageEstimator;
import org.apache.cassandra.db.Clustering;
import org.apache.cassandra.db.ClusteringComparator;
import org.apache.cassandra.db.DecoratedKey;
@@ -212,5 +213,22 @@ public String toString()
.map(ByteBufferUtil::bytesToHex)
.collect(Collectors.toList())));
}
+
+ @Override
+ public long ramBytesUsed()
+ {
+ // Object header + 4 references (token, partitionKey, clustering, primaryKeySupplier) + implicit outer reference
+ long size = RamUsageEstimator.NUM_BYTES_OBJECT_HEADER +
+ 5L * RamUsageEstimator.NUM_BYTES_OBJECT_REF;
+
+ if (token != null)
+ size += token.getHeapSize();
+ if (partitionKey != null)
+ size += RamUsageEstimator.NUM_BYTES_OBJECT_HEADER +
+ 2L * RamUsageEstimator.NUM_BYTES_OBJECT_REF + // token and key references
+ 2L * Long.BYTES;
+ // We don't count clustering size here as it's managed elsewhere
+ return size;
+ }
}
}
diff --git a/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java
index 80870dab36bf..7da1097e64ae 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java
@@ -49,7 +49,7 @@
import org.apache.cassandra.index.sai.disk.v1.IndexSearcher;
import org.apache.cassandra.index.sai.disk.v1.PerIndexFiles;
import org.apache.cassandra.index.sai.disk.v1.SegmentMetadata;
-import org.apache.cassandra.index.sai.disk.v1.postings.VectorPostingList;
+import org.apache.cassandra.index.sai.disk.v1.postings.ReorderingPostingList;
import org.apache.cassandra.index.sai.disk.v5.V5VectorPostingsWriter;
import org.apache.cassandra.index.sai.disk.vector.BruteForceRowIdIterator;
import org.apache.cassandra.index.sai.disk.vector.CassandraDiskAnn;
@@ -64,8 +64,8 @@
import org.apache.cassandra.index.sai.utils.PrimaryKey;
import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey;
import org.apache.cassandra.index.sai.utils.RangeUtil;
+import org.apache.cassandra.index.sai.utils.RowIdWithMeta;
import org.apache.cassandra.index.sai.utils.RowIdWithScore;
-import org.apache.cassandra.index.sai.utils.SegmentOrdering;
import org.apache.cassandra.io.sstable.format.SSTableReader;
import org.apache.cassandra.metrics.LinearFit;
import org.apache.cassandra.metrics.PairedSlidingWindowReservoir;
@@ -81,7 +81,7 @@
/**
* Executes ann search against the graph for an individual index segment.
*/
-public class V2VectorIndexSearcher extends IndexSearcher implements SegmentOrdering
+public class V2VectorIndexSearcher extends IndexSearcher
{
private static final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
private static final VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport();
@@ -133,13 +133,13 @@ public ProductQuantization getPQ()
}
@Override
- public KeyRangeIterator search(Expression exp, AbstractBounds keyRange, QueryContext context, boolean defer, int limit) throws IOException
+ public KeyRangeIterator search(Expression exp, AbstractBounds keyRange, QueryContext context, boolean defer) throws IOException
{
- PostingList results = searchPosting(context, exp, keyRange, limit);
+ PostingList results = searchPosting(context, exp, keyRange);
return toPrimaryKeyIterator(results, context);
}
- private PostingList searchPosting(QueryContext context, Expression exp, AbstractBounds keyRange, int limit) throws IOException
+ private PostingList searchPosting(QueryContext context, Expression exp, AbstractBounds keyRange) throws IOException
{
if (logger.isTraceEnabled())
logger.trace(indexContext.logMessage("Searching on expression '{}'..."), exp);
@@ -151,7 +151,7 @@ private PostingList searchPosting(QueryContext context, Expression exp, Abstract
// this is a thresholded query, so pass graph.size() as top k to get all results satisfying the threshold
var result = searchInternal(keyRange, context, queryVector, graph.size(), graph.size(), exp.getEuclideanSearchThreshold());
- return new VectorPostingList(result);
+ return new ReorderingPostingList(result, RowIdWithMeta::getSegmentRowId);
}
@Override
@@ -160,11 +160,11 @@ public CloseableIterator orderBy(Orderer orderer, Express
if (logger.isTraceEnabled())
logger.trace(indexContext.logMessage("Searching on expression '{}'..."), orderer);
- if (orderer.vector == null)
+ if (!orderer.isANN())
throw new IllegalArgumentException(indexContext.logMessage("Unsupported expression during ANN index query: " + orderer));
int rerankK = indexContext.getIndexWriterConfig().getSourceModel().rerankKFor(limit, graph.getCompression());
- var queryVector = vts.createFloatVector(orderer.vector);
+ var queryVector = vts.createFloatVector(orderer.getVectorTerm());
var result = searchInternal(keyRange, context, queryVector, limit, rerankK, 0);
return toMetaSortedIterator(result, context);
@@ -485,14 +485,14 @@ public CloseableIterator orderResultsBy(SSTableReader rea
if (cost.shouldUseBruteForce())
{
// brute force using the in-memory compressed vectors to cut down the number of results returned
- var queryVector = vts.createFloatVector(orderer.vector);
+ var queryVector = vts.createFloatVector(orderer.getVectorTerm());
return toMetaSortedIterator(this.orderByBruteForce(queryVector, segmentOrdinalPairs, limit, rerankK), context);
}
// Create bits from the mapping
var bits = bitSetForSearch();
segmentOrdinalPairs.forEachRightInt(bits::set);
// else ask the index to perform a search limited to the bits we created
- var queryVector = vts.createFloatVector(orderer.vector);
+ var queryVector = vts.createFloatVector(orderer.getVectorTerm());
var results = graph.search(queryVector, limit, rerankK, 0, bits, context, cost::updateStatistics);
return toMetaSortedIterator(results, context);
}
diff --git a/src/java/org/apache/cassandra/index/sai/disk/v4/V4InvertedIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v4/V4InvertedIndexSearcher.java
index ac11b238844e..f819b6e1ee6f 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/v4/V4InvertedIndexSearcher.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/v4/V4InvertedIndexSearcher.java
@@ -37,6 +37,6 @@ class V4InvertedIndexSearcher extends InvertedIndexSearcher
SegmentMetadata segmentMetadata,
IndexContext indexContext) throws IOException
{
- super(sstableContext, perIndexFiles, segmentMetadata, indexContext, Version.DB, false);
+ super(sstableContext, perIndexFiles, segmentMetadata, indexContext, segmentMetadata.version, false);
}
}
diff --git a/src/java/org/apache/cassandra/index/sai/disk/v7/V7OnDiskFormat.java b/src/java/org/apache/cassandra/index/sai/disk/v7/V7OnDiskFormat.java
new file mode 100644
index 000000000000..5457685cf1b0
--- /dev/null
+++ b/src/java/org/apache/cassandra/index/sai/disk/v7/V7OnDiskFormat.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.index.sai.disk.v7;
+
+import java.util.EnumSet;
+import java.util.Set;
+
+import org.apache.cassandra.db.marshal.AbstractType;
+import org.apache.cassandra.index.sai.disk.format.IndexComponentType;
+import org.apache.cassandra.index.sai.disk.v1.V1OnDiskFormat;
+import org.apache.cassandra.index.sai.disk.v3.V3OnDiskFormat;
+import org.apache.cassandra.index.sai.disk.v6.V6OnDiskFormat;
+import org.apache.cassandra.index.sai.utils.TypeUtil;
+
+public class V7OnDiskFormat extends V6OnDiskFormat
+{
+ public static final V7OnDiskFormat instance = new V7OnDiskFormat();
+
+ private static final Set LITERAL_COMPONENTS = EnumSet.of(IndexComponentType.COLUMN_COMPLETION_MARKER,
+ IndexComponentType.META,
+ IndexComponentType.TERMS_DATA,
+ IndexComponentType.POSTING_LISTS,
+ IndexComponentType.DOC_LENGTHS);
+
+ @Override
+ public Set perIndexComponentTypes(AbstractType> validator)
+ {
+ if (validator.isVector())
+ return V3OnDiskFormat.VECTOR_COMPONENTS_V3;
+ if (TypeUtil.isLiteral(validator))
+ return LITERAL_COMPONENTS;
+ return V1OnDiskFormat.NUMERIC_COMPONENTS;
+ }
+}
diff --git a/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java b/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java
index 25231748149f..a6a52c462f3b 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java
@@ -54,6 +54,7 @@
import org.apache.cassandra.index.sai.disk.format.IndexComponents;
import org.apache.cassandra.index.sai.disk.v1.SegmentMetadata;
import org.apache.cassandra.index.sai.iterators.KeyRangeIterator;
+import org.apache.cassandra.index.sai.memory.MemoryIndex;
import org.apache.cassandra.index.sai.memory.MemtableIndex;
import org.apache.cassandra.index.sai.plan.Expression;
import org.apache.cassandra.index.sai.plan.Orderer;
@@ -222,7 +223,7 @@ public List> orderBy(QueryContext conte
assert slice == null : "ANN does not support index slicing";
assert orderer.isANN() : "Only ANN is supported for vector search, received " + orderer.operator;
- var qv = vts.createFloatVector(orderer.vector);
+ var qv = vts.createFloatVector(orderer.getVectorTerm());
return List.of(searchInternal(context, qv, keyRange, limit, 0));
}
@@ -310,7 +311,7 @@ public CloseableIterator orderResultsBy(QueryContext cont
relevantOrdinals.size(), keys.size(), maxBruteForceRows, graph.size(), limit);
// convert the expression value to query vector
- var qv = vts.createFloatVector(orderer.vector);
+ var qv = vts.createFloatVector(orderer.getVectorTerm());
// brute force path
if (keysInGraph.size() <= maxBruteForceRows)
{
@@ -422,7 +423,7 @@ public static int ensureSaneEstimate(int rawEstimate, int rerankK, int graphSize
}
@Override
- public Iterator>> iterator(DecoratedKey min, DecoratedKey max)
+ public Iterator>> iterator(DecoratedKey min, DecoratedKey max)
{
// This method is only used when merging an in-memory index with a RowMapping. This is done a different
// way with the graph using the writeData method below.
diff --git a/src/java/org/apache/cassandra/index/sai/memory/MemoryIndex.java b/src/java/org/apache/cassandra/index/sai/memory/MemoryIndex.java
index 83c43e212d78..b50a58805430 100644
--- a/src/java/org/apache/cassandra/index/sai/memory/MemoryIndex.java
+++ b/src/java/org/apache/cassandra/index/sai/memory/MemoryIndex.java
@@ -20,6 +20,7 @@
import java.nio.ByteBuffer;
import java.util.Iterator;
+import java.util.List;
import java.util.function.LongConsumer;
import org.apache.cassandra.db.Clustering;
@@ -30,8 +31,8 @@
import org.apache.cassandra.index.sai.iterators.KeyRangeIterator;
import org.apache.cassandra.index.sai.plan.Expression;
import org.apache.cassandra.index.sai.plan.Orderer;
+import org.apache.cassandra.index.sai.utils.PrimaryKey;
import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey;
-import org.apache.cassandra.index.sai.utils.PrimaryKeys;
import org.apache.cassandra.utils.CloseableIterator;
import org.apache.cassandra.utils.Pair;
import org.apache.cassandra.utils.bytecomparable.ByteComparable;
@@ -64,5 +65,17 @@ public abstract void add(DecoratedKey key,
/**
* Iterate all Term->PrimaryKeys mappings in sorted order
*/
- public abstract Iterator> iterator();
+ public abstract Iterator>> iterator();
+
+ public static class PkWithFrequency
+ {
+ public final PrimaryKey pk;
+ public final int frequency;
+
+ public PkWithFrequency(PrimaryKey pk, int frequency)
+ {
+ this.pk = pk;
+ this.frequency = frequency;
+ }
+ }
}
diff --git a/src/java/org/apache/cassandra/index/sai/memory/MemtableIndex.java b/src/java/org/apache/cassandra/index/sai/memory/MemtableIndex.java
index 22bdc284bf3f..f340b06a5e19 100644
--- a/src/java/org/apache/cassandra/index/sai/memory/MemtableIndex.java
+++ b/src/java/org/apache/cassandra/index/sai/memory/MemtableIndex.java
@@ -20,6 +20,7 @@
import java.nio.ByteBuffer;
import java.util.Iterator;
+import java.util.List;
import javax.annotation.Nullable;
import org.apache.cassandra.db.Clustering;
@@ -32,7 +33,6 @@
import org.apache.cassandra.index.sai.disk.vector.VectorMemtableIndex;
import org.apache.cassandra.index.sai.iterators.KeyRangeIterator;
import org.apache.cassandra.index.sai.plan.Expression;
-import org.apache.cassandra.index.sai.utils.PrimaryKey;
import org.apache.cassandra.index.sai.utils.MemtableOrdering;
import org.apache.cassandra.utils.Pair;
import org.apache.cassandra.utils.bytecomparable.ByteComparable;
@@ -79,7 +79,7 @@ default void update(DecoratedKey key, Clustering clustering, ByteBuffer oldValue
long estimateMatchingRowsCount(Expression expression, AbstractBounds keyRange);
- Iterator>> iterator(DecoratedKey min, DecoratedKey max);
+ Iterator>> iterator(DecoratedKey min, DecoratedKey max);
static MemtableIndex createIndex(IndexContext indexContext, Memtable mt)
{
diff --git a/src/java/org/apache/cassandra/index/sai/memory/RowMapping.java b/src/java/org/apache/cassandra/index/sai/memory/RowMapping.java
index a435a6e47169..e0cdc610a63c 100644
--- a/src/java/org/apache/cassandra/index/sai/memory/RowMapping.java
+++ b/src/java/org/apache/cassandra/index/sai/memory/RowMapping.java
@@ -17,10 +17,11 @@
*/
package org.apache.cassandra.index.sai.memory;
+import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
+import java.util.List;
-import com.carrotsearch.hppc.IntArrayList;
import org.apache.cassandra.db.compaction.OperationType;
import org.apache.cassandra.db.rows.RangeTombstoneMarker;
import org.apache.cassandra.db.rows.Row;
@@ -44,7 +45,7 @@ public class RowMapping
public static final RowMapping DUMMY = new RowMapping()
{
@Override
- public Iterator> merge(MemtableIndex index) { return Collections.emptyIterator(); }
+ public Iterator>> merge(MemtableIndex index) { return Collections.emptyIterator(); }
@Override
public void complete() {}
@@ -89,6 +90,16 @@ public static RowMapping create(OperationType opType)
return DUMMY;
}
+ public static class RowIdWithFrequency {
+ public final int rowId;
+ public final int frequency;
+
+ public RowIdWithFrequency(int rowId, int frequency) {
+ this.rowId = rowId;
+ this.frequency = frequency;
+ }
+ }
+
/**
* Merge IndexMemtable(index term to PrimaryKeys mappings) with row mapping of a sstable
* (PrimaryKey to RowId mappings).
@@ -97,33 +108,32 @@ public static RowMapping create(OperationType opType)
*
* @return iterator of index term to postings mapping exists in the sstable
*/
- public Iterator> merge(MemtableIndex index)
+ public Iterator>> merge(MemtableIndex index)
{
assert complete : "RowMapping is not built.";
- Iterator>> iterator = index.iterator(minKey.partitionKey(), maxKey.partitionKey());
- return new AbstractGuavaIterator>()
+ var it = index.iterator(minKey.partitionKey(), maxKey.partitionKey());
+ return new AbstractGuavaIterator<>()
{
@Override
- protected Pair computeNext()
+ protected Pair> computeNext()
{
- while (iterator.hasNext())
+ while (it.hasNext())
{
- Pair> pair = iterator.next();
+ var pair = it.next();
- IntArrayList postings = null;
- Iterator primaryKeys = pair.right;
+ List postings = null;
+ var primaryKeysWithFreq = pair.right;
- while (primaryKeys.hasNext())
+ for (var pkWithFreq : primaryKeysWithFreq)
{
- PrimaryKey primaryKey = primaryKeys.next();
- ByteComparable byteComparable = v -> primaryKey.asComparableBytes(v);
+ ByteComparable byteComparable = pkWithFreq.pk::asComparableBytes;
Integer segmentRowId = rowMapping.get(byteComparable);
if (segmentRowId != null)
{
- postings = postings == null ? new IntArrayList() : postings;
- postings.add(segmentRowId);
+ postings = postings == null ? new ArrayList<>() : postings;
+ postings.add(new RowIdWithFrequency(segmentRowId, pkWithFreq.frequency));
}
}
if (postings != null && !postings.isEmpty())
diff --git a/src/java/org/apache/cassandra/index/sai/memory/TrieMemoryIndex.java b/src/java/org/apache/cassandra/index/sai/memory/TrieMemoryIndex.java
index 57517d9192d0..593f11acccca 100644
--- a/src/java/org/apache/cassandra/index/sai/memory/TrieMemoryIndex.java
+++ b/src/java/org/apache/cassandra/index/sai/memory/TrieMemoryIndex.java
@@ -28,13 +28,15 @@
import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.util.ArrayList;
-
import java.util.Collection;
-
+import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
+import java.util.Objects;
import java.util.SortedSet;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.LongAdder;
import java.util.function.LongConsumer;
import javax.annotation.Nullable;
@@ -43,6 +45,8 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import io.github.jbellis.jvector.util.Accountable;
+import io.github.jbellis.jvector.util.RamUsageEstimator;
import io.netty.util.concurrent.FastThreadLocal;
import org.apache.cassandra.db.Clustering;
import org.apache.cassandra.db.DecoratedKey;
@@ -57,6 +61,7 @@
import org.apache.cassandra.dht.AbstractBounds;
import org.apache.cassandra.index.sai.IndexContext;
import org.apache.cassandra.index.sai.analyzer.AbstractAnalyzer;
+import org.apache.cassandra.index.sai.analyzer.NoOpAnalyzer;
import org.apache.cassandra.index.sai.disk.format.Version;
import org.apache.cassandra.index.sai.disk.v6.TermsDistribution;
import org.apache.cassandra.index.sai.iterators.KeyRangeIterator;
@@ -83,6 +88,8 @@ public class TrieMemoryIndex extends MemoryIndex
private final InMemoryTrie data;
private final PrimaryKeysReducer primaryKeysReducer;
+ private final Map termFrequencies;
+ private final Map docLengths = new HashMap<>();
private final Memtable memtable;
private AbstractBounds keyBounds;
@@ -111,6 +118,48 @@ public TrieMemoryIndex(IndexContext indexContext, Memtable memtable, AbstractBou
this.data = InMemoryTrie.longLived(TypeUtil.BYTE_COMPARABLE_VERSION, TrieMemtable.BUFFER_TYPE, indexContext.columnFamilyStore().readOrdering());
this.primaryKeysReducer = new PrimaryKeysReducer();
this.memtable = memtable;
+ termFrequencies = new ConcurrentHashMap<>();
+ }
+
+ public synchronized Map getDocLengths()
+ {
+ return docLengths;
+ }
+
+ private static class PkWithTerm implements Accountable
+ {
+ private final PrimaryKey pk;
+ private final ByteComparable term;
+
+ private PkWithTerm(PrimaryKey pk, ByteComparable term)
+ {
+ this.pk = pk;
+ this.term = term;
+ }
+
+ @Override
+ public long ramBytesUsed()
+ {
+ return RamUsageEstimator.NUM_BYTES_OBJECT_HEADER +
+ 2L * RamUsageEstimator.NUM_BYTES_OBJECT_REF +
+ pk.ramBytesUsed() +
+ ByteComparable.length(term, TypeUtil.BYTE_COMPARABLE_VERSION);
+ }
+
+ @Override
+ public int hashCode()
+ {
+ return Objects.hash(pk, ByteComparable.length(term, TypeUtil.BYTE_COMPARABLE_VERSION));
+ }
+
+ @Override
+ public boolean equals(Object o)
+ {
+ if (o == null || getClass() != o.getClass()) return false;
+ PkWithTerm that = (PkWithTerm) o;
+ return Objects.equals(pk, that.pk)
+ && ByteComparable.compare(term, that.term, TypeUtil.BYTE_COMPARABLE_VERSION) == 0;
+ }
}
public synchronized void add(DecoratedKey key,
@@ -129,12 +178,33 @@ public synchronized void add(DecoratedKey key,
final long initialSizeOffHeap = data.usedSizeOffHeap();
final long reducerHeapSize = primaryKeysReducer.heapAllocations();
+ if (docLengths.containsKey(primaryKey) && !(analyzer instanceof NoOpAnalyzer))
+ {
+ AtomicLong heapReclaimed = new AtomicLong();
+ // we're overwriting an existing cell, clear out the old term counts
+ for (Map.Entry entry : data.entrySet())
+ {
+ var termInTrie = entry.getKey();
+ entry.getValue().forEach(pkInTrie -> {
+ if (pkInTrie.equals(primaryKey))
+ {
+ var t = new PkWithTerm(pkInTrie, termInTrie);
+ termFrequencies.remove(t);
+ heapReclaimed.addAndGet(RamUsageEstimator.HASHTABLE_RAM_BYTES_PER_ENTRY + t.ramBytesUsed() + Integer.BYTES);
+ }
+ });
+ }
+ }
+
+ int tokenCount = 0;
while (analyzer.hasNext())
{
final ByteBuffer term = analyzer.next();
if (!indexContext.validateMaxTermSize(key, term))
continue;
+ tokenCount++;
+
// Note that this term is already encoded once by the TypeUtil.encode call above.
setMinMaxTerm(term.duplicate());
@@ -142,7 +212,25 @@ public synchronized void add(DecoratedKey key,
try
{
- data.putSingleton(encodedTerm, primaryKey, primaryKeysReducer, term.limit() <= MAX_RECURSIVE_KEY_LENGTH);
+ data.putSingleton(encodedTerm, primaryKey, (existing, update) -> {
+ // First do the normal primary keys reduction
+ PrimaryKeys result = primaryKeysReducer.apply(existing, update);
+ if (analyzer instanceof NoOpAnalyzer)
+ return result;
+
+ // Then update term frequency
+ var pkbc = new PkWithTerm(update, encodedTerm);
+ termFrequencies.compute(pkbc, (k, oldValue) -> {
+ if (oldValue == null) {
+ // New key added, track heap allocation
+ onHeapAllocationsTracker.accept(RamUsageEstimator.HASHTABLE_RAM_BYTES_PER_ENTRY + k.ramBytesUsed() + Integer.BYTES);
+ return 1;
+ }
+ return oldValue + 1;
+ });
+
+ return result;
+ }, term.limit() <= MAX_RECURSIVE_KEY_LENGTH);
}
catch (TrieSpaceExhaustedException e)
{
@@ -150,6 +238,14 @@ public synchronized void add(DecoratedKey key,
}
}
+ docLengths.put(primaryKey, tokenCount);
+ // heap used for term frequencies and doc lengths
+ long heapUsed = RamUsageEstimator.HASHTABLE_RAM_BYTES_PER_ENTRY
+ + primaryKey.ramBytesUsed()
+ + Integer.BYTES;
+ onHeapAllocationsTracker.accept(heapUsed);
+
+ // memory used by the trie
onHeapAllocationsTracker.accept((data.usedSizeOnHeap() - initialSizeOnHeap) +
(primaryKeysReducer.heapAllocations() - reducerHeapSize));
offHeapAllocationsTracker.accept(data.usedSizeOffHeap() - initialSizeOffHeap);
@@ -161,10 +257,10 @@ public synchronized void add(DecoratedKey key,
}
@Override
- public Iterator> iterator()
+ public Iterator>> iterator()
{
Iterator> iterator = data.entrySet().iterator();
- return new Iterator>()
+ return new Iterator<>()
{
@Override
public boolean hasNext()
@@ -173,10 +269,17 @@ public boolean hasNext()
}
@Override
- public Pair next()
+ public Pair> next()
{
Map.Entry entry = iterator.next();
- return Pair.create(entry.getKey(), entry.getValue());
+ var pairs = new ArrayList(entry.getValue().size());
+ for (PrimaryKey pk : entry.getValue().keys())
+ {
+ var frequencyRaw = termFrequencies.get(new PkWithTerm(pk, entry.getKey()));
+ int frequency = frequencyRaw == null ? 1 : frequencyRaw;
+ pairs.add(new PkWithFrequency(pk, frequency));
+ }
+ return Pair.create(entry.getKey(), pairs);
}
};
}
@@ -418,7 +521,6 @@ private ByteComparable asByteComparable(ByteBuffer input)
return Version.latest().onDiskFormat().encodeForTrie(input, indexContext.getValidator());
}
-
class PrimaryKeysReducer implements InMemoryTrie.UpsertTransformer
{
private final LongAdder heapAllocations = new LongAdder();
@@ -662,6 +764,13 @@ public void close() throws IOException
}
}
+ /**
+ * Iterator that provides ordered access to all indexed terms and their associated primary keys
+ * in the TrieMemoryIndex. For each term in the index, yields PrimaryKeyWithSortKey objects that
+ * combine a primary key with its associated term.
+ *
+ * A more verbose name could be KeysMatchingTermsByTermIterator.
+ */
private class AllTermsIterator extends AbstractIterator
{
private final Iterator> iterator;
diff --git a/src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java b/src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java
index f1cb9f5f47e7..dd440a4c42aa 100644
--- a/src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java
+++ b/src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java
@@ -22,6 +22,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
+import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
@@ -30,28 +31,37 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
+import com.google.common.collect.Iterators;
import com.google.common.util.concurrent.Runnables;
+import org.apache.cassandra.cql3.Operator;
import org.apache.cassandra.db.Clustering;
+import org.apache.cassandra.db.DataRange;
import org.apache.cassandra.db.DecoratedKey;
import org.apache.cassandra.db.PartitionPosition;
+import org.apache.cassandra.db.RegularAndStaticColumns;
+import org.apache.cassandra.db.filter.ColumnFilter;
import org.apache.cassandra.db.marshal.AbstractType;
import org.apache.cassandra.db.memtable.Memtable;
import org.apache.cassandra.db.memtable.ShardBoundaries;
import org.apache.cassandra.db.memtable.TrieMemtable;
+import org.apache.cassandra.db.rows.Row;
import org.apache.cassandra.dht.AbstractBounds;
+import org.apache.cassandra.dht.Bounds;
import org.apache.cassandra.index.sai.IndexContext;
import org.apache.cassandra.index.sai.QueryContext;
import org.apache.cassandra.index.sai.disk.format.Version;
-import org.apache.cassandra.index.sai.iterators.KeyRangeLazyIterator;
import org.apache.cassandra.index.sai.iterators.KeyRangeConcatIterator;
+import org.apache.cassandra.index.sai.iterators.KeyRangeIntersectionIterator;
import org.apache.cassandra.index.sai.iterators.KeyRangeIterator;
+import org.apache.cassandra.index.sai.iterators.KeyRangeLazyIterator;
+import org.apache.cassandra.index.sai.memory.MemoryIndex.PkWithFrequency;
import org.apache.cassandra.index.sai.plan.Expression;
import org.apache.cassandra.index.sai.plan.Orderer;
+import org.apache.cassandra.index.sai.utils.BM25Utils;
import org.apache.cassandra.index.sai.utils.PrimaryKey;
import org.apache.cassandra.index.sai.utils.PrimaryKeyWithByteComparable;
import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey;
-import org.apache.cassandra.index.sai.utils.PrimaryKeys;
import org.apache.cassandra.index.sai.utils.TypeUtil;
import org.apache.cassandra.sensors.Context;
import org.apache.cassandra.sensors.RequestSensors;
@@ -63,7 +73,6 @@
import org.apache.cassandra.utils.Reducer;
import org.apache.cassandra.utils.SortingIterator;
import org.apache.cassandra.utils.bytecomparable.ByteComparable;
-import org.apache.cassandra.utils.bytecomparable.ByteSource;
import org.apache.cassandra.utils.concurrent.OpOrder;
public class TrieMemtableIndex implements MemtableIndex
@@ -236,15 +245,46 @@ public List> orderBy(QueryContext query
int startShard = boundaries.getShardForToken(keyRange.left.getToken());
int endShard = keyRange.right.isMinimum() ? boundaries.shardCount() - 1 : boundaries.getShardForToken(keyRange.right.getToken());
- var iterators = new ArrayList>(endShard - startShard + 1);
-
- for (int shard = startShard; shard <= endShard; ++shard)
+ if (!orderer.isBM25())
{
- assert rangeIndexes[shard] != null;
- iterators.add(rangeIndexes[shard].orderBy(orderer, slice));
+ var iterators = new ArrayList>(endShard - startShard + 1);
+ for (int shard = startShard; shard <= endShard; ++shard)
+ {
+ assert rangeIndexes[shard] != null;
+ iterators.add(rangeIndexes[shard].orderBy(orderer, slice));
+ }
+ return iterators;
}
- return iterators;
+ // BM25
+ var queryTerms = orderer.getQueryTerms();
+
+ // Intersect iterators to find documents containing all terms
+ var termIterators = keyIteratorsPerTerm(queryContext, keyRange, queryTerms);
+ var intersectedIterator = KeyRangeIntersectionIterator.builder(termIterators).build();
+
+ // Compute BM25 scores
+ var docStats = computeDocumentFrequencies(queryContext, queryTerms);
+ var analyzer = indexContext.getAnalyzerFactory().create();
+ var it = Iterators.transform(intersectedIterator, pk -> BM25Utils.DocTF.createFromDocument(pk, getCellForKey(pk), analyzer, queryTerms));
+ return List.of(BM25Utils.computeScores(CloseableIterator.wrap(it),
+ queryTerms,
+ docStats,
+ indexContext,
+ memtable));
+ }
+
+ private List keyIteratorsPerTerm(QueryContext queryContext, AbstractBounds keyRange, List queryTerms)
+ {
+ List termIterators = new ArrayList<>(queryTerms.size());
+ for (ByteBuffer term : queryTerms)
+ {
+ Expression expr = new Expression(indexContext);
+ expr.add(Operator.ANALYZER_MATCHES, term);
+ KeyRangeIterator iterator = search(queryContext, expr, keyRange, Integer.MAX_VALUE);
+ termIterators.add(iterator);
+ }
+ return termIterators;
}
@Override
@@ -256,38 +296,105 @@ public long estimateMatchingRowsCount(Expression expression, AbstractBounds orderResultsBy(QueryContext context, List keys, Orderer orderer, int limit)
+ public CloseableIterator orderResultsBy(QueryContext queryContext, List keys, Orderer orderer, int limit)
{
if (keys.isEmpty())
return CloseableIterator.emptyIterator();
- return SortingIterator.createCloseable(
- orderer.getComparator(),
- keys,
- key ->
+
+ if (!orderer.isBM25())
+ {
+ return SortingIterator.createCloseable(
+ orderer.getComparator(),
+ keys,
+ key ->
+ {
+ var partition = memtable.getPartition(key.partitionKey());
+ if (partition == null)
+ return null;
+ var row = partition.getRow(key.clustering());
+ if (row == null)
+ return null;
+ var cell = row.getCell(indexContext.getDefinition());
+ if (cell == null)
+ return null;
+
+ // We do two kinds of encoding... it'd be great to make this more straight forward, but this is what
+ // we have for now. I leave it to the reader to inspect the two methods to see the nuanced differences.
+ var encoding = encode(TypeUtil.encode(cell.buffer(), validator));
+ return new PrimaryKeyWithByteComparable(indexContext, memtable, key, encoding);
+ },
+ Runnables.doNothing()
+ );
+ }
+
+ // BM25
+ var analyzer = indexContext.getAnalyzerFactory().create();
+ var queryTerms = orderer.getQueryTerms();
+ var docStats = computeDocumentFrequencies(queryContext, queryTerms);
+ var it = keys.stream().map(pk -> BM25Utils.DocTF.createFromDocument(pk, getCellForKey(pk), analyzer, queryTerms)).iterator();
+ return BM25Utils.computeScores(CloseableIterator.wrap(it),
+ queryTerms,
+ docStats,
+ indexContext,
+ memtable);
+ }
+
+ /**
+ * Count document frequencies for each term using brute force
+ */
+ private BM25Utils.DocStats computeDocumentFrequencies(QueryContext queryContext, List queryTerms)
+ {
+ var termIterators = keyIteratorsPerTerm(queryContext, Bounds.unbounded(indexContext.getPartitioner()), queryTerms);
+ var documentFrequencies = new HashMap();
+ for (int i = 0; i < queryTerms.size(); i++)
+ {
+ // KeyRangeIterator.getMaxKeys is not accurate enough, we have to count them
+ long keys = 0;
+ for (var it = termIterators.get(i); it.hasNext(); it.next())
+ keys++;
+ documentFrequencies.put(queryTerms.get(i), keys);
+ }
+ long docCount = 0;
+
+ // count all documents in the queried column
+ try (var it = memtable.makePartitionIterator(ColumnFilter.selection(RegularAndStaticColumns.of(indexContext.getDefinition())),
+ DataRange.allData(memtable.metadata().partitioner)))
+ {
+ while (it.hasNext())
{
- var partition = memtable.getPartition(key.partitionKey());
- if (partition == null)
- return null;
- var row = partition.getRow(key.clustering());
- if (row == null)
- return null;
- var cell = row.getCell(indexContext.getDefinition());
- if (cell == null)
- return null;
-
- // We do two kinds of encoding... it'd be great to make this more straight forward, but this is what
- // we have for now. I leave it to the reader to inspect the two methods to see the nuanced differences.
- var encoding = encode(TypeUtil.encode(cell.buffer(), validator));
- return new PrimaryKeyWithByteComparable(indexContext, memtable, key, encoding);
- },
- Runnables.doNothing()
- );
+ var partitions = it.next();
+ while (partitions.hasNext())
+ {
+ var unfiltered = partitions.next();
+ if (!unfiltered.isRow())
+ continue;
+ var row = (Row) unfiltered;
+ var cell = row.getCell(indexContext.getDefinition());
+ if (cell == null)
+ continue;
+
+ docCount++;
+ }
+ }
+ }
+ return new BM25Utils.DocStats(documentFrequencies, docCount);
+ }
+
+ @Nullable
+ private org.apache.cassandra.db.rows.Cell> getCellForKey(PrimaryKey key)
+ {
+ var partition = memtable.getPartition(key.partitionKey());
+ if (partition == null)
+ return null;
+ var row = partition.getRow(key.clustering());
+ if (row == null)
+ return null;
+ return row.getCell(indexContext.getDefinition());
}
private ByteComparable encode(ByteBuffer input)
{
- return indexContext.isLiteral() ? v -> ByteSource.preencoded(input)
- : v -> TypeUtil.asComparableBytes(input, indexContext.getValidator(), v);
+ return Version.latest().onDiskFormat().encodeForTrie(input, indexContext.getValidator());
}
/**
@@ -302,26 +409,34 @@ private ByteComparable encode(ByteBuffer input)
* @return iterator of indexed term to primary keys mapping in sorted by indexed term and primary key.
*/
@Override
- public Iterator>> iterator(DecoratedKey min, DecoratedKey max)
+ public Iterator>> iterator(DecoratedKey min, DecoratedKey max)
{
int minSubrange = min == null ? 0 : boundaries.getShardForKey(min);
int maxSubrange = max == null ? rangeIndexes.length - 1 : boundaries.getShardForKey(max);
- List>> rangeIterators = new ArrayList<>(maxSubrange - minSubrange + 1);
+ List>>> rangeIterators = new ArrayList<>(maxSubrange - minSubrange + 1);
for (int i = minSubrange; i <= maxSubrange; i++)
rangeIterators.add(rangeIndexes[i].iterator());
- return MergeIterator.get(rangeIterators, (o1, o2) -> ByteComparable.compare(o1.left, o2.left, TypeUtil.BYTE_COMPARABLE_VERSION),
+ return MergeIterator.get(rangeIterators,
+ (o1, o2) -> ByteComparable.compare(o1.left, o2.left, TypeUtil.BYTE_COMPARABLE_VERSION),
new PrimaryKeysMergeReducer(rangeIterators.size()));
}
- // The PrimaryKeysMergeReducer receives the range iterators from each of the range indexes selected based on the
- // min and max keys passed to the iterator method. It doesn't strictly do any reduction because the terms in each
- // range index are unique. It will receive at most one range index entry per selected range index before getReduced
- // is called.
- private static class PrimaryKeysMergeReducer extends Reducer, Pair>>
+ /**
+ * Used to merge sorted primary keys from multiple TrieMemoryIndex shards for a given indexed term.
+ * For each term that appears in multiple shards, the reducer:
+ * 1. Receives exactly one call to reduce() per shard containing that term
+ * 2. Merges all the primary keys for that term via getReduced()
+ * 3. Resets state via onKeyChange() before processing the next term
+ *
+ * While this follows the Reducer pattern, its "reduction" operation is a simple merge since each term
+ * appears at most once per shard, and each key will only be found in a given shard, so there are no values to aggregate;
+ * we simply combine and sort the primary keys from each shard that contains the term.
+ */
+ private static class PrimaryKeysMergeReducer extends Reducer>, Pair>>
{
- private final Pair[] rangeIndexEntriesToMerge;
+ private final Pair>[] rangeIndexEntriesToMerge;
private final Comparator comparator;
private ByteComparable term;
@@ -337,7 +452,7 @@ private static class PrimaryKeysMergeReducer extends Reducer termPair)
+ public void reduce(int index, Pair> termPair)
{
Preconditions.checkArgument(rangeIndexEntriesToMerge[index] == null, "Terms should be unique in the memory index");
@@ -348,17 +463,17 @@ public void reduce(int index, Pair termPair)
@Override
// Return a merger of the term keys for the term.
- public Pair> getReduced()
+ public Pair> getReduced()
{
Preconditions.checkArgument(term != null, "The term must exist in the memory index");
- List> keyIterators = new ArrayList<>(rangeIndexEntriesToMerge.length);
- for (Pair p : rangeIndexEntriesToMerge)
- if (p != null && p.right != null && !p.right.isEmpty())
- keyIterators.add(p.right.iterator());
+ var merged = new ArrayList();
+ for (var p : rangeIndexEntriesToMerge)
+ if (p != null && p.right != null)
+ merged.addAll(p.right);
- Iterator primaryKeys = MergeIterator.get(keyIterators, comparator, Reducer.getIdentity());
- return Pair.create(term, primaryKeys);
+ merged.sort((o1, o2) -> comparator.compare(o1.pk, o2.pk));
+ return Pair.create(term, merged);
}
@Override
diff --git a/src/java/org/apache/cassandra/index/sai/plan/Expression.java b/src/java/org/apache/cassandra/index/sai/plan/Expression.java
index a1bd9acddc4b..aac8e240829b 100644
--- a/src/java/org/apache/cassandra/index/sai/plan/Expression.java
+++ b/src/java/org/apache/cassandra/index/sai/plan/Expression.java
@@ -101,6 +101,7 @@ public static Op valueOf(Operator operator)
return IN;
case ANN:
+ case BM25:
case ORDER_BY_ASC:
case ORDER_BY_DESC:
return ORDER_BY;
@@ -250,6 +251,7 @@ public Expression add(Operator op, ByteBuffer value)
boundedAnnEuclideanDistanceThreshold = GeoUtil.amplifiedEuclideanSimilarityThreshold(lower.value.vector, searchRadiusMeters);
break;
case ANN:
+ case BM25:
case ORDER_BY_ASC:
case ORDER_BY_DESC:
// If we alread have an operation on the column, we don't need to set the ORDER_BY op because
diff --git a/src/java/org/apache/cassandra/index/sai/plan/Operation.java b/src/java/org/apache/cassandra/index/sai/plan/Operation.java
index 7722ada87463..8968d4390f5a 100644
--- a/src/java/org/apache/cassandra/index/sai/plan/Operation.java
+++ b/src/java/org/apache/cassandra/index/sai/plan/Operation.java
@@ -93,7 +93,7 @@ protected static ListMultimap analyzeGroup(QueryCont
analyzer.reset(e.getIndexValue());
// EQ/LIKE_*/NOT_EQ can have multiple expressions e.g. text = "Hello World",
- // becomes text = "Hello" OR text = "World" because "space" is always interpreted as a split point (by analyzer),
+ // becomes text = "Hello" AND text = "World" because "space" is always interpreted as a split point (by analyzer),
// CONTAINS/CONTAINS_KEY are always treated as multiple expressions since they currently only targetting
// collections, NOT_EQ is made an independent expression only in case of pre-existing multiple EQ expressions, or
// if there is no EQ operations and NOT_EQ is met or a single NOT_EQ expression present,
@@ -102,6 +102,7 @@ protected static ListMultimap analyzeGroup(QueryCont
boolean isMultiExpression = columnIsMultiExpression.getOrDefault(e.column(), Boolean.FALSE);
switch (e.operator())
{
+ // case BM25: leave it at the default of `false`
case EQ:
// EQ operator will always be a multiple expression because it is being used by map entries
isMultiExpression = indexContext.isNonFrozenCollection();
diff --git a/src/java/org/apache/cassandra/index/sai/plan/Orderer.java b/src/java/org/apache/cassandra/index/sai/plan/Orderer.java
index e23c83ed91e2..af202bae5f5b 100644
--- a/src/java/org/apache/cassandra/index/sai/plan/Orderer.java
+++ b/src/java/org/apache/cassandra/index/sai/plan/Orderer.java
@@ -19,9 +19,12 @@
package org.apache.cassandra.index.sai.plan;
import java.nio.ByteBuffer;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.EnumSet;
+import java.util.HashSet;
+import java.util.List;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
@@ -41,12 +44,15 @@ public class Orderer
{
// The list of operators that are valid for order by clauses.
static final EnumSet ORDER_BY_OPERATORS = EnumSet.of(Operator.ANN,
+ Operator.BM25,
Operator.ORDER_BY_ASC,
Operator.ORDER_BY_DESC);
public final IndexContext context;
public final Operator operator;
- public final float[] vector;
+ public final ByteBuffer term;
+ private float[] vector;
+ private List queryTerms;
/**
* Create an orderer for the given index context, operator, and term.
@@ -59,7 +65,7 @@ public Orderer(IndexContext context, Operator operator, ByteBuffer term)
this.context = context;
assert ORDER_BY_OPERATORS.contains(operator) : "Invalid operator for order by clause " + operator;
this.operator = operator;
- this.vector = context.getValidator().isVector() ? TypeUtil.decomposeVector(context.getValidator(), term) : null;
+ this.term = term;
}
public String getIndexName()
@@ -75,8 +81,8 @@ public boolean isAscending()
public Comparator super PrimaryKeyWithSortKey> getComparator()
{
- // ANN's PrimaryKeyWithSortKey is always descending, so we use the natural order for the priority queue
- return isAscending() || isANN() ? Comparator.naturalOrder() : Comparator.reverseOrder();
+ // ANN/BM25's PrimaryKeyWithSortKey is always descending, so we use the natural order for the priority queue
+ return (isAscending() || isANN() || isBM25()) ? Comparator.naturalOrder() : Comparator.reverseOrder();
}
public boolean isLiteral()
@@ -89,6 +95,11 @@ public boolean isANN()
return operator == Operator.ANN;
}
+ public boolean isBM25()
+ {
+ return operator == Operator.BM25;
+ }
+
@Nullable
public static Orderer from(SecondaryIndexManager indexManager, RowFilter filter)
{
@@ -110,8 +121,38 @@ public static boolean isFilterExpressionOrderer(RowFilter.Expression expression)
public String toString()
{
String direction = isAscending() ? "ASC" : "DESC";
- return isANN()
- ? context.getColumnName() + " ANN OF " + Arrays.toString(vector) + ' ' + direction
- : context.getColumnName() + ' ' + direction;
+ if (isANN())
+ return context.getColumnName() + " ANN OF " + Arrays.toString(getVectorTerm()) + ' ' + direction;
+ if (isBM25())
+ return context.getColumnName() + " BM25 OF " + TypeUtil.getString(term, context.getValidator()) + ' ' + direction;
+ return context.getColumnName() + ' ' + direction;
+ }
+
+ public float[] getVectorTerm()
+ {
+ if (vector == null)
+ vector = TypeUtil.decomposeVector(context.getValidator(), term);
+ return vector;
+ }
+
+ public List getQueryTerms()
+ {
+ if (queryTerms != null)
+ return queryTerms;
+
+ var queryAnalyzer = context.getQueryAnalyzerFactory().create();
+ // Split query into terms
+ var uniqueTerms = new HashSet();
+ queryAnalyzer.reset(term);
+ try
+ {
+ queryAnalyzer.forEachRemaining(uniqueTerms::add);
+ }
+ finally
+ {
+ queryAnalyzer.end();
+ }
+ queryTerms = new ArrayList<>(uniqueTerms);
+ return queryTerms;
}
}
diff --git a/src/java/org/apache/cassandra/index/sai/plan/Plan.java b/src/java/org/apache/cassandra/index/sai/plan/Plan.java
index e40906c73d52..d82b92905403 100644
--- a/src/java/org/apache/cassandra/index/sai/plan/Plan.java
+++ b/src/java/org/apache/cassandra/index/sai/plan/Plan.java
@@ -1257,10 +1257,12 @@ protected double estimateSelectivity()
@Override
protected KeysIterationCost estimateCost()
{
- return ordering.isANN()
- ? estimateAnnSortCost()
- : estimateGlobalSortCost();
-
+ if (ordering.isANN())
+ return estimateAnnSortCost();
+ else if (ordering.isBM25())
+ return estimateBm25SortCost();
+ else
+ return estimateGlobalSortCost();
}
private KeysIterationCost estimateAnnSortCost()
@@ -1277,6 +1279,21 @@ private KeysIterationCost estimateAnnSortCost()
return new KeysIterationCost(expectedKeys, initCost, searchCost);
}
+ private KeysIterationCost estimateBm25SortCost()
+ {
+ double expectedKeys = access.expectedAccessCount(source.expectedKeys());
+
+ int termCount = ordering.getQueryTerms().size();
+ // all of the cost for BM25 is up front since the index doesn't give us the information we need
+ // to return results in order, in isolation. The big cost is reading the indexed cells out of
+ // the sstables.
+ // VSTODO if we had stats on cell size _per column_ we could usefully include ROW_BYTE_COST
+ double initCost = source.fullCost()
+ + source.expectedKeys() * (hrs(ROW_CELL_COST) + ROW_CELL_COST)
+ + termCount * BM25_SCORE_COST;
+ return new KeysIterationCost(expectedKeys, initCost, 0);
+ }
+
private KeysIterationCost estimateGlobalSortCost()
{
return new KeysIterationCost(source.expectedKeys(),
@@ -1310,19 +1327,51 @@ protected KeysSort withAccess(Access access)
}
/**
- * Returns all keys in ANN order.
- * Contrary to {@link KeysSort}, there is no input node here and the output is generated lazily.
+ * Base class for index scans that return results in a computed order (ANN, BM25)
+ * rather than the natural index order.
*/
- final static class AnnIndexScan extends Leaf
+ abstract static class ScoredIndexScan extends Leaf
{
final Orderer ordering;
- AnnIndexScan(Factory factory, int id, Access access, Orderer ordering)
+ protected ScoredIndexScan(Factory factory, int id, Access access, Orderer ordering)
{
super(factory, id, access);
this.ordering = ordering;
}
+ @Nullable
+ @Override
+ protected Orderer ordering()
+ {
+ return ordering;
+ }
+
+ @Override
+ protected double estimateSelectivity()
+ {
+ return 1.0;
+ }
+
+ @Override
+ protected Iterator extends PrimaryKey> execute(Executor executor)
+ {
+ int softLimit = max(1, round((float) access.expectedAccessCount(factory.tableMetrics.rows)));
+ return executor.getTopKRows((Expression) null, softLimit);
+ }
+ }
+
+ /**
+ * Returns all keys in ANN order.
+ * Contrary to {@link KeysSort}, there is no input node here and the output is generated lazily.
+ */
+ final static class AnnIndexScan extends ScoredIndexScan
+ {
+ protected AnnIndexScan(Factory factory, int id, Access access, Orderer ordering)
+ {
+ super(factory, id, access, ordering);
+ }
+
@Override
protected KeysIterationCost estimateCost()
{
@@ -1335,32 +1384,53 @@ protected KeysIterationCost estimateCost()
return new KeysIterationCost(expectedKeys, initCost, searchCost);
}
- @Nullable
@Override
- protected Orderer ordering()
+ protected KeysIteration withAccess(Access access)
{
- return ordering;
+ return Objects.equals(access, this.access)
+ ? this
+ : new AnnIndexScan(factory, id, access, ordering);
}
+ @Nullable
@Override
- protected Iterator extends PrimaryKey> execute(Executor executor)
+ protected IndexContext getIndexContext()
{
- int softLimit = max(1, round((float) access.expectedAccessCount(factory.tableMetrics.rows)));
- return executor.getTopKRows((Expression) null, softLimit);
+ return ordering.context;
}
+ }
+ /**
+ * Returns all keys in BM25 order.
+ * Like AnnIndexScan, this generates results lazily without an input node.
+ */
+ final static class Bm25IndexScan extends ScoredIndexScan
+ {
+ protected Bm25IndexScan(Factory factory, int id, Access access, Orderer ordering)
+ {
+ super(factory, id, access, ordering);
+ }
+
+ @Nonnull
@Override
- protected KeysIteration withAccess(Access access)
+ protected KeysIterationCost estimateCost()
{
- return Objects.equals(access, this.access)
- ? this
- : new AnnIndexScan(factory, id, access, ordering);
+ double expectedKeys = access.expectedAccessCount(factory.tableMetrics.rows);
+ int expectedKeysInt = Math.max(1, (int) Math.ceil(expectedKeys));
+
+ int termCount = ordering.getQueryTerms().size();
+ double initCost = expectedKeysInt * (hrs(ROW_CELL_COST) + ROW_CELL_COST)
+ + termCount * BM25_SCORE_COST;
+
+ return new KeysIterationCost(expectedKeys, initCost, 0);
}
@Override
- protected double estimateSelectivity()
+ protected KeysIteration withAccess(Access access)
{
- return 1.0;
+ return Objects.equals(access, this.access)
+ ? this
+ : new Bm25IndexScan(factory, id, access, ordering);
}
@Override
@@ -1684,6 +1754,8 @@ private KeysIteration indexScan(Expression predicate, long matchingKeysCount, Or
if (ordering != null)
if (ordering.isANN())
return new AnnIndexScan(this, id, defaultAccess, ordering);
+ else if (ordering.isBM25())
+ return new Bm25IndexScan(this, id, defaultAccess, ordering);
else if (ordering.isLiteral())
return new LiteralIndexScan(this, id, predicate, matchingKeysCount, defaultAccess, ordering);
else
@@ -1938,6 +2010,9 @@ public static class CostCoefficients
/** Additional cost added to row fetch cost per each serialized byte of the row */
public final static double ROW_BYTE_COST = 0.005;
+
+ /** Cost to perform BM25 scoring, per query term */
+ public final static double BM25_SCORE_COST = 0.5;
}
/** Convenience builder for building intersection and union nodes */
diff --git a/src/java/org/apache/cassandra/index/sai/plan/QueryController.java b/src/java/org/apache/cassandra/index/sai/plan/QueryController.java
index 15b0965fe82e..992b5f1334ab 100644
--- a/src/java/org/apache/cassandra/index/sai/plan/QueryController.java
+++ b/src/java/org/apache/cassandra/index/sai/plan/QueryController.java
@@ -195,6 +195,11 @@ public TableMetadata metadata()
return command.metadata();
}
+ public ReadCommand command()
+ {
+ return command;
+ }
+
RowFilter.FilterElement filterOperation()
{
// NOTE: we cannot remove the order by filter expression here yet because it is used in the FilterTree class
@@ -883,6 +888,7 @@ private long estimateMatchingRowCount(Expression predicate)
switch (predicate.getOp())
{
case EQ:
+ case MATCH:
case CONTAINS_KEY:
case CONTAINS_VALUE:
case NOT_EQ:
diff --git a/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexQueryPlan.java b/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexQueryPlan.java
index 22e564365bea..39a66cbad9de 100644
--- a/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexQueryPlan.java
+++ b/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexQueryPlan.java
@@ -216,7 +216,7 @@ public Function postProcessor(ReadCommand
return partitions -> partitions;
// in case of top-k query, filter out rows that are not actually global top-K
- return partitions -> new TopKProcessor(command).filter(partitions);
+ return partitions -> new TopKProcessor(command).reorder(partitions);
}
/**
diff --git a/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java
index 2f0cabd05c2c..f5ebe85a1ed9 100644
--- a/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java
+++ b/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java
@@ -30,12 +30,10 @@
import java.util.Queue;
import java.util.function.Supplier;
import java.util.stream.Collectors;
-
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import com.google.common.base.Preconditions;
-
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -45,8 +43,13 @@
import org.apache.cassandra.db.PartitionPosition;
import org.apache.cassandra.db.ReadCommand;
import org.apache.cassandra.db.ReadExecutionController;
+import org.apache.cassandra.db.filter.ColumnFilter;
+import org.apache.cassandra.db.marshal.FloatType;
import org.apache.cassandra.db.partitions.UnfilteredPartitionIterator;
import org.apache.cassandra.db.rows.AbstractUnfilteredRowIterator;
+import org.apache.cassandra.db.rows.BTreeRow;
+import org.apache.cassandra.db.rows.BufferCell;
+import org.apache.cassandra.db.rows.ColumnData;
import org.apache.cassandra.db.rows.Row;
import org.apache.cassandra.db.rows.Unfiltered;
import org.apache.cassandra.db.rows.UnfilteredRowIterator;
@@ -60,13 +63,16 @@
import org.apache.cassandra.index.sai.iterators.KeyRangeIterator;
import org.apache.cassandra.index.sai.metrics.TableQueryMetrics;
import org.apache.cassandra.index.sai.utils.PrimaryKey;
-import org.apache.cassandra.index.sai.utils.RangeUtil;
+import org.apache.cassandra.index.sai.utils.PrimaryKeyWithScore;
import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey;
+import org.apache.cassandra.index.sai.utils.RangeUtil;
import org.apache.cassandra.io.util.FileUtils;
+import org.apache.cassandra.schema.ColumnMetadata;
import org.apache.cassandra.schema.TableMetadata;
import org.apache.cassandra.utils.AbstractIterator;
import org.apache.cassandra.utils.CloseableIterator;
import org.apache.cassandra.utils.FBUtilities;
+import org.apache.cassandra.utils.btree.BTree;
public class StorageAttachedIndexSearcher implements Index.Searcher
{
@@ -109,19 +115,17 @@ public UnfilteredPartitionIterator search(ReadExecutionController executionContr
// Can't check for `command.isTopK()` because the planner could optimize sorting out
Orderer ordering = plan.ordering();
- if (ordering != null)
- {
- assert !(keysIterator instanceof KeyRangeIterator);
- var scoredKeysIterator = (CloseableIterator) keysIterator;
- var result = new ScoreOrderedResultRetriever(scoredKeysIterator, filterTree, controller,
- executionController, queryContext, command.limits().count());
- return new TopKProcessor(command).filter(result);
- }
- else
+ if (ordering == null)
{
assert keysIterator instanceof KeyRangeIterator;
return new ResultRetriever((KeyRangeIterator) keysIterator, filterTree, controller, executionController, queryContext);
}
+
+ assert !(keysIterator instanceof KeyRangeIterator);
+ var scoredKeysIterator = (CloseableIterator) keysIterator;
+ var result = new ScoreOrderedResultRetriever(scoredKeysIterator, filterTree, controller,
+ executionController, queryContext, command.limits().count());
+ return new TopKProcessor(command).filter(result);
}
catch (QueryView.Builder.MissingIndexException e)
{
@@ -521,48 +525,50 @@ public UnfilteredRowIterator computeNext()
*/
private void fillPendingRows()
{
+ // Group PKs by source sstable/memtable
+ var groupedKeys = new HashMap>();
// We always want to get at least 1.
int rowsToRetrieve = Math.max(1, softLimit - returnedRowCount);
- var keys = new HashMap>();
// We want to get the first unique `rowsToRetrieve` keys to materialize
// Don't pass the priority queue here because it is more efficient to add keys in bulk
- fillKeys(keys, rowsToRetrieve, null);
+ fillKeys(groupedKeys, rowsToRetrieve, null);
// Sort the primary keys by PrK order, just in case that helps with cache and disk efficiency
- var primaryKeyPriorityQueue = new PriorityQueue<>(keys.keySet());
+ var primaryKeyPriorityQueue = new PriorityQueue<>(groupedKeys.keySet());
- while (!keys.isEmpty())
+ // drain groupedKeys into pendingRows
+ while (!groupedKeys.isEmpty())
{
- var primaryKey = primaryKeyPriorityQueue.poll();
- var primaryKeyWithSortKeys = keys.remove(primaryKey);
- var partitionIterator = readAndValidatePartition(primaryKey, primaryKeyWithSortKeys);
+ var pk = primaryKeyPriorityQueue.poll();
+ var sourceKeys = groupedKeys.remove(pk);
+ var partitionIterator = readAndValidatePartition(pk, sourceKeys);
if (partitionIterator != null)
pendingRows.add(partitionIterator);
else
// The current primaryKey did not produce a partition iterator. We know the caller will need
// `rowsToRetrieve` rows, so we get the next unique key and add it to the queue.
- fillKeys(keys, 1, primaryKeyPriorityQueue);
+ fillKeys(groupedKeys, 1, primaryKeyPriorityQueue);
}
}
/**
- * Fills the keys map with the next `count` unique primary keys that are in the keys produced by calling
+ * Fills the `groupedKeys` Map with the next `count` unique primary keys that are in the keys produced by calling
* {@link #nextSelectedKeyInRange()}. We map PrimaryKey to List because the same
* primary key can be in the result set multiple times, but with different source tables.
- * @param keys the map to fill
+ * @param groupedKeys the map to fill
* @param count the number of unique PrimaryKeys to consume from the iterator
* @param primaryKeyPriorityQueue the priority queue to add new keys to. If the queue is null, we do not add
* keys to the queue.
*/
- private void fillKeys(Map> keys, int count, PriorityQueue primaryKeyPriorityQueue)
+ private void fillKeys(Map> groupedKeys, int count, PriorityQueue primaryKeyPriorityQueue)
{
- int initialSize = keys.size();
- while (keys.size() - initialSize < count)
+ int initialSize = groupedKeys.size();
+ while (groupedKeys.size() - initialSize < count)
{
var primaryKeyWithSortKey = nextSelectedKeyInRange();
if (primaryKeyWithSortKey == null)
return;
var nextPrimaryKey = primaryKeyWithSortKey.primaryKey();
- var accumulator = keys.computeIfAbsent(nextPrimaryKey, k -> new ArrayList<>());
+ var accumulator = groupedKeys.computeIfAbsent(nextPrimaryKey, k -> new ArrayList<>());
if (primaryKeyPriorityQueue != null && accumulator.isEmpty())
primaryKeyPriorityQueue.add(nextPrimaryKey);
accumulator.add(primaryKeyWithSortKey);
@@ -602,15 +608,29 @@ private boolean isInRange(DecoratedKey key)
return null;
}
- public UnfilteredRowIterator readAndValidatePartition(PrimaryKey key, List primaryKeys)
+ /**
+ * Reads and validates a partition for a given primary key against its sources.
+ *
+ * @param pk The primary key of the partition to read and validate
+ * @param sourceKeys A list of PrimaryKeyWithSortKey objects associated with the primary key.
+ * Multiple sort keys can exist for the same primary key when data comes from different
+ * sstables or memtables.
+ *
+ * @return An UnfilteredRowIterator containing the validated partition data, or null if:
+ * - The key has already been processed
+ * - The partition does not pass index filters
+ * - The partition contains no valid rows
+ * - The row data does not match the index metadata for any of the provided primary keys
+ */
+ public UnfilteredRowIterator readAndValidatePartition(PrimaryKey pk, List sourceKeys)
{
// If we've already processed the key, we can skip it. Because the score ordered iterator does not
// deduplicate rows, we could see dupes if a row is in the ordering index multiple times. This happens
// in the case of dupes and of overwrites.
- if (processedKeys.contains(key))
+ if (processedKeys.contains(pk))
return null;
- try (UnfilteredRowIterator partition = controller.getPartition(key, view, executionController))
+ try (UnfilteredRowIterator partition = controller.getPartition(pk, view, executionController))
{
queryContext.addPartitionsRead(1);
queryContext.checkpoint();
@@ -619,7 +639,7 @@ public UnfilteredRowIterator readAndValidatePartition(PrimaryKey key, List primaryKeysWithScore, ReadCommand command)
{
super(partition.metadata(),
partition.partitionKey(),
@@ -662,7 +694,47 @@ public PrimaryKeyIterator(UnfilteredRowIterator partition, Row staticRow, Unfilt
partition.isReverseOrder(),
partition.stats());
- row = content;
+ assert !primaryKeysWithScore.isEmpty();
+ var isScoredRow = primaryKeysWithScore.get(0) instanceof PrimaryKeyWithScore;
+ if (!content.isRow() || !isScoredRow)
+ {
+ this.row = content;
+ return;
+ }
+
+ // When +score is added on the coordinator side, it's represented as a PrecomputedColumnFilter
+ // even in a 'SELECT *' because WCF is not capable of representing synthetic columns.
+ // This can be simplified when we remove ANN_USE_SYNTHETIC_SCORE
+ var tm = metadata();
+ var scoreColumn = ColumnMetadata.syntheticColumn(tm.keyspace,
+ tm.name,
+ ColumnMetadata.SYNTHETIC_SCORE_ID,
+ FloatType.instance);
+ var isScoreFetched = command.columnFilter().fetchesExplicitly(scoreColumn);
+ if (!isScoreFetched)
+ {
+ this.row = content;
+ return;
+ }
+
+ // Clone the original Row
+ Row originalRow = (Row) content;
+ ArrayList columnData = new ArrayList<>(originalRow.columnCount() + 1);
+ columnData.addAll(originalRow.columnData());
+
+ // inject +score as a new column
+ var pkWithScore = (PrimaryKeyWithScore) primaryKeysWithScore.get(0);
+ columnData.add(BufferCell.live(scoreColumn,
+ FBUtilities.nowInSeconds(),
+ FloatType.instance.decompose(pkWithScore.indexScore)));
+
+ this.row = BTreeRow.create(originalRow.clustering(),
+ originalRow.primaryKeyLivenessInfo(),
+ originalRow.deletion(),
+ BTree.builder(ColumnData.comparator)
+ .auto(true)
+ .addAll(columnData)
+ .build());
}
@Override
@@ -674,18 +746,6 @@ protected Unfiltered computeNext()
return row;
}
}
-
- @Override
- public TableMetadata metadata()
- {
- return controller.metadata();
- }
-
- public void close()
- {
- FileUtils.closeQuietly(scoredPrimaryKeyIterator);
- controller.finish();
- }
}
private static UnfilteredRowIterator applyIndexFilter(UnfilteredRowIterator partition, FilterTree tree, QueryContext queryContext)
diff --git a/src/java/org/apache/cassandra/index/sai/plan/TopKProcessor.java b/src/java/org/apache/cassandra/index/sai/plan/TopKProcessor.java
index a2e4b315f74f..5760109328b7 100644
--- a/src/java/org/apache/cassandra/index/sai/plan/TopKProcessor.java
+++ b/src/java/org/apache/cassandra/index/sai/plan/TopKProcessor.java
@@ -27,8 +27,6 @@
import java.util.SortedSet;
import java.util.TreeMap;
import java.util.TreeSet;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.CompletionException;
import javax.annotation.Nullable;
import org.apache.commons.lang3.tuple.Triple;
@@ -38,28 +36,25 @@
import io.github.jbellis.jvector.vector.VectorizationProvider;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
-import org.apache.cassandra.concurrent.ImmediateExecutor;
-import org.apache.cassandra.concurrent.LocalAwareExecutorService;
-import org.apache.cassandra.concurrent.SharedExecutorPool;
-import org.apache.cassandra.config.CassandraRelevantProperties;
import org.apache.cassandra.cql3.Operator;
+import org.apache.cassandra.cql3.statements.SelectStatement;
import org.apache.cassandra.db.ColumnFamilyStore;
import org.apache.cassandra.db.DecoratedKey;
import org.apache.cassandra.db.Keyspace;
import org.apache.cassandra.db.ReadCommand;
import org.apache.cassandra.db.filter.RowFilter;
-import org.apache.cassandra.db.partitions.BasePartitionIterator;
-import org.apache.cassandra.db.partitions.ParallelCommandProcessor;
+import org.apache.cassandra.db.marshal.FloatType;
import org.apache.cassandra.db.partitions.PartitionIterator;
import org.apache.cassandra.db.partitions.UnfilteredPartitionIterator;
import org.apache.cassandra.db.rows.BaseRowIterator;
+import org.apache.cassandra.db.rows.Cell;
import org.apache.cassandra.db.rows.Row;
import org.apache.cassandra.db.rows.Unfiltered;
import org.apache.cassandra.index.Index;
import org.apache.cassandra.index.SecondaryIndexManager;
import org.apache.cassandra.index.sai.IndexContext;
import org.apache.cassandra.index.sai.StorageAttachedIndex;
-import org.apache.cassandra.index.sai.utils.AbortedOperationException;
+import org.apache.cassandra.index.sai.plan.StorageAttachedIndexSearcher.ScoreOrderedResultRetriever;
import org.apache.cassandra.index.sai.utils.InMemoryPartitionIterator;
import org.apache.cassandra.index.sai.utils.InMemoryUnfilteredPartitionIterator;
import org.apache.cassandra.index.sai.utils.PartitionInfo;
@@ -72,33 +67,30 @@
import static org.apache.cassandra.cql3.statements.RequestValidations.invalidRequest;
/**
- * Processor applied to SAI based ORDER BY queries. This class could likely be refactored into either two filter
- * methods depending on where the processing is happening or into two classes.
- *
- * This processor performs the following steps on a replica:
- * - collect LIMIT rows from partition iterator, making sure that all are valid.
- * - return rows in Primary Key order
- *
- * This processor performs the following steps on a coordinator:
- * - consume all rows from the provided partition iterator and sort them according to the specified order.
- * For vectors, that is similarit score and for all others, that is the ordering defined by their
- * {@link org.apache.cassandra.db.marshal.AbstractType}. If there are multiple vector indexes,
- * the final score is the sum of all vector index scores.
- * - remove rows with the lowest scores from PQ if PQ size exceeds limit
- * - return rows from PQ in primary key order to caller
+ * Processor applied to SAI based ORDER BY queries.
+ *
+ * * On a replica:
+ * * - filter(ScoreOrderedResultRetriever) is used to collect up to the top-K rows.
+ * * - We store any tombstones as well, to avoid losing them during coordinator reconciliation.
+ * * - The result is returned in PK order so that coordinator can merge from multiple replicas.
+ *
+ * On a coordinator:
+ * - reorder(PartitionIterator) is used to consume all rows from the provided partitions,
+ * compute the order based on either a column ordering or a similarity score, and keep top-K.
+ * - The result is returned in score/sortkey order.
*/
public class TopKProcessor
{
public static final String INDEX_MAY_HAVE_BEEN_DROPPED = "An index may have been dropped. Ordering on non-clustering " +
"column requires the column to be indexed";
protected static final Logger logger = LoggerFactory.getLogger(TopKProcessor.class);
- private static final LocalAwareExecutorService PARALLEL_EXECUTOR = getExecutor();
private static final VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport();
private final ReadCommand command;
private final IndexContext indexContext;
private final RowFilter.Expression expression;
private final VectorFloat> queryVector;
+ private final ColumnMetadata scoreColumn;
private final int limit;
@@ -106,52 +98,29 @@ public TopKProcessor(ReadCommand command)
{
this.command = command;
- Pair annIndexAndExpression = findTopKIndexContext();
+ Pair indexAndExpression = findTopKIndexContext();
// this can happen in case an index was dropped after the query was initiated
- if (annIndexAndExpression == null)
+ if (indexAndExpression == null)
throw invalidRequest(INDEX_MAY_HAVE_BEEN_DROPPED);
- this.indexContext = annIndexAndExpression.left;
- this.expression = annIndexAndExpression.right;
- if (expression.operator() == Operator.ANN)
+ this.indexContext = indexAndExpression.left;
+ this.expression = indexAndExpression.right;
+ if (expression.operator() == Operator.ANN && !SelectStatement.ANN_USE_SYNTHETIC_SCORE)
this.queryVector = vts.createFloatVector(TypeUtil.decomposeVector(indexContext, expression.getIndexValue().duplicate()));
else
this.queryVector = null;
this.limit = command.limits().count();
- }
-
- /**
- * Executor to use for parallel index reads.
- * Defined by -Dcassandra.index_read.parallele=true/false, true by default.
- *
- * INDEX_READ uses 2 * cpus threads by default but can be overridden with -Dcassandra.index_read.parallel_thread_num=
- *
- * @return stage to use, default INDEX_READ
- */
- private static LocalAwareExecutorService getExecutor()
- {
- boolean isParallel = CassandraRelevantProperties.USE_PARALLEL_INDEX_READ.getBoolean();
-
- if (isParallel)
- {
- int numThreads = CassandraRelevantProperties.PARALLEL_INDEX_READ_NUM_THREADS.isPresent()
- ? CassandraRelevantProperties.PARALLEL_INDEX_READ_NUM_THREADS.getInt()
- : FBUtilities.getAvailableProcessors() * 2;
- return SharedExecutorPool.SHARED.newExecutor(numThreads, maximumPoolSize -> {}, "request", "IndexParallelRead");
- }
- else
- return ImmediateExecutor.INSTANCE;
+ this.scoreColumn = ColumnMetadata.syntheticColumn(indexContext.getKeyspace(), indexContext.getTable(), ColumnMetadata.SYNTHETIC_SCORE_ID, FloatType.instance);
}
/**
* Sort the specified filtered rows according to the {@code ORDER BY} clause and keep the first {@link #limit} rows.
- * This is meant to be used on the coordinator-side to sort the rows collected from the replicas.
- * Caller must close the supplied iterator.
+ * Called on the coordinator side.
*
- * @param partitions the partitions collected by the coordinator
- * @return the provided rows, sorted by the requested {@code ORDER BY} chriteria and trimmed to {@link #limit} rows
+ * @param partitions the partitions collected by the coordinator. It will be closed as a side-effect.
+ * @return the provided rows, sorted and trimmed to {@link #limit} rows
*/
- public PartitionIterator filter(PartitionIterator partitions)
+ public PartitionIterator reorder(PartitionIterator partitions)
{
// We consume the partitions iterator and create a new one. Use a try-with-resources block to ensure the
// original iterator is closed. We do not expect exceptions here, but if they happen, we want to make sure the
@@ -159,17 +128,35 @@ public PartitionIterator filter(PartitionIterator partitions)
try (partitions)
{
Comparator> comparator = comparator()
- .thenComparing(Triple::getLeft, Comparator.comparing(p -> p.key))
- .thenComparing(Triple::getMiddle, command.metadata().comparator);
+ .thenComparing(Triple::getLeft, Comparator.comparing(pi -> pi.key))
+ .thenComparing(Triple::getMiddle, command.metadata().comparator);
TopKSelector> topK = new TopKSelector<>(comparator, limit);
+ while (partitions.hasNext())
+ {
+ try (BaseRowIterator> partitionRowIterator = partitions.next())
+ {
+ if (expression.operator() == Operator.ANN || expression.operator() == Operator.BM25)
+ {
+ PartitionResults pr = processScoredPartition(partitionRowIterator);
+ topK.addAll(pr.rows);
+ }
+ else
+ {
+ while (partitionRowIterator.hasNext())
+ {
+ Row row = (Row) partitionRowIterator.next();
+ ByteBuffer value = row.getCell(expression.column()).buffer();
+ topK.add(Triple.of(PartitionInfo.create(partitionRowIterator), row, value));
+ }
+ }
+ }
+ }
- processPartitions(partitions, topK, null);
-
+ // Convert the topK results to a PartitionIterator
List> sortedRows = new ArrayList<>(topK.size());
for (Triple triple : topK.getShared())
sortedRows.add(Pair.create(triple.getLeft(), triple.getMiddle()));
-
return InMemoryPartitionIterator.create(command, sortedRows);
}
}
@@ -186,149 +173,53 @@ public PartitionIterator filter(PartitionIterator partitions)
*
* All tombstones will be kept. Caller must close the supplied iterator.
*
- * @param partitions the partitions collected in the replica side of a query
+ * @param partitions the partitions collected in the replica side of a query. It will be closed as a side-effect.
* @return the provided rows, sorted by the requested {@code ORDER BY} chriteria, trimmed to {@link #limit} rows,
* and the sorted again by primary key.
*/
- public UnfilteredPartitionIterator filter(UnfilteredPartitionIterator partitions)
+ public UnfilteredPartitionIterator filter(ScoreOrderedResultRetriever partitions)
{
- // We consume the partitions iterator and create a new one. Use a try-with-resources block to ensure the
- // original iterator is closed. We do not expect exceptions here, but if they happen, we want to make sure the
- // original iterator is closed to prevent leaking resources, which could compound the effect of an exception.
try (partitions)
{
- TopKSelector> topK = new TopKSelector<>(comparator(), limit);
-
- TreeMap> unfilteredByPartition = new TreeMap<>(Comparator.comparing(p -> p.key));
-
- processPartitions(partitions, topK, unfilteredByPartition);
-
- // Reorder the rows by primary key.
- for (var triple : topK.getUnsortedShared())
- addUnfiltered(unfilteredByPartition, triple.getLeft(), triple.getMiddle());
-
- return new InMemoryUnfilteredPartitionIterator(command, unfilteredByPartition);
- }
- }
-
- private Comparator> comparator()
- {
- Comparator> comparator;
- if (queryVector != null)
- {
- comparator = Comparator.comparing((Triple t) -> (Float) t.getRight()).reversed();
- }
- else
- {
- comparator = Comparator.comparing(t -> (ByteBuffer) t.getRight(), indexContext.getValidator());
- if (expression.operator() == Operator.ORDER_BY_DESC)
- comparator = comparator.reversed();
- }
- return comparator;
- }
-
- private , P extends BasePartitionIterator>
- void processPartitions(P partitions,
- TopKSelector> topK,
- @Nullable TreeMap> unfilteredByPartition)
- {
- if (PARALLEL_EXECUTOR != ImmediateExecutor.INSTANCE && partitions instanceof ParallelCommandProcessor)
- {
- ParallelCommandProcessor pIter = (ParallelCommandProcessor) partitions;
- var commands = pIter.getUninitializedCommands();
- List> results = new ArrayList<>(commands.size());
+ TreeMap> unfilteredByPartition = new TreeMap<>(Comparator.comparing(pi -> pi.key));
- int count = commands.size();
- for (var command: commands) {
- CompletableFuture future = new CompletableFuture<>();
- results.add(future);
-
- // run last command immediately, others in parallel (if possible)
- count--;
- var executor = count == 0 ? ImmediateExecutor.INSTANCE : PARALLEL_EXECUTOR;
-
- executor.maybeExecuteImmediately(() -> {
- try (var partitionRowIterator = pIter.commandToIterator(command.left(), command.right()))
- {
- future.complete(partitionRowIterator == null ? null : processPartition(partitionRowIterator));
- }
- catch (Throwable t)
- {
- future.completeExceptionally(t);
- }
- });
- }
-
- for (CompletableFuture triplesFuture: results)
- {
- PartitionResults pr;
- try
- {
- pr = triplesFuture.join();
- }
- catch (CompletionException t)
- {
- if (t.getCause() instanceof AbortedOperationException)
- throw (AbortedOperationException) t.getCause();
- throw t;
- }
- if (pr == null)
- continue;
- topK.addAll(pr.rows);
- if (unfilteredByPartition != null)
- {
- for (var uf : pr.tombstones)
- addUnfiltered(unfilteredByPartition, pr.partitionInfo, uf);
- }
- }
- }
- else if (partitions instanceof StorageAttachedIndexSearcher.ScoreOrderedResultRetriever)
- {
- // FilteredPartitions does not implement ParallelizablePartitionIterator.
- // Realistically, this won't benefit from parallelizm as these are coming from in-memory/memtable data.
int rowsMatched = 0;
- // Check rowsMatched first to prevent fetching one more partition than needed.
+ // Because each “partition” from ScoreOrderedResultRetriever is actually a single row
+ // or tombstone, we can simply read them until we have enough.
while (rowsMatched < limit && partitions.hasNext())
{
- // Must close to move to the next partition, otherwise hasNext() fails
- try (var partitionRowIterator = partitions.next())
+ try (BaseRowIterator> partitionRowIterator = partitions.next())
{
rowsMatched += processSingleRowPartition(unfilteredByPartition, partitionRowIterator);
}
}
+
+ return new InMemoryUnfilteredPartitionIterator(command, unfilteredByPartition);
}
- else
+ }
+
+ /**
+ * Constructs a comparator for triple (PartitionInfo, Row, X) used for top-K ranking.
+ * For ANN/BM25 we compare descending by X (float score). For ORDER_BY_ASC or DESC,
+ * we compare ascending/descending by the row’s relevant ByteBuffer data.
+ */
+ private Comparator> comparator()
+ {
+ if (expression.operator() == Operator.ANN || expression.operator() == Operator.BM25)
{
- // FilteredPartitions does not implement ParallelizablePartitionIterator.
- // Realistically, this won't benefit from parallelizm as these are coming from in-memory/memtable data.
- while (partitions.hasNext())
- {
- // have to close to move to the next partition, otherwise hasNext() fails
- try (var partitionRowIterator = partitions.next())
- {
- if (queryVector != null)
- {
- PartitionResults pr = processPartition(partitionRowIterator);
- topK.addAll(pr.rows);
- if (unfilteredByPartition != null)
- {
- for (var uf : pr.tombstones)
- addUnfiltered(unfilteredByPartition, pr.partitionInfo, uf);
- }
- }
- else
- {
- while (partitionRowIterator.hasNext())
- {
- Row row = (Row) partitionRowIterator.next();
- topK.add(Triple.of(PartitionInfo.create(partitionRowIterator), row, row.getCell(expression.column()).buffer()));
- }
- }
- }
- }
+ // For similarity, higher is better, so reversed
+ return Comparator.comparing((Triple t) -> (Float) t.getRight()).reversed();
}
+
+ Comparator> comparator = Comparator.comparing(t -> (ByteBuffer) t.getRight(), indexContext.getValidator());
+ if (expression.operator() == Operator.ORDER_BY_DESC)
+ comparator = comparator.reversed();
+ return comparator;
}
+ /**
+ * Simple holder for partial results of a single partition (score-based path).
+ */
private class PartitionResults
{
final PartitionInfo partitionInfo;
@@ -352,9 +243,9 @@ void addRow(Triple triple)
}
/**
- * Processes a single partition, calculating scores for rows and extracting tombstones.
+ * Processes all rows in a single partition to compute scores (for ANN or BM25)
*/
- private PartitionResults processPartition(BaseRowIterator> partitionRowIterator)
+ private PartitionResults processScoredPartition(BaseRowIterator> partitionRowIterator)
{
// Compute key and static row score once per partition
DecoratedKey key = partitionRowIterator.partitionKey();
@@ -400,15 +291,14 @@ private int processSingleRowPartition(TreeMap
return unfiltered.isRangeTombstoneMarker() ? 0 : 1;
}
- private void addUnfiltered(SortedMap> unfilteredByPartition, PartitionInfo partitionInfo, Unfiltered unfiltered)
+ private void addUnfiltered(SortedMap> unfilteredByPartition,
+ PartitionInfo partitionInfo,
+ Unfiltered unfiltered)
{
var map = unfilteredByPartition.computeIfAbsent(partitionInfo, k -> new TreeSet<>(command.metadata().comparator));
map.add(unfiltered);
}
- /**
- * Sum the scores from different vector indexes for the row
- */
private float getScoreForRow(DecoratedKey key, Row row)
{
ColumnMetadata column = indexContext.getDefinition();
@@ -422,6 +312,15 @@ private float getScoreForRow(DecoratedKey key, Row row)
if ((column.isClusteringColumn() || column.isRegular()) && row.isStatic())
return 0;
+ // If we have a synthetic score column, use it
+ var scoreData = row.getColumnData(scoreColumn);
+ if (scoreData != null)
+ {
+ var cell = (Cell>) scoreData;
+ return FloatType.instance.compose(cell.buffer());
+ }
+
+ // TODO remove this once we enable ANN_USE_SYNTHETIC_SCORE
ByteBuffer value = indexContext.getValueOf(key, row, FBUtilities.nowInSeconds());
if (value != null)
{
@@ -431,28 +330,30 @@ private float getScoreForRow(DecoratedKey key, Row row)
return 0;
}
-
private Pair findTopKIndexContext()
{
ColumnFamilyStore cfs = Keyspace.openAndGetStore(command.metadata());
for (RowFilter.Expression expression : command.rowFilter().expressions())
{
- StorageAttachedIndex sai = findVectorIndexFor(cfs.indexManager, expression);
+ StorageAttachedIndex sai = findOrderingIndexFor(cfs.indexManager, expression);
if (sai != null)
- {
return Pair.create(sai.getIndexContext(), expression);
- }
}
return null;
}
@Nullable
- private StorageAttachedIndex findVectorIndexFor(SecondaryIndexManager sim, RowFilter.Expression e)
+ private StorageAttachedIndex findOrderingIndexFor(SecondaryIndexManager sim, RowFilter.Expression e)
{
- if (e.operator() != Operator.ANN && e.operator() != Operator.ORDER_BY_ASC && e.operator() != Operator.ORDER_BY_DESC)
+ if (e.operator() != Operator.ANN
+ && e.operator() != Operator.BM25
+ && e.operator() != Operator.ORDER_BY_ASC
+ && e.operator() != Operator.ORDER_BY_DESC)
+ {
return null;
+ }
Optional index = sim.getBestIndexFor(e);
return (StorageAttachedIndex) index.filter(i -> i instanceof StorageAttachedIndex).orElse(null);
diff --git a/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java b/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java
new file mode 100644
index 000000000000..ae101c6f0373
--- /dev/null
+++ b/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.index.sai.utils;
+
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.cassandra.db.memtable.Memtable;
+import org.apache.cassandra.db.rows.Cell;
+import org.apache.cassandra.index.sai.IndexContext;
+import org.apache.cassandra.index.sai.analyzer.AbstractAnalyzer;
+import org.apache.cassandra.io.sstable.SSTableId;
+import org.apache.cassandra.io.util.FileUtils;
+import org.apache.cassandra.utils.CloseableIterator;
+
+public class BM25Utils
+{
+ private static final float K1 = 1.2f; // BM25 term frequency saturation parameter
+ private static final float B = 0.75f; // BM25 length normalization parameter
+
+ /**
+ * Term frequencies across all documents. Each document is only counted once.
+ */
+ public static class DocStats
+ {
+ // Map of term -> count of docs containing that term
+ private final Map frequencies;
+ // total number of docs in the index
+ private final long docCount;
+
+ public DocStats(Map frequencies, long docCount)
+ {
+ this.frequencies = frequencies;
+ this.docCount = docCount;
+ }
+ }
+
+ /**
+ * Term frequencies within a single document. All instances of a term are counted.
+ */
+ public static class DocTF
+ {
+ private final PrimaryKey pk;
+ private final Map frequencies;
+ private final int termCount;
+
+ public DocTF(PrimaryKey pk, int termCount, Map frequencies)
+ {
+ this.pk = pk;
+ this.frequencies = frequencies;
+ this.termCount = termCount;
+ }
+
+ public int getTermFrequency(ByteBuffer term)
+ {
+ return frequencies.getOrDefault(term, 0);
+ }
+
+ public static DocTF createFromDocument(PrimaryKey pk,
+ Cell> cell,
+ AbstractAnalyzer docAnalyzer,
+ Collection queryTerms)
+ {
+ int count = 0;
+ Map frequencies = new HashMap<>();
+
+ docAnalyzer.reset(cell.buffer());
+ try
+ {
+ while (docAnalyzer.hasNext())
+ {
+ ByteBuffer term = docAnalyzer.next();
+ count++;
+ if (queryTerms.contains(term))
+ frequencies.merge(term, 1, Integer::sum);
+ }
+ }
+ finally
+ {
+ docAnalyzer.end();
+ }
+
+ return new DocTF(pk, count, frequencies);
+ }
+ }
+
+ public static CloseableIterator computeScores(CloseableIterator docIterator,
+ List queryTerms,
+ DocStats docStats,
+ IndexContext indexContext,
+ Object source)
+ {
+ // data structures for document stats and frequencies
+ ArrayList documents = new ArrayList<>();
+ double totalTermCount = 0;
+
+ // Compute TF within each document
+ while (docIterator.hasNext())
+ {
+ var tf = docIterator.next();
+ documents.add(tf);
+ totalTermCount += tf.termCount;
+ }
+ if (documents.isEmpty())
+ return CloseableIterator.emptyIterator();
+
+ // Calculate average document length
+ double avgDocLength = totalTermCount / documents.size();
+
+ // Calculate BM25 scores
+ var scoredDocs = new ArrayList(documents.size());
+ for (var doc : documents)
+ {
+ double score = 0.0;
+ for (var queryTerm : queryTerms)
+ {
+ int tf = doc.getTermFrequency(queryTerm);
+ Long df = docStats.frequencies.get(queryTerm);
+ // we shouldn't have more hits for a term than we counted total documents
+ assert df <= docStats.docCount : String.format("df=%d, totalDocs=%d", df, docStats.docCount);
+
+ double normalizedTf = tf / (tf + K1 * (1 - B + B * doc.termCount / avgDocLength));
+ double idf = Math.log(1 + (docStats.docCount - df + 0.5) / (df + 0.5));
+ double deltaScore = normalizedTf * idf;
+ assert deltaScore >= 0 : String.format("BM25 score for tf=%d, df=%d, tc=%d, totalDocs=%d is %f",
+ tf, df, doc.termCount, docStats.docCount, deltaScore);
+ score += deltaScore;
+ }
+ if (source instanceof Memtable)
+ scoredDocs.add(new PrimaryKeyWithScore(indexContext, (Memtable) source, doc.pk, (float) score));
+ else if (source instanceof SSTableId)
+ scoredDocs.add(new PrimaryKeyWithScore(indexContext, (SSTableId) source, doc.pk, (float) score));
+ else
+ throw new IllegalArgumentException("Invalid source " + source.getClass());
+ }
+
+ // sort by score (PKWS implements Comparator correctly for us)
+ Collections.sort(scoredDocs);
+
+ return new CloseableIterator<>()
+ {
+ private final Iterator iterator = scoredDocs.iterator();
+
+ @Override
+ public boolean hasNext()
+ {
+ return iterator.hasNext();
+ }
+
+ @Override
+ public PrimaryKeyWithSortKey next()
+ {
+ return iterator.next();
+ }
+
+ @Override
+ public void close()
+ {
+ FileUtils.closeQuietly(docIterator);
+ }
+ };
+ }
+}
diff --git a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKey.java b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKey.java
index 516e26caaeea..0bcd16c0a7f3 100644
--- a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKey.java
+++ b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKey.java
@@ -19,6 +19,7 @@
import java.util.function.Supplier;
+import io.github.jbellis.jvector.util.Accountable;
import org.apache.cassandra.db.Clustering;
import org.apache.cassandra.db.ClusteringComparator;
import org.apache.cassandra.db.DecoratedKey;
@@ -38,7 +39,7 @@
* For the V2 on-disk format the {@link DecoratedKey} and {@link Clustering} are supported.
*
*/
-public interface PrimaryKey extends Comparable
+public interface PrimaryKey extends Comparable, Accountable
{
/**
* A factory for creating {@link PrimaryKey} instances
diff --git a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithByteComparable.java b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithByteComparable.java
index 837c032b7952..bb931c942fbf 100644
--- a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithByteComparable.java
+++ b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithByteComparable.java
@@ -21,7 +21,10 @@
import java.nio.ByteBuffer;
import java.util.Arrays;
+import io.github.jbellis.jvector.util.RamUsageEstimator;
+import org.apache.cassandra.db.memtable.Memtable;
import org.apache.cassandra.index.sai.IndexContext;
+import org.apache.cassandra.io.sstable.SSTableId;
import org.apache.cassandra.utils.bytecomparable.ByteComparable;
import org.apache.cassandra.utils.bytecomparable.ByteSource;
import org.apache.cassandra.utils.bytecomparable.ByteSourceInverse;
@@ -34,7 +37,13 @@ public class PrimaryKeyWithByteComparable extends PrimaryKeyWithSortKey
{
private final ByteComparable byteComparable;
- public PrimaryKeyWithByteComparable(IndexContext context, Object sourceTable, PrimaryKey primaryKey, ByteComparable byteComparable)
+ public PrimaryKeyWithByteComparable(IndexContext context, Memtable sourceTable, PrimaryKey primaryKey, ByteComparable byteComparable)
+ {
+ super(context, sourceTable, primaryKey);
+ this.byteComparable = byteComparable;
+ }
+
+ public PrimaryKeyWithByteComparable(IndexContext context, SSTableId sourceTable, PrimaryKey primaryKey, ByteComparable byteComparable)
{
super(context, sourceTable, primaryKey);
this.byteComparable = byteComparable;
@@ -65,4 +74,10 @@ public int compareTo(PrimaryKey o)
return ByteComparable.compare(byteComparable, ((PrimaryKeyWithByteComparable) o).byteComparable, TypeUtil.BYTE_COMPARABLE_VERSION);
}
+
+ @Override
+ public long ramBytesUsed()
+ {
+ return super.ramBytesUsed() + RamUsageEstimator.NUM_BYTES_OBJECT_REF;
+ }
}
diff --git a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithScore.java b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithScore.java
index 8b7b4acba9c6..b88c210d65f2 100644
--- a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithScore.java
+++ b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithScore.java
@@ -20,7 +20,9 @@
import java.nio.ByteBuffer;
+import org.apache.cassandra.db.memtable.Memtable;
import org.apache.cassandra.index.sai.IndexContext;
+import org.apache.cassandra.io.sstable.SSTableId;
/**
* A {@link PrimaryKey} that includes a score from a source index.
@@ -28,9 +30,15 @@
*/
public class PrimaryKeyWithScore extends PrimaryKeyWithSortKey
{
- private final float indexScore;
+ public final float indexScore;
- public PrimaryKeyWithScore(IndexContext context, Object source, PrimaryKey primaryKey, float indexScore)
+ public PrimaryKeyWithScore(IndexContext context, Memtable source, PrimaryKey primaryKey, float indexScore)
+ {
+ super(context, source, primaryKey);
+ this.indexScore = indexScore;
+ }
+
+ public PrimaryKeyWithScore(IndexContext context, SSTableId source, PrimaryKey primaryKey, float indexScore)
{
super(context, source, primaryKey);
this.indexScore = indexScore;
@@ -53,4 +61,11 @@ public int compareTo(PrimaryKey o)
// Descending order
return Float.compare(((PrimaryKeyWithScore) o).indexScore, indexScore);
}
+
+ @Override
+ public long ramBytesUsed()
+ {
+ // Include super class fields plus float value
+ return super.ramBytesUsed() + Float.BYTES;
+ }
}
diff --git a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithSortKey.java b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithSortKey.java
index 2e79b0402124..b3a6fb4338e5 100644
--- a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithSortKey.java
+++ b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithSortKey.java
@@ -20,11 +20,14 @@
import java.nio.ByteBuffer;
+import io.github.jbellis.jvector.util.RamUsageEstimator;
import org.apache.cassandra.db.Clustering;
import org.apache.cassandra.db.DecoratedKey;
+import org.apache.cassandra.db.memtable.Memtable;
import org.apache.cassandra.db.rows.Row;
import org.apache.cassandra.dht.Token;
import org.apache.cassandra.index.sai.IndexContext;
+import org.apache.cassandra.io.sstable.SSTableId;
import org.apache.cassandra.utils.bytecomparable.ByteComparable;
import org.apache.cassandra.utils.bytecomparable.ByteSource;
@@ -41,7 +44,14 @@ public abstract class PrimaryKeyWithSortKey implements PrimaryKey
// Either a Memtable reference or an SSTableId reference
private final Object sourceTable;
- protected PrimaryKeyWithSortKey(IndexContext context, Object sourceTable, PrimaryKey primaryKey)
+ protected PrimaryKeyWithSortKey(IndexContext context, Memtable sourceTable, PrimaryKey primaryKey)
+ {
+ this.context = context;
+ this.sourceTable = sourceTable;
+ this.primaryKey = primaryKey;
+ }
+
+ protected PrimaryKeyWithSortKey(IndexContext context, SSTableId sourceTable, PrimaryKey primaryKey)
{
this.context = context;
this.sourceTable = sourceTable;
@@ -137,4 +147,13 @@ public ByteSource asComparableBytesMaxPrefix(ByteComparable.Version version)
return primaryKey.asComparableBytesMaxPrefix(version);
}
+ @Override
+ public long ramBytesUsed()
+ {
+ // Object header + 3 references (context, primaryKey, sourceTable)
+ return RamUsageEstimator.NUM_BYTES_OBJECT_HEADER +
+ 3L * RamUsageEstimator.NUM_BYTES_OBJECT_REF +
+ primaryKey.ramBytesUsed();
+ }
+
}
diff --git a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeys.java b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeys.java
index 173dc045965f..75298ac217da 100644
--- a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeys.java
+++ b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeys.java
@@ -41,6 +41,7 @@ public class PrimaryKeys implements Iterable
* Adds the specified {@link PrimaryKey}.
*
* @param key a primary key
+ * @return the bytes allocated for the key (0 if it already existed in the set)
*/
public long add(PrimaryKey key)
{
diff --git a/src/java/org/apache/cassandra/index/sai/utils/RowIdWithScore.java b/src/java/org/apache/cassandra/index/sai/utils/RowIdWithScore.java
index b4b1f129567f..c6ed4708e0c8 100644
--- a/src/java/org/apache/cassandra/index/sai/utils/RowIdWithScore.java
+++ b/src/java/org/apache/cassandra/index/sai/utils/RowIdWithScore.java
@@ -26,7 +26,7 @@
*/
public class RowIdWithScore extends RowIdWithMeta
{
- private final float score;
+ public final float score;
public RowIdWithScore(int segmentRowId, float score)
{
diff --git a/src/java/org/apache/cassandra/index/sai/utils/RowWithSourceTable.java b/src/java/org/apache/cassandra/index/sai/utils/RowWithSourceTable.java
index 3064c701fc56..d93198a94d1a 100644
--- a/src/java/org/apache/cassandra/index/sai/utils/RowWithSourceTable.java
+++ b/src/java/org/apache/cassandra/index/sai/utils/RowWithSourceTable.java
@@ -375,4 +375,13 @@ private Row maybeWrapRow(Row r)
return this;
return new RowWithSourceTable(r, source);
}
+
+ @Override
+ public String toString()
+ {
+ return "RowWithSourceTable{" +
+ row +
+ ", source=" + source +
+ '}';
+ }
}
diff --git a/src/java/org/apache/cassandra/schema/ColumnMetadata.java b/src/java/org/apache/cassandra/schema/ColumnMetadata.java
index 38784e72acb5..008813564e57 100644
--- a/src/java/org/apache/cassandra/schema/ColumnMetadata.java
+++ b/src/java/org/apache/cassandra/schema/ColumnMetadata.java
@@ -71,9 +71,9 @@ public enum ClusteringOrder
/**
* The type of CQL3 column this definition represents.
- * There is 4 main type of CQL3 columns: those parts of the partition key,
- * those parts of the clustering columns and amongst the others, regular and
- * static ones.
+ * There are 5 types of columns: those parts of the partition key,
+ * those parts of the clustering columns and amongst the others, regular,
+ * static, and synthetic ones.
*
* IMPORTANT: this enum is serialized as toString() and deserialized by calling
* Kind.valueOf(), so do not override toString() or rename existing values.
@@ -81,18 +81,22 @@ public enum ClusteringOrder
public enum Kind
{
// NOTE: if adding a new type, must modify comparisonOrder
+ SYNTHETIC,
PARTITION_KEY,
CLUSTERING,
REGULAR,
STATIC;
+ // it is not possible to add new Kinds after Synthetic without invasive changes to BTreeRow, which
+ // assumes that complex regulr/static columns are the last ones
public boolean isPrimaryKeyKind()
{
return this == PARTITION_KEY || this == CLUSTERING;
}
-
}
+ public static final ColumnIdentifier SYNTHETIC_SCORE_ID = ColumnIdentifier.getInterned("+:!score", true);
+
/**
* Whether this is a dropped column.
*/
@@ -121,10 +125,17 @@ public boolean isPrimaryKeyKind()
*/
private final long comparisonOrder;
+ /**
+ * Bit layout (from most to least significant):
+ * - Bits 61-63: Kind ordinal (3 bits, supporting up to 8 Kind values)
+ * - Bit 60: isComplex flag
+ * - Bits 48-59: position (12 bits, see assert)
+ * - Bits 0-47: name.prefixComparison (shifted right by 16)
+ */
private static long comparisonOrder(Kind kind, boolean isComplex, long position, ColumnIdentifier name)
{
assert position >= 0 && position < 1 << 12;
- return (((long) kind.ordinal()) << 61)
+ return (((long) kind.ordinal()) << 61)
| (isComplex ? 1L << 60 : 0)
| (position << 48)
| (name.prefixComparison >>> 16);
@@ -170,6 +181,14 @@ public static ColumnMetadata staticColumn(String keyspace, String table, String
return new ColumnMetadata(keyspace, table, ColumnIdentifier.getInterned(name, true), type, NO_POSITION, Kind.STATIC);
}
+ /**
+ * Creates a new synthetic column metadata instance.
+ */
+ public static ColumnMetadata syntheticColumn(String keyspace, String table, ColumnIdentifier id, AbstractType> type)
+ {
+ return new ColumnMetadata(keyspace, table, id, type, NO_POSITION, Kind.SYNTHETIC);
+ }
+
/**
* Rebuild the metadata for a dropped column from its recorded data.
*
@@ -225,6 +244,7 @@ public ColumnMetadata(String ksName,
this.kind = kind;
this.position = position;
this.cellPathComparator = makeCellPathComparator(kind, type);
+ assert kind != Kind.SYNTHETIC || cellPathComparator == null;
this.cellComparator = cellPathComparator == null ? ColumnData.comparator : new Comparator>()
{
@Override
@@ -461,7 +481,7 @@ public int compareTo(ColumnMetadata other)
return 0;
if (comparisonOrder != other.comparisonOrder)
- return Long.compare(comparisonOrder, other.comparisonOrder);
+ return Long.compareUnsigned(comparisonOrder, other.comparisonOrder);
return this.name.compareTo(other.name);
}
@@ -593,6 +613,11 @@ public boolean isCounterColumn()
return type.isCounter();
}
+ public boolean isSynthetic()
+ {
+ return kind == Kind.SYNTHETIC;
+ }
+
public Selector.Factory newSelectorFactory(TableMetadata table, AbstractType> expectedType, List defs, VariableSpecifications boundNames) throws InvalidRequestException
{
return SimpleSelector.newFactory(this, addAndGetIndex(this, defs));
diff --git a/src/java/org/apache/cassandra/schema/TableMetadata.java b/src/java/org/apache/cassandra/schema/TableMetadata.java
index ba5e1db8d84d..8885b85fdd43 100644
--- a/src/java/org/apache/cassandra/schema/TableMetadata.java
+++ b/src/java/org/apache/cassandra/schema/TableMetadata.java
@@ -1122,8 +1122,7 @@ public Builder addStaticColumn(ColumnIdentifier name, AbstractType type)
public Builder addColumn(ColumnMetadata column)
{
- if (columns.containsKey(column.name.bytes))
- throw new IllegalArgumentException();
+ assert !columns.containsKey(column.name.bytes) : column.name + " is already present";
switch (column.kind)
{
diff --git a/src/java/org/apache/cassandra/service/ClientWarn.java b/src/java/org/apache/cassandra/service/ClientWarn.java
index 5a6a878681e1..38570a06d2b8 100644
--- a/src/java/org/apache/cassandra/service/ClientWarn.java
+++ b/src/java/org/apache/cassandra/service/ClientWarn.java
@@ -18,7 +18,9 @@
package org.apache.cassandra.service;
import java.util.ArrayList;
+import java.util.HashSet;
import java.util.List;
+import java.util.Set;
import io.netty.util.concurrent.FastThreadLocal;
import org.apache.cassandra.concurrent.ExecutorLocal;
@@ -45,10 +47,18 @@ public void set(State value)
}
public void warn(String text)
+ {
+ warn(text, null);
+ }
+
+ /**
+ * Issue the given warning if this is the first time `key` is seen.
+ */
+ public void warn(String text, Object key)
{
State state = warnLocal.get();
if (state != null)
- state.add(text);
+ state.add(text, key);
}
public void captureWarnings()
@@ -72,11 +82,16 @@ public void resetWarnings()
public static class State
{
private final List warnings = new ArrayList<>();
+ private final Set keysAdded = new HashSet<>();
- private void add(String warning)
+ private void add(String warning, Object key)
{
if (warnings.size() < FBUtilities.MAX_UNSIGNED_SHORT)
+ {
+ if (key != null && !keysAdded.add(key))
+ return;
warnings.add(maybeTruncate(warning));
+ }
}
private static String maybeTruncate(String warning)
diff --git a/test/burn/org/apache/cassandra/index/sai/LongBM25Test.java b/test/burn/org/apache/cassandra/index/sai/LongBM25Test.java
new file mode 100644
index 000000000000..49b5b5118540
--- /dev/null
+++ b/test/burn/org/apache/cassandra/index/sai/LongBM25Test.java
@@ -0,0 +1,251 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.index.sai;
+
+import java.io.IOException;
+import java.net.URISyntaxException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ForkJoinPool;
+import java.util.concurrent.ThreadLocalRandom;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.slf4j.Logger;
+
+import org.apache.cassandra.db.memtable.TrieMemtable;
+
+public class LongBM25Test extends SAITester
+{
+ private static final Logger logger = org.slf4j.LoggerFactory.getLogger(LongBM25Test.class);
+
+ private static final List documentLines = new ArrayList<>();
+
+ static
+ {
+ try
+ {
+ var cl = LongBM25Test.class.getClassLoader();
+ var resourceDir = cl.getResource("bm25");
+ if (resourceDir == null)
+ throw new RuntimeException("Could not find resource directory test/resources/bm25/");
+
+ var dirPath = java.nio.file.Paths.get(resourceDir.toURI());
+ try (var files = java.nio.file.Files.list(dirPath))
+ {
+ files.forEach(file -> {
+ try (var lines = java.nio.file.Files.lines(file))
+ {
+ lines.map(String::trim)
+ .filter(line -> !line.isEmpty())
+ .forEach(documentLines::add);
+ }
+ catch (IOException e)
+ {
+ throw new RuntimeException("Failed to read file: " + file, e);
+ }
+ });
+ }
+ if (documentLines.isEmpty())
+ {
+ throw new RuntimeException("No document lines loaded from test/resources/bm25/");
+ }
+ }
+ catch (IOException | URISyntaxException e)
+ {
+ throw new RuntimeException("Failed to load test documents", e);
+ }
+ }
+
+ KeySet keysInserted = new KeySet();
+ private final int threadCount = 12;
+
+ @Before
+ public void setup() throws Throwable
+ {
+ // we don't get loaded until after TM, so we can't affect the very first memtable,
+ // but this will affect all subsequent ones
+ TrieMemtable.SHARD_COUNT = 4 * threadCount;
+ }
+
+ @FunctionalInterface
+ private interface Op
+ {
+ void run(int i) throws Throwable;
+ }
+
+ public void testConcurrentOps(Op op) throws ExecutionException, InterruptedException
+ {
+ createTable("CREATE TABLE %s (key int primary key, value text)");
+ // Create analyzed index following BM25Test pattern
+ createIndex("CREATE CUSTOM INDEX ON %s(value) " +
+ "USING 'org.apache.cassandra.index.sai.StorageAttachedIndex' " +
+ "WITH OPTIONS = {" +
+ "'index_analyzer': '{" +
+ "\"tokenizer\" : {\"name\" : \"standard\"}, " +
+ "\"filters\" : [{\"name\" : \"porterstem\"}]" +
+ "}'}"
+ );
+
+ AtomicInteger counter = new AtomicInteger();
+ long start = System.currentTimeMillis();
+ var fjp = new ForkJoinPool(threadCount);
+ var keys = IntStream.range(0, 10_000_000).boxed().collect(Collectors.toList());
+ Collections.shuffle(keys);
+ var task = fjp.submit(() -> keys.stream().parallel().forEach(i ->
+ {
+ wrappedOp(op, i);
+ if (counter.incrementAndGet() % 10_000 == 0)
+ {
+ var elapsed = System.currentTimeMillis() - start;
+ logger.info("{} ops in {}ms = {} ops/s", counter.get(), elapsed, counter.get() * 1000.0 / elapsed);
+ }
+ if (ThreadLocalRandom.current().nextDouble() < 0.001)
+ flush();
+ }));
+ fjp.shutdown();
+ task.get(); // re-throw
+ }
+
+ private static void wrappedOp(Op op, Integer i)
+ {
+ try
+ {
+ op.run(i);
+ }
+ catch (Throwable e)
+ {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private static String randomDocument()
+ {
+ var R = ThreadLocalRandom.current();
+ int numLines = R.nextInt(5, 51); // 5 to 50 lines inclusive
+ var selectedLines = new ArrayList();
+
+ for (int i = 0; i < numLines; i++)
+ {
+ selectedLines.add(randomQuery(R));
+ }
+
+ return String.join("\n", selectedLines);
+ }
+
+ private static String randomLine(ThreadLocalRandom R)
+ {
+ return documentLines.get(R.nextInt(documentLines.size()));
+ }
+
+ @Test
+ public void testConcurrentReadsWritesDeletes() throws ExecutionException, InterruptedException
+ {
+ testConcurrentOps(i -> {
+ var R = ThreadLocalRandom.current();
+ if (R.nextDouble() < 0.2 || keysInserted.isEmpty())
+ {
+ var doc = randomDocument();
+ execute("INSERT INTO %s (key, value) VALUES (?, ?)", i, doc);
+ keysInserted.add(i);
+ }
+ else if (R.nextDouble() < 0.1)
+ {
+ var key = keysInserted.getRandom();
+ execute("DELETE FROM %s WHERE key = ?", key);
+ }
+ else
+ {
+ var line = randomQuery(R);
+ execute("SELECT * FROM %s ORDER BY value BM25 OF ? LIMIT ?", line, R.nextInt(1, 100));
+ }
+ });
+ }
+
+ private static String randomQuery(ThreadLocalRandom R)
+ {
+ while (true)
+ {
+ var line = randomLine(R);
+ if (line.chars().anyMatch(Character::isAlphabetic))
+ return line;
+ }
+ }
+
+ @Test
+ public void testConcurrentReadsWrites() throws ExecutionException, InterruptedException
+ {
+ testConcurrentOps(i -> {
+ var R = ThreadLocalRandom.current();
+ if (R.nextDouble() < 0.1 || keysInserted.isEmpty())
+ {
+ var doc = randomDocument();
+ execute("INSERT INTO %s (key, value) VALUES (?, ?)", i, doc);
+ keysInserted.add(i);
+ }
+ else
+ {
+ var line = randomQuery(R);
+ execute("SELECT * FROM %s ORDER BY value BM25 OF ? LIMIT ?", line, R.nextInt(1, 100));
+ }
+ });
+ }
+
+ @Test
+ public void testConcurrentWrites() throws ExecutionException, InterruptedException
+ {
+ testConcurrentOps(i -> {
+ var doc = randomDocument();
+ execute("INSERT INTO %s (key, value) VALUES (?, ?)", i, doc);
+ });
+ }
+
+ private static class KeySet
+ {
+ private final Map keys = new ConcurrentHashMap<>();
+ private final AtomicInteger ordinal = new AtomicInteger();
+
+ public void add(int key)
+ {
+ var i = ordinal.getAndIncrement();
+ keys.put(i, key);
+ }
+
+ public int getRandom()
+ {
+ if (isEmpty())
+ throw new IllegalStateException();
+ var i = ThreadLocalRandom.current().nextInt(ordinal.get());
+ // in case there is race with add(key), retry another random
+ return keys.containsKey(i) ? keys.get(i) : getRandom();
+ }
+
+ public boolean isEmpty()
+ {
+ return keys.isEmpty();
+ }
+ }
+}
diff --git a/test/distributed/org/apache/cassandra/distributed/test/sai/BM25DistributedTest.java b/test/distributed/org/apache/cassandra/distributed/test/sai/BM25DistributedTest.java
new file mode 100644
index 000000000000..95c7587ea044
--- /dev/null
+++ b/test/distributed/org/apache/cassandra/distributed/test/sai/BM25DistributedTest.java
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.distributed.test.sai;
+
+import java.util.concurrent.atomic.AtomicInteger;
+
+import org.junit.AfterClass;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import org.apache.cassandra.distributed.Cluster;
+import org.apache.cassandra.distributed.api.ConsistencyLevel;
+import org.apache.cassandra.distributed.test.TestBaseImpl;
+import org.apache.cassandra.index.sai.disk.format.Version;
+
+import static org.apache.cassandra.distributed.api.Feature.GOSSIP;
+import static org.apache.cassandra.distributed.api.Feature.NETWORK;
+import static org.assertj.core.api.Assertions.assertThat;
+
+public class BM25DistributedTest extends TestBaseImpl
+{
+ private static final String CREATE_KEYSPACE = "CREATE KEYSPACE %%s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': %d}";
+ private static final String CREATE_TABLE = "CREATE TABLE %s (k int PRIMARY KEY, v text)";
+ private static final String CREATE_INDEX = "CREATE CUSTOM INDEX ON %s(v) USING 'StorageAttachedIndex' WITH OPTIONS = {'index_analyzer': '{\"tokenizer\" : {\"name\" : \"standard\"}, \"filters\" : [{\"name\" : \"porterstem\"}]}'}";
+
+ // To get consistent results from BM25 we need to know which docs are evaluated, the easiest way
+ // to do that is to put all the docs on every replica
+ private static final int NUM_NODES = 3;
+ private static final int RF = 3;
+
+ private static Cluster cluster;
+ private static String table;
+
+ private static final AtomicInteger seq = new AtomicInteger();
+
+ @BeforeClass
+ public static void setupCluster() throws Exception
+ {
+ cluster = Cluster.build(NUM_NODES)
+ .withTokenCount(1)
+ .withDataDirCount(1)
+ .withConfig(config -> config.with(GOSSIP).with(NETWORK))
+ .start();
+
+ cluster.schemaChange(withKeyspace(String.format(CREATE_KEYSPACE, RF)));
+ cluster.forEach(i -> i.runOnInstance(() -> org.apache.cassandra.index.sai.SAIUtil.setLatestVersion(Version.EC)));
+ }
+
+ @AfterClass
+ public static void closeCluster()
+ {
+ if (cluster != null)
+ cluster.close();
+ }
+
+ @Before
+ public void before()
+ {
+ table = "table_" + seq.getAndIncrement();
+ cluster.schemaChange(formatQuery(CREATE_TABLE));
+ cluster.schemaChange(formatQuery(CREATE_INDEX));
+ SAIUtil.waitForIndexQueryable(cluster, KEYSPACE);
+ }
+
+ @Test
+ public void testTermFrequencyOrdering()
+ {
+ // Insert documents with varying frequencies of the term "apple"
+ execute("INSERT INTO %s (k, v) VALUES (1, 'apple')");
+ execute("INSERT INTO %s (k, v) VALUES (2, 'apple apple')");
+ execute("INSERT INTO %s (k, v) VALUES (3, 'apple apple apple')");
+
+ // Query memtable index
+ assertBM25Ordering();
+
+ // Flush and query on-disk index
+ cluster.forEach(n -> n.flush(KEYSPACE));
+ assertBM25Ordering();
+ }
+
+ private void assertBM25Ordering()
+ {
+ Object[][] result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'apple' LIMIT 3");
+ assertThat(result).hasNumberOfRows(3);
+
+ // Results should be ordered by term frequency (highest to lowest)
+ assertThat((Integer) result[0][0]).isEqualTo(3); // 3 occurrences
+ assertThat((Integer) result[1][0]).isEqualTo(2); // 2 occurrences
+ assertThat((Integer) result[2][0]).isEqualTo(1); // 1 occurrence
+ }
+
+ private static Object[][] execute(String query)
+ {
+ return execute(query, ConsistencyLevel.QUORUM);
+ }
+
+ private static Object[][] execute(String query, ConsistencyLevel consistencyLevel)
+ {
+ return cluster.coordinator(1).execute(formatQuery(query), consistencyLevel);
+ }
+
+ private static String formatQuery(String query)
+ {
+ return String.format(query, KEYSPACE + '.' + table);
+ }
+}
diff --git a/test/unit/org/apache/cassandra/cql3/CQLTester.java b/test/unit/org/apache/cassandra/cql3/CQLTester.java
index 38b937041caf..aec7238c9348 100644
--- a/test/unit/org/apache/cassandra/cql3/CQLTester.java
+++ b/test/unit/org/apache/cassandra/cql3/CQLTester.java
@@ -768,6 +768,11 @@ protected String currentIndex()
return indexes.get(indexes.size() - 1);
}
+ protected String getIndex(int i)
+ {
+ return indexes.get(i);
+ }
+
protected Collection currentTables()
{
if (tables == null || tables.isEmpty())
diff --git a/test/unit/org/apache/cassandra/cql3/validation/operations/SelectOrderByTest.java b/test/unit/org/apache/cassandra/cql3/validation/operations/SelectOrderByTest.java
index dac86c76e5d8..34f7d606ce55 100644
--- a/test/unit/org/apache/cassandra/cql3/validation/operations/SelectOrderByTest.java
+++ b/test/unit/org/apache/cassandra/cql3/validation/operations/SelectOrderByTest.java
@@ -659,7 +659,7 @@ public void testAllowSkippingEqualityAndSingleValueInRestrictedClusteringColumns
assertInvalidMessage("Cannot combine clustering column ordering with non-clustering column ordering",
"SELECT * FROM %s WHERE a=? ORDER BY b ASC, c ASC, d ASC", 0);
- String errorMsg = "Order by currently only supports the ordering of columns following their declared order in the PRIMARY KEY";
+ String errorMsg = "Ordering by clustered columns must follow the declared order in the PRIMARY KEY";
assertRows(execute("SELECT * FROM %s WHERE a=? AND b=? ORDER BY c", 0, 0),
row(0, 0, 0, 0),
diff --git a/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java b/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java
new file mode 100644
index 000000000000..3a40e075114a
--- /dev/null
+++ b/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java
@@ -0,0 +1,510 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.index.sai.cql;
+
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.cassandra.index.sai.SAITester;
+import org.apache.cassandra.index.sai.SAIUtil;
+import org.apache.cassandra.index.sai.disk.format.Version;
+import org.apache.cassandra.index.sai.disk.v1.SegmentBuilder;
+import org.apache.cassandra.index.sai.plan.QueryController;
+
+import static org.apache.cassandra.index.sai.analyzer.AnalyzerEqOperatorSupport.EQ_AMBIGUOUS_ERROR;
+import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
+
+public class BM25Test extends SAITester
+{
+ @Before
+ public void setup() throws Throwable
+ {
+ SAIUtil.setLatestVersion(Version.EC);
+ }
+
+ @Test
+ public void testTwoIndexes()
+ {
+ // create un-analyzed index
+ createTable("CREATE TABLE %s (k int PRIMARY KEY, v text)");
+ createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'");
+ execute("INSERT INTO %s (k, v) VALUES (1, 'apple')");
+
+ // BM25 should fail with only an equality index
+ assertInvalidMessage("BM25 ordering on column v requires an analyzed index",
+ "SELECT k FROM %s WHERE v : 'apple' ORDER BY v BM25 OF 'apple' LIMIT 3");
+
+ // create analyzed index
+ analyzeIndex();
+ // BM25 query should work now
+ var result = execute("SELECT k FROM %s WHERE v : 'apple' ORDER BY v BM25 OF 'apple' LIMIT 3");
+ assertRows(result, row(1));
+ }
+
+ @Test
+ public void testTwoIndexesAmbiguousPredicate() throws Throwable
+ {
+ createTable("CREATE TABLE %s (k int PRIMARY KEY, v text)");
+
+ // Create analyzed and un-analyzed indexes
+ analyzeIndex();
+ createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'");
+
+ execute("INSERT INTO %s (k, v) VALUES (1, 'apple')");
+ execute("INSERT INTO %s (k, v) VALUES (2, 'apple juice')");
+ execute("INSERT INTO %s (k, v) VALUES (3, 'orange juice')");
+
+ // equality predicate is ambiguous (both analyzed and un-analyzed indexes could support it) so it should
+ // be rejected
+ beforeAndAfterFlush(() -> {
+ // Single predicate
+ assertInvalidMessage(String.format(EQ_AMBIGUOUS_ERROR, "v", getIndex(0), getIndex(1)),
+ "SELECT k FROM %s WHERE v = 'apple'");
+
+ // AND
+ assertInvalidMessage(String.format(EQ_AMBIGUOUS_ERROR, "v", getIndex(0), getIndex(1)),
+ "SELECT k FROM %s WHERE v = 'apple' AND v : 'juice'");
+
+ // OR
+ assertInvalidMessage(String.format(EQ_AMBIGUOUS_ERROR, "v", getIndex(0), getIndex(1)),
+ "SELECT k FROM %s WHERE v = 'apple' OR v : 'juice'");
+ });
+ }
+
+ @Test
+ public void testTwoIndexesWithEqualsUnsupported() throws Throwable
+ {
+ createTable("CREATE TABLE %s (k int PRIMARY KEY, v text)");
+ createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'");
+ // analyzed index with equals_behavior:unsupported option
+ createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex' " +
+ "WITH OPTIONS = { 'equals_behaviour_when_analyzed': 'unsupported', " +
+ "'index_analyzer':'{\"tokenizer\":{\"name\":\"standard\"},\"filters\":[{\"name\":\"porterstem\"}]}' }");
+
+ execute("INSERT INTO %s (k, v) VALUES (1, 'apple')");
+ execute("INSERT INTO %s (k, v) VALUES (2, 'apple juice')");
+
+ beforeAndAfterFlush(() -> {
+ // combining two EQ predicates is not allowed
+ assertInvalid("SELECT k FROM %s WHERE v = 'apple' AND v = 'juice'");
+
+ // combining EQ and MATCH predicates is also not allowed (when we're not converting EQ to MATCH)
+ assertInvalid("SELECT k FROM %s WHERE v = 'apple' AND v : 'apple'");
+
+ // combining two MATCH predicates is fine
+ assertRows(execute("SELECT k FROM %s WHERE v : 'apple' AND v : 'juice'"),
+ row(2));
+
+ // = operator should use un-analyzed index since equals is unsupported in analyzed index
+ assertRows(execute("SELECT k FROM %s WHERE v = 'apple'"),
+ row(1));
+
+ // : operator should use analyzed index
+ assertRows(execute("SELECT k FROM %s WHERE v : 'apple'"),
+ row(1), row(2));
+ });
+ }
+
+ @Test
+ public void testComplexQueriesWithMultipleIndexes() throws Throwable
+ {
+ createTable("CREATE TABLE %s (k int PRIMARY KEY, v1 text, v2 text, v3 int)");
+
+ // Create mix of analyzed, unanalyzed, and non-text indexes
+ createIndex("CREATE CUSTOM INDEX ON %s(v1) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'");
+ createIndex("CREATE CUSTOM INDEX ON %s(v2) " +
+ "USING 'org.apache.cassandra.index.sai.StorageAttachedIndex' " +
+ "WITH OPTIONS = {" +
+ "'index_analyzer': '{" +
+ "\"tokenizer\" : {\"name\" : \"standard\"}, " +
+ "\"filters\" : [{\"name\" : \"porterstem\"}]" +
+ "}'" +
+ "}");
+ createIndex("CREATE CUSTOM INDEX ON %s(v3) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'");
+
+ execute("INSERT INTO %s (k, v1, v2, v3) VALUES (1, 'apple', 'orange juice', 5)");
+ execute("INSERT INTO %s (k, v1, v2, v3) VALUES (2, 'apple juice', 'apple', 10)");
+ execute("INSERT INTO %s (k, v1, v2, v3) VALUES (3, 'banana', 'grape juice', 5)");
+
+ beforeAndAfterFlush(() -> {
+ // Complex query mixing different types of indexes and operators
+ assertRows(execute("SELECT k FROM %s WHERE v1 = 'apple' AND v2 : 'juice' AND v3 = 5"),
+ row(1));
+
+ // Mix of AND and OR conditions across different index types
+ assertRows(execute("SELECT k FROM %s WHERE v3 = 5 AND (v1 = 'apple' OR v2 : 'apple')"),
+ row(1));
+
+ // Multi-term analyzed query
+ assertRows(execute("SELECT k FROM %s WHERE v2 : 'orange juice'"),
+ row(1));
+
+ // Range query with text match
+ assertRows(execute("SELECT k FROM %s WHERE v3 >= 5 AND v2 : 'juice'"),
+ row(1), row(3));
+ });
+ }
+
+ @Test
+ public void testMatchingAllowed() throws Throwable
+ {
+ // match operator should be allowed with BM25 on the same column
+ // (seems obvious but exercises a corner case in the internal RestrictionSet processing)
+ createSimpleTable();
+
+ execute("INSERT INTO %s (k, v) VALUES (1, 'apple')");
+
+ beforeAndAfterFlush(() ->
+ {
+ var result = execute("SELECT k FROM %s WHERE v : 'apple' ORDER BY v BM25 OF 'apple' LIMIT 3");
+ assertRows(result, row(1));
+ });
+ }
+
+ @Test
+ public void testUnknownQueryTerm() throws Throwable
+ {
+ createSimpleTable();
+
+ execute("INSERT INTO %s (k, v) VALUES (1, 'apple')");
+
+ beforeAndAfterFlush(() ->
+ {
+ var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'orange' LIMIT 1");
+ assertEmpty(result);
+ });
+ }
+
+ @Test
+ public void testDuplicateQueryTerm() throws Throwable
+ {
+ createSimpleTable();
+
+ execute("INSERT INTO %s (k, v) VALUES (1, 'apple')");
+
+ beforeAndAfterFlush(() ->
+ {
+ var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'apple apple' LIMIT 1");
+ assertRows(result, row(1));
+ });
+ }
+
+ @Test
+ public void testEmptyQuery() throws Throwable
+ {
+ createSimpleTable();
+
+ execute("INSERT INTO %s (k, v) VALUES (1, 'apple')");
+
+ beforeAndAfterFlush(() ->
+ {
+ assertInvalidMessage("BM25 query must contain at least one term (perhaps your analyzer is discarding tokens you didn't expect)",
+ "SELECT k FROM %s ORDER BY v BM25 OF '+' LIMIT 1");
+ });
+ }
+
+ @Test
+ public void testTermFrequencyOrdering() throws Throwable
+ {
+ createSimpleTable();
+
+ // Insert documents with varying frequencies of the term "apple"
+ execute("INSERT INTO %s (k, v) VALUES (1, 'apple')");
+ execute("INSERT INTO %s (k, v) VALUES (2, 'apple apple')");
+ execute("INSERT INTO %s (k, v) VALUES (3, 'apple apple apple')");
+
+ beforeAndAfterFlush(() ->
+ {
+ // Results should be ordered by term frequency (highest to lowest)
+ var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'apple' LIMIT 3");
+ assertRows(result,
+ row(3), // 3 occurrences
+ row(2), // 2 occurrences
+ row(1)); // 1 occurrence
+ });
+ }
+
+ @Test
+ public void testTermFrequenciesWithOverwrites() throws Throwable
+ {
+ createSimpleTable();
+
+ // Insert documents with varying frequencies of the term "apple", but overwrite the first term
+ // This exercises the code that is supposed to reset frequency counts for overwrites
+ execute("INSERT INTO %s (k, v) VALUES (1, 'apple')");
+ execute("INSERT INTO %s (k, v) VALUES (1, 'apple')");
+ execute("INSERT INTO %s (k, v) VALUES (1, 'apple')");
+ execute("INSERT INTO %s (k, v) VALUES (1, 'apple')");
+ execute("INSERT INTO %s (k, v) VALUES (2, 'apple apple')");
+ execute("INSERT INTO %s (k, v) VALUES (3, 'apple apple apple')");
+
+ beforeAndAfterFlush(() ->
+ {
+ // Results should be ordered by term frequency (highest to lowest)
+ var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'apple' LIMIT 3");
+ assertRows(result,
+ row(3), // 3 occurrences
+ row(2), // 2 occurrences
+ row(1)); // 1 occurrence
+ });
+ }
+
+ @Test
+ public void testDocumentLength() throws Throwable
+ {
+ createSimpleTable();
+ // Create documents with same term frequency but different lengths
+ execute("INSERT INTO %s (k, v) VALUES (1, 'test test')");
+ execute("INSERT INTO %s (k, v) VALUES (2, 'test test other words here to make it longer')");
+ execute("INSERT INTO %s (k, v) VALUES (3, 'test test extremely long document with many additional words to significantly increase the document length while maintaining the same term frequency for our target term')");
+
+ beforeAndAfterFlush(() ->
+ {
+ // Documents with same term frequency should be ordered by length (shorter first)
+ var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'test' LIMIT 3");
+ assertRows(result,
+ row(1),
+ row(2),
+ row(3));
+ });
+ }
+
+ @Test
+ public void testMultiTermQueryScoring() throws Throwable
+ {
+ createSimpleTable();
+ // Two terms, but "apple" appears in fewer documents
+ execute("INSERT INTO %s (k, v) VALUES (1, 'apple banana')");
+ execute("INSERT INTO %s (k, v) VALUES (2, 'apple apple banana')");
+ execute("INSERT INTO %s (k, v) VALUES (3, 'apple banana banana')");
+ execute("INSERT INTO %s (k, v) VALUES (4, 'apple apple banana banana')");
+ execute("INSERT INTO %s (k, v) VALUES (5, 'banana banana')");
+
+ beforeAndAfterFlush(() ->
+ {
+ var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'apple banana' LIMIT 4");
+ assertRows(result,
+ row(2), // Highest frequency of most important term
+ row(4), // More mentions of both terms
+ row(1), // One of each term
+ row(3)); // Low frequency of most important term
+ });
+ }
+
+ @Test
+ public void testIrrelevantRowsScoring() throws Throwable
+ {
+ createSimpleTable();
+ // Insert pizza reviews with varying relevance to "crispy crust"
+ execute("INSERT INTO %s (k, v) VALUES (1, 'The pizza had a crispy crust and was delicious')"); // Basic mention
+ execute("INSERT INTO %s (k, v) VALUES (2, 'Very crispy crispy crust, perfectly cooked')"); // Emphasized crispy
+ execute("INSERT INTO %s (k, v) VALUES (3, 'The crust crust crust was okay, nothing special')"); // Only crust mentions
+ execute("INSERT INTO %s (k, v) VALUES (4, 'Super crispy crispy crust crust, best pizza ever!')"); // Most mentions of both
+ execute("INSERT INTO %s (k, v) VALUES (5, 'The toppings were good but the pizza was soggy')"); // Irrelevant review
+
+ beforeAndAfterFlush(this::assertIrrelevantRowsCorrect);
+ }
+
+ private void assertIrrelevantRowsCorrect()
+ {
+ var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'crispy crust' LIMIT 5");
+ assertRows(result,
+ row(4), // Highest frequency of both terms
+ row(2), // High frequency of 'crispy', one 'crust'
+ row(1)); // One mention of each term
+ // Rows 4 and 5 do not contain all terms
+ }
+
+ @Test
+ public void testIrrelevantRowsWithCompaction()
+ {
+ // same dataset as testIrrelevantRowsScoring, but split across two sstables
+ createSimpleTable();
+ disableCompaction();
+
+ execute("INSERT INTO %s (k, v) VALUES (1, 'The pizza had a crispy crust and was delicious')"); // Basic mention
+ execute("INSERT INTO %s (k, v) VALUES (2, 'Very crispy crispy crust, perfectly cooked')"); // Emphasized crispy
+ flush();
+
+ execute("INSERT INTO %s (k, v) VALUES (3, 'The crust crust crust was okay, nothing special')"); // Only crust mentions
+ execute("INSERT INTO %s (k, v) VALUES (4, 'Super crispy crispy crust crust, best pizza ever!')"); // Most mentions of both
+ execute("INSERT INTO %s (k, v) VALUES (5, 'The toppings were good but the pizza was soggy')"); // Irrelevant review
+ flush();
+
+ assertIrrelevantRowsCorrect();
+
+ compact();
+ assertIrrelevantRowsCorrect();
+
+ // Force segmentation and requery
+ SegmentBuilder.updateLastValidSegmentRowId(2);
+ compact();
+ assertIrrelevantRowsCorrect();
+ }
+
+ private void createSimpleTable()
+ {
+ createTable("CREATE TABLE %s (k int PRIMARY KEY, v text)");
+ analyzeIndex();
+ }
+
+ private String analyzeIndex()
+ {
+ return createIndex("CREATE CUSTOM INDEX ON %s(v) " +
+ "USING 'org.apache.cassandra.index.sai.StorageAttachedIndex' " +
+ "WITH OPTIONS = {" +
+ "'index_analyzer': '{" +
+ "\"tokenizer\" : {\"name\" : \"standard\"}, " +
+ "\"filters\" : [{\"name\" : \"porterstem\"}]" +
+ "}'}"
+ );
+ }
+
+ @Test
+ public void testWithPredicate() throws Throwable
+ {
+ createTable("CREATE TABLE %s (k int PRIMARY KEY, p int, v text)");
+ analyzeIndex();
+ execute("CREATE CUSTOM INDEX ON %s(p) USING 'StorageAttachedIndex'");
+
+ // Insert documents with varying frequencies of the term "apple"
+ execute("INSERT INTO %s (k, p, v) VALUES (1, 5, 'apple')");
+ execute("INSERT INTO %s (k, p, v) VALUES (2, 5, 'apple apple')");
+ execute("INSERT INTO %s (k, p, v) VALUES (3, 5, 'apple apple apple')");
+ execute("INSERT INTO %s (k, p, v) VALUES (4, 6, 'apple apple apple')");
+ execute("INSERT INTO %s (k, p, v) VALUES (5, 7, 'apple apple apple')");
+
+ beforeAndAfterFlush(() ->
+ {
+ // Results should be ordered by term frequency (highest to lowest)
+ var result = execute("SELECT k FROM %s WHERE p = 5 ORDER BY v BM25 OF 'apple' LIMIT 3");
+ assertRows(result,
+ row(3), // 3 occurrences
+ row(2), // 2 occurrences
+ row(1)); // 1 occurrence
+ });
+ }
+
+ @Test
+ public void testWidePartition() throws Throwable
+ {
+ createTable("CREATE TABLE %s (k1 int, k2 int, v text, PRIMARY KEY (k1, k2))");
+ analyzeIndex();
+
+ // Insert documents with varying frequencies of the term "apple"
+ execute("INSERT INTO %s (k1, k2, v) VALUES (0, 1, 'apple')");
+ execute("INSERT INTO %s (k1, k2, v) VALUES (0, 2, 'apple apple')");
+ execute("INSERT INTO %s (k1, k2, v) VALUES (0, 3, 'apple apple apple')");
+
+ beforeAndAfterFlush(() ->
+ {
+ // Results should be ordered by term frequency (highest to lowest)
+ var result = execute("SELECT k2 FROM %s ORDER BY v BM25 OF 'apple' LIMIT 3");
+ assertRows(result,
+ row(3), // 3 occurrences
+ row(2), // 2 occurrences
+ row(1)); // 1 occurrence
+ });
+ }
+
+ @Test
+ public void testWidePartitionWithPkPredicate() throws Throwable
+ {
+ createTable("CREATE TABLE %s (k1 int, k2 int, v text, PRIMARY KEY (k1, k2))");
+ analyzeIndex();
+
+ // Insert documents with varying frequencies of the term "apple"
+ execute("INSERT INTO %s (k1, k2, v) VALUES (0, 1, 'apple')");
+ execute("INSERT INTO %s (k1, k2, v) VALUES (0, 2, 'apple apple')");
+ execute("INSERT INTO %s (k1, k2, v) VALUES (0, 3, 'apple apple apple')");
+ execute("INSERT INTO %s (k1, k2, v) VALUES (1, 3, 'apple apple apple')");
+ execute("INSERT INTO %s (k1, k2, v) VALUES (2, 3, 'apple apple apple')");
+
+ beforeAndAfterFlush(() ->
+ {
+ // Results should be ordered by term frequency (highest to lowest)
+ var result = execute("SELECT k2 FROM %s WHERE k1 = 0 ORDER BY v BM25 OF 'apple' LIMIT 3");
+ assertRows(result,
+ row(3), // 3 occurrences
+ row(2), // 2 occurrences
+ row(1)); // 1 occurrence
+ });
+ }
+
+ @Test
+ public void testWidePartitionWithPredicate() throws Throwable
+ {
+ createTable("CREATE TABLE %s (k1 int, k2 int, p int, v text, PRIMARY KEY (k1, k2))");
+ analyzeIndex();
+ execute("CREATE CUSTOM INDEX ON %s(p) USING 'StorageAttachedIndex'");
+
+ // Insert documents with varying frequencies of the term "apple"
+ execute("INSERT INTO %s (k1, k2, p, v) VALUES (0, 1, 5, 'apple')");
+ execute("INSERT INTO %s (k1, k2, p, v) VALUES (0, 2, 5, 'apple apple')");
+ execute("INSERT INTO %s (k1, k2, p, v) VALUES (0, 3, 5, 'apple apple apple')");
+ execute("INSERT INTO %s (k1, k2, p, v) VALUES (0, 4, 6, 'apple apple apple')");
+ execute("INSERT INTO %s (k1, k2, p, v) VALUES (0, 5, 7, 'apple apple apple')");
+
+ beforeAndAfterFlush(() ->
+ {
+ // Results should be ordered by term frequency (highest to lowest)
+ var result = execute("SELECT k2 FROM %s WHERE p = 5 ORDER BY v BM25 OF 'apple' LIMIT 3");
+ assertRows(result,
+ row(3), // 3 occurrences
+ row(2), // 2 occurrences
+ row(1)); // 1 occurrence
+ });
+ }
+
+ @Test
+ public void testWithPredicateSearchThenOrder() throws Throwable
+ {
+ QueryController.QUERY_OPT_LEVEL = 0;
+ testWithPredicate();
+ }
+
+ @Test
+ public void testWidePartitionWithPredicateOrderThenSearch() throws Throwable
+ {
+ QueryController.QUERY_OPT_LEVEL = 1;
+ testWidePartitionWithPredicate();
+ }
+
+ @Test
+ public void testQueryWithNulls() throws Throwable
+ {
+ createSimpleTable();
+
+ execute("INSERT INTO %s (k, v) VALUES (0, null)");
+ execute("INSERT INTO %s (k, v) VALUES (1, 'test document')");
+ beforeAndAfterFlush(() ->
+ {
+ var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'test' LIMIT 1");
+ assertRows(result, row(1));
+ });
+ }
+
+ @Test
+ public void testQueryEmptyTable()
+ {
+ createSimpleTable();
+ var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'test' LIMIT 1");
+ assertThat(result).hasSize(0);
+ }
+}
diff --git a/test/unit/org/apache/cassandra/index/sai/cql/MultipleColumnIndexTest.java b/test/unit/org/apache/cassandra/index/sai/cql/MultipleColumnIndexTest.java
index 5364717a8906..8d952dc1f4f0 100644
--- a/test/unit/org/apache/cassandra/index/sai/cql/MultipleColumnIndexTest.java
+++ b/test/unit/org/apache/cassandra/index/sai/cql/MultipleColumnIndexTest.java
@@ -40,12 +40,30 @@ public void canCreateMultipleMapIndexesOnSameColumn() throws Throwable
}
@Test
- public void cannotHaveMultipleLiteralIndexesWithDifferentOptions() throws Throwable
+ public void canHaveAnalyzedAndUnanalyzedIndexesOnSameColumn() throws Throwable
{
- createTable("CREATE TABLE %s (pk int, ck int, value text, PRIMARY KEY(pk, ck))");
+ createTable("CREATE TABLE %s (pk int, value text, PRIMARY KEY(pk))");
createIndex("CREATE CUSTOM INDEX ON %s(value) USING 'StorageAttachedIndex' WITH OPTIONS = { 'case_sensitive' : true }");
- assertThatThrownBy(() -> createIndex("CREATE CUSTOM INDEX ON %s(value) USING 'StorageAttachedIndex' WITH OPTIONS = { 'case_sensitive' : false }"))
- .isInstanceOf(InvalidRequestException.class);
+ createIndex("CREATE CUSTOM INDEX ON %s(value) USING 'StorageAttachedIndex' WITH OPTIONS = { 'case_sensitive' : false, 'equals_behaviour_when_analyzed': 'unsupported' }");
+
+ execute("INSERT INTO %s (pk, value) VALUES (?, ?)", 1, "a");
+ execute("INSERT INTO %s (pk, value) VALUES (?, ?)", 2, "A");
+ beforeAndAfterFlush(() -> {
+ assertRows(execute("SELECT pk FROM %s WHERE value = 'a'"),
+ row(1));
+ assertRows(execute("SELECT pk FROM %s WHERE value : 'a'"),
+ row(1),
+ row(2));
+ });
+ }
+
+ @Test
+ public void cannotHaveMultipleAnalyzingIndexesOnSameColumn() throws Throwable
+ {
+ createTable("CREATE TABLE %s (pk int, ck int, value text, PRIMARY KEY(pk, ck))");
+ createIndex("CREATE CUSTOM INDEX ON %s(value) USING 'StorageAttachedIndex' WITH OPTIONS = { 'case_sensitive' : false }");
+ assertThatThrownBy(() -> createIndex("CREATE CUSTOM INDEX ON %s(value) USING 'StorageAttachedIndex' WITH OPTIONS = { 'normalize' : true }"))
+ .isInstanceOf(InvalidRequestException.class);
}
@Test
diff --git a/test/unit/org/apache/cassandra/index/sai/cql/NativeIndexDDLTest.java b/test/unit/org/apache/cassandra/index/sai/cql/NativeIndexDDLTest.java
index ebda5a426430..b7b5a078b2c0 100644
--- a/test/unit/org/apache/cassandra/index/sai/cql/NativeIndexDDLTest.java
+++ b/test/unit/org/apache/cassandra/index/sai/cql/NativeIndexDDLTest.java
@@ -581,7 +581,7 @@ public void shouldFailCreationMultipleIndexesOnSimpleColumn()
// different name, different option, same target.
assertThatThrownBy(() -> executeNet("CREATE CUSTOM INDEX ON %s(v1) USING 'StorageAttachedIndex' WITH OPTIONS = { 'case_sensitive' : true }"))
.isInstanceOf(InvalidQueryException.class)
- .hasMessageContaining("Cannot create more than one storage-attached index on the same column: v1" );
+ .hasMessageContaining("Cannot create duplicate storage-attached index on column: v1" );
ResultSet rows = executeNet("SELECT id FROM %s WHERE v1 = '1'");
assertEquals(1, rows.all().size());
diff --git a/test/unit/org/apache/cassandra/index/sai/disk/RAMPostingSlicesTest.java b/test/unit/org/apache/cassandra/index/sai/disk/RAMPostingSlicesTest.java
index 8d49798d906b..d6ca88d935ee 100644
--- a/test/unit/org/apache/cassandra/index/sai/disk/RAMPostingSlicesTest.java
+++ b/test/unit/org/apache/cassandra/index/sai/disk/RAMPostingSlicesTest.java
@@ -19,6 +19,8 @@
import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
import org.junit.Test;
@@ -31,7 +33,7 @@ public class RAMPostingSlicesTest extends SaiRandomizedTest
@Test
public void testRAMPostingSlices() throws Exception
{
- RAMPostingSlices slices = new RAMPostingSlices(Counter.newCounter());
+ RAMPostingSlices slices = new RAMPostingSlices(Counter.newCounter(), false);
int[] segmentRowIdUpto = new int[1024];
Arrays.fill(segmentRowIdUpto, -1);
@@ -56,7 +58,7 @@ public void testRAMPostingSlices() throws Exception
bitSets[termID].set(segmentRowIdUpto[termID]);
- slices.writeVInt(termID, segmentRowIdUpto[termID]);
+ slices.writePosting(termID, segmentRowIdUpto[termID], 1);
}
for (int termID = 0; termID < segmentRowIdUpto.length; termID++)
@@ -74,4 +76,36 @@ public void testRAMPostingSlices() throws Exception
assertEquals(segmentRowId, segmentRowIdUpto[termID]);
}
}
+
+ @Test
+ public void testRAMPostingSlicesWithFrequencies() throws Exception {
+ RAMPostingSlices slices = new RAMPostingSlices(Counter.newCounter(), true);
+
+ // Test with just 3 terms and known frequencies
+ for (int termId = 0; termId < 3; termId++) {
+ slices.createNewSlice(termId);
+
+ // Write a sequence of rows with different frequencies for each term
+ slices.writePosting(termId, 5, 1); // first posting at row 5
+ slices.writePosting(termId, 3, 2); // next at row 8 (delta=3)
+ slices.writePosting(termId, 2, 3); // next at row 10 (delta=2)
+ }
+
+ // Verify each term's postings
+ for (int termId = 0; termId < 3; termId++) {
+ ByteSliceReader reader = new ByteSliceReader();
+ PostingList postings = slices.postingList(termId, reader, 10);
+
+ assertEquals(5, postings.nextPosting());
+ assertEquals(1, postings.frequency());
+
+ assertEquals(8, postings.nextPosting());
+ assertEquals(2, postings.frequency());
+
+ assertEquals(10, postings.nextPosting());
+ assertEquals(3, postings.frequency());
+
+ assertEquals(PostingList.END_OF_STREAM, postings.nextPosting());
+ }
+ }
}
diff --git a/test/unit/org/apache/cassandra/index/sai/disk/RAMStringIndexerTest.java b/test/unit/org/apache/cassandra/index/sai/disk/RAMStringIndexerTest.java
index 9ee1605512cf..15d47013f946 100644
--- a/test/unit/org/apache/cassandra/index/sai/disk/RAMStringIndexerTest.java
+++ b/test/unit/org/apache/cassandra/index/sai/disk/RAMStringIndexerTest.java
@@ -21,11 +21,15 @@
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import org.junit.Assert;
import org.junit.Test;
+import org.apache.cassandra.index.sai.disk.format.IndexComponent;
+import org.apache.cassandra.index.sai.disk.v1.trie.InvertedIndexWriter;
import org.apache.cassandra.index.sai.utils.SaiRandomizedTest;
import org.apache.cassandra.index.sai.utils.TypeUtil;
import org.apache.cassandra.utils.ByteBufferUtil;
@@ -40,13 +44,13 @@ public class RAMStringIndexerTest extends SaiRandomizedTest
@Test
public void test() throws Exception
{
- RAMStringIndexer indexer = new RAMStringIndexer();
+ RAMStringIndexer indexer = new RAMStringIndexer(false);
- indexer.add(new BytesRef("0"), 100);
- indexer.add(new BytesRef("2"), 102);
- indexer.add(new BytesRef("0"), 200);
- indexer.add(new BytesRef("2"), 202);
- indexer.add(new BytesRef("2"), 302);
+ indexer.addAll(List.of(new BytesRef("0")), 100);
+ indexer.addAll(List.of(new BytesRef("2")), 102);
+ indexer.addAll(List.of(new BytesRef("0")), 200);
+ indexer.addAll(List.of(new BytesRef("2")), 202);
+ indexer.addAll(List.of(new BytesRef("2")), 302);
List> matches = new ArrayList<>();
matches.add(Arrays.asList(100L, 200L));
@@ -75,10 +79,48 @@ public void test() throws Exception
}
}
+ @Test
+ public void testWithFrequencies() throws Exception
+ {
+ RAMStringIndexer indexer = new RAMStringIndexer(true);
+
+ // Add same term twice in same row to increment frequency
+ indexer.addAll(List.of(new BytesRef("A"), new BytesRef("A")), 100);
+ indexer.addAll(List.of(new BytesRef("B")), 102);
+ indexer.addAll(List.of(new BytesRef("A"), new BytesRef("A"), new BytesRef("A")), 200);
+ indexer.addAll(List.of(new BytesRef("B"), new BytesRef("B")), 202);
+ indexer.addAll(List.of(new BytesRef("B")), 302);
+
+ // Expected results: rowID -> frequency
+ List> matches = Arrays.asList(Map.of(100L, 2, 200L, 3),
+ Map.of(102L, 1, 202L, 2, 302L, 1));
+
+ try (TermsIterator terms = indexer.getTermsWithPostings(ByteBufferUtil.bytes("A"), ByteBufferUtil.bytes("B"), TypeUtil.BYTE_COMPARABLE_VERSION))
+ {
+ int ord = 0;
+ while (terms.hasNext())
+ {
+ terms.next();
+ try (PostingList postings = terms.postings())
+ {
+ Map results = new HashMap<>();
+ long segmentRowId;
+ while ((segmentRowId = postings.nextPosting()) != PostingList.END_OF_STREAM)
+ {
+ results.put(segmentRowId, postings.frequency());
+ }
+ assertEquals(matches.get(ord++), results);
+ }
+ }
+ assertArrayEquals("A".getBytes(), terms.getMinTerm().array());
+ assertArrayEquals("B".getBytes(), terms.getMaxTerm().array());
+ }
+ }
+
@Test
public void testLargeSegment() throws IOException
{
- final RAMStringIndexer indexer = new RAMStringIndexer();
+ final RAMStringIndexer indexer = new RAMStringIndexer(false);
final int numTerms = between(1 << 10, 1 << 13);
final int numPostings = between(1 << 5, 1 << 10);
@@ -87,7 +129,7 @@ public void testLargeSegment() throws IOException
final BytesRef term = new BytesRef(String.format("%04d", id));
for (int posting = 0; posting < numPostings; ++posting)
{
- indexer.add(term, posting);
+ indexer.addAll(List.of(term), posting);
}
}
@@ -124,14 +166,14 @@ public void testRequiresFlush()
{
RAMStringIndexer.MAX_BLOCK_BYTE_POOL_SIZE = 1024 * 1024 * 100;
// primary behavior we're testing is that exceptions aren't thrown due to overflowing backing structures
- RAMStringIndexer indexer = new RAMStringIndexer();
+ RAMStringIndexer indexer = new RAMStringIndexer(false);
Assert.assertFalse(indexer.requiresFlush());
for (int i = 0; i < Integer.MAX_VALUE; i++)
{
if (indexer.requiresFlush())
break;
- indexer.add(new BytesRef(String.format("%5000d", i)), i);
+ indexer.addAll(List.of(new BytesRef(String.format("%5000d", i))), i);
}
// If we don't require a flush before MAX_VALUE, the implementation of RAMStringIndexer has sufficiently
// changed to warrant changes to the test.
diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexBuilder.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexBuilder.java
index d5699c195305..db2f61370754 100644
--- a/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexBuilder.java
+++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexBuilder.java
@@ -19,15 +19,18 @@
import java.nio.ByteBuffer;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.List;
import java.util.function.IntSupplier;
import java.util.function.Supplier;
+import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import com.carrotsearch.hppc.IntArrayList;
import org.apache.cassandra.db.marshal.UTF8Type;
import org.apache.cassandra.index.sai.disk.format.Version;
+import org.apache.cassandra.index.sai.memory.RowMapping;
import org.apache.cassandra.utils.Pair;
import org.apache.cassandra.utils.bytecomparable.ByteComparable;
@@ -83,10 +86,15 @@ public static class TermsEnum
this.byteComparableBytes = byteComparableBytes;
this.postings = postings;
}
+ }
- public Pair toPair()
- {
- return Pair.create(byteComparableBytes, postings);
- }
+ /**
+ * Adds default frequency of 1 to postings
+ */
+ static Pair> toTermWithFrequency(TermsEnum te)
+ {
+ return Pair.create(te.byteComparableBytes, Arrays.stream(te.postings.toArray()).boxed()
+ .map(p -> new RowMapping.RowIdWithFrequency(p, 1))
+ .collect(toList()));
}
}
diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcherTest.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcherTest.java
index 47565eb28739..a15b11f6abf3 100644
--- a/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcherTest.java
+++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcherTest.java
@@ -23,6 +23,7 @@
import java.util.List;
import java.util.stream.Collectors;
+import org.agrona.collections.Int2IntHashMap;
import org.junit.BeforeClass;
import org.junit.Test;
@@ -101,12 +102,12 @@ private void doTestEqQueriesAgainstStringIndex(Version version) throws Exception
final int numTerms = randomIntBetween(64, 512), numPostings = randomIntBetween(256, 1024);
final List termsEnum = buildTermsEnum(version, numTerms, numPostings);
- try (IndexSearcher searcher = buildIndexAndOpenSearcher(numTerms, numPostings, termsEnum))
+ try (IndexSearcher searcher = buildIndexAndOpenSearcher(numTerms, termsEnum))
{
for (int t = 0; t < numTerms; ++t)
{
try (KeyRangeIterator results = searcher.search(new Expression(indexContext)
- .add(Operator.EQ, termsEnum.get(t).originalTermBytes), null, new QueryContext(), false, LIMIT))
+ .add(Operator.EQ, termsEnum.get(t).originalTermBytes), null, new QueryContext(), false))
{
assertTrue(results.hasNext());
@@ -121,7 +122,7 @@ private void doTestEqQueriesAgainstStringIndex(Version version) throws Exception
}
try (KeyRangeIterator results = searcher.search(new Expression(indexContext)
- .add(Operator.EQ, termsEnum.get(t).originalTermBytes), null, new QueryContext(), false, LIMIT))
+ .add(Operator.EQ, termsEnum.get(t).originalTermBytes), null, new QueryContext(), false))
{
assertTrue(results.hasNext());
@@ -143,12 +144,12 @@ private void doTestEqQueriesAgainstStringIndex(Version version) throws Exception
// try searching for terms that weren't indexed
final String tooLongTerm = randomSimpleString(10, 12);
KeyRangeIterator results = searcher.search(new Expression(indexContext)
- .add(Operator.EQ, UTF8Type.instance.decompose(tooLongTerm)), null, new QueryContext(), false, LIMIT);
+ .add(Operator.EQ, UTF8Type.instance.decompose(tooLongTerm)), null, new QueryContext(), false);
assertFalse(results.hasNext());
final String tooShortTerm = randomSimpleString(1, 2);
results = searcher.search(new Expression(indexContext)
- .add(Operator.EQ, UTF8Type.instance.decompose(tooShortTerm)), null, new QueryContext(), false, LIMIT);
+ .add(Operator.EQ, UTF8Type.instance.decompose(tooShortTerm)), null, new QueryContext(), false);
assertFalse(results.hasNext());
}
}
@@ -159,10 +160,10 @@ public void testUnsupportedOperator() throws Exception
final int numTerms = randomIntBetween(5, 15), numPostings = randomIntBetween(5, 20);
final List termsEnum = buildTermsEnum(version, numTerms, numPostings);
- try (IndexSearcher searcher = buildIndexAndOpenSearcher(numTerms, numPostings, termsEnum))
+ try (IndexSearcher searcher = buildIndexAndOpenSearcher(numTerms, termsEnum))
{
searcher.search(new Expression(indexContext)
- .add(Operator.NEQ, UTF8Type.instance.decompose("a")), null, new QueryContext(), false, LIMIT);
+ .add(Operator.NEQ, UTF8Type.instance.decompose("a")), null, new QueryContext(), false);
fail("Expect IllegalArgumentException thrown, but didn't");
}
@@ -172,9 +173,8 @@ public void testUnsupportedOperator() throws Exception
}
}
- private IndexSearcher buildIndexAndOpenSearcher(int terms, int postings, List termsEnum) throws IOException
+ private IndexSearcher buildIndexAndOpenSearcher(int terms, List termsEnum) throws IOException
{
- final int size = terms * postings;
final IndexDescriptor indexDescriptor = newIndexDescriptor();
final String index = newIndex();
final IndexContext indexContext = SAITester.createIndexContext(index, UTF8Type.instance);
@@ -189,9 +189,10 @@ private IndexSearcher buildIndexAndOpenSearcher(int terms, int postings, List buildTermsEnum(Version version, int
return InvertedIndexBuilder.buildStringTermsEnum(version, terms, postings, () -> randomSimpleString(3, 5), () -> nextInt(0, Integer.MAX_VALUE));
}
+ private Int2IntHashMap createMockDocLengths(List termsEnum)
+ {
+ Int2IntHashMap docLengths = new Int2IntHashMap(Integer.MIN_VALUE);
+ for (InvertedIndexBuilder.TermsEnum term : termsEnum)
+ {
+ for (var cursor : term.postings)
+ docLengths.put(cursor.value, 1);
+ }
+ return docLengths;
+ }
+
private ByteBuffer wrap(ByteComparable bc)
{
return ByteBuffer.wrap(ByteSourceInverse.readBytes(bc.asComparableBytes(TypeUtil.BYTE_COMPARABLE_VERSION)));
diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcherTest.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcherTest.java
index 8f34b9a808aa..65242e319a37 100644
--- a/test/unit/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcherTest.java
+++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcherTest.java
@@ -151,7 +151,7 @@ public void testUnsupportedOperator() throws Exception
{{
operation = Op.NOT_EQ;
lower = upper = new Bound(ShortType.instance.decompose((short) 0), Int32Type.instance, true);
- }}, null, new QueryContext(), false, LIMIT);
+ }}, null, new QueryContext(), false);
fail("Expect IllegalArgumentException thrown, but didn't");
}
@@ -169,7 +169,7 @@ private void testEqQueries(final IndexSearcher indexSearcher,
{{
operation = Op.EQ;
lower = upper = new Bound(rawType.decompose(rawValueProducer.apply(EQ_TEST_LOWER_BOUND_INCLUSIVE)), encodedType, true);
- }}, null, new QueryContext(), false, LIMIT))
+ }}, null, new QueryContext(), false))
{
assertTrue(results.hasNext());
@@ -180,7 +180,7 @@ private void testEqQueries(final IndexSearcher indexSearcher,
{{
operation = Op.EQ;
lower = upper = new Bound(rawType.decompose(rawValueProducer.apply(EQ_TEST_UPPER_BOUND_EXCLUSIVE)), encodedType, true);
- }}, null, new QueryContext(), false, LIMIT))
+ }}, null, new QueryContext(), false))
{
assertFalse(results.hasNext());
indexSearcher.close();
@@ -206,7 +206,7 @@ private void testRangeQueries(final IndexSearcher indexSearch
lower = new Bound(rawType.decompose(rawValueProducer.apply((short)2)), encodedType, false);
upper = new Bound(rawType.decompose(rawValueProducer.apply((short)7)), encodedType, true);
- }}, null, new QueryContext(), false, LIMIT))
+ }}, null, new QueryContext(), false))
{
assertTrue(results.hasNext());
@@ -218,7 +218,7 @@ private void testRangeQueries(final IndexSearcher indexSearch
{{
operation = Op.RANGE;
lower = new Bound(rawType.decompose(rawValueProducer.apply(RANGE_TEST_UPPER_BOUND_EXCLUSIVE)), encodedType, true);
- }}, null, new QueryContext(), false, LIMIT))
+ }}, null, new QueryContext(), false))
{
assertFalse(results.hasNext());
}
@@ -227,7 +227,7 @@ private void testRangeQueries(final IndexSearcher indexSearch
{{
operation = Op.RANGE;
upper = new Bound(rawType.decompose(rawValueProducer.apply(RANGE_TEST_LOWER_BOUND_INCLUSIVE)), encodedType, false);
- }}, null, new QueryContext(), false, LIMIT))
+ }}, null, new QueryContext(), false))
{
assertFalse(results.hasNext());
indexSearcher.close();
diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/TermsReaderTest.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/TermsReaderTest.java
index bdfbf529e86b..97277da2490a 100644
--- a/test/unit/org/apache/cassandra/index/sai/disk/v1/TermsReaderTest.java
+++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/TermsReaderTest.java
@@ -18,42 +18,49 @@
package org.apache.cassandra.index.sai.disk.v1;
import java.io.IOException;
+import java.util.ArrayList;
import java.util.Collection;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import java.util.stream.Collectors;
import org.junit.Test;
-import com.carrotsearch.hppc.IntArrayList;
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
+import org.agrona.collections.Int2IntHashMap;
import org.apache.cassandra.config.DatabaseDescriptor;
import org.apache.cassandra.db.marshal.UTF8Type;
import org.apache.cassandra.index.sai.IndexContext;
import org.apache.cassandra.index.sai.QueryContext;
import org.apache.cassandra.index.sai.SAITester;
+import org.apache.cassandra.index.sai.SAIUtil;
import org.apache.cassandra.index.sai.disk.MemtableTermsIterator;
import org.apache.cassandra.index.sai.disk.PostingList;
+import org.apache.cassandra.index.sai.disk.RAMStringIndexer;
import org.apache.cassandra.index.sai.disk.TermsIterator;
import org.apache.cassandra.index.sai.disk.format.IndexComponentType;
import org.apache.cassandra.index.sai.disk.format.IndexComponents;
import org.apache.cassandra.index.sai.disk.format.IndexDescriptor;
import org.apache.cassandra.index.sai.disk.format.Version;
import org.apache.cassandra.index.sai.disk.v1.trie.InvertedIndexWriter;
+import org.apache.cassandra.index.sai.memory.RowMapping;
import org.apache.cassandra.index.sai.metrics.QueryEventListener;
import org.apache.cassandra.index.sai.utils.SAICodecUtils;
import org.apache.cassandra.index.sai.utils.SaiRandomizedTest;
import org.apache.cassandra.index.sai.utils.TypeUtil;
import org.apache.cassandra.io.util.FileHandle;
+import org.apache.cassandra.utils.ByteBufferUtil;
import org.apache.cassandra.utils.Pair;
import org.apache.cassandra.utils.bytecomparable.ByteComparable;
import org.apache.cassandra.utils.bytecomparable.ByteSourceInverse;
+import org.apache.lucene.util.BytesRef;
import static org.apache.cassandra.index.sai.disk.v1.InvertedIndexBuilder.buildStringTermsEnum;
import static org.apache.cassandra.index.sai.metrics.QueryEventListeners.NO_OP_TRIE_LISTENER;
public class TermsReaderTest extends SaiRandomizedTest
{
-
public static final ByteComparable.Version VERSION = TypeUtil.BYTE_COMPARABLE_VERSION;
@ParametersFactory()
@@ -102,8 +109,11 @@ private void doTestTermsIteration(Version version) throws IOException
IndexComponents.ForWrite components = indexDescriptor.newPerIndexComponentsForWrite(indexContext);
try (InvertedIndexWriter writer = new InvertedIndexWriter(components))
{
- var iter = termsEnum.stream().map(InvertedIndexBuilder.TermsEnum::toPair).iterator();
- indexMetas = writer.writeAll(new MemtableTermsIterator(null, null, iter));
+ var iter = termsEnum.stream()
+ .map(InvertedIndexBuilder::toTermWithFrequency)
+ .iterator();
+ Int2IntHashMap docLengths = createMockDocLengths(termsEnum);
+ indexMetas = writer.writeAll(new MemtableTermsIterator(null, null, iter), docLengths);
}
FileHandle termsData = components.get(IndexComponentType.TERMS_DATA).createFileHandle();
@@ -142,8 +152,11 @@ private void testTermQueries(Version version, int numTerms, int numPostings) thr
IndexComponents.ForWrite components = indexDescriptor.newPerIndexComponentsForWrite(indexContext);
try (InvertedIndexWriter writer = new InvertedIndexWriter(components))
{
- var iter = termsEnum.stream().map(InvertedIndexBuilder.TermsEnum::toPair).iterator();
- indexMetas = writer.writeAll(new MemtableTermsIterator(null, null, iter));
+ var iter = termsEnum.stream()
+ .map(InvertedIndexBuilder::toTermWithFrequency)
+ .iterator();
+ Int2IntHashMap docLengths = createMockDocLengths(termsEnum);
+ indexMetas = writer.writeAll(new MemtableTermsIterator(null, null, iter), docLengths);
}
FileHandle termsData = components.get(IndexComponentType.TERMS_DATA).createFileHandle();
@@ -159,22 +172,24 @@ private void testTermQueries(Version version, int numTerms, int numPostings) thr
termsFooterPointer,
version))
{
- var iter = termsEnum.stream().map(InvertedIndexBuilder.TermsEnum::toPair).collect(Collectors.toList());
- for (Pair pair : iter)
+ var iter = termsEnum.stream()
+ .map(InvertedIndexBuilder::toTermWithFrequency)
+ .collect(Collectors.toList());
+ for (Pair> pair : iter)
{
final byte[] bytes = ByteSourceInverse.readBytes(pair.left.asComparableBytes(VERSION));
try (PostingList actualPostingList = reader.exactMatch(ByteComparable.preencoded(VERSION, bytes),
(QueryEventListener.TrieIndexEventListener)NO_OP_TRIE_LISTENER,
new QueryContext()))
{
- final IntArrayList expectedPostingList = pair.right;
+ final List expectedPostingList = pair.right;
assertNotNull(actualPostingList);
assertEquals(expectedPostingList.size(), actualPostingList.size());
for (int i = 0; i < expectedPostingList.size(); ++i)
{
- final long expectedRowID = expectedPostingList.get(i);
+ final long expectedRowID = expectedPostingList.get(i).rowId;
long result = actualPostingList.nextPosting();
assertEquals(String.format("row %d mismatch of %d in enum %d", i, expectedPostingList.size(), termsEnum.indexOf(pair)), expectedRowID, result);
}
@@ -188,18 +203,18 @@ private void testTermQueries(Version version, int numTerms, int numPostings) thr
(QueryEventListener.TrieIndexEventListener)NO_OP_TRIE_LISTENER,
new QueryContext()))
{
- final IntArrayList expectedPostingList = pair.right;
+ final List expectedPostingList = pair.right;
// test skipping to the last block
final int idxToSkip = numPostings - 2;
// tokens are equal to their corresponding row IDs
- final int tokenToSkip = expectedPostingList.get(idxToSkip);
+ final int tokenToSkip = expectedPostingList.get(idxToSkip).rowId;
long advanceResult = actualPostingList.advance(tokenToSkip);
assertEquals(tokenToSkip, advanceResult);
for (int i = idxToSkip + 1; i < expectedPostingList.size(); ++i)
{
- final long expectedRowID = expectedPostingList.get(i);
+ final long expectedRowID = expectedPostingList.get(i).rowId;
long result = actualPostingList.nextPosting();
assertEquals(expectedRowID, result);
}
@@ -211,6 +226,17 @@ private void testTermQueries(Version version, int numTerms, int numPostings) thr
}
}
+ private Int2IntHashMap createMockDocLengths(List termsEnum)
+ {
+ Int2IntHashMap docLengths = new Int2IntHashMap(Integer.MIN_VALUE);
+ for (InvertedIndexBuilder.TermsEnum term : termsEnum)
+ {
+ for (var cursor : term.postings)
+ docLengths.put(cursor.value, 1);
+ }
+ return docLengths;
+ }
+
private List buildTermsEnum(Version version, int terms, int postings)
{
return buildStringTermsEnum(version, terms, postings, () -> randomSimpleString(4, 10), () -> nextInt(0, Integer.MAX_VALUE));
diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/kdtree/ImmutableOneDimPointValuesTest.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/kdtree/ImmutableOneDimPointValuesTest.java
index 9f086249f331..ba8f33cd9951 100644
--- a/test/unit/org/apache/cassandra/index/sai/disk/v1/kdtree/ImmutableOneDimPointValuesTest.java
+++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/kdtree/ImmutableOneDimPointValuesTest.java
@@ -19,12 +19,14 @@
import java.io.IOException;
import java.nio.ByteBuffer;
+import java.util.Arrays;
+import java.util.List;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
-import com.carrotsearch.hppc.IntArrayList;
+import org.apache.cassandra.index.sai.memory.RowMapping;
import org.apache.cassandra.db.marshal.Int32Type;
import org.apache.cassandra.index.sai.disk.MemtableTermsIterator;
import org.apache.cassandra.index.sai.disk.TermsIterator;
@@ -111,20 +113,22 @@ private TermsIterator buildDescTermEnum(int from, int to)
final ByteBuffer minTerm = Int32Type.instance.decompose(from);
final ByteBuffer maxTerm = Int32Type.instance.decompose(to);
- final AbstractGuavaIterator> iterator = new AbstractGuavaIterator>()
+ final AbstractGuavaIterator>> iterator = new AbstractGuavaIterator<>()
{
private int currentTerm = from;
@Override
- protected Pair computeNext()
+ protected Pair> computeNext()
{
if (currentTerm <= to)
{
return endOfData();
}
final ByteBuffer term = Int32Type.instance.decompose(currentTerm++);
- IntArrayList postings = new IntArrayList();
- postings.add(0, 1, 2);
+ List postings = Arrays.asList(
+ new RowMapping.RowIdWithFrequency(0, 1),
+ new RowMapping.RowIdWithFrequency(1, 1),
+ new RowMapping.RowIdWithFrequency(2, 1));
return Pair.create(v -> ByteSource.preencoded(term), postings);
}
};
diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/kdtree/KDTreeIndexBuilder.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/kdtree/KDTreeIndexBuilder.java
index fdbd1f9d91f8..3a867e41d90f 100644
--- a/test/unit/org/apache/cassandra/index/sai/disk/v1/kdtree/KDTreeIndexBuilder.java
+++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/kdtree/KDTreeIndexBuilder.java
@@ -21,7 +21,9 @@
import java.math.BigDecimal;
import java.math.BigInteger;
import java.nio.ByteBuffer;
+import java.util.ArrayList;
import java.util.Iterator;
+import java.util.List;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
@@ -31,6 +33,7 @@
import org.junit.Assert;
import com.carrotsearch.hppc.IntArrayList;
+import org.apache.cassandra.index.sai.memory.RowMapping;
import org.apache.cassandra.db.marshal.AbstractType;
import org.apache.cassandra.db.marshal.DecimalType;
import org.apache.cassandra.db.marshal.Int32Type;
@@ -142,7 +145,22 @@ public KDTreeIndexBuilder(IndexDescriptor indexDescriptor,
KDTreeIndexSearcher flushAndOpen() throws IOException
{
- final TermsIterator termEnum = new MemtableTermsIterator(null, null, terms);
+ // Wrap postings with RowIdWithFrequency using default frequency of 1
+ final TermsIterator termEnum = new MemtableTermsIterator(null, null, new AbstractGuavaIterator<>()
+ {
+ @Override
+ protected Pair> computeNext()
+ {
+ if (!terms.hasNext())
+ return endOfData();
+
+ Pair pair = terms.next();
+ List postings = new ArrayList<>(pair.right.size());
+ for (int i = 0; i < pair.right.size(); i++)
+ postings.add(new RowMapping.RowIdWithFrequency(pair.right.get(i), 1));
+ return Pair.create(pair.left, postings);
+ }
+ });
final ImmutableOneDimPointValues pointValues = ImmutableOneDimPointValues.fromTermEnum(termEnum, type);
final SegmentMetadata metadata;
diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/kdtree/NumericIndexWriterTest.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/kdtree/NumericIndexWriterTest.java
index dcc1e7ef056d..27bff7fa40cb 100644
--- a/test/unit/org/apache/cassandra/index/sai/disk/v1/kdtree/NumericIndexWriterTest.java
+++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/kdtree/NumericIndexWriterTest.java
@@ -18,6 +18,8 @@
package org.apache.cassandra.index.sai.disk.v1.kdtree;
import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
import org.junit.Before;
import org.junit.Test;
@@ -35,6 +37,7 @@
import org.apache.cassandra.index.sai.disk.format.IndexDescriptor;
import org.apache.cassandra.index.sai.disk.v1.IndexWriterConfig;
import org.apache.cassandra.index.sai.disk.v1.SegmentMetadata;
+import org.apache.cassandra.index.sai.memory.RowMapping;
import org.apache.cassandra.index.sai.metrics.QueryEventListener;
import org.apache.cassandra.index.sai.metrics.QueryEventListeners;
import org.apache.cassandra.index.sai.utils.SaiRandomizedTest;
@@ -190,21 +193,21 @@ private TermsIterator buildTermEnum(int startTermInclusive, int endTermExclusive
final ByteBuffer minTerm = Int32Type.instance.decompose(startTermInclusive);
final ByteBuffer maxTerm = Int32Type.instance.decompose(endTermExclusive);
- final AbstractGuavaIterator> iterator = new AbstractGuavaIterator>()
+ final AbstractGuavaIterator>> iterator = new AbstractGuavaIterator<>()
{
private int currentTerm = startTermInclusive;
private int currentRowId = 0;
@Override
- protected Pair computeNext()
+ protected Pair> computeNext()
{
if (currentTerm >= endTermExclusive)
{
return endOfData();
}
final ByteBuffer term = Int32Type.instance.decompose(currentTerm++);
- final IntArrayList postings = new IntArrayList();
- postings.add(currentRowId++);
+ final List postings = new ArrayList<>();
+ postings.add(new RowMapping.RowIdWithFrequency(currentRowId++, 1));
final ByteSource encoded = Int32Type.instance.asComparableBytes(term, TypeUtil.BYTE_COMPARABLE_VERSION);
return Pair.create(v -> encoded, postings);
}
diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingListTest.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingListTest.java
new file mode 100644
index 000000000000..76285a4f0bc4
--- /dev/null
+++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingListTest.java
@@ -0,0 +1,210 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.index.sai.disk.v1.postings;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.stream.IntStream;
+
+import org.junit.Test;
+
+import org.apache.cassandra.index.sai.disk.PostingList;
+import org.apache.cassandra.index.sai.postings.IntArrayPostingList;
+import org.apache.cassandra.index.sai.utils.SaiRandomizedTest;
+import org.apache.cassandra.utils.ByteBufferUtil;
+
+public class IntersectingPostingListTest extends SaiRandomizedTest
+{
+ private Map createPostingMap(PostingList... lists)
+ {
+ Map map = new HashMap<>();
+ for (int i = 0; i < lists.length; i++)
+ {
+ map.put(ByteBufferUtil.bytes(String.valueOf((char) ('A' + i))), lists[i]);
+ }
+ return map;
+ }
+
+ @Test
+ public void shouldIntersectOverlappingPostingLists() throws IOException
+ {
+ var map = createPostingMap(new IntArrayPostingList(new int[]{ 1, 4, 6, 8 }),
+ new IntArrayPostingList(new int[]{ 2, 4, 6, 9 }),
+ new IntArrayPostingList(new int[]{ 4, 6, 7 }));
+
+ final PostingList intersected = IntersectingPostingList.intersect(map);
+ assertPostingListEquals(new IntArrayPostingList(new int[]{ 4, 6 }), intersected);
+ }
+
+ @Test
+ public void shouldIntersectDisjointPostingLists() throws IOException
+ {
+ var map = createPostingMap(new IntArrayPostingList(new int[]{ 1, 3, 5 }),
+ new IntArrayPostingList(new int[]{ 2, 4, 6 }));
+
+ final PostingList intersected = IntersectingPostingList.intersect(map);
+ assertPostingListEquals(new IntArrayPostingList(new int[]{}), intersected);
+ }
+
+ @Test
+ public void shouldIntersectSinglePostingList() throws IOException
+ {
+ var map = createPostingMap(new IntArrayPostingList(new int[]{ 1, 4, 6 }));
+
+ final PostingList intersected = IntersectingPostingList.intersect(map);
+ assertPostingListEquals(new IntArrayPostingList(new int[]{ 1, 4, 6 }), intersected);
+ }
+
+ @Test
+ public void shouldIntersectIdenticalPostingLists() throws IOException
+ {
+ var map = createPostingMap(new IntArrayPostingList(new int[]{ 1, 2, 3 }),
+ new IntArrayPostingList(new int[]{ 1, 2, 3 }));
+
+ final PostingList intersected = IntersectingPostingList.intersect(map);
+ assertPostingListEquals(new IntArrayPostingList(new int[]{ 1, 2, 3 }), intersected);
+ }
+
+ @Test
+ public void shouldAdvanceAllIntersectedLists() throws IOException
+ {
+ var map = createPostingMap(new IntArrayPostingList(new int[]{ 1, 3, 5, 7, 9 }),
+ new IntArrayPostingList(new int[]{ 2, 3, 5, 7, 8 }),
+ new IntArrayPostingList(new int[]{ 3, 5, 7, 10 }));
+
+ final PostingList intersected = IntersectingPostingList.intersect(map);
+ final PostingList expected = new IntArrayPostingList(new int[]{ 3, 5, 7 });
+
+ assertEquals(expected.advance(5), intersected.advance(5));
+ assertPostingListEquals(expected, intersected);
+ }
+
+ @Test
+ public void shouldHandleEmptyList() throws IOException
+ {
+ var map = createPostingMap(new IntArrayPostingList(new int[]{}),
+ new IntArrayPostingList(new int[]{ 1, 2, 3 }));
+
+ final PostingList intersected = IntersectingPostingList.intersect(map);
+ assertEquals(PostingList.END_OF_STREAM, intersected.advance(1));
+ }
+
+ @Test
+ public void shouldInterleaveNextAndAdvance() throws IOException
+ {
+ var map = createPostingMap(new IntArrayPostingList(new int[]{ 1, 3, 5, 7, 9 }),
+ new IntArrayPostingList(new int[]{ 1, 3, 5, 7, 9 }),
+ new IntArrayPostingList(new int[]{ 1, 3, 5, 7, 9 }));
+
+ final PostingList intersected = IntersectingPostingList.intersect(map);
+
+ assertEquals(1, intersected.nextPosting());
+ assertEquals(5, intersected.advance(5));
+ assertEquals(7, intersected.nextPosting());
+ assertEquals(9, intersected.advance(9));
+ }
+
+ @Test
+ public void shouldInterleaveNextAndAdvanceOnRandom() throws IOException
+ {
+ for (int i = 0; i < 1000; ++i)
+ {
+ testAdvancingOnRandom();
+ }
+ }
+
+ private void testAdvancingOnRandom() throws IOException
+ {
+ final int postingsCount = nextInt(1, 50_000);
+ final int postingListCount = nextInt(2, 10);
+
+ final AtomicInteger rowId = new AtomicInteger();
+ final int[] commonPostings = IntStream.generate(() -> rowId.addAndGet(nextInt(1, 10)))
+ .limit(postingsCount / 4)
+ .toArray();
+
+ var splitPostingLists = new ArrayList();
+ for (int i = 0; i < postingListCount; i++)
+ {
+ final int[] uniquePostings = IntStream.generate(() -> rowId.addAndGet(nextInt(1, 10)))
+ .limit(postingsCount)
+ .toArray();
+ int[] combined = IntStream.concat(IntStream.of(commonPostings),
+ IntStream.of(uniquePostings))
+ .distinct()
+ .sorted()
+ .toArray();
+ splitPostingLists.add(new IntArrayPostingList(combined));
+ }
+
+ final PostingList intersected = IntersectingPostingList.intersect(createPostingMap(splitPostingLists.toArray(new PostingList[0])));
+ final PostingList expected = new IntArrayPostingList(commonPostings);
+
+ final List actions = new ArrayList<>();
+ for (int idx = 0; idx < commonPostings.length; idx++)
+ {
+ if (nextInt(0, 8) == 0)
+ {
+ actions.add((postingList) -> {
+ try
+ {
+ return postingList.nextPosting();
+ }
+ catch (IOException e)
+ {
+ fail(e.getMessage());
+ throw new RuntimeException(e);
+ }
+ });
+ }
+ else
+ {
+ final int skips = nextInt(0, 5);
+ idx = Math.min(idx + skips, commonPostings.length - 1);
+ final int rowID = commonPostings[idx];
+ actions.add((postingList) -> {
+ try
+ {
+ return postingList.advance(rowID);
+ }
+ catch (IOException e)
+ {
+ fail(e.getMessage());
+ throw new RuntimeException(e);
+ }
+ });
+ }
+ }
+
+ for (PostingListAdvance action : actions)
+ {
+ assertEquals(action.advance(expected), action.advance(intersected));
+ }
+ }
+
+ private interface PostingListAdvance
+ {
+ long advance(PostingList list) throws IOException;
+ }
+}
diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/PostingsTest.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/PostingsTest.java
index 184bfd3af8f0..451789e872c9 100644
--- a/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/PostingsTest.java
+++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/PostingsTest.java
@@ -328,7 +328,7 @@ private PostingsReader openReader(IndexComponent.ForRead postingLists, long fp,
private PostingsReader.BlocksSummary assertBlockSummary(int blockSize, PostingList expected, IndexInput input) throws IOException
{
final PostingsReader.BlocksSummary summary = new PostingsReader.BlocksSummary(input, input.getFilePointer());
- assertEquals(blockSize, summary.blockSize);
+ assertEquals(blockSize, summary.blockEntries);
assertEquals(expected.size(), summary.numPostings);
assertTrue(summary.offsets.length() > 0);
assertEquals(summary.offsets.length(), summary.maxValues.length());
diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/VectorPostingListTest.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/ReorderingPostingListTest.java
similarity index 91%
rename from test/unit/org/apache/cassandra/index/sai/disk/v1/postings/VectorPostingListTest.java
rename to test/unit/org/apache/cassandra/index/sai/disk/v1/postings/ReorderingPostingListTest.java
index 3135db3c7748..10b2b6a33f47 100644
--- a/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/VectorPostingListTest.java
+++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/ReorderingPostingListTest.java
@@ -30,14 +30,14 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
-public class VectorPostingListTest
+public class ReorderingPostingListTest
{
@Test
public void ensureEmptySourceBehavesCorrectly() throws Throwable
{
var source = new TestIterator(CloseableIterator.emptyIterator());
- try (var postingList = new VectorPostingList(source))
+ try (var postingList = new ReorderingPostingList(source, RowIdWithScore::getSegmentRowId))
{
// Even an empty source should be closed
assertTrue(source.isClosed);
@@ -55,7 +55,7 @@ public void ensureIteratorIsConsumedClosedAndReordered() throws Throwable
new RowIdWithScore(4, 4),
}).iterator());
- try (var postingList = new VectorPostingList(source))
+ try (var postingList = new ReorderingPostingList(source, RowIdWithScore::getSegmentRowId))
{
// The posting list is eagerly consumed, so it should be closed before
// we close postingList
@@ -80,7 +80,7 @@ public void ensureAdvanceWorksCorrectly() throws Throwable
new RowIdWithScore(2, 2),
}).iterator());
- try (var postingList = new VectorPostingList(source))
+ try (var postingList = new ReorderingPostingList(source, RowIdWithScore::getSegmentRowId))
{
assertEquals(3, postingList.advance(3));
assertEquals(PostingList.END_OF_STREAM, postingList.advance(4));
diff --git a/test/unit/org/apache/cassandra/index/sai/memory/TrieMemoryIndexTest.java b/test/unit/org/apache/cassandra/index/sai/memory/TrieMemoryIndexTest.java
index 4119b3afb872..6199266d6da0 100644
--- a/test/unit/org/apache/cassandra/index/sai/memory/TrieMemoryIndexTest.java
+++ b/test/unit/org/apache/cassandra/index/sai/memory/TrieMemoryIndexTest.java
@@ -82,17 +82,16 @@ public void iteratorShouldReturnAllValuesNumeric()
index.add(makeKey(table, Integer.toString(row)), Clustering.EMPTY, Int32Type.instance.decompose(row / 10), allocatedBytes -> {}, allocatesBytes -> {});
}
- Iterator> iterator = index.iterator();
+ var iterator = index.iterator();
int valueCount = 0;
while(iterator.hasNext())
{
- Pair pair = iterator.next();
+ var pair = iterator.next();
int value = ByteSourceInverse.getSignedInt(pair.left.asComparableBytes(TypeUtil.BYTE_COMPARABLE_VERSION));
int idCount = 0;
- Iterator primaryKeyIterator = pair.right.iterator();
- while (primaryKeyIterator.hasNext())
+ for (var pkf : pair.right)
{
- PrimaryKey primaryKey = primaryKeyIterator.next();
+ PrimaryKey primaryKey = pkf.pk;
int id = Int32Type.instance.compose(primaryKey.partitionKey().getKey());
assertEquals(id/10, value);
idCount++;
@@ -113,17 +112,16 @@ public void iteratorShouldReturnAllValuesString()
index.add(makeKey(table, Integer.toString(row)), Clustering.EMPTY, UTF8Type.instance.decompose(Integer.toString(row / 10)), allocatedBytes -> {}, allocatesBytes -> {});
}
- Iterator> iterator = index.iterator();
+ var iterator = index.iterator();
int valueCount = 0;
while(iterator.hasNext())
{
- Pair pair = iterator.next();
+ var pair = iterator.next();
String value = new String(ByteSourceInverse.readBytes(pair.left.asPeekableBytes(TypeUtil.BYTE_COMPARABLE_VERSION)), StandardCharsets.UTF_8);
int idCount = 0;
- Iterator primaryKeyIterator = pair.right.iterator();
- while (primaryKeyIterator.hasNext())
+ for (var pkf : pair.right)
{
- PrimaryKey primaryKey = primaryKeyIterator.next();
+ PrimaryKey primaryKey = pkf.pk;
String id = UTF8Type.instance.compose(primaryKey.partitionKey().getKey());
assertEquals(Integer.toString(Integer.parseInt(id) / 10), value);
idCount++;
@@ -149,11 +147,11 @@ private void shouldAcceptPrefixValuesForType(AbstractType> type, IntFunction {}, allocatesBytes -> {});
}
- final Iterator> iterator = index.iterator();
+ final var iterator = index.iterator();
int i = 0;
while (iterator.hasNext())
{
- Pair pair = iterator.next();
+ var pair = iterator.next();
assertEquals(1, pair.right.size());
final int rowId = i;
diff --git a/test/unit/org/apache/cassandra/index/sai/memory/TrieMemtableIndexTestBase.java b/test/unit/org/apache/cassandra/index/sai/memory/TrieMemtableIndexTestBase.java
index 0f1089250348..21b13059de61 100644
--- a/test/unit/org/apache/cassandra/index/sai/memory/TrieMemtableIndexTestBase.java
+++ b/test/unit/org/apache/cassandra/index/sai/memory/TrieMemtableIndexTestBase.java
@@ -225,11 +225,11 @@ public void indexIteratorTest()
DecoratedKey minimum = temp1.compareTo(temp2) <= 0 ? temp1 : temp2;
DecoratedKey maximum = temp1.compareTo(temp2) <= 0 ? temp2 : temp1;
- Iterator>> iterator = memtableIndex.iterator(minimum, maximum);
+ var iterator = memtableIndex.iterator(minimum, maximum);
while (iterator.hasNext())
{
- Pair> termPair = iterator.next();
+ var termPair = iterator.next();
int term = termFromComparable(termPair.left);
// The iterator will return keys outside the range of min/max, so we need to filter here to
// get the correct keys
@@ -239,9 +239,9 @@ public void indexIteratorTest()
.sorted()
.collect(Collectors.toList());
List termPks = new ArrayList<>();
- while (termPair.right.hasNext())
+ for (var pkWithFreq : termPair.right)
{
- DecoratedKey pk = termPair.right.next().partitionKey();
+ DecoratedKey pk = pkWithFreq.pk.partitionKey();
if (pk.compareTo(minimum) >= 0 && pk.compareTo(maximum) <= 0)
termPks.add(pk);
}
diff --git a/test/unit/org/apache/cassandra/index/sai/memory/VectorMemtableIndexTest.java b/test/unit/org/apache/cassandra/index/sai/memory/VectorMemtableIndexTest.java
index 75e8c83f30be..2197c00fe231 100644
--- a/test/unit/org/apache/cassandra/index/sai/memory/VectorMemtableIndexTest.java
+++ b/test/unit/org/apache/cassandra/index/sai/memory/VectorMemtableIndexTest.java
@@ -148,7 +148,7 @@ private void validate(List keys)
{
IntStream.range(0, 1_000).parallel().forEach(i ->
{
- var orderer = generateRandomOrderer();
+ var orderer = randomVectorOrderer();
AbstractBounds keyRange = generateRandomBounds(keys);
// compute keys in range of the bounds
Set keysInRange = keys.stream().filter(keyRange::contains)
@@ -197,7 +197,7 @@ public void indexIteratorTest()
// VSTODO
}
- private Orderer generateRandomOrderer()
+ private Orderer randomVectorOrderer()
{
return new Orderer(indexContext, Operator.ANN, randomVectorSerialized());
}
diff --git a/test/unit/org/apache/cassandra/index/sai/metrics/IndexMetricsTest.java b/test/unit/org/apache/cassandra/index/sai/metrics/IndexMetricsTest.java
index cc65bd20c293..c59f10c8e42f 100644
--- a/test/unit/org/apache/cassandra/index/sai/metrics/IndexMetricsTest.java
+++ b/test/unit/org/apache/cassandra/index/sai/metrics/IndexMetricsTest.java
@@ -30,6 +30,7 @@
public class IndexMetricsTest extends AbstractMetricsTest
{
+
private static final String TABLE = "table_name";
private static final String INDEX = "table_name_index";
@@ -155,7 +156,7 @@ public void testQueriesCount()
int rowCount = 10;
for (int i = 0; i < rowCount; i++)
- execute("INSERT INTO %s (id1, v1, v2, v3) VALUES (?, ?, '0', [?, 0.0])", Integer.toString(i), i, i);
+ execute("INSERT INTO %s (id1, v1, v2, v3) VALUES (?, ?, '0', ?)", Integer.toString(i), i, vector(i, i));
assertIndexQueryCount(indexV1, 0L);
@@ -180,7 +181,7 @@ public void testQueriesCount()
assertIndexQueryCount(indexV1, 4L);
assertIndexQueryCount(indexV2, 2L);
- String indexV3 = createIndex("CREATE CUSTOM INDEX ON %s (v3) USING 'StorageAttachedIndex'");
+ String indexV3 = createIndex("CREATE CUSTOM INDEX ON %s (v3) USING 'StorageAttachedIndex' WITH OPTIONS = {'similarity_function': 'euclidean'}");
assertIndexQueryCount(indexV3, 0L);
executeNet("SELECT id1 FROM %s WHERE v2 = '2' ORDER BY v3 ANN OF [5,0] LIMIT 10");
assertIndexQueryCount(indexV1, 4L);
|