Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#284 Fix lock/semaphore/channel resource release on fatal errors. #285

Merged
merged 1 commit into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ class IngestionJob(operationDef: OperationDef,
extends JobBase(operationDef, metastore, bookkeeper, notificationTargets, outputTable) {
import JobBase._

private val log = LoggerFactory.getLogger(this.getClass)

override val scheduleStrategy: ScheduleStrategy = new ScheduleStrategySourcing

override def preRunCheckJob(infoDate: LocalDate, jobConfig: Config, dependencyWarnings: Seq[DependencyWarning]): JobPreRunResult = {
Expand Down Expand Up @@ -153,13 +151,17 @@ class IngestionJob(operationDef: OperationDef,
inputRecordCount: Option[Long]): SaveResult = {
val stats = metastore.saveTable(outputTable.name, infoDate, df, inputRecordCount)

source.postProcess(
sourceTable.query,
outputTable.name,
metastore.getMetastoreReader(Seq(outputTable.name), infoDate),
infoDate,
operationDef.extraOptions
)
try {
source.postProcess(
sourceTable.query,
outputTable.name,
metastore.getMetastoreReader(Seq(outputTable.name), infoDate),
infoDate,
operationDef.extraOptions
)
} catch {
case _: AbstractMethodError => log.warn(s"Sources were built using old version of Pramen that does not support post processing. Ignoring...")
}

source.close()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package za.co.absa.pramen.core.pipeline

import com.typesafe.config.Config
import org.apache.spark.sql.types.StructType
import org.slf4j.LoggerFactory
import org.slf4j.{Logger, LoggerFactory}
import za.co.absa.pramen.core.bookkeeper.Bookkeeper
import za.co.absa.pramen.core.expr.DateExprEvaluator
import za.co.absa.pramen.core.metastore.Metastore
Expand All @@ -35,7 +35,7 @@ abstract class JobBase(operationDef: OperationDef,
jobNotificationTargets: Seq[JobNotificationTarget],
outputTableDef: MetaTable
) extends Job {
private val log = LoggerFactory.getLogger(this.getClass)
protected val log: Logger = LoggerFactory.getLogger(this.getClass)

override val name: String = operationDef.name

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,6 @@ class PythonTransformationJob(operationDef: OperationDef,
(implicit spark: SparkSession)
extends JobBase(operationDef, metastore, bookkeeper,notificationTargets, outputTable) {

private val log = LoggerFactory.getLogger(this.getClass)

private val minimumRecords: Int = operationDef.extraOptions.getOrElse(MINIMUM_RECORDS_OPTION, "0").toInt

override val scheduleStrategy: ScheduleStrategy = new ScheduleStrategySourcing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ class SinkJob(operationDef: OperationDef,
extends JobBase(operationDef, metastore, bookkeeper, notificationTargets, outputTable) {
import JobBase._

private val log = LoggerFactory.getLogger(this.getClass)

private val inputTables = operationDef.dependencies.flatMap(_.tables).distinct

override val scheduleStrategy: ScheduleStrategy = new ScheduleStrategySourcing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,15 @@ class TransformationJob(operationDef: OperationDef,
inputRecordCount: Option[Long]): SaveResult = {
val saveResults = SaveResult(metastore.saveTable(outputTable.name, infoDate, df, None))

transformer.postProcess(
outputTable.name,
metastore.getMetastoreReader(inputTables :+ outputTable.name, infoDate),
infoDate, operationDef.extraOptions
)
try {
transformer.postProcess(
outputTable.name,
metastore.getMetastoreReader(inputTables :+ outputTable.name, infoDate),
infoDate, operationDef.extraOptions
)
} catch {
case _: AbstractMethodError => log.warn(s"Transformers were built using old version of Pramen that does not support post processing. Ignoring...")
}

saveResults
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import za.co.absa.pramen.core.pipeline.Job
import za.co.absa.pramen.core.runner.jobrunner.ConcurrentJobRunner.JobRunResults
import za.co.absa.pramen.core.runner.splitter.ScheduleParams
import za.co.absa.pramen.core.runner.task.{RunStatus, TaskResult, TaskRunner}
import za.co.absa.pramen.core.utils.Emoji

import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors.newFixedThreadPool
Expand Down Expand Up @@ -76,18 +77,23 @@ class ConcurrentJobRunnerImpl(runtimeConfig: RuntimeConfig,
}

private def workerLoop(workerNum: Int, incomingJobs: ReadChannel[Job]): Unit = {
incomingJobs.foreach(job => {
incomingJobs.foreach { job =>
val isTransient = job.outputTable.format.isInstanceOf[DataFormat.Transient]
Try {
try {
log.info(s"Worker $workerNum starting job '${job.name}' that outputs to '${job.outputTable.name}'...")
val isSucceeded = runJob(job)

completedJobsChannel.send((job, Nil, isSucceeded))
}.recover({
} catch {
case NonFatal(ex) =>
completedJobsChannel.send((job, TaskResult(job, RunStatus.Failed(ex), None, applicationId, isTransient, Nil, Nil, Nil) :: Nil, false))
})
})
case ex: Throwable =>
log.error(s"${Emoji.FAILURE} A FATAL error has been encountered.", ex)
val fatalEx = new RuntimeException(s"FATAL exception encountered, stopping the pipeline.", ex)
completedJobsChannel.send((job, TaskResult(job, RunStatus.Failed(fatalEx), None, applicationId, isTransient, Nil, Nil, Nil) :: Nil, false))
completedJobsChannel.close()
}
}
completedJobsChannel.close()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,88 +241,90 @@ abstract class TaskRunnerBase(conf: Config,
val isTransient = task.job.outputTable.format.isInstanceOf[DataFormat.Transient]
val lock = lockFactory.getLock(getTokenName(task))

val attempt = Try {
if (runtimeConfig.useLocks && !lock.tryAcquire())
throw new IllegalStateException(s"Another instance is already running for ${task.job.outputTable.name} for ${task.infoDate}")
val attempt = try {
Try {
if (runtimeConfig.useLocks && !lock.tryAcquire())
throw new IllegalStateException(s"Another instance is already running for ${task.job.outputTable.name} for ${task.infoDate}")

val recordCountOldOpt = bookkeeper.getLatestDataChunk(task.job.outputTable.name, task.infoDate, task.infoDate).map(_.outputRecordCount)
val recordCountOldOpt = bookkeeper.getLatestDataChunk(task.job.outputTable.name, task.infoDate, task.infoDate).map(_.outputRecordCount)

val runResult = task.job.run(task.infoDate, conf)
val runResult = task.job.run(task.infoDate, conf)

val schemaChangesBeforeTransform = handleSchemaChange(runResult.data, task.job.outputTable, task.infoDate)
val schemaChangesBeforeTransform = handleSchemaChange(runResult.data, task.job.outputTable, task.infoDate)

val dfWithTimestamp = task.job.operation.processingTimestampColumn match {
case Some(timestampCol) => addProcessingTimestamp(runResult.data, timestampCol)
case None => runResult.data
}

val dfWithInfoDate = if (dfWithTimestamp.schema.exists(f => f.name.equals(task.job.outputTable.infoDateColumn)) || task.job.outputTable.infoDateColumn.isEmpty) {
dfWithTimestamp
} else {
dfWithTimestamp.withColumn(task.job.outputTable.infoDateColumn, lit(Date.valueOf(task.infoDate)))
}
val dfWithTimestamp = task.job.operation.processingTimestampColumn match {
case Some(timestampCol) => addProcessingTimestamp(runResult.data, timestampCol)
case None => runResult.data
}

val postProcessed = task.job.postProcessing(dfWithInfoDate, task.infoDate, conf)
val dfWithInfoDate = if (dfWithTimestamp.schema.exists(f => f.name.equals(task.job.outputTable.infoDateColumn)) || task.job.outputTable.infoDateColumn.isEmpty) {
dfWithTimestamp
} else {
dfWithTimestamp.withColumn(task.job.outputTable.infoDateColumn, lit(Date.valueOf(task.infoDate)))
}

val dfTransformed = applyFilters(
applyTransformations(postProcessed, task.job.operation.schemaTransformations),
task.job.operation.filters,
task.infoDate,
task.infoDate,
task.infoDate
)
val postProcessed = task.job.postProcessing(dfWithInfoDate, task.infoDate, conf)

val schemaChangesAfterTransform = if (task.job.operation.schemaTransformations.nonEmpty) {
val transformedTable = task.job.outputTable.copy(name = s"${task.job.outputTable.name}_transformed")
handleSchemaChange(dfTransformed, transformedTable, task.infoDate)
} else {
Nil
}
val dfTransformed = applyFilters(
applyTransformations(postProcessed, task.job.operation.schemaTransformations),
task.job.operation.filters,
task.infoDate,
task.infoDate,
task.infoDate
)

val saveResult = if (runtimeConfig.isDryRun) {
log.warn(s"$WARNING DRY RUN mode, no actual writes to ${task.job.outputTable.name} for ${task.infoDate} will be performed.")
SaveResult(MetaTableStats(dfTransformed.count(), None))
} else {
task.job.save(dfTransformed, task.infoDate, conf, started, validationResult.inputRecordsCount)
}
val schemaChangesAfterTransform = if (task.job.operation.schemaTransformations.nonEmpty) {
val transformedTable = task.job.outputTable.copy(name = s"${task.job.outputTable.name}_transformed")
handleSchemaChange(dfTransformed, transformedTable, task.infoDate)
} else {
Nil
}

val hiveWarnings = if (task.job.outputTable.hiveTable.nonEmpty) {
val recreate = schemaChangesBeforeTransform.nonEmpty || schemaChangesAfterTransform.nonEmpty || task.reason == TaskRunReason.Rerun
task.job.createOrRefreshHiveTable(dfTransformed.schema, task.infoDate, recreate)
} else {
Seq.empty
}
val saveResult = if (runtimeConfig.isDryRun) {
log.warn(s"$WARNING DRY RUN mode, no actual writes to ${task.job.outputTable.name} for ${task.infoDate} will be performed.")
SaveResult(MetaTableStats(dfTransformed.count(), None))
} else {
task.job.save(dfTransformed, task.infoDate, conf, started, validationResult.inputRecordsCount)
}

val outputMetastoreHiveTable = task.job.outputTable.hiveTable.map(table => HiveHelper.getFullTable(task.job.outputTable.hiveConfig.database, table))
val hiveTableUpdates = (saveResult.hiveTablesUpdates ++ outputMetastoreHiveTable).distinct

val stats = saveResult.stats

val finished = Instant.now()

val completionReason = if (validationResult.status == NeedsUpdate || (validationResult.status == AlreadyRan && task.reason != TaskRunReason.Rerun))
TaskRunReason.Update else task.reason

val warnings = validationResult.warnings ++ runResult.warnings ++ saveResult.warnings ++ hiveWarnings

TaskResult(task.job,
RunStatus.Succeeded(recordCountOldOpt,
stats.recordCount,
stats.dataSizeBytes,
completionReason,
runResult.filesRead,
saveResult.filesSent,
hiveTableUpdates,
warnings),
Some(RunInfo(task.infoDate, started, finished)),
applicationId,
isTransient,
schemaChangesBeforeTransform ::: schemaChangesAfterTransform,
validationResult.dependencyWarnings,
Seq.empty)
}
val hiveWarnings = if (task.job.outputTable.hiveTable.nonEmpty) {
val recreate = schemaChangesBeforeTransform.nonEmpty || schemaChangesAfterTransform.nonEmpty || task.reason == TaskRunReason.Rerun
task.job.createOrRefreshHiveTable(dfTransformed.schema, task.infoDate, recreate)
} else {
Seq.empty
}

if (runtimeConfig.useLocks) {
val outputMetastoreHiveTable = task.job.outputTable.hiveTable.map(table => HiveHelper.getFullTable(task.job.outputTable.hiveConfig.database, table))
val hiveTableUpdates = (saveResult.hiveTablesUpdates ++ outputMetastoreHiveTable).distinct

val stats = saveResult.stats

val finished = Instant.now()

val completionReason = if (validationResult.status == NeedsUpdate || (validationResult.status == AlreadyRan && task.reason != TaskRunReason.Rerun))
TaskRunReason.Update else task.reason

val warnings = validationResult.warnings ++ runResult.warnings ++ saveResult.warnings ++ hiveWarnings

TaskResult(task.job,
RunStatus.Succeeded(recordCountOldOpt,
stats.recordCount,
stats.dataSizeBytes,
completionReason,
runResult.filesRead,
saveResult.filesSent,
hiveTableUpdates,
warnings),
Some(RunInfo(task.infoDate, started, finished)),
applicationId,
isTransient,
schemaChangesBeforeTransform ::: schemaChangesAfterTransform,
validationResult.dependencyWarnings,
Seq.empty)
}
} catch {
case ex: Throwable => Failure(ex)
} finally {
lock.release()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ import za.co.absa.pramen.core.journal.Journal
import za.co.absa.pramen.core.lock.TokenLockFactory
import za.co.absa.pramen.core.pipeline.Task
import za.co.absa.pramen.core.state.PipelineState
import za.co.absa.pramen.core.utils.Emoji

import java.util.concurrent.Executors.newFixedThreadPool
import java.util.concurrent.{ExecutorService, Semaphore}
import scala.concurrent.ExecutionContext.fromExecutorService
import scala.concurrent.{ExecutionContextExecutorService, Future}
import scala.util.Try

/**
* The responsibility of this class is to handle the execution method.
Expand Down Expand Up @@ -59,15 +59,18 @@ class TaskRunnerMultithreaded(conf: Config,
val resourceCount = getTruncatedResourceCount(requestedCount)

availableResources.acquire(resourceCount)
val result = Try { action }
availableResources.release(resourceCount)
val result = try {
action
} finally {
availableResources.release(resourceCount)
}

result.get
result
}

private[core] def getTruncatedResourceCount(requestedCount: Int): Int = {
if (requestedCount > maxResources) {
log.warn(s"Asked for $requestedCount resources but maximum allowed is $maxResources. Truncating to $maxResources")
log.warn(s"${Emoji.WARNING} Asked for $requestedCount resources but maximum allowed is $maxResources. Truncating to $maxResources")
maxResources
} else {
requestedCount
Expand All @@ -76,7 +79,9 @@ class TaskRunnerMultithreaded(conf: Config,

override def runParallel(tasks: Seq[Task]): Seq[Future[RunStatus]] = {
tasks.map(task => Future {
log.warn(s"${Emoji.PARALLEL}The task has requested ${task.job.operation.consumeThreads} threads...")
whenEnoughResourcesAreAvailable(task.job.operation.consumeThreads) {
log.warn(s"${Emoji.PARALLEL}Running task for the table: '${task.job.outputTable.name}' for '${task.infoDate}'...")
runTask(task)
}
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ object Emoji {
val WARNING = "\u26A0\uFE0F"
val FAILURE = "\u274C"
val EXCLAMATION = s"\u2757"
val WRENCH = "\uD83D\uDD27"
val PARALLEL = "\u29B7"

val ROCKET = "\uD83D\uDE80"
val EMAIL1 = "\uD83D\uDCE7"
Expand Down
Loading