Skip to content

Commit

Permalink
#534 Allow specifying the explicit number of partitions for metastore…
Browse files Browse the repository at this point in the history
… tables.
  • Loading branch information
yruslan committed Jan 10, 2025
1 parent 46dbb9d commit d4a2e3d
Show file tree
Hide file tree
Showing 23 changed files with 198 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ sealed trait DataFormat {
}

object DataFormat {
case class Parquet(path: String, recordsPerPartition: Option[Long]) extends DataFormat {
case class Parquet(path: String, partitionInfo: PartitionInfo = PartitionInfo.Default) extends DataFormat {
override def name: String = "parquet"

override val isTransient: Boolean = false
Expand All @@ -38,7 +38,7 @@ object DataFormat {
override val isRaw: Boolean = false
}

case class Delta(query: Query, recordsPerPartition: Option[Long]) extends DataFormat {
case class Delta(query: Query, partitionInfo: PartitionInfo = PartitionInfo.Default) extends DataFormat {
override def name: String = "delta"

override val isTransient: Boolean = false
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright 2022 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package za.co.absa.pramen.api

trait PartitionInfo

object PartitionInfo {
case object Default extends PartitionInfo

case class Explicit(numberOfPartitions: Int) extends PartitionInfo

case class PerRecordCount(recordsPerPartition: Long) extends PartitionInfo

}
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,6 @@ object MetastoreImpl {
private val log = LoggerFactory.getLogger(this.getClass)

val METASTORE_KEY = "pramen.metastore.tables"
val DEFAULT_RECORDS_PER_PARTITION = 500000

def fromConfig(conf: Config,
runtimeConfig: RuntimeConfig,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package za.co.absa.pramen.core.metastore.model

import com.typesafe.config.Config
import za.co.absa.pramen.api.{CachePolicy, DataFormat, Query}
import za.co.absa.pramen.api.{CachePolicy, DataFormat, PartitionInfo, Query}
import za.co.absa.pramen.core.utils.ConfigUtils

object DataFormatParser {
Expand All @@ -30,6 +30,7 @@ object DataFormatParser {
val FORMAT_KEY = "format"
val PATH_KEY = "path"
val TABLE_KEY = "table"
val NUMBER_OF_PARTITIONS_KEY = "number.of.partitions"
val RECORDS_PER_PARTITION_KEY = "records.per.partition"
val CACHE_POLICY_KEY = "cache.policy"
val DEFAULT_FORMAT = "parquet"
Expand All @@ -44,14 +45,12 @@ object DataFormatParser {
format match {
case FORMAT_PARQUET =>
val path = getPath(conf)
val recordsPerPartition = ConfigUtils.getOptionLong(conf, RECORDS_PER_PARTITION_KEY)
.orElse(defaultRecordsPerPartition)
DataFormat.Parquet(path, recordsPerPartition)
val partitionInfo = getPartitionInfo(conf, defaultRecordsPerPartition)
DataFormat.Parquet(path, partitionInfo)
case FORMAT_DELTA =>
val query = getQuery(conf)
val recordsPerPartition = ConfigUtils.getOptionLong(conf, RECORDS_PER_PARTITION_KEY)
.orElse(defaultRecordsPerPartition)
DataFormat.Delta(query, recordsPerPartition)
val partitionInfo = getPartitionInfo(conf, defaultRecordsPerPartition)
DataFormat.Delta(query, partitionInfo)
case FORMAT_RAW =>
if (!conf.hasPath(PATH_KEY)) throw new IllegalArgumentException(s"Mandatory option for a metastore table having 'raw' format: $PATH_KEY")
val path = Query.Path(conf.getString(PATH_KEY)).path
Expand All @@ -66,6 +65,26 @@ object DataFormatParser {
}
}

private[core] def getPartitionInfo(conf: Config, defaultRecordsPerPartition: Option[Long]): PartitionInfo = {
val numberOfPartitionsOpt = ConfigUtils.getOptionInt(conf, NUMBER_OF_PARTITIONS_KEY)
val recordsPerPartitionOpt = ConfigUtils.getOptionLong(conf, RECORDS_PER_PARTITION_KEY)

(numberOfPartitionsOpt, recordsPerPartitionOpt) match {
case (Some(_), Some(_)) =>
throw new IllegalArgumentException(
s"Both '$NUMBER_OF_PARTITIONS_KEY' and '$RECORDS_PER_PARTITION_KEY' are specified. Please specify only one of those options.")
case (Some(nop), None) =>
PartitionInfo.Explicit(nop)
case (None, Some(rpp)) =>
PartitionInfo.PerRecordCount(rpp)
case (None, None) =>
defaultRecordsPerPartition match {
case Some(rpp) => PartitionInfo.PerRecordCount(rpp)
case None => PartitionInfo.Default
}
}
}

private[core] def getCachePolicy(conf: Config): Option[CachePolicy] = {
if (conf.hasPath(CACHE_POLICY_KEY)) {
conf.getString(CACHE_POLICY_KEY).trim.toLowerCase match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ object MetastorePersistence {
val saveModeOpt = saveModeOverride.orElse(metaTable.saveModeOpt)

metaTable.format match {
case DataFormat.Parquet(path, recordsPerPartition) =>
case DataFormat.Parquet(path, partitionInfo) =>
new MetastorePersistenceParquet(
path, metaTable.infoDateColumn, metaTable.infoDateFormat, metaTable.batchIdColumn, batchId, recordsPerPartition, saveModeOpt, metaTable.readOptions, metaTable.writeOptions
path, metaTable.infoDateColumn, metaTable.infoDateFormat, metaTable.batchIdColumn, batchId, partitionInfo, saveModeOpt, metaTable.readOptions, metaTable.writeOptions
)
case DataFormat.Delta(query, recordsPerPartition) =>
case DataFormat.Delta(query, partitionInfo) =>
new MetastorePersistenceDelta(
query, metaTable.infoDateColumn, metaTable.infoDateFormat, metaTable.batchIdColumn, batchId, metaTable.partitionByInfoDate, recordsPerPartition, saveModeOpt, metaTable.readOptions, metaTable.writeOptions
query, metaTable.infoDateColumn, metaTable.infoDateFormat, metaTable.batchIdColumn, batchId, metaTable.partitionByInfoDate, partitionInfo, saveModeOpt, metaTable.readOptions, metaTable.writeOptions
)
case DataFormat.Raw(path) =>
new MetastorePersistenceRaw(path, metaTable.infoDateColumn, metaTable.infoDateFormat, saveModeOpt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DateType
import org.apache.spark.sql.{Column, DataFrame, SaveMode, SparkSession}
import org.slf4j.LoggerFactory
import za.co.absa.pramen.api.Query
import za.co.absa.pramen.api.{PartitionInfo, Query}
import za.co.absa.pramen.core.metastore.MetaTableStats
import za.co.absa.pramen.core.metastore.model.HiveConfig
import za.co.absa.pramen.core.metastore.peristence.MetastorePersistenceParquet.applyPartitioning
import za.co.absa.pramen.core.utils.Emoji.SUCCESS
import za.co.absa.pramen.core.utils.hive.QueryExecutor
import za.co.absa.pramen.core.utils.{FsUtils, StringUtils}
Expand All @@ -38,7 +39,7 @@ class MetastorePersistenceDelta(query: Query,
batchIdColumn: String,
batchId: Long,
partitionByInfoDate: Boolean,
recordsPerPartition: Option[Long],
partitionInfo: PartitionInfo,
saveModeOpt: Option[SaveMode],
readOptions: Map[String, String],
writeOptions: Map[String, String]
Expand Down Expand Up @@ -67,13 +68,8 @@ class MetastorePersistenceDelta(query: Query,

val whereCondition = s"$infoDateColumn='$infoDateStr'"

val dfRepartitioned = if (partitionByInfoDate && recordsPerPartition.nonEmpty) {
val recordCount = numberOfRecordsEstimate match {
case Some(count) => count
case None => df.count()
}

applyRepartitioning(df, recordCount)
val dfRepartitioned = if (partitionByInfoDate) {
applyPartitioning(df, partitionInfo, numberOfRecordsEstimate)
} else {
df
}
Expand Down Expand Up @@ -206,15 +202,6 @@ class MetastorePersistenceDelta(query: Query,
}
}

def applyRepartitioning(dfIn: DataFrame, recordCount: Long): DataFrame = {
recordsPerPartition match {
case None => dfIn
case Some(rpp) =>
val numPartitions = Math.max(1, Math.ceil(recordCount.toDouble / rpp)).toInt
dfIn.repartition(numPartitions)
}
}

private def getPartitionPath(infoDate: LocalDate): Path = {
val partition = s"$infoDateColumn=${dateFormatter.format(infoDate)}"
new Path(query.query, partition)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{Column, DataFrame, SaveMode, SparkSession}
import org.slf4j.LoggerFactory
import za.co.absa.pramen.api.PartitionInfo
import za.co.absa.pramen.core.config.Keys
import za.co.absa.pramen.core.metastore.MetaTableStats
import za.co.absa.pramen.core.metastore.model.HiveConfig
Expand All @@ -38,12 +39,14 @@ class MetastorePersistenceParquet(path: String,
infoDateFormat: String,
batchIdColumn: String,
batchId: Long,
recordsPerPartition: Option[Long],
partitionInfo: PartitionInfo,
saveModeOpt: Option[SaveMode],
readOptions: Map[String, String],
writeOptions: Map[String, String]
)(implicit spark: SparkSession) extends MetastorePersistence {

import MetastorePersistenceParquet._

private val log = LoggerFactory.getLogger(this.getClass)
private val dateFormatter = DateTimeFormatter.ofPattern(infoDateFormat)

Expand Down Expand Up @@ -83,16 +86,7 @@ class MetastorePersistenceParquet(path: String,
df
}

val dfRepartitioned = if (recordsPerPartition.nonEmpty) {
val recordCount = numberOfRecordsEstimate match {
case Some(count) => count
case None => dfIn.count()
}

applyRepartitioning(dfIn, recordCount)
} else {
dfIn
}
val dfRepartitioned = applyPartitioning(dfIn, partitionInfo, numberOfRecordsEstimate)

writeAndCleanOnFailure(dfRepartitioned, outputDirStr, fsUtils, saveMode)

Expand Down Expand Up @@ -191,15 +185,6 @@ class MetastorePersistenceParquet(path: String,
}
}

def applyRepartitioning(dfIn: DataFrame, recordCount: Long): DataFrame = {
recordsPerPartition match {
case None => dfIn
case Some(rpp) =>
val numPartitions = Math.max(1, Math.ceil(recordCount.toDouble / rpp)).toInt
dfIn.repartition(numPartitions)
}
}

private[core] def writeAndCleanOnFailure(df: DataFrame,
outputDirStr: String,
fsUtils: FsUtils,
Expand Down Expand Up @@ -231,3 +216,17 @@ class MetastorePersistenceParquet(path: String,
}
}
}

object MetastorePersistenceParquet {
def applyPartitioning(dfIn: DataFrame, partitionInfo: PartitionInfo, recordCountEstimate: Option[Long]): DataFrame = {
partitionInfo match {
case PartitionInfo.Default => dfIn
case PartitionInfo.Explicit(nop) =>
dfIn.coalesce(nop)
case PartitionInfo.PerRecordCount(rpp) =>
val recordCount = recordCountEstimate.getOrElse(dfIn.count())
val numPartitions = Math.max(1, Math.ceil(recordCount.toDouble / rpp)).toInt
dfIn.repartition(numPartitions)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@ package za.co.absa.pramen.core.pipeline
import com.typesafe.config.Config
import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession}
import za.co.absa.pramen.api.status.{DependencyWarning, JobType, TaskRunReason}
import za.co.absa.pramen.api.{DataFormat, Reason}
import za.co.absa.pramen.api.{DataFormat, PartitionInfo, Reason}
import za.co.absa.pramen.core.app.config.GeneralConfig.TEMPORARY_DIRECTORY_KEY
import za.co.absa.pramen.core.bookkeeper.Bookkeeper
import za.co.absa.pramen.core.databricks.{DatabricksClient, PramenPyJobTemplate}
import za.co.absa.pramen.core.exceptions.ProcessFailedException
import za.co.absa.pramen.core.metastore.Metastore
import za.co.absa.pramen.core.metastore.MetastoreImpl.DEFAULT_RECORDS_PER_PARTITION
import za.co.absa.pramen.core.metastore.model.MetaTable
import za.co.absa.pramen.core.pipeline.PythonTransformationJob._
import za.co.absa.pramen.core.process.ProcessRunner
Expand Down Expand Up @@ -250,8 +249,8 @@ class PythonTransformationJob(operationDef: OperationDef,
def getTable(mt: MetaTable): String = {
val description = if (mt.description.isEmpty) "" else s"\n description: ${escapeString(mt.description)}"
val recordsPerPartition = mt.format match {
case f: DataFormat.Parquet => s"\n records_per_partition: ${f.recordsPerPartition.getOrElse(DEFAULT_RECORDS_PER_PARTITION)}"
case f: DataFormat.Delta => s"\n records_per_partition: ${f.recordsPerPartition.getOrElse(DEFAULT_RECORDS_PER_PARTITION)}"
case f: DataFormat.Parquet => getPartitionJaml(f.partitionInfo)
case f: DataFormat.Delta => getPartitionJaml(f.partitionInfo)
case _ => ""
}

Expand Down Expand Up @@ -282,6 +281,17 @@ class PythonTransformationJob(operationDef: OperationDef,
sb.toString
}

private[core] def getPartitionJaml(partitionInfo: PartitionInfo): String = {
partitionInfo match {
case PartitionInfo.Default =>
""
case PartitionInfo.Explicit(npp) =>
s"\n number_of_partitions: $npp"
case PartitionInfo.PerRecordCount(rpp) =>
s"\n records_per_partition: $rpp"
}
}

private[core] def getTemporaryPathForYamlConfig(conf: Config) = {
val temporaryDirectoryBase = if (conf.hasPath(TEMPORARY_DIRECTORY_KEY) && conf.getString(TEMPORARY_DIRECTORY_KEY).nonEmpty) {
conf.getString(TEMPORARY_DIRECTORY_KEY).stripSuffix("/")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package za.co.absa.pramen.core.metastore.model
import com.typesafe.config.ConfigFactory
import org.scalatest.wordspec.AnyWordSpec
import za.co.absa.pramen.api.DataFormat._
import za.co.absa.pramen.api.{CachePolicy, Query}
import za.co.absa.pramen.api.{CachePolicy, PartitionInfo, Query}
import za.co.absa.pramen.core.metastore.model.DataFormatParser.{PATH_KEY, TABLE_KEY}

class DataFormatSuite extends AnyWordSpec {
Expand All @@ -33,10 +33,10 @@ class DataFormatSuite extends AnyWordSpec {
assert(!format.isTransient)
assert(format.isInstanceOf[Parquet])
assert(format.asInstanceOf[Parquet].path == "/a/b/c")
assert(format.asInstanceOf[Parquet].recordsPerPartition.isEmpty)
assert(format.asInstanceOf[Parquet].partitionInfo == PartitionInfo.Default)
}

"use 'parquet' when specified explicitly" in {
"use 'parquet' when rpp specified explicitly" in {
val conf = ConfigFactory.parseString(
"""format = parquet
|path = /a/b/c
Expand All @@ -49,10 +49,26 @@ class DataFormatSuite extends AnyWordSpec {
assert(!format.isTransient)
assert(format.isInstanceOf[Parquet])
assert(format.asInstanceOf[Parquet].path == "/a/b/c")
assert(format.asInstanceOf[Parquet].recordsPerPartition.contains(100))
assert(format.asInstanceOf[Parquet].partitionInfo == PartitionInfo.PerRecordCount(100L))
}

"use 'delta' when specified explicitly" in {
"use 'parquet' when npp specified explicitly" in {
val conf = ConfigFactory.parseString(
"""format = parquet
|path = /a/b/c
|number.of.partitions = 10
|""".stripMargin)

val format = DataFormatParser.fromConfig(conf, conf)

assert(format.name == "parquet")
assert(!format.isTransient)
assert(format.isInstanceOf[Parquet])
assert(format.asInstanceOf[Parquet].path == "/a/b/c")
assert(format.asInstanceOf[Parquet].partitionInfo == PartitionInfo.Explicit(10))
}

"use 'delta' when rpp specified explicitly" in {
val conf = ConfigFactory.parseString(
"""format = delta
|path = /a/b/c
Expand All @@ -66,7 +82,24 @@ class DataFormatSuite extends AnyWordSpec {
assert(format.isInstanceOf[Delta])
assert(format.asInstanceOf[Delta].query.isInstanceOf[Query.Path])
assert(format.asInstanceOf[Delta].query.query == "/a/b/c")
assert(format.asInstanceOf[Delta].recordsPerPartition.contains(200))
assert(format.asInstanceOf[Delta].partitionInfo == PartitionInfo.PerRecordCount(200L))
}

"use 'delta' when npp specified explicitly" in {
val conf = ConfigFactory.parseString(
"""format = delta
|path = /a/b/c
|number.of.partitions = 10
|""".stripMargin)

val format = DataFormatParser.fromConfig(conf, conf)

assert(format.name == "delta")
assert(!format.isTransient)
assert(format.isInstanceOf[Delta])
assert(format.asInstanceOf[Delta].query.isInstanceOf[Query.Path])
assert(format.asInstanceOf[Delta].query.query == "/a/b/c")
assert(format.asInstanceOf[Delta].partitionInfo == PartitionInfo.Explicit(10))
}

"use 'raw' when specified explicitly" in {
Expand Down Expand Up @@ -151,7 +184,7 @@ class DataFormatSuite extends AnyWordSpec {
assert(format.isInstanceOf[Delta])
assert(format.asInstanceOf[Delta].query.isInstanceOf[Query.Path])
assert(format.asInstanceOf[Delta].query.query == "/a/b/c")
assert(format.asInstanceOf[Delta].recordsPerPartition.contains(100))
assert(format.asInstanceOf[Delta].partitionInfo == PartitionInfo.PerRecordCount(100))
}

"throw an exception on unknown format" in {
Expand Down
Loading

0 comments on commit d4a2e3d

Please sign in to comment.