diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/StructFilters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/StructFilters.scala index 4ac62b987b151..1b2013d87eedf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/StructFilters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/StructFilters.scala @@ -156,6 +156,8 @@ object StructFilters { Some(Literal(true, BooleanType)) case sources.AlwaysFalse() => Some(Literal(false, BooleanType)) + case _: sources.CollatedFilter => + None } translate(filter) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala index a52bca1066059..88f556130bfe6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.parseColumnPath import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue, NamedReference} import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse => V2AlwaysFalse, AlwaysTrue => V2AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate} -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.{DataType, StringType} import org.apache.spark.unsafe.types.UTF8String //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -381,3 +381,87 @@ case class AlwaysFalse() extends Filter { @Evolving object AlwaysFalse extends AlwaysFalse { } + +/** + * Base class for collation aware string filters. + */ +@Evolving +abstract class CollatedFilter() extends Filter { + + /** The corresponding non-collation aware filter. */ + def correspondingFilter: Filter + def dataType: DataType + + override def references: Array[String] = correspondingFilter.references + override def toV2: Predicate = correspondingFilter.toV2 +} + +/** Collation aware equivalent of [[EqualTo]]. */ +@Evolving +case class CollatedEqualTo(attribute: String, value: Any, dataType: DataType) + extends CollatedFilter { + override def correspondingFilter: Filter = EqualTo(attribute, value) +} + +/** Collation aware equivalent of [[EqualNullSafe]]. */ +@Evolving +case class CollatedEqualNullSafe(attribute: String, value: Any, dataType: DataType) + extends CollatedFilter { + override def correspondingFilter: Filter = EqualNullSafe(attribute, value) +} + +/** Collation aware equivalent of [[GreaterThan]]. */ +@Evolving +case class CollatedGreaterThan(attribute: String, value: Any, dataType: DataType) + extends CollatedFilter { + override def correspondingFilter: Filter = GreaterThan(attribute, value) +} + +/** Collation aware equivalent of [[GreaterThanOrEqual]]. */ +@Evolving +case class CollatedGreaterThanOrEqual(attribute: String, value: Any, dataType: DataType) + extends CollatedFilter { + override def correspondingFilter: Filter = GreaterThanOrEqual(attribute, value) +} + +/** Collation aware equivalent of [[LessThan]]. */ +@Evolving +case class CollatedLessThan(attribute: String, value: Any, dataType: DataType) + extends CollatedFilter { + override def correspondingFilter: Filter = LessThan(attribute, value) +} + +/** Collation aware equivalent of [[LessThanOrEqual]]. */ +@Evolving +case class CollatedLessThanOrEqual(attribute: String, value: Any, dataType: DataType) + extends CollatedFilter { + override def correspondingFilter: Filter = LessThanOrEqual(attribute, value) +} + +/** Collation aware equivalent of [[In]]. */ +@Evolving +case class CollatedIn(attribute: String, values: Array[Any], dataType: DataType) + extends CollatedFilter { + override def correspondingFilter: Filter = In(attribute, values) +} + +/** Collation aware equivalent of [[StringStartsWith]]. */ +@Evolving +case class CollatedStringStartsWith(attribute: String, value: String, dataType: DataType) + extends CollatedFilter { + override def correspondingFilter: Filter = StringStartsWith(attribute, value) +} + +/** Collation aware equivalent of [[StringEndsWith]]. */ +@Evolving +case class CollatedStringEndsWith(attribute: String, value: String, dataType: DataType) + extends CollatedFilter { + override def correspondingFilter: Filter = StringEndsWith(attribute, value) +} + +/** Collation aware equivalent of [[StringContains]]. */ +@Evolving +case class CollatedStringContains(attribute: String, value: String, dataType: DataType) + extends CollatedFilter { + override def correspondingFilter: Filter = StringContains(attribute, value) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 7cda347ce581b..5d2310c130703 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -500,42 +500,76 @@ object DataSourceStrategy } } + /** + * Creates a collation aware filter if the input data type is string with non-default collation + */ + private def collationAwareFilter(filter: sources.Filter, dataType: DataType): Filter = { + if (!SchemaUtils.hasNonUTF8BinaryCollation(dataType)) { + return filter + } + + filter match { + case sources.EqualTo(attribute, value) => + CollatedEqualTo(attribute, value, dataType) + case sources.EqualNullSafe(attribute, value) => + CollatedEqualNullSafe(attribute, value, dataType) + case sources.GreaterThan(attribute, value) => + CollatedGreaterThan(attribute, value, dataType) + case sources.GreaterThanOrEqual(attribute, value) => + CollatedGreaterThanOrEqual(attribute, value, dataType) + case sources.LessThan(attribute, value) => + CollatedLessThan(attribute, value, dataType) + case sources.LessThanOrEqual(attribute, value) => + CollatedLessThanOrEqual(attribute, value, dataType) + case sources.In(attribute, values) => + CollatedIn(attribute, values, dataType) + case sources.StringStartsWith(attribute, value) => + CollatedStringStartsWith(attribute, value, dataType) + case sources.StringEndsWith(attribute, value) => + CollatedStringEndsWith(attribute, value, dataType) + case sources.StringContains(attribute, value) => + CollatedStringContains(attribute, value, dataType) + case other => + other + } + } + private def translateLeafNodeFilter( predicate: Expression, pushableColumn: PushableColumnBase): Option[Filter] = predicate match { - case expressions.EqualTo(pushableColumn(name), Literal(v, t)) => - Some(sources.EqualTo(name, convertToScala(v, t))) - case expressions.EqualTo(Literal(v, t), pushableColumn(name)) => - Some(sources.EqualTo(name, convertToScala(v, t))) - - case expressions.EqualNullSafe(pushableColumn(name), Literal(v, t)) => - Some(sources.EqualNullSafe(name, convertToScala(v, t))) - case expressions.EqualNullSafe(Literal(v, t), pushableColumn(name)) => - Some(sources.EqualNullSafe(name, convertToScala(v, t))) - - case expressions.GreaterThan(pushableColumn(name), Literal(v, t)) => - Some(sources.GreaterThan(name, convertToScala(v, t))) - case expressions.GreaterThan(Literal(v, t), pushableColumn(name)) => - Some(sources.LessThan(name, convertToScala(v, t))) - - case expressions.LessThan(pushableColumn(name), Literal(v, t)) => - Some(sources.LessThan(name, convertToScala(v, t))) - case expressions.LessThan(Literal(v, t), pushableColumn(name)) => - Some(sources.GreaterThan(name, convertToScala(v, t))) - - case expressions.GreaterThanOrEqual(pushableColumn(name), Literal(v, t)) => - Some(sources.GreaterThanOrEqual(name, convertToScala(v, t))) - case expressions.GreaterThanOrEqual(Literal(v, t), pushableColumn(name)) => - Some(sources.LessThanOrEqual(name, convertToScala(v, t))) - - case expressions.LessThanOrEqual(pushableColumn(name), Literal(v, t)) => - Some(sources.LessThanOrEqual(name, convertToScala(v, t))) - case expressions.LessThanOrEqual(Literal(v, t), pushableColumn(name)) => - Some(sources.GreaterThanOrEqual(name, convertToScala(v, t))) + case expressions.EqualTo(e @ pushableColumn(name), Literal(v, t)) => + Some(collationAwareFilter(sources.EqualTo(name, convertToScala(v, t)), e.dataType)) + case expressions.EqualTo(Literal(v, t), e @ pushableColumn(name)) => + Some(collationAwareFilter(sources.EqualTo(name, convertToScala(v, t)), e.dataType)) + + case expressions.EqualNullSafe(e @ pushableColumn(name), Literal(v, t)) => + Some(collationAwareFilter(sources.EqualNullSafe(name, convertToScala(v, t)), e.dataType)) + case expressions.EqualNullSafe(Literal(v, t), e @ pushableColumn(name)) => + Some(collationAwareFilter(sources.EqualNullSafe(name, convertToScala(v, t)), e.dataType)) + + case expressions.GreaterThan(e @ pushableColumn(name), Literal(v, t)) => + Some(collationAwareFilter(sources.GreaterThan(name, convertToScala(v, t)), e.dataType)) + case expressions.GreaterThan(Literal(v, t), e @ pushableColumn(name)) => + Some(collationAwareFilter(sources.LessThan(name, convertToScala(v, t)), e.dataType)) + + case expressions.LessThan(e @ pushableColumn(name), Literal(v, t)) => + Some(collationAwareFilter(sources.LessThan(name, convertToScala(v, t)), e.dataType)) + case expressions.LessThan(Literal(v, t), e @ pushableColumn(name)) => + Some(collationAwareFilter(sources.GreaterThan(name, convertToScala(v, t)), e.dataType)) + + case expressions.GreaterThanOrEqual(e @ pushableColumn(name), Literal(v, t)) => + Some(collationAwareFilter(sources.GreaterThanOrEqual(name, convertToScala(v, t)), e.dataType)) + case expressions.GreaterThanOrEqual(Literal(v, t), e @ pushableColumn(name)) => + Some(collationAwareFilter(sources.LessThanOrEqual(name, convertToScala(v, t)), e.dataType)) + + case expressions.LessThanOrEqual(e @ pushableColumn(name), Literal(v, t)) => + Some(collationAwareFilter(sources.LessThanOrEqual(name, convertToScala(v, t)), e.dataType)) + case expressions.LessThanOrEqual(Literal(v, t), e @ pushableColumn(name)) => + Some(collationAwareFilter(sources.GreaterThanOrEqual(name, convertToScala(v, t)), e.dataType)) case expressions.InSet(e @ pushableColumn(name), set) => val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType) - Some(sources.In(name, set.toArray.map(toScala))) + Some(collationAwareFilter(sources.In(name, set.toArray.map(toScala)), e.dataType)) // Because we only convert In to InSet in Optimizer when there are more than certain // items. So it is possible we still get an In expression here that needs to be pushed @@ -543,20 +577,20 @@ object DataSourceStrategy case expressions.In(e @ pushableColumn(name), list) if list.forall(_.isInstanceOf[Literal]) => val hSet = list.map(_.eval(EmptyRow)) val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType) - Some(sources.In(name, hSet.toArray.map(toScala))) + Some(collationAwareFilter(sources.In(name, hSet.toArray.map(toScala)), e.dataType)) case expressions.IsNull(pushableColumn(name)) => Some(sources.IsNull(name)) case expressions.IsNotNull(pushableColumn(name)) => Some(sources.IsNotNull(name)) - case expressions.StartsWith(pushableColumn(name), Literal(v: UTF8String, StringType)) => - Some(sources.StringStartsWith(name, v.toString)) + case expressions.StartsWith(e @ pushableColumn(name), Literal(v: UTF8String, StringType)) => + Some(collationAwareFilter(sources.StringStartsWith(name, v.toString), e.dataType)) - case expressions.EndsWith(pushableColumn(name), Literal(v: UTF8String, StringType)) => - Some(sources.StringEndsWith(name, v.toString)) + case expressions.EndsWith(e @ pushableColumn(name), Literal(v: UTF8String, StringType)) => + Some(collationAwareFilter(sources.StringEndsWith(name, v.toString), e.dataType)) - case expressions.Contains(pushableColumn(name), Literal(v: UTF8String, StringType)) => - Some(sources.StringContains(name, v.toString)) + case expressions.Contains(e @ pushableColumn(name), Literal(v: UTF8String, StringType)) => + Some(collationAwareFilter(sources.StringContains(name, v.toString), e.dataType)) case expressions.Literal(true, BooleanType) => Some(sources.AlwaysTrue) @@ -595,16 +629,6 @@ object DataSourceStrategy translatedFilterToExpr: Option[mutable.HashMap[sources.Filter, Expression]], nestedPredicatePushdownEnabled: Boolean) : Option[Filter] = { - - def translateAndRecordLeafNodeFilter(filter: Expression): Option[Filter] = { - val translatedFilter = - translateLeafNodeFilter(filter, PushableColumn(nestedPredicatePushdownEnabled)) - if (translatedFilter.isDefined && translatedFilterToExpr.isDefined) { - translatedFilterToExpr.get(translatedFilter.get) = predicate - } - translatedFilter - } - predicate match { case expressions.And(left, right) => // See SPARK-12218 for detailed discussion @@ -631,25 +655,16 @@ object DataSourceStrategy right, translatedFilterToExpr, nestedPredicatePushdownEnabled) } yield sources.Or(leftFilter, rightFilter) - case notNull @ expressions.IsNotNull(_: AttributeReference) => - // Not null filters on attribute references can always be pushed, also for collated columns. - translateAndRecordLeafNodeFilter(notNull) - - case isNull @ expressions.IsNull(_: AttributeReference) => - // Is null filters on attribute references can always be pushed, also for collated columns. - translateAndRecordLeafNodeFilter(isNull) - - case p if p.references.exists(ref => SchemaUtils.hasNonUTF8BinaryCollation(ref.dataType)) => - // The filter cannot be pushed and we widen it to be AlwaysTrue(). This is only valid if - // the result of the filter is not negated by a Not expression it is wrapped in. - translateAndRecordLeafNodeFilter(Literal.TrueLiteral) - case expressions.Not(child) => translateFilterWithMapping(child, translatedFilterToExpr, nestedPredicatePushdownEnabled) .map(sources.Not) case other => - translateAndRecordLeafNodeFilter(other) + val filter = translateLeafNodeFilter(other, PushableColumn(nestedPredicatePushdownEnabled)) + if (filter.isDefined && translatedFilterToExpr.isDefined) { + translatedFilterToExpr.get(filter.get) = predicate + } + filter } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala index b6aee77577a42..c80dc83079675 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala @@ -26,9 +26,9 @@ import org.json4s.{Formats, NoTypeHints} import org.json4s.jackson.Serialization import org.apache.spark.{SparkException, SparkUpgradeException} -import org.apache.spark.sql.{SPARK_LEGACY_DATETIME_METADATA_KEY, SPARK_LEGACY_INT96_METADATA_KEY, SPARK_TIMEZONE_METADATA_KEY, SPARK_VERSION_METADATA_KEY} +import org.apache.spark.sql.{sources, SPARK_LEGACY_DATETIME_METADATA_KEY, SPARK_LEGACY_INT96_METADATA_KEY, SPARK_TIMEZONE_METADATA_KEY, SPARK_VERSION_METADATA_KEY} import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogUtils} -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, ExpressionSet, GetStructField, PredicateHelper} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, ExpressionSet, PredicateHelper} import org.apache.spark.sql.catalyst.util.RebaseDateTime import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} @@ -36,7 +36,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ -import org.apache.spark.sql.util.{CaseInsensitiveStringMap, SchemaUtils} +import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.Utils @@ -280,22 +280,15 @@ object DataSourceUtils extends PredicateHelper { (ExpressionSet(partitionFilters ++ extraPartitionFilter).toSeq, dataFilters) } - /** - * Determines whether a filter should be pushed down to the data source or not. - * - * @param expression The filter expression to be evaluated. - * @param isCollationPushDownSupported Whether the data source supports collation push down. - * @return A boolean indicating whether the filter should be pushed down or not. - */ - def shouldPushFilter(expression: Expression, isCollationPushDownSupported: Boolean): Boolean = { - if (!expression.deterministic) return false - - isCollationPushDownSupported || !expression.exists { - case childExpression @ (_: Attribute | _: GetStructField) => - // don't push down filters for types with non-binary sortable collation - // as it could lead to incorrect results - SchemaUtils.hasNonUTF8BinaryCollation(childExpression.dataType) - + def containsFiltersWithCollation(filter: sources.Filter): Boolean = { + filter match { + case sources.And(left, right) => + containsFiltersWithCollation(left) || containsFiltersWithCollation(right) + case sources.Or(left, right) => + containsFiltersWithCollation(left) || containsFiltersWithCollation(right) + case sources.Not(child) => + containsFiltersWithCollation(child) + case _: sources.CollatedFilter => true case _ => false } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala index 0785b0cbe9e23..36c59950fe209 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala @@ -223,12 +223,6 @@ trait FileFormat { */ def fileConstantMetadataExtractors: Map[String, PartitionedFile => Any] = FileFormat.BASE_METADATA_EXTRACTORS - - /** - * Returns whether the file format supports filter push down - * for non utf8 binary collated columns. - */ - def supportsCollationPushDown: Boolean = false } object FileFormat { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index d31cb111924b3..27019ab047ff2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -160,11 +160,8 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging { // - filters that need to be evaluated again after the scan val filterSet = ExpressionSet(filters) - val filtersToPush = filters.filter(f => - DataSourceUtils.shouldPushFilter(f, fsRelation.fileFormat.supportsCollationPushDown)) - val normalizedFilters = DataSourceStrategy.normalizeExprs( - filtersToPush, l.output) + filters.filter(_.deterministic), l.output) val partitionColumns = l.resolve( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index b0431d1df3987..1dffea4e1bc87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -63,8 +63,7 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { _)) if filters.nonEmpty && fsRelation.partitionSchema.nonEmpty => val normalizedFilters = DataSourceStrategy.normalizeExprs( - filters.filter(f => !SubqueryExpression.hasSubquery(f) && - DataSourceUtils.shouldPushFilter(f, fsRelation.fileFormat.supportsCollationPushDown)), + filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)), logicalRelation.output) val (partitionKeyFilters, _) = DataSourceUtils .getPartitionFiltersAndDataFilters(partitionSchema, normalizedFilters) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala index 7cd2779f86f95..447a36fe622c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala @@ -70,10 +70,9 @@ abstract class FileScanBuilder( } override def pushFilters(filters: Seq[Expression]): Seq[Expression] = { - val (filtersToPush, filtersToRemain) = filters.partition( - f => DataSourceUtils.shouldPushFilter(f, supportsCollationPushDown)) + val (deterministicFilters, nonDeterminsticFilters) = filters.partition(_.deterministic) val (partitionFilters, dataFilters) = - DataSourceUtils.getPartitionFiltersAndDataFilters(partitionSchema, filtersToPush) + DataSourceUtils.getPartitionFiltersAndDataFilters(partitionSchema, deterministicFilters) this.partitionFilters = partitionFilters this.dataFilters = dataFilters val translatedFilters = mutable.ArrayBuffer.empty[sources.Filter] @@ -84,7 +83,7 @@ abstract class FileScanBuilder( } } pushedDataFilters = pushDataFilters(translatedFilters.toArray) - dataFilters ++ filtersToRemain + dataFilters ++ nonDeterminsticFilters } override def pushedFilters: Array[Predicate] = pushedDataFilters.map(_.toV2) @@ -96,12 +95,6 @@ abstract class FileScanBuilder( */ protected def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = Array.empty[Filter] - /** - * Returns whether the file scan builder supports filter pushdown - * for non utf8 binary collated columns. - */ - protected def supportsCollationPushDown: Boolean = false - private def createRequiredNameSet(): Set[String] = requiredSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 5fbe88a09e7cc..229677d208136 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GreaterTha import org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils.{negativeInt, positiveInt} import org.apache.spark.sql.catalyst.plans.logical.Filter import org.apache.spark.sql.catalyst.types.DataTypeUtils -import org.apache.spark.sql.execution.{ExplainMode, FileSourceScanLike, SimpleMode} +import org.apache.spark.sql.execution.{FileSourceScanLike, SimpleMode} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.FilePartition import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, FileScan} @@ -1242,59 +1242,6 @@ class FileBasedDataSourceSuite extends QueryTest } } } - - test("disable filter pushdown for collated strings") { - Seq("parquet").foreach { format => - Seq(format, "").foreach { conf => - withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> conf) { - withTempPath { path => - val collation = "'UTF8_LCASE'" - val df = sql( - s"""SELECT - | COLLATE(c, $collation) as c1, - | struct(COLLATE(c, $collation)) as str, - | named_struct('f1', named_struct('f2', - | COLLATE(c, $collation), 'f3', 1)) as namedstr, - | array(COLLATE(c, $collation)) as arr, - | map(COLLATE(c, $collation), 1) as map1, - | map(1, COLLATE(c, $collation)) as map2 - |FROM VALUES ('aaa'), ('AAA'), ('bbb') - |as data(c) - |""".stripMargin) - - df.write.format(format).save(path.getAbsolutePath) - - // filter and expected result - val filters = Seq( - ("==", Seq(Row("aaa"), Row("AAA"))), - ("!=", Seq(Row("bbb"))), - ("<", Seq()), - ("<=", Seq(Row("aaa"), Row("AAA"))), - (">", Seq(Row("bbb"))), - (">=", Seq(Row("aaa"), Row("AAA"), Row("bbb")))) - - filters.foreach { filter => - val readback = spark.read - .format(format) - .load(path.getAbsolutePath) - .where(s"c1 ${filter._1} collate('aaa', $collation)") - .where(s"str ${filter._1} struct(collate('aaa', $collation))") - .where(s"namedstr.f1.f2 ${filter._1} collate('aaa', $collation)") - .where(s"arr ${filter._1} array(collate('aaa', $collation))") - .where(s"map_keys(map1) ${filter._1} array(collate('aaa', $collation))") - .where(s"map_values(map2) ${filter._1} array(collate('aaa', $collation))") - .select("c1") - - val explain = readback.queryExecution.explainString( - ExplainMode.fromString("extended")) - assert(explain.contains("PushedFilters: []")) - checkAnswer(readback, filter._2) - } - } - } - } - } - } } object TestingUDT { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToParquetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToParquetSuite.scala new file mode 100644 index 0000000000000..ab8e82162ce10 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToParquetSuite.scala @@ -0,0 +1,265 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.collation + +import org.apache.parquet.schema.MessageType + +import org.apache.spark.SparkConf +import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, SparkToParquetSchemaConverter} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation +import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan +import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} +import org.apache.spark.sql.sources.{EqualTo, Filter, IsNotNull} +import org.apache.spark.sql.test.SharedSparkSession + +abstract class CollatedFilterPushDownToParquetSuite extends QueryTest + with SharedSparkSession + with AdaptiveSparkPlanHelper { + + val dataSource = "parquet" + val nonCollatedCol = "c0" + val collatedCol = "c1" + val collatedStructCol = "c2" + val collatedStructNestedCol = "f1" + val collatedStructFieldAccess = s"$collatedStructCol.$collatedStructNestedCol" + val collatedArrayCol = "c3" + val collatedMapCol = "c4" + + val lcaseCollation = "'UTF8_LCASE'" + + def getPushedDownFilters(query: DataFrame): Seq[Filter] + + protected def createParquetFilters(schema: MessageType): ParquetFilters = + new ParquetFilters(schema, conf.parquetFilterPushDownDate, conf.parquetFilterPushDownTimestamp, + conf.parquetFilterPushDownDecimal, conf.parquetFilterPushDownStringPredicate, + conf.parquetFilterPushDownInFilterThreshold, + conf.caseSensitiveAnalysis, + RebaseSpec(LegacyBehaviorPolicy.CORRECTED)) + + def testPushDown( + filterString: String, + expectedPushedFilters: Seq[Filter], + expectedRowCount: Int): Unit = { + withTempPath { path => + val df = sql( + s""" + |SELECT + | c as $nonCollatedCol, + | COLLATE(c, $lcaseCollation) as $collatedCol, + | named_struct('$collatedStructNestedCol', + | COLLATE(c, $lcaseCollation)) as $collatedStructCol, + | array(COLLATE(c, $lcaseCollation)) as $collatedArrayCol, + | map(COLLATE(c, $lcaseCollation), 1) as $collatedMapCol + |FROM VALUES ('aaa'), ('AAA'), ('bbb') + |as data(c) + |""".stripMargin) + + df.write.format(dataSource).save(path.getAbsolutePath) + + val query = spark.read.format(dataSource).load(path.getAbsolutePath) + .filter(filterString) + + val actualPushedFilters = getPushedDownFilters(query) + assert(actualPushedFilters.toSet === expectedPushedFilters.toSet) + assert(query.count() === expectedRowCount) + } + } + + test("do not push down anything for literal comparison") { + testPushDown( + filterString = s"'aaa' COLLATE UNICODE = 'bbb' COLLATE UNICODE", + expectedPushedFilters = Seq.empty, + expectedRowCount = 0) + } + + test("push down null check for collated column") { + testPushDown( + filterString = s"$collatedCol = 'aaa'", + expectedPushedFilters = Seq(IsNotNull(collatedCol)), + expectedRowCount = 2) + } + + test("push down null check for non-equality check") { + testPushDown( + filterString = s"$collatedCol != 'aaa'", + expectedPushedFilters = Seq(IsNotNull(collatedCol)), + expectedRowCount = 1) + } + + test("push down null check for greater than check") { + testPushDown( + filterString = s"$collatedCol > 'aaa'", + expectedPushedFilters = Seq(IsNotNull(collatedCol)), + expectedRowCount = 1) + } + + test("push down null check for gte check") { + testPushDown( + filterString = s"$collatedCol >= 'aaa'", + expectedPushedFilters = Seq(IsNotNull(collatedCol)), + expectedRowCount = 3) + } + + test("push down null check for less than check") { + testPushDown( + filterString = s"$collatedCol < 'aaa'", + expectedPushedFilters = Seq(IsNotNull(collatedCol)), + expectedRowCount = 0) + } + + test("push down null check for lte check") { + testPushDown( + filterString = s"$collatedCol <= 'aaa'", + expectedPushedFilters = Seq(IsNotNull(collatedCol)), + expectedRowCount = 2) + } + + test("push down null check for STARTSWITH") { + testPushDown( + filterString = s"STARTSWITH($collatedCol, 'a')", + expectedPushedFilters = Seq(IsNotNull(collatedCol)), + expectedRowCount = 2) + } + + test("push down null check for ENDSWITH") { + testPushDown( + filterString = s"ENDSWITH($collatedCol, 'a')", + expectedPushedFilters = Seq(IsNotNull(collatedCol)), + expectedRowCount = 2) + } + + test("push down null check for CONTAINS") { + testPushDown( + filterString = s"CONTAINS($collatedCol, 'a')", + expectedPushedFilters = Seq(IsNotNull(collatedCol)), + expectedRowCount = 2) + } + + test("no push down for IN") { + testPushDown( + filterString = s"$collatedCol IN ('aaa', 'bbb')", + expectedPushedFilters = Seq.empty, + expectedRowCount = 3) + } + + test("push down null check for equality for non-collated column in AND") { + testPushDown( + filterString = s"$collatedCol = 'aaa' AND $nonCollatedCol = 'aaa'", + expectedPushedFilters = + Seq(IsNotNull(collatedCol), IsNotNull(nonCollatedCol), EqualTo(nonCollatedCol, "aaa")), + expectedRowCount = 1) + } + + test("for OR do not push down anything") { + testPushDown( + filterString = s"$collatedCol = 'aaa' OR $nonCollatedCol = 'aaa'", + expectedPushedFilters = Seq.empty, + expectedRowCount = 2) + } + + test("mix OR and AND") { + testPushDown( + filterString = s"$collatedCol = 'aaa' AND ($nonCollatedCol = 'aaa' OR $collatedCol = 'aaa')", + expectedPushedFilters = Seq(IsNotNull(collatedCol)), + expectedRowCount = 2) + } + + test("negate check on collated column") { + testPushDown( + filterString = s"NOT($collatedCol == 'aaa')", + expectedPushedFilters = Seq(IsNotNull(collatedCol)), + expectedRowCount = 1) + } + + test("compare entire struct - parquet does not support null check on complex types") { + testPushDown( + filterString = s"$collatedStructCol = " + + s"named_struct('$collatedStructNestedCol', collate('aaa', $lcaseCollation))", + expectedPushedFilters = Seq.empty, + expectedRowCount = 2) + } + + test("inner struct field access") { + testPushDown( + filterString = s"$collatedStructFieldAccess = 'aaa'", + expectedPushedFilters = Seq(IsNotNull(collatedStructFieldAccess)), + expectedRowCount = 2) + } + + test("array - parquet does not support null check on complex types") { + testPushDown( + filterString = s"$collatedArrayCol = array(collate('aaa', $lcaseCollation))", + expectedPushedFilters = Seq.empty, + expectedRowCount = 2) + } + + test("map - parquet does not support null check on complex types") { + testPushDown( + filterString = s"map_keys($collatedMapCol) != array(collate('aaa', $lcaseCollation))", + expectedPushedFilters = Seq.empty, + expectedRowCount = 1) + } +} + +class CollatedFilterPushDownToParquetV1Suite extends CollatedFilterPushDownToParquetSuite { + override protected def sparkConf: SparkConf = + super + .sparkConf + .set(SQLConf.USE_V1_SOURCE_LIST, dataSource) + + override def getPushedDownFilters(query: DataFrame): Seq[Filter] = { + var maybeRelation: Option[HadoopFsRelation] = None + val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { + case PhysicalOperation(_, filters, + LogicalRelation(relation: HadoopFsRelation, _, _, _)) => + maybeRelation = Some(relation) + filters + }.flatten + + if (maybeAnalyzedPredicate.isEmpty) { + return Seq.empty + } + + val (_, selectedFilters, _) = + DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate) + + val schema = new SparkToParquetSchemaConverter(conf).convert(query.schema) + val parquetFilters = createParquetFilters(schema) + parquetFilters.convertibleFilters(selectedFilters) + } +} + +class CollatedFilterPushDownToParquetV2Suite extends CollatedFilterPushDownToParquetSuite { + override protected def sparkConf: SparkConf = + super + .sparkConf + .set(SQLConf.USE_V1_SOURCE_LIST, "") + + override def getPushedDownFilters(query: DataFrame): Seq[Filter] = { + query.queryExecution.optimizedPlan.collectFirst { + case PhysicalOperation(_, _, + DataSourceV2ScanRelation(_, scan: ParquetScan, _, _, _)) => + scan.pushedFilters.toSeq + }.getOrElse(Seq.empty) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala index 834225baf070e..9f0396ab60e32 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala @@ -357,36 +357,4 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession { Some(sources.GreaterThanOrEqual("col", "value"))) testTranslateFilter(IsNotNull(colAttr), Some(sources.IsNotNull("col"))) } - - for (collation <- Seq("UTF8_LCASE", "UNICODE")) { - test(s"SPARK-48431: Filter pushdown on columns with $collation collation") { - val colAttr = $"col".string(collation) - - // No pushdown for all comparison based filters. - testTranslateFilter(EqualTo(colAttr, Literal("value")), Some(sources.AlwaysTrue)) - testTranslateFilter(LessThan(colAttr, Literal("value")), Some(sources.AlwaysTrue)) - testTranslateFilter(LessThan(colAttr, Literal("value")), Some(sources.AlwaysTrue)) - testTranslateFilter(LessThanOrEqual(colAttr, Literal("value")), Some(sources.AlwaysTrue)) - testTranslateFilter(GreaterThan(colAttr, Literal("value")), Some(sources.AlwaysTrue)) - testTranslateFilter(GreaterThanOrEqual(colAttr, Literal("value")), Some(sources.AlwaysTrue)) - - // Allow pushdown of Is(Not)Null filter. - testTranslateFilter(IsNotNull(colAttr), Some(sources.IsNotNull("col"))) - testTranslateFilter(IsNull(colAttr), Some(sources.IsNull("col"))) - - // Top level filter splitting at And and Or. - testTranslateFilter(And(EqualTo(colAttr, Literal("value")), IsNotNull(colAttr)), - Some(sources.And(sources.AlwaysTrue, sources.IsNotNull("col")))) - testTranslateFilter(Or(EqualTo(colAttr, Literal("value")), IsNotNull(colAttr)), - Some(sources.Or(sources.AlwaysTrue, sources.IsNotNull("col")))) - - // Different cases involving Not. - testTranslateFilter(Not(EqualTo(colAttr, Literal("value"))), Some(sources.AlwaysTrue)) - testTranslateFilter(And(Not(EqualTo(colAttr, Literal("value"))), IsNotNull(colAttr)), - Some(sources.And(sources.AlwaysTrue, sources.IsNotNull("col")))) - // This filter would work, but we want to keep the translation logic simple. - testTranslateFilter(And(EqualTo(colAttr, Literal("value")), Not(IsNotNull(colAttr))), - Some(sources.And(sources.AlwaysTrue, sources.AlwaysTrue))) - } - } }