Skip to content

Commit

Permalink
initial
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhao-db committed Dec 25, 2024
1 parent 9c9bdab commit 9fafdfa
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -427,24 +427,15 @@ case object VariantGet {
messageParameters = Map("id" -> v.getTypeInfo.toString)
)
}
// We mostly use the `Cast` expression to implement the cast. However, `Cast` silently
// ignores the overflow in the long/decimal -> timestamp cast, and we want to enforce
// strict overflow checks.
input.dataType match {
case LongType if dataType == TimestampType =>
try Math.multiplyExact(input.value.asInstanceOf[Long], MICROS_PER_SECOND)
try castLongToTimestamp(input.value.asInstanceOf[Long])
catch {
case _: ArithmeticException => invalidCast()
}
case _: DecimalType if dataType == TimestampType =>
try {
input.value
.asInstanceOf[Decimal]
.toJavaBigDecimal
.multiply(new java.math.BigDecimal(MICROS_PER_SECOND))
.toBigInteger
.longValueExact()
} catch {
try castDecimalToTimestamp(input.value.asInstanceOf[Decimal])
catch {
case _: ArithmeticException => invalidCast()
}
case _ =>
Expand Down Expand Up @@ -497,6 +488,27 @@ case object VariantGet {
}
}
}

// We mostly use the `Cast` expression to implement the cast, but we need some custom logic for
// certain type combinations.
//
// `castLongToTimestamp/castDecimalToTimestamp`: `Cast` silently ignores the overflow in the
// long/decimal -> timestamp cast, and we want to enforce strict overflow checks. They both throw
// an `ArithmeticException` when overflow happens.
def castLongToTimestamp(input: Long): Long =
Math.multiplyExact(input, MICROS_PER_SECOND)

def castDecimalToTimestamp(input: Decimal): Long = {
val multiplier = new java.math.BigDecimal(MICROS_PER_SECOND)
input.toJavaBigDecimal.multiply(multiplier).toBigInteger.longValueExact()
}

// Cast decimal to string, but strip any trailing zeros. We don't have to call it if the decimal
// is returned by `Variant.getDecimal`, which already strips any trailing zeros. But we need it
// if the decimal is produced by Spark internally, e.g., on a shredded decimal produced by the
// Spark Parquet reader.
def castDecimalToString(input: Decimal): UTF8String =
UTF8String.fromString(input.toJavaBigDecimal.stripTrailingZeros.toPlainString)
}

abstract class ParseJsonExpressionBuilderBase(failOnError: Boolean) extends ExpressionBuilder {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,31 @@ case class ScalarCastHelper(
} else {
""
}
if (cast != null) {
val customCast = (child.dataType, dataType) match {
case (_: LongType, _: TimestampType) => "castLongToTimestamp"
case (_: DecimalType, _: TimestampType) => "castDecimalToTimestamp"
case (_: DecimalType, _: StringType) => "castDecimalToString"
case _ => null
}
if (customCast != null) {
val childCode = child.genCode(ctx)
// We can avoid the try-catch block for decimal -> string, but the performance benefit is
// little. We can also be more specific in the exception type, like catching
// `ArithmeticException` instead of `Exception`, but it is unnecessary. The `try_cast` codegen
// also catches `Exception` instead of specific exceptions.
val code = code"""
${childCode.code}
boolean ${ev.isNull} = false;
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
try {
${ev.value} = ${classOf[VariantGet].getName}.$customCast(${childCode.value});
} catch (Exception e) {
${ev.isNull} = true;
$invalidCastCode
}
"""
ev.copy(code = code)
} else if (cast != null) {
val castCode = cast.genCode(ctx)
val code = code"""
${castCode.code}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import java.time.LocalDateTime
import org.apache.spark.SparkThrowable
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.execution.datasources.parquet.{ParquetTest, SparkShreddingUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
Expand Down Expand Up @@ -349,4 +351,33 @@ class VariantShreddingSuite extends QueryTest with SharedSparkSession with Parqu
checkExpr(path, "variant_get(v, '$.a')", null, parseJson("null"), parseJson("1"), null,
parseJson("null"), parseJson("3"))
}

testWithTempPath("custom casts") { path =>
writeRows(path, writeSchema(LongType),
Row(metadata(Nil), null, Long.MaxValue / MICROS_PER_SECOND + 1),
Row(metadata(Nil), null, Long.MaxValue / MICROS_PER_SECOND))

// long -> timestamp
checkException(path, "cast(v as timestamp)", "INVALID_VARIANT_CAST")
checkExpr(path, "try_cast(v as timestamp)",
null, toJavaTimestamp(Long.MaxValue / MICROS_PER_SECOND * MICROS_PER_SECOND))

writeRows(path, writeSchema(DecimalType(38, 19)),
Row(metadata(Nil), null, Decimal("1E18")),
Row(metadata(Nil), null, Decimal("100")),
Row(metadata(Nil), null, Decimal("10")),
Row(metadata(Nil), null, Decimal("1")),
Row(metadata(Nil), null, Decimal("0")),
Row(metadata(Nil), null, Decimal("0.1")),
Row(metadata(Nil), null, Decimal("0.01")),
Row(metadata(Nil), null, Decimal("1E-18")))

checkException(path, "cast(v as timestamp)", "INVALID_VARIANT_CAST")
// decimal -> timestamp
checkExpr(path, "try_cast(v as timestamp)",
(null +: Seq(100000000, 10000000, 1000000, 0, 100000, 10000, 0).map(toJavaTimestamp(_))): _*)
// decimal -> string
checkExpr(path, "cast(v as string)",
"1000000000000000000", "100", "10", "1", "0", "0.1", "0.01", "0.000000000000000001")
}
}

0 comments on commit 9fafdfa

Please sign in to comment.