Skip to content

Commit

Permalink
#398 Do not correct decimals in schema if it is correct already.
Browse files Browse the repository at this point in the history
  • Loading branch information
yruslan committed May 21, 2024
1 parent 1cebc22 commit 562960b
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,15 @@ object JdbcSparkUtils {
case t: DecimalType if t.scale == 0 && t.precision <= 18 =>
log.info(s"Correct '${field.name}' (prec=${t.precision}, scale=${t.scale}) to long")
newSchema += s"${field.name} long"
case t: DecimalType if t.scale >= 18 =>
case t: DecimalType if t.scale > 18 =>
log.info(s"Correct '${field.name}' (prec=${t.precision}, scale=${t.scale}) to decimal(38, 18)")
newSchema += s"${field.name} decimal(38, 18)"
case t: DecimalType if fixPrecision && t.scale > 0 =>
val fixedPrecision = if (t.precision + t.scale > 38) 38 else t.precision + t.scale
log.info(s"Correct '${field.name}' (prec=${t.precision}, scale=${t.scale}) to decimal($fixedPrecision, ${t.scale})")
newSchema += s"${field.name} decimal($fixedPrecision, ${t.scale})"
if (fixedPrecision > t.precision) {
log.info(s"Correct '${field.name}' (prec=${t.precision}, scale=${t.scale}) to decimal($fixedPrecision, ${t.scale})")
newSchema += s"${field.name} decimal($fixedPrecision, ${t.scale})"
}
case _ =>
field
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,19 @@ class JdbcSparkUtilsSuite extends AnyWordSpec with BeforeAndAfterAll with SparkT
assert(customFields.contains("value decimal(38, 18)"))
}

"correct invalid precision with small scale" in {
val schema = StructType(Array(StructField("value", DecimalType(30, 16))))

val dfOrig: DataFrame = Seq("1234567890").toDF("value")
.withColumn("value", $"value".cast("decimal(38, 18)"))

val df = spark.createDataFrame(dfOrig.rdd, schema)

val customFields = JdbcSparkUtils.getCorrectedDecimalsSchema(df, fixPrecision = true)

assert(customFields.contains("value decimal(38, 16)"))
}

"do nothing if the field is okay" in {
val df: DataFrame = Seq("12345").toDF("value")
.withColumn("value", $"value".cast("int"))
Expand Down

0 comments on commit 562960b

Please sign in to comment.