diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala index c19df82e6576b..ba910b8c7e5fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala @@ -427,24 +427,15 @@ case object VariantGet { messageParameters = Map("id" -> v.getTypeInfo.toString) ) } - // We mostly use the `Cast` expression to implement the cast. However, `Cast` silently - // ignores the overflow in the long/decimal -> timestamp cast, and we want to enforce - // strict overflow checks. input.dataType match { case LongType if dataType == TimestampType => - try Math.multiplyExact(input.value.asInstanceOf[Long], MICROS_PER_SECOND) + try castLongToTimestamp(input.value.asInstanceOf[Long]) catch { case _: ArithmeticException => invalidCast() } case _: DecimalType if dataType == TimestampType => - try { - input.value - .asInstanceOf[Decimal] - .toJavaBigDecimal - .multiply(new java.math.BigDecimal(MICROS_PER_SECOND)) - .toBigInteger - .longValueExact() - } catch { + try castDecimalToTimestamp(input.value.asInstanceOf[Decimal]) + catch { case _: ArithmeticException => invalidCast() } case _ => @@ -497,6 +488,27 @@ case object VariantGet { } } } + + // We mostly use the `Cast` expression to implement the cast, but we need some custom logic for + // certain type combinations. + // + // `castLongToTimestamp/castDecimalToTimestamp`: `Cast` silently ignores the overflow in the + // long/decimal -> timestamp cast, and we want to enforce strict overflow checks. They both throw + // an `ArithmeticException` when overflow happens. + def castLongToTimestamp(input: Long): Long = + Math.multiplyExact(input, MICROS_PER_SECOND) + + def castDecimalToTimestamp(input: Decimal): Long = { + val multiplier = new java.math.BigDecimal(MICROS_PER_SECOND) + input.toJavaBigDecimal.multiply(multiplier).toBigInteger.longValueExact() + } + + // Cast decimal to string, but strip any trailing zeros. We don't have to call it if the decimal + // is returned by `Variant.getDecimal`, which already strips any trailing zeros. But we need it + // if the decimal is produced by Spark internally, e.g., on a shredded decimal produced by the + // Spark Parquet reader. + def castDecimalToString(input: Decimal): UTF8String = + UTF8String.fromString(input.toJavaBigDecimal.stripTrailingZeros.toPlainString) } abstract class ParseJsonExpressionBuilderBase(failOnError: Boolean) extends ExpressionBuilder { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala index a83ca78455faa..34c167aea363a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala @@ -111,7 +111,31 @@ case class ScalarCastHelper( } else { "" } - if (cast != null) { + val customCast = (child.dataType, dataType) match { + case (_: LongType, _: TimestampType) => "castLongToTimestamp" + case (_: DecimalType, _: TimestampType) => "castDecimalToTimestamp" + case (_: DecimalType, _: StringType) => "castDecimalToString" + case _ => null + } + if (customCast != null) { + val childCode = child.genCode(ctx) + // We can avoid the try-catch block for decimal -> string, but the performance benefit is + // little. We can also be more specific in the exception type, like catching + // `ArithmeticException` instead of `Exception`, but it is unnecessary. The `try_cast` codegen + // also catches `Exception` instead of specific exceptions. + val code = code""" + ${childCode.code} + boolean ${ev.isNull} = false; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + try { + ${ev.value} = ${classOf[VariantGet].getName}.$customCast(${childCode.value}); + } catch (Exception e) { + ${ev.isNull} = true; + $invalidCastCode + } + """ + ev.copy(code = code) + } else if (cast != null) { val castCode = cast.genCode(ctx) val code = code""" ${castCode.code} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantShreddingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantShreddingSuite.scala index b6623bb57a716..3443028ba45b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/VariantShreddingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantShreddingSuite.scala @@ -24,6 +24,8 @@ import java.time.LocalDateTime import org.apache.spark.SparkThrowable import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils +import org.apache.spark.sql.catalyst.util.DateTimeConstants._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetTest, SparkShreddingUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -349,4 +351,33 @@ class VariantShreddingSuite extends QueryTest with SharedSparkSession with Parqu checkExpr(path, "variant_get(v, '$.a')", null, parseJson("null"), parseJson("1"), null, parseJson("null"), parseJson("3")) } + + testWithTempPath("custom casts") { path => + writeRows(path, writeSchema(LongType), + Row(metadata(Nil), null, Long.MaxValue / MICROS_PER_SECOND + 1), + Row(metadata(Nil), null, Long.MaxValue / MICROS_PER_SECOND)) + + // long -> timestamp + checkException(path, "cast(v as timestamp)", "INVALID_VARIANT_CAST") + checkExpr(path, "try_cast(v as timestamp)", + null, toJavaTimestamp(Long.MaxValue / MICROS_PER_SECOND * MICROS_PER_SECOND)) + + writeRows(path, writeSchema(DecimalType(38, 19)), + Row(metadata(Nil), null, Decimal("1E18")), + Row(metadata(Nil), null, Decimal("100")), + Row(metadata(Nil), null, Decimal("10")), + Row(metadata(Nil), null, Decimal("1")), + Row(metadata(Nil), null, Decimal("0")), + Row(metadata(Nil), null, Decimal("0.1")), + Row(metadata(Nil), null, Decimal("0.01")), + Row(metadata(Nil), null, Decimal("1E-18"))) + + checkException(path, "cast(v as timestamp)", "INVALID_VARIANT_CAST") + // decimal -> timestamp + checkExpr(path, "try_cast(v as timestamp)", + (null +: Seq(100000000, 10000000, 1000000, 0, 100000, 10000, 0).map(toJavaTimestamp(_))): _*) + // decimal -> string + checkExpr(path, "cast(v as string)", + "1000000000000000000", "100", "10", "1", "0", "0.1", "0.01", "0.000000000000000001") + } }