From 176e7eee6a4aba510c4a5f21d93192f7c542687c Mon Sep 17 00:00:00 2001 From: Chenhao Li Date: Wed, 18 Dec 2024 17:06:02 -0800 Subject: [PATCH] initial --- .../spark/sql/catalyst/expressions/Cast.scala | 13 ++-- .../variant/variantExpressions.scala | 61 ++++++++----------- 2 files changed, 34 insertions(+), 40 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d4ebdf10ef11f..abd635e22f261 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -572,6 +572,11 @@ case class Cast( } } + private lazy val castArgs = variant.VariantCastArgs( + evalMode != EvalMode.TRY, + timeZoneId, + zoneId) + def needsTimeZone: Boolean = Cast.needsTimeZone(child.dataType, dataType) // [[func]] assumes the input is no longer null because eval already does the null check. @@ -1127,7 +1132,7 @@ case class Cast( _ => throw QueryExecutionErrors.cannotCastFromNullTypeError(to) } else if (from.isInstanceOf[VariantType]) { buildCast[VariantVal](_, v => { - variant.VariantGet.cast(v, to, evalMode != EvalMode.TRY, timeZoneId, zoneId) + variant.VariantGet.cast(v, to, castArgs) }) } else { to match { @@ -1225,12 +1230,10 @@ case class Cast( case _ if from.isInstanceOf[VariantType] => (c, evPrim, evNull) => val tmp = ctx.freshVariable("tmp", classOf[Object]) val dataTypeArg = ctx.addReferenceObj("dataType", to) - val zoneStrArg = ctx.addReferenceObj("zoneStr", timeZoneId) - val zoneIdArg = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName) - val failOnError = evalMode != EvalMode.TRY + val castArgsArg = ctx.addReferenceObj("castArgs", castArgs) val cls = classOf[variant.VariantGet].getName code""" - Object $tmp = $cls.cast($c, $dataTypeArg, $failOnError, $zoneStrArg, $zoneIdArg); + Object $tmp = $cls.cast($c, $dataTypeArg, $castArgsArg); if ($tmp == null) { $evNull = true; } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala index 2fa0ce0f570c9..c19df82e6576b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala @@ -278,14 +278,13 @@ case class VariantGet( override def nullable: Boolean = true override def nullIntolerant: Boolean = true + private lazy val castArgs = VariantCastArgs( + failOnError, + timeZoneId, + zoneId) + protected override def nullSafeEval(input: Any, path: Any): Any = { - VariantGet.variantGet( - input.asInstanceOf[VariantVal], - parsedPath, - dataType, - failOnError, - timeZoneId, - zoneId) + VariantGet.variantGet(input.asInstanceOf[VariantVal], parsedPath, dataType, castArgs) } protected override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -293,15 +292,14 @@ case class VariantGet( val tmp = ctx.freshVariable("tmp", classOf[Object]) val parsedPathArg = ctx.addReferenceObj("parsedPath", parsedPath) val dataTypeArg = ctx.addReferenceObj("dataType", dataType) - val zoneStrArg = ctx.addReferenceObj("zoneStr", timeZoneId) - val zoneIdArg = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName) + val castArgsArg = ctx.addReferenceObj("castArgs", castArgs) val code = code""" ${childCode.code} boolean ${ev.isNull} = ${childCode.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { Object $tmp = org.apache.spark.sql.catalyst.expressions.variant.VariantGet.variantGet( - ${childCode.value}, $parsedPathArg, $dataTypeArg, $failOnError, $zoneStrArg, $zoneIdArg); + ${childCode.value}, $parsedPathArg, $dataTypeArg, $castArgsArg); if ($tmp == null) { ${ev.isNull} = true; } else { @@ -323,6 +321,12 @@ case class VariantGet( override def withTimeZone(timeZoneId: String): VariantGet = copy(timeZoneId = Option(timeZoneId)) } +// Several parameters used by `VariantGet.cast`. Packed together to simplify parameter passing. +case class VariantCastArgs( + failOnError: Boolean, + zoneStr: Option[String], + zoneId: ZoneId) + case object VariantGet { /** * Returns whether a data type can be cast into/from variant. For scalar types, we allow a subset @@ -347,9 +351,7 @@ case object VariantGet { input: VariantVal, parsedPath: Array[VariantPathParser.PathSegment], dataType: DataType, - failOnError: Boolean, - zoneStr: Option[String], - zoneId: ZoneId): Any = { + castArgs: VariantCastArgs): Any = { var v = new Variant(input.getValue, input.getMetadata) for (path <- parsedPath) { v = path match { @@ -359,21 +361,16 @@ case object VariantGet { } if (v == null) return null } - VariantGet.cast(v, dataType, failOnError, zoneStr, zoneId) + VariantGet.cast(v, dataType, castArgs) } /** * A simple wrapper of the `cast` function that takes `Variant` rather than `VariantVal`. The * `Cast` expression uses it and makes the implementation simpler. */ - def cast( - input: VariantVal, - dataType: DataType, - failOnError: Boolean, - zoneStr: Option[String], - zoneId: ZoneId): Any = { + def cast(input: VariantVal, dataType: DataType, castArgs: VariantCastArgs): Any = { val v = new Variant(input.getValue, input.getMetadata) - VariantGet.cast(v, dataType, failOnError, zoneStr, zoneId) + VariantGet.cast(v, dataType, castArgs) } /** @@ -383,15 +380,10 @@ case object VariantGet { * "hello" to int). If the cast fails, throw an exception when `failOnError` is true, or return a * SQL NULL when it is false. */ - def cast( - v: Variant, - dataType: DataType, - failOnError: Boolean, - zoneStr: Option[String], - zoneId: ZoneId): Any = { + def cast(v: Variant, dataType: DataType, castArgs: VariantCastArgs): Any = { def invalidCast(): Any = { - if (failOnError) { - throw QueryExecutionErrors.invalidVariantCast(v.toJson(zoneId), dataType) + if (castArgs.failOnError) { + throw QueryExecutionErrors.invalidVariantCast(v.toJson(castArgs.zoneId), dataType) } else { null } @@ -411,7 +403,7 @@ case object VariantGet { val input = variantType match { case Type.OBJECT | Type.ARRAY => return if (dataType.isInstanceOf[StringType]) { - UTF8String.fromString(v.toJson(zoneId)) + UTF8String.fromString(v.toJson(castArgs.zoneId)) } else { invalidCast() } @@ -457,7 +449,7 @@ case object VariantGet { } case _ => if (Cast.canAnsiCast(input.dataType, dataType)) { - val result = Cast(input, dataType, zoneStr, EvalMode.TRY).eval() + val result = Cast(input, dataType, castArgs.zoneStr, EvalMode.TRY).eval() if (result == null) invalidCast() else result } else { invalidCast() @@ -468,7 +460,7 @@ case object VariantGet { val size = v.arraySize() val array = new Array[Any](size) for (i <- 0 until size) { - array(i) = cast(v.getElementAtIndex(i), elementType, failOnError, zoneStr, zoneId) + array(i) = cast(v.getElementAtIndex(i), elementType, castArgs) } new GenericArrayData(array) } else { @@ -482,7 +474,7 @@ case object VariantGet { for (i <- 0 until size) { val field = v.getFieldAtIndex(i) keyArray(i) = UTF8String.fromString(field.key) - valueArray(i) = cast(field.value, valueType, failOnError, zoneStr, zoneId) + valueArray(i) = cast(field.value, valueType, castArgs) } ArrayBasedMapData(keyArray, valueArray) } else { @@ -495,8 +487,7 @@ case object VariantGet { val field = v.getFieldAtIndex(i) st.getFieldIndex(field.key) match { case Some(idx) => - row.update(idx, - cast(field.value, fields(idx).dataType, failOnError, zoneStr, zoneId)) + row.update(idx, cast(field.value, fields(idx).dataType, castArgs)) case _ => } }