diff --git a/spark-jobs/pom.xml b/spark-jobs/pom.xml
index 00d22f50b..eee6587ea 100644
--- a/spark-jobs/pom.xml
+++ b/spark-jobs/pom.xml
@@ -71,6 +71,11 @@
spark-cobol_${scala.compat.version}
${cobrix.version}
+
+ za.co.absa
+ spark-partition-sizing_${scala.compat.version}
+ 0.1.0
+
diff --git a/spark-jobs/src/main/resources/reference.conf b/spark-jobs/src/main/resources/reference.conf
index 8701c3dbc..e1dbc6bb8 100644
--- a/spark-jobs/src/main/resources/reference.conf
+++ b/spark-jobs/src/main/resources/reference.conf
@@ -80,10 +80,16 @@ spline.mode=BEST_EFFORT
#
spline.producer.url="http://localhost:8085/producer"
-#Block size in bytes needed for repartition/coalesce
+#possible values: plan, dataFrame, sample
+
+partition.strategy="plan"
+#Block size in bytes needed for repartition/coalesce, needed for any strategy except for recordCount
#min.processing.partition.size=31457280
#max.processing.partition.size=134217728
+#if sample is selected
+#partition.sample.size=100
+
# Control plugins
# Several plugins can be used. In this case the last element of the key needs to be incremented for each plugin.
#standardization.plugin.control.metrics.1=za.co.absa.enceladus.KafkaPluginFactory
diff --git a/spark-jobs/src/main/scala/za/co/absa/enceladus/common/CommonJobExecution.scala b/spark-jobs/src/main/scala/za/co/absa/enceladus/common/CommonJobExecution.scala
index 67943e38a..19dd9ddf2 100644
--- a/spark-jobs/src/main/scala/za/co/absa/enceladus/common/CommonJobExecution.scala
+++ b/spark-jobs/src/main/scala/za/co/absa/enceladus/common/CommonJobExecution.scala
@@ -182,43 +182,6 @@ trait CommonJobExecution extends ProjectMetadata {
}
}
- protected def repartitionDataFrame(df: DataFrame, minBlockSize: Option[Long], maxBlockSize: Option[Long])
- (implicit spark: SparkSession): DataFrame = {
- def computeBlockCount(desiredBlockSize: Long, totalByteSize: BigInt, addRemainder: Boolean): Int = {
- val int = (totalByteSize / desiredBlockSize).toInt
- val blockCount = int + (if (addRemainder && (totalByteSize % desiredBlockSize != 0)) 1 else 0)
- blockCount max 1
- }
-
- def changePartitionCount(blockCount: Int, fnc: Int => DataFrame): DataFrame = {
- val outputDf = fnc(blockCount)
- log.info(s"Number of output partitions: ${outputDf.rdd.getNumPartitions}")
- outputDf
- }
-
- val currentPartionCount = df.rdd.getNumPartitions
-
- if (currentPartionCount > 0) {
- val catalystPlan = df.queryExecution.logical
- val sizeInBytes = spark.sessionState.executePlan(catalystPlan).optimizedPlan.stats.sizeInBytes
-
- val currentBlockSize = sizeInBytes / df.rdd.getNumPartitions
-
- (minBlockSize, maxBlockSize) match {
- case (Some(min), None) if currentBlockSize < min =>
- changePartitionCount(computeBlockCount(min, sizeInBytes, addRemainder = false), df.coalesce)
- case (None, Some(max)) if currentBlockSize > max =>
- changePartitionCount(computeBlockCount(max, sizeInBytes, addRemainder = true), df.repartition)
- case (Some(min), Some(max)) if currentBlockSize < min || currentBlockSize > max =>
- changePartitionCount(computeBlockCount(max, sizeInBytes, addRemainder = true), df.repartition)
- case _ => df
- }
- } else {
- // empty dataframe
- df
- }
- }
-
protected def finishJob[T](jobConfig: JobConfigParser[T]): Unit = {
val name = jobConfig.datasetName
val version = jobConfig.datasetVersion
diff --git a/spark-jobs/src/main/scala/za/co/absa/enceladus/common/Repartitioner.scala b/spark-jobs/src/main/scala/za/co/absa/enceladus/common/Repartitioner.scala
new file mode 100644
index 000000000..b70a98cb1
--- /dev/null
+++ b/spark-jobs/src/main/scala/za/co/absa/enceladus/common/Repartitioner.scala
@@ -0,0 +1,67 @@
+/*
+ * Copyright 2018 ABSA Group Limited
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package za.co.absa.enceladus.common
+
+import org.apache.spark.sql.DataFrame
+import org.slf4j.Logger
+import za.co.absa.enceladus.common.config.CommonConfConstants
+import za.co.absa.enceladus.utils.config.ConfigReader
+import za.co.absa.spark.partition.sizing.DataFramePartitioner.DataFrameFunctions
+import za.co.absa.spark.partition.sizing.sizer._
+import za.co.absa.spark.partition.sizing.types.DataTypeSizes
+import za.co.absa.spark.partition.sizing.types.DataTypeSizes.DefaultDataTypeSizes
+
+class Repartitioner(configReader: ConfigReader, log: Logger) {
+
+ val minPartition: Option[Long] = configReader.getLongOption(CommonConfConstants.minPartitionSizeKey)
+ val maxPartition: Option[Long] = configReader.getLongOption(CommonConfConstants.maxPartitionSizeKey)
+
+ implicit val dataTypeSizes: DataTypeSizes = DefaultDataTypeSizes
+
+ def repartition(df: DataFrame): DataFrame = {
+ val partitionStrategy = configReader.getStringOption(CommonConfConstants.partitionStrategy)
+ if (minPartition.isEmpty && maxPartition.isEmpty) {
+ log.warn(s"No partitioning applied doe to missing: ${CommonConfConstants.minPartitionSizeKey}, " +
+ s"${CommonConfConstants.minPartitionSizeKey} keys")
+ }
+ partitionStrategy match {
+ case Some("plan") => df.repartitionByPlanSize(minPartition, maxPartition)
+ case Some("dataframe") => repartitionByDf(df)
+ case Some("sample") => repartitionBySample(df)
+ case _ => df
+ }
+ }
+
+ private def repartitionBySample(df: DataFrame): DataFrame = {
+ val maybeInt = configReader.getIntOption(CommonConfConstants.partitionSampleSizeKey)
+ maybeInt match {
+ case None => {
+ log.warn(s"No repartitioning applied due to missing ${CommonConfConstants.partitionSampleSizeKey} key")
+ df
+ }
+ case Some(x) => {
+ val sizer = new FromDataframeSampleSizer(x)
+ df.repartitionByDesiredSize(sizer)(minPartition, maxPartition)
+ }
+ }
+ }
+
+ private def repartitionByDf(df: DataFrame): DataFrame = {
+ val sizer = new FromDataframeSizer()
+ df.repartitionByDesiredSize(sizer)(minPartition, maxPartition)
+ }
+
+}
diff --git a/spark-jobs/src/main/scala/za/co/absa/enceladus/common/config/CommonConfConstants.scala b/spark-jobs/src/main/scala/za/co/absa/enceladus/common/config/CommonConfConstants.scala
index 4bb20939f..7b007b5af 100644
--- a/spark-jobs/src/main/scala/za/co/absa/enceladus/common/config/CommonConfConstants.scala
+++ b/spark-jobs/src/main/scala/za/co/absa/enceladus/common/config/CommonConfConstants.scala
@@ -16,6 +16,10 @@
package za.co.absa.enceladus.common.config
object CommonConfConstants {
+ val partitionStrategy = "partition.strategy"
+ val maxRecordsPerPartitionKey = "max.record.partition.count"
+ val partitionSampleSizeKey = "partition.sample.size"
+
val minPartitionSizeKey = "min.processing.partition.size"
val maxPartitionSizeKey = "max.processing.partition.size"
}
diff --git a/spark-jobs/src/main/scala/za/co/absa/enceladus/conformance/ConformanceExecution.scala b/spark-jobs/src/main/scala/za/co/absa/enceladus/conformance/ConformanceExecution.scala
index be5c448b1..47eb4666c 100644
--- a/spark-jobs/src/main/scala/za/co/absa/enceladus/conformance/ConformanceExecution.scala
+++ b/spark-jobs/src/main/scala/za/co/absa/enceladus/conformance/ConformanceExecution.scala
@@ -25,7 +25,7 @@ import za.co.absa.atum.core.Atum
import za.co.absa.enceladus.common.RecordIdGeneration._
import za.co.absa.enceladus.common.config.{CommonConfConstants, JobConfigParser, PathConfig}
import za.co.absa.enceladus.common.plugin.menas.MenasPlugin
-import za.co.absa.enceladus.common.{CommonJobExecution, Constants, RecordIdGeneration}
+import za.co.absa.enceladus.common.{CommonJobExecution, Constants, RecordIdGeneration, Repartitioner}
import za.co.absa.enceladus.conformance.config.{ConformanceConfig, ConformanceConfigParser}
import za.co.absa.enceladus.conformance.interpreter.rules.ValidationException
import za.co.absa.enceladus.conformance.interpreter.{DynamicInterpreter, FeatureSwitches}
@@ -154,9 +154,8 @@ trait ConformanceExecution extends CommonJobExecution {
handleEmptyOutput(SourcePhase.Conformance)
}
- val minBlockSize = configReader.getLongOption(CommonConfConstants.minPartitionSizeKey)
- val maxBlockSize = configReader.getLongOption(CommonConfConstants.maxPartitionSizeKey)
- val withRepartitioning = repartitionDataFrame(withPartCols, minBlockSize, maxBlockSize)
+ val repartitioner = new Repartitioner(configReader, log)
+ val withRepartitioning = repartitioner.repartition(withPartCols)
withRepartitioning.write.parquet(preparationResult.pathCfg.publish.path)
diff --git a/spark-jobs/src/main/scala/za/co/absa/enceladus/standardization/StandardizationExecution.scala b/spark-jobs/src/main/scala/za/co/absa/enceladus/standardization/StandardizationExecution.scala
index 9e7615064..297ce9a04 100644
--- a/spark-jobs/src/main/scala/za/co/absa/enceladus/standardization/StandardizationExecution.scala
+++ b/spark-jobs/src/main/scala/za/co/absa/enceladus/standardization/StandardizationExecution.scala
@@ -25,9 +25,9 @@ import za.co.absa.atum.AtumImplicits._
import za.co.absa.atum.core.Atum
import za.co.absa.enceladus.utils.schema.SchemaUtils
import za.co.absa.enceladus.common.RecordIdGeneration.getRecordIdGenerationStrategyFromConfig
-import za.co.absa.enceladus.common.config.{CommonConfConstants, JobConfigParser, PathConfig}
+import za.co.absa.enceladus.common.config.{JobConfigParser, PathConfig}
import za.co.absa.enceladus.common.plugin.menas.MenasPlugin
-import za.co.absa.enceladus.common.{CommonJobExecution, Constants}
+import za.co.absa.enceladus.common.{CommonJobExecution, Constants, Repartitioner}
import za.co.absa.enceladus.dao.MenasDAO
import za.co.absa.enceladus.dao.auth.MenasCredentials
import za.co.absa.enceladus.model.Dataset
@@ -198,11 +198,9 @@ trait StandardizationExecution extends CommonJobExecution {
log.info(s"Writing into standardized path ${preparationResult.pathCfg.standardization.path}")
- val minPartitionSize = configReader.getLongOption(CommonConfConstants.minPartitionSizeKey)
- val maxPartitionSize = configReader.getLongOption(CommonConfConstants.maxPartitionSizeKey)
-
val withRepartitioning = if (cmd.isInstanceOf[StandardizationConfig]) {
- repartitionDataFrame(standardizedDF, minPartitionSize, maxPartitionSize)
+ val repartitioner = new Repartitioner(configReader, log)
+ repartitioner.repartition(standardizedDF)
} else {
standardizedDF
}
diff --git a/spark-jobs/src/test/scala/za/co/absa/enceladus/common/CommonExecutionSuite.scala b/spark-jobs/src/test/scala/za/co/absa/enceladus/common/CommonExecutionSuite.scala
index fe59bbc31..009dba888 100644
--- a/spark-jobs/src/test/scala/za/co/absa/enceladus/common/CommonExecutionSuite.scala
+++ b/spark-jobs/src/test/scala/za/co/absa/enceladus/common/CommonExecutionSuite.scala
@@ -16,7 +16,6 @@
package za.co.absa.enceladus.common
import org.apache.spark.sql.types.{StringType, StructType}
-import org.apache.spark.sql.{DataFrame, SparkSession}
import org.mockito.Mockito
import org.mockito.scalatest.MockitoSugar
import org.scalatest.flatspec.AnyFlatSpec
@@ -27,6 +26,7 @@ import za.co.absa.enceladus.model.{Dataset, Validation}
import za.co.absa.enceladus.standardization.config.StandardizationConfig
import za.co.absa.enceladus.utils.testUtils.TZNormalizedSparkTestBase
import za.co.absa.enceladus.utils.validation.ValidationLevel
+import za.co.absa.spark.partition.sizing.DataFramePartitioner.DataFrameFunctions
class CommonExecutionSuite extends AnyFlatSpec with Matchers with TZNormalizedSparkTestBase with MockitoSugar {
@@ -35,9 +35,6 @@ class CommonExecutionSuite extends AnyFlatSpec with Matchers with TZNormalizedSp
prepareJob()
}
override protected def validatePaths(pathConfig: PathConfig): Unit = {}
- override def repartitionDataFrame(df: DataFrame, minBlockSize: Option[Long], maxBlockSize: Option[Long])
- (implicit spark: SparkSession): DataFrame =
- super.repartitionDataFrame(df, minBlockSize, maxBlockSize)
}
Seq(
@@ -70,7 +67,7 @@ class CommonExecutionSuite extends AnyFlatSpec with Matchers with TZNormalizedSp
val df = spark.read.schema(schema).parquet("src/test/resources/data/empty")
df.rdd.getNumPartitions shouldBe 0 // ensure there are 0 partitions for the test
val commonJob = new CommonJobExecutionTest
- val result = commonJob.repartitionDataFrame(df, Option(1), Option(2))
+ val result = df.repartitionByPlanSize(Option(1), Option(2))
result shouldBe df
}