diff --git a/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/JdbcSparkUtils.scala b/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/JdbcSparkUtils.scala index bf08b7e57..92a91ae54 100644 --- a/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/JdbcSparkUtils.scala +++ b/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/JdbcSparkUtils.scala @@ -276,13 +276,15 @@ object JdbcSparkUtils { case t: DecimalType if t.scale == 0 && t.precision <= 18 => log.info(s"Correct '${field.name}' (prec=${t.precision}, scale=${t.scale}) to long") newSchema += s"${field.name} long" - case t: DecimalType if t.scale >= 18 => + case t: DecimalType if t.scale > 18 => log.info(s"Correct '${field.name}' (prec=${t.precision}, scale=${t.scale}) to decimal(38, 18)") newSchema += s"${field.name} decimal(38, 18)" case t: DecimalType if fixPrecision && t.scale > 0 => val fixedPrecision = if (t.precision + t.scale > 38) 38 else t.precision + t.scale - log.info(s"Correct '${field.name}' (prec=${t.precision}, scale=${t.scale}) to decimal($fixedPrecision, ${t.scale})") - newSchema += s"${field.name} decimal($fixedPrecision, ${t.scale})" + if (fixedPrecision > t.precision) { + log.info(s"Correct '${field.name}' (prec=${t.precision}, scale=${t.scale}) to decimal($fixedPrecision, ${t.scale})") + newSchema += s"${field.name} decimal($fixedPrecision, ${t.scale})" + } case _ => field } diff --git a/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/utils/JdbcSparkUtilsSuite.scala b/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/utils/JdbcSparkUtilsSuite.scala index 2cfd5b7f8..386727777 100644 --- a/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/utils/JdbcSparkUtilsSuite.scala +++ b/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/utils/JdbcSparkUtilsSuite.scala @@ -228,6 +228,19 @@ class JdbcSparkUtilsSuite extends AnyWordSpec with BeforeAndAfterAll with SparkT assert(customFields.contains("value decimal(38, 18)")) } + "correct invalid precision with small scale" in { + val schema = StructType(Array(StructField("value", DecimalType(30, 16)))) + + val dfOrig: DataFrame = Seq("1234567890").toDF("value") + .withColumn("value", $"value".cast("decimal(38, 18)")) + + val df = spark.createDataFrame(dfOrig.rdd, schema) + + val customFields = JdbcSparkUtils.getCorrectedDecimalsSchema(df, fixPrecision = true) + + assert(customFields.contains("value decimal(38, 16)")) + } + "do nothing if the field is okay" in { val df: DataFrame = Seq("12345").toDF("value") .withColumn("value", $"value".cast("int"))