From b0cdc37bc29fbfb4c7301523dbdc4723cde2b917 Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Wed, 26 Feb 2025 18:44:22 -0600 Subject: [PATCH] Implement synthetic columns and ORDER BY BM25 (#1434) ### What is the issue https://github.com/riptano/cndb/issues/11725 # New functionality ORDER BY BM25, enabled by SAI Version.EC # Enhanced functionality ORDER BY ANN can also used the synthetic score column so that coordinator does not need to recompute similarity for every row returned by the different replicas. Controlled by SelectStatement.ANN_USE_SYNTHETIC_SCORE (default false) --- src/antlr/Lexer.g | 1 + src/antlr/Parser.g | 15 +- .../cassandra/cql3/GeoDistanceRelation.java | 6 + .../cassandra/cql3/MultiColumnRelation.java | 6 + .../org/apache/cassandra/cql3/Operator.java | 21 +- .../org/apache/cassandra/cql3/Ordering.java | 75 +++ .../org/apache/cassandra/cql3/Relation.java | 7 + .../cassandra/cql3/SingleColumnRelation.java | 36 +- .../apache/cassandra/cql3/TokenRelation.java | 8 +- .../ClusteringColumnRestrictions.java | 4 +- .../PartitionKeySingleRestrictionSet.java | 2 +- .../cql3/restrictions/RestrictionSet.java | 35 +- .../restrictions/SingleColumnRestriction.java | 125 ++++- .../cql3/restrictions/SingleRestriction.java | 24 +- .../restrictions/StatementRestrictions.java | 16 +- .../cql3/selection/ColumnFilterFactory.java | 16 +- .../cassandra/cql3/selection/Selection.java | 23 +- .../cql3/statements/SelectStatement.java | 114 ++-- src/java/org/apache/cassandra/db/Columns.java | 94 +++- .../org/apache/cassandra/db/ReadCommand.java | 22 +- .../cassandra/db/RegularAndStaticColumns.java | 4 +- .../cassandra/db/filter/ColumnFilter.java | 40 +- .../apache/cassandra/db/filter/RowFilter.java | 1 + .../db/monitoring/MonitorableImpl.java | 5 + .../partitions/ParallelCommandProcessor.java | 65 --- .../apache/cassandra/index/IndexRegistry.java | 85 ++- .../cassandra/index/RowFilterValidator.java | 103 ---- .../index/SecondaryIndexManager.java | 6 - .../cassandra/index/sai/IndexContext.java | 10 +- .../cassandra/index/sai/QueryContext.java | 2 +- .../index/sai/StorageAttachedIndex.java | 30 +- .../analyzer/AnalyzerEqOperatorSupport.java | 9 +- .../index/sai/disk/MemtableTermsIterator.java | 30 +- .../cassandra/index/sai/disk/PostingList.java | 8 + .../index/sai/disk/PrimaryKeyWithSource.java | 11 + .../index/sai/disk/RAMPostingSlices.java | 65 ++- .../index/sai/disk/RAMStringIndexer.java | 84 ++- .../sai/disk/format/IndexComponentType.java | 7 +- .../index/sai/disk/format/Version.java | 28 +- .../index/sai/disk/v1/DocLengthsReader.java | 59 ++ .../index/sai/disk/v1/IndexSearcher.java | 24 +- .../sai/disk/v1/InvertedIndexSearcher.java | 146 ++++- .../sai/disk/v1/KDTreeIndexSearcher.java | 2 +- .../sai/disk/v1/MemtableIndexWriter.java | 33 +- .../v1/PartitionAwarePrimaryKeyFactory.java | 14 + .../index/sai/disk/v1/PerIndexFiles.java | 7 + .../index/sai/disk/v1/SSTableIndexWriter.java | 2 +- .../cassandra/index/sai/disk/v1/Segment.java | 2 +- .../index/sai/disk/v1/SegmentBuilder.java | 90 ++-- .../index/sai/disk/v1/SegmentMetadata.java | 5 + .../index/sai/disk/v1/TermsReader.java | 12 +- .../index/sai/disk/v1/V1OnDiskFormat.java | 9 +- .../v1/postings/IntersectingPostingList.java | 141 +++++ .../sai/disk/v1/postings/PostingsReader.java | 63 ++- .../sai/disk/v1/postings/PostingsWriter.java | 106 ++-- ...ngList.java => ReorderingPostingList.java} | 10 +- .../v1/postings/ScanningPostingsReader.java | 4 +- .../sai/disk/v1/trie/DocLengthsWriter.java | 73 +++ .../sai/disk/v1/trie/InvertedIndexWriter.java | 24 +- .../disk/v2/RowAwarePrimaryKeyFactory.java | 18 + .../sai/disk/v2/V2VectorIndexSearcher.java | 22 +- .../sai/disk/v4/V4InvertedIndexSearcher.java | 2 +- .../index/sai/disk/v7/V7OnDiskFormat.java | 50 ++ .../sai/disk/vector/VectorMemtableIndex.java | 7 +- .../index/sai/memory/MemoryIndex.java | 17 +- .../index/sai/memory/MemtableIndex.java | 4 +- .../index/sai/memory/RowMapping.java | 40 +- .../index/sai/memory/TrieMemoryIndex.java | 125 ++++- .../index/sai/memory/TrieMemtableIndex.java | 215 ++++++-- .../cassandra/index/sai/plan/Expression.java | 2 + .../cassandra/index/sai/plan/Operation.java | 3 +- .../cassandra/index/sai/plan/Orderer.java | 55 +- .../apache/cassandra/index/sai/plan/Plan.java | 115 +++- .../index/sai/plan/QueryController.java | 6 + .../plan/StorageAttachedIndexQueryPlan.java | 2 +- .../plan/StorageAttachedIndexSearcher.java | 160 ++++-- .../index/sai/plan/TopKProcessor.java | 297 ++++------ .../cassandra/index/sai/utils/BM25Utils.java | 185 +++++++ .../cassandra/index/sai/utils/PrimaryKey.java | 3 +- .../utils/PrimaryKeyWithByteComparable.java | 17 +- .../index/sai/utils/PrimaryKeyWithScore.java | 19 +- .../sai/utils/PrimaryKeyWithSortKey.java | 21 +- .../index/sai/utils/PrimaryKeys.java | 1 + .../index/sai/utils/RowIdWithScore.java | 2 +- .../index/sai/utils/RowWithSourceTable.java | 9 + .../cassandra/schema/ColumnMetadata.java | 37 +- .../cassandra/schema/TableMetadata.java | 3 +- .../apache/cassandra/service/ClientWarn.java | 19 +- .../cassandra/index/sai/LongBM25Test.java | 251 +++++++++ .../test/sai/BM25DistributedTest.java | 123 +++++ .../org/apache/cassandra/cql3/CQLTester.java | 5 + .../operations/SelectOrderByTest.java | 2 +- .../cassandra/index/sai/cql/BM25Test.java | 510 ++++++++++++++++++ .../sai/cql/MultipleColumnIndexTest.java | 26 +- .../index/sai/cql/NativeIndexDDLTest.java | 2 +- .../index/sai/disk/RAMPostingSlicesTest.java | 38 +- .../index/sai/disk/RAMStringIndexerTest.java | 62 ++- .../sai/disk/v1/InvertedIndexBuilder.java | 16 +- .../disk/v1/InvertedIndexSearcherTest.java | 34 +- .../sai/disk/v1/KDTreeIndexSearcherTest.java | 12 +- .../index/sai/disk/v1/TermsReaderTest.java | 52 +- .../ImmutableOneDimPointValuesTest.java | 14 +- .../disk/v1/kdtree/KDTreeIndexBuilder.java | 20 +- .../v1/kdtree/NumericIndexWriterTest.java | 11 +- .../postings/IntersectingPostingListTest.java | 210 ++++++++ .../sai/disk/v1/postings/PostingsTest.java | 2 +- ...st.java => ReorderingPostingListTest.java} | 8 +- .../index/sai/memory/TrieMemoryIndexTest.java | 22 +- .../sai/memory/TrieMemtableIndexTestBase.java | 8 +- .../sai/memory/VectorMemtableIndexTest.java | 4 +- .../index/sai/metrics/IndexMetricsTest.java | 5 +- 111 files changed, 3937 insertions(+), 1025 deletions(-) delete mode 100644 src/java/org/apache/cassandra/db/partitions/ParallelCommandProcessor.java delete mode 100644 src/java/org/apache/cassandra/index/RowFilterValidator.java create mode 100644 src/java/org/apache/cassandra/index/sai/disk/v1/DocLengthsReader.java create mode 100644 src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java rename src/java/org/apache/cassandra/index/sai/disk/v1/postings/{VectorPostingList.java => ReorderingPostingList.java} (83%) create mode 100644 src/java/org/apache/cassandra/index/sai/disk/v1/trie/DocLengthsWriter.java create mode 100644 src/java/org/apache/cassandra/index/sai/disk/v7/V7OnDiskFormat.java create mode 100644 src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java create mode 100644 test/burn/org/apache/cassandra/index/sai/LongBM25Test.java create mode 100644 test/distributed/org/apache/cassandra/distributed/test/sai/BM25DistributedTest.java create mode 100644 test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java create mode 100644 test/unit/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingListTest.java rename test/unit/org/apache/cassandra/index/sai/disk/v1/postings/{VectorPostingListTest.java => ReorderingPostingListTest.java} (91%) diff --git a/src/antlr/Lexer.g b/src/antlr/Lexer.g index 3da08245c3d6..dd050c1478e7 100644 --- a/src/antlr/Lexer.g +++ b/src/antlr/Lexer.g @@ -227,6 +227,7 @@ K_DROPPED: D R O P P E D; K_COLUMN: C O L U M N; K_RECORD: R E C O R D; K_ANN: A N N; +K_BM25: B M '2' '5'; // Case-insensitive alpha characters fragment A: ('a'|'A'); diff --git a/src/antlr/Parser.g b/src/antlr/Parser.g index c0544620d51d..6b76c9d74bed 100644 --- a/src/antlr/Parser.g +++ b/src/antlr/Parser.g @@ -459,14 +459,18 @@ customIndexExpression [WhereClause.Builder clause] ; orderByClause[List orderings] - @init{ + @init { Ordering.Direction direction = Ordering.Direction.ASC; + Ordering.Raw.Expression expr = null; } - : c=cident (K_ANN K_OF t=term)? (K_ASC | K_DESC { direction = Ordering.Direction.DESC; })? + : c=cident + ( K_ANN K_OF t=term { expr = new Ordering.Raw.Ann(c, t); } + | K_BM25 K_OF t=term { expr = new Ordering.Raw.Bm25(c, t); } + )? + (K_ASC | K_DESC { direction = Ordering.Direction.DESC; })? { - Ordering.Raw.Expression expr = (t == null) - ? new Ordering.Raw.SingleColumn(c) - : new Ordering.Raw.Ann(c, t); + if (expr == null) + expr = new Ordering.Raw.SingleColumn(c); orderings.add(new Ordering.Raw(expr, direction)); } ; @@ -1969,6 +1973,7 @@ basic_unreserved_keyword returns [String str] | K_COLUMN | K_RECORD | K_ANN + | K_BM25 | K_OFFSET ) { $str = $k.text; } ; diff --git a/src/java/org/apache/cassandra/cql3/GeoDistanceRelation.java b/src/java/org/apache/cassandra/cql3/GeoDistanceRelation.java index 3d5fb2eeeef7..640e7e600686 100644 --- a/src/java/org/apache/cassandra/cql3/GeoDistanceRelation.java +++ b/src/java/org/apache/cassandra/cql3/GeoDistanceRelation.java @@ -141,6 +141,12 @@ protected Restriction newAnnRestriction(TableMetadata table, VariableSpecificati throw invalidRequest("%s cannot be used with the GEO_DISTANCE function", operator()); } + @Override + protected Restriction newBm25Restriction(TableMetadata table, VariableSpecifications boundNames) + { + throw invalidRequest("%s cannot be used with the GEO_DISTANCE function", operator()); + } + @Override protected Restriction newAnalyzerMatchesRestriction(TableMetadata table, VariableSpecifications boundNames) { diff --git a/src/java/org/apache/cassandra/cql3/MultiColumnRelation.java b/src/java/org/apache/cassandra/cql3/MultiColumnRelation.java index fcd505fcf5df..f56d76e2ced9 100644 --- a/src/java/org/apache/cassandra/cql3/MultiColumnRelation.java +++ b/src/java/org/apache/cassandra/cql3/MultiColumnRelation.java @@ -250,6 +250,12 @@ protected Restriction newAnnRestriction(TableMetadata table, VariableSpecificati throw invalidRequest("%s cannot be used for multi-column relations", operator()); } + @Override + protected Restriction newBm25Restriction(TableMetadata table, VariableSpecifications boundNames) + { + throw invalidRequest("%s cannot be used for multi-column relations", operator()); + } + @Override protected Restriction newAnalyzerMatchesRestriction(TableMetadata table, VariableSpecifications boundNames) { diff --git a/src/java/org/apache/cassandra/cql3/Operator.java b/src/java/org/apache/cassandra/cql3/Operator.java index 9908c64c39ea..d393fc1aa416 100644 --- a/src/java/org/apache/cassandra/cql3/Operator.java +++ b/src/java/org/apache/cassandra/cql3/Operator.java @@ -323,7 +323,7 @@ public boolean isSatisfiedBy(AbstractType type, @Nullable Index.Analyzer indexAnalyzer, @Nullable Index.Analyzer queryAnalyzer) { - return true; + throw new UnsupportedOperationException(); } }, NOT_IN(16) @@ -523,6 +523,7 @@ private boolean hasToken(AbstractType type, List tokens, ByteBuff return false; } }, + /** * An operator that performs a distance bounded approximate nearest neighbor search against a vector column such * that all result vectors are within a given distance of the query vector. The notable difference between this @@ -584,6 +585,24 @@ public boolean isSatisfiedBy(AbstractType type, { throw new UnsupportedOperationException(); } + }, + BM25(104) + { + @Override + public String toString() + { + return "BM25"; + } + + @Override + public boolean isSatisfiedBy(AbstractType type, + ByteBuffer leftOperand, + ByteBuffer rightOperand, + @Nullable Index.Analyzer indexAnalyzer, + @Nullable Index.Analyzer queryAnalyzer) + { + throw new UnsupportedOperationException(); + } }; /** diff --git a/src/java/org/apache/cassandra/cql3/Ordering.java b/src/java/org/apache/cassandra/cql3/Ordering.java index 81aa94a076cd..2dd817818a76 100644 --- a/src/java/org/apache/cassandra/cql3/Ordering.java +++ b/src/java/org/apache/cassandra/cql3/Ordering.java @@ -20,6 +20,7 @@ import org.apache.cassandra.cql3.restrictions.SingleColumnRestriction; import org.apache.cassandra.cql3.restrictions.SingleRestriction; +import org.apache.cassandra.cql3.statements.SelectStatement; import org.apache.cassandra.schema.ColumnMetadata; import org.apache.cassandra.schema.TableMetadata; @@ -48,6 +49,11 @@ public interface Expression SingleRestriction toRestriction(); ColumnMetadata getColumn(); + + default boolean isScored() + { + return false; + } } /** @@ -118,6 +124,54 @@ public ColumnMetadata getColumn() { return column; } + + @Override + public boolean isScored() + { + return SelectStatement.ANN_USE_SYNTHETIC_SCORE; + } + } + + /** + * An expression used in BM25 ordering. + * ORDER BY column BM25 OF value + */ + public static class Bm25 implements Expression + { + final ColumnMetadata column; + final Term queryValue; + final Direction direction; + + public Bm25(ColumnMetadata column, Term queryValue, Direction direction) + { + this.column = column; + this.queryValue = queryValue; + this.direction = direction; + } + + @Override + public boolean hasNonClusteredOrdering() + { + return true; + } + + @Override + public SingleRestriction toRestriction() + { + return new SingleColumnRestriction.Bm25Restriction(column, queryValue); + } + + @Override + public ColumnMetadata getColumn() + { + return column; + } + + @Override + public boolean isScored() + { + return true; + } } public enum Direction @@ -190,6 +244,27 @@ public Ordering.Expression bind(TableMetadata table, VariableSpecifications boun return new Ordering.Ann(column, value, direction); } } + + public static class Bm25 implements Expression + { + final ColumnIdentifier columnId; + final Term.Raw queryValue; + + Bm25(ColumnIdentifier column, Term.Raw queryValue) + { + this.columnId = column; + this.queryValue = queryValue; + } + + @Override + public Ordering.Expression bind(TableMetadata table, VariableSpecifications boundNames, Direction direction) + { + ColumnMetadata column = table.getExistingColumn(columnId); + Term value = queryValue.prepare(table.keyspace, column); + value.collectMarkerSpecification(boundNames); + return new Ordering.Bm25(column, value, direction); + } + } } } diff --git a/src/java/org/apache/cassandra/cql3/Relation.java b/src/java/org/apache/cassandra/cql3/Relation.java index 5cca2d257323..42cf3c9c8287 100644 --- a/src/java/org/apache/cassandra/cql3/Relation.java +++ b/src/java/org/apache/cassandra/cql3/Relation.java @@ -202,6 +202,8 @@ public final Restriction toRestriction(TableMetadata table, VariableSpecificatio return newLikeRestriction(table, boundNames, relationType); case ANN: return newAnnRestriction(table, boundNames); + case BM25: + return newBm25Restriction(table, boundNames); case ANALYZER_MATCHES: return newAnalyzerMatchesRestriction(table, boundNames); default: throw invalidRequest("Unsupported \"!=\" relation: %s", this); @@ -296,6 +298,11 @@ protected abstract Restriction newSliceRestriction(TableMetadata table, */ protected abstract Restriction newAnnRestriction(TableMetadata table, VariableSpecifications boundNames); + /** + * Creates a new BM25 restriction instance. + */ + protected abstract Restriction newBm25Restriction(TableMetadata table, VariableSpecifications boundNames); + /** * Creates a new Analyzer Matches restriction instance. */ diff --git a/src/java/org/apache/cassandra/cql3/SingleColumnRelation.java b/src/java/org/apache/cassandra/cql3/SingleColumnRelation.java index ec66ad70b529..2b97b99c3562 100644 --- a/src/java/org/apache/cassandra/cql3/SingleColumnRelation.java +++ b/src/java/org/apache/cassandra/cql3/SingleColumnRelation.java @@ -21,8 +21,11 @@ import java.util.ArrayList; import java.util.List; import java.util.Objects; +import java.util.stream.Collectors; import org.apache.cassandra.db.marshal.VectorType; +import org.apache.cassandra.index.IndexRegistry; +import org.apache.cassandra.index.sai.analyzer.AnalyzerEqOperatorSupport; import org.apache.cassandra.schema.ColumnMetadata; import org.apache.cassandra.schema.TableMetadata; import org.apache.cassandra.cql3.Term.Raw; @@ -33,6 +36,7 @@ import org.apache.cassandra.db.marshal.ListType; import org.apache.cassandra.db.marshal.MapType; import org.apache.cassandra.exceptions.InvalidRequestException; +import org.apache.cassandra.service.ClientWarn; import static org.apache.cassandra.cql3.statements.RequestValidations.checkFalse; import static org.apache.cassandra.cql3.statements.RequestValidations.checkTrue; @@ -191,7 +195,29 @@ protected Restriction newEQRestriction(TableMetadata table, VariableSpecificatio if (mapKey == null) { Term term = toTerm(toReceivers(columnDef), value, table.keyspace, boundNames); - return new SingleColumnRestriction.EQRestriction(columnDef, term); + // Leave the restriction as EQ if no analyzed index in backwards compatibility mode is present + var ebi = IndexRegistry.obtain(table).getEqBehavior(columnDef); + if (ebi.behavior == IndexRegistry.EqBehavior.EQ) + return new SingleColumnRestriction.EQRestriction(columnDef, term); + + // the index is configured to transform EQ into MATCH for backwards compatibility + var matchIndexName = ebi.matchIndex.getIndexMetadata() == null ? "Unknown" : ebi.matchIndex.getIndexMetadata().name; + if (ebi.behavior == IndexRegistry.EqBehavior.MATCH) + { + ClientWarn.instance.warn(String.format(AnalyzerEqOperatorSupport.EQ_RESTRICTION_ON_ANALYZED_WARNING, + columnDef.toString(), + matchIndexName), + columnDef); + return new SingleColumnRestriction.AnalyzerMatchesRestriction(columnDef, term); + } + + // multiple indexes support EQ, this is unsupported + assert ebi.behavior == IndexRegistry.EqBehavior.AMBIGUOUS; + var eqIndexName = ebi.eqIndex.getIndexMetadata() == null ? "Unknown" : ebi.eqIndex.getIndexMetadata().name; + throw invalidRequest(AnalyzerEqOperatorSupport.EQ_AMBIGUOUS_ERROR, + columnDef.toString(), + matchIndexName, + eqIndexName); } List receivers = toReceivers(columnDef); Term entryKey = toTerm(Collections.singletonList(receivers.get(0)), mapKey, table.keyspace, boundNames); @@ -333,6 +359,14 @@ protected Restriction newAnnRestriction(TableMetadata table, VariableSpecificati return new SingleColumnRestriction.AnnRestriction(columnDef, term); } + @Override + protected Restriction newBm25Restriction(TableMetadata table, VariableSpecifications boundNames) + { + ColumnMetadata columnDef = table.getExistingColumn(entity); + Term term = toTerm(toReceivers(columnDef), value, table.keyspace, boundNames); + return new SingleColumnRestriction.Bm25Restriction(columnDef, term); + } + @Override protected Restriction newAnalyzerMatchesRestriction(TableMetadata table, VariableSpecifications boundNames) { diff --git a/src/java/org/apache/cassandra/cql3/TokenRelation.java b/src/java/org/apache/cassandra/cql3/TokenRelation.java index a3ca586eee76..ca849dc82a30 100644 --- a/src/java/org/apache/cassandra/cql3/TokenRelation.java +++ b/src/java/org/apache/cassandra/cql3/TokenRelation.java @@ -138,7 +138,13 @@ protected Restriction newLikeRestriction(TableMetadata table, VariableSpecificat @Override protected Restriction newAnnRestriction(TableMetadata table, VariableSpecifications boundNames) { - throw invalidRequest("%s cannot be used for toekn relations", operator()); + throw invalidRequest("%s cannot be used for token relations", operator()); + } + + @Override + protected Restriction newBm25Restriction(TableMetadata table, VariableSpecifications boundNames) + { + throw invalidRequest("%s cannot be used for token relations", operator()); } @Override diff --git a/src/java/org/apache/cassandra/cql3/restrictions/ClusteringColumnRestrictions.java b/src/java/org/apache/cassandra/cql3/restrictions/ClusteringColumnRestrictions.java index 8ccd7fcaf37a..b2c08a065045 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/ClusteringColumnRestrictions.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/ClusteringColumnRestrictions.java @@ -191,7 +191,7 @@ public ClusteringColumnRestrictions.Builder addRestriction(Restriction restricti SingleRestriction lastRestriction = restrictions.lastRestriction(); ColumnMetadata lastRestrictionStart = lastRestriction.getFirstColumn(); ColumnMetadata newRestrictionStart = newRestriction.getFirstColumn(); - restrictions.addRestriction(newRestriction, isDisjunction, indexRegistry); + restrictions.addRestriction(newRestriction, isDisjunction); checkFalse(lastRestriction.isSlice() && newRestrictionStart.position() > lastRestrictionStart.position(), "Clustering column \"%s\" cannot be restricted (preceding column \"%s\" is restricted by a non-EQ relation)", @@ -205,7 +205,7 @@ public ClusteringColumnRestrictions.Builder addRestriction(Restriction restricti } else { - restrictions.addRestriction(newRestriction, isDisjunction, indexRegistry); + restrictions.addRestriction(newRestriction, isDisjunction); } return this; diff --git a/src/java/org/apache/cassandra/cql3/restrictions/PartitionKeySingleRestrictionSet.java b/src/java/org/apache/cassandra/cql3/restrictions/PartitionKeySingleRestrictionSet.java index 2536f94f482d..ccb144b4b174 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/PartitionKeySingleRestrictionSet.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/PartitionKeySingleRestrictionSet.java @@ -191,7 +191,7 @@ public PartitionKeyRestrictions build(IndexRegistry indexRegistry, boolean isDis if (restriction.isOnToken()) return buildWithTokens(restrictionSet, i, indexRegistry); - restrictionSet.addRestriction((SingleRestriction) restriction, isDisjunction, indexRegistry); + restrictionSet.addRestriction((SingleRestriction) restriction, isDisjunction); } return buildPartitionKeyRestrictions(restrictionSet); diff --git a/src/java/org/apache/cassandra/cql3/restrictions/RestrictionSet.java b/src/java/org/apache/cassandra/cql3/restrictions/RestrictionSet.java index 4cde1c434f33..dacb018e6186 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/RestrictionSet.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/RestrictionSet.java @@ -398,45 +398,36 @@ private Builder() { } - public void addRestriction(SingleRestriction restriction, boolean isDisjunction, IndexRegistry indexRegistry) + public void addRestriction(SingleRestriction restriction, boolean isDisjunction) { List columnDefs = restriction.getColumnDefs(); if (isDisjunction) { // If this restriction is part of a disjunction query then we don't want - // to merge the restrictions (if that is possible), we just add the - // restriction to the set of restrictions for the column. + // to merge the restrictions, we just add the new restriction addRestrictionForColumns(columnDefs, restriction, null); } else { - // In some special cases such as EQ in analyzed index we need to skip merging the restriction, - // so we can send multiple EQ restrictions to the index. - if (restriction.skipMerge(indexRegistry)) - { - addRestrictionForColumns(columnDefs, restriction, null); - return; - } - - // If this restriction isn't part of a disjunction then we need to get - // the set of existing restrictions for the column and merge them with the - // new restriction + // ANDed together restrictions against the same columns should be merged. Set existingRestrictions = getRestrictions(newRestrictions, columnDefs); - SingleRestriction merged = restriction; - Set replacedRestrictions = new HashSet<>(); - - for (SingleRestriction existing : existingRestrictions) + // merge the new restriction into an existing one. note that there is only ever a single + // restriction (per column), UNLESS one is ORDER BY BM25 and the other is MATCH. + for (var existing : existingRestrictions) { - if (!existing.skipMerge(indexRegistry)) + // shouldMerge exists for the BM25/MATCH case + if (existing.shouldMerge(restriction)) { - merged = existing.mergeWith(merged); - replacedRestrictions.add(existing); + var merged = existing.mergeWith(restriction); + addRestrictionForColumns(merged.getColumnDefs(), merged, Set.of(existing)); + return; } } - addRestrictionForColumns(merged.getColumnDefs(), merged, replacedRestrictions); + // no existing restrictions that we should merge the new one with, add a new one + addRestrictionForColumns(columnDefs, restriction, null); } } diff --git a/src/java/org/apache/cassandra/cql3/restrictions/SingleColumnRestriction.java b/src/java/org/apache/cassandra/cql3/restrictions/SingleColumnRestriction.java index 46c2afcc1d8f..cd44a6f09000 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/SingleColumnRestriction.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/SingleColumnRestriction.java @@ -24,15 +24,19 @@ import java.util.List; import java.util.Map; -import org.apache.cassandra.db.filter.RowFilter; -import org.apache.cassandra.db.filter.ANNOptions; -import org.apache.cassandra.schema.ColumnMetadata; -import org.apache.cassandra.cql3.*; +import org.apache.cassandra.cql3.MarkerOrTerms; +import org.apache.cassandra.cql3.Operator; +import org.apache.cassandra.cql3.QueryOptions; +import org.apache.cassandra.cql3.Term; +import org.apache.cassandra.cql3.Terms; import org.apache.cassandra.cql3.functions.Function; import org.apache.cassandra.cql3.statements.Bound; import org.apache.cassandra.db.MultiClusteringBuilder; +import org.apache.cassandra.db.filter.ANNOptions; +import org.apache.cassandra.db.filter.RowFilter; import org.apache.cassandra.index.Index; import org.apache.cassandra.index.IndexRegistry; +import org.apache.cassandra.schema.ColumnMetadata; import org.apache.cassandra.serializers.ListSerializer; import org.apache.cassandra.transport.ProtocolVersion; import org.apache.cassandra.utils.ByteBufferUtil; @@ -76,12 +80,17 @@ public ColumnMetadata getLastColumn() @Override public boolean hasSupportingIndex(IndexRegistry indexRegistry) + { + return findSupportingIndex(indexRegistry) != null; + } + + public Index findSupportingIndex(IndexRegistry indexRegistry) { for (Index index : indexRegistry.listIndexes()) if (isSupportedBy(index)) - return true; + return index; - return false; + return null; } @Override @@ -190,24 +199,6 @@ public String toString() return String.format("EQ(%s)", term); } - @Override - public boolean skipMerge(IndexRegistry indexRegistry) - { - // We should skip merging this EQ if there is an analyzed index for this column that supports EQ, - // so there can be multiple EQs for the same column. - - if (indexRegistry == null) - return false; - - for (Index index : indexRegistry.listIndexes()) - { - if (index.supportsExpression(columnDef, Operator.ANALYZER_MATCHES) && - index.supportsExpression(columnDef, Operator.EQ)) - return true; - } - return false; - } - @Override public SingleRestriction doMergeWith(SingleRestriction otherRestriction) { @@ -1188,6 +1179,86 @@ public boolean isBoundedAnn() } } + public static final class Bm25Restriction extends SingleColumnRestriction + { + private final Term value; + + public Bm25Restriction(ColumnMetadata columnDef, Term value) + { + super(columnDef); + this.value = value; + } + + public ByteBuffer value(QueryOptions options) + { + return value.bindAndGet(options); + } + + @Override + public void addFunctionsTo(List functions) + { + value.addFunctionsTo(functions); + } + + @Override + MultiColumnRestriction toMultiColumnRestriction() + { + throw new UnsupportedOperationException(); + } + + @Override + public void addToRowFilter(RowFilter.Builder filter, IndexRegistry indexRegistry, QueryOptions options, ANNOptions annOptions) + { + var index = findSupportingIndex(indexRegistry); + var valueBytes = value.bindAndGet(options); + var terms = index.getQueryAnalyzer().get().analyze(valueBytes); + if (terms.isEmpty()) + throw invalidRequest("BM25 query must contain at least one term (perhaps your analyzer is discarding tokens you didn't expect)"); + filter.add(columnDef, Operator.BM25, valueBytes); + } + + @Override + public MultiClusteringBuilder appendTo(MultiClusteringBuilder builder, QueryOptions options) + { + throw new UnsupportedOperationException(); + } + + @Override + public String toString() + { + return String.format("BM25(%s)", value); + } + + @Override + public SingleRestriction doMergeWith(SingleRestriction otherRestriction) + { + throw invalidRequest("%s cannot be restricted by both BM25 and %s", columnDef.name, otherRestriction.toString()); + } + + @Override + protected boolean isSupportedBy(Index index) + { + return index.supportsExpression(columnDef, Operator.BM25); + } + + @Override + public boolean isIndexBasedOrdering() + { + return true; + } + + @Override + public boolean shouldMerge(SingleRestriction other) + { + // we don't want to merge MATCH restrictions with ORDER BY BM25 + // so shouldMerge = false for that scenario, and true for others + // (because even though we can't meaningfully merge with others, we want doMergeWith to be called to throw) + // + // (Note that because ORDER BY is processed before WHERE, we only need this check in the BM25 class) + return !other.isAnalyzerMatches(); + } + } + /** * A Bounded ANN Restriction is one that uses a similarity score as the limiting factor for ANN instead of a number * of results. @@ -1335,10 +1406,12 @@ public String toString() @Override public SingleRestriction doMergeWith(SingleRestriction otherRestriction) { - if (!(otherRestriction.isAnalyzerMatches())) + if (!otherRestriction.isAnalyzerMatches()) throw invalidRequest(CANNOT_BE_MERGED_ERROR, columnDef.name); - List otherValues = ((AnalyzerMatchesRestriction) otherRestriction).getValues(); + List otherValues = otherRestriction instanceof AnalyzerMatchesRestriction + ? ((AnalyzerMatchesRestriction) otherRestriction).getValues() + : List.of(((EQRestriction) otherRestriction).term); List newValues = new ArrayList<>(values.size() + otherValues.size()); newValues.addAll(values); newValues.addAll(otherValues); diff --git a/src/java/org/apache/cassandra/cql3/restrictions/SingleRestriction.java b/src/java/org/apache/cassandra/cql3/restrictions/SingleRestriction.java index 595451f812de..bdd80badc0ae 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/SingleRestriction.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/SingleRestriction.java @@ -20,7 +20,6 @@ import org.apache.cassandra.cql3.QueryOptions; import org.apache.cassandra.cql3.statements.Bound; import org.apache.cassandra.db.MultiClusteringBuilder; -import org.apache.cassandra.index.IndexRegistry; /** * A single restriction/clause on one or multiple column. @@ -97,17 +96,6 @@ public default boolean isInclusive(Bound b) return true; } - /** - * Checks if this restriction shouldn't be merged with other restrictions. - * - * @param indexRegistry the index registry - * @return {@code true} if this shouldn't be merged with other restrictions - */ - default boolean skipMerge(IndexRegistry indexRegistry) - { - return false; - } - /** * Merges this restriction with the specified one. * @@ -141,4 +129,16 @@ public default MultiClusteringBuilder appendBoundTo(MultiClusteringBuilder build { return appendTo(builder, options); } + + /** + * @return true if the other restriction should be merged with this one. + * This is NOT for preventing illegal combinations of restrictions, e.g. + * a=1 AND a=2; that is handled by mergeWith. Instead, this is for the case + * where we want two completely different semantics against the same column. + * Currently the only such case is BM25 with MATCH. + */ + default boolean shouldMerge(SingleRestriction other) + { + return true; + } } diff --git a/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java b/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java index b20d658867e8..ac49d7d6f694 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java @@ -83,6 +83,7 @@ public class StatementRestrictions "Restriction on partition key column %s must not be nested under OR operator"; public static final String GEO_DISTANCE_REQUIRES_INDEX_MESSAGE = "GEO_DISTANCE requires the vector column to be indexed"; + public static final String BM25_ORDERING_REQUIRES_ANALYZED_INDEX_MESSAGE = "BM25 ordering on column %s requires an analyzed index"; public static final String NON_CLUSTER_ORDERING_REQUIRES_INDEX_MESSAGE = "Ordering on non-clustering column %s requires the column to be indexed."; public static final String NON_CLUSTER_ORDERING_REQUIRES_ALL_RESTRICTED_NON_PARTITION_KEY_COLUMNS_INDEXED_MESSAGE = "Ordering on non-clustering column requires each restricted column to be indexed except for fully-specified partition keys"; @@ -445,7 +446,7 @@ else if (def.isClusteringColumn() && nestingLevel == 0) } else { - nonPrimaryKeyRestrictionSet.addRestriction((SingleRestriction) restriction, element.isDisjunction(), indexRegistry); + nonPrimaryKeyRestrictionSet.addRestriction((SingleRestriction) restriction, element.isDisjunction()); } } } @@ -685,7 +686,8 @@ else if (indexOrderings.size() == 1) if (orderings.size() > 1) throw new InvalidRequestException("Cannot combine clustering column ordering with non-clustering column ordering"); Ordering ordering = indexOrderings.get(0); - if (ordering.direction != Ordering.Direction.ASC && ordering.expression instanceof Ordering.Ann) + // TODO remove the instanceof with SelectStatement.ANN_USE_SYNTHETIC_SCORE. + if (ordering.direction != Ordering.Direction.ASC && (ordering.expression.isScored() || ordering.expression instanceof Ordering.Ann)) throw new InvalidRequestException("Descending ANN ordering is not supported"); if (!ENABLE_SAI_GENERAL_ORDER_BY && ordering.expression instanceof Ordering.SingleColumn) throw new InvalidRequestException("SAI based ORDER BY on non-vector column is not supported"); @@ -698,10 +700,14 @@ else if (indexOrderings.size() == 1) throw new InvalidRequestException(String.format("SAI based ordering on column %s of type %s is not supported", restriction.getFirstColumn(), restriction.getFirstColumn().type.asCQL3Type())); - throw new InvalidRequestException(String.format(NON_CLUSTER_ORDERING_REQUIRES_INDEX_MESSAGE, - restriction.getFirstColumn())); + if (ordering.expression instanceof Ordering.Bm25) + throw new InvalidRequestException(String.format(BM25_ORDERING_REQUIRES_ANALYZED_INDEX_MESSAGE, + restriction.getFirstColumn())); + else + throw new InvalidRequestException(String.format(NON_CLUSTER_ORDERING_REQUIRES_INDEX_MESSAGE, + restriction.getFirstColumn())); } - receiver.addRestriction(restriction, false, indexRegistry); + receiver.addRestriction(restriction, false); } } diff --git a/src/java/org/apache/cassandra/cql3/selection/ColumnFilterFactory.java b/src/java/org/apache/cassandra/cql3/selection/ColumnFilterFactory.java index 63fa0520101e..00225cca4108 100644 --- a/src/java/org/apache/cassandra/cql3/selection/ColumnFilterFactory.java +++ b/src/java/org/apache/cassandra/cql3/selection/ColumnFilterFactory.java @@ -38,9 +38,21 @@ abstract class ColumnFilterFactory */ abstract ColumnFilter newInstance(List selectors); - public static ColumnFilterFactory wildcard(TableMetadata table) + public static ColumnFilterFactory wildcard(TableMetadata table, Set orderingColumns) { - return new PrecomputedColumnFilter(ColumnFilter.all(table)); + ColumnFilter cf; + if (orderingColumns.isEmpty()) + { + cf = ColumnFilter.all(table); + } + else + { + ColumnFilter.Builder builder = ColumnFilter.selectionBuilder(); + builder.addAll(table.regularAndStaticColumns()); + builder.addAll(orderingColumns); + cf = builder.build(); + } + return new PrecomputedColumnFilter(cf); } public static ColumnFilterFactory fromColumns(TableMetadata table, diff --git a/src/java/org/apache/cassandra/cql3/selection/Selection.java b/src/java/org/apache/cassandra/cql3/selection/Selection.java index 02aae61dd5ff..12d8aa014e19 100644 --- a/src/java/org/apache/cassandra/cql3/selection/Selection.java +++ b/src/java/org/apache/cassandra/cql3/selection/Selection.java @@ -43,10 +43,24 @@ public abstract class Selection private static final Predicate STATIC_COLUMN_FILTER = (column) -> column.isStatic(); private final TableMetadata table; + + // Full list of columns needed for processing the query, including selected columns, ordering columns, + // and columns needed for restrictions. Wildcard columns are fully materialized here. + // + // This also includes synthetic columns, because unlike all the other not-physical-columns selectables, they are + // computed on the replica instead of the coordinator and so, like physical columns, they need to be sent back + // as part of the result. private final List columns; + + // maps ColumnSpecifications (columns, function calls, aliases) to the columns backing them private final SelectionColumnMapping columnMapping; + + // metadata matching the ColumnSpcifications protected final ResultSet.ResultMetadata metadata; + + // creates a ColumnFilter that breaks columns into `queried` and `fetched` protected final ColumnFilterFactory columnFilterFactory; + protected final boolean isJson; // Columns used to order the result set for JSON queries with post ordering. @@ -126,10 +140,15 @@ public ResultSet.ResultMetadata getResultMetadata() } public static Selection wildcard(TableMetadata table, boolean isJson, boolean returnStaticContentOnPartitionWithNoRows) + { + return wildcard(table, Collections.emptySet(), isJson, returnStaticContentOnPartitionWithNoRows); + } + + public static Selection wildcard(TableMetadata table, Set orderingColumns, boolean isJson, boolean returnStaticContentOnPartitionWithNoRows) { List all = new ArrayList<>(table.columns().size()); Iterators.addAll(all, table.allColumnsInSelectOrder()); - return new SimpleSelection(table, all, Collections.emptySet(), true, isJson, returnStaticContentOnPartitionWithNoRows); + return new SimpleSelection(table, all, orderingColumns, true, isJson, returnStaticContentOnPartitionWithNoRows); } public static Selection wildcardWithGroupBy(TableMetadata table, @@ -400,7 +419,7 @@ public SimpleSelection(TableMetadata table, selectedColumns, orderingColumns, SelectionColumnMapping.simpleMapping(selectedColumns), - isWildcard ? ColumnFilterFactory.wildcard(table) + isWildcard ? ColumnFilterFactory.wildcard(table, orderingColumns) : ColumnFilterFactory.fromColumns(table, selectedColumns, orderingColumns, Collections.emptySet(), returnStaticContentOnPartitionWithNoRows), isWildcard, isJson); diff --git a/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java b/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java index 58b75aacab5a..63fd7abd2b70 100644 --- a/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java +++ b/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java @@ -40,6 +40,7 @@ import org.apache.cassandra.cql3.restrictions.ExternalRestriction; import org.apache.cassandra.cql3.restrictions.Restrictions; import org.apache.cassandra.cql3.selection.SortedRowsBuilder; +import org.apache.cassandra.db.marshal.FloatType; import org.apache.cassandra.guardrails.Guardrails; import org.apache.cassandra.sensors.SensorsCustomParams; import org.apache.cassandra.schema.ColumnMetadata; @@ -105,6 +106,12 @@ */ public class SelectStatement implements CQLStatement.SingleKeyspaceCqlStatement { + // TODO remove this when we no longer need to downgrade to replicas that don't know about synthetic columns, + // and the related code in + // - StatementRestrictions.addOrderingRestrictions + // - StorageAttachedIndexSearcher.PrimaryKeyIterator constructor + public static final boolean ANN_USE_SYNTHETIC_SCORE = Boolean.parseBoolean(System.getProperty("cassandra.sai.ann_use_synthetic_score", "false")); + private static final Logger logger = LoggerFactory.getLogger(SelectStatement.class); private static final NoSpamLogger noSpamLogger = NoSpamLogger.getLogger(SelectStatement.logger, 1, TimeUnit.MINUTES); public static final String TOPK_CONSISTENCY_LEVEL_ERROR = "Top-K queries can only be run with consistency level ONE/LOCAL_ONE. Consistency level %s was used."; @@ -1100,12 +1107,16 @@ void processPartition(RowIterator partition, QueryOptions options, ResultSetBuil case CLUSTERING: result.add(row.clustering().bufferAt(def.position())); break; + case SYNTHETIC: + // treat as REGULAR case REGULAR: result.add(row.getColumnData(def), nowInSec); break; case STATIC: result.add(staticRow.getColumnData(def), nowInSec); break; + default: + throw new AssertionError(); } } } @@ -1191,6 +1202,9 @@ public SelectStatement prepare(boolean forView, UnaryOperator keyspaceMa List selectables = RawSelector.toSelectables(selectClause, table); boolean containsOnlyStaticColumns = selectOnlyStaticColumns(table, selectables); + // Besides actual restrictions (where clauses), prepareRestrictions will include pseudo-restrictions + // on indexed columns to allow pushing ORDER BY into the index; see StatementRestrictions::addOrderingRestrictions. + // Therefore, we don't want to convert an ANN Ordering column into a +score column until after that. List orderings = getOrderings(table); StatementRestrictions restrictions = prepareRestrictions( table, bindVariables, orderings, containsOnlyStaticColumns, forView); @@ -1198,6 +1212,11 @@ public SelectStatement prepare(boolean forView, UnaryOperator keyspaceMa // If we order post-query, the sorted column needs to be in the ResultSet for sorting, // even if we don't ultimately ship them to the client (CASSANDRA-4911). Map orderingColumns = getOrderingColumns(orderings); + // +score column for ANN/BM25 + var scoreOrdering = getScoreOrdering(orderings); + assert scoreOrdering == null || orderingColumns.isEmpty() : "can't have both scored ordering and column ordering"; + if (scoreOrdering != null) + orderingColumns = scoreOrdering; Set resultSetOrderingColumns = getResultSetOrdering(restrictions, orderingColumns); Selection selection = prepareSelection(table, @@ -1226,9 +1245,9 @@ public SelectStatement prepare(boolean forView, UnaryOperator keyspaceMa if (!orderingColumns.isEmpty()) { assert !forView; - verifyOrderingIsAllowed(restrictions, orderingColumns); + verifyOrderingIsAllowed(table, restrictions, orderingColumns); orderingComparator = getOrderingComparator(selection, restrictions, orderingColumns); - isReversed = isReversed(table, orderingColumns, restrictions); + isReversed = isReversed(table, orderingColumns); if (isReversed && orderingComparator != null) orderingComparator = orderingComparator.reverse(); } @@ -1252,6 +1271,21 @@ public SelectStatement prepare(boolean forView, UnaryOperator keyspaceMa options); } + private Map getScoreOrdering(List orderings) + { + if (orderings.isEmpty()) + return null; + + var expr = orderings.get(0).expression; + if (!expr.isScored()) + return null; + + // Create synthetic score column + ColumnMetadata sourceColumn = expr.getColumn(); + var cm = ColumnMetadata.syntheticColumn(sourceColumn.ksName, sourceColumn.cfName, ColumnMetadata.SYNTHETIC_SCORE_ID, FloatType.instance); + return Map.of(cm, orderings.get(0)); + } + private Set getResultSetOrdering(StatementRestrictions restrictions, Map orderingColumns) { if (restrictions.keyIsInRelation() || orderingColumns.values().stream().anyMatch(o -> o.expression.hasNonClusteredOrdering())) @@ -1269,8 +1303,9 @@ private Selection prepareSelection(TableMetadata table, if (selectables.isEmpty()) // wildcard query { - return hasGroupBy ? Selection.wildcardWithGroupBy(table, boundNames, parameters.isJson, restrictions.returnStaticContentOnPartitionWithNoRows()) - : Selection.wildcard(table, parameters.isJson, restrictions.returnStaticContentOnPartitionWithNoRows()); + return hasGroupBy + ? Selection.wildcardWithGroupBy(table, boundNames, parameters.isJson, restrictions.returnStaticContentOnPartitionWithNoRows()) + : Selection.wildcard(table, parameters.isJson, restrictions.returnStaticContentOnPartitionWithNoRows()); } return Selection.fromSelectors(table, @@ -1312,13 +1347,14 @@ private Map getOrderingColumns(List ordering if (orderings.isEmpty()) return Collections.emptyMap(); - Map orderingColumns = new LinkedHashMap<>(); - for (Ordering ordering : orderings) - { - ColumnMetadata column = ordering.expression.getColumn(); - orderingColumns.put(column, ordering); - } - return orderingColumns; + return orderings.stream() + .filter(ordering -> !ordering.expression.isScored()) + .collect(Collectors.toMap(ordering -> ordering.expression.getColumn(), + ordering -> ordering, + (a, b) -> { + throw new IllegalStateException("Duplicate keys"); + }, + LinkedHashMap::new)); } private List getOrderings(TableMetadata table) @@ -1365,12 +1401,28 @@ private Term prepareLimit(VariableSpecifications boundNames, Term.Raw limit, return prepLimit; } - private static void verifyOrderingIsAllowed(StatementRestrictions restrictions, Map orderingColumns) throws InvalidRequestException + private static void verifyOrderingIsAllowed(TableMetadata table, StatementRestrictions restrictions, Map orderingColumns) throws InvalidRequestException { if (orderingColumns.values().stream().anyMatch(o -> o.expression.hasNonClusteredOrdering())) return; + checkFalse(restrictions.usesSecondaryIndexing(), "ORDER BY with 2ndary indexes is not supported."); checkFalse(restrictions.isKeyRange(), "ORDER BY is only supported when the partition key is restricted by an EQ or an IN."); + + // check that clustering columns are valid + int i = 0; + for (var entry : orderingColumns.entrySet()) + { + ColumnMetadata def = entry.getKey(); + checkTrue(def.isClusteringColumn(), + "Order by is currently only supported on indexed columns and the clustered columns of the PRIMARY KEY, got %s", def.name); + while (i != def.position()) + { + checkTrue(restrictions.isColumnRestrictedByEq(table.clusteringColumns().get(i++)), + "Ordering by clustered columns must follow the declared order in the PRIMARY KEY"); + } + i++; + } } private static void validateDistinctSelection(TableMetadata metadata, @@ -1485,35 +1537,30 @@ private ColumnComparator> getOrderingComparator(Selection selec : new CompositeComparator(sorters, idToSort); } - private boolean isReversed(TableMetadata table, Map orderingColumns, StatementRestrictions restrictions) throws InvalidRequestException + private boolean isReversed(TableMetadata table, Map orderingColumns) throws InvalidRequestException { - // Nonclustered ordering handles descending logic in a different way + // Nonclustered ordering handles descending logic through ScoreOrderedResultRetriever and TKP if (orderingColumns.values().stream().anyMatch(o -> o.expression.hasNonClusteredOrdering())) return false; - Boolean[] reversedMap = new Boolean[table.clusteringColumns().size()]; - int i = 0; + Boolean[] clusteredMap = new Boolean[table.clusteringColumns().size()]; for (var entry : orderingColumns.entrySet()) { ColumnMetadata def = entry.getKey(); Ordering ordering = entry.getValue(); - boolean reversed = ordering.direction == Ordering.Direction.DESC; - - // VSTODO move this to verifyOrderingIsAllowed? - checkTrue(def.isClusteringColumn(), - "Order by is currently only supported on the clustered columns of the PRIMARY KEY, got %s", def.name); - while (i != def.position()) - { - checkTrue(restrictions.isColumnRestrictedByEq(table.clusteringColumns().get(i++)), - "Order by currently only supports the ordering of columns following their declared order in the PRIMARY KEY"); - } - i++; - reversedMap[def.position()] = (reversed != def.isReversedType()); + // We defined ANN OF to be ASC ordering, as in, "order by near-ness". But since score goes from + // 0 (worst) to 1 (closest), we need to reverse the ordering for the comparator when we're sorting + // by synthetic +score column. + boolean cqlReversed = ordering.direction == Ordering.Direction.DESC; + if (def.position() == ColumnMetadata.NO_POSITION) + return ordering.expression.isScored() || cqlReversed; + else + clusteredMap[def.position()] = (cqlReversed != def.isReversedType()); } - // Check that all boolean in reversedMap, if set, agrees + // Check that all boolean in clusteredMap, if set, agrees Boolean isReversed = null; - for (Boolean b : reversedMap) + for (Boolean b : clusteredMap) { // Column on which order is specified can be in any order if (b == null) @@ -1658,7 +1705,14 @@ public int compare(T o1, T o2) { return wrapped.compare(o2, o1); } + + @Override + public boolean indexOrdering() + { + return wrapped.indexOrdering(); + } } + /** * Used in orderResults(...) method when single 'ORDER BY' condition where given */ diff --git a/src/java/org/apache/cassandra/db/Columns.java b/src/java/org/apache/cassandra/db/Columns.java index 9a904de12253..45b3da97a596 100644 --- a/src/java/org/apache/cassandra/db/Columns.java +++ b/src/java/org/apache/cassandra/db/Columns.java @@ -28,6 +28,7 @@ import net.nicoulaj.compilecommand.annotations.DontInline; import org.apache.cassandra.cql3.ColumnIdentifier; +import org.apache.cassandra.db.marshal.AbstractType; import org.apache.cassandra.db.marshal.SetType; import org.apache.cassandra.db.marshal.UTF8Type; import org.apache.cassandra.db.rows.ColumnData; @@ -36,6 +37,7 @@ import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.schema.ColumnMetadata; import org.apache.cassandra.schema.TableMetadata; +import org.apache.cassandra.serializers.AbstractTypeSerializer; import org.apache.cassandra.utils.ByteBufferUtil; import org.apache.cassandra.utils.ObjectSizes; import org.apache.cassandra.utils.SearchIterator; @@ -459,37 +461,107 @@ public String toString() public static class Serializer { + AbstractTypeSerializer typeSerializer = new AbstractTypeSerializer(); + public void serialize(Columns columns, DataOutputPlus out) throws IOException { - out.writeUnsignedVInt(columns.size()); + int regularCount = 0; + int syntheticCount = 0; + + // Count regular and synthetic columns + for (ColumnMetadata column : columns) + { + if (column.isSynthetic()) + syntheticCount++; + else + regularCount++; + } + + // Jam the two counts into a single value to avoid massive backwards compatibility issues + long packedCount = getPackedCount(syntheticCount, regularCount); + out.writeUnsignedVInt(packedCount); + + // First pass - write synthetic columns with their full metadata for (ColumnMetadata column : columns) - ByteBufferUtil.writeWithVIntLength(column.name.bytes, out); + { + if (column.isSynthetic()) + { + ByteBufferUtil.writeWithVIntLength(column.name.bytes, out); + typeSerializer.serialize(column.type, out); + } + } + + // Second pass - write regular columns + for (ColumnMetadata column : columns) + { + if (!column.isSynthetic()) + ByteBufferUtil.writeWithVIntLength(column.name.bytes, out); + } + } + + private static long getPackedCount(int syntheticCount, int regularCount) + { + // Left shift of 20 gives us over 1M regular columns, and up to 4 synthetic columns + // before overflowing to a 4th byte. + return ((long) syntheticCount << 20) | regularCount; } public long serializedSize(Columns columns) { - long size = TypeSizes.sizeofUnsignedVInt(columns.size()); + int regularCount = 0; + int syntheticCount = 0; + long size = 0; + + // Count and calculate sizes for (ColumnMetadata column : columns) - size += ByteBufferUtil.serializedSizeWithVIntLength(column.name.bytes); - return size; + { + if (column.isSynthetic()) + { + syntheticCount++; + size += ByteBufferUtil.serializedSizeWithVIntLength(column.name.bytes); + size += typeSerializer.serializedSize(column.type); + } + else + { + regularCount++; + size += ByteBufferUtil.serializedSizeWithVIntLength(column.name.bytes); + } + } + + return TypeSizes.sizeofUnsignedVInt(getPackedCount(syntheticCount, regularCount)) + + size; } public Columns deserialize(DataInputPlus in, TableMetadata metadata) throws IOException { - int length = (int)in.readUnsignedVInt(); try (BTree.FastBuilder builder = BTree.fastBuilder()) { - for (int i = 0; i < length; i++) + long packedCount = in.readUnsignedVInt() ; + int regularCount = (int) (packedCount & 0xFFFFF); + int syntheticCount = (int) (packedCount >> 20); + + // First pass - synthetic columns + for (int i = 0; i < syntheticCount; i++) + { + ByteBuffer name = ByteBufferUtil.readWithVIntLength(in); + AbstractType type = typeSerializer.deserialize(in); + + if (!name.equals(ColumnMetadata.SYNTHETIC_SCORE_ID.bytes)) + throw new IllegalStateException("Unknown synthetic column " + UTF8Type.instance.getString(name)); + + ColumnMetadata column = ColumnMetadata.syntheticColumn(metadata.keyspace, metadata.name, ColumnMetadata.SYNTHETIC_SCORE_ID, type); + builder.add(column); + } + + // Second pass - regular columns + for (int i = 0; i < regularCount; i++) { ByteBuffer name = ByteBufferUtil.readWithVIntLength(in); ColumnMetadata column = metadata.getColumn(name); if (column == null) { - // If we don't find the definition, it could be we have data for a dropped column, and we shouldn't - // fail deserialization because of that. So we grab a "fake" ColumnMetadata that ensure proper - // deserialization. The column will be ignore later on anyway. + // If we don't find the definition, it could be we have data for a dropped column column = metadata.getDroppedColumn(name); - if (column == null) throw new RuntimeException("Unknown column " + UTF8Type.instance.getString(name) + " during deserialization of " + metadata.keyspace + '.' + metadata.name); } diff --git a/src/java/org/apache/cassandra/db/ReadCommand.java b/src/java/org/apache/cassandra/db/ReadCommand.java index b110a3ada2de..7fa00f0436e5 100644 --- a/src/java/org/apache/cassandra/db/ReadCommand.java +++ b/src/java/org/apache/cassandra/db/ReadCommand.java @@ -74,6 +74,7 @@ import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessageFlag; import org.apache.cassandra.net.Verb; +import org.apache.cassandra.schema.ColumnMetadata; import org.apache.cassandra.schema.IndexMetadata; import org.apache.cassandra.schema.Schema; import org.apache.cassandra.schema.SchemaConstants; @@ -384,8 +385,6 @@ static Index.QueryPlan findIndexQueryPlan(TableMetadata table, RowFilter rowFilt @Override public void maybeValidateIndexes() { - IndexRegistry.obtain(metadata()).validate(rowFilter()); - if (null != indexQueryPlan) indexQueryPlan.validate(this); } @@ -415,9 +414,9 @@ public UnfilteredPartitionIterator executeLocally(ReadExecutionController execut } Context context = Context.from(this); - UnfilteredPartitionIterator iterator = (null == searcher) ? Transformation.apply(queryStorage(cfs, executionController), new TrackingRowIterator(context)) - : Transformation.apply(searchStorage(searcher, executionController), new TrackingRowIterator(context)); - + var storageTarget = (null == searcher) ? queryStorage(cfs, executionController) + : searchStorage(searcher, executionController); + UnfilteredPartitionIterator iterator = Transformation.apply(storageTarget, new TrackingRowIterator(context)); iterator = RTBoundValidator.validate(iterator, Stage.MERGED, false); try @@ -1054,6 +1053,19 @@ public ReadCommand deserialize(DataInputPlus in, int version) throws IOException TableMetadata metadata = schema.getExistingTableMetadata(TableId.deserialize(in)); int nowInSec = in.readInt(); ColumnFilter columnFilter = ColumnFilter.serializer.deserialize(in, version, metadata); + + // add synthetic columns to the tablemetadata so we can serialize them in our response + var tmb = metadata.unbuild(); + for (var it = columnFilter.fetchedColumns().regulars.simpleColumns(); it.hasNext(); ) + { + var c = it.next(); + // synthetic columns sort first, so when we hit the first non-synthetic, we're done + if (!c.isSynthetic()) + break; + tmb.addColumn(ColumnMetadata.syntheticColumn(c.ksName, c.cfName, c.name, c.type)); + } + metadata = tmb.build(); + RowFilter rowFilter = RowFilter.serializer.deserialize(in, version, metadata); DataLimits limits = DataLimits.serializer.deserialize(in, version, metadata.comparator); Index.QueryPlan indexQueryPlan = null; diff --git a/src/java/org/apache/cassandra/db/RegularAndStaticColumns.java b/src/java/org/apache/cassandra/db/RegularAndStaticColumns.java index b6da183d013f..55533eda0e97 100644 --- a/src/java/org/apache/cassandra/db/RegularAndStaticColumns.java +++ b/src/java/org/apache/cassandra/db/RegularAndStaticColumns.java @@ -163,7 +163,7 @@ public Builder add(ColumnMetadata c) } else { - assert c.isRegular(); + assert c.isRegular() || c.isSynthetic(); if (regularColumns == null) regularColumns = BTree.builder(naturalOrder()); regularColumns.add(c); @@ -197,7 +197,7 @@ public Builder addAll(RegularAndStaticColumns columns) public RegularAndStaticColumns build() { - return new RegularAndStaticColumns(staticColumns == null ? Columns.NONE : Columns.from(staticColumns), + return new RegularAndStaticColumns(staticColumns == null ? Columns.NONE : Columns.from(staticColumns), regularColumns == null ? Columns.NONE : Columns.from(regularColumns)); } } diff --git a/src/java/org/apache/cassandra/db/filter/ColumnFilter.java b/src/java/org/apache/cassandra/db/filter/ColumnFilter.java index d9a1b9d4e51a..644e6d661a61 100644 --- a/src/java/org/apache/cassandra/db/filter/ColumnFilter.java +++ b/src/java/org/apache/cassandra/db/filter/ColumnFilter.java @@ -75,6 +75,9 @@ public abstract class ColumnFilter public static final Serializer serializer = new Serializer(); + // TODO remove this with ANN_USE_SYNTHETIC_SCORE + public abstract boolean fetchesExplicitly(ColumnMetadata column); + /** * The fetching strategy for the different queries. */ @@ -103,7 +106,8 @@ boolean fetchesAllColumns(boolean isStatic) @Override RegularAndStaticColumns getFetchedColumns(TableMetadata metadata, RegularAndStaticColumns queried) { - return metadata.regularAndStaticColumns(); + var merged = queried.regulars.mergeTo(metadata.regularColumns()); + return new RegularAndStaticColumns(metadata.staticColumns(), merged); } }, @@ -124,7 +128,8 @@ boolean fetchesAllColumns(boolean isStatic) @Override RegularAndStaticColumns getFetchedColumns(TableMetadata metadata, RegularAndStaticColumns queried) { - return new RegularAndStaticColumns(queried.statics, metadata.regularColumns()); + var merged = queried.regulars.mergeTo(metadata.regularColumns()); + return new RegularAndStaticColumns(queried.statics, merged); } }, @@ -295,14 +300,16 @@ public static ColumnFilter selection(TableMetadata metadata, } /** - * The columns that needs to be fetched internally for this filter. + * The columns that needs to be fetched internally. See FetchingStrategy for why this is + * always a superset of the queried columns. * * @return the columns to fetch for this filter. */ public abstract RegularAndStaticColumns fetchedColumns(); /** - * The columns actually queried by the user. + * The columns needed to process the query, including selected columns, ordering columns, + * restriction (predicate) columns, and synthetic columns. *

* Note that this is in general not all the columns that are fetched internally (see {@link #fetchedColumns}). */ @@ -619,9 +626,7 @@ private SortedSetMultimap buildSubSelectio */ public static class WildCardColumnFilter extends ColumnFilter { - /** - * The queried and fetched columns. - */ + // for wildcards, there is no distinction between fetched and queried because queried is already "everything" private final RegularAndStaticColumns fetchedAndQueried; /** @@ -667,6 +672,12 @@ public boolean fetches(ColumnMetadata column) return true; } + @Override + public boolean fetchesExplicitly(ColumnMetadata column) + { + return false; + } + @Override public boolean fetchedColumnIsQueried(ColumnMetadata column) { @@ -739,14 +750,9 @@ public static class SelectionColumnFilter extends ColumnFilter { public final FetchingStrategy fetchingStrategy; - /** - * The selected columns - */ + // Materializes the columns required to implement queriedColumns() and fetchedColumns(), + // see the comments to superclass's methods private final RegularAndStaticColumns queried; - - /** - * The columns that need to be fetched to be able - */ private final RegularAndStaticColumns fetched; private final SortedSetMultimap subSelections; // can be null @@ -820,6 +826,12 @@ public boolean fetches(ColumnMetadata column) return fetchingStrategy.fetchesAllColumns(column.isStatic()) || fetched.contains(column); } + @Override + public boolean fetchesExplicitly(ColumnMetadata column) + { + return fetched.contains(column); + } + /** * Whether the provided complex cell (identified by its column and path), which is assumed to be _fetched_ by * this filter, is also _queried_ by the user. diff --git a/src/java/org/apache/cassandra/db/filter/RowFilter.java b/src/java/org/apache/cassandra/db/filter/RowFilter.java index 96ad240b539c..4d35566fa41f 100644 --- a/src/java/org/apache/cassandra/db/filter/RowFilter.java +++ b/src/java/org/apache/cassandra/db/filter/RowFilter.java @@ -1239,6 +1239,7 @@ public boolean isSatisfiedBy(TableMetadata metadata, DecoratedKey partitionKey, case LIKE_MATCHES: case ANALYZER_MATCHES: case ANN: + case BM25: { assert !column.isComplex() : "Only CONTAINS and CONTAINS_KEY are supported for 'complex' types"; ByteBuffer foundValue = getValue(metadata, partitionKey, row); diff --git a/src/java/org/apache/cassandra/db/monitoring/MonitorableImpl.java b/src/java/org/apache/cassandra/db/monitoring/MonitorableImpl.java index a6e7947b23f1..30376eb385f3 100644 --- a/src/java/org/apache/cassandra/db/monitoring/MonitorableImpl.java +++ b/src/java/org/apache/cassandra/db/monitoring/MonitorableImpl.java @@ -18,6 +18,8 @@ package org.apache.cassandra.db.monitoring; +import org.apache.cassandra.index.sai.QueryContext; + import static org.apache.cassandra.utils.MonotonicClock.approxTime; public abstract class MonitorableImpl implements Monitorable @@ -123,6 +125,9 @@ public boolean complete() private void check() { + if (QueryContext.DISABLE_TIMEOUT) + return; + if (approxCreationTimeNanos < 0 || state != MonitoringState.IN_PROGRESS) return; diff --git a/src/java/org/apache/cassandra/db/partitions/ParallelCommandProcessor.java b/src/java/org/apache/cassandra/db/partitions/ParallelCommandProcessor.java deleted file mode 100644 index 19d8523bd5c9..000000000000 --- a/src/java/org/apache/cassandra/db/partitions/ParallelCommandProcessor.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.db.partitions; - -import java.util.List; - -import org.apache.cassandra.db.SinglePartitionReadCommand; -import org.apache.cassandra.db.rows.UnfilteredRowIterator; -import org.apache.cassandra.index.sai.utils.PrimaryKey; -import org.apache.cassandra.utils.Pair; - -/** - * An "processor" over a number of unfiltered partitions (i.e. partitions containing deletion information). - * - * Unlike {@link UnfilteredPartitionIterator}, this is designed to be used concurrently. - * - * Unlike UnfilteredPartitionIterator which requires single-threaded - * while (partitions.hasNext()) - * { - * var part = partitions.next(); - * ... - * part.close(); - * } - * - * this one allows concurrency, like - * var commands = partitions.getUninitializedCommands(); - * commands.parallelStream().forEach(tuple -> { - * var iter = partitions.commandToIterator(tuple.left(), tuple.right()); - * } - */ -public interface ParallelCommandProcessor -{ - /** - * Single-threaded call to get all commands and corresponding keys. - * - * @return the list of partition read commands. - */ - List> getUninitializedCommands(); - - /** - * Get an iterator for a given command and key. - * This method can be called concurrently for reulst of getUninitializedCommands(). - * - * @param command - * @param key - * @return - */ - UnfilteredRowIterator commandToIterator(PrimaryKey key, SinglePartitionReadCommand command); -} diff --git a/src/java/org/apache/cassandra/index/IndexRegistry.java b/src/java/org/apache/cassandra/index/IndexRegistry.java index 639d9aec350b..94bdfe0792fc 100644 --- a/src/java/org/apache/cassandra/index/IndexRegistry.java +++ b/src/java/org/apache/cassandra/index/IndexRegistry.java @@ -21,6 +21,7 @@ package org.apache.cassandra.index; import java.util.Collection; +import java.util.HashSet; import java.util.Collections; import java.util.Optional; import java.util.Set; @@ -103,12 +104,6 @@ public Optional getBestIndexFor(RowFilter.Expression expression) public void validate(PartitionUpdate update) { } - - @Override - public void validate(RowFilter filter) - { - // no-op since it's an empty registry - } }; /** @@ -296,12 +291,6 @@ public Optional getBestIndexFor(RowFilter.Expression expression) public void validate(PartitionUpdate update) { } - - @Override - public void validate(RowFilter filter) - { - // no-op since it's an empty registry - } }; default void registerIndex(Index index) @@ -354,8 +343,6 @@ default Optional getAnalyzerFor(ColumnMetadata column, */ void validate(PartitionUpdate update); - void validate(RowFilter filter); - /** * Returns the {@code IndexRegistry} associated to the specified table. * @@ -369,4 +356,74 @@ public static IndexRegistry obtain(TableMetadata table) return table.isVirtual() ? EMPTY : Keyspace.openAndGetStore(table).indexManager; } + + enum EqBehavior + { + EQ, + MATCH, + AMBIGUOUS + } + + class EqBehaviorIndexes + { + public EqBehavior behavior; + public final Index eqIndex; + public final Index matchIndex; + + private EqBehaviorIndexes(Index eqIndex, Index matchIndex, EqBehavior behavior) + { + this.eqIndex = eqIndex; + this.matchIndex = matchIndex; + this.behavior = behavior; + } + + public static EqBehaviorIndexes eq(Index eqIndex) + { + return new EqBehaviorIndexes(eqIndex, null, EqBehavior.EQ); + } + + public static EqBehaviorIndexes match(Index eqAndMatchIndex) + { + return new EqBehaviorIndexes(eqAndMatchIndex, eqAndMatchIndex, EqBehavior.MATCH); + } + + public static EqBehaviorIndexes ambiguous(Index firstEqIndex, Index secondEqIndex) + { + return new EqBehaviorIndexes(firstEqIndex, secondEqIndex, EqBehavior.AMBIGUOUS); + } + } + + /** + * @return + * - AMBIGUOUS if an index supports EQ and a different one supports both EQ and ANALYZER_MATCHES + * - MATCHES if an index supports both EQ and ANALYZER_MATCHES + * - otherwise EQ + */ + default EqBehaviorIndexes getEqBehavior(ColumnMetadata cm) + { + Index eqOnlyIndex = null; + Index bothIndex = null; + + for (Index index : listIndexes()) + { + boolean supportsEq = index.supportsExpression(cm, Operator.EQ); + boolean supportsMatches = index.supportsExpression(cm, Operator.ANALYZER_MATCHES); + + if (supportsEq && supportsMatches) + bothIndex = index; + else if (supportsEq) + eqOnlyIndex = index; + } + + // If we have one index supporting only EQ and another supporting both, return AMBIGUOUS + if (eqOnlyIndex != null && bothIndex != null) + return EqBehaviorIndexes.ambiguous(eqOnlyIndex, bothIndex); + + // If we have an index supporting both EQ and MATCHES, return MATCHES + if (bothIndex != null) + return EqBehaviorIndexes.match(bothIndex); + + // Otherwise return EQ + return EqBehaviorIndexes.eq(eqOnlyIndex == null ? bothIndex : eqOnlyIndex); + } } diff --git a/src/java/org/apache/cassandra/index/RowFilterValidator.java b/src/java/org/apache/cassandra/index/RowFilterValidator.java deleted file mode 100644 index fb70fbfc1452..000000000000 --- a/src/java/org/apache/cassandra/index/RowFilterValidator.java +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Copyright DataStax, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.index; - -import java.util.HashSet; -import java.util.Set; -import java.util.StringJoiner; - -import org.apache.cassandra.cql3.Operator; -import org.apache.cassandra.db.filter.RowFilter; -import org.apache.cassandra.index.sai.analyzer.AnalyzerEqOperatorSupport; -import org.apache.cassandra.schema.ColumnMetadata; -import org.apache.cassandra.service.ClientWarn; - -/** - * Class for validating the index-related aspects of a {@link RowFilter}, without considering what index is actually used. - *

- * It will emit a client warning when a query has EQ restrictions on columns having an analyzed index. - */ -class RowFilterValidator -{ - private final Iterable allIndexes; - - private Set columns; - private Set indexes; - - private RowFilterValidator(Iterable allIndexes) - { - this.allIndexes = allIndexes; - } - - private void addEqRestriction(ColumnMetadata column) - { - for (Index index : allIndexes) - { - if (index.supportsExpression(column, Operator.EQ) && - index.supportsExpression(column, Operator.ANALYZER_MATCHES)) - { - if (columns == null) - columns = new HashSet<>(); - columns.add(column); - - if (indexes == null) - indexes = new HashSet<>(); - indexes.add(index); - } - } - } - - private void validate() - { - if (columns == null || indexes == null) - return; - - StringJoiner columnNames = new StringJoiner(", "); - StringJoiner indexNames = new StringJoiner(", "); - columns.forEach(column -> columnNames.add(column.name.toString())); - indexes.forEach(index -> indexNames.add(index.getIndexMetadata().name)); - - ClientWarn.instance.warn(String.format(AnalyzerEqOperatorSupport.EQ_RESTRICTION_ON_ANALYZED_WARNING, columnNames, indexNames)); - } - - /** - * Emits a client warning if the filter contains EQ restrictions on columns having an analyzed index. - * - * @param filter the filter to validate - * @param indexes the existing indexes - */ - public static void validate(RowFilter filter, Iterable indexes) - { - RowFilterValidator validator = new RowFilterValidator(indexes); - validate(filter.root(), validator); - validator.validate(); - } - - private static void validate(RowFilter.FilterElement element, RowFilterValidator validator) - { - for (RowFilter.Expression expression : element.expressions()) - { - if (expression.operator() == Operator.EQ) - validator.addEqRestriction(expression.column()); - } - - for (RowFilter.FilterElement child : element.children()) - { - validate(child, validator); - } - } -} diff --git a/src/java/org/apache/cassandra/index/SecondaryIndexManager.java b/src/java/org/apache/cassandra/index/SecondaryIndexManager.java index b1930e741804..d4c75ed2a4f1 100644 --- a/src/java/org/apache/cassandra/index/SecondaryIndexManager.java +++ b/src/java/org/apache/cassandra/index/SecondaryIndexManager.java @@ -1277,12 +1277,6 @@ public void validate(PartitionUpdate update) throws InvalidRequestException index.validate(update); } - @Override - public void validate(RowFilter filter) - { - RowFilterValidator.validate(filter, indexes.values()); - } - /* * IndexRegistry methods */ diff --git a/src/java/org/apache/cassandra/index/sai/IndexContext.java b/src/java/org/apache/cassandra/index/sai/IndexContext.java index 7820400a6c25..1de515e16267 100644 --- a/src/java/org/apache/cassandra/index/sai/IndexContext.java +++ b/src/java/org/apache/cassandra/index/sai/IndexContext.java @@ -717,8 +717,8 @@ public boolean supports(Operator op) { if (op.isLike() || op == Operator.LIKE) return false; // Analyzed columns store the indexed result, so we are unable to compute raw equality. - // The only supported operator is ANALYZER_MATCHES. - if (op == Operator.ANALYZER_MATCHES) return isAnalyzed; + // The only supported operators are ANALYZER_MATCHES and BM25. + if (op == Operator.ANALYZER_MATCHES || op == Operator.BM25) return isAnalyzed; // If the column is analyzed and the operator is EQ, we need to check if the analyzer supports it. if (op == Operator.EQ && isAnalyzed && !analyzerFactory.supportsEquals()) @@ -742,7 +742,6 @@ public boolean supports(Operator op) || column.type instanceof IntegerType); // Currently truncates to 20 bytes Expression.Op operator = Expression.Op.valueOf(op); - if (isNonFrozenCollection()) { if (indexType == IndexTarget.Type.KEYS) @@ -754,17 +753,12 @@ public boolean supports(Operator op) return indexType == IndexTarget.Type.KEYS_AND_VALUES && (operator == Expression.Op.EQ || operator == Expression.Op.NOT_EQ || operator == Expression.Op.RANGE); } - if (indexType == IndexTarget.Type.FULL) return operator == Expression.Op.EQ; - AbstractType validator = getValidator(); - if (operator == Expression.Op.IN) return true; - if (operator != Expression.Op.EQ && EQ_ONLY_TYPES.contains(validator)) return false; - // RANGE only applicable to non-literal indexes return (operator != null) && !(TypeUtil.isLiteral(validator) && operator == Expression.Op.RANGE); } diff --git a/src/java/org/apache/cassandra/index/sai/QueryContext.java b/src/java/org/apache/cassandra/index/sai/QueryContext.java index a0b1719049ec..5f1f7cda8498 100644 --- a/src/java/org/apache/cassandra/index/sai/QueryContext.java +++ b/src/java/org/apache/cassandra/index/sai/QueryContext.java @@ -36,7 +36,7 @@ @NotThreadSafe public class QueryContext { - private static final boolean DISABLE_TIMEOUT = Boolean.getBoolean("cassandra.sai.test.disable.timeout"); + public static final boolean DISABLE_TIMEOUT = Boolean.getBoolean("cassandra.sai.test.disable.timeout"); protected final long queryStartTimeNanos; diff --git a/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java b/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java index 28cb42ab8dc8..12557cd66a2d 100644 --- a/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java +++ b/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java @@ -299,20 +299,26 @@ public static Map validateOptions(Map options, T throw new InvalidRequestException("Failed to retrieve target column for: " + targetColumn); } - // In order to support different index target on non-frozen map, ie. KEYS, VALUE, ENTRIES, we need to put index - // name as part of index file name instead of column name. We only need to check that the target is different - // between indexes. This will only allow indexes in the same column with a different IndexTarget.Type. - // - // Note that: "metadata.indexes" already includes current index - if (metadata.indexes.stream().filter(index -> index.getIndexClassName().equals(StorageAttachedIndex.class.getName())) - .map(index -> TargetParser.parse(metadata, index.options.get(IndexTarget.TARGET_OPTION_NAME))) - .filter(Objects::nonNull).filter(t -> t.equals(target)).count() > 1) - { - throw new InvalidRequestException("Cannot create more than one storage-attached index on the same column: " + target.left); - } + // Check for duplicate indexes considering both target and analyzer configuration + boolean isAnalyzed = AbstractAnalyzer.isAnalyzed(options); + long duplicateCount = metadata.indexes.stream() + .filter(index -> index.getIndexClassName().equals(StorageAttachedIndex.class.getName())) + .filter(index -> { + // Indexes on the same column with different target (KEYS, VALUES, ENTRIES) + // are allowed on non-frozen Maps + var existingTarget = TargetParser.parse(metadata, index.options.get(IndexTarget.TARGET_OPTION_NAME)); + if (existingTarget == null || !existingTarget.equals(target)) + return false; + // Also allow different indexes if one is analyzed and the other isn't + return isAnalyzed == AbstractAnalyzer.isAnalyzed(index.options); + }) + .count(); + // >1 because "metadata.indexes" already includes current index + if (duplicateCount > 1) + throw new InvalidRequestException(String.format("Cannot create duplicate storage-attached index on column: %s", target.left)); // Analyzer is not supported against PK columns - if (AbstractAnalyzer.isAnalyzed(options)) + if (isAnalyzed) { for (ColumnMetadata column : metadata.primaryKeyColumns()) { diff --git a/src/java/org/apache/cassandra/index/sai/analyzer/AnalyzerEqOperatorSupport.java b/src/java/org/apache/cassandra/index/sai/analyzer/AnalyzerEqOperatorSupport.java index 116bc7f62832..30408c9b986f 100644 --- a/src/java/org/apache/cassandra/index/sai/analyzer/AnalyzerEqOperatorSupport.java +++ b/src/java/org/apache/cassandra/index/sai/analyzer/AnalyzerEqOperatorSupport.java @@ -49,7 +49,7 @@ public class AnalyzerEqOperatorSupport OPTION, Arrays.toString(Value.values())); public static final String EQ_RESTRICTION_ON_ANALYZED_WARNING = - String.format("Columns [%%s] are restricted by '=' and have analyzed indexes [%%s] able to process those restrictions. " + + String.format("Column [%%s] is restricted by '=' and has an analyzed index [%%s] able to process those restrictions. " + "Analyzed indexes might process '=' restrictions in a way that is inconsistent with non-indexed queries. " + "While '=' is still supported on analyzed indexes for backwards compatibility, " + "it is recommended to use the ':' operator instead to prevent the ambiguity. " + @@ -58,6 +58,13 @@ public class AnalyzerEqOperatorSupport "please use '%s':'%s' in the index options.", OPTION, Value.UNSUPPORTED.toString().toLowerCase()); + public static final String EQ_AMBIGUOUS_ERROR = + String.format("Column [%%s] equality predicate is ambiguous. It has both an analyzed index [%%s] configured with '%s':'%s', " + + "and an un-analyzed index [%%s]. " + + "To avoid ambiguity, drop the analyzed index and recreate it with option '%s':'%s'.", + OPTION, Value.MATCH.toString().toLowerCase(), OPTION, Value.UNSUPPORTED.toString().toLowerCase()); + + public static final String LWT_CONDITION_ON_ANALYZED_WARNING = "Index analyzers not applied to LWT conditions on columns [%s]."; diff --git a/src/java/org/apache/cassandra/index/sai/disk/MemtableTermsIterator.java b/src/java/org/apache/cassandra/index/sai/disk/MemtableTermsIterator.java index 9f0253e61ed8..615f3cfb7a4e 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/MemtableTermsIterator.java +++ b/src/java/org/apache/cassandra/index/sai/disk/MemtableTermsIterator.java @@ -19,11 +19,11 @@ import java.nio.ByteBuffer; import java.util.Iterator; +import java.util.List; import com.google.common.base.Preconditions; -import com.carrotsearch.hppc.IntArrayList; -import com.carrotsearch.hppc.cursors.IntCursor; +import org.apache.cassandra.index.sai.memory.RowMapping; import org.apache.cassandra.utils.Pair; import org.apache.cassandra.utils.bytecomparable.ByteComparable; @@ -34,16 +34,16 @@ public class MemtableTermsIterator implements TermsIterator { private final ByteBuffer minTerm; private final ByteBuffer maxTerm; - private final Iterator> iterator; + private final Iterator>> iterator; - private Pair current; + private Pair> current; private int maxSSTableRowId = -1; private int minSSTableRowId = Integer.MAX_VALUE; public MemtableTermsIterator(ByteBuffer minTerm, ByteBuffer maxTerm, - Iterator> iterator) + Iterator>> iterator) { Preconditions.checkArgument(iterator != null); this.minTerm = minTerm; @@ -69,22 +69,24 @@ public void close() {} @Override public PostingList postings() { - final IntArrayList list = current.right; + var list = current.right; assert list.size() > 0; - final int minSegmentRowID = list.get(0); - final int maxSegmentRowID = list.get(list.size() - 1); + final int minSegmentRowID = list.get(0).rowId; + final int maxSegmentRowID = list.get(list.size() - 1).rowId; // Because we are working with postings from the memtable, there is only one segment, so segment row ids // and sstable row ids are the same. minSSTableRowId = Math.min(minSSTableRowId, minSegmentRowID); maxSSTableRowId = Math.max(maxSSTableRowId, maxSegmentRowID); - final Iterator it = list.iterator(); + var it = list.iterator(); return new PostingList() { + int frequency; + @Override public int nextPosting() { @@ -93,7 +95,9 @@ public int nextPosting() return END_OF_STREAM; } - return it.next().value; + var rowIdWithFrequency = it.next(); + frequency = rowIdWithFrequency.frequency; + return rowIdWithFrequency.rowId; } @Override @@ -102,6 +106,12 @@ public int size() return list.size(); } + @Override + public int frequency() + { + return frequency; + } + @Override public int advance(int targetRowID) { diff --git a/src/java/org/apache/cassandra/index/sai/disk/PostingList.java b/src/java/org/apache/cassandra/index/sai/disk/PostingList.java index b7dffb972599..4959c6f0be6a 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/PostingList.java +++ b/src/java/org/apache/cassandra/index/sai/disk/PostingList.java @@ -42,6 +42,14 @@ default void close() throws IOException {} */ int nextPosting() throws IOException; + /** + * @return the number of occurrences of the term in the current row (the one most recently returned by nextPosting). + */ + default int frequency() + { + return 1; + } + int size(); /** diff --git a/src/java/org/apache/cassandra/index/sai/disk/PrimaryKeyWithSource.java b/src/java/org/apache/cassandra/index/sai/disk/PrimaryKeyWithSource.java index 7f9f7e89f51e..d303bd4d27c7 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/PrimaryKeyWithSource.java +++ b/src/java/org/apache/cassandra/index/sai/disk/PrimaryKeyWithSource.java @@ -18,6 +18,7 @@ package org.apache.cassandra.index.sai.disk; +import io.github.jbellis.jvector.util.RamUsageEstimator; import org.apache.cassandra.db.Clustering; import org.apache.cassandra.db.DecoratedKey; import org.apache.cassandra.dht.Token; @@ -127,4 +128,14 @@ public String toString() { return String.format("%s (source sstable: %s, %s)", primaryKey, sourceSstableId, sourceRowId); } + + @Override + public long ramBytesUsed() + { + // Object header + 3 references (primaryKey, sourceSstableId) + long value + return RamUsageEstimator.NUM_BYTES_OBJECT_HEADER + + 2L * RamUsageEstimator.NUM_BYTES_OBJECT_REF + + Long.BYTES + + primaryKey.ramBytesUsed(); + } } diff --git a/src/java/org/apache/cassandra/index/sai/disk/RAMPostingSlices.java b/src/java/org/apache/cassandra/index/sai/disk/RAMPostingSlices.java index 5879e5f1abdf..1313f81569ae 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/RAMPostingSlices.java +++ b/src/java/org/apache/cassandra/index/sai/disk/RAMPostingSlices.java @@ -25,22 +25,33 @@ import org.apache.lucene.util.mutable.MutableValueInt; /** - * Encodes postings as variable integers into slices + * Encodes postings as variable integers into "slices" of byte blocks for efficient memory usage. */ class RAMPostingSlices { static final int DEFAULT_TERM_DICT_SIZE = 1024; + /** Pool of byte blocks storing the actual posting data */ private final ByteBlockPool postingsPool; + /** true if we're also writing term frequencies for an analyzed index */ + private final boolean includeFrequencies; + + /** The starting positions of postings for each term. Term id = index in array. */ private int[] postingStarts = new int[DEFAULT_TERM_DICT_SIZE]; + /** The current write positions for each term's postings. Term id = index in array. */ private int[] postingUptos = new int[DEFAULT_TERM_DICT_SIZE]; + /** The number of postings for each term. Term id = index in array. */ private int[] sizes = new int[DEFAULT_TERM_DICT_SIZE]; - RAMPostingSlices(Counter memoryUsage) + RAMPostingSlices(Counter memoryUsage, boolean includeFrequencies) { postingsPool = new ByteBlockPool(new ByteBlockPool.DirectTrackingAllocator(memoryUsage)); + this.includeFrequencies = includeFrequencies; } + /** + * Creates and returns a PostingList for the given term ID. + */ PostingList postingList(int termID, final ByteSliceReader reader, long maxSegmentRowID) { initReader(reader, termID); @@ -49,20 +60,35 @@ PostingList postingList(int termID, final ByteSliceReader reader, long maxSegmen return new PostingList() { + int frequency = Integer.MIN_VALUE; + @Override public int nextPosting() throws IOException { if (reader.eof()) { + frequency = Integer.MIN_VALUE; return PostingList.END_OF_STREAM; } else { lastSegmentRowId.value += reader.readVInt(); + if (includeFrequencies) + frequency = reader.readVInt(); return lastSegmentRowId.value; } } + @Override + public int frequency() + { + if (!includeFrequencies) + return 1; + if (frequency <= 0) + throw new IllegalStateException("frequency() called before nextPosting()"); + return frequency; + } + @Override public int size() { @@ -77,12 +103,20 @@ public int advance(int targetRowID) }; } + /** + * Initializes a ByteSliceReader for reading postings for a specific term. + */ void initReader(ByteSliceReader reader, int termID) { final int upto = postingUptos[termID]; reader.init(postingsPool, postingStarts[termID], upto); } + /** + * Creates a new slice for storing postings for a given term ID. + * Grows the internal arrays if necessary and allocates a new block + * if the current block cannot accommodate a new slice. + */ void createNewSlice(int termID) { if (termID >= postingStarts.length - 1) @@ -103,7 +137,27 @@ void createNewSlice(int termID) postingUptos[termID] = upto + postingsPool.byteOffset; } - void writeVInt(int termID, int i) + void writePosting(int termID, int deltaRowId, int frequency) + { + assert termID >= 0 : termID; + assert deltaRowId >= 0 : deltaRowId; + writeVInt(termID, deltaRowId); + + if (includeFrequencies) + { + assert frequency > 0 : frequency; + writeVInt(termID, frequency); + } + + sizes[termID]++; + } + + /** + * Writes a variable-length integer to the posting list for a given term. + * The integer is encoded using a variable-length encoding scheme where each + * byte uses 7 bits for the value and 1 bit to indicate if more bytes follow. + */ + private void writeVInt(int termID, int i) { while ((i & ~0x7F) != 0) { @@ -111,9 +165,12 @@ void writeVInt(int termID, int i) i >>>= 7; } writeByte(termID, (byte) i); - sizes[termID]++; } + /** + * Writes a single byte to the posting list for a given term. + * If the current slice is full, it automatically allocates a new slice. + */ private void writeByte(int termID, byte b) { int upto = postingUptos[termID]; diff --git a/src/java/org/apache/cassandra/index/sai/disk/RAMStringIndexer.java b/src/java/org/apache/cassandra/index/sai/disk/RAMStringIndexer.java index 8b6f07d43597..40c7d8fa37a8 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/RAMStringIndexer.java +++ b/src/java/org/apache/cassandra/index/sai/disk/RAMStringIndexer.java @@ -18,10 +18,12 @@ package org.apache.cassandra.index.sai.disk; import java.nio.ByteBuffer; +import java.util.List; import java.util.NoSuchElementException; import com.google.common.annotations.VisibleForTesting; +import org.agrona.collections.Int2IntHashMap; import org.apache.cassandra.utils.bytecomparable.ByteComparable; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.ByteBlockPool; @@ -43,11 +45,14 @@ public class RAMStringIndexer private final Counter termsBytesUsed; private final Counter slicesBytesUsed; - private int rowCount = 0; private int[] lastSegmentRowID = new int[RAMPostingSlices.DEFAULT_TERM_DICT_SIZE]; - public RAMStringIndexer() + private final boolean writeFrequencies; + private final Int2IntHashMap docLengths = new Int2IntHashMap(Integer.MIN_VALUE); + + public RAMStringIndexer(boolean writeFrequencies) { + this.writeFrequencies = writeFrequencies; termsBytesUsed = Counter.newCounter(); slicesBytesUsed = Counter.newCounter(); @@ -55,7 +60,7 @@ public RAMStringIndexer() termsHash = new BytesRefHash(termsPool); - slices = new RAMPostingSlices(slicesBytesUsed); + slices = new RAMPostingSlices(slicesBytesUsed, writeFrequencies); } public long estimatedBytesUsed() @@ -75,7 +80,12 @@ public boolean requiresFlush() public boolean isEmpty() { - return rowCount == 0; + return docLengths.isEmpty(); + } + + public Int2IntHashMap getDocLengths() + { + return docLengths; } /** @@ -140,36 +150,58 @@ private ByteComparable asByteComparable(byte[] bytes, int offset, int length) }; } - public long add(BytesRef term, int segmentRowId) + /** + * @return bytes allocated. may be zero if the (term, row) pair is a duplicate + */ + public long addAll(List terms, int segmentRowId) { long startBytes = estimatedBytesUsed(); - int termID = termsHash.add(term); - - if (termID >= 0) - { - // firs time seeing this term, create the term's first slice ! - slices.createNewSlice(termID); - } - else - { - termID = (-termID) - 1; - } + Int2IntHashMap frequencies = new Int2IntHashMap(Integer.MIN_VALUE); + Int2IntHashMap deltas = new Int2IntHashMap(Integer.MIN_VALUE); - if (termID >= lastSegmentRowID.length - 1) + for (BytesRef term : terms) { - lastSegmentRowID = ArrayUtil.grow(lastSegmentRowID, termID + 1); - } + int termID = termsHash.add(term); + boolean firstOccurrence = termID >= 0; - int delta = segmentRowId - lastSegmentRowID[termID]; - - lastSegmentRowID[termID] = segmentRowId; + if (firstOccurrence) + { + // first time seeing this term in any row, create the term's first slice ! + slices.createNewSlice(termID); + // grow the termID -> last segment array if necessary + if (termID >= lastSegmentRowID.length - 1) + lastSegmentRowID = ArrayUtil.grow(lastSegmentRowID, termID + 1); + if (writeFrequencies) + frequencies.put(termID, 1); + } + else + { + termID = (-termID) - 1; + // compaction should call this method only with increasing segmentRowIds + assert segmentRowId >= lastSegmentRowID[termID]; + // increment frequency + if (writeFrequencies) + frequencies.put(termID, frequencies.getOrDefault(termID, 0) + 1); + // Skip computing a delta if we've already seen this term in this row + if (segmentRowId == lastSegmentRowID[termID]) + continue; + } - slices.writeVInt(termID, delta); + // Compute the delta from the last time this term was seen, to this row + int delta = segmentRowId - lastSegmentRowID[termID]; + // sanity check that we're advancing the row id, i.e. no duplicate entries. + assert firstOccurrence || delta > 0; + deltas.put(termID, delta); + lastSegmentRowID[termID] = segmentRowId; + } - long allocatedBytes = estimatedBytesUsed() - startBytes; + // add the postings now that we know the frequencies + deltas.forEachInt((termID, delta) -> { + slices.writePosting(termID, delta, frequencies.get(termID)); + }); - rowCount++; + docLengths.put(segmentRowId, terms.size()); - return allocatedBytes; + return estimatedBytesUsed() - startBytes; } } diff --git a/src/java/org/apache/cassandra/index/sai/disk/format/IndexComponentType.java b/src/java/org/apache/cassandra/index/sai/disk/format/IndexComponentType.java index 165f484eae66..bf38d6b8fbd8 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/format/IndexComponentType.java +++ b/src/java/org/apache/cassandra/index/sai/disk/format/IndexComponentType.java @@ -111,7 +111,12 @@ public enum IndexComponentType * * V1 V2 */ - GROUP_COMPLETION_MARKER("GroupComplete"); + GROUP_COMPLETION_MARKER("GroupComplete"), + + /** + * Stores document length information for BM25 scoring + */ + DOC_LENGTHS("DocLengths"); public final String representation; diff --git a/src/java/org/apache/cassandra/index/sai/disk/format/Version.java b/src/java/org/apache/cassandra/index/sai/disk/format/Version.java index 6411072486b8..77757eec4851 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/format/Version.java +++ b/src/java/org/apache/cassandra/index/sai/disk/format/Version.java @@ -34,6 +34,7 @@ import org.apache.cassandra.index.sai.disk.v4.V4OnDiskFormat; import org.apache.cassandra.index.sai.disk.v5.V5OnDiskFormat; import org.apache.cassandra.index.sai.disk.v6.V6OnDiskFormat; +import org.apache.cassandra.index.sai.disk.v7.V7OnDiskFormat; import org.apache.cassandra.index.sai.utils.TypeUtil; import org.apache.cassandra.utils.bytecomparable.ByteComparable; @@ -43,7 +44,7 @@ * Format version of indexing component, denoted as [major][minor]. Same forward-compatibility rules apply as to * {@link org.apache.cassandra.io.sstable.format.Version}. */ -public class Version +public class Version implements Comparable { // 6.8 formats public static final Version AA = new Version("aa", V1OnDiskFormat.instance, Version::aaFileNameFormat); @@ -53,15 +54,17 @@ public class Version public static final Version CA = new Version("ca", V3OnDiskFormat.instance, (c, i, g) -> stargazerFileNameFormat(c, i, g, "ca")); // NOTE: use DB to prevent collisions with upstream file formats // Encode trie entries using their AbstractType to ensure trie entries are sorted for range queries and are prefix free. - public static final Version DB = new Version("db", V4OnDiskFormat.instance, (c, i, g) -> stargazerFileNameFormat(c, i, g,"db")); + public static final Version DB = new Version("db", V4OnDiskFormat.instance, (c, i, g) -> stargazerFileNameFormat(c, i, g, "db")); // revamps vector postings lists to cause fewer reads from disk public static final Version DC = new Version("dc", V5OnDiskFormat.instance, (c, i, g) -> stargazerFileNameFormat(c, i, g, "dc")); // histograms in index metadata public static final Version EB = new Version("eb", V6OnDiskFormat.instance, (c, i, g) -> stargazerFileNameFormat(c, i, g, "eb")); + // term frequencies index component + public static final Version EC = new Version("ec", V7OnDiskFormat.instance, (c, i, g) -> stargazerFileNameFormat(c, i, g, "ec")); // These are in reverse-chronological order so that the latest version is first. Version matching tests // are more likely to match the latest version so we want to test that one first. - public static final List ALL = Lists.newArrayList(EB, DC, DB, CA, BA, AA); + public static final List ALL = Lists.newArrayList(EC, EB, DC, DB, CA, BA, AA); public static final Version EARLIEST = AA; public static final Version VECTOR_EARLIEST = BA; @@ -87,7 +90,8 @@ public static Version parse(String input) { checkArgument(input != null); checkArgument(input.length() == 2); - for (var v : ALL) { + for (var v : ALL) + { if (input.equals(v.version)) return v; } @@ -110,7 +114,7 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - Version other = (Version)o; + Version other = (Version) o; return Objects.equal(version, other.version); } @@ -143,6 +147,11 @@ public boolean useImmutableComponentFiles() return CassandraRelevantProperties.IMMUTABLE_SAI_COMPONENTS.getBoolean() && onOrAfter(Version.CA); } + @Override + public int compareTo(Version other) + { + return this.version.compareTo(other.version); + } public interface FileNameFormatter { @@ -152,7 +161,7 @@ public interface FileNameFormatter */ default String format(IndexComponentType indexComponentType, IndexContext indexContext, int generation) { - return format(indexComponentType, indexContext == null ? null : indexContext.getIndexName(), generation); + return format(indexComponentType, indexContext == null ? null : indexContext.getIndexName(), generation); } /** @@ -160,8 +169,8 @@ default String format(IndexComponentType indexComponentType, IndexContext indexC * filename is returned (so the suffix of the full filename), not a full path. * * @param indexComponentType the type of the index component. - * @param indexName the name of the index, or {@code null} for a per-sstable component. - * @param generation the generation of the build of the component. + * @param indexName the name of the index, or {@code null} for a per-sstable component. + * @param generation the generation of the build of the component. */ String format(IndexComponentType indexComponentType, @Nullable String indexName, int generation); } @@ -191,7 +200,6 @@ else if (componentStr.startsWith("SAI" + SAI_SEPARATOR)) return tryParseStargazerFileName(componentStr); else return Optional.empty(); - } public static class ParsedFileName @@ -222,7 +230,7 @@ private static String aaFileNameFormat(IndexComponentType indexComponentType, @N StringBuilder stringBuilder = new StringBuilder(); stringBuilder.append(indexName == null ? String.format(VERSION_AA_PER_SSTABLE_FORMAT, indexComponentType.representation) - : String.format(VERSION_AA_PER_INDEX_FORMAT, indexName, indexComponentType.representation)); + : String.format(VERSION_AA_PER_INDEX_FORMAT, indexName, indexComponentType.representation)); return stringBuilder.toString(); } diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/DocLengthsReader.java b/src/java/org/apache/cassandra/index/sai/disk/v1/DocLengthsReader.java new file mode 100644 index 000000000000..45769f9337d8 --- /dev/null +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/DocLengthsReader.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.index.sai.disk.v1; + +import java.io.Closeable; +import java.io.IOException; + +import org.apache.cassandra.index.sai.disk.io.IndexInputReader; +import org.apache.cassandra.index.sai.utils.IndexFileUtils; +import org.apache.cassandra.index.sai.utils.SAICodecUtils; +import org.apache.cassandra.io.util.FileHandle; +import org.apache.cassandra.io.util.FileUtils; +import org.apache.lucene.codecs.CodecUtil; + +public class DocLengthsReader implements Closeable +{ + private final FileHandle fileHandle; + private final IndexInputReader input; + private final SegmentMetadata.ComponentMetadata componentMetadata; + + public DocLengthsReader(FileHandle fileHandle, SegmentMetadata.ComponentMetadata componentMetadata) + { + this.fileHandle = fileHandle; + this.input = IndexFileUtils.instance.openInput(fileHandle); + this.componentMetadata = componentMetadata; + } + + public int get(int rowID) throws IOException + { + // Account for header size in offset calculation + long position = componentMetadata.offset + (long) rowID * Integer.BYTES; + if (position >= componentMetadata.offset + componentMetadata.length) + return 0; + input.seek(position); + return input.readInt(); + } + + @Override + public void close() throws IOException + { + FileUtils.close(fileHandle, input); + } +} + diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java index 52988a685cd5..16d14bdd8ac3 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java @@ -68,9 +68,7 @@ public abstract class IndexSearcher implements Closeable, SegmentOrdering protected final SegmentMetadata metadata; protected final IndexContext indexContext; - private static final SSTableReadsListener NOOP_LISTENER = new SSTableReadsListener() {}; - - private final ColumnFilter columnFilter; + protected final ColumnFilter columnFilter; protected IndexSearcher(PrimaryKeyMap.Factory primaryKeyMapFactory, PerIndexFiles perIndexFiles, @@ -90,30 +88,36 @@ protected IndexSearcher(PrimaryKeyMap.Factory primaryKeyMapFactory, public abstract long indexFileCacheSize(); /** - * Search on-disk index synchronously + * Search on-disk index synchronously. Used for WHERE clause predicates, including BOUNDED_ANN. * * @param expression to filter on disk index * @param keyRange key range specific in read command, used by ANN index * @param queryContext to track per sstable cache and per query metrics * @param defer create the iterator in a deferred state - * @param limit the num of rows to returned, used by ANN index * @return {@link KeyRangeIterator} that matches given expression */ - public abstract KeyRangeIterator search(Expression expression, AbstractBounds keyRange, QueryContext queryContext, boolean defer, int limit) throws IOException; + public abstract KeyRangeIterator search(Expression expression, AbstractBounds keyRange, QueryContext queryContext, boolean defer) throws IOException; /** - * Order the on-disk index synchronously and produce an iterator in score order + * Order the rows by the given Orderer. Used for ORDER BY clause when + * (1) the WHERE predicate is either a partition restriction or a range restriction on the index, + * (2) there is no WHERE predicate, or + * (3) the planner determines it is better to post-filter the ordered results by the predicate. * * @param orderer the object containing the ordering logic * @param slice optional predicate to get a slice of the index * @param keyRange key range specific in read command, used by ANN index * @param queryContext to track per sstable cache and per query metrics - * @param limit the num of rows to returned, used by ANN index + * @param limit the initial num of rows to returned, used by ANN index. More rows may be requested if filtering throws away more than expected! * @return an iterator of {@link PrimaryKeyWithSortKey} in score order */ public abstract CloseableIterator orderBy(Orderer orderer, Expression slice, AbstractBounds keyRange, QueryContext queryContext, int limit) throws IOException; - + /** + * Order the rows by the given Orderer. Used for ORDER BY clause when the WHERE predicates + * have been applied first, yielding a list of primary keys. Again, `limit` is a planner hint for ANN to determine + * the initial number of results returned, not a maximum. + */ @Override public CloseableIterator orderResultsBy(SSTableReader reader, QueryContext context, List keys, Orderer orderer, int limit) throws IOException { @@ -124,7 +128,7 @@ public CloseableIterator orderResultsBy(SSTableReader rea { var slices = Slices.with(indexContext.comparator(), Slice.make(key.clustering())); // TODO if we end up needing to read the row still, is it better to store offset and use reader.unfilteredAt? - try (var iter = reader.iterator(key.partitionKey(), slices, columnFilter, false, NOOP_LISTENER)) + try (var iter = reader.iterator(key.partitionKey(), slices, columnFilter, false, SSTableReadsListener.NOOP_LISTENER)) { if (iter.hasNext()) { diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java index 15aaa349e1c8..08a63818c709 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java @@ -19,15 +19,27 @@ package org.apache.cassandra.index.sai.disk.v1; import java.io.IOException; +import java.io.UncheckedIOException; import java.lang.invoke.MethodHandles; +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; import com.google.common.base.MoreObjects; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.cassandra.cql3.Operator; import org.apache.cassandra.db.PartitionPosition; +import org.apache.cassandra.db.Slice; +import org.apache.cassandra.db.Slices; +import org.apache.cassandra.db.rows.Cell; +import org.apache.cassandra.db.rows.Row; import org.apache.cassandra.dht.AbstractBounds; +import org.apache.cassandra.exceptions.InvalidRequestException; import org.apache.cassandra.index.sai.IndexContext; import org.apache.cassandra.index.sai.QueryContext; import org.apache.cassandra.index.sai.SSTableContext; @@ -35,24 +47,31 @@ import org.apache.cassandra.index.sai.disk.TermsIterator; import org.apache.cassandra.index.sai.disk.format.IndexComponentType; import org.apache.cassandra.index.sai.disk.format.Version; +import org.apache.cassandra.index.sai.disk.v1.postings.IntersectingPostingList; import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; import org.apache.cassandra.index.sai.metrics.MulticastQueryEventListeners; import org.apache.cassandra.index.sai.metrics.QueryEventListener; import org.apache.cassandra.index.sai.plan.Expression; import org.apache.cassandra.index.sai.plan.Orderer; +import org.apache.cassandra.index.sai.utils.BM25Utils; +import org.apache.cassandra.index.sai.utils.BM25Utils.DocTF; +import org.apache.cassandra.index.sai.utils.PrimaryKey; import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey; import org.apache.cassandra.index.sai.utils.RowIdWithByteComparable; import org.apache.cassandra.index.sai.utils.SAICodecUtils; -import org.apache.cassandra.index.sai.utils.SegmentOrdering; +import org.apache.cassandra.io.sstable.format.SSTableReader; +import org.apache.cassandra.io.sstable.format.SSTableReadsListener; import org.apache.cassandra.io.util.FileUtils; import org.apache.cassandra.utils.AbstractIterator; import org.apache.cassandra.utils.CloseableIterator; import org.apache.cassandra.utils.bytecomparable.ByteComparable; +import static org.apache.cassandra.index.sai.disk.PostingList.END_OF_STREAM; + /** * Executes {@link Expression}s against the trie-based terms dictionary for an individual index segment. */ -public class InvertedIndexSearcher extends IndexSearcher implements SegmentOrdering +public class InvertedIndexSearcher extends IndexSearcher { private static final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); @@ -60,6 +79,9 @@ public class InvertedIndexSearcher extends IndexSearcher implements SegmentOrder private final QueryEventListener.TrieIndexEventListener perColumnEventListener; private final Version version; private final boolean filterRangeResults; + private final SSTableReader sstable; + private final DocLengthsReader docLengthsReader; + private final long segmentRowIdOffset; protected InvertedIndexSearcher(SSTableContext sstableContext, PerIndexFiles perIndexFiles, @@ -69,6 +91,7 @@ protected InvertedIndexSearcher(SSTableContext sstableContext, boolean filterRangeResults) throws IOException { super(sstableContext.primaryKeyMapFactory(), perIndexFiles, segmentMetadata, indexContext); + this.sstable = sstableContext.sstable; long root = metadata.getIndexRoot(IndexComponentType.TERMS_DATA); assert root >= 0; @@ -76,6 +99,9 @@ protected InvertedIndexSearcher(SSTableContext sstableContext, this.version = version; this.filterRangeResults = filterRangeResults; perColumnEventListener = (QueryEventListener.TrieIndexEventListener)indexContext.getColumnQueryMetrics(); + var docLengthsMeta = segmentMetadata.componentMetadatas.getOptional(IndexComponentType.DOC_LENGTHS); + this.segmentRowIdOffset = segmentMetadata.segmentRowIdOffset; + this.docLengthsReader = docLengthsMeta == null ? null : new DocLengthsReader(indexFiles.docLengths(), docLengthsMeta); Map map = metadata.componentMetadatas.get(IndexComponentType.TERMS_DATA).attributes; String footerPointerString = map.get(SAICodecUtils.FOOTER_POINTER); @@ -100,7 +126,7 @@ public long indexFileCacheSize() } @SuppressWarnings("resource") - public KeyRangeIterator search(Expression exp, AbstractBounds keyRange, QueryContext context, boolean defer, int limit) throws IOException + public KeyRangeIterator search(Expression exp, AbstractBounds keyRange, QueryContext context, boolean defer) throws IOException { PostingList postingList = searchPosting(exp, context); return toPrimaryKeyIterator(postingList, context); @@ -129,11 +155,117 @@ else if (exp.getOp() == Expression.Op.RANGE) throw new IllegalArgumentException(indexContext.logMessage("Unsupported expression: " + exp)); } + private Cell readColumn(SSTableReader sstable, PrimaryKey primaryKey) + { + var dk = primaryKey.partitionKey(); + var slices = Slices.with(indexContext.comparator(), Slice.make(primaryKey.clustering())); + try (var rowIterator = sstable.iterator(dk, slices, columnFilter, false, SSTableReadsListener.NOOP_LISTENER)) + { + var unfiltered = rowIterator.next(); + assert unfiltered.isRow() : unfiltered; + Row row = (Row) unfiltered; + return row.getCell(indexContext.getDefinition()); + } + } + @Override public CloseableIterator orderBy(Orderer orderer, Expression slice, AbstractBounds keyRange, QueryContext queryContext, int limit) throws IOException { - var iter = new RowIdWithTermsIterator(reader.allTerms(orderer.isAscending())); - return toMetaSortedIterator(iter, queryContext); + if (!orderer.isBM25()) + { + var iter = new RowIdWithTermsIterator(reader.allTerms(orderer.isAscending())); + return toMetaSortedIterator(iter, queryContext); + } + if (docLengthsReader == null) + throw new InvalidRequestException(indexContext.getIndexName() + " does not support BM25 scoring until it is rebuilt"); + + // find documents that match each term + var queryTerms = orderer.getQueryTerms(); + var postingLists = queryTerms.stream() + .collect(Collectors.toMap(Function.identity(), term -> + { + var encodedTerm = version.onDiskFormat().encodeForTrie(term, indexContext.getValidator()); + var listener = MulticastQueryEventListeners.of(queryContext, perColumnEventListener); + var postings = reader.exactMatch(encodedTerm, listener, queryContext); + return postings == null ? PostingList.EMPTY : postings; + })); + // extract the match count for each + var documentFrequencies = postingLists.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> (long) e.getValue().size())); + + var pkm = primaryKeyMapFactory.newPerSSTablePrimaryKeyMap(); + var merged = IntersectingPostingList.intersect(postingLists); + + // Wrap the iterator with resource management + var it = new AbstractIterator() { // Anonymous class extends AbstractIterator + private boolean closed; + + @Override + protected DocTF computeNext() + { + try + { + int rowId = merged.nextPosting(); + if (rowId == PostingList.END_OF_STREAM) + return endOfData(); + int docLength = docLengthsReader.get(rowId); // segment-local rowid + var pk = pkm.primaryKeyFromRowId(segmentRowIdOffset + rowId); // sstable-global rowid + return new DocTF(pk, docLength, merged.frequencies()); + } + catch (IOException e) + { + throw new UncheckedIOException(e); + } + } + + @Override + public void close() + { + if (closed) return; + closed = true; + FileUtils.closeQuietly(pkm, merged); + } + }; + return bm25Internal(it, queryTerms, documentFrequencies); + } + + private CloseableIterator bm25Internal(CloseableIterator keyIterator, + List queryTerms, + Map documentFrequencies) + { + var totalRows = sstable.getTotalRows(); + // since doc frequencies can be an estimate from the index histogram, which does not have bounded error, + // cap frequencies to total rows so that the IDF term doesn't turn negative + var cappedFrequencies = documentFrequencies.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> Math.min(e.getValue(), totalRows))); + var docStats = new BM25Utils.DocStats(cappedFrequencies, totalRows); + return BM25Utils.computeScores(keyIterator, + queryTerms, + docStats, + indexContext, + sstable.descriptor.id); + } + + @Override + public CloseableIterator orderResultsBy(SSTableReader reader, QueryContext queryContext, List keys, Orderer orderer, int limit) throws IOException + { + if (!orderer.isBM25()) + return super.orderResultsBy(reader, queryContext, keys, orderer, limit); + if (docLengthsReader == null) + throw new InvalidRequestException(indexContext.getIndexName() + " does not support BM25 scoring until it is rebuilt"); + + var queryTerms = orderer.getQueryTerms(); + // compute documentFrequencies from either histogram or an index search + var documentFrequencies = new HashMap(); + // any index new enough to support BM25 should also support histograms + assert metadata.version.onDiskFormat().indexFeatureSet().hasTermsHistogram(); + for (ByteBuffer term : queryTerms) + { + long matches = metadata.estimateNumRowsMatching(new Expression(indexContext).add(Operator.ANALYZER_MATCHES, term)); + documentFrequencies.put(term, matches); + } + var analyzer = indexContext.getAnalyzerFactory().create(); + var it = keys.stream().map(pk -> DocTF.createFromDocument(pk, readColumn(sstable, pk), analyzer, queryTerms)).iterator(); + return bm25Internal(CloseableIterator.wrap(it), queryTerms, documentFrequencies); } @Override @@ -147,7 +279,7 @@ public String toString() @Override public void close() { - reader.close(); + FileUtils.closeQuietly(reader, docLengthsReader); } /** @@ -172,7 +304,7 @@ protected RowIdWithByteComparable computeNext() while (true) { long nextPosting = currentPostingList.nextPosting(); - if (nextPosting != PostingList.END_OF_STREAM) + if (nextPosting != END_OF_STREAM) return new RowIdWithByteComparable(Math.toIntExact(nextPosting), currentTerm); if (!source.hasNext()) diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcher.java index 8a2aa354bd47..5911f2351014 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcher.java @@ -84,7 +84,7 @@ public long indexFileCacheSize() } @Override - public KeyRangeIterator search(Expression exp, AbstractBounds keyRange, QueryContext context, boolean defer, int limit) throws IOException + public KeyRangeIterator search(Expression exp, AbstractBounds keyRange, QueryContext context, boolean defer) throws IOException { PostingList postingList = searchPosting(exp, context); return toPrimaryKeyIterator(postingList, context); diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/MemtableIndexWriter.java b/src/java/org/apache/cassandra/index/sai/disk/v1/MemtableIndexWriter.java index 102699ca0288..8683409896ad 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/MemtableIndexWriter.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/MemtableIndexWriter.java @@ -19,15 +19,15 @@ import java.io.IOException; import java.lang.invoke.MethodHandles; +import java.util.Arrays; import java.util.Collections; -import java.util.Iterator; import java.util.concurrent.TimeUnit; import com.google.common.base.Stopwatch; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import com.carrotsearch.hppc.IntArrayList; +import org.agrona.collections.Int2IntHashMap; import org.apache.cassandra.db.DecoratedKey; import org.apache.cassandra.db.marshal.AbstractType; import org.apache.cassandra.db.rows.Row; @@ -36,17 +36,17 @@ import org.apache.cassandra.index.sai.disk.PerIndexWriter; import org.apache.cassandra.index.sai.disk.format.IndexComponents; import org.apache.cassandra.index.sai.disk.format.Version; -import org.apache.cassandra.index.sai.disk.vector.VectorMemtableIndex; import org.apache.cassandra.index.sai.disk.v1.kdtree.ImmutableOneDimPointValues; import org.apache.cassandra.index.sai.disk.v1.kdtree.NumericIndexWriter; import org.apache.cassandra.index.sai.disk.v1.trie.InvertedIndexWriter; +import org.apache.cassandra.index.sai.disk.vector.VectorMemtableIndex; import org.apache.cassandra.index.sai.memory.MemtableIndex; import org.apache.cassandra.index.sai.memory.RowMapping; +import org.apache.cassandra.index.sai.memory.TrieMemoryIndex; +import org.apache.cassandra.index.sai.memory.TrieMemtableIndex; import org.apache.cassandra.index.sai.utils.PrimaryKey; import org.apache.cassandra.index.sai.utils.TypeUtil; import org.apache.cassandra.utils.ByteBufferUtil; -import org.apache.cassandra.utils.Pair; -import org.apache.cassandra.utils.bytecomparable.ByteComparable; /** * Column index writer that flushes indexed data directly from the corresponding Memtable index, without buffering index @@ -126,7 +126,7 @@ public void complete(Stopwatch stopwatch) throws IOException } else { - final Iterator> iterator = rowMapping.merge(memtableIndex); + var iterator = rowMapping.merge(memtableIndex); try (MemtableTermsIterator terms = new MemtableTermsIterator(memtableIndex.getMinTerm(), memtableIndex.getMaxTerm(), iterator)) { long cellCount = flush(minKey, maxKey, indexContext().getValidator(), terms, rowMapping.maxSegmentRowId); @@ -151,9 +151,21 @@ private long flush(DecoratedKey minKey, DecoratedKey maxKey, AbstractType ter SegmentMetadata.ComponentMetadataMap indexMetas; if (TypeUtil.isLiteral(termComparator)) { - try (InvertedIndexWriter writer = new InvertedIndexWriter(perIndexComponents)) + try (InvertedIndexWriter writer = new InvertedIndexWriter(perIndexComponents, writeFrequencies())) { - indexMetas = writer.writeAll(metadataBuilder.intercept(terms)); + // Convert PrimaryKey->length map to rowId->length using RowMapping + var docLengths = new Int2IntHashMap(Integer.MIN_VALUE); + Arrays.stream(((TrieMemtableIndex) memtableIndex).getRangeIndexes()) + .map(TrieMemoryIndex.class::cast) + .forEach(trieMemoryIndex -> + trieMemoryIndex.getDocLengths().forEach((pk, length) -> { + int rowId = rowMapping.get(pk); + if (rowId >= 0) + docLengths.put(rowId, (int) length); + }) + ); + + indexMetas = writer.writeAll(metadataBuilder.intercept(terms), docLengths); numRows = writer.getPostingsCount(); } } @@ -194,6 +206,11 @@ private long flush(DecoratedKey minKey, DecoratedKey maxKey, AbstractType ter return numRows; } + private boolean writeFrequencies() + { + return indexContext().isAnalyzed() && Version.latest().onOrAfter(Version.EC); + } + private void flushVectorIndex(DecoratedKey minKey, DecoratedKey maxKey, long startTime, Stopwatch stopwatch) throws IOException { var vectorIndex = (VectorMemtableIndex) memtableIndex; diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/PartitionAwarePrimaryKeyFactory.java b/src/java/org/apache/cassandra/index/sai/disk/v1/PartitionAwarePrimaryKeyFactory.java index 3b49e22a84fd..c1b47916229b 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/PartitionAwarePrimaryKeyFactory.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/PartitionAwarePrimaryKeyFactory.java @@ -21,6 +21,7 @@ import java.util.Objects; import java.util.function.Supplier; +import io.github.jbellis.jvector.util.RamUsageEstimator; import org.apache.cassandra.db.Clustering; import org.apache.cassandra.db.DecoratedKey; import org.apache.cassandra.dht.Token; @@ -136,6 +137,19 @@ private ByteSource asComparableBytes(int terminator, ByteComparable.Version vers return ByteSource.withTerminator(terminator, tokenComparable, keyComparable, null); } + @Override + public long ramBytesUsed() + { + // Compute shallow size: object header + 4 references (3 declared + 1 implicit outer reference) + long shallowSize = RamUsageEstimator.NUM_BYTES_OBJECT_HEADER + 4L * RamUsageEstimator.NUM_BYTES_OBJECT_REF; + long preHashedDecoratedKeySize = partitionKey == null + ? 0 + : RamUsageEstimator.NUM_BYTES_OBJECT_HEADER + + 2L * RamUsageEstimator.NUM_BYTES_OBJECT_REF // token and key references + + 2L * Long.BYTES; + return shallowSize + token.getHeapSize() + preHashedDecoratedKeySize; + } + @Override public int compareTo(PrimaryKey o) { diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/PerIndexFiles.java b/src/java/org/apache/cassandra/index/sai/disk/v1/PerIndexFiles.java index e4a1ff6bca9e..b82c23cc589d 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/PerIndexFiles.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/PerIndexFiles.java @@ -28,6 +28,7 @@ import org.apache.cassandra.index.sai.disk.format.IndexComponents; import org.apache.cassandra.index.sai.disk.format.IndexComponentType; +import org.apache.cassandra.index.sai.disk.format.Version; import org.apache.cassandra.io.util.FileHandle; import org.apache.cassandra.io.util.FileUtils; @@ -104,6 +105,12 @@ public FileHandle pq() return getFile(IndexComponentType.PQ).sharedCopy(); } + /** It is the caller's responsibility to close the returned file handle. */ + public FileHandle docLengths() + { + return getFile(IndexComponentType.DOC_LENGTHS).sharedCopy(); + } + public FileHandle getFile(IndexComponentType indexComponentType) { FileHandle file = files.get(indexComponentType); diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/SSTableIndexWriter.java b/src/java/org/apache/cassandra/index/sai/disk/v1/SSTableIndexWriter.java index bb85ad583cba..a4812fb120ac 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/SSTableIndexWriter.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/SSTableIndexWriter.java @@ -241,7 +241,7 @@ else if (shouldFlush(sstableRowId)) if (term.remaining() == 0 && TypeUtil.skipsEmptyValue(indexContext.getValidator())) return; - long allocated = currentBuilder.addAll(term, type, key, sstableRowId); + long allocated = currentBuilder.analyzeAndAdd(term, type, key, sstableRowId); limiter.increment(allocated); } diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/Segment.java b/src/java/org/apache/cassandra/index/sai/disk/v1/Segment.java index 87fe0a998e5f..3d52c944d726 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/Segment.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/Segment.java @@ -144,7 +144,7 @@ public long indexFileCacheSize() */ public KeyRangeIterator search(Expression expression, AbstractBounds keyRange, QueryContext context, boolean defer, int limit) throws IOException { - return index.search(expression, keyRange, context, defer, limit); + return index.search(expression, keyRange, context, defer); } /** diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentBuilder.java b/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentBuilder.java index e1cbfc318180..095b43d3926f 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentBuilder.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentBuilder.java @@ -29,6 +29,7 @@ import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.LongAdder; +import java.util.stream.Collectors; import javax.annotation.concurrent.NotThreadSafe; import com.google.common.annotations.VisibleForTesting; @@ -42,6 +43,7 @@ import org.apache.cassandra.index.sai.IndexContext; import org.apache.cassandra.index.sai.analyzer.AbstractAnalyzer; import org.apache.cassandra.index.sai.analyzer.ByteLimitedMaterializer; +import org.apache.cassandra.index.sai.analyzer.NoOpAnalyzer; import org.apache.cassandra.index.sai.disk.PostingList; import org.apache.cassandra.index.sai.disk.RAMStringIndexer; import org.apache.cassandra.index.sai.disk.TermsIterator; @@ -154,9 +156,11 @@ public boolean isEmpty() return kdTreeRamBuffer.numRows() == 0; } - protected long addInternal(ByteBuffer term, int segmentRowId) + @Override + protected long addInternal(List terms, int segmentRowId) { - TypeUtil.toComparableBytes(term, termComparator, buffer); + assert terms.size() == 1; + TypeUtil.toComparableBytes(terms.get(0), termComparator, buffer); return kdTreeRamBuffer.addPackedValue(segmentRowId, new BytesRef(buffer)); } @@ -192,10 +196,14 @@ public static class RAMStringSegmentBuilder extends SegmentBuilder { super(components, rowIdOffset, limiter); this.byteComparableVersion = components.byteComparableVersionFor(IndexComponentType.TERMS_DATA); - ramIndexer = new RAMStringIndexer(); + ramIndexer = new RAMStringIndexer(writeFrequencies()); totalBytesAllocated = ramIndexer.estimatedBytesUsed(); totalBytesAllocatedConcurrent.add(totalBytesAllocated); + } + private boolean writeFrequencies() + { + return !(analyzer instanceof NoOpAnalyzer) && Version.latest().onOrAfter(Version.EC); } public boolean isEmpty() @@ -203,21 +211,26 @@ public boolean isEmpty() return ramIndexer.isEmpty(); } - protected long addInternal(ByteBuffer term, int segmentRowId) + @Override + protected long addInternal(List terms, int segmentRowId) { - var encodedTerm = components.onDiskFormat().encodeForTrie(term, termComparator); - var bytes = ByteSourceInverse.readBytes(encodedTerm.asComparableBytes(byteComparableVersion)); - var bytesRef = new BytesRef(bytes); - return ramIndexer.add(bytesRef, segmentRowId); + var bytesRefs = terms.stream() + .map(term -> components.onDiskFormat().encodeForTrie(term, termComparator)) + .map(encodedTerm -> ByteSourceInverse.readBytes(encodedTerm.asComparableBytes(byteComparableVersion))) + .map(BytesRef::new) + .collect(Collectors.toList()); + // ramIndexer is responsible for merging duplicate (term, row) pairs + return ramIndexer.addAll(bytesRefs, segmentRowId); } @Override protected void flushInternal(SegmentMetadataBuilder metadataBuilder) throws IOException { - try (InvertedIndexWriter writer = new InvertedIndexWriter(components)) + try (InvertedIndexWriter writer = new InvertedIndexWriter(components, writeFrequencies())) { TermsIterator termsWithPostings = ramIndexer.getTermsWithPostings(minTerm, maxTerm, byteComparableVersion); - var metadataMap = writer.writeAll(metadataBuilder.intercept(termsWithPostings)); + var docLengths = ramIndexer.getDocLengths(); + var metadataMap = writer.writeAll(metadataBuilder.intercept(termsWithPostings), docLengths); metadataBuilder.setComponentsMetadata(metadataMap); } } @@ -261,21 +274,23 @@ public boolean isEmpty() } @Override - protected long addInternal(ByteBuffer term, int segmentRowId) + protected long addInternal(List terms, int segmentRowId) { throw new UnsupportedOperationException(); } @Override - protected long addInternalAsync(ByteBuffer term, int segmentRowId) + protected long addInternalAsync(List terms, int segmentRowId) { + assert terms.size() == 1; + // CompactionGraph splits adding a node into two parts: // (1) maybeAddVector, which must be done serially because it writes to disk incrementally // (2) addGraphNode, which may be done asynchronously CompactionGraph.InsertionResult result; try { - result = graphIndex.maybeAddVector(term, segmentRowId); + result = graphIndex.maybeAddVector(terms.get(0), segmentRowId); } catch (IOException e) { @@ -367,19 +382,20 @@ public boolean isEmpty() } @Override - protected long addInternal(ByteBuffer term, int segmentRowId) + protected long addInternal(List terms, int segmentRowId) { - return graphIndex.add(term, segmentRowId); + assert terms.size() == 1; + return graphIndex.add(terms.get(0), segmentRowId); } @Override - protected long addInternalAsync(ByteBuffer term, int segmentRowId) + protected long addInternalAsync(List terms, int segmentRowId) { updatesInFlight.incrementAndGet(); compactionExecutor.submit(() -> { try { - long bytesAdded = addInternal(term, segmentRowId); + long bytesAdded = addInternal(terms, segmentRowId); totalBytesAllocatedConcurrent.add(bytesAdded); termSizeReservoir.update(bytesAdded); } @@ -454,23 +470,22 @@ public SegmentMetadata flush() throws IOException return metadataBuilder.build(); } - public long addAll(ByteBuffer term, AbstractType type, PrimaryKey key, long sstableRowId) + public long analyzeAndAdd(ByteBuffer rawTerm, AbstractType type, PrimaryKey key, long sstableRowId) { long totalSize = 0; if (TypeUtil.isLiteral(type)) { - List tokens = ByteLimitedMaterializer.materializeTokens(analyzer, term, components.context(), key); - for (ByteBuffer tokenTerm : tokens) - totalSize += add(tokenTerm, key, sstableRowId); + var terms = ByteLimitedMaterializer.materializeTokens(analyzer, rawTerm, components.context(), key); + totalSize += add(terms, key, sstableRowId); } else { - totalSize += add(term, key, sstableRowId); + totalSize += add(List.of(rawTerm), key, sstableRowId); } return totalSize; } - private long add(ByteBuffer term, PrimaryKey key, long sstableRowId) + private long add(List terms, PrimaryKey key, long sstableRowId) { assert !flushed : "Cannot add to flushed segment."; assert sstableRowId >= maxSSTableRowId; @@ -481,9 +496,12 @@ private long add(ByteBuffer term, PrimaryKey key, long sstableRowId) minKey = minKey == null ? key : minKey; maxKey = key; - // Note that the min and max terms are not encoded. - minTerm = TypeUtil.min(term, minTerm, termComparator, Version.latest()); - maxTerm = TypeUtil.max(term, maxTerm, termComparator, Version.latest()); + // Update term boundaries for all terms in this row + for (ByteBuffer term : terms) + { + minTerm = TypeUtil.min(term, minTerm, termComparator, Version.latest()); + maxTerm = TypeUtil.max(term, maxTerm, termComparator, Version.latest()); + } rowCount++; @@ -495,15 +513,23 @@ private long add(ByteBuffer term, PrimaryKey key, long sstableRowId) maxSegmentRowId = Math.max(maxSegmentRowId, segmentRowId); - long bytesAllocated = supportsAsyncAdd() - ? addInternalAsync(term, segmentRowId) - : addInternal(term, segmentRowId); - totalBytesAllocated += bytesAllocated; + long bytesAllocated; + if (supportsAsyncAdd()) + { + // only vector indexing is done async and there can only be one term + assert terms.size() == 1; + bytesAllocated = addInternalAsync(terms, segmentRowId); + } + else + { + bytesAllocated = addInternal(terms, segmentRowId); + } + totalBytesAllocated += bytesAllocated; return bytesAllocated; } - protected long addInternalAsync(ByteBuffer term, int segmentRowId) + protected long addInternalAsync(List terms, int segmentRowId) { throw new UnsupportedOperationException(); } @@ -566,7 +592,7 @@ long release(IndexContext indexContext) public abstract boolean isEmpty(); - protected abstract long addInternal(ByteBuffer term, int segmentRowId); + protected abstract long addInternal(List terms, int segmentRowId); protected abstract void flushInternal(SegmentMetadataBuilder metadataBuilder) throws IOException; diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentMetadata.java b/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentMetadata.java index 00361d6f27ab..b8f4cd769873 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentMetadata.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentMetadata.java @@ -376,6 +376,11 @@ public ComponentMetadata get(IndexComponentType indexComponentType) return metas.get(indexComponentType); } + public ComponentMetadata getOptional(IndexComponentType indexComponentType) + { + return metas.get(indexComponentType); + } + public Map> asMap() { Map> metaAttributes = new HashMap<>(); diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/TermsReader.java b/src/java/org/apache/cassandra/index/sai/disk/v1/TermsReader.java index a75871ec28cd..f05511b195a5 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/TermsReader.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/TermsReader.java @@ -228,7 +228,7 @@ public PostingsReader getPostingReader(long offset) throws IOException { PostingsReader.BlocksSummary header = new PostingsReader.BlocksSummary(postingsSummaryInput, offset); - return new PostingsReader(postingsInput, header, listener.postingListEventListener()); + return new PostingsReader(postingsInput, header, readFrequencies(), listener.postingListEventListener()); } } @@ -363,11 +363,17 @@ private PostingsReader currentReader(IndexInput postingsInput, PostingsReader.InputCloser.NOOP); return new PostingsReader(postingsInput, blocksSummary, + readFrequencies(), listener.postingListEventListener(), PostingsReader.InputCloser.NOOP); } } + private boolean readFrequencies() + { + return indexContext.isAnalyzed() && version.onOrAfter(Version.EC); + } + private class TermsScanner implements TermsIterator { private final TrieTermsDictionaryReader termsDictionaryReader; @@ -400,7 +406,7 @@ public PostingList postings() throws IOException { assert entry != null; var blockSummary = new PostingsReader.BlocksSummary(postingsSummaryInput, entry.right, PostingsReader.InputCloser.NOOP); - return new ScanningPostingsReader(postingsInput, blockSummary); + return new ScanningPostingsReader(postingsInput, blockSummary, readFrequencies()); } @Override @@ -461,7 +467,7 @@ public PostingList postings() throws IOException { assert entry != null; var blockSummary = new PostingsReader.BlocksSummary(postingsSummaryInput, entry.right, PostingsReader.InputCloser.NOOP); - return new ScanningPostingsReader(postingsInput, blockSummary); + return new ScanningPostingsReader(postingsInput, blockSummary, readFrequencies()); } @Override diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/V1OnDiskFormat.java b/src/java/org/apache/cassandra/index/sai/disk/v1/V1OnDiskFormat.java index 5dbd732f4676..8afe5d62b4c1 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/V1OnDiskFormat.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/V1OnDiskFormat.java @@ -83,10 +83,11 @@ public class V1OnDiskFormat implements OnDiskFormat IndexComponentType.META, IndexComponentType.TERMS_DATA, IndexComponentType.POSTING_LISTS); - private static final Set NUMERIC_COMPONENTS = EnumSet.of(IndexComponentType.COLUMN_COMPLETION_MARKER, - IndexComponentType.META, - IndexComponentType.KD_TREE, - IndexComponentType.KD_TREE_POSTING_LISTS); + + public static final Set NUMERIC_COMPONENTS = EnumSet.of(IndexComponentType.COLUMN_COMPLETION_MARKER, + IndexComponentType.META, + IndexComponentType.KD_TREE, + IndexComponentType.KD_TREE_POSTING_LISTS); /** * Global limit on heap consumed by all index segment building that occurs outside the context of Memtable flush. diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java new file mode 100644 index 000000000000..295796c9066f --- /dev/null +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.disk.v1.postings; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import javax.annotation.concurrent.NotThreadSafe; + +import org.apache.cassandra.index.sai.disk.PostingList; +import org.apache.cassandra.io.util.FileUtils; + +/** + * Performs intersection operations on multiple PostingLists, returning only postings + * that appear in all inputs. + */ +@NotThreadSafe +public class IntersectingPostingList implements PostingList +{ + private final Map postingsByTerm; + private final List postingLists; // so we can access by ordinal in intersection code + private final int size; + + private IntersectingPostingList(Map postingsByTerm) + { + if (postingsByTerm.isEmpty()) + throw new AssertionError(); + this.postingsByTerm = postingsByTerm; + this.postingLists = new ArrayList<>(postingsByTerm.values()); + this.size = postingLists.stream() + .mapToInt(PostingList::size) + .min() + .orElse(0); + } + + /** + * @return the intersection of the provided term-posting list mappings + */ + public static IntersectingPostingList intersect(Map postingsByTerm) + { + // TODO optimize cases where + // - we have a single postinglist + // - any posting list is empty (intersection also empty) + return new IntersectingPostingList(postingsByTerm); + } + + @Override + public int nextPosting() throws IOException + { + return findNextIntersection(Integer.MIN_VALUE, false); + } + + @Override + public int advance(int targetRowID) throws IOException + { + assert targetRowID >= 0 : targetRowID; + return findNextIntersection(targetRowID, true); + } + + @Override + public int frequency() + { + // call frequencies() instead + throw new UnsupportedOperationException(); + } + + public Map frequencies() + { + Map result = new HashMap<>(); + for (Map.Entry entry : postingsByTerm.entrySet()) + result.put(entry.getKey(), entry.getValue().frequency()); + return result; + } + + private int findNextIntersection(int targetRowID, boolean isAdvance) throws IOException + { + int maxRowId = targetRowID; + int maxRowIdIndex = -1; + + // Scan through all posting lists looking for a common row ID + for (int i = 0; i < postingLists.size(); i++) + { + // don't advance the sublist in which we found our current max + if (i == maxRowIdIndex) + continue; + + // Advance this sublist to the current max, special casing the first one as needed + PostingList list = postingLists.get(i); + int rowId = (isAdvance || maxRowIdIndex >= 0) + ? list.advance(maxRowId) + : list.nextPosting(); + if (rowId == END_OF_STREAM) + return END_OF_STREAM; + + // Update maxRowId + index if we find a larger value, or this was the first sublist evaluated + if (rowId > maxRowId || maxRowIdIndex < 0) + { + maxRowId = rowId; + maxRowIdIndex = i; + i = -1; // restart the scan with new maxRowId + } + } + + // Once we complete a full scan without finding a larger rowId, we've found an intersection + return maxRowId; + } + + @Override + public int size() + { + return size; + } + + @Override + public void close() + { + for (PostingList list : postingLists) + FileUtils.closeQuietly(list); + } +} + + diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/PostingsReader.java b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/PostingsReader.java index ed3556c4ffe5..0ee2ab63b132 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/PostingsReader.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/PostingsReader.java @@ -51,7 +51,7 @@ public class PostingsReader implements OrdinalPostingList protected final IndexInput input; protected final InputCloser runOnClose; - private final int blockSize; + private final int blockEntries; private final int numPostings; private final LongArray blockOffsets; private final LongArray blockMaxValues; @@ -69,6 +69,8 @@ public class PostingsReader implements OrdinalPostingList private long currentPosition; private LongValues currentFORValues; private int postingsDecoded = 0; + private int currentFrequency = Integer.MIN_VALUE; + private final boolean readFrequencies; @VisibleForTesting public PostingsReader(IndexInput input, long summaryOffset, QueryEventListener.PostingListEventListener listener) throws IOException @@ -78,7 +80,7 @@ public PostingsReader(IndexInput input, long summaryOffset, QueryEventListener.P public PostingsReader(IndexInput input, BlocksSummary summary, QueryEventListener.PostingListEventListener listener) throws IOException { - this(input, summary, listener, () -> { + this(input, summary, false, listener, () -> { try { input.close(); @@ -90,14 +92,29 @@ public PostingsReader(IndexInput input, BlocksSummary summary, QueryEventListene }); } - public PostingsReader(IndexInput input, BlocksSummary summary, QueryEventListener.PostingListEventListener listener, InputCloser runOnClose) throws IOException + public PostingsReader(IndexInput input, BlocksSummary summary, boolean readFrequencies, QueryEventListener.PostingListEventListener listener) throws IOException + { + this(input, summary, readFrequencies, listener, () -> { + try + { + input.close(); + } + finally + { + summary.close(); + } + }); + } + + public PostingsReader(IndexInput input, BlocksSummary summary, boolean readFrequencies, QueryEventListener.PostingListEventListener listener, InputCloser runOnClose) throws IOException { assert input instanceof IndexInputReader; logger.trace("Opening postings reader for {}", input); + this.readFrequencies = readFrequencies; this.input = input; this.seekingInput = new SeekingRandomAccessInput(input); this.blockOffsets = summary.offsets; - this.blockSize = summary.blockSize; + this.blockEntries = summary.blockEntries; this.numPostings = summary.numPostings; this.blockMaxValues = summary.maxValues; this.listener = listener; @@ -122,7 +139,7 @@ public interface InputCloser public static class BlocksSummary { - final int blockSize; + final int blockEntries; final int numPostings; final LongArray offsets; final LongArray maxValues; @@ -139,7 +156,7 @@ public BlocksSummary(IndexInput input, long offset, InputCloser runOnClose) thro this.runOnClose = runOnClose; input.seek(offset); - this.blockSize = input.readVInt(); + this.blockEntries = input.readVInt(); // This is the count of row ids in a single posting list. For now, a segment cannot have more than // Integer.MAX_VALUE row ids, so it is safe to use an int here. this.numPostings = input.readVInt(); @@ -323,10 +340,10 @@ private void lastPosInBlock(int block) // blockMaxValues is integer only actualSegmentRowId = Math.toIntExact(blockMaxValues.get(block)); //upper bound, since we might've advanced to the last block, but upper bound is enough - totalPostingsRead += (blockSize - blockIdx) + (block - postingsBlockIdx + 1) * blockSize; + totalPostingsRead += (blockEntries - blockIdx) + (block - postingsBlockIdx + 1) * blockEntries; postingsBlockIdx = block + 1; - blockIdx = blockSize; + blockIdx = blockEntries; } @Override @@ -341,9 +358,9 @@ public int nextPosting() throws IOException } @VisibleForTesting - int getBlockSize() + int getBlockEntries() { - return blockSize; + return blockEntries; } private int peekNext() throws IOException @@ -352,27 +369,28 @@ private int peekNext() throws IOException { return END_OF_STREAM; } - if (blockIdx == blockSize) + if (blockIdx == blockEntries) { reBuffer(); } - return actualSegmentRowId + nextRowID(); + return actualSegmentRowId + nextRowDelta(); } - private int nextRowID() + private int nextRowDelta() { - // currentFORValues is null when the all the values in the block are the same if (currentFORValues == null) { + currentFrequency = Integer.MIN_VALUE; return 0; } - else - { - final long id = currentFORValues.get(blockIdx); - postingsDecoded++; - return Math.toIntExact(id); - } + + long offset = readFrequencies ? 2L * blockIdx : blockIdx; + long id = currentFORValues.get(offset); + if (readFrequencies) + currentFrequency = Math.toIntExact(currentFORValues.get(offset + 1)); + postingsDecoded++; + return Math.toIntExact(id); } private void advanceOnePosition(int nextRowID) @@ -420,4 +438,9 @@ else if (bitsPerValue > 64) } currentFORValues = LuceneCompat.directReaderGetInstance(seekingInput, bitsPerValue, currentPosition); } + + @Override + public int frequency() { + return currentFrequency; + } } diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/PostingsWriter.java b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/PostingsWriter.java index ef027e1e2a39..5bf0decfb8e4 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/PostingsWriter.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/PostingsWriter.java @@ -27,7 +27,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.agrona.collections.IntArrayList; import org.agrona.collections.LongArrayList; import org.apache.cassandra.index.sai.disk.PostingList; import org.apache.cassandra.index.sai.disk.format.IndexComponentType; @@ -42,6 +41,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static java.lang.Math.max; +import static java.lang.Math.min; /** @@ -65,7 +65,7 @@ *

* Each posting list ends with a meta section and a skip table, that are written right after all postings blocks. Skip * interval is the same as block size, and each skip entry points to the end of each block. Skip table consist of - * block offsets and last values of each block, compressed as two FoR blocks. + * block offsets and maximum rowids of each block, compressed as two FoR blocks. *

* * Visual representation of the disk format: @@ -80,7 +80,7 @@ * | LIST SIZE | SKIP TABLE | * +---------------+------------+ * | BLOCKS POS.| - * | MAX VALUES | + * | MAX ROWIDS | * +------------+ * * @@ -91,13 +91,14 @@ public class PostingsWriter implements Closeable protected static final Logger logger = LoggerFactory.getLogger(PostingsWriter.class); // import static org.apache.lucene.codecs.lucene50.Lucene50PostingsFormat.BLOCK_SIZE; - private final static int BLOCK_SIZE = 128; + private final static int BLOCK_ENTRIES = 128; private static final String POSTINGS_MUST_BE_SORTED_ERROR_MSG = "Postings must be sorted ascending, got [%s] after [%s]"; private final IndexOutput dataOutput; - private final int blockSize; + private final int blockEntries; private final long[] deltaBuffer; + private final int[] freqBuffer; // frequency is capped at 255 private final LongArrayList blockOffsets = new LongArrayList(); private final LongArrayList blockMaxIDs = new LongArrayList(); private final ResettableByteBuffersIndexOutput inMemoryOutput; @@ -106,35 +107,43 @@ public class PostingsWriter implements Closeable private int bufferUpto; private long lastSegmentRowId; - private long maxDelta; // This number is the count of row ids written to the postings for this segment. Because a segment row id can be in // multiple postings list for the segment, this number could exceed Integer.MAX_VALUE, so we use a long. private long totalPostings; + private final boolean writeFrequencies; public PostingsWriter(IndexComponents.ForWrite components) throws IOException { - this(components, BLOCK_SIZE); + this(components, BLOCK_ENTRIES); } + public PostingsWriter(IndexComponents.ForWrite components, boolean writeFrequencies) throws IOException + { + this(components.addOrGet(IndexComponentType.POSTING_LISTS).openOutput(true), BLOCK_ENTRIES, writeFrequencies); + } + + public PostingsWriter(IndexOutput dataOutput) throws IOException { - this(dataOutput, BLOCK_SIZE); + this(dataOutput, BLOCK_ENTRIES, false); } @VisibleForTesting - PostingsWriter(IndexComponents.ForWrite components, int blockSize) throws IOException + PostingsWriter(IndexComponents.ForWrite components, int blockEntries) throws IOException { - this(components.addOrGet(IndexComponentType.POSTING_LISTS).openOutput(true), blockSize); + this(components.addOrGet(IndexComponentType.POSTING_LISTS).openOutput(true), blockEntries, false); } - private PostingsWriter(IndexOutput dataOutput, int blockSize) throws IOException + private PostingsWriter(IndexOutput dataOutput, int blockEntries, boolean writeFrequencies) throws IOException { assert dataOutput instanceof IndexOutputWriter; logger.debug("Creating postings writer for output {}", dataOutput); - this.blockSize = blockSize; + this.writeFrequencies = writeFrequencies; + this.blockEntries = blockEntries; this.dataOutput = dataOutput; startOffset = dataOutput.getFilePointer(); - deltaBuffer = new long[blockSize]; + deltaBuffer = new long[blockEntries]; + freqBuffer = new int[blockEntries]; inMemoryOutput = LuceneCompat.getResettableByteBuffersIndexOutput(dataOutput.order(), 1024, "blockOffsets"); SAICodecUtils.writeHeader(dataOutput); } @@ -191,7 +200,7 @@ public long write(PostingList postings) throws IOException int size = 0; while ((segmentRowId = postings.nextPosting()) != PostingList.END_OF_STREAM) { - writePosting(segmentRowId); + writePosting(segmentRowId, postings.frequency()); size++; totalPostings++; } @@ -210,19 +219,19 @@ public long getTotalPostings() return totalPostings; } - private void writePosting(long segmentRowId) throws IOException - { + private void writePosting(long segmentRowId, int freq) throws IOException { if (!(segmentRowId >= lastSegmentRowId || lastSegmentRowId == 0)) throw new IllegalArgumentException(String.format(POSTINGS_MUST_BE_SORTED_ERROR_MSG, segmentRowId, lastSegmentRowId)); + assert freq > 0; final long delta = segmentRowId - lastSegmentRowId; - maxDelta = max(maxDelta, delta); - deltaBuffer[bufferUpto++] = delta; + deltaBuffer[bufferUpto] = delta; + freqBuffer[bufferUpto] = min(freq, 255); + bufferUpto++; - if (bufferUpto == blockSize) - { + if (bufferUpto == blockEntries) { addBlockToSkipTable(segmentRowId); - writePostingsBlock(maxDelta, bufferUpto); + writePostingsBlock(bufferUpto); resetBlockCounters(); } lastSegmentRowId = segmentRowId; @@ -234,7 +243,7 @@ private void finish() throws IOException { addBlockToSkipTable(lastSegmentRowId); - writePostingsBlock(maxDelta, bufferUpto); + writePostingsBlock(bufferUpto); } } @@ -242,7 +251,6 @@ private void resetBlockCounters() { bufferUpto = 0; lastSegmentRowId = 0; - maxDelta = 0; } private void addBlockToSkipTable(long maxSegmentRowID) @@ -253,7 +261,7 @@ private void addBlockToSkipTable(long maxSegmentRowID) private void writeSummary(int exactSize) throws IOException { - dataOutput.writeVInt(blockSize); + dataOutput.writeVInt(blockEntries); dataOutput.writeVInt(exactSize); writeSkipTable(); } @@ -272,19 +280,29 @@ private void writeSkipTable() throws IOException writeSortedFoRBlock(blockMaxIDs, dataOutput); } - private void writePostingsBlock(long maxValue, int blockSize) throws IOException - { + private void writePostingsBlock(int entries) throws IOException { + // Find max value to determine bits needed + long maxValue = 0; + for (int i = 0; i < entries; i++) { + maxValue = max(maxValue, deltaBuffer[i]); + if (writeFrequencies) + maxValue = max(maxValue, freqBuffer[i]); + } + + // Use the maximum bits needed for either value type final int bitsPerValue = maxValue == 0 ? 0 : LuceneCompat.directWriterUnsignedBitsRequired(dataOutput.order(), maxValue); - - assert bitsPerValue < Byte.MAX_VALUE; - + dataOutput.writeByte((byte) bitsPerValue); - if (bitsPerValue > 0) - { - final DirectWriterAdapter writer = LuceneCompat.directWriterGetInstance(dataOutput.order(), dataOutput, blockSize, bitsPerValue); - for (int i = 0; i < blockSize; ++i) - { + if (bitsPerValue > 0) { + // Write interleaved [delta][freq] pairs + final DirectWriterAdapter writer = LuceneCompat.directWriterGetInstance(dataOutput.order(), + dataOutput, + writeFrequencies ? entries * 2L : entries, + bitsPerValue); + for (int i = 0; i < entries; ++i) { writer.add(deltaBuffer[i]); + if (writeFrequencies) + writer.add(freqBuffer[i]); } writer.finish(); } @@ -292,9 +310,9 @@ private void writePostingsBlock(long maxValue, int blockSize) throws IOException private void writeSortedFoRBlock(LongArrayList values, IndexOutput output) throws IOException { + assert !values.isEmpty(); final long maxValue = values.getLong(values.size() - 1); - assert values.size() > 0; final int bitsPerValue = maxValue == 0 ? 0 : LuceneCompat.directWriterUnsignedBitsRequired(output.order(), maxValue); output.writeByte((byte) bitsPerValue); if (bitsPerValue > 0) @@ -307,22 +325,4 @@ private void writeSortedFoRBlock(LongArrayList values, IndexOutput output) throw writer.finish(); } } - - private void writeSortedFoRBlock(IntArrayList values, IndexOutput output) throws IOException - { - final int maxValue = values.getInt(values.size() - 1); - - assert values.size() > 0; - final int bitsPerValue = maxValue == 0 ? 0 : LuceneCompat.directWriterUnsignedBitsRequired(output.order(), maxValue); - output.writeByte((byte) bitsPerValue); - if (bitsPerValue > 0) - { - final DirectWriterAdapter writer = LuceneCompat.directWriterGetInstance(output.order(), output, values.size(), bitsPerValue); - for (int i = 0; i < values.size(); ++i) - { - writer.add(values.getInt(i)); - } - writer.finish(); - } - } } diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/VectorPostingList.java b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/ReorderingPostingList.java similarity index 83% rename from src/java/org/apache/cassandra/index/sai/disk/v1/postings/VectorPostingList.java rename to src/java/org/apache/cassandra/index/sai/disk/v1/postings/ReorderingPostingList.java index 52155bf6ed59..f43c7c4e8dce 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/VectorPostingList.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/ReorderingPostingList.java @@ -19,31 +19,29 @@ package org.apache.cassandra.index.sai.disk.v1.postings; import java.io.IOException; +import java.util.function.ToIntFunction; import org.apache.cassandra.index.sai.disk.PostingList; -import org.apache.cassandra.index.sai.utils.RowIdWithMeta; import org.apache.cassandra.utils.CloseableIterator; import org.apache.lucene.util.LongHeap; /** * A posting list for ANN search results. Transforms results from similarity order to rowId order. */ -public class VectorPostingList implements PostingList +public class ReorderingPostingList implements PostingList { private final LongHeap segmentRowIds; private final int size; - public VectorPostingList(CloseableIterator source) + public ReorderingPostingList(CloseableIterator source, ToIntFunction rowIdTransformer) { - // TODO find int specific data structure? segmentRowIds = new LongHeap(32); int n = 0; - // Once the source is consumed, we have to close it. try (source) { while (source.hasNext()) { - segmentRowIds.push(source.next().getSegmentRowId()); + segmentRowIds.push(rowIdTransformer.applyAsInt(source.next())); n++; } } diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/ScanningPostingsReader.java b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/ScanningPostingsReader.java index 2fd7c2336ab2..5c9da6169e08 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/ScanningPostingsReader.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/ScanningPostingsReader.java @@ -31,9 +31,9 @@ */ public class ScanningPostingsReader extends PostingsReader { - public ScanningPostingsReader(IndexInput input, BlocksSummary summary) throws IOException + public ScanningPostingsReader(IndexInput input, BlocksSummary summary, boolean readFrequencies) throws IOException { - super(input, summary, QueryEventListener.PostingListEventListener.NO_OP, InputCloser.NOOP); + super(input, summary, readFrequencies, QueryEventListener.PostingListEventListener.NO_OP, InputCloser.NOOP); } @Override diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/trie/DocLengthsWriter.java b/src/java/org/apache/cassandra/index/sai/disk/v1/trie/DocLengthsWriter.java new file mode 100644 index 000000000000..90dc6b5c6ff9 --- /dev/null +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/trie/DocLengthsWriter.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.index.sai.disk.v1.trie; + +import java.io.Closeable; +import java.io.IOException; + +import org.agrona.collections.Int2IntHashMap; +import org.apache.cassandra.index.sai.disk.format.IndexComponents; +import org.apache.cassandra.index.sai.disk.format.IndexComponentType; +import org.apache.cassandra.index.sai.disk.io.IndexOutputWriter; +import org.apache.cassandra.index.sai.utils.SAICodecUtils; + +/** + * Writes document length information to disk for use in text scoring + */ +public class DocLengthsWriter implements Closeable +{ + private final IndexOutputWriter output; + + public DocLengthsWriter(IndexComponents.ForWrite components) throws IOException + { + this.output = components.addOrGet(IndexComponentType.DOC_LENGTHS).openOutput(true); + SAICodecUtils.writeHeader(output); + } + + public void writeDocLengths(Int2IntHashMap lengths) throws IOException + { + // Calculate max row ID from doc lengths map + int maxRowId = -1; + for (var keyIterator = lengths.keySet().iterator(); keyIterator.hasNext(); ) + { + int key = keyIterator.nextValue(); + if (key > maxRowId) + maxRowId = key; + } + + // write out the doc lengths in row order + for (int rowId = 0; rowId <= maxRowId; rowId++) + { + final int length = lengths.get(rowId); + output.writeInt(length == lengths.missingValue() ? 0 : length); + } + + SAICodecUtils.writeFooter(output); + } + + public long getFilePointer() + { + return output.getFilePointer(); + } + + @Override + public void close() throws IOException + { + output.close(); + } +} diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/trie/InvertedIndexWriter.java b/src/java/org/apache/cassandra/index/sai/disk/v1/trie/InvertedIndexWriter.java index e6adb61ce863..a61df65f2239 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/trie/InvertedIndexWriter.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/trie/InvertedIndexWriter.java @@ -25,10 +25,12 @@ import org.apache.commons.lang3.mutable.MutableLong; +import org.agrona.collections.Int2IntHashMap; import org.apache.cassandra.index.sai.disk.PostingList; import org.apache.cassandra.index.sai.disk.TermsIterator; import org.apache.cassandra.index.sai.disk.format.IndexComponents; import org.apache.cassandra.index.sai.disk.format.IndexComponentType; +import org.apache.cassandra.index.sai.disk.format.Version; import org.apache.cassandra.index.sai.disk.v1.SegmentMetadata; import org.apache.cassandra.index.sai.disk.v1.postings.PostingsWriter; import org.apache.cassandra.index.sai.utils.SAICodecUtils; @@ -42,12 +44,19 @@ public class InvertedIndexWriter implements Closeable { private final TrieTermsDictionaryWriter termsDictionaryWriter; private final PostingsWriter postingsWriter; + private final DocLengthsWriter docLengthsWriter; private long postingsAdded; public InvertedIndexWriter(IndexComponents.ForWrite components) throws IOException + { + this(components, false); + } + + public InvertedIndexWriter(IndexComponents.ForWrite components, boolean writeFrequencies) throws IOException { this.termsDictionaryWriter = new TrieTermsDictionaryWriter(components); - this.postingsWriter = new PostingsWriter(components); + this.postingsWriter = new PostingsWriter(components, writeFrequencies); + this.docLengthsWriter = Version.latest().onOrAfter(Version.EC) ? new DocLengthsWriter(components) : null; } /** @@ -58,7 +67,7 @@ public InvertedIndexWriter(IndexComponents.ForWrite components) throws IOExcepti * @return metadata describing the location of this inverted index in the overall SSTable * terms and postings component files */ - public SegmentMetadata.ComponentMetadataMap writeAll(TermsIterator terms) throws IOException + public SegmentMetadata.ComponentMetadataMap writeAll(TermsIterator terms, Int2IntHashMap docLengths) throws IOException { // Terms and postings writers are opened in append mode with pointers at the end of their respective files. long termsOffset = termsDictionaryWriter.getStartOffset(); @@ -91,6 +100,15 @@ public SegmentMetadata.ComponentMetadataMap writeAll(TermsIterator terms) throws components.put(IndexComponentType.POSTING_LISTS, -1, postingsOffset, postingsLength); components.put(IndexComponentType.TERMS_DATA, termsRoot, termsOffset, termsLength, map); + // Write doc lengths + if (docLengthsWriter != null) + { + long docLengthsOffset = docLengthsWriter.getFilePointer(); + docLengthsWriter.writeDocLengths(docLengths); + long docLengthsLength = docLengthsWriter.getFilePointer() - docLengthsOffset; + components.put(IndexComponentType.DOC_LENGTHS, -1, docLengthsOffset, docLengthsLength); + } + return components; } @@ -99,6 +117,8 @@ public void close() throws IOException { postingsWriter.close(); termsDictionaryWriter.close(); + if (docLengthsWriter != null) + docLengthsWriter.close(); } /** diff --git a/src/java/org/apache/cassandra/index/sai/disk/v2/RowAwarePrimaryKeyFactory.java b/src/java/org/apache/cassandra/index/sai/disk/v2/RowAwarePrimaryKeyFactory.java index e51a6229de58..cba1e1124ce3 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v2/RowAwarePrimaryKeyFactory.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v2/RowAwarePrimaryKeyFactory.java @@ -23,6 +23,7 @@ import java.util.function.Supplier; import java.util.stream.Collectors; +import io.github.jbellis.jvector.util.RamUsageEstimator; import org.apache.cassandra.db.Clustering; import org.apache.cassandra.db.ClusteringComparator; import org.apache.cassandra.db.DecoratedKey; @@ -212,5 +213,22 @@ public String toString() .map(ByteBufferUtil::bytesToHex) .collect(Collectors.toList()))); } + + @Override + public long ramBytesUsed() + { + // Object header + 4 references (token, partitionKey, clustering, primaryKeySupplier) + implicit outer reference + long size = RamUsageEstimator.NUM_BYTES_OBJECT_HEADER + + 5L * RamUsageEstimator.NUM_BYTES_OBJECT_REF; + + if (token != null) + size += token.getHeapSize(); + if (partitionKey != null) + size += RamUsageEstimator.NUM_BYTES_OBJECT_HEADER + + 2L * RamUsageEstimator.NUM_BYTES_OBJECT_REF + // token and key references + 2L * Long.BYTES; + // We don't count clustering size here as it's managed elsewhere + return size; + } } } diff --git a/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java index 80870dab36bf..7da1097e64ae 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java @@ -49,7 +49,7 @@ import org.apache.cassandra.index.sai.disk.v1.IndexSearcher; import org.apache.cassandra.index.sai.disk.v1.PerIndexFiles; import org.apache.cassandra.index.sai.disk.v1.SegmentMetadata; -import org.apache.cassandra.index.sai.disk.v1.postings.VectorPostingList; +import org.apache.cassandra.index.sai.disk.v1.postings.ReorderingPostingList; import org.apache.cassandra.index.sai.disk.v5.V5VectorPostingsWriter; import org.apache.cassandra.index.sai.disk.vector.BruteForceRowIdIterator; import org.apache.cassandra.index.sai.disk.vector.CassandraDiskAnn; @@ -64,8 +64,8 @@ import org.apache.cassandra.index.sai.utils.PrimaryKey; import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey; import org.apache.cassandra.index.sai.utils.RangeUtil; +import org.apache.cassandra.index.sai.utils.RowIdWithMeta; import org.apache.cassandra.index.sai.utils.RowIdWithScore; -import org.apache.cassandra.index.sai.utils.SegmentOrdering; import org.apache.cassandra.io.sstable.format.SSTableReader; import org.apache.cassandra.metrics.LinearFit; import org.apache.cassandra.metrics.PairedSlidingWindowReservoir; @@ -81,7 +81,7 @@ /** * Executes ann search against the graph for an individual index segment. */ -public class V2VectorIndexSearcher extends IndexSearcher implements SegmentOrdering +public class V2VectorIndexSearcher extends IndexSearcher { private static final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); private static final VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport(); @@ -133,13 +133,13 @@ public ProductQuantization getPQ() } @Override - public KeyRangeIterator search(Expression exp, AbstractBounds keyRange, QueryContext context, boolean defer, int limit) throws IOException + public KeyRangeIterator search(Expression exp, AbstractBounds keyRange, QueryContext context, boolean defer) throws IOException { - PostingList results = searchPosting(context, exp, keyRange, limit); + PostingList results = searchPosting(context, exp, keyRange); return toPrimaryKeyIterator(results, context); } - private PostingList searchPosting(QueryContext context, Expression exp, AbstractBounds keyRange, int limit) throws IOException + private PostingList searchPosting(QueryContext context, Expression exp, AbstractBounds keyRange) throws IOException { if (logger.isTraceEnabled()) logger.trace(indexContext.logMessage("Searching on expression '{}'..."), exp); @@ -151,7 +151,7 @@ private PostingList searchPosting(QueryContext context, Expression exp, Abstract // this is a thresholded query, so pass graph.size() as top k to get all results satisfying the threshold var result = searchInternal(keyRange, context, queryVector, graph.size(), graph.size(), exp.getEuclideanSearchThreshold()); - return new VectorPostingList(result); + return new ReorderingPostingList(result, RowIdWithMeta::getSegmentRowId); } @Override @@ -160,11 +160,11 @@ public CloseableIterator orderBy(Orderer orderer, Express if (logger.isTraceEnabled()) logger.trace(indexContext.logMessage("Searching on expression '{}'..."), orderer); - if (orderer.vector == null) + if (!orderer.isANN()) throw new IllegalArgumentException(indexContext.logMessage("Unsupported expression during ANN index query: " + orderer)); int rerankK = indexContext.getIndexWriterConfig().getSourceModel().rerankKFor(limit, graph.getCompression()); - var queryVector = vts.createFloatVector(orderer.vector); + var queryVector = vts.createFloatVector(orderer.getVectorTerm()); var result = searchInternal(keyRange, context, queryVector, limit, rerankK, 0); return toMetaSortedIterator(result, context); @@ -485,14 +485,14 @@ public CloseableIterator orderResultsBy(SSTableReader rea if (cost.shouldUseBruteForce()) { // brute force using the in-memory compressed vectors to cut down the number of results returned - var queryVector = vts.createFloatVector(orderer.vector); + var queryVector = vts.createFloatVector(orderer.getVectorTerm()); return toMetaSortedIterator(this.orderByBruteForce(queryVector, segmentOrdinalPairs, limit, rerankK), context); } // Create bits from the mapping var bits = bitSetForSearch(); segmentOrdinalPairs.forEachRightInt(bits::set); // else ask the index to perform a search limited to the bits we created - var queryVector = vts.createFloatVector(orderer.vector); + var queryVector = vts.createFloatVector(orderer.getVectorTerm()); var results = graph.search(queryVector, limit, rerankK, 0, bits, context, cost::updateStatistics); return toMetaSortedIterator(results, context); } diff --git a/src/java/org/apache/cassandra/index/sai/disk/v4/V4InvertedIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v4/V4InvertedIndexSearcher.java index ac11b238844e..f819b6e1ee6f 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v4/V4InvertedIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v4/V4InvertedIndexSearcher.java @@ -37,6 +37,6 @@ class V4InvertedIndexSearcher extends InvertedIndexSearcher SegmentMetadata segmentMetadata, IndexContext indexContext) throws IOException { - super(sstableContext, perIndexFiles, segmentMetadata, indexContext, Version.DB, false); + super(sstableContext, perIndexFiles, segmentMetadata, indexContext, segmentMetadata.version, false); } } diff --git a/src/java/org/apache/cassandra/index/sai/disk/v7/V7OnDiskFormat.java b/src/java/org/apache/cassandra/index/sai/disk/v7/V7OnDiskFormat.java new file mode 100644 index 000000000000..5457685cf1b0 --- /dev/null +++ b/src/java/org/apache/cassandra/index/sai/disk/v7/V7OnDiskFormat.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.disk.v7; + +import java.util.EnumSet; +import java.util.Set; + +import org.apache.cassandra.db.marshal.AbstractType; +import org.apache.cassandra.index.sai.disk.format.IndexComponentType; +import org.apache.cassandra.index.sai.disk.v1.V1OnDiskFormat; +import org.apache.cassandra.index.sai.disk.v3.V3OnDiskFormat; +import org.apache.cassandra.index.sai.disk.v6.V6OnDiskFormat; +import org.apache.cassandra.index.sai.utils.TypeUtil; + +public class V7OnDiskFormat extends V6OnDiskFormat +{ + public static final V7OnDiskFormat instance = new V7OnDiskFormat(); + + private static final Set LITERAL_COMPONENTS = EnumSet.of(IndexComponentType.COLUMN_COMPLETION_MARKER, + IndexComponentType.META, + IndexComponentType.TERMS_DATA, + IndexComponentType.POSTING_LISTS, + IndexComponentType.DOC_LENGTHS); + + @Override + public Set perIndexComponentTypes(AbstractType validator) + { + if (validator.isVector()) + return V3OnDiskFormat.VECTOR_COMPONENTS_V3; + if (TypeUtil.isLiteral(validator)) + return LITERAL_COMPONENTS; + return V1OnDiskFormat.NUMERIC_COMPONENTS; + } +} diff --git a/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java b/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java index 25231748149f..a6a52c462f3b 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java +++ b/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java @@ -54,6 +54,7 @@ import org.apache.cassandra.index.sai.disk.format.IndexComponents; import org.apache.cassandra.index.sai.disk.v1.SegmentMetadata; import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; +import org.apache.cassandra.index.sai.memory.MemoryIndex; import org.apache.cassandra.index.sai.memory.MemtableIndex; import org.apache.cassandra.index.sai.plan.Expression; import org.apache.cassandra.index.sai.plan.Orderer; @@ -222,7 +223,7 @@ public List> orderBy(QueryContext conte assert slice == null : "ANN does not support index slicing"; assert orderer.isANN() : "Only ANN is supported for vector search, received " + orderer.operator; - var qv = vts.createFloatVector(orderer.vector); + var qv = vts.createFloatVector(orderer.getVectorTerm()); return List.of(searchInternal(context, qv, keyRange, limit, 0)); } @@ -310,7 +311,7 @@ public CloseableIterator orderResultsBy(QueryContext cont relevantOrdinals.size(), keys.size(), maxBruteForceRows, graph.size(), limit); // convert the expression value to query vector - var qv = vts.createFloatVector(orderer.vector); + var qv = vts.createFloatVector(orderer.getVectorTerm()); // brute force path if (keysInGraph.size() <= maxBruteForceRows) { @@ -422,7 +423,7 @@ public static int ensureSaneEstimate(int rawEstimate, int rerankK, int graphSize } @Override - public Iterator>> iterator(DecoratedKey min, DecoratedKey max) + public Iterator>> iterator(DecoratedKey min, DecoratedKey max) { // This method is only used when merging an in-memory index with a RowMapping. This is done a different // way with the graph using the writeData method below. diff --git a/src/java/org/apache/cassandra/index/sai/memory/MemoryIndex.java b/src/java/org/apache/cassandra/index/sai/memory/MemoryIndex.java index 83c43e212d78..b50a58805430 100644 --- a/src/java/org/apache/cassandra/index/sai/memory/MemoryIndex.java +++ b/src/java/org/apache/cassandra/index/sai/memory/MemoryIndex.java @@ -20,6 +20,7 @@ import java.nio.ByteBuffer; import java.util.Iterator; +import java.util.List; import java.util.function.LongConsumer; import org.apache.cassandra.db.Clustering; @@ -30,8 +31,8 @@ import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; import org.apache.cassandra.index.sai.plan.Expression; import org.apache.cassandra.index.sai.plan.Orderer; +import org.apache.cassandra.index.sai.utils.PrimaryKey; import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey; -import org.apache.cassandra.index.sai.utils.PrimaryKeys; import org.apache.cassandra.utils.CloseableIterator; import org.apache.cassandra.utils.Pair; import org.apache.cassandra.utils.bytecomparable.ByteComparable; @@ -64,5 +65,17 @@ public abstract void add(DecoratedKey key, /** * Iterate all Term->PrimaryKeys mappings in sorted order */ - public abstract Iterator> iterator(); + public abstract Iterator>> iterator(); + + public static class PkWithFrequency + { + public final PrimaryKey pk; + public final int frequency; + + public PkWithFrequency(PrimaryKey pk, int frequency) + { + this.pk = pk; + this.frequency = frequency; + } + } } diff --git a/src/java/org/apache/cassandra/index/sai/memory/MemtableIndex.java b/src/java/org/apache/cassandra/index/sai/memory/MemtableIndex.java index 22bdc284bf3f..f340b06a5e19 100644 --- a/src/java/org/apache/cassandra/index/sai/memory/MemtableIndex.java +++ b/src/java/org/apache/cassandra/index/sai/memory/MemtableIndex.java @@ -20,6 +20,7 @@ import java.nio.ByteBuffer; import java.util.Iterator; +import java.util.List; import javax.annotation.Nullable; import org.apache.cassandra.db.Clustering; @@ -32,7 +33,6 @@ import org.apache.cassandra.index.sai.disk.vector.VectorMemtableIndex; import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; import org.apache.cassandra.index.sai.plan.Expression; -import org.apache.cassandra.index.sai.utils.PrimaryKey; import org.apache.cassandra.index.sai.utils.MemtableOrdering; import org.apache.cassandra.utils.Pair; import org.apache.cassandra.utils.bytecomparable.ByteComparable; @@ -79,7 +79,7 @@ default void update(DecoratedKey key, Clustering clustering, ByteBuffer oldValue long estimateMatchingRowsCount(Expression expression, AbstractBounds keyRange); - Iterator>> iterator(DecoratedKey min, DecoratedKey max); + Iterator>> iterator(DecoratedKey min, DecoratedKey max); static MemtableIndex createIndex(IndexContext indexContext, Memtable mt) { diff --git a/src/java/org/apache/cassandra/index/sai/memory/RowMapping.java b/src/java/org/apache/cassandra/index/sai/memory/RowMapping.java index a435a6e47169..e0cdc610a63c 100644 --- a/src/java/org/apache/cassandra/index/sai/memory/RowMapping.java +++ b/src/java/org/apache/cassandra/index/sai/memory/RowMapping.java @@ -17,10 +17,11 @@ */ package org.apache.cassandra.index.sai.memory; +import java.util.ArrayList; import java.util.Collections; import java.util.Iterator; +import java.util.List; -import com.carrotsearch.hppc.IntArrayList; import org.apache.cassandra.db.compaction.OperationType; import org.apache.cassandra.db.rows.RangeTombstoneMarker; import org.apache.cassandra.db.rows.Row; @@ -44,7 +45,7 @@ public class RowMapping public static final RowMapping DUMMY = new RowMapping() { @Override - public Iterator> merge(MemtableIndex index) { return Collections.emptyIterator(); } + public Iterator>> merge(MemtableIndex index) { return Collections.emptyIterator(); } @Override public void complete() {} @@ -89,6 +90,16 @@ public static RowMapping create(OperationType opType) return DUMMY; } + public static class RowIdWithFrequency { + public final int rowId; + public final int frequency; + + public RowIdWithFrequency(int rowId, int frequency) { + this.rowId = rowId; + this.frequency = frequency; + } + } + /** * Merge IndexMemtable(index term to PrimaryKeys mappings) with row mapping of a sstable * (PrimaryKey to RowId mappings). @@ -97,33 +108,32 @@ public static RowMapping create(OperationType opType) * * @return iterator of index term to postings mapping exists in the sstable */ - public Iterator> merge(MemtableIndex index) + public Iterator>> merge(MemtableIndex index) { assert complete : "RowMapping is not built."; - Iterator>> iterator = index.iterator(minKey.partitionKey(), maxKey.partitionKey()); - return new AbstractGuavaIterator>() + var it = index.iterator(minKey.partitionKey(), maxKey.partitionKey()); + return new AbstractGuavaIterator<>() { @Override - protected Pair computeNext() + protected Pair> computeNext() { - while (iterator.hasNext()) + while (it.hasNext()) { - Pair> pair = iterator.next(); + var pair = it.next(); - IntArrayList postings = null; - Iterator primaryKeys = pair.right; + List postings = null; + var primaryKeysWithFreq = pair.right; - while (primaryKeys.hasNext()) + for (var pkWithFreq : primaryKeysWithFreq) { - PrimaryKey primaryKey = primaryKeys.next(); - ByteComparable byteComparable = v -> primaryKey.asComparableBytes(v); + ByteComparable byteComparable = pkWithFreq.pk::asComparableBytes; Integer segmentRowId = rowMapping.get(byteComparable); if (segmentRowId != null) { - postings = postings == null ? new IntArrayList() : postings; - postings.add(segmentRowId); + postings = postings == null ? new ArrayList<>() : postings; + postings.add(new RowIdWithFrequency(segmentRowId, pkWithFreq.frequency)); } } if (postings != null && !postings.isEmpty()) diff --git a/src/java/org/apache/cassandra/index/sai/memory/TrieMemoryIndex.java b/src/java/org/apache/cassandra/index/sai/memory/TrieMemoryIndex.java index 57517d9192d0..593f11acccca 100644 --- a/src/java/org/apache/cassandra/index/sai/memory/TrieMemoryIndex.java +++ b/src/java/org/apache/cassandra/index/sai/memory/TrieMemoryIndex.java @@ -28,13 +28,15 @@ import java.math.BigDecimal; import java.nio.ByteBuffer; import java.util.ArrayList; - import java.util.Collection; - +import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.SortedSet; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.LongAdder; import java.util.function.LongConsumer; import javax.annotation.Nullable; @@ -43,6 +45,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import io.github.jbellis.jvector.util.Accountable; +import io.github.jbellis.jvector.util.RamUsageEstimator; import io.netty.util.concurrent.FastThreadLocal; import org.apache.cassandra.db.Clustering; import org.apache.cassandra.db.DecoratedKey; @@ -57,6 +61,7 @@ import org.apache.cassandra.dht.AbstractBounds; import org.apache.cassandra.index.sai.IndexContext; import org.apache.cassandra.index.sai.analyzer.AbstractAnalyzer; +import org.apache.cassandra.index.sai.analyzer.NoOpAnalyzer; import org.apache.cassandra.index.sai.disk.format.Version; import org.apache.cassandra.index.sai.disk.v6.TermsDistribution; import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; @@ -83,6 +88,8 @@ public class TrieMemoryIndex extends MemoryIndex private final InMemoryTrie data; private final PrimaryKeysReducer primaryKeysReducer; + private final Map termFrequencies; + private final Map docLengths = new HashMap<>(); private final Memtable memtable; private AbstractBounds keyBounds; @@ -111,6 +118,48 @@ public TrieMemoryIndex(IndexContext indexContext, Memtable memtable, AbstractBou this.data = InMemoryTrie.longLived(TypeUtil.BYTE_COMPARABLE_VERSION, TrieMemtable.BUFFER_TYPE, indexContext.columnFamilyStore().readOrdering()); this.primaryKeysReducer = new PrimaryKeysReducer(); this.memtable = memtable; + termFrequencies = new ConcurrentHashMap<>(); + } + + public synchronized Map getDocLengths() + { + return docLengths; + } + + private static class PkWithTerm implements Accountable + { + private final PrimaryKey pk; + private final ByteComparable term; + + private PkWithTerm(PrimaryKey pk, ByteComparable term) + { + this.pk = pk; + this.term = term; + } + + @Override + public long ramBytesUsed() + { + return RamUsageEstimator.NUM_BYTES_OBJECT_HEADER + + 2L * RamUsageEstimator.NUM_BYTES_OBJECT_REF + + pk.ramBytesUsed() + + ByteComparable.length(term, TypeUtil.BYTE_COMPARABLE_VERSION); + } + + @Override + public int hashCode() + { + return Objects.hash(pk, ByteComparable.length(term, TypeUtil.BYTE_COMPARABLE_VERSION)); + } + + @Override + public boolean equals(Object o) + { + if (o == null || getClass() != o.getClass()) return false; + PkWithTerm that = (PkWithTerm) o; + return Objects.equals(pk, that.pk) + && ByteComparable.compare(term, that.term, TypeUtil.BYTE_COMPARABLE_VERSION) == 0; + } } public synchronized void add(DecoratedKey key, @@ -129,12 +178,33 @@ public synchronized void add(DecoratedKey key, final long initialSizeOffHeap = data.usedSizeOffHeap(); final long reducerHeapSize = primaryKeysReducer.heapAllocations(); + if (docLengths.containsKey(primaryKey) && !(analyzer instanceof NoOpAnalyzer)) + { + AtomicLong heapReclaimed = new AtomicLong(); + // we're overwriting an existing cell, clear out the old term counts + for (Map.Entry entry : data.entrySet()) + { + var termInTrie = entry.getKey(); + entry.getValue().forEach(pkInTrie -> { + if (pkInTrie.equals(primaryKey)) + { + var t = new PkWithTerm(pkInTrie, termInTrie); + termFrequencies.remove(t); + heapReclaimed.addAndGet(RamUsageEstimator.HASHTABLE_RAM_BYTES_PER_ENTRY + t.ramBytesUsed() + Integer.BYTES); + } + }); + } + } + + int tokenCount = 0; while (analyzer.hasNext()) { final ByteBuffer term = analyzer.next(); if (!indexContext.validateMaxTermSize(key, term)) continue; + tokenCount++; + // Note that this term is already encoded once by the TypeUtil.encode call above. setMinMaxTerm(term.duplicate()); @@ -142,7 +212,25 @@ public synchronized void add(DecoratedKey key, try { - data.putSingleton(encodedTerm, primaryKey, primaryKeysReducer, term.limit() <= MAX_RECURSIVE_KEY_LENGTH); + data.putSingleton(encodedTerm, primaryKey, (existing, update) -> { + // First do the normal primary keys reduction + PrimaryKeys result = primaryKeysReducer.apply(existing, update); + if (analyzer instanceof NoOpAnalyzer) + return result; + + // Then update term frequency + var pkbc = new PkWithTerm(update, encodedTerm); + termFrequencies.compute(pkbc, (k, oldValue) -> { + if (oldValue == null) { + // New key added, track heap allocation + onHeapAllocationsTracker.accept(RamUsageEstimator.HASHTABLE_RAM_BYTES_PER_ENTRY + k.ramBytesUsed() + Integer.BYTES); + return 1; + } + return oldValue + 1; + }); + + return result; + }, term.limit() <= MAX_RECURSIVE_KEY_LENGTH); } catch (TrieSpaceExhaustedException e) { @@ -150,6 +238,14 @@ public synchronized void add(DecoratedKey key, } } + docLengths.put(primaryKey, tokenCount); + // heap used for term frequencies and doc lengths + long heapUsed = RamUsageEstimator.HASHTABLE_RAM_BYTES_PER_ENTRY + + primaryKey.ramBytesUsed() + + Integer.BYTES; + onHeapAllocationsTracker.accept(heapUsed); + + // memory used by the trie onHeapAllocationsTracker.accept((data.usedSizeOnHeap() - initialSizeOnHeap) + (primaryKeysReducer.heapAllocations() - reducerHeapSize)); offHeapAllocationsTracker.accept(data.usedSizeOffHeap() - initialSizeOffHeap); @@ -161,10 +257,10 @@ public synchronized void add(DecoratedKey key, } @Override - public Iterator> iterator() + public Iterator>> iterator() { Iterator> iterator = data.entrySet().iterator(); - return new Iterator>() + return new Iterator<>() { @Override public boolean hasNext() @@ -173,10 +269,17 @@ public boolean hasNext() } @Override - public Pair next() + public Pair> next() { Map.Entry entry = iterator.next(); - return Pair.create(entry.getKey(), entry.getValue()); + var pairs = new ArrayList(entry.getValue().size()); + for (PrimaryKey pk : entry.getValue().keys()) + { + var frequencyRaw = termFrequencies.get(new PkWithTerm(pk, entry.getKey())); + int frequency = frequencyRaw == null ? 1 : frequencyRaw; + pairs.add(new PkWithFrequency(pk, frequency)); + } + return Pair.create(entry.getKey(), pairs); } }; } @@ -418,7 +521,6 @@ private ByteComparable asByteComparable(ByteBuffer input) return Version.latest().onDiskFormat().encodeForTrie(input, indexContext.getValidator()); } - class PrimaryKeysReducer implements InMemoryTrie.UpsertTransformer { private final LongAdder heapAllocations = new LongAdder(); @@ -662,6 +764,13 @@ public void close() throws IOException } } + /** + * Iterator that provides ordered access to all indexed terms and their associated primary keys + * in the TrieMemoryIndex. For each term in the index, yields PrimaryKeyWithSortKey objects that + * combine a primary key with its associated term. + *

+ * A more verbose name could be KeysMatchingTermsByTermIterator. + */ private class AllTermsIterator extends AbstractIterator { private final Iterator> iterator; diff --git a/src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java b/src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java index f1cb9f5f47e7..dd440a4c42aa 100644 --- a/src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java +++ b/src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; +import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Objects; @@ -30,28 +31,37 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import com.google.common.collect.Iterators; import com.google.common.util.concurrent.Runnables; +import org.apache.cassandra.cql3.Operator; import org.apache.cassandra.db.Clustering; +import org.apache.cassandra.db.DataRange; import org.apache.cassandra.db.DecoratedKey; import org.apache.cassandra.db.PartitionPosition; +import org.apache.cassandra.db.RegularAndStaticColumns; +import org.apache.cassandra.db.filter.ColumnFilter; import org.apache.cassandra.db.marshal.AbstractType; import org.apache.cassandra.db.memtable.Memtable; import org.apache.cassandra.db.memtable.ShardBoundaries; import org.apache.cassandra.db.memtable.TrieMemtable; +import org.apache.cassandra.db.rows.Row; import org.apache.cassandra.dht.AbstractBounds; +import org.apache.cassandra.dht.Bounds; import org.apache.cassandra.index.sai.IndexContext; import org.apache.cassandra.index.sai.QueryContext; import org.apache.cassandra.index.sai.disk.format.Version; -import org.apache.cassandra.index.sai.iterators.KeyRangeLazyIterator; import org.apache.cassandra.index.sai.iterators.KeyRangeConcatIterator; +import org.apache.cassandra.index.sai.iterators.KeyRangeIntersectionIterator; import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; +import org.apache.cassandra.index.sai.iterators.KeyRangeLazyIterator; +import org.apache.cassandra.index.sai.memory.MemoryIndex.PkWithFrequency; import org.apache.cassandra.index.sai.plan.Expression; import org.apache.cassandra.index.sai.plan.Orderer; +import org.apache.cassandra.index.sai.utils.BM25Utils; import org.apache.cassandra.index.sai.utils.PrimaryKey; import org.apache.cassandra.index.sai.utils.PrimaryKeyWithByteComparable; import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey; -import org.apache.cassandra.index.sai.utils.PrimaryKeys; import org.apache.cassandra.index.sai.utils.TypeUtil; import org.apache.cassandra.sensors.Context; import org.apache.cassandra.sensors.RequestSensors; @@ -63,7 +73,6 @@ import org.apache.cassandra.utils.Reducer; import org.apache.cassandra.utils.SortingIterator; import org.apache.cassandra.utils.bytecomparable.ByteComparable; -import org.apache.cassandra.utils.bytecomparable.ByteSource; import org.apache.cassandra.utils.concurrent.OpOrder; public class TrieMemtableIndex implements MemtableIndex @@ -236,15 +245,46 @@ public List> orderBy(QueryContext query int startShard = boundaries.getShardForToken(keyRange.left.getToken()); int endShard = keyRange.right.isMinimum() ? boundaries.shardCount() - 1 : boundaries.getShardForToken(keyRange.right.getToken()); - var iterators = new ArrayList>(endShard - startShard + 1); - - for (int shard = startShard; shard <= endShard; ++shard) + if (!orderer.isBM25()) { - assert rangeIndexes[shard] != null; - iterators.add(rangeIndexes[shard].orderBy(orderer, slice)); + var iterators = new ArrayList>(endShard - startShard + 1); + for (int shard = startShard; shard <= endShard; ++shard) + { + assert rangeIndexes[shard] != null; + iterators.add(rangeIndexes[shard].orderBy(orderer, slice)); + } + return iterators; } - return iterators; + // BM25 + var queryTerms = orderer.getQueryTerms(); + + // Intersect iterators to find documents containing all terms + var termIterators = keyIteratorsPerTerm(queryContext, keyRange, queryTerms); + var intersectedIterator = KeyRangeIntersectionIterator.builder(termIterators).build(); + + // Compute BM25 scores + var docStats = computeDocumentFrequencies(queryContext, queryTerms); + var analyzer = indexContext.getAnalyzerFactory().create(); + var it = Iterators.transform(intersectedIterator, pk -> BM25Utils.DocTF.createFromDocument(pk, getCellForKey(pk), analyzer, queryTerms)); + return List.of(BM25Utils.computeScores(CloseableIterator.wrap(it), + queryTerms, + docStats, + indexContext, + memtable)); + } + + private List keyIteratorsPerTerm(QueryContext queryContext, AbstractBounds keyRange, List queryTerms) + { + List termIterators = new ArrayList<>(queryTerms.size()); + for (ByteBuffer term : queryTerms) + { + Expression expr = new Expression(indexContext); + expr.add(Operator.ANALYZER_MATCHES, term); + KeyRangeIterator iterator = search(queryContext, expr, keyRange, Integer.MAX_VALUE); + termIterators.add(iterator); + } + return termIterators; } @Override @@ -256,38 +296,105 @@ public long estimateMatchingRowsCount(Expression expression, AbstractBounds orderResultsBy(QueryContext context, List keys, Orderer orderer, int limit) + public CloseableIterator orderResultsBy(QueryContext queryContext, List keys, Orderer orderer, int limit) { if (keys.isEmpty()) return CloseableIterator.emptyIterator(); - return SortingIterator.createCloseable( - orderer.getComparator(), - keys, - key -> + + if (!orderer.isBM25()) + { + return SortingIterator.createCloseable( + orderer.getComparator(), + keys, + key -> + { + var partition = memtable.getPartition(key.partitionKey()); + if (partition == null) + return null; + var row = partition.getRow(key.clustering()); + if (row == null) + return null; + var cell = row.getCell(indexContext.getDefinition()); + if (cell == null) + return null; + + // We do two kinds of encoding... it'd be great to make this more straight forward, but this is what + // we have for now. I leave it to the reader to inspect the two methods to see the nuanced differences. + var encoding = encode(TypeUtil.encode(cell.buffer(), validator)); + return new PrimaryKeyWithByteComparable(indexContext, memtable, key, encoding); + }, + Runnables.doNothing() + ); + } + + // BM25 + var analyzer = indexContext.getAnalyzerFactory().create(); + var queryTerms = orderer.getQueryTerms(); + var docStats = computeDocumentFrequencies(queryContext, queryTerms); + var it = keys.stream().map(pk -> BM25Utils.DocTF.createFromDocument(pk, getCellForKey(pk), analyzer, queryTerms)).iterator(); + return BM25Utils.computeScores(CloseableIterator.wrap(it), + queryTerms, + docStats, + indexContext, + memtable); + } + + /** + * Count document frequencies for each term using brute force + */ + private BM25Utils.DocStats computeDocumentFrequencies(QueryContext queryContext, List queryTerms) + { + var termIterators = keyIteratorsPerTerm(queryContext, Bounds.unbounded(indexContext.getPartitioner()), queryTerms); + var documentFrequencies = new HashMap(); + for (int i = 0; i < queryTerms.size(); i++) + { + // KeyRangeIterator.getMaxKeys is not accurate enough, we have to count them + long keys = 0; + for (var it = termIterators.get(i); it.hasNext(); it.next()) + keys++; + documentFrequencies.put(queryTerms.get(i), keys); + } + long docCount = 0; + + // count all documents in the queried column + try (var it = memtable.makePartitionIterator(ColumnFilter.selection(RegularAndStaticColumns.of(indexContext.getDefinition())), + DataRange.allData(memtable.metadata().partitioner))) + { + while (it.hasNext()) { - var partition = memtable.getPartition(key.partitionKey()); - if (partition == null) - return null; - var row = partition.getRow(key.clustering()); - if (row == null) - return null; - var cell = row.getCell(indexContext.getDefinition()); - if (cell == null) - return null; - - // We do two kinds of encoding... it'd be great to make this more straight forward, but this is what - // we have for now. I leave it to the reader to inspect the two methods to see the nuanced differences. - var encoding = encode(TypeUtil.encode(cell.buffer(), validator)); - return new PrimaryKeyWithByteComparable(indexContext, memtable, key, encoding); - }, - Runnables.doNothing() - ); + var partitions = it.next(); + while (partitions.hasNext()) + { + var unfiltered = partitions.next(); + if (!unfiltered.isRow()) + continue; + var row = (Row) unfiltered; + var cell = row.getCell(indexContext.getDefinition()); + if (cell == null) + continue; + + docCount++; + } + } + } + return new BM25Utils.DocStats(documentFrequencies, docCount); + } + + @Nullable + private org.apache.cassandra.db.rows.Cell getCellForKey(PrimaryKey key) + { + var partition = memtable.getPartition(key.partitionKey()); + if (partition == null) + return null; + var row = partition.getRow(key.clustering()); + if (row == null) + return null; + return row.getCell(indexContext.getDefinition()); } private ByteComparable encode(ByteBuffer input) { - return indexContext.isLiteral() ? v -> ByteSource.preencoded(input) - : v -> TypeUtil.asComparableBytes(input, indexContext.getValidator(), v); + return Version.latest().onDiskFormat().encodeForTrie(input, indexContext.getValidator()); } /** @@ -302,26 +409,34 @@ private ByteComparable encode(ByteBuffer input) * @return iterator of indexed term to primary keys mapping in sorted by indexed term and primary key. */ @Override - public Iterator>> iterator(DecoratedKey min, DecoratedKey max) + public Iterator>> iterator(DecoratedKey min, DecoratedKey max) { int minSubrange = min == null ? 0 : boundaries.getShardForKey(min); int maxSubrange = max == null ? rangeIndexes.length - 1 : boundaries.getShardForKey(max); - List>> rangeIterators = new ArrayList<>(maxSubrange - minSubrange + 1); + List>>> rangeIterators = new ArrayList<>(maxSubrange - minSubrange + 1); for (int i = minSubrange; i <= maxSubrange; i++) rangeIterators.add(rangeIndexes[i].iterator()); - return MergeIterator.get(rangeIterators, (o1, o2) -> ByteComparable.compare(o1.left, o2.left, TypeUtil.BYTE_COMPARABLE_VERSION), + return MergeIterator.get(rangeIterators, + (o1, o2) -> ByteComparable.compare(o1.left, o2.left, TypeUtil.BYTE_COMPARABLE_VERSION), new PrimaryKeysMergeReducer(rangeIterators.size())); } - // The PrimaryKeysMergeReducer receives the range iterators from each of the range indexes selected based on the - // min and max keys passed to the iterator method. It doesn't strictly do any reduction because the terms in each - // range index are unique. It will receive at most one range index entry per selected range index before getReduced - // is called. - private static class PrimaryKeysMergeReducer extends Reducer, Pair>> + /** + * Used to merge sorted primary keys from multiple TrieMemoryIndex shards for a given indexed term. + * For each term that appears in multiple shards, the reducer: + * 1. Receives exactly one call to reduce() per shard containing that term + * 2. Merges all the primary keys for that term via getReduced() + * 3. Resets state via onKeyChange() before processing the next term + *

+ * While this follows the Reducer pattern, its "reduction" operation is a simple merge since each term + * appears at most once per shard, and each key will only be found in a given shard, so there are no values to aggregate; + * we simply combine and sort the primary keys from each shard that contains the term. + */ + private static class PrimaryKeysMergeReducer extends Reducer>, Pair>> { - private final Pair[] rangeIndexEntriesToMerge; + private final Pair>[] rangeIndexEntriesToMerge; private final Comparator comparator; private ByteComparable term; @@ -337,7 +452,7 @@ private static class PrimaryKeysMergeReducer extends Reducer termPair) + public void reduce(int index, Pair> termPair) { Preconditions.checkArgument(rangeIndexEntriesToMerge[index] == null, "Terms should be unique in the memory index"); @@ -348,17 +463,17 @@ public void reduce(int index, Pair termPair) @Override // Return a merger of the term keys for the term. - public Pair> getReduced() + public Pair> getReduced() { Preconditions.checkArgument(term != null, "The term must exist in the memory index"); - List> keyIterators = new ArrayList<>(rangeIndexEntriesToMerge.length); - for (Pair p : rangeIndexEntriesToMerge) - if (p != null && p.right != null && !p.right.isEmpty()) - keyIterators.add(p.right.iterator()); + var merged = new ArrayList(); + for (var p : rangeIndexEntriesToMerge) + if (p != null && p.right != null) + merged.addAll(p.right); - Iterator primaryKeys = MergeIterator.get(keyIterators, comparator, Reducer.getIdentity()); - return Pair.create(term, primaryKeys); + merged.sort((o1, o2) -> comparator.compare(o1.pk, o2.pk)); + return Pair.create(term, merged); } @Override diff --git a/src/java/org/apache/cassandra/index/sai/plan/Expression.java b/src/java/org/apache/cassandra/index/sai/plan/Expression.java index a1bd9acddc4b..aac8e240829b 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/Expression.java +++ b/src/java/org/apache/cassandra/index/sai/plan/Expression.java @@ -101,6 +101,7 @@ public static Op valueOf(Operator operator) return IN; case ANN: + case BM25: case ORDER_BY_ASC: case ORDER_BY_DESC: return ORDER_BY; @@ -250,6 +251,7 @@ public Expression add(Operator op, ByteBuffer value) boundedAnnEuclideanDistanceThreshold = GeoUtil.amplifiedEuclideanSimilarityThreshold(lower.value.vector, searchRadiusMeters); break; case ANN: + case BM25: case ORDER_BY_ASC: case ORDER_BY_DESC: // If we alread have an operation on the column, we don't need to set the ORDER_BY op because diff --git a/src/java/org/apache/cassandra/index/sai/plan/Operation.java b/src/java/org/apache/cassandra/index/sai/plan/Operation.java index 7722ada87463..8968d4390f5a 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/Operation.java +++ b/src/java/org/apache/cassandra/index/sai/plan/Operation.java @@ -93,7 +93,7 @@ protected static ListMultimap analyzeGroup(QueryCont analyzer.reset(e.getIndexValue()); // EQ/LIKE_*/NOT_EQ can have multiple expressions e.g. text = "Hello World", - // becomes text = "Hello" OR text = "World" because "space" is always interpreted as a split point (by analyzer), + // becomes text = "Hello" AND text = "World" because "space" is always interpreted as a split point (by analyzer), // CONTAINS/CONTAINS_KEY are always treated as multiple expressions since they currently only targetting // collections, NOT_EQ is made an independent expression only in case of pre-existing multiple EQ expressions, or // if there is no EQ operations and NOT_EQ is met or a single NOT_EQ expression present, @@ -102,6 +102,7 @@ protected static ListMultimap analyzeGroup(QueryCont boolean isMultiExpression = columnIsMultiExpression.getOrDefault(e.column(), Boolean.FALSE); switch (e.operator()) { + // case BM25: leave it at the default of `false` case EQ: // EQ operator will always be a multiple expression because it is being used by map entries isMultiExpression = indexContext.isNonFrozenCollection(); diff --git a/src/java/org/apache/cassandra/index/sai/plan/Orderer.java b/src/java/org/apache/cassandra/index/sai/plan/Orderer.java index e23c83ed91e2..af202bae5f5b 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/Orderer.java +++ b/src/java/org/apache/cassandra/index/sai/plan/Orderer.java @@ -19,9 +19,12 @@ package org.apache.cassandra.index.sai.plan; import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; import java.util.EnumSet; +import java.util.HashSet; +import java.util.List; import java.util.stream.Collectors; import javax.annotation.Nullable; @@ -41,12 +44,15 @@ public class Orderer { // The list of operators that are valid for order by clauses. static final EnumSet ORDER_BY_OPERATORS = EnumSet.of(Operator.ANN, + Operator.BM25, Operator.ORDER_BY_ASC, Operator.ORDER_BY_DESC); public final IndexContext context; public final Operator operator; - public final float[] vector; + public final ByteBuffer term; + private float[] vector; + private List queryTerms; /** * Create an orderer for the given index context, operator, and term. @@ -59,7 +65,7 @@ public Orderer(IndexContext context, Operator operator, ByteBuffer term) this.context = context; assert ORDER_BY_OPERATORS.contains(operator) : "Invalid operator for order by clause " + operator; this.operator = operator; - this.vector = context.getValidator().isVector() ? TypeUtil.decomposeVector(context.getValidator(), term) : null; + this.term = term; } public String getIndexName() @@ -75,8 +81,8 @@ public boolean isAscending() public Comparator getComparator() { - // ANN's PrimaryKeyWithSortKey is always descending, so we use the natural order for the priority queue - return isAscending() || isANN() ? Comparator.naturalOrder() : Comparator.reverseOrder(); + // ANN/BM25's PrimaryKeyWithSortKey is always descending, so we use the natural order for the priority queue + return (isAscending() || isANN() || isBM25()) ? Comparator.naturalOrder() : Comparator.reverseOrder(); } public boolean isLiteral() @@ -89,6 +95,11 @@ public boolean isANN() return operator == Operator.ANN; } + public boolean isBM25() + { + return operator == Operator.BM25; + } + @Nullable public static Orderer from(SecondaryIndexManager indexManager, RowFilter filter) { @@ -110,8 +121,38 @@ public static boolean isFilterExpressionOrderer(RowFilter.Expression expression) public String toString() { String direction = isAscending() ? "ASC" : "DESC"; - return isANN() - ? context.getColumnName() + " ANN OF " + Arrays.toString(vector) + ' ' + direction - : context.getColumnName() + ' ' + direction; + if (isANN()) + return context.getColumnName() + " ANN OF " + Arrays.toString(getVectorTerm()) + ' ' + direction; + if (isBM25()) + return context.getColumnName() + " BM25 OF " + TypeUtil.getString(term, context.getValidator()) + ' ' + direction; + return context.getColumnName() + ' ' + direction; + } + + public float[] getVectorTerm() + { + if (vector == null) + vector = TypeUtil.decomposeVector(context.getValidator(), term); + return vector; + } + + public List getQueryTerms() + { + if (queryTerms != null) + return queryTerms; + + var queryAnalyzer = context.getQueryAnalyzerFactory().create(); + // Split query into terms + var uniqueTerms = new HashSet(); + queryAnalyzer.reset(term); + try + { + queryAnalyzer.forEachRemaining(uniqueTerms::add); + } + finally + { + queryAnalyzer.end(); + } + queryTerms = new ArrayList<>(uniqueTerms); + return queryTerms; } } diff --git a/src/java/org/apache/cassandra/index/sai/plan/Plan.java b/src/java/org/apache/cassandra/index/sai/plan/Plan.java index e40906c73d52..d82b92905403 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/Plan.java +++ b/src/java/org/apache/cassandra/index/sai/plan/Plan.java @@ -1257,10 +1257,12 @@ protected double estimateSelectivity() @Override protected KeysIterationCost estimateCost() { - return ordering.isANN() - ? estimateAnnSortCost() - : estimateGlobalSortCost(); - + if (ordering.isANN()) + return estimateAnnSortCost(); + else if (ordering.isBM25()) + return estimateBm25SortCost(); + else + return estimateGlobalSortCost(); } private KeysIterationCost estimateAnnSortCost() @@ -1277,6 +1279,21 @@ private KeysIterationCost estimateAnnSortCost() return new KeysIterationCost(expectedKeys, initCost, searchCost); } + private KeysIterationCost estimateBm25SortCost() + { + double expectedKeys = access.expectedAccessCount(source.expectedKeys()); + + int termCount = ordering.getQueryTerms().size(); + // all of the cost for BM25 is up front since the index doesn't give us the information we need + // to return results in order, in isolation. The big cost is reading the indexed cells out of + // the sstables. + // VSTODO if we had stats on cell size _per column_ we could usefully include ROW_BYTE_COST + double initCost = source.fullCost() + + source.expectedKeys() * (hrs(ROW_CELL_COST) + ROW_CELL_COST) + + termCount * BM25_SCORE_COST; + return new KeysIterationCost(expectedKeys, initCost, 0); + } + private KeysIterationCost estimateGlobalSortCost() { return new KeysIterationCost(source.expectedKeys(), @@ -1310,19 +1327,51 @@ protected KeysSort withAccess(Access access) } /** - * Returns all keys in ANN order. - * Contrary to {@link KeysSort}, there is no input node here and the output is generated lazily. + * Base class for index scans that return results in a computed order (ANN, BM25) + * rather than the natural index order. */ - final static class AnnIndexScan extends Leaf + abstract static class ScoredIndexScan extends Leaf { final Orderer ordering; - AnnIndexScan(Factory factory, int id, Access access, Orderer ordering) + protected ScoredIndexScan(Factory factory, int id, Access access, Orderer ordering) { super(factory, id, access); this.ordering = ordering; } + @Nullable + @Override + protected Orderer ordering() + { + return ordering; + } + + @Override + protected double estimateSelectivity() + { + return 1.0; + } + + @Override + protected Iterator execute(Executor executor) + { + int softLimit = max(1, round((float) access.expectedAccessCount(factory.tableMetrics.rows))); + return executor.getTopKRows((Expression) null, softLimit); + } + } + + /** + * Returns all keys in ANN order. + * Contrary to {@link KeysSort}, there is no input node here and the output is generated lazily. + */ + final static class AnnIndexScan extends ScoredIndexScan + { + protected AnnIndexScan(Factory factory, int id, Access access, Orderer ordering) + { + super(factory, id, access, ordering); + } + @Override protected KeysIterationCost estimateCost() { @@ -1335,32 +1384,53 @@ protected KeysIterationCost estimateCost() return new KeysIterationCost(expectedKeys, initCost, searchCost); } - @Nullable @Override - protected Orderer ordering() + protected KeysIteration withAccess(Access access) { - return ordering; + return Objects.equals(access, this.access) + ? this + : new AnnIndexScan(factory, id, access, ordering); } + @Nullable @Override - protected Iterator execute(Executor executor) + protected IndexContext getIndexContext() { - int softLimit = max(1, round((float) access.expectedAccessCount(factory.tableMetrics.rows))); - return executor.getTopKRows((Expression) null, softLimit); + return ordering.context; } + } + /** + * Returns all keys in BM25 order. + * Like AnnIndexScan, this generates results lazily without an input node. + */ + final static class Bm25IndexScan extends ScoredIndexScan + { + protected Bm25IndexScan(Factory factory, int id, Access access, Orderer ordering) + { + super(factory, id, access, ordering); + } + + @Nonnull @Override - protected KeysIteration withAccess(Access access) + protected KeysIterationCost estimateCost() { - return Objects.equals(access, this.access) - ? this - : new AnnIndexScan(factory, id, access, ordering); + double expectedKeys = access.expectedAccessCount(factory.tableMetrics.rows); + int expectedKeysInt = Math.max(1, (int) Math.ceil(expectedKeys)); + + int termCount = ordering.getQueryTerms().size(); + double initCost = expectedKeysInt * (hrs(ROW_CELL_COST) + ROW_CELL_COST) + + termCount * BM25_SCORE_COST; + + return new KeysIterationCost(expectedKeys, initCost, 0); } @Override - protected double estimateSelectivity() + protected KeysIteration withAccess(Access access) { - return 1.0; + return Objects.equals(access, this.access) + ? this + : new Bm25IndexScan(factory, id, access, ordering); } @Override @@ -1684,6 +1754,8 @@ private KeysIteration indexScan(Expression predicate, long matchingKeysCount, Or if (ordering != null) if (ordering.isANN()) return new AnnIndexScan(this, id, defaultAccess, ordering); + else if (ordering.isBM25()) + return new Bm25IndexScan(this, id, defaultAccess, ordering); else if (ordering.isLiteral()) return new LiteralIndexScan(this, id, predicate, matchingKeysCount, defaultAccess, ordering); else @@ -1938,6 +2010,9 @@ public static class CostCoefficients /** Additional cost added to row fetch cost per each serialized byte of the row */ public final static double ROW_BYTE_COST = 0.005; + + /** Cost to perform BM25 scoring, per query term */ + public final static double BM25_SCORE_COST = 0.5; } /** Convenience builder for building intersection and union nodes */ diff --git a/src/java/org/apache/cassandra/index/sai/plan/QueryController.java b/src/java/org/apache/cassandra/index/sai/plan/QueryController.java index 15b0965fe82e..992b5f1334ab 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/QueryController.java +++ b/src/java/org/apache/cassandra/index/sai/plan/QueryController.java @@ -195,6 +195,11 @@ public TableMetadata metadata() return command.metadata(); } + public ReadCommand command() + { + return command; + } + RowFilter.FilterElement filterOperation() { // NOTE: we cannot remove the order by filter expression here yet because it is used in the FilterTree class @@ -883,6 +888,7 @@ private long estimateMatchingRowCount(Expression predicate) switch (predicate.getOp()) { case EQ: + case MATCH: case CONTAINS_KEY: case CONTAINS_VALUE: case NOT_EQ: diff --git a/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexQueryPlan.java b/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexQueryPlan.java index 22e564365bea..39a66cbad9de 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexQueryPlan.java +++ b/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexQueryPlan.java @@ -216,7 +216,7 @@ public Function postProcessor(ReadCommand return partitions -> partitions; // in case of top-k query, filter out rows that are not actually global top-K - return partitions -> new TopKProcessor(command).filter(partitions); + return partitions -> new TopKProcessor(command).reorder(partitions); } /** diff --git a/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java index 2f0cabd05c2c..f5ebe85a1ed9 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java @@ -30,12 +30,10 @@ import java.util.Queue; import java.util.function.Supplier; import java.util.stream.Collectors; - import javax.annotation.Nonnull; import javax.annotation.Nullable; import com.google.common.base.Preconditions; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -45,8 +43,13 @@ import org.apache.cassandra.db.PartitionPosition; import org.apache.cassandra.db.ReadCommand; import org.apache.cassandra.db.ReadExecutionController; +import org.apache.cassandra.db.filter.ColumnFilter; +import org.apache.cassandra.db.marshal.FloatType; import org.apache.cassandra.db.partitions.UnfilteredPartitionIterator; import org.apache.cassandra.db.rows.AbstractUnfilteredRowIterator; +import org.apache.cassandra.db.rows.BTreeRow; +import org.apache.cassandra.db.rows.BufferCell; +import org.apache.cassandra.db.rows.ColumnData; import org.apache.cassandra.db.rows.Row; import org.apache.cassandra.db.rows.Unfiltered; import org.apache.cassandra.db.rows.UnfilteredRowIterator; @@ -60,13 +63,16 @@ import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; import org.apache.cassandra.index.sai.metrics.TableQueryMetrics; import org.apache.cassandra.index.sai.utils.PrimaryKey; -import org.apache.cassandra.index.sai.utils.RangeUtil; +import org.apache.cassandra.index.sai.utils.PrimaryKeyWithScore; import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey; +import org.apache.cassandra.index.sai.utils.RangeUtil; import org.apache.cassandra.io.util.FileUtils; +import org.apache.cassandra.schema.ColumnMetadata; import org.apache.cassandra.schema.TableMetadata; import org.apache.cassandra.utils.AbstractIterator; import org.apache.cassandra.utils.CloseableIterator; import org.apache.cassandra.utils.FBUtilities; +import org.apache.cassandra.utils.btree.BTree; public class StorageAttachedIndexSearcher implements Index.Searcher { @@ -109,19 +115,17 @@ public UnfilteredPartitionIterator search(ReadExecutionController executionContr // Can't check for `command.isTopK()` because the planner could optimize sorting out Orderer ordering = plan.ordering(); - if (ordering != null) - { - assert !(keysIterator instanceof KeyRangeIterator); - var scoredKeysIterator = (CloseableIterator) keysIterator; - var result = new ScoreOrderedResultRetriever(scoredKeysIterator, filterTree, controller, - executionController, queryContext, command.limits().count()); - return new TopKProcessor(command).filter(result); - } - else + if (ordering == null) { assert keysIterator instanceof KeyRangeIterator; return new ResultRetriever((KeyRangeIterator) keysIterator, filterTree, controller, executionController, queryContext); } + + assert !(keysIterator instanceof KeyRangeIterator); + var scoredKeysIterator = (CloseableIterator) keysIterator; + var result = new ScoreOrderedResultRetriever(scoredKeysIterator, filterTree, controller, + executionController, queryContext, command.limits().count()); + return new TopKProcessor(command).filter(result); } catch (QueryView.Builder.MissingIndexException e) { @@ -521,48 +525,50 @@ public UnfilteredRowIterator computeNext() */ private void fillPendingRows() { + // Group PKs by source sstable/memtable + var groupedKeys = new HashMap>(); // We always want to get at least 1. int rowsToRetrieve = Math.max(1, softLimit - returnedRowCount); - var keys = new HashMap>(); // We want to get the first unique `rowsToRetrieve` keys to materialize // Don't pass the priority queue here because it is more efficient to add keys in bulk - fillKeys(keys, rowsToRetrieve, null); + fillKeys(groupedKeys, rowsToRetrieve, null); // Sort the primary keys by PrK order, just in case that helps with cache and disk efficiency - var primaryKeyPriorityQueue = new PriorityQueue<>(keys.keySet()); + var primaryKeyPriorityQueue = new PriorityQueue<>(groupedKeys.keySet()); - while (!keys.isEmpty()) + // drain groupedKeys into pendingRows + while (!groupedKeys.isEmpty()) { - var primaryKey = primaryKeyPriorityQueue.poll(); - var primaryKeyWithSortKeys = keys.remove(primaryKey); - var partitionIterator = readAndValidatePartition(primaryKey, primaryKeyWithSortKeys); + var pk = primaryKeyPriorityQueue.poll(); + var sourceKeys = groupedKeys.remove(pk); + var partitionIterator = readAndValidatePartition(pk, sourceKeys); if (partitionIterator != null) pendingRows.add(partitionIterator); else // The current primaryKey did not produce a partition iterator. We know the caller will need // `rowsToRetrieve` rows, so we get the next unique key and add it to the queue. - fillKeys(keys, 1, primaryKeyPriorityQueue); + fillKeys(groupedKeys, 1, primaryKeyPriorityQueue); } } /** - * Fills the keys map with the next `count` unique primary keys that are in the keys produced by calling + * Fills the `groupedKeys` Map with the next `count` unique primary keys that are in the keys produced by calling * {@link #nextSelectedKeyInRange()}. We map PrimaryKey to List because the same * primary key can be in the result set multiple times, but with different source tables. - * @param keys the map to fill + * @param groupedKeys the map to fill * @param count the number of unique PrimaryKeys to consume from the iterator * @param primaryKeyPriorityQueue the priority queue to add new keys to. If the queue is null, we do not add * keys to the queue. */ - private void fillKeys(Map> keys, int count, PriorityQueue primaryKeyPriorityQueue) + private void fillKeys(Map> groupedKeys, int count, PriorityQueue primaryKeyPriorityQueue) { - int initialSize = keys.size(); - while (keys.size() - initialSize < count) + int initialSize = groupedKeys.size(); + while (groupedKeys.size() - initialSize < count) { var primaryKeyWithSortKey = nextSelectedKeyInRange(); if (primaryKeyWithSortKey == null) return; var nextPrimaryKey = primaryKeyWithSortKey.primaryKey(); - var accumulator = keys.computeIfAbsent(nextPrimaryKey, k -> new ArrayList<>()); + var accumulator = groupedKeys.computeIfAbsent(nextPrimaryKey, k -> new ArrayList<>()); if (primaryKeyPriorityQueue != null && accumulator.isEmpty()) primaryKeyPriorityQueue.add(nextPrimaryKey); accumulator.add(primaryKeyWithSortKey); @@ -602,15 +608,29 @@ private boolean isInRange(DecoratedKey key) return null; } - public UnfilteredRowIterator readAndValidatePartition(PrimaryKey key, List primaryKeys) + /** + * Reads and validates a partition for a given primary key against its sources. + *

+ * @param pk The primary key of the partition to read and validate + * @param sourceKeys A list of PrimaryKeyWithSortKey objects associated with the primary key. + * Multiple sort keys can exist for the same primary key when data comes from different + * sstables or memtables. + * + * @return An UnfilteredRowIterator containing the validated partition data, or null if: + * - The key has already been processed + * - The partition does not pass index filters + * - The partition contains no valid rows + * - The row data does not match the index metadata for any of the provided primary keys + */ + public UnfilteredRowIterator readAndValidatePartition(PrimaryKey pk, List sourceKeys) { // If we've already processed the key, we can skip it. Because the score ordered iterator does not // deduplicate rows, we could see dupes if a row is in the ordering index multiple times. This happens // in the case of dupes and of overwrites. - if (processedKeys.contains(key)) + if (processedKeys.contains(pk)) return null; - try (UnfilteredRowIterator partition = controller.getPartition(key, view, executionController)) + try (UnfilteredRowIterator partition = controller.getPartition(pk, view, executionController)) { queryContext.addPartitionsRead(1); queryContext.checkpoint(); @@ -619,7 +639,7 @@ public UnfilteredRowIterator readAndValidatePartition(PrimaryKey key, List primaryKeysWithScore, ReadCommand command) { super(partition.metadata(), partition.partitionKey(), @@ -662,7 +694,47 @@ public PrimaryKeyIterator(UnfilteredRowIterator partition, Row staticRow, Unfilt partition.isReverseOrder(), partition.stats()); - row = content; + assert !primaryKeysWithScore.isEmpty(); + var isScoredRow = primaryKeysWithScore.get(0) instanceof PrimaryKeyWithScore; + if (!content.isRow() || !isScoredRow) + { + this.row = content; + return; + } + + // When +score is added on the coordinator side, it's represented as a PrecomputedColumnFilter + // even in a 'SELECT *' because WCF is not capable of representing synthetic columns. + // This can be simplified when we remove ANN_USE_SYNTHETIC_SCORE + var tm = metadata(); + var scoreColumn = ColumnMetadata.syntheticColumn(tm.keyspace, + tm.name, + ColumnMetadata.SYNTHETIC_SCORE_ID, + FloatType.instance); + var isScoreFetched = command.columnFilter().fetchesExplicitly(scoreColumn); + if (!isScoreFetched) + { + this.row = content; + return; + } + + // Clone the original Row + Row originalRow = (Row) content; + ArrayList columnData = new ArrayList<>(originalRow.columnCount() + 1); + columnData.addAll(originalRow.columnData()); + + // inject +score as a new column + var pkWithScore = (PrimaryKeyWithScore) primaryKeysWithScore.get(0); + columnData.add(BufferCell.live(scoreColumn, + FBUtilities.nowInSeconds(), + FloatType.instance.decompose(pkWithScore.indexScore))); + + this.row = BTreeRow.create(originalRow.clustering(), + originalRow.primaryKeyLivenessInfo(), + originalRow.deletion(), + BTree.builder(ColumnData.comparator) + .auto(true) + .addAll(columnData) + .build()); } @Override @@ -674,18 +746,6 @@ protected Unfiltered computeNext() return row; } } - - @Override - public TableMetadata metadata() - { - return controller.metadata(); - } - - public void close() - { - FileUtils.closeQuietly(scoredPrimaryKeyIterator); - controller.finish(); - } } private static UnfilteredRowIterator applyIndexFilter(UnfilteredRowIterator partition, FilterTree tree, QueryContext queryContext) diff --git a/src/java/org/apache/cassandra/index/sai/plan/TopKProcessor.java b/src/java/org/apache/cassandra/index/sai/plan/TopKProcessor.java index a2e4b315f74f..5760109328b7 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/TopKProcessor.java +++ b/src/java/org/apache/cassandra/index/sai/plan/TopKProcessor.java @@ -27,8 +27,6 @@ import java.util.SortedSet; import java.util.TreeMap; import java.util.TreeSet; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionException; import javax.annotation.Nullable; import org.apache.commons.lang3.tuple.Triple; @@ -38,28 +36,25 @@ import io.github.jbellis.jvector.vector.VectorizationProvider; import io.github.jbellis.jvector.vector.types.VectorFloat; import io.github.jbellis.jvector.vector.types.VectorTypeSupport; -import org.apache.cassandra.concurrent.ImmediateExecutor; -import org.apache.cassandra.concurrent.LocalAwareExecutorService; -import org.apache.cassandra.concurrent.SharedExecutorPool; -import org.apache.cassandra.config.CassandraRelevantProperties; import org.apache.cassandra.cql3.Operator; +import org.apache.cassandra.cql3.statements.SelectStatement; import org.apache.cassandra.db.ColumnFamilyStore; import org.apache.cassandra.db.DecoratedKey; import org.apache.cassandra.db.Keyspace; import org.apache.cassandra.db.ReadCommand; import org.apache.cassandra.db.filter.RowFilter; -import org.apache.cassandra.db.partitions.BasePartitionIterator; -import org.apache.cassandra.db.partitions.ParallelCommandProcessor; +import org.apache.cassandra.db.marshal.FloatType; import org.apache.cassandra.db.partitions.PartitionIterator; import org.apache.cassandra.db.partitions.UnfilteredPartitionIterator; import org.apache.cassandra.db.rows.BaseRowIterator; +import org.apache.cassandra.db.rows.Cell; import org.apache.cassandra.db.rows.Row; import org.apache.cassandra.db.rows.Unfiltered; import org.apache.cassandra.index.Index; import org.apache.cassandra.index.SecondaryIndexManager; import org.apache.cassandra.index.sai.IndexContext; import org.apache.cassandra.index.sai.StorageAttachedIndex; -import org.apache.cassandra.index.sai.utils.AbortedOperationException; +import org.apache.cassandra.index.sai.plan.StorageAttachedIndexSearcher.ScoreOrderedResultRetriever; import org.apache.cassandra.index.sai.utils.InMemoryPartitionIterator; import org.apache.cassandra.index.sai.utils.InMemoryUnfilteredPartitionIterator; import org.apache.cassandra.index.sai.utils.PartitionInfo; @@ -72,33 +67,30 @@ import static org.apache.cassandra.cql3.statements.RequestValidations.invalidRequest; /** - * Processor applied to SAI based ORDER BY queries. This class could likely be refactored into either two filter - * methods depending on where the processing is happening or into two classes. - *

- * This processor performs the following steps on a replica: - * - collect LIMIT rows from partition iterator, making sure that all are valid. - * - return rows in Primary Key order - *

- * This processor performs the following steps on a coordinator: - * - consume all rows from the provided partition iterator and sort them according to the specified order. - * For vectors, that is similarit score and for all others, that is the ordering defined by their - * {@link org.apache.cassandra.db.marshal.AbstractType}. If there are multiple vector indexes, - * the final score is the sum of all vector index scores. - * - remove rows with the lowest scores from PQ if PQ size exceeds limit - * - return rows from PQ in primary key order to caller + * Processor applied to SAI based ORDER BY queries. + * + * * On a replica: + * * - filter(ScoreOrderedResultRetriever) is used to collect up to the top-K rows. + * * - We store any tombstones as well, to avoid losing them during coordinator reconciliation. + * * - The result is returned in PK order so that coordinator can merge from multiple replicas. + * + * On a coordinator: + * - reorder(PartitionIterator) is used to consume all rows from the provided partitions, + * compute the order based on either a column ordering or a similarity score, and keep top-K. + * - The result is returned in score/sortkey order. */ public class TopKProcessor { public static final String INDEX_MAY_HAVE_BEEN_DROPPED = "An index may have been dropped. Ordering on non-clustering " + "column requires the column to be indexed"; protected static final Logger logger = LoggerFactory.getLogger(TopKProcessor.class); - private static final LocalAwareExecutorService PARALLEL_EXECUTOR = getExecutor(); private static final VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport(); private final ReadCommand command; private final IndexContext indexContext; private final RowFilter.Expression expression; private final VectorFloat queryVector; + private final ColumnMetadata scoreColumn; private final int limit; @@ -106,52 +98,29 @@ public TopKProcessor(ReadCommand command) { this.command = command; - Pair annIndexAndExpression = findTopKIndexContext(); + Pair indexAndExpression = findTopKIndexContext(); // this can happen in case an index was dropped after the query was initiated - if (annIndexAndExpression == null) + if (indexAndExpression == null) throw invalidRequest(INDEX_MAY_HAVE_BEEN_DROPPED); - this.indexContext = annIndexAndExpression.left; - this.expression = annIndexAndExpression.right; - if (expression.operator() == Operator.ANN) + this.indexContext = indexAndExpression.left; + this.expression = indexAndExpression.right; + if (expression.operator() == Operator.ANN && !SelectStatement.ANN_USE_SYNTHETIC_SCORE) this.queryVector = vts.createFloatVector(TypeUtil.decomposeVector(indexContext, expression.getIndexValue().duplicate())); else this.queryVector = null; this.limit = command.limits().count(); - } - - /** - * Executor to use for parallel index reads. - * Defined by -Dcassandra.index_read.parallele=true/false, true by default. - *

- * INDEX_READ uses 2 * cpus threads by default but can be overridden with -Dcassandra.index_read.parallel_thread_num= - * - * @return stage to use, default INDEX_READ - */ - private static LocalAwareExecutorService getExecutor() - { - boolean isParallel = CassandraRelevantProperties.USE_PARALLEL_INDEX_READ.getBoolean(); - - if (isParallel) - { - int numThreads = CassandraRelevantProperties.PARALLEL_INDEX_READ_NUM_THREADS.isPresent() - ? CassandraRelevantProperties.PARALLEL_INDEX_READ_NUM_THREADS.getInt() - : FBUtilities.getAvailableProcessors() * 2; - return SharedExecutorPool.SHARED.newExecutor(numThreads, maximumPoolSize -> {}, "request", "IndexParallelRead"); - } - else - return ImmediateExecutor.INSTANCE; + this.scoreColumn = ColumnMetadata.syntheticColumn(indexContext.getKeyspace(), indexContext.getTable(), ColumnMetadata.SYNTHETIC_SCORE_ID, FloatType.instance); } /** * Sort the specified filtered rows according to the {@code ORDER BY} clause and keep the first {@link #limit} rows. - * This is meant to be used on the coordinator-side to sort the rows collected from the replicas. - * Caller must close the supplied iterator. + * Called on the coordinator side. * - * @param partitions the partitions collected by the coordinator - * @return the provided rows, sorted by the requested {@code ORDER BY} chriteria and trimmed to {@link #limit} rows + * @param partitions the partitions collected by the coordinator. It will be closed as a side-effect. + * @return the provided rows, sorted and trimmed to {@link #limit} rows */ - public PartitionIterator filter(PartitionIterator partitions) + public PartitionIterator reorder(PartitionIterator partitions) { // We consume the partitions iterator and create a new one. Use a try-with-resources block to ensure the // original iterator is closed. We do not expect exceptions here, but if they happen, we want to make sure the @@ -159,17 +128,35 @@ public PartitionIterator filter(PartitionIterator partitions) try (partitions) { Comparator> comparator = comparator() - .thenComparing(Triple::getLeft, Comparator.comparing(p -> p.key)) - .thenComparing(Triple::getMiddle, command.metadata().comparator); + .thenComparing(Triple::getLeft, Comparator.comparing(pi -> pi.key)) + .thenComparing(Triple::getMiddle, command.metadata().comparator); TopKSelector> topK = new TopKSelector<>(comparator, limit); + while (partitions.hasNext()) + { + try (BaseRowIterator partitionRowIterator = partitions.next()) + { + if (expression.operator() == Operator.ANN || expression.operator() == Operator.BM25) + { + PartitionResults pr = processScoredPartition(partitionRowIterator); + topK.addAll(pr.rows); + } + else + { + while (partitionRowIterator.hasNext()) + { + Row row = (Row) partitionRowIterator.next(); + ByteBuffer value = row.getCell(expression.column()).buffer(); + topK.add(Triple.of(PartitionInfo.create(partitionRowIterator), row, value)); + } + } + } + } - processPartitions(partitions, topK, null); - + // Convert the topK results to a PartitionIterator List> sortedRows = new ArrayList<>(topK.size()); for (Triple triple : topK.getShared()) sortedRows.add(Pair.create(triple.getLeft(), triple.getMiddle())); - return InMemoryPartitionIterator.create(command, sortedRows); } } @@ -186,149 +173,53 @@ public PartitionIterator filter(PartitionIterator partitions) *

* All tombstones will be kept. Caller must close the supplied iterator. * - * @param partitions the partitions collected in the replica side of a query + * @param partitions the partitions collected in the replica side of a query. It will be closed as a side-effect. * @return the provided rows, sorted by the requested {@code ORDER BY} chriteria, trimmed to {@link #limit} rows, * and the sorted again by primary key. */ - public UnfilteredPartitionIterator filter(UnfilteredPartitionIterator partitions) + public UnfilteredPartitionIterator filter(ScoreOrderedResultRetriever partitions) { - // We consume the partitions iterator and create a new one. Use a try-with-resources block to ensure the - // original iterator is closed. We do not expect exceptions here, but if they happen, we want to make sure the - // original iterator is closed to prevent leaking resources, which could compound the effect of an exception. try (partitions) { - TopKSelector> topK = new TopKSelector<>(comparator(), limit); - - TreeMap> unfilteredByPartition = new TreeMap<>(Comparator.comparing(p -> p.key)); - - processPartitions(partitions, topK, unfilteredByPartition); - - // Reorder the rows by primary key. - for (var triple : topK.getUnsortedShared()) - addUnfiltered(unfilteredByPartition, triple.getLeft(), triple.getMiddle()); - - return new InMemoryUnfilteredPartitionIterator(command, unfilteredByPartition); - } - } - - private Comparator> comparator() - { - Comparator> comparator; - if (queryVector != null) - { - comparator = Comparator.comparing((Triple t) -> (Float) t.getRight()).reversed(); - } - else - { - comparator = Comparator.comparing(t -> (ByteBuffer) t.getRight(), indexContext.getValidator()); - if (expression.operator() == Operator.ORDER_BY_DESC) - comparator = comparator.reversed(); - } - return comparator; - } - - private , P extends BasePartitionIterator> - void processPartitions(P partitions, - TopKSelector> topK, - @Nullable TreeMap> unfilteredByPartition) - { - if (PARALLEL_EXECUTOR != ImmediateExecutor.INSTANCE && partitions instanceof ParallelCommandProcessor) - { - ParallelCommandProcessor pIter = (ParallelCommandProcessor) partitions; - var commands = pIter.getUninitializedCommands(); - List> results = new ArrayList<>(commands.size()); + TreeMap> unfilteredByPartition = new TreeMap<>(Comparator.comparing(pi -> pi.key)); - int count = commands.size(); - for (var command: commands) { - CompletableFuture future = new CompletableFuture<>(); - results.add(future); - - // run last command immediately, others in parallel (if possible) - count--; - var executor = count == 0 ? ImmediateExecutor.INSTANCE : PARALLEL_EXECUTOR; - - executor.maybeExecuteImmediately(() -> { - try (var partitionRowIterator = pIter.commandToIterator(command.left(), command.right())) - { - future.complete(partitionRowIterator == null ? null : processPartition(partitionRowIterator)); - } - catch (Throwable t) - { - future.completeExceptionally(t); - } - }); - } - - for (CompletableFuture triplesFuture: results) - { - PartitionResults pr; - try - { - pr = triplesFuture.join(); - } - catch (CompletionException t) - { - if (t.getCause() instanceof AbortedOperationException) - throw (AbortedOperationException) t.getCause(); - throw t; - } - if (pr == null) - continue; - topK.addAll(pr.rows); - if (unfilteredByPartition != null) - { - for (var uf : pr.tombstones) - addUnfiltered(unfilteredByPartition, pr.partitionInfo, uf); - } - } - } - else if (partitions instanceof StorageAttachedIndexSearcher.ScoreOrderedResultRetriever) - { - // FilteredPartitions does not implement ParallelizablePartitionIterator. - // Realistically, this won't benefit from parallelizm as these are coming from in-memory/memtable data. int rowsMatched = 0; - // Check rowsMatched first to prevent fetching one more partition than needed. + // Because each “partition” from ScoreOrderedResultRetriever is actually a single row + // or tombstone, we can simply read them until we have enough. while (rowsMatched < limit && partitions.hasNext()) { - // Must close to move to the next partition, otherwise hasNext() fails - try (var partitionRowIterator = partitions.next()) + try (BaseRowIterator partitionRowIterator = partitions.next()) { rowsMatched += processSingleRowPartition(unfilteredByPartition, partitionRowIterator); } } + + return new InMemoryUnfilteredPartitionIterator(command, unfilteredByPartition); } - else + } + + /** + * Constructs a comparator for triple (PartitionInfo, Row, X) used for top-K ranking. + * For ANN/BM25 we compare descending by X (float score). For ORDER_BY_ASC or DESC, + * we compare ascending/descending by the row’s relevant ByteBuffer data. + */ + private Comparator> comparator() + { + if (expression.operator() == Operator.ANN || expression.operator() == Operator.BM25) { - // FilteredPartitions does not implement ParallelizablePartitionIterator. - // Realistically, this won't benefit from parallelizm as these are coming from in-memory/memtable data. - while (partitions.hasNext()) - { - // have to close to move to the next partition, otherwise hasNext() fails - try (var partitionRowIterator = partitions.next()) - { - if (queryVector != null) - { - PartitionResults pr = processPartition(partitionRowIterator); - topK.addAll(pr.rows); - if (unfilteredByPartition != null) - { - for (var uf : pr.tombstones) - addUnfiltered(unfilteredByPartition, pr.partitionInfo, uf); - } - } - else - { - while (partitionRowIterator.hasNext()) - { - Row row = (Row) partitionRowIterator.next(); - topK.add(Triple.of(PartitionInfo.create(partitionRowIterator), row, row.getCell(expression.column()).buffer())); - } - } - } - } + // For similarity, higher is better, so reversed + return Comparator.comparing((Triple t) -> (Float) t.getRight()).reversed(); } + + Comparator> comparator = Comparator.comparing(t -> (ByteBuffer) t.getRight(), indexContext.getValidator()); + if (expression.operator() == Operator.ORDER_BY_DESC) + comparator = comparator.reversed(); + return comparator; } + /** + * Simple holder for partial results of a single partition (score-based path). + */ private class PartitionResults { final PartitionInfo partitionInfo; @@ -352,9 +243,9 @@ void addRow(Triple triple) } /** - * Processes a single partition, calculating scores for rows and extracting tombstones. + * Processes all rows in a single partition to compute scores (for ANN or BM25) */ - private PartitionResults processPartition(BaseRowIterator partitionRowIterator) + private PartitionResults processScoredPartition(BaseRowIterator partitionRowIterator) { // Compute key and static row score once per partition DecoratedKey key = partitionRowIterator.partitionKey(); @@ -400,15 +291,14 @@ private int processSingleRowPartition(TreeMap return unfiltered.isRangeTombstoneMarker() ? 0 : 1; } - private void addUnfiltered(SortedMap> unfilteredByPartition, PartitionInfo partitionInfo, Unfiltered unfiltered) + private void addUnfiltered(SortedMap> unfilteredByPartition, + PartitionInfo partitionInfo, + Unfiltered unfiltered) { var map = unfilteredByPartition.computeIfAbsent(partitionInfo, k -> new TreeSet<>(command.metadata().comparator)); map.add(unfiltered); } - /** - * Sum the scores from different vector indexes for the row - */ private float getScoreForRow(DecoratedKey key, Row row) { ColumnMetadata column = indexContext.getDefinition(); @@ -422,6 +312,15 @@ private float getScoreForRow(DecoratedKey key, Row row) if ((column.isClusteringColumn() || column.isRegular()) && row.isStatic()) return 0; + // If we have a synthetic score column, use it + var scoreData = row.getColumnData(scoreColumn); + if (scoreData != null) + { + var cell = (Cell) scoreData; + return FloatType.instance.compose(cell.buffer()); + } + + // TODO remove this once we enable ANN_USE_SYNTHETIC_SCORE ByteBuffer value = indexContext.getValueOf(key, row, FBUtilities.nowInSeconds()); if (value != null) { @@ -431,28 +330,30 @@ private float getScoreForRow(DecoratedKey key, Row row) return 0; } - private Pair findTopKIndexContext() { ColumnFamilyStore cfs = Keyspace.openAndGetStore(command.metadata()); for (RowFilter.Expression expression : command.rowFilter().expressions()) { - StorageAttachedIndex sai = findVectorIndexFor(cfs.indexManager, expression); + StorageAttachedIndex sai = findOrderingIndexFor(cfs.indexManager, expression); if (sai != null) - { return Pair.create(sai.getIndexContext(), expression); - } } return null; } @Nullable - private StorageAttachedIndex findVectorIndexFor(SecondaryIndexManager sim, RowFilter.Expression e) + private StorageAttachedIndex findOrderingIndexFor(SecondaryIndexManager sim, RowFilter.Expression e) { - if (e.operator() != Operator.ANN && e.operator() != Operator.ORDER_BY_ASC && e.operator() != Operator.ORDER_BY_DESC) + if (e.operator() != Operator.ANN + && e.operator() != Operator.BM25 + && e.operator() != Operator.ORDER_BY_ASC + && e.operator() != Operator.ORDER_BY_DESC) + { return null; + } Optional index = sim.getBestIndexFor(e); return (StorageAttachedIndex) index.filter(i -> i instanceof StorageAttachedIndex).orElse(null); diff --git a/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java b/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java new file mode 100644 index 000000000000..ae101c6f0373 --- /dev/null +++ b/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.utils; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +import org.apache.cassandra.db.memtable.Memtable; +import org.apache.cassandra.db.rows.Cell; +import org.apache.cassandra.index.sai.IndexContext; +import org.apache.cassandra.index.sai.analyzer.AbstractAnalyzer; +import org.apache.cassandra.io.sstable.SSTableId; +import org.apache.cassandra.io.util.FileUtils; +import org.apache.cassandra.utils.CloseableIterator; + +public class BM25Utils +{ + private static final float K1 = 1.2f; // BM25 term frequency saturation parameter + private static final float B = 0.75f; // BM25 length normalization parameter + + /** + * Term frequencies across all documents. Each document is only counted once. + */ + public static class DocStats + { + // Map of term -> count of docs containing that term + private final Map frequencies; + // total number of docs in the index + private final long docCount; + + public DocStats(Map frequencies, long docCount) + { + this.frequencies = frequencies; + this.docCount = docCount; + } + } + + /** + * Term frequencies within a single document. All instances of a term are counted. + */ + public static class DocTF + { + private final PrimaryKey pk; + private final Map frequencies; + private final int termCount; + + public DocTF(PrimaryKey pk, int termCount, Map frequencies) + { + this.pk = pk; + this.frequencies = frequencies; + this.termCount = termCount; + } + + public int getTermFrequency(ByteBuffer term) + { + return frequencies.getOrDefault(term, 0); + } + + public static DocTF createFromDocument(PrimaryKey pk, + Cell cell, + AbstractAnalyzer docAnalyzer, + Collection queryTerms) + { + int count = 0; + Map frequencies = new HashMap<>(); + + docAnalyzer.reset(cell.buffer()); + try + { + while (docAnalyzer.hasNext()) + { + ByteBuffer term = docAnalyzer.next(); + count++; + if (queryTerms.contains(term)) + frequencies.merge(term, 1, Integer::sum); + } + } + finally + { + docAnalyzer.end(); + } + + return new DocTF(pk, count, frequencies); + } + } + + public static CloseableIterator computeScores(CloseableIterator docIterator, + List queryTerms, + DocStats docStats, + IndexContext indexContext, + Object source) + { + // data structures for document stats and frequencies + ArrayList documents = new ArrayList<>(); + double totalTermCount = 0; + + // Compute TF within each document + while (docIterator.hasNext()) + { + var tf = docIterator.next(); + documents.add(tf); + totalTermCount += tf.termCount; + } + if (documents.isEmpty()) + return CloseableIterator.emptyIterator(); + + // Calculate average document length + double avgDocLength = totalTermCount / documents.size(); + + // Calculate BM25 scores + var scoredDocs = new ArrayList(documents.size()); + for (var doc : documents) + { + double score = 0.0; + for (var queryTerm : queryTerms) + { + int tf = doc.getTermFrequency(queryTerm); + Long df = docStats.frequencies.get(queryTerm); + // we shouldn't have more hits for a term than we counted total documents + assert df <= docStats.docCount : String.format("df=%d, totalDocs=%d", df, docStats.docCount); + + double normalizedTf = tf / (tf + K1 * (1 - B + B * doc.termCount / avgDocLength)); + double idf = Math.log(1 + (docStats.docCount - df + 0.5) / (df + 0.5)); + double deltaScore = normalizedTf * idf; + assert deltaScore >= 0 : String.format("BM25 score for tf=%d, df=%d, tc=%d, totalDocs=%d is %f", + tf, df, doc.termCount, docStats.docCount, deltaScore); + score += deltaScore; + } + if (source instanceof Memtable) + scoredDocs.add(new PrimaryKeyWithScore(indexContext, (Memtable) source, doc.pk, (float) score)); + else if (source instanceof SSTableId) + scoredDocs.add(new PrimaryKeyWithScore(indexContext, (SSTableId) source, doc.pk, (float) score)); + else + throw new IllegalArgumentException("Invalid source " + source.getClass()); + } + + // sort by score (PKWS implements Comparator correctly for us) + Collections.sort(scoredDocs); + + return new CloseableIterator<>() + { + private final Iterator iterator = scoredDocs.iterator(); + + @Override + public boolean hasNext() + { + return iterator.hasNext(); + } + + @Override + public PrimaryKeyWithSortKey next() + { + return iterator.next(); + } + + @Override + public void close() + { + FileUtils.closeQuietly(docIterator); + } + }; + } +} diff --git a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKey.java b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKey.java index 516e26caaeea..0bcd16c0a7f3 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKey.java +++ b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKey.java @@ -19,6 +19,7 @@ import java.util.function.Supplier; +import io.github.jbellis.jvector.util.Accountable; import org.apache.cassandra.db.Clustering; import org.apache.cassandra.db.ClusteringComparator; import org.apache.cassandra.db.DecoratedKey; @@ -38,7 +39,7 @@ * For the V2 on-disk format the {@link DecoratedKey} and {@link Clustering} are supported. * */ -public interface PrimaryKey extends Comparable +public interface PrimaryKey extends Comparable, Accountable { /** * A factory for creating {@link PrimaryKey} instances diff --git a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithByteComparable.java b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithByteComparable.java index 837c032b7952..bb931c942fbf 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithByteComparable.java +++ b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithByteComparable.java @@ -21,7 +21,10 @@ import java.nio.ByteBuffer; import java.util.Arrays; +import io.github.jbellis.jvector.util.RamUsageEstimator; +import org.apache.cassandra.db.memtable.Memtable; import org.apache.cassandra.index.sai.IndexContext; +import org.apache.cassandra.io.sstable.SSTableId; import org.apache.cassandra.utils.bytecomparable.ByteComparable; import org.apache.cassandra.utils.bytecomparable.ByteSource; import org.apache.cassandra.utils.bytecomparable.ByteSourceInverse; @@ -34,7 +37,13 @@ public class PrimaryKeyWithByteComparable extends PrimaryKeyWithSortKey { private final ByteComparable byteComparable; - public PrimaryKeyWithByteComparable(IndexContext context, Object sourceTable, PrimaryKey primaryKey, ByteComparable byteComparable) + public PrimaryKeyWithByteComparable(IndexContext context, Memtable sourceTable, PrimaryKey primaryKey, ByteComparable byteComparable) + { + super(context, sourceTable, primaryKey); + this.byteComparable = byteComparable; + } + + public PrimaryKeyWithByteComparable(IndexContext context, SSTableId sourceTable, PrimaryKey primaryKey, ByteComparable byteComparable) { super(context, sourceTable, primaryKey); this.byteComparable = byteComparable; @@ -65,4 +74,10 @@ public int compareTo(PrimaryKey o) return ByteComparable.compare(byteComparable, ((PrimaryKeyWithByteComparable) o).byteComparable, TypeUtil.BYTE_COMPARABLE_VERSION); } + + @Override + public long ramBytesUsed() + { + return super.ramBytesUsed() + RamUsageEstimator.NUM_BYTES_OBJECT_REF; + } } diff --git a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithScore.java b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithScore.java index 8b7b4acba9c6..b88c210d65f2 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithScore.java +++ b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithScore.java @@ -20,7 +20,9 @@ import java.nio.ByteBuffer; +import org.apache.cassandra.db.memtable.Memtable; import org.apache.cassandra.index.sai.IndexContext; +import org.apache.cassandra.io.sstable.SSTableId; /** * A {@link PrimaryKey} that includes a score from a source index. @@ -28,9 +30,15 @@ */ public class PrimaryKeyWithScore extends PrimaryKeyWithSortKey { - private final float indexScore; + public final float indexScore; - public PrimaryKeyWithScore(IndexContext context, Object source, PrimaryKey primaryKey, float indexScore) + public PrimaryKeyWithScore(IndexContext context, Memtable source, PrimaryKey primaryKey, float indexScore) + { + super(context, source, primaryKey); + this.indexScore = indexScore; + } + + public PrimaryKeyWithScore(IndexContext context, SSTableId source, PrimaryKey primaryKey, float indexScore) { super(context, source, primaryKey); this.indexScore = indexScore; @@ -53,4 +61,11 @@ public int compareTo(PrimaryKey o) // Descending order return Float.compare(((PrimaryKeyWithScore) o).indexScore, indexScore); } + + @Override + public long ramBytesUsed() + { + // Include super class fields plus float value + return super.ramBytesUsed() + Float.BYTES; + } } diff --git a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithSortKey.java b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithSortKey.java index 2e79b0402124..b3a6fb4338e5 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithSortKey.java +++ b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithSortKey.java @@ -20,11 +20,14 @@ import java.nio.ByteBuffer; +import io.github.jbellis.jvector.util.RamUsageEstimator; import org.apache.cassandra.db.Clustering; import org.apache.cassandra.db.DecoratedKey; +import org.apache.cassandra.db.memtable.Memtable; import org.apache.cassandra.db.rows.Row; import org.apache.cassandra.dht.Token; import org.apache.cassandra.index.sai.IndexContext; +import org.apache.cassandra.io.sstable.SSTableId; import org.apache.cassandra.utils.bytecomparable.ByteComparable; import org.apache.cassandra.utils.bytecomparable.ByteSource; @@ -41,7 +44,14 @@ public abstract class PrimaryKeyWithSortKey implements PrimaryKey // Either a Memtable reference or an SSTableId reference private final Object sourceTable; - protected PrimaryKeyWithSortKey(IndexContext context, Object sourceTable, PrimaryKey primaryKey) + protected PrimaryKeyWithSortKey(IndexContext context, Memtable sourceTable, PrimaryKey primaryKey) + { + this.context = context; + this.sourceTable = sourceTable; + this.primaryKey = primaryKey; + } + + protected PrimaryKeyWithSortKey(IndexContext context, SSTableId sourceTable, PrimaryKey primaryKey) { this.context = context; this.sourceTable = sourceTable; @@ -137,4 +147,13 @@ public ByteSource asComparableBytesMaxPrefix(ByteComparable.Version version) return primaryKey.asComparableBytesMaxPrefix(version); } + @Override + public long ramBytesUsed() + { + // Object header + 3 references (context, primaryKey, sourceTable) + return RamUsageEstimator.NUM_BYTES_OBJECT_HEADER + + 3L * RamUsageEstimator.NUM_BYTES_OBJECT_REF + + primaryKey.ramBytesUsed(); + } + } diff --git a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeys.java b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeys.java index 173dc045965f..75298ac217da 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeys.java +++ b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeys.java @@ -41,6 +41,7 @@ public class PrimaryKeys implements Iterable * Adds the specified {@link PrimaryKey}. * * @param key a primary key + * @return the bytes allocated for the key (0 if it already existed in the set) */ public long add(PrimaryKey key) { diff --git a/src/java/org/apache/cassandra/index/sai/utils/RowIdWithScore.java b/src/java/org/apache/cassandra/index/sai/utils/RowIdWithScore.java index b4b1f129567f..c6ed4708e0c8 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/RowIdWithScore.java +++ b/src/java/org/apache/cassandra/index/sai/utils/RowIdWithScore.java @@ -26,7 +26,7 @@ */ public class RowIdWithScore extends RowIdWithMeta { - private final float score; + public final float score; public RowIdWithScore(int segmentRowId, float score) { diff --git a/src/java/org/apache/cassandra/index/sai/utils/RowWithSourceTable.java b/src/java/org/apache/cassandra/index/sai/utils/RowWithSourceTable.java index 3064c701fc56..d93198a94d1a 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/RowWithSourceTable.java +++ b/src/java/org/apache/cassandra/index/sai/utils/RowWithSourceTable.java @@ -375,4 +375,13 @@ private Row maybeWrapRow(Row r) return this; return new RowWithSourceTable(r, source); } + + @Override + public String toString() + { + return "RowWithSourceTable{" + + row + + ", source=" + source + + '}'; + } } diff --git a/src/java/org/apache/cassandra/schema/ColumnMetadata.java b/src/java/org/apache/cassandra/schema/ColumnMetadata.java index 38784e72acb5..008813564e57 100644 --- a/src/java/org/apache/cassandra/schema/ColumnMetadata.java +++ b/src/java/org/apache/cassandra/schema/ColumnMetadata.java @@ -71,9 +71,9 @@ public enum ClusteringOrder /** * The type of CQL3 column this definition represents. - * There is 4 main type of CQL3 columns: those parts of the partition key, - * those parts of the clustering columns and amongst the others, regular and - * static ones. + * There are 5 types of columns: those parts of the partition key, + * those parts of the clustering columns and amongst the others, regular, + * static, and synthetic ones. * * IMPORTANT: this enum is serialized as toString() and deserialized by calling * Kind.valueOf(), so do not override toString() or rename existing values. @@ -81,18 +81,22 @@ public enum ClusteringOrder public enum Kind { // NOTE: if adding a new type, must modify comparisonOrder + SYNTHETIC, PARTITION_KEY, CLUSTERING, REGULAR, STATIC; + // it is not possible to add new Kinds after Synthetic without invasive changes to BTreeRow, which + // assumes that complex regulr/static columns are the last ones public boolean isPrimaryKeyKind() { return this == PARTITION_KEY || this == CLUSTERING; } - } + public static final ColumnIdentifier SYNTHETIC_SCORE_ID = ColumnIdentifier.getInterned("+:!score", true); + /** * Whether this is a dropped column. */ @@ -121,10 +125,17 @@ public boolean isPrimaryKeyKind() */ private final long comparisonOrder; + /** + * Bit layout (from most to least significant): + * - Bits 61-63: Kind ordinal (3 bits, supporting up to 8 Kind values) + * - Bit 60: isComplex flag + * - Bits 48-59: position (12 bits, see assert) + * - Bits 0-47: name.prefixComparison (shifted right by 16) + */ private static long comparisonOrder(Kind kind, boolean isComplex, long position, ColumnIdentifier name) { assert position >= 0 && position < 1 << 12; - return (((long) kind.ordinal()) << 61) + return (((long) kind.ordinal()) << 61) | (isComplex ? 1L << 60 : 0) | (position << 48) | (name.prefixComparison >>> 16); @@ -170,6 +181,14 @@ public static ColumnMetadata staticColumn(String keyspace, String table, String return new ColumnMetadata(keyspace, table, ColumnIdentifier.getInterned(name, true), type, NO_POSITION, Kind.STATIC); } + /** + * Creates a new synthetic column metadata instance. + */ + public static ColumnMetadata syntheticColumn(String keyspace, String table, ColumnIdentifier id, AbstractType type) + { + return new ColumnMetadata(keyspace, table, id, type, NO_POSITION, Kind.SYNTHETIC); + } + /** * Rebuild the metadata for a dropped column from its recorded data. * @@ -225,6 +244,7 @@ public ColumnMetadata(String ksName, this.kind = kind; this.position = position; this.cellPathComparator = makeCellPathComparator(kind, type); + assert kind != Kind.SYNTHETIC || cellPathComparator == null; this.cellComparator = cellPathComparator == null ? ColumnData.comparator : new Comparator>() { @Override @@ -461,7 +481,7 @@ public int compareTo(ColumnMetadata other) return 0; if (comparisonOrder != other.comparisonOrder) - return Long.compare(comparisonOrder, other.comparisonOrder); + return Long.compareUnsigned(comparisonOrder, other.comparisonOrder); return this.name.compareTo(other.name); } @@ -593,6 +613,11 @@ public boolean isCounterColumn() return type.isCounter(); } + public boolean isSynthetic() + { + return kind == Kind.SYNTHETIC; + } + public Selector.Factory newSelectorFactory(TableMetadata table, AbstractType expectedType, List defs, VariableSpecifications boundNames) throws InvalidRequestException { return SimpleSelector.newFactory(this, addAndGetIndex(this, defs)); diff --git a/src/java/org/apache/cassandra/schema/TableMetadata.java b/src/java/org/apache/cassandra/schema/TableMetadata.java index ba5e1db8d84d..8885b85fdd43 100644 --- a/src/java/org/apache/cassandra/schema/TableMetadata.java +++ b/src/java/org/apache/cassandra/schema/TableMetadata.java @@ -1122,8 +1122,7 @@ public Builder addStaticColumn(ColumnIdentifier name, AbstractType type) public Builder addColumn(ColumnMetadata column) { - if (columns.containsKey(column.name.bytes)) - throw new IllegalArgumentException(); + assert !columns.containsKey(column.name.bytes) : column.name + " is already present"; switch (column.kind) { diff --git a/src/java/org/apache/cassandra/service/ClientWarn.java b/src/java/org/apache/cassandra/service/ClientWarn.java index 5a6a878681e1..38570a06d2b8 100644 --- a/src/java/org/apache/cassandra/service/ClientWarn.java +++ b/src/java/org/apache/cassandra/service/ClientWarn.java @@ -18,7 +18,9 @@ package org.apache.cassandra.service; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; +import java.util.Set; import io.netty.util.concurrent.FastThreadLocal; import org.apache.cassandra.concurrent.ExecutorLocal; @@ -45,10 +47,18 @@ public void set(State value) } public void warn(String text) + { + warn(text, null); + } + + /** + * Issue the given warning if this is the first time `key` is seen. + */ + public void warn(String text, Object key) { State state = warnLocal.get(); if (state != null) - state.add(text); + state.add(text, key); } public void captureWarnings() @@ -72,11 +82,16 @@ public void resetWarnings() public static class State { private final List warnings = new ArrayList<>(); + private final Set keysAdded = new HashSet<>(); - private void add(String warning) + private void add(String warning, Object key) { if (warnings.size() < FBUtilities.MAX_UNSIGNED_SHORT) + { + if (key != null && !keysAdded.add(key)) + return; warnings.add(maybeTruncate(warning)); + } } private static String maybeTruncate(String warning) diff --git a/test/burn/org/apache/cassandra/index/sai/LongBM25Test.java b/test/burn/org/apache/cassandra/index/sai/LongBM25Test.java new file mode 100644 index 000000000000..49b5b5118540 --- /dev/null +++ b/test/burn/org/apache/cassandra/index/sai/LongBM25Test.java @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import org.junit.Before; +import org.junit.Test; +import org.slf4j.Logger; + +import org.apache.cassandra.db.memtable.TrieMemtable; + +public class LongBM25Test extends SAITester +{ + private static final Logger logger = org.slf4j.LoggerFactory.getLogger(LongBM25Test.class); + + private static final List documentLines = new ArrayList<>(); + + static + { + try + { + var cl = LongBM25Test.class.getClassLoader(); + var resourceDir = cl.getResource("bm25"); + if (resourceDir == null) + throw new RuntimeException("Could not find resource directory test/resources/bm25/"); + + var dirPath = java.nio.file.Paths.get(resourceDir.toURI()); + try (var files = java.nio.file.Files.list(dirPath)) + { + files.forEach(file -> { + try (var lines = java.nio.file.Files.lines(file)) + { + lines.map(String::trim) + .filter(line -> !line.isEmpty()) + .forEach(documentLines::add); + } + catch (IOException e) + { + throw new RuntimeException("Failed to read file: " + file, e); + } + }); + } + if (documentLines.isEmpty()) + { + throw new RuntimeException("No document lines loaded from test/resources/bm25/"); + } + } + catch (IOException | URISyntaxException e) + { + throw new RuntimeException("Failed to load test documents", e); + } + } + + KeySet keysInserted = new KeySet(); + private final int threadCount = 12; + + @Before + public void setup() throws Throwable + { + // we don't get loaded until after TM, so we can't affect the very first memtable, + // but this will affect all subsequent ones + TrieMemtable.SHARD_COUNT = 4 * threadCount; + } + + @FunctionalInterface + private interface Op + { + void run(int i) throws Throwable; + } + + public void testConcurrentOps(Op op) throws ExecutionException, InterruptedException + { + createTable("CREATE TABLE %s (key int primary key, value text)"); + // Create analyzed index following BM25Test pattern + createIndex("CREATE CUSTOM INDEX ON %s(value) " + + "USING 'org.apache.cassandra.index.sai.StorageAttachedIndex' " + + "WITH OPTIONS = {" + + "'index_analyzer': '{" + + "\"tokenizer\" : {\"name\" : \"standard\"}, " + + "\"filters\" : [{\"name\" : \"porterstem\"}]" + + "}'}" + ); + + AtomicInteger counter = new AtomicInteger(); + long start = System.currentTimeMillis(); + var fjp = new ForkJoinPool(threadCount); + var keys = IntStream.range(0, 10_000_000).boxed().collect(Collectors.toList()); + Collections.shuffle(keys); + var task = fjp.submit(() -> keys.stream().parallel().forEach(i -> + { + wrappedOp(op, i); + if (counter.incrementAndGet() % 10_000 == 0) + { + var elapsed = System.currentTimeMillis() - start; + logger.info("{} ops in {}ms = {} ops/s", counter.get(), elapsed, counter.get() * 1000.0 / elapsed); + } + if (ThreadLocalRandom.current().nextDouble() < 0.001) + flush(); + })); + fjp.shutdown(); + task.get(); // re-throw + } + + private static void wrappedOp(Op op, Integer i) + { + try + { + op.run(i); + } + catch (Throwable e) + { + throw new RuntimeException(e); + } + } + + private static String randomDocument() + { + var R = ThreadLocalRandom.current(); + int numLines = R.nextInt(5, 51); // 5 to 50 lines inclusive + var selectedLines = new ArrayList(); + + for (int i = 0; i < numLines; i++) + { + selectedLines.add(randomQuery(R)); + } + + return String.join("\n", selectedLines); + } + + private static String randomLine(ThreadLocalRandom R) + { + return documentLines.get(R.nextInt(documentLines.size())); + } + + @Test + public void testConcurrentReadsWritesDeletes() throws ExecutionException, InterruptedException + { + testConcurrentOps(i -> { + var R = ThreadLocalRandom.current(); + if (R.nextDouble() < 0.2 || keysInserted.isEmpty()) + { + var doc = randomDocument(); + execute("INSERT INTO %s (key, value) VALUES (?, ?)", i, doc); + keysInserted.add(i); + } + else if (R.nextDouble() < 0.1) + { + var key = keysInserted.getRandom(); + execute("DELETE FROM %s WHERE key = ?", key); + } + else + { + var line = randomQuery(R); + execute("SELECT * FROM %s ORDER BY value BM25 OF ? LIMIT ?", line, R.nextInt(1, 100)); + } + }); + } + + private static String randomQuery(ThreadLocalRandom R) + { + while (true) + { + var line = randomLine(R); + if (line.chars().anyMatch(Character::isAlphabetic)) + return line; + } + } + + @Test + public void testConcurrentReadsWrites() throws ExecutionException, InterruptedException + { + testConcurrentOps(i -> { + var R = ThreadLocalRandom.current(); + if (R.nextDouble() < 0.1 || keysInserted.isEmpty()) + { + var doc = randomDocument(); + execute("INSERT INTO %s (key, value) VALUES (?, ?)", i, doc); + keysInserted.add(i); + } + else + { + var line = randomQuery(R); + execute("SELECT * FROM %s ORDER BY value BM25 OF ? LIMIT ?", line, R.nextInt(1, 100)); + } + }); + } + + @Test + public void testConcurrentWrites() throws ExecutionException, InterruptedException + { + testConcurrentOps(i -> { + var doc = randomDocument(); + execute("INSERT INTO %s (key, value) VALUES (?, ?)", i, doc); + }); + } + + private static class KeySet + { + private final Map keys = new ConcurrentHashMap<>(); + private final AtomicInteger ordinal = new AtomicInteger(); + + public void add(int key) + { + var i = ordinal.getAndIncrement(); + keys.put(i, key); + } + + public int getRandom() + { + if (isEmpty()) + throw new IllegalStateException(); + var i = ThreadLocalRandom.current().nextInt(ordinal.get()); + // in case there is race with add(key), retry another random + return keys.containsKey(i) ? keys.get(i) : getRandom(); + } + + public boolean isEmpty() + { + return keys.isEmpty(); + } + } +} diff --git a/test/distributed/org/apache/cassandra/distributed/test/sai/BM25DistributedTest.java b/test/distributed/org/apache/cassandra/distributed/test/sai/BM25DistributedTest.java new file mode 100644 index 000000000000..95c7587ea044 --- /dev/null +++ b/test/distributed/org/apache/cassandra/distributed/test/sai/BM25DistributedTest.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.distributed.test.sai; + +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.apache.cassandra.distributed.Cluster; +import org.apache.cassandra.distributed.api.ConsistencyLevel; +import org.apache.cassandra.distributed.test.TestBaseImpl; +import org.apache.cassandra.index.sai.disk.format.Version; + +import static org.apache.cassandra.distributed.api.Feature.GOSSIP; +import static org.apache.cassandra.distributed.api.Feature.NETWORK; +import static org.assertj.core.api.Assertions.assertThat; + +public class BM25DistributedTest extends TestBaseImpl +{ + private static final String CREATE_KEYSPACE = "CREATE KEYSPACE %%s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': %d}"; + private static final String CREATE_TABLE = "CREATE TABLE %s (k int PRIMARY KEY, v text)"; + private static final String CREATE_INDEX = "CREATE CUSTOM INDEX ON %s(v) USING 'StorageAttachedIndex' WITH OPTIONS = {'index_analyzer': '{\"tokenizer\" : {\"name\" : \"standard\"}, \"filters\" : [{\"name\" : \"porterstem\"}]}'}"; + + // To get consistent results from BM25 we need to know which docs are evaluated, the easiest way + // to do that is to put all the docs on every replica + private static final int NUM_NODES = 3; + private static final int RF = 3; + + private static Cluster cluster; + private static String table; + + private static final AtomicInteger seq = new AtomicInteger(); + + @BeforeClass + public static void setupCluster() throws Exception + { + cluster = Cluster.build(NUM_NODES) + .withTokenCount(1) + .withDataDirCount(1) + .withConfig(config -> config.with(GOSSIP).with(NETWORK)) + .start(); + + cluster.schemaChange(withKeyspace(String.format(CREATE_KEYSPACE, RF))); + cluster.forEach(i -> i.runOnInstance(() -> org.apache.cassandra.index.sai.SAIUtil.setLatestVersion(Version.EC))); + } + + @AfterClass + public static void closeCluster() + { + if (cluster != null) + cluster.close(); + } + + @Before + public void before() + { + table = "table_" + seq.getAndIncrement(); + cluster.schemaChange(formatQuery(CREATE_TABLE)); + cluster.schemaChange(formatQuery(CREATE_INDEX)); + SAIUtil.waitForIndexQueryable(cluster, KEYSPACE); + } + + @Test + public void testTermFrequencyOrdering() + { + // Insert documents with varying frequencies of the term "apple" + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + execute("INSERT INTO %s (k, v) VALUES (2, 'apple apple')"); + execute("INSERT INTO %s (k, v) VALUES (3, 'apple apple apple')"); + + // Query memtable index + assertBM25Ordering(); + + // Flush and query on-disk index + cluster.forEach(n -> n.flush(KEYSPACE)); + assertBM25Ordering(); + } + + private void assertBM25Ordering() + { + Object[][] result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertThat(result).hasNumberOfRows(3); + + // Results should be ordered by term frequency (highest to lowest) + assertThat((Integer) result[0][0]).isEqualTo(3); // 3 occurrences + assertThat((Integer) result[1][0]).isEqualTo(2); // 2 occurrences + assertThat((Integer) result[2][0]).isEqualTo(1); // 1 occurrence + } + + private static Object[][] execute(String query) + { + return execute(query, ConsistencyLevel.QUORUM); + } + + private static Object[][] execute(String query, ConsistencyLevel consistencyLevel) + { + return cluster.coordinator(1).execute(formatQuery(query), consistencyLevel); + } + + private static String formatQuery(String query) + { + return String.format(query, KEYSPACE + '.' + table); + } +} diff --git a/test/unit/org/apache/cassandra/cql3/CQLTester.java b/test/unit/org/apache/cassandra/cql3/CQLTester.java index 38b937041caf..aec7238c9348 100644 --- a/test/unit/org/apache/cassandra/cql3/CQLTester.java +++ b/test/unit/org/apache/cassandra/cql3/CQLTester.java @@ -768,6 +768,11 @@ protected String currentIndex() return indexes.get(indexes.size() - 1); } + protected String getIndex(int i) + { + return indexes.get(i); + } + protected Collection currentTables() { if (tables == null || tables.isEmpty()) diff --git a/test/unit/org/apache/cassandra/cql3/validation/operations/SelectOrderByTest.java b/test/unit/org/apache/cassandra/cql3/validation/operations/SelectOrderByTest.java index dac86c76e5d8..34f7d606ce55 100644 --- a/test/unit/org/apache/cassandra/cql3/validation/operations/SelectOrderByTest.java +++ b/test/unit/org/apache/cassandra/cql3/validation/operations/SelectOrderByTest.java @@ -659,7 +659,7 @@ public void testAllowSkippingEqualityAndSingleValueInRestrictedClusteringColumns assertInvalidMessage("Cannot combine clustering column ordering with non-clustering column ordering", "SELECT * FROM %s WHERE a=? ORDER BY b ASC, c ASC, d ASC", 0); - String errorMsg = "Order by currently only supports the ordering of columns following their declared order in the PRIMARY KEY"; + String errorMsg = "Ordering by clustered columns must follow the declared order in the PRIMARY KEY"; assertRows(execute("SELECT * FROM %s WHERE a=? AND b=? ORDER BY c", 0, 0), row(0, 0, 0, 0), diff --git a/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java b/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java new file mode 100644 index 000000000000..3a40e075114a --- /dev/null +++ b/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java @@ -0,0 +1,510 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.cql; + +import org.junit.Before; +import org.junit.Test; + +import org.apache.cassandra.index.sai.SAITester; +import org.apache.cassandra.index.sai.SAIUtil; +import org.apache.cassandra.index.sai.disk.format.Version; +import org.apache.cassandra.index.sai.disk.v1.SegmentBuilder; +import org.apache.cassandra.index.sai.plan.QueryController; + +import static org.apache.cassandra.index.sai.analyzer.AnalyzerEqOperatorSupport.EQ_AMBIGUOUS_ERROR; +import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; + +public class BM25Test extends SAITester +{ + @Before + public void setup() throws Throwable + { + SAIUtil.setLatestVersion(Version.EC); + } + + @Test + public void testTwoIndexes() + { + // create un-analyzed index + createTable("CREATE TABLE %s (k int PRIMARY KEY, v text)"); + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'"); + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + + // BM25 should fail with only an equality index + assertInvalidMessage("BM25 ordering on column v requires an analyzed index", + "SELECT k FROM %s WHERE v : 'apple' ORDER BY v BM25 OF 'apple' LIMIT 3"); + + // create analyzed index + analyzeIndex(); + // BM25 query should work now + var result = execute("SELECT k FROM %s WHERE v : 'apple' ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertRows(result, row(1)); + } + + @Test + public void testTwoIndexesAmbiguousPredicate() throws Throwable + { + createTable("CREATE TABLE %s (k int PRIMARY KEY, v text)"); + + // Create analyzed and un-analyzed indexes + analyzeIndex(); + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'"); + + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + execute("INSERT INTO %s (k, v) VALUES (2, 'apple juice')"); + execute("INSERT INTO %s (k, v) VALUES (3, 'orange juice')"); + + // equality predicate is ambiguous (both analyzed and un-analyzed indexes could support it) so it should + // be rejected + beforeAndAfterFlush(() -> { + // Single predicate + assertInvalidMessage(String.format(EQ_AMBIGUOUS_ERROR, "v", getIndex(0), getIndex(1)), + "SELECT k FROM %s WHERE v = 'apple'"); + + // AND + assertInvalidMessage(String.format(EQ_AMBIGUOUS_ERROR, "v", getIndex(0), getIndex(1)), + "SELECT k FROM %s WHERE v = 'apple' AND v : 'juice'"); + + // OR + assertInvalidMessage(String.format(EQ_AMBIGUOUS_ERROR, "v", getIndex(0), getIndex(1)), + "SELECT k FROM %s WHERE v = 'apple' OR v : 'juice'"); + }); + } + + @Test + public void testTwoIndexesWithEqualsUnsupported() throws Throwable + { + createTable("CREATE TABLE %s (k int PRIMARY KEY, v text)"); + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'"); + // analyzed index with equals_behavior:unsupported option + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex' " + + "WITH OPTIONS = { 'equals_behaviour_when_analyzed': 'unsupported', " + + "'index_analyzer':'{\"tokenizer\":{\"name\":\"standard\"},\"filters\":[{\"name\":\"porterstem\"}]}' }"); + + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + execute("INSERT INTO %s (k, v) VALUES (2, 'apple juice')"); + + beforeAndAfterFlush(() -> { + // combining two EQ predicates is not allowed + assertInvalid("SELECT k FROM %s WHERE v = 'apple' AND v = 'juice'"); + + // combining EQ and MATCH predicates is also not allowed (when we're not converting EQ to MATCH) + assertInvalid("SELECT k FROM %s WHERE v = 'apple' AND v : 'apple'"); + + // combining two MATCH predicates is fine + assertRows(execute("SELECT k FROM %s WHERE v : 'apple' AND v : 'juice'"), + row(2)); + + // = operator should use un-analyzed index since equals is unsupported in analyzed index + assertRows(execute("SELECT k FROM %s WHERE v = 'apple'"), + row(1)); + + // : operator should use analyzed index + assertRows(execute("SELECT k FROM %s WHERE v : 'apple'"), + row(1), row(2)); + }); + } + + @Test + public void testComplexQueriesWithMultipleIndexes() throws Throwable + { + createTable("CREATE TABLE %s (k int PRIMARY KEY, v1 text, v2 text, v3 int)"); + + // Create mix of analyzed, unanalyzed, and non-text indexes + createIndex("CREATE CUSTOM INDEX ON %s(v1) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'"); + createIndex("CREATE CUSTOM INDEX ON %s(v2) " + + "USING 'org.apache.cassandra.index.sai.StorageAttachedIndex' " + + "WITH OPTIONS = {" + + "'index_analyzer': '{" + + "\"tokenizer\" : {\"name\" : \"standard\"}, " + + "\"filters\" : [{\"name\" : \"porterstem\"}]" + + "}'" + + "}"); + createIndex("CREATE CUSTOM INDEX ON %s(v3) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'"); + + execute("INSERT INTO %s (k, v1, v2, v3) VALUES (1, 'apple', 'orange juice', 5)"); + execute("INSERT INTO %s (k, v1, v2, v3) VALUES (2, 'apple juice', 'apple', 10)"); + execute("INSERT INTO %s (k, v1, v2, v3) VALUES (3, 'banana', 'grape juice', 5)"); + + beforeAndAfterFlush(() -> { + // Complex query mixing different types of indexes and operators + assertRows(execute("SELECT k FROM %s WHERE v1 = 'apple' AND v2 : 'juice' AND v3 = 5"), + row(1)); + + // Mix of AND and OR conditions across different index types + assertRows(execute("SELECT k FROM %s WHERE v3 = 5 AND (v1 = 'apple' OR v2 : 'apple')"), + row(1)); + + // Multi-term analyzed query + assertRows(execute("SELECT k FROM %s WHERE v2 : 'orange juice'"), + row(1)); + + // Range query with text match + assertRows(execute("SELECT k FROM %s WHERE v3 >= 5 AND v2 : 'juice'"), + row(1), row(3)); + }); + } + + @Test + public void testMatchingAllowed() throws Throwable + { + // match operator should be allowed with BM25 on the same column + // (seems obvious but exercises a corner case in the internal RestrictionSet processing) + createSimpleTable(); + + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + + beforeAndAfterFlush(() -> + { + var result = execute("SELECT k FROM %s WHERE v : 'apple' ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertRows(result, row(1)); + }); + } + + @Test + public void testUnknownQueryTerm() throws Throwable + { + createSimpleTable(); + + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + + beforeAndAfterFlush(() -> + { + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'orange' LIMIT 1"); + assertEmpty(result); + }); + } + + @Test + public void testDuplicateQueryTerm() throws Throwable + { + createSimpleTable(); + + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + + beforeAndAfterFlush(() -> + { + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'apple apple' LIMIT 1"); + assertRows(result, row(1)); + }); + } + + @Test + public void testEmptyQuery() throws Throwable + { + createSimpleTable(); + + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + + beforeAndAfterFlush(() -> + { + assertInvalidMessage("BM25 query must contain at least one term (perhaps your analyzer is discarding tokens you didn't expect)", + "SELECT k FROM %s ORDER BY v BM25 OF '+' LIMIT 1"); + }); + } + + @Test + public void testTermFrequencyOrdering() throws Throwable + { + createSimpleTable(); + + // Insert documents with varying frequencies of the term "apple" + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + execute("INSERT INTO %s (k, v) VALUES (2, 'apple apple')"); + execute("INSERT INTO %s (k, v) VALUES (3, 'apple apple apple')"); + + beforeAndAfterFlush(() -> + { + // Results should be ordered by term frequency (highest to lowest) + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertRows(result, + row(3), // 3 occurrences + row(2), // 2 occurrences + row(1)); // 1 occurrence + }); + } + + @Test + public void testTermFrequenciesWithOverwrites() throws Throwable + { + createSimpleTable(); + + // Insert documents with varying frequencies of the term "apple", but overwrite the first term + // This exercises the code that is supposed to reset frequency counts for overwrites + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + execute("INSERT INTO %s (k, v) VALUES (2, 'apple apple')"); + execute("INSERT INTO %s (k, v) VALUES (3, 'apple apple apple')"); + + beforeAndAfterFlush(() -> + { + // Results should be ordered by term frequency (highest to lowest) + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertRows(result, + row(3), // 3 occurrences + row(2), // 2 occurrences + row(1)); // 1 occurrence + }); + } + + @Test + public void testDocumentLength() throws Throwable + { + createSimpleTable(); + // Create documents with same term frequency but different lengths + execute("INSERT INTO %s (k, v) VALUES (1, 'test test')"); + execute("INSERT INTO %s (k, v) VALUES (2, 'test test other words here to make it longer')"); + execute("INSERT INTO %s (k, v) VALUES (3, 'test test extremely long document with many additional words to significantly increase the document length while maintaining the same term frequency for our target term')"); + + beforeAndAfterFlush(() -> + { + // Documents with same term frequency should be ordered by length (shorter first) + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'test' LIMIT 3"); + assertRows(result, + row(1), + row(2), + row(3)); + }); + } + + @Test + public void testMultiTermQueryScoring() throws Throwable + { + createSimpleTable(); + // Two terms, but "apple" appears in fewer documents + execute("INSERT INTO %s (k, v) VALUES (1, 'apple banana')"); + execute("INSERT INTO %s (k, v) VALUES (2, 'apple apple banana')"); + execute("INSERT INTO %s (k, v) VALUES (3, 'apple banana banana')"); + execute("INSERT INTO %s (k, v) VALUES (4, 'apple apple banana banana')"); + execute("INSERT INTO %s (k, v) VALUES (5, 'banana banana')"); + + beforeAndAfterFlush(() -> + { + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'apple banana' LIMIT 4"); + assertRows(result, + row(2), // Highest frequency of most important term + row(4), // More mentions of both terms + row(1), // One of each term + row(3)); // Low frequency of most important term + }); + } + + @Test + public void testIrrelevantRowsScoring() throws Throwable + { + createSimpleTable(); + // Insert pizza reviews with varying relevance to "crispy crust" + execute("INSERT INTO %s (k, v) VALUES (1, 'The pizza had a crispy crust and was delicious')"); // Basic mention + execute("INSERT INTO %s (k, v) VALUES (2, 'Very crispy crispy crust, perfectly cooked')"); // Emphasized crispy + execute("INSERT INTO %s (k, v) VALUES (3, 'The crust crust crust was okay, nothing special')"); // Only crust mentions + execute("INSERT INTO %s (k, v) VALUES (4, 'Super crispy crispy crust crust, best pizza ever!')"); // Most mentions of both + execute("INSERT INTO %s (k, v) VALUES (5, 'The toppings were good but the pizza was soggy')"); // Irrelevant review + + beforeAndAfterFlush(this::assertIrrelevantRowsCorrect); + } + + private void assertIrrelevantRowsCorrect() + { + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'crispy crust' LIMIT 5"); + assertRows(result, + row(4), // Highest frequency of both terms + row(2), // High frequency of 'crispy', one 'crust' + row(1)); // One mention of each term + // Rows 4 and 5 do not contain all terms + } + + @Test + public void testIrrelevantRowsWithCompaction() + { + // same dataset as testIrrelevantRowsScoring, but split across two sstables + createSimpleTable(); + disableCompaction(); + + execute("INSERT INTO %s (k, v) VALUES (1, 'The pizza had a crispy crust and was delicious')"); // Basic mention + execute("INSERT INTO %s (k, v) VALUES (2, 'Very crispy crispy crust, perfectly cooked')"); // Emphasized crispy + flush(); + + execute("INSERT INTO %s (k, v) VALUES (3, 'The crust crust crust was okay, nothing special')"); // Only crust mentions + execute("INSERT INTO %s (k, v) VALUES (4, 'Super crispy crispy crust crust, best pizza ever!')"); // Most mentions of both + execute("INSERT INTO %s (k, v) VALUES (5, 'The toppings were good but the pizza was soggy')"); // Irrelevant review + flush(); + + assertIrrelevantRowsCorrect(); + + compact(); + assertIrrelevantRowsCorrect(); + + // Force segmentation and requery + SegmentBuilder.updateLastValidSegmentRowId(2); + compact(); + assertIrrelevantRowsCorrect(); + } + + private void createSimpleTable() + { + createTable("CREATE TABLE %s (k int PRIMARY KEY, v text)"); + analyzeIndex(); + } + + private String analyzeIndex() + { + return createIndex("CREATE CUSTOM INDEX ON %s(v) " + + "USING 'org.apache.cassandra.index.sai.StorageAttachedIndex' " + + "WITH OPTIONS = {" + + "'index_analyzer': '{" + + "\"tokenizer\" : {\"name\" : \"standard\"}, " + + "\"filters\" : [{\"name\" : \"porterstem\"}]" + + "}'}" + ); + } + + @Test + public void testWithPredicate() throws Throwable + { + createTable("CREATE TABLE %s (k int PRIMARY KEY, p int, v text)"); + analyzeIndex(); + execute("CREATE CUSTOM INDEX ON %s(p) USING 'StorageAttachedIndex'"); + + // Insert documents with varying frequencies of the term "apple" + execute("INSERT INTO %s (k, p, v) VALUES (1, 5, 'apple')"); + execute("INSERT INTO %s (k, p, v) VALUES (2, 5, 'apple apple')"); + execute("INSERT INTO %s (k, p, v) VALUES (3, 5, 'apple apple apple')"); + execute("INSERT INTO %s (k, p, v) VALUES (4, 6, 'apple apple apple')"); + execute("INSERT INTO %s (k, p, v) VALUES (5, 7, 'apple apple apple')"); + + beforeAndAfterFlush(() -> + { + // Results should be ordered by term frequency (highest to lowest) + var result = execute("SELECT k FROM %s WHERE p = 5 ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertRows(result, + row(3), // 3 occurrences + row(2), // 2 occurrences + row(1)); // 1 occurrence + }); + } + + @Test + public void testWidePartition() throws Throwable + { + createTable("CREATE TABLE %s (k1 int, k2 int, v text, PRIMARY KEY (k1, k2))"); + analyzeIndex(); + + // Insert documents with varying frequencies of the term "apple" + execute("INSERT INTO %s (k1, k2, v) VALUES (0, 1, 'apple')"); + execute("INSERT INTO %s (k1, k2, v) VALUES (0, 2, 'apple apple')"); + execute("INSERT INTO %s (k1, k2, v) VALUES (0, 3, 'apple apple apple')"); + + beforeAndAfterFlush(() -> + { + // Results should be ordered by term frequency (highest to lowest) + var result = execute("SELECT k2 FROM %s ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertRows(result, + row(3), // 3 occurrences + row(2), // 2 occurrences + row(1)); // 1 occurrence + }); + } + + @Test + public void testWidePartitionWithPkPredicate() throws Throwable + { + createTable("CREATE TABLE %s (k1 int, k2 int, v text, PRIMARY KEY (k1, k2))"); + analyzeIndex(); + + // Insert documents with varying frequencies of the term "apple" + execute("INSERT INTO %s (k1, k2, v) VALUES (0, 1, 'apple')"); + execute("INSERT INTO %s (k1, k2, v) VALUES (0, 2, 'apple apple')"); + execute("INSERT INTO %s (k1, k2, v) VALUES (0, 3, 'apple apple apple')"); + execute("INSERT INTO %s (k1, k2, v) VALUES (1, 3, 'apple apple apple')"); + execute("INSERT INTO %s (k1, k2, v) VALUES (2, 3, 'apple apple apple')"); + + beforeAndAfterFlush(() -> + { + // Results should be ordered by term frequency (highest to lowest) + var result = execute("SELECT k2 FROM %s WHERE k1 = 0 ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertRows(result, + row(3), // 3 occurrences + row(2), // 2 occurrences + row(1)); // 1 occurrence + }); + } + + @Test + public void testWidePartitionWithPredicate() throws Throwable + { + createTable("CREATE TABLE %s (k1 int, k2 int, p int, v text, PRIMARY KEY (k1, k2))"); + analyzeIndex(); + execute("CREATE CUSTOM INDEX ON %s(p) USING 'StorageAttachedIndex'"); + + // Insert documents with varying frequencies of the term "apple" + execute("INSERT INTO %s (k1, k2, p, v) VALUES (0, 1, 5, 'apple')"); + execute("INSERT INTO %s (k1, k2, p, v) VALUES (0, 2, 5, 'apple apple')"); + execute("INSERT INTO %s (k1, k2, p, v) VALUES (0, 3, 5, 'apple apple apple')"); + execute("INSERT INTO %s (k1, k2, p, v) VALUES (0, 4, 6, 'apple apple apple')"); + execute("INSERT INTO %s (k1, k2, p, v) VALUES (0, 5, 7, 'apple apple apple')"); + + beforeAndAfterFlush(() -> + { + // Results should be ordered by term frequency (highest to lowest) + var result = execute("SELECT k2 FROM %s WHERE p = 5 ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertRows(result, + row(3), // 3 occurrences + row(2), // 2 occurrences + row(1)); // 1 occurrence + }); + } + + @Test + public void testWithPredicateSearchThenOrder() throws Throwable + { + QueryController.QUERY_OPT_LEVEL = 0; + testWithPredicate(); + } + + @Test + public void testWidePartitionWithPredicateOrderThenSearch() throws Throwable + { + QueryController.QUERY_OPT_LEVEL = 1; + testWidePartitionWithPredicate(); + } + + @Test + public void testQueryWithNulls() throws Throwable + { + createSimpleTable(); + + execute("INSERT INTO %s (k, v) VALUES (0, null)"); + execute("INSERT INTO %s (k, v) VALUES (1, 'test document')"); + beforeAndAfterFlush(() -> + { + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'test' LIMIT 1"); + assertRows(result, row(1)); + }); + } + + @Test + public void testQueryEmptyTable() + { + createSimpleTable(); + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'test' LIMIT 1"); + assertThat(result).hasSize(0); + } +} diff --git a/test/unit/org/apache/cassandra/index/sai/cql/MultipleColumnIndexTest.java b/test/unit/org/apache/cassandra/index/sai/cql/MultipleColumnIndexTest.java index 5364717a8906..8d952dc1f4f0 100644 --- a/test/unit/org/apache/cassandra/index/sai/cql/MultipleColumnIndexTest.java +++ b/test/unit/org/apache/cassandra/index/sai/cql/MultipleColumnIndexTest.java @@ -40,12 +40,30 @@ public void canCreateMultipleMapIndexesOnSameColumn() throws Throwable } @Test - public void cannotHaveMultipleLiteralIndexesWithDifferentOptions() throws Throwable + public void canHaveAnalyzedAndUnanalyzedIndexesOnSameColumn() throws Throwable { - createTable("CREATE TABLE %s (pk int, ck int, value text, PRIMARY KEY(pk, ck))"); + createTable("CREATE TABLE %s (pk int, value text, PRIMARY KEY(pk))"); createIndex("CREATE CUSTOM INDEX ON %s(value) USING 'StorageAttachedIndex' WITH OPTIONS = { 'case_sensitive' : true }"); - assertThatThrownBy(() -> createIndex("CREATE CUSTOM INDEX ON %s(value) USING 'StorageAttachedIndex' WITH OPTIONS = { 'case_sensitive' : false }")) - .isInstanceOf(InvalidRequestException.class); + createIndex("CREATE CUSTOM INDEX ON %s(value) USING 'StorageAttachedIndex' WITH OPTIONS = { 'case_sensitive' : false, 'equals_behaviour_when_analyzed': 'unsupported' }"); + + execute("INSERT INTO %s (pk, value) VALUES (?, ?)", 1, "a"); + execute("INSERT INTO %s (pk, value) VALUES (?, ?)", 2, "A"); + beforeAndAfterFlush(() -> { + assertRows(execute("SELECT pk FROM %s WHERE value = 'a'"), + row(1)); + assertRows(execute("SELECT pk FROM %s WHERE value : 'a'"), + row(1), + row(2)); + }); + } + + @Test + public void cannotHaveMultipleAnalyzingIndexesOnSameColumn() throws Throwable + { + createTable("CREATE TABLE %s (pk int, ck int, value text, PRIMARY KEY(pk, ck))"); + createIndex("CREATE CUSTOM INDEX ON %s(value) USING 'StorageAttachedIndex' WITH OPTIONS = { 'case_sensitive' : false }"); + assertThatThrownBy(() -> createIndex("CREATE CUSTOM INDEX ON %s(value) USING 'StorageAttachedIndex' WITH OPTIONS = { 'normalize' : true }")) + .isInstanceOf(InvalidRequestException.class); } @Test diff --git a/test/unit/org/apache/cassandra/index/sai/cql/NativeIndexDDLTest.java b/test/unit/org/apache/cassandra/index/sai/cql/NativeIndexDDLTest.java index ebda5a426430..b7b5a078b2c0 100644 --- a/test/unit/org/apache/cassandra/index/sai/cql/NativeIndexDDLTest.java +++ b/test/unit/org/apache/cassandra/index/sai/cql/NativeIndexDDLTest.java @@ -581,7 +581,7 @@ public void shouldFailCreationMultipleIndexesOnSimpleColumn() // different name, different option, same target. assertThatThrownBy(() -> executeNet("CREATE CUSTOM INDEX ON %s(v1) USING 'StorageAttachedIndex' WITH OPTIONS = { 'case_sensitive' : true }")) .isInstanceOf(InvalidQueryException.class) - .hasMessageContaining("Cannot create more than one storage-attached index on the same column: v1" ); + .hasMessageContaining("Cannot create duplicate storage-attached index on column: v1" ); ResultSet rows = executeNet("SELECT id FROM %s WHERE v1 = '1'"); assertEquals(1, rows.all().size()); diff --git a/test/unit/org/apache/cassandra/index/sai/disk/RAMPostingSlicesTest.java b/test/unit/org/apache/cassandra/index/sai/disk/RAMPostingSlicesTest.java index 8d49798d906b..d6ca88d935ee 100644 --- a/test/unit/org/apache/cassandra/index/sai/disk/RAMPostingSlicesTest.java +++ b/test/unit/org/apache/cassandra/index/sai/disk/RAMPostingSlicesTest.java @@ -19,6 +19,8 @@ import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; import org.junit.Test; @@ -31,7 +33,7 @@ public class RAMPostingSlicesTest extends SaiRandomizedTest @Test public void testRAMPostingSlices() throws Exception { - RAMPostingSlices slices = new RAMPostingSlices(Counter.newCounter()); + RAMPostingSlices slices = new RAMPostingSlices(Counter.newCounter(), false); int[] segmentRowIdUpto = new int[1024]; Arrays.fill(segmentRowIdUpto, -1); @@ -56,7 +58,7 @@ public void testRAMPostingSlices() throws Exception bitSets[termID].set(segmentRowIdUpto[termID]); - slices.writeVInt(termID, segmentRowIdUpto[termID]); + slices.writePosting(termID, segmentRowIdUpto[termID], 1); } for (int termID = 0; termID < segmentRowIdUpto.length; termID++) @@ -74,4 +76,36 @@ public void testRAMPostingSlices() throws Exception assertEquals(segmentRowId, segmentRowIdUpto[termID]); } } + + @Test + public void testRAMPostingSlicesWithFrequencies() throws Exception { + RAMPostingSlices slices = new RAMPostingSlices(Counter.newCounter(), true); + + // Test with just 3 terms and known frequencies + for (int termId = 0; termId < 3; termId++) { + slices.createNewSlice(termId); + + // Write a sequence of rows with different frequencies for each term + slices.writePosting(termId, 5, 1); // first posting at row 5 + slices.writePosting(termId, 3, 2); // next at row 8 (delta=3) + slices.writePosting(termId, 2, 3); // next at row 10 (delta=2) + } + + // Verify each term's postings + for (int termId = 0; termId < 3; termId++) { + ByteSliceReader reader = new ByteSliceReader(); + PostingList postings = slices.postingList(termId, reader, 10); + + assertEquals(5, postings.nextPosting()); + assertEquals(1, postings.frequency()); + + assertEquals(8, postings.nextPosting()); + assertEquals(2, postings.frequency()); + + assertEquals(10, postings.nextPosting()); + assertEquals(3, postings.frequency()); + + assertEquals(PostingList.END_OF_STREAM, postings.nextPosting()); + } + } } diff --git a/test/unit/org/apache/cassandra/index/sai/disk/RAMStringIndexerTest.java b/test/unit/org/apache/cassandra/index/sai/disk/RAMStringIndexerTest.java index 9ee1605512cf..15d47013f946 100644 --- a/test/unit/org/apache/cassandra/index/sai/disk/RAMStringIndexerTest.java +++ b/test/unit/org/apache/cassandra/index/sai/disk/RAMStringIndexerTest.java @@ -21,11 +21,15 @@ import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.junit.Assert; import org.junit.Test; +import org.apache.cassandra.index.sai.disk.format.IndexComponent; +import org.apache.cassandra.index.sai.disk.v1.trie.InvertedIndexWriter; import org.apache.cassandra.index.sai.utils.SaiRandomizedTest; import org.apache.cassandra.index.sai.utils.TypeUtil; import org.apache.cassandra.utils.ByteBufferUtil; @@ -40,13 +44,13 @@ public class RAMStringIndexerTest extends SaiRandomizedTest @Test public void test() throws Exception { - RAMStringIndexer indexer = new RAMStringIndexer(); + RAMStringIndexer indexer = new RAMStringIndexer(false); - indexer.add(new BytesRef("0"), 100); - indexer.add(new BytesRef("2"), 102); - indexer.add(new BytesRef("0"), 200); - indexer.add(new BytesRef("2"), 202); - indexer.add(new BytesRef("2"), 302); + indexer.addAll(List.of(new BytesRef("0")), 100); + indexer.addAll(List.of(new BytesRef("2")), 102); + indexer.addAll(List.of(new BytesRef("0")), 200); + indexer.addAll(List.of(new BytesRef("2")), 202); + indexer.addAll(List.of(new BytesRef("2")), 302); List> matches = new ArrayList<>(); matches.add(Arrays.asList(100L, 200L)); @@ -75,10 +79,48 @@ public void test() throws Exception } } + @Test + public void testWithFrequencies() throws Exception + { + RAMStringIndexer indexer = new RAMStringIndexer(true); + + // Add same term twice in same row to increment frequency + indexer.addAll(List.of(new BytesRef("A"), new BytesRef("A")), 100); + indexer.addAll(List.of(new BytesRef("B")), 102); + indexer.addAll(List.of(new BytesRef("A"), new BytesRef("A"), new BytesRef("A")), 200); + indexer.addAll(List.of(new BytesRef("B"), new BytesRef("B")), 202); + indexer.addAll(List.of(new BytesRef("B")), 302); + + // Expected results: rowID -> frequency + List> matches = Arrays.asList(Map.of(100L, 2, 200L, 3), + Map.of(102L, 1, 202L, 2, 302L, 1)); + + try (TermsIterator terms = indexer.getTermsWithPostings(ByteBufferUtil.bytes("A"), ByteBufferUtil.bytes("B"), TypeUtil.BYTE_COMPARABLE_VERSION)) + { + int ord = 0; + while (terms.hasNext()) + { + terms.next(); + try (PostingList postings = terms.postings()) + { + Map results = new HashMap<>(); + long segmentRowId; + while ((segmentRowId = postings.nextPosting()) != PostingList.END_OF_STREAM) + { + results.put(segmentRowId, postings.frequency()); + } + assertEquals(matches.get(ord++), results); + } + } + assertArrayEquals("A".getBytes(), terms.getMinTerm().array()); + assertArrayEquals("B".getBytes(), terms.getMaxTerm().array()); + } + } + @Test public void testLargeSegment() throws IOException { - final RAMStringIndexer indexer = new RAMStringIndexer(); + final RAMStringIndexer indexer = new RAMStringIndexer(false); final int numTerms = between(1 << 10, 1 << 13); final int numPostings = between(1 << 5, 1 << 10); @@ -87,7 +129,7 @@ public void testLargeSegment() throws IOException final BytesRef term = new BytesRef(String.format("%04d", id)); for (int posting = 0; posting < numPostings; ++posting) { - indexer.add(term, posting); + indexer.addAll(List.of(term), posting); } } @@ -124,14 +166,14 @@ public void testRequiresFlush() { RAMStringIndexer.MAX_BLOCK_BYTE_POOL_SIZE = 1024 * 1024 * 100; // primary behavior we're testing is that exceptions aren't thrown due to overflowing backing structures - RAMStringIndexer indexer = new RAMStringIndexer(); + RAMStringIndexer indexer = new RAMStringIndexer(false); Assert.assertFalse(indexer.requiresFlush()); for (int i = 0; i < Integer.MAX_VALUE; i++) { if (indexer.requiresFlush()) break; - indexer.add(new BytesRef(String.format("%5000d", i)), i); + indexer.addAll(List.of(new BytesRef(String.format("%5000d", i))), i); } // If we don't require a flush before MAX_VALUE, the implementation of RAMStringIndexer has sufficiently // changed to warrant changes to the test. diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexBuilder.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexBuilder.java index d5699c195305..db2f61370754 100644 --- a/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexBuilder.java +++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexBuilder.java @@ -19,15 +19,18 @@ import java.nio.ByteBuffer; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.function.IntSupplier; import java.util.function.Supplier; +import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; import com.carrotsearch.hppc.IntArrayList; import org.apache.cassandra.db.marshal.UTF8Type; import org.apache.cassandra.index.sai.disk.format.Version; +import org.apache.cassandra.index.sai.memory.RowMapping; import org.apache.cassandra.utils.Pair; import org.apache.cassandra.utils.bytecomparable.ByteComparable; @@ -83,10 +86,15 @@ public static class TermsEnum this.byteComparableBytes = byteComparableBytes; this.postings = postings; } + } - public Pair toPair() - { - return Pair.create(byteComparableBytes, postings); - } + /** + * Adds default frequency of 1 to postings + */ + static Pair> toTermWithFrequency(TermsEnum te) + { + return Pair.create(te.byteComparableBytes, Arrays.stream(te.postings.toArray()).boxed() + .map(p -> new RowMapping.RowIdWithFrequency(p, 1)) + .collect(toList())); } } diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcherTest.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcherTest.java index 47565eb28739..a15b11f6abf3 100644 --- a/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcherTest.java +++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcherTest.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.stream.Collectors; +import org.agrona.collections.Int2IntHashMap; import org.junit.BeforeClass; import org.junit.Test; @@ -101,12 +102,12 @@ private void doTestEqQueriesAgainstStringIndex(Version version) throws Exception final int numTerms = randomIntBetween(64, 512), numPostings = randomIntBetween(256, 1024); final List termsEnum = buildTermsEnum(version, numTerms, numPostings); - try (IndexSearcher searcher = buildIndexAndOpenSearcher(numTerms, numPostings, termsEnum)) + try (IndexSearcher searcher = buildIndexAndOpenSearcher(numTerms, termsEnum)) { for (int t = 0; t < numTerms; ++t) { try (KeyRangeIterator results = searcher.search(new Expression(indexContext) - .add(Operator.EQ, termsEnum.get(t).originalTermBytes), null, new QueryContext(), false, LIMIT)) + .add(Operator.EQ, termsEnum.get(t).originalTermBytes), null, new QueryContext(), false)) { assertTrue(results.hasNext()); @@ -121,7 +122,7 @@ private void doTestEqQueriesAgainstStringIndex(Version version) throws Exception } try (KeyRangeIterator results = searcher.search(new Expression(indexContext) - .add(Operator.EQ, termsEnum.get(t).originalTermBytes), null, new QueryContext(), false, LIMIT)) + .add(Operator.EQ, termsEnum.get(t).originalTermBytes), null, new QueryContext(), false)) { assertTrue(results.hasNext()); @@ -143,12 +144,12 @@ private void doTestEqQueriesAgainstStringIndex(Version version) throws Exception // try searching for terms that weren't indexed final String tooLongTerm = randomSimpleString(10, 12); KeyRangeIterator results = searcher.search(new Expression(indexContext) - .add(Operator.EQ, UTF8Type.instance.decompose(tooLongTerm)), null, new QueryContext(), false, LIMIT); + .add(Operator.EQ, UTF8Type.instance.decompose(tooLongTerm)), null, new QueryContext(), false); assertFalse(results.hasNext()); final String tooShortTerm = randomSimpleString(1, 2); results = searcher.search(new Expression(indexContext) - .add(Operator.EQ, UTF8Type.instance.decompose(tooShortTerm)), null, new QueryContext(), false, LIMIT); + .add(Operator.EQ, UTF8Type.instance.decompose(tooShortTerm)), null, new QueryContext(), false); assertFalse(results.hasNext()); } } @@ -159,10 +160,10 @@ public void testUnsupportedOperator() throws Exception final int numTerms = randomIntBetween(5, 15), numPostings = randomIntBetween(5, 20); final List termsEnum = buildTermsEnum(version, numTerms, numPostings); - try (IndexSearcher searcher = buildIndexAndOpenSearcher(numTerms, numPostings, termsEnum)) + try (IndexSearcher searcher = buildIndexAndOpenSearcher(numTerms, termsEnum)) { searcher.search(new Expression(indexContext) - .add(Operator.NEQ, UTF8Type.instance.decompose("a")), null, new QueryContext(), false, LIMIT); + .add(Operator.NEQ, UTF8Type.instance.decompose("a")), null, new QueryContext(), false); fail("Expect IllegalArgumentException thrown, but didn't"); } @@ -172,9 +173,8 @@ public void testUnsupportedOperator() throws Exception } } - private IndexSearcher buildIndexAndOpenSearcher(int terms, int postings, List termsEnum) throws IOException + private IndexSearcher buildIndexAndOpenSearcher(int terms, List termsEnum) throws IOException { - final int size = terms * postings; final IndexDescriptor indexDescriptor = newIndexDescriptor(); final String index = newIndex(); final IndexContext indexContext = SAITester.createIndexContext(index, UTF8Type.instance); @@ -189,9 +189,10 @@ private IndexSearcher buildIndexAndOpenSearcher(int terms, int postings, List buildTermsEnum(Version version, int return InvertedIndexBuilder.buildStringTermsEnum(version, terms, postings, () -> randomSimpleString(3, 5), () -> nextInt(0, Integer.MAX_VALUE)); } + private Int2IntHashMap createMockDocLengths(List termsEnum) + { + Int2IntHashMap docLengths = new Int2IntHashMap(Integer.MIN_VALUE); + for (InvertedIndexBuilder.TermsEnum term : termsEnum) + { + for (var cursor : term.postings) + docLengths.put(cursor.value, 1); + } + return docLengths; + } + private ByteBuffer wrap(ByteComparable bc) { return ByteBuffer.wrap(ByteSourceInverse.readBytes(bc.asComparableBytes(TypeUtil.BYTE_COMPARABLE_VERSION))); diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcherTest.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcherTest.java index 8f34b9a808aa..65242e319a37 100644 --- a/test/unit/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcherTest.java +++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcherTest.java @@ -151,7 +151,7 @@ public void testUnsupportedOperator() throws Exception {{ operation = Op.NOT_EQ; lower = upper = new Bound(ShortType.instance.decompose((short) 0), Int32Type.instance, true); - }}, null, new QueryContext(), false, LIMIT); + }}, null, new QueryContext(), false); fail("Expect IllegalArgumentException thrown, but didn't"); } @@ -169,7 +169,7 @@ private void testEqQueries(final IndexSearcher indexSearcher, {{ operation = Op.EQ; lower = upper = new Bound(rawType.decompose(rawValueProducer.apply(EQ_TEST_LOWER_BOUND_INCLUSIVE)), encodedType, true); - }}, null, new QueryContext(), false, LIMIT)) + }}, null, new QueryContext(), false)) { assertTrue(results.hasNext()); @@ -180,7 +180,7 @@ private void testEqQueries(final IndexSearcher indexSearcher, {{ operation = Op.EQ; lower = upper = new Bound(rawType.decompose(rawValueProducer.apply(EQ_TEST_UPPER_BOUND_EXCLUSIVE)), encodedType, true); - }}, null, new QueryContext(), false, LIMIT)) + }}, null, new QueryContext(), false)) { assertFalse(results.hasNext()); indexSearcher.close(); @@ -206,7 +206,7 @@ private void testRangeQueries(final IndexSearcher indexSearch lower = new Bound(rawType.decompose(rawValueProducer.apply((short)2)), encodedType, false); upper = new Bound(rawType.decompose(rawValueProducer.apply((short)7)), encodedType, true); - }}, null, new QueryContext(), false, LIMIT)) + }}, null, new QueryContext(), false)) { assertTrue(results.hasNext()); @@ -218,7 +218,7 @@ private void testRangeQueries(final IndexSearcher indexSearch {{ operation = Op.RANGE; lower = new Bound(rawType.decompose(rawValueProducer.apply(RANGE_TEST_UPPER_BOUND_EXCLUSIVE)), encodedType, true); - }}, null, new QueryContext(), false, LIMIT)) + }}, null, new QueryContext(), false)) { assertFalse(results.hasNext()); } @@ -227,7 +227,7 @@ private void testRangeQueries(final IndexSearcher indexSearch {{ operation = Op.RANGE; upper = new Bound(rawType.decompose(rawValueProducer.apply(RANGE_TEST_LOWER_BOUND_INCLUSIVE)), encodedType, false); - }}, null, new QueryContext(), false, LIMIT)) + }}, null, new QueryContext(), false)) { assertFalse(results.hasNext()); indexSearcher.close(); diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/TermsReaderTest.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/TermsReaderTest.java index bdfbf529e86b..97277da2490a 100644 --- a/test/unit/org/apache/cassandra/index/sai/disk/v1/TermsReaderTest.java +++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/TermsReaderTest.java @@ -18,42 +18,49 @@ package org.apache.cassandra.index.sai.disk.v1; import java.io.IOException; +import java.util.ArrayList; import java.util.Collection; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import org.junit.Test; -import com.carrotsearch.hppc.IntArrayList; import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import org.agrona.collections.Int2IntHashMap; import org.apache.cassandra.config.DatabaseDescriptor; import org.apache.cassandra.db.marshal.UTF8Type; import org.apache.cassandra.index.sai.IndexContext; import org.apache.cassandra.index.sai.QueryContext; import org.apache.cassandra.index.sai.SAITester; +import org.apache.cassandra.index.sai.SAIUtil; import org.apache.cassandra.index.sai.disk.MemtableTermsIterator; import org.apache.cassandra.index.sai.disk.PostingList; +import org.apache.cassandra.index.sai.disk.RAMStringIndexer; import org.apache.cassandra.index.sai.disk.TermsIterator; import org.apache.cassandra.index.sai.disk.format.IndexComponentType; import org.apache.cassandra.index.sai.disk.format.IndexComponents; import org.apache.cassandra.index.sai.disk.format.IndexDescriptor; import org.apache.cassandra.index.sai.disk.format.Version; import org.apache.cassandra.index.sai.disk.v1.trie.InvertedIndexWriter; +import org.apache.cassandra.index.sai.memory.RowMapping; import org.apache.cassandra.index.sai.metrics.QueryEventListener; import org.apache.cassandra.index.sai.utils.SAICodecUtils; import org.apache.cassandra.index.sai.utils.SaiRandomizedTest; import org.apache.cassandra.index.sai.utils.TypeUtil; import org.apache.cassandra.io.util.FileHandle; +import org.apache.cassandra.utils.ByteBufferUtil; import org.apache.cassandra.utils.Pair; import org.apache.cassandra.utils.bytecomparable.ByteComparable; import org.apache.cassandra.utils.bytecomparable.ByteSourceInverse; +import org.apache.lucene.util.BytesRef; import static org.apache.cassandra.index.sai.disk.v1.InvertedIndexBuilder.buildStringTermsEnum; import static org.apache.cassandra.index.sai.metrics.QueryEventListeners.NO_OP_TRIE_LISTENER; public class TermsReaderTest extends SaiRandomizedTest { - public static final ByteComparable.Version VERSION = TypeUtil.BYTE_COMPARABLE_VERSION; @ParametersFactory() @@ -102,8 +109,11 @@ private void doTestTermsIteration(Version version) throws IOException IndexComponents.ForWrite components = indexDescriptor.newPerIndexComponentsForWrite(indexContext); try (InvertedIndexWriter writer = new InvertedIndexWriter(components)) { - var iter = termsEnum.stream().map(InvertedIndexBuilder.TermsEnum::toPair).iterator(); - indexMetas = writer.writeAll(new MemtableTermsIterator(null, null, iter)); + var iter = termsEnum.stream() + .map(InvertedIndexBuilder::toTermWithFrequency) + .iterator(); + Int2IntHashMap docLengths = createMockDocLengths(termsEnum); + indexMetas = writer.writeAll(new MemtableTermsIterator(null, null, iter), docLengths); } FileHandle termsData = components.get(IndexComponentType.TERMS_DATA).createFileHandle(); @@ -142,8 +152,11 @@ private void testTermQueries(Version version, int numTerms, int numPostings) thr IndexComponents.ForWrite components = indexDescriptor.newPerIndexComponentsForWrite(indexContext); try (InvertedIndexWriter writer = new InvertedIndexWriter(components)) { - var iter = termsEnum.stream().map(InvertedIndexBuilder.TermsEnum::toPair).iterator(); - indexMetas = writer.writeAll(new MemtableTermsIterator(null, null, iter)); + var iter = termsEnum.stream() + .map(InvertedIndexBuilder::toTermWithFrequency) + .iterator(); + Int2IntHashMap docLengths = createMockDocLengths(termsEnum); + indexMetas = writer.writeAll(new MemtableTermsIterator(null, null, iter), docLengths); } FileHandle termsData = components.get(IndexComponentType.TERMS_DATA).createFileHandle(); @@ -159,22 +172,24 @@ private void testTermQueries(Version version, int numTerms, int numPostings) thr termsFooterPointer, version)) { - var iter = termsEnum.stream().map(InvertedIndexBuilder.TermsEnum::toPair).collect(Collectors.toList()); - for (Pair pair : iter) + var iter = termsEnum.stream() + .map(InvertedIndexBuilder::toTermWithFrequency) + .collect(Collectors.toList()); + for (Pair> pair : iter) { final byte[] bytes = ByteSourceInverse.readBytes(pair.left.asComparableBytes(VERSION)); try (PostingList actualPostingList = reader.exactMatch(ByteComparable.preencoded(VERSION, bytes), (QueryEventListener.TrieIndexEventListener)NO_OP_TRIE_LISTENER, new QueryContext())) { - final IntArrayList expectedPostingList = pair.right; + final List expectedPostingList = pair.right; assertNotNull(actualPostingList); assertEquals(expectedPostingList.size(), actualPostingList.size()); for (int i = 0; i < expectedPostingList.size(); ++i) { - final long expectedRowID = expectedPostingList.get(i); + final long expectedRowID = expectedPostingList.get(i).rowId; long result = actualPostingList.nextPosting(); assertEquals(String.format("row %d mismatch of %d in enum %d", i, expectedPostingList.size(), termsEnum.indexOf(pair)), expectedRowID, result); } @@ -188,18 +203,18 @@ private void testTermQueries(Version version, int numTerms, int numPostings) thr (QueryEventListener.TrieIndexEventListener)NO_OP_TRIE_LISTENER, new QueryContext())) { - final IntArrayList expectedPostingList = pair.right; + final List expectedPostingList = pair.right; // test skipping to the last block final int idxToSkip = numPostings - 2; // tokens are equal to their corresponding row IDs - final int tokenToSkip = expectedPostingList.get(idxToSkip); + final int tokenToSkip = expectedPostingList.get(idxToSkip).rowId; long advanceResult = actualPostingList.advance(tokenToSkip); assertEquals(tokenToSkip, advanceResult); for (int i = idxToSkip + 1; i < expectedPostingList.size(); ++i) { - final long expectedRowID = expectedPostingList.get(i); + final long expectedRowID = expectedPostingList.get(i).rowId; long result = actualPostingList.nextPosting(); assertEquals(expectedRowID, result); } @@ -211,6 +226,17 @@ private void testTermQueries(Version version, int numTerms, int numPostings) thr } } + private Int2IntHashMap createMockDocLengths(List termsEnum) + { + Int2IntHashMap docLengths = new Int2IntHashMap(Integer.MIN_VALUE); + for (InvertedIndexBuilder.TermsEnum term : termsEnum) + { + for (var cursor : term.postings) + docLengths.put(cursor.value, 1); + } + return docLengths; + } + private List buildTermsEnum(Version version, int terms, int postings) { return buildStringTermsEnum(version, terms, postings, () -> randomSimpleString(4, 10), () -> nextInt(0, Integer.MAX_VALUE)); diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/kdtree/ImmutableOneDimPointValuesTest.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/kdtree/ImmutableOneDimPointValuesTest.java index 9f086249f331..ba8f33cd9951 100644 --- a/test/unit/org/apache/cassandra/index/sai/disk/v1/kdtree/ImmutableOneDimPointValuesTest.java +++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/kdtree/ImmutableOneDimPointValuesTest.java @@ -19,12 +19,14 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.List; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; -import com.carrotsearch.hppc.IntArrayList; +import org.apache.cassandra.index.sai.memory.RowMapping; import org.apache.cassandra.db.marshal.Int32Type; import org.apache.cassandra.index.sai.disk.MemtableTermsIterator; import org.apache.cassandra.index.sai.disk.TermsIterator; @@ -111,20 +113,22 @@ private TermsIterator buildDescTermEnum(int from, int to) final ByteBuffer minTerm = Int32Type.instance.decompose(from); final ByteBuffer maxTerm = Int32Type.instance.decompose(to); - final AbstractGuavaIterator> iterator = new AbstractGuavaIterator>() + final AbstractGuavaIterator>> iterator = new AbstractGuavaIterator<>() { private int currentTerm = from; @Override - protected Pair computeNext() + protected Pair> computeNext() { if (currentTerm <= to) { return endOfData(); } final ByteBuffer term = Int32Type.instance.decompose(currentTerm++); - IntArrayList postings = new IntArrayList(); - postings.add(0, 1, 2); + List postings = Arrays.asList( + new RowMapping.RowIdWithFrequency(0, 1), + new RowMapping.RowIdWithFrequency(1, 1), + new RowMapping.RowIdWithFrequency(2, 1)); return Pair.create(v -> ByteSource.preencoded(term), postings); } }; diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/kdtree/KDTreeIndexBuilder.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/kdtree/KDTreeIndexBuilder.java index fdbd1f9d91f8..3a867e41d90f 100644 --- a/test/unit/org/apache/cassandra/index/sai/disk/v1/kdtree/KDTreeIndexBuilder.java +++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/kdtree/KDTreeIndexBuilder.java @@ -21,7 +21,9 @@ import java.math.BigDecimal; import java.math.BigInteger; import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.Iterator; +import java.util.List; import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -31,6 +33,7 @@ import org.junit.Assert; import com.carrotsearch.hppc.IntArrayList; +import org.apache.cassandra.index.sai.memory.RowMapping; import org.apache.cassandra.db.marshal.AbstractType; import org.apache.cassandra.db.marshal.DecimalType; import org.apache.cassandra.db.marshal.Int32Type; @@ -142,7 +145,22 @@ public KDTreeIndexBuilder(IndexDescriptor indexDescriptor, KDTreeIndexSearcher flushAndOpen() throws IOException { - final TermsIterator termEnum = new MemtableTermsIterator(null, null, terms); + // Wrap postings with RowIdWithFrequency using default frequency of 1 + final TermsIterator termEnum = new MemtableTermsIterator(null, null, new AbstractGuavaIterator<>() + { + @Override + protected Pair> computeNext() + { + if (!terms.hasNext()) + return endOfData(); + + Pair pair = terms.next(); + List postings = new ArrayList<>(pair.right.size()); + for (int i = 0; i < pair.right.size(); i++) + postings.add(new RowMapping.RowIdWithFrequency(pair.right.get(i), 1)); + return Pair.create(pair.left, postings); + } + }); final ImmutableOneDimPointValues pointValues = ImmutableOneDimPointValues.fromTermEnum(termEnum, type); final SegmentMetadata metadata; diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/kdtree/NumericIndexWriterTest.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/kdtree/NumericIndexWriterTest.java index dcc1e7ef056d..27bff7fa40cb 100644 --- a/test/unit/org/apache/cassandra/index/sai/disk/v1/kdtree/NumericIndexWriterTest.java +++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/kdtree/NumericIndexWriterTest.java @@ -18,6 +18,8 @@ package org.apache.cassandra.index.sai.disk.v1.kdtree; import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; import org.junit.Before; import org.junit.Test; @@ -35,6 +37,7 @@ import org.apache.cassandra.index.sai.disk.format.IndexDescriptor; import org.apache.cassandra.index.sai.disk.v1.IndexWriterConfig; import org.apache.cassandra.index.sai.disk.v1.SegmentMetadata; +import org.apache.cassandra.index.sai.memory.RowMapping; import org.apache.cassandra.index.sai.metrics.QueryEventListener; import org.apache.cassandra.index.sai.metrics.QueryEventListeners; import org.apache.cassandra.index.sai.utils.SaiRandomizedTest; @@ -190,21 +193,21 @@ private TermsIterator buildTermEnum(int startTermInclusive, int endTermExclusive final ByteBuffer minTerm = Int32Type.instance.decompose(startTermInclusive); final ByteBuffer maxTerm = Int32Type.instance.decompose(endTermExclusive); - final AbstractGuavaIterator> iterator = new AbstractGuavaIterator>() + final AbstractGuavaIterator>> iterator = new AbstractGuavaIterator<>() { private int currentTerm = startTermInclusive; private int currentRowId = 0; @Override - protected Pair computeNext() + protected Pair> computeNext() { if (currentTerm >= endTermExclusive) { return endOfData(); } final ByteBuffer term = Int32Type.instance.decompose(currentTerm++); - final IntArrayList postings = new IntArrayList(); - postings.add(currentRowId++); + final List postings = new ArrayList<>(); + postings.add(new RowMapping.RowIdWithFrequency(currentRowId++, 1)); final ByteSource encoded = Int32Type.instance.asComparableBytes(term, TypeUtil.BYTE_COMPARABLE_VERSION); return Pair.create(v -> encoded, postings); } diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingListTest.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingListTest.java new file mode 100644 index 000000000000..76285a4f0bc4 --- /dev/null +++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingListTest.java @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.disk.v1.postings; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.IntStream; + +import org.junit.Test; + +import org.apache.cassandra.index.sai.disk.PostingList; +import org.apache.cassandra.index.sai.postings.IntArrayPostingList; +import org.apache.cassandra.index.sai.utils.SaiRandomizedTest; +import org.apache.cassandra.utils.ByteBufferUtil; + +public class IntersectingPostingListTest extends SaiRandomizedTest +{ + private Map createPostingMap(PostingList... lists) + { + Map map = new HashMap<>(); + for (int i = 0; i < lists.length; i++) + { + map.put(ByteBufferUtil.bytes(String.valueOf((char) ('A' + i))), lists[i]); + } + return map; + } + + @Test + public void shouldIntersectOverlappingPostingLists() throws IOException + { + var map = createPostingMap(new IntArrayPostingList(new int[]{ 1, 4, 6, 8 }), + new IntArrayPostingList(new int[]{ 2, 4, 6, 9 }), + new IntArrayPostingList(new int[]{ 4, 6, 7 })); + + final PostingList intersected = IntersectingPostingList.intersect(map); + assertPostingListEquals(new IntArrayPostingList(new int[]{ 4, 6 }), intersected); + } + + @Test + public void shouldIntersectDisjointPostingLists() throws IOException + { + var map = createPostingMap(new IntArrayPostingList(new int[]{ 1, 3, 5 }), + new IntArrayPostingList(new int[]{ 2, 4, 6 })); + + final PostingList intersected = IntersectingPostingList.intersect(map); + assertPostingListEquals(new IntArrayPostingList(new int[]{}), intersected); + } + + @Test + public void shouldIntersectSinglePostingList() throws IOException + { + var map = createPostingMap(new IntArrayPostingList(new int[]{ 1, 4, 6 })); + + final PostingList intersected = IntersectingPostingList.intersect(map); + assertPostingListEquals(new IntArrayPostingList(new int[]{ 1, 4, 6 }), intersected); + } + + @Test + public void shouldIntersectIdenticalPostingLists() throws IOException + { + var map = createPostingMap(new IntArrayPostingList(new int[]{ 1, 2, 3 }), + new IntArrayPostingList(new int[]{ 1, 2, 3 })); + + final PostingList intersected = IntersectingPostingList.intersect(map); + assertPostingListEquals(new IntArrayPostingList(new int[]{ 1, 2, 3 }), intersected); + } + + @Test + public void shouldAdvanceAllIntersectedLists() throws IOException + { + var map = createPostingMap(new IntArrayPostingList(new int[]{ 1, 3, 5, 7, 9 }), + new IntArrayPostingList(new int[]{ 2, 3, 5, 7, 8 }), + new IntArrayPostingList(new int[]{ 3, 5, 7, 10 })); + + final PostingList intersected = IntersectingPostingList.intersect(map); + final PostingList expected = new IntArrayPostingList(new int[]{ 3, 5, 7 }); + + assertEquals(expected.advance(5), intersected.advance(5)); + assertPostingListEquals(expected, intersected); + } + + @Test + public void shouldHandleEmptyList() throws IOException + { + var map = createPostingMap(new IntArrayPostingList(new int[]{}), + new IntArrayPostingList(new int[]{ 1, 2, 3 })); + + final PostingList intersected = IntersectingPostingList.intersect(map); + assertEquals(PostingList.END_OF_STREAM, intersected.advance(1)); + } + + @Test + public void shouldInterleaveNextAndAdvance() throws IOException + { + var map = createPostingMap(new IntArrayPostingList(new int[]{ 1, 3, 5, 7, 9 }), + new IntArrayPostingList(new int[]{ 1, 3, 5, 7, 9 }), + new IntArrayPostingList(new int[]{ 1, 3, 5, 7, 9 })); + + final PostingList intersected = IntersectingPostingList.intersect(map); + + assertEquals(1, intersected.nextPosting()); + assertEquals(5, intersected.advance(5)); + assertEquals(7, intersected.nextPosting()); + assertEquals(9, intersected.advance(9)); + } + + @Test + public void shouldInterleaveNextAndAdvanceOnRandom() throws IOException + { + for (int i = 0; i < 1000; ++i) + { + testAdvancingOnRandom(); + } + } + + private void testAdvancingOnRandom() throws IOException + { + final int postingsCount = nextInt(1, 50_000); + final int postingListCount = nextInt(2, 10); + + final AtomicInteger rowId = new AtomicInteger(); + final int[] commonPostings = IntStream.generate(() -> rowId.addAndGet(nextInt(1, 10))) + .limit(postingsCount / 4) + .toArray(); + + var splitPostingLists = new ArrayList(); + for (int i = 0; i < postingListCount; i++) + { + final int[] uniquePostings = IntStream.generate(() -> rowId.addAndGet(nextInt(1, 10))) + .limit(postingsCount) + .toArray(); + int[] combined = IntStream.concat(IntStream.of(commonPostings), + IntStream.of(uniquePostings)) + .distinct() + .sorted() + .toArray(); + splitPostingLists.add(new IntArrayPostingList(combined)); + } + + final PostingList intersected = IntersectingPostingList.intersect(createPostingMap(splitPostingLists.toArray(new PostingList[0]))); + final PostingList expected = new IntArrayPostingList(commonPostings); + + final List actions = new ArrayList<>(); + for (int idx = 0; idx < commonPostings.length; idx++) + { + if (nextInt(0, 8) == 0) + { + actions.add((postingList) -> { + try + { + return postingList.nextPosting(); + } + catch (IOException e) + { + fail(e.getMessage()); + throw new RuntimeException(e); + } + }); + } + else + { + final int skips = nextInt(0, 5); + idx = Math.min(idx + skips, commonPostings.length - 1); + final int rowID = commonPostings[idx]; + actions.add((postingList) -> { + try + { + return postingList.advance(rowID); + } + catch (IOException e) + { + fail(e.getMessage()); + throw new RuntimeException(e); + } + }); + } + } + + for (PostingListAdvance action : actions) + { + assertEquals(action.advance(expected), action.advance(intersected)); + } + } + + private interface PostingListAdvance + { + long advance(PostingList list) throws IOException; + } +} diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/PostingsTest.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/PostingsTest.java index 184bfd3af8f0..451789e872c9 100644 --- a/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/PostingsTest.java +++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/PostingsTest.java @@ -328,7 +328,7 @@ private PostingsReader openReader(IndexComponent.ForRead postingLists, long fp, private PostingsReader.BlocksSummary assertBlockSummary(int blockSize, PostingList expected, IndexInput input) throws IOException { final PostingsReader.BlocksSummary summary = new PostingsReader.BlocksSummary(input, input.getFilePointer()); - assertEquals(blockSize, summary.blockSize); + assertEquals(blockSize, summary.blockEntries); assertEquals(expected.size(), summary.numPostings); assertTrue(summary.offsets.length() > 0); assertEquals(summary.offsets.length(), summary.maxValues.length()); diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/VectorPostingListTest.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/ReorderingPostingListTest.java similarity index 91% rename from test/unit/org/apache/cassandra/index/sai/disk/v1/postings/VectorPostingListTest.java rename to test/unit/org/apache/cassandra/index/sai/disk/v1/postings/ReorderingPostingListTest.java index 3135db3c7748..10b2b6a33f47 100644 --- a/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/VectorPostingListTest.java +++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/ReorderingPostingListTest.java @@ -30,14 +30,14 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -public class VectorPostingListTest +public class ReorderingPostingListTest { @Test public void ensureEmptySourceBehavesCorrectly() throws Throwable { var source = new TestIterator(CloseableIterator.emptyIterator()); - try (var postingList = new VectorPostingList(source)) + try (var postingList = new ReorderingPostingList(source, RowIdWithScore::getSegmentRowId)) { // Even an empty source should be closed assertTrue(source.isClosed); @@ -55,7 +55,7 @@ public void ensureIteratorIsConsumedClosedAndReordered() throws Throwable new RowIdWithScore(4, 4), }).iterator()); - try (var postingList = new VectorPostingList(source)) + try (var postingList = new ReorderingPostingList(source, RowIdWithScore::getSegmentRowId)) { // The posting list is eagerly consumed, so it should be closed before // we close postingList @@ -80,7 +80,7 @@ public void ensureAdvanceWorksCorrectly() throws Throwable new RowIdWithScore(2, 2), }).iterator()); - try (var postingList = new VectorPostingList(source)) + try (var postingList = new ReorderingPostingList(source, RowIdWithScore::getSegmentRowId)) { assertEquals(3, postingList.advance(3)); assertEquals(PostingList.END_OF_STREAM, postingList.advance(4)); diff --git a/test/unit/org/apache/cassandra/index/sai/memory/TrieMemoryIndexTest.java b/test/unit/org/apache/cassandra/index/sai/memory/TrieMemoryIndexTest.java index 4119b3afb872..6199266d6da0 100644 --- a/test/unit/org/apache/cassandra/index/sai/memory/TrieMemoryIndexTest.java +++ b/test/unit/org/apache/cassandra/index/sai/memory/TrieMemoryIndexTest.java @@ -82,17 +82,16 @@ public void iteratorShouldReturnAllValuesNumeric() index.add(makeKey(table, Integer.toString(row)), Clustering.EMPTY, Int32Type.instance.decompose(row / 10), allocatedBytes -> {}, allocatesBytes -> {}); } - Iterator> iterator = index.iterator(); + var iterator = index.iterator(); int valueCount = 0; while(iterator.hasNext()) { - Pair pair = iterator.next(); + var pair = iterator.next(); int value = ByteSourceInverse.getSignedInt(pair.left.asComparableBytes(TypeUtil.BYTE_COMPARABLE_VERSION)); int idCount = 0; - Iterator primaryKeyIterator = pair.right.iterator(); - while (primaryKeyIterator.hasNext()) + for (var pkf : pair.right) { - PrimaryKey primaryKey = primaryKeyIterator.next(); + PrimaryKey primaryKey = pkf.pk; int id = Int32Type.instance.compose(primaryKey.partitionKey().getKey()); assertEquals(id/10, value); idCount++; @@ -113,17 +112,16 @@ public void iteratorShouldReturnAllValuesString() index.add(makeKey(table, Integer.toString(row)), Clustering.EMPTY, UTF8Type.instance.decompose(Integer.toString(row / 10)), allocatedBytes -> {}, allocatesBytes -> {}); } - Iterator> iterator = index.iterator(); + var iterator = index.iterator(); int valueCount = 0; while(iterator.hasNext()) { - Pair pair = iterator.next(); + var pair = iterator.next(); String value = new String(ByteSourceInverse.readBytes(pair.left.asPeekableBytes(TypeUtil.BYTE_COMPARABLE_VERSION)), StandardCharsets.UTF_8); int idCount = 0; - Iterator primaryKeyIterator = pair.right.iterator(); - while (primaryKeyIterator.hasNext()) + for (var pkf : pair.right) { - PrimaryKey primaryKey = primaryKeyIterator.next(); + PrimaryKey primaryKey = pkf.pk; String id = UTF8Type.instance.compose(primaryKey.partitionKey().getKey()); assertEquals(Integer.toString(Integer.parseInt(id) / 10), value); idCount++; @@ -149,11 +147,11 @@ private void shouldAcceptPrefixValuesForType(AbstractType type, IntFunction {}, allocatesBytes -> {}); } - final Iterator> iterator = index.iterator(); + final var iterator = index.iterator(); int i = 0; while (iterator.hasNext()) { - Pair pair = iterator.next(); + var pair = iterator.next(); assertEquals(1, pair.right.size()); final int rowId = i; diff --git a/test/unit/org/apache/cassandra/index/sai/memory/TrieMemtableIndexTestBase.java b/test/unit/org/apache/cassandra/index/sai/memory/TrieMemtableIndexTestBase.java index 0f1089250348..21b13059de61 100644 --- a/test/unit/org/apache/cassandra/index/sai/memory/TrieMemtableIndexTestBase.java +++ b/test/unit/org/apache/cassandra/index/sai/memory/TrieMemtableIndexTestBase.java @@ -225,11 +225,11 @@ public void indexIteratorTest() DecoratedKey minimum = temp1.compareTo(temp2) <= 0 ? temp1 : temp2; DecoratedKey maximum = temp1.compareTo(temp2) <= 0 ? temp2 : temp1; - Iterator>> iterator = memtableIndex.iterator(minimum, maximum); + var iterator = memtableIndex.iterator(minimum, maximum); while (iterator.hasNext()) { - Pair> termPair = iterator.next(); + var termPair = iterator.next(); int term = termFromComparable(termPair.left); // The iterator will return keys outside the range of min/max, so we need to filter here to // get the correct keys @@ -239,9 +239,9 @@ public void indexIteratorTest() .sorted() .collect(Collectors.toList()); List termPks = new ArrayList<>(); - while (termPair.right.hasNext()) + for (var pkWithFreq : termPair.right) { - DecoratedKey pk = termPair.right.next().partitionKey(); + DecoratedKey pk = pkWithFreq.pk.partitionKey(); if (pk.compareTo(minimum) >= 0 && pk.compareTo(maximum) <= 0) termPks.add(pk); } diff --git a/test/unit/org/apache/cassandra/index/sai/memory/VectorMemtableIndexTest.java b/test/unit/org/apache/cassandra/index/sai/memory/VectorMemtableIndexTest.java index 75e8c83f30be..2197c00fe231 100644 --- a/test/unit/org/apache/cassandra/index/sai/memory/VectorMemtableIndexTest.java +++ b/test/unit/org/apache/cassandra/index/sai/memory/VectorMemtableIndexTest.java @@ -148,7 +148,7 @@ private void validate(List keys) { IntStream.range(0, 1_000).parallel().forEach(i -> { - var orderer = generateRandomOrderer(); + var orderer = randomVectorOrderer(); AbstractBounds keyRange = generateRandomBounds(keys); // compute keys in range of the bounds Set keysInRange = keys.stream().filter(keyRange::contains) @@ -197,7 +197,7 @@ public void indexIteratorTest() // VSTODO } - private Orderer generateRandomOrderer() + private Orderer randomVectorOrderer() { return new Orderer(indexContext, Operator.ANN, randomVectorSerialized()); } diff --git a/test/unit/org/apache/cassandra/index/sai/metrics/IndexMetricsTest.java b/test/unit/org/apache/cassandra/index/sai/metrics/IndexMetricsTest.java index cc65bd20c293..c59f10c8e42f 100644 --- a/test/unit/org/apache/cassandra/index/sai/metrics/IndexMetricsTest.java +++ b/test/unit/org/apache/cassandra/index/sai/metrics/IndexMetricsTest.java @@ -30,6 +30,7 @@ public class IndexMetricsTest extends AbstractMetricsTest { + private static final String TABLE = "table_name"; private static final String INDEX = "table_name_index"; @@ -155,7 +156,7 @@ public void testQueriesCount() int rowCount = 10; for (int i = 0; i < rowCount; i++) - execute("INSERT INTO %s (id1, v1, v2, v3) VALUES (?, ?, '0', [?, 0.0])", Integer.toString(i), i, i); + execute("INSERT INTO %s (id1, v1, v2, v3) VALUES (?, ?, '0', ?)", Integer.toString(i), i, vector(i, i)); assertIndexQueryCount(indexV1, 0L); @@ -180,7 +181,7 @@ public void testQueriesCount() assertIndexQueryCount(indexV1, 4L); assertIndexQueryCount(indexV2, 2L); - String indexV3 = createIndex("CREATE CUSTOM INDEX ON %s (v3) USING 'StorageAttachedIndex'"); + String indexV3 = createIndex("CREATE CUSTOM INDEX ON %s (v3) USING 'StorageAttachedIndex' WITH OPTIONS = {'similarity_function': 'euclidean'}"); assertIndexQueryCount(indexV3, 0L); executeNet("SELECT id1 FROM %s WHERE v2 = '2' ORDER BY v3 ANN OF [5,0] LIMIT 10"); assertIndexQueryCount(indexV1, 4L);