Skip to content

Commit

Permalink
#469 Improve unit test coverage.
Browse files Browse the repository at this point in the history
  • Loading branch information
yruslan committed Aug 21, 2024
1 parent f96b3bc commit 393c4bc
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ abstract class TaskRunnerBase(conf: Config,
protected def runTask(task: Task): RunStatus = {
val started = Instant.now()
task.job.operation.killMaxExecutionTimeSeconds match {
case Some(timeout) =>
case Some(timeout) if timeout > 0 =>
@volatile var runStatus: RunStatus = null

try {
Expand All @@ -132,6 +132,9 @@ abstract class TaskRunnerBase(conf: Config,
case NonFatal(ex) =>
failTask(task, started, ex)
}
case Some(timeout) =>
log.error(s"Incorrect timeout for the task: ${task.job.name}. Should be bigger than zero, got: $timeout.")
doValidateAndRunTask(task)
case None =>
doValidateAndRunTask(task)
}
Expand Down Expand Up @@ -445,12 +448,42 @@ abstract class TaskRunnerBase(conf: Config,
val updatedResult = taskResult.copy(notificationTargetErrors = notificationTargetErrors)

logTaskResult(updatedResult, isLazy)
pipelineState.addTaskCompletion(Seq(updatedResult))
addJournalEntry(task, updatedResult, pipelineState.getState().pipelineInfo)
val wasInterrupted = isTaskInterrupted(task, taskResult)
if (wasInterrupted) {
log.warn("Skipping the interrupted exception of the killed task.")
} else {
pipelineState.addTaskCompletion(Seq(updatedResult))
addJournalEntry(task, updatedResult, pipelineState.getState().pipelineInfo)
}

updatedResult.runStatus
}

private def isTaskInterrupted(task: Task, taskResult: TaskResult): Boolean = {
val hasTimeout = task.job.operation.killMaxExecutionTimeSeconds.nonEmpty

taskResult.runStatus match {
case _: RunStatus.Failed if hasTimeout =>
val failureException = taskResult.runStatus.asInstanceOf[RunStatus.Failed].ex

failureException match {
case _: InterruptedException =>
true
case _: FatalErrorWrapper =>
failureException.getCause match {
case _: InterruptedException =>
true
case _ =>
false
}
case _ =>
false
}
case _ =>
false
}
}

private def addJournalEntry(task: Task, taskResult: TaskResult, pipelineInfo: PipelineInfo): Unit = {
val taskCompleted = TaskCompleted.fromTaskResult(task, taskResult, pipelineInfo)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class MetastoreSpy(registeredTables: Seq[String] = Seq("table1", "table2"),
isTableAvailable: Boolean = true,
isTableEmpty: Boolean = false,
trackDays: Int = 0,
failHive: Boolean = false,
readOptions: Map[String, String] = Map.empty[String, String],
writeOptions: Map[String, String] = Map.empty[String, String]) extends Metastore {

Expand Down Expand Up @@ -86,7 +87,10 @@ class MetastoreSpy(registeredTables: Seq[String] = Seq("table1", "table2"),
schema: Option[StructType],
hiveHelper: HiveHelper,
recreate: Boolean): Unit = {
hiveCreationInvocations.append((tableName, infoDate, schema, recreate))
if (failHive) {
throw new RuntimeException("Test exception")
} else
hiveCreationInvocations.append((tableName, infoDate, schema, recreate))
}

override def getStats(tableName: String, infoDate: LocalDate): MetaTableStats = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import za.co.absa.pramen.core.OperationDefFactory
import za.co.absa.pramen.core.base.SparkTestBase
import za.co.absa.pramen.core.expr.exceptions.SyntaxErrorException
import za.co.absa.pramen.core.fixtures.TextComparisonFixture
import za.co.absa.pramen.core.metastore.model.HiveConfig
import za.co.absa.pramen.core.mocks.MetaTableFactory
import za.co.absa.pramen.core.mocks.bookkeeper.SyncBookkeeperMock
import za.co.absa.pramen.core.mocks.job.JobBaseDummy
Expand Down Expand Up @@ -212,22 +213,66 @@ class JobBaseSuite extends AnyWordSpec with SparkTestBase with TextComparisonFix
}
}

"createOrRefreshHiveTable" should {
"do nothing if Hive table is not defined" in {
val job = getUseCase()

val warnings = job.createOrRefreshHiveTable(null, infoDate, recreate = false)

assert(warnings.isEmpty)

}

"return an empty seq of warnings when the operation succeeded" in {
val job = getUseCase(hiveTable = Some("test_hive_table"))

val warnings = job.createOrRefreshHiveTable(null, infoDate, recreate = false)

assert(warnings.isEmpty)
}

"return warnings if ignore failures enabled" in {
val job = getUseCase(hiveTable = Some("test_hive_table"), hiveFailure = true, ignoreHiveFailures = true)

val warnings = job.createOrRefreshHiveTable(null, infoDate, recreate = false)

assert(warnings.nonEmpty)
assert(warnings.head == "Failed to create or update Hive table 'test_hive_table': Test exception")
}

"re-throw the exception if ignore failures disabled" in {
val job = getUseCase(hiveTable = Some("test_hive_table"), hiveFailure = true)

val ex = intercept[RuntimeException] {
job.createOrRefreshHiveTable(null, infoDate, recreate = false)
}

assert(ex.getMessage == "Test exception")
}
}

def getUseCase(tableDf: DataFrame = null,
dependencies: Seq[MetastoreDependency] = Nil,
isTableAvailable: Boolean = true,
isTableEmpty: Boolean = false,
allowParallel: Boolean = true,
warnMaxExecutionTimeSeconds: Option[Int] = None): JobBase = {
warnMaxExecutionTimeSeconds: Option[Int] = None,
hiveTable: Option[String] = None,
hiveFailure: Boolean = false,
ignoreHiveFailures: Boolean = false): JobBase = {
val operation = OperationDefFactory.getDummyOperationDef(dependencies = dependencies,
allowParallel = allowParallel,
warnMaxExecutionTimeSeconds = warnMaxExecutionTimeSeconds,
extraOptions = Map[String, String]("value" -> "7"))

val bk = new SyncBookkeeperMock

val metastore = new MetastoreSpy(tableDf = tableDf, isTableAvailable = isTableAvailable, isTableEmpty = isTableEmpty)
val metastore = new MetastoreSpy(tableDf = tableDf, isTableAvailable = isTableAvailable, isTableEmpty = isTableEmpty, failHive = hiveFailure)

val outputTable = MetaTableFactory.getDummyMetaTable(name = "test_output_table")
val outputTable = MetaTableFactory.getDummyMetaTable(name = "test_output_table",
hiveTable = hiveTable,
hiveConfig = HiveConfig.getNullConfig.copy(ignoreFailures = ignoreHiveFailures)
)

new JobBaseDummy(operation, Nil, metastore, bk, outputTable)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import scala.concurrent.Await
import scala.concurrent.duration.Duration

class TaskRunnerBaseSuite extends AnyWordSpec with SparkTestBase with TextComparisonFixture {

import spark.implicits._

private val infoDate = LocalDate.of(2022, 2, 18)
Expand Down Expand Up @@ -145,6 +146,43 @@ class TaskRunnerBaseSuite extends AnyWordSpec with SparkTestBase with TextCompar
assert(journalEntries.head.status == "Failed")
}

"run a job that is failing with timeout" in {
val now = Instant.now()

val runFunction: () => RunResult = () => {
Thread.sleep(2000)
null
}

val (runner, _, journal, state, tasks) = getUseCase(runFunction = runFunction,
isRerun = true,
allowParallel = false,
timeoutTask = true)

val taskPreDefs = Seq(core.pipeline.TaskPreDef(infoDate, TaskRunReason.New))

val fut = runner.runJobTasks(tasks.head.job, taskPreDefs)

Await.result(fut, Duration.Inf)

val result = state.completedStatuses

val job = tasks.head.job.asInstanceOf[JobSpy]

assert(job.validateCount == 1)
assert(job.runCount == 1)
assert(job.postProcessingCount == 0)
assert(job.saveCount == 0)
assert(job.createHiveTableCount == 0)
assert(result.length == 1)
assert(result.head.runStatus.isInstanceOf[Failed])

val journalEntries = journal.getEntries(now, now.plusSeconds(30))

assert(journalEntries.length == 1)
assert(journalEntries.head.status == "Failed")
}

"run a single lazy job" in {
val now = Instant.now()
val notificationTarget = new NotificationTargetSpy(ConfigFactory.empty(), (action: TaskResult) => ())
Expand Down Expand Up @@ -581,7 +619,8 @@ class TaskRunnerBaseSuite extends AnyWordSpec with SparkTestBase with TextCompar
bookkeeperIn: Bookkeeper = null,
allowParallel: Boolean = true,
hiveTable: Option[String] = None,
jobNotificationTargets: Seq[JobNotificationTarget] = Nil
jobNotificationTargets: Seq[JobNotificationTarget] = Nil,
timeoutTask: Boolean = false
): (TaskRunnerBase, Bookkeeper, Journal, PipelineStateSpy, Seq[Task]) = {
val conf = ConfigFactory.empty()

Expand All @@ -590,12 +629,13 @@ class TaskRunnerBaseSuite extends AnyWordSpec with SparkTestBase with TextCompar
val bookkeeper = if (bookkeeperIn == null) new SyncBookkeeperMock else bookkeeperIn
val journal = new JournalMock
val tokenLockFactory = new TokenLockFactoryMock

val state = new PipelineStateSpy
val killTimer = if (timeoutTask) Some(1) else None

val operationDef = OperationDefFactory.getDummyOperationDef(
schemaTransformations = List(TransformExpression("c", Some("cast(b as string)"), None)),
filters = List("b > 1")
filters = List("b > 1"),
killMaxExecutionTimeSeconds = killTimer
)

val stats = MetaTableStats(2, Some(100))
Expand Down

0 comments on commit 393c4bc

Please sign in to comment.