From c0179343f773a2a6d64c8de0906006063e6df0b4 Mon Sep 17 00:00:00 2001 From: Chenhao Li Date: Thu, 19 Dec 2024 13:16:01 -0800 Subject: [PATCH] minor fix --- .../spark/sql/execution/SparkOptimizer.scala | 6 ++-- .../datasources/PushVariantIntoScan.scala | 28 +++++++++++-------- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index d5f70afb70dcc..a51870cfd7fdd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -43,7 +43,8 @@ class SparkOptimizer( V2ScanRelationPushDown, V2ScanPartitioningAndOrdering, V2Writes, - PruneFileSourcePartitions) + PruneFileSourcePartitions, + PushVariantIntoScan) override def preCBORules: Seq[Rule[LogicalPlan]] = Seq(OptimizeMetadataOnlyDeleteFromTable) @@ -95,8 +96,7 @@ class SparkOptimizer( EliminateLimits, ConstantFolding), Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*), - Batch("Replace CTE with Repartition", Once, ReplaceCTERefWithRepartition), - Batch("Push Variant Into Scan", Once, PushVariantIntoScan))) + Batch("Replace CTE with Repartition", Once, ReplaceCTERefWithRepartition))) override def nonExcludableRules: Seq[String] = super.nonExcludableRules ++ Seq( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala index 6bf6cd770ce26..83d219c28983b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala @@ -143,8 +143,11 @@ class VariantInRelation { // Find eligible variants recursively. `attrId` is the root attribute id. // `path` is the current struct access path. `dataType` is the child data type after extracting // `path` from the root attribute struct. - def addVariantFields(attrId: ExprId, dataType: DataType, defaultValue: Any, - path: Seq[Int]): Unit = { + def addVariantFields( + attrId: ExprId, + dataType: DataType, + defaultValue: Any, + path: Seq[Int]): Unit = { dataType match { // TODO(SHREDDING): non-null default value is not yet supported. case _: VariantType if defaultValue == null => @@ -195,8 +198,9 @@ class VariantInRelation { } // Add a requested field to a variant column. - private def addField(map: HashMap[RequestedVariantField, Int], - field: RequestedVariantField): Unit = { + private def addField( + map: HashMap[RequestedVariantField, Int], + field: RequestedVariantField): Unit = { val idx = map.size map.getOrElseUpdate(field, idx) } @@ -227,8 +231,9 @@ class VariantInRelation { case _ => expr.children.foreach(collectRequestedFields) } - def rewriteExpr(expr: Expression, - attributeMap: Map[ExprId, AttributeReference]): Expression = { + def rewriteExpr( + expr: Expression, + attributeMap: Map[ExprId, AttributeReference]): Expression = { def rewriteAttribute(expr: Expression): Expression = expr.transformDown { case a: Attribute => attributeMap.getOrElse(a.exprId, a) } @@ -275,11 +280,12 @@ object PushVariantIntoScan extends Rule[LogicalPlan] { } } - private def rewritePlan(originalPlan: LogicalPlan, - projectList: Seq[NamedExpression], - filters: Seq[Expression], - relation: LogicalRelation, - hadoopFsRelation: HadoopFsRelation): LogicalPlan = { + private def rewritePlan( + originalPlan: LogicalPlan, + projectList: Seq[NamedExpression], + filters: Seq[Expression], + relation: LogicalRelation, + hadoopFsRelation: HadoopFsRelation): LogicalPlan = { val variants = new VariantInRelation val defaultValues = ResolveDefaultColumns.existenceDefaultValues(hadoopFsRelation.schema) // I'm not aware of any case that an attribute `relation.output` can have a different data type