Skip to content

Commit

Permalink
fix: correctly deal with exceptions in TaskWorkers and maintain state…
Browse files Browse the repository at this point in the history
… correctly.
  • Loading branch information
provos committed Sep 1, 2024
1 parent 7344aad commit 79c7bd0
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 23 deletions.
53 changes: 32 additions & 21 deletions src/planai/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from collections import defaultdict, deque
from queue import Empty, Queue
from threading import Event, Lock
from typing import TYPE_CHECKING, DefaultDict, Dict, List, Tuple
from typing import TYPE_CHECKING, DefaultDict, Deque, Dict, List, Tuple

from .task import TaskWorker, TaskWorkItem
from .web_interface import is_quit_requested, run_web_interface
Expand Down Expand Up @@ -46,8 +46,13 @@ def __init__(self, graph: "Graph", web_port=5000):
self.active_tasks = 0
self.task_completion_event = threading.Event()
self.web_port = web_port
self.debug_active_tasks: Dict[str, Tuple[TaskWorker, TaskWorkItem]] = {}
self.completed_tasks: deque = deque(maxlen=100) # Keep last 100 completed tasks
self.debug_active_tasks: Dict[int, Tuple[TaskWorker, TaskWorkItem]] = {}
self.completed_tasks: Deque[Tuple[TaskWorker, TaskWorkItem]] = deque(
maxlen=100
) # Keep last 100 completed tasks
self.failed_tasks: Deque[Tuple[TaskWorker, TaskWorkItem]] = deque(
maxlen=100
) # Keep last 100 failed tasks
self.total_completed_tasks = 0
self.task_id_counter = 0
self.task_lock = threading.Lock()
Expand Down Expand Up @@ -133,15 +138,14 @@ def _execute_task(self, worker: TaskWorker, task: TaskWorkItem):

try:
worker._pre_consume_work(task)
except Exception:
raise # Re-raise the caught exception
finally:
with self.task_lock:
if task_id in self.debug_active_tasks:
del self.debug_active_tasks[task_id]
self.completed_tasks.appendleft(
(task_id, worker, task)
) # may need to move this to _task_completed

def _get_next_task_id(self):
def _get_next_task_id(self) -> int:
with self.task_lock:
self.task_id_counter += 1
return self.task_id_counter
Expand Down Expand Up @@ -173,7 +177,7 @@ def get_completed_tasks(self) -> List[Dict]:
with self.task_lock:
return [
self._task_to_dict(worker, task)
for task_id, worker, task in self.completed_tasks
for worker, task in self.completed_tasks
]

def _get_task_id(self, task: TaskWorkItem) -> str:
Expand Down Expand Up @@ -210,19 +214,26 @@ def _task_completed(self, worker: TaskWorker, task: TaskWorkItem, future):
# This code will run whether the task succeeded or failed

# Determine whether we should retry the task
if not success and worker.num_retries > 0:
if task._retry_count < worker.num_retries:
task._retry_count += 1
self.work_queue.put((worker, task))
logging.info(
f"Retrying task {task.name} for the {task._retry_count} time"
)
return
else:
logging.error(
f"Task {task.name} failed after {task._retry_count} retries"
)
# we'll fall through and do the clean up
if not success:
if worker.num_retries > 0:
if task._retry_count < worker.num_retries:
task._retry_count += 1
self.active_tasks -= 1
self.work_queue.put((worker, task))
logging.info(
f"Retrying task {task.name} for the {task._retry_count} time"
)
return

with self.task_lock:
self.failed_tasks.appendleft((worker, task))
logging.error(
f"Task {task.name} failed after {task._retry_count} retries"
)
# we'll fall through and do the clean up
else:
with self.task_lock:
self.completed_tasks.appendleft((worker, task))

self._remove_provenance(task)
self.active_tasks -= 1
Expand Down
76 changes: 74 additions & 2 deletions tests/planai/test_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,12 @@ def test_dispatch(self):

def test_execute_task(self):
worker = Mock(spec=TaskWorker)
future = Mock()
task = DummyTaskWorkItem(data="test")
self.dispatcher._execute_task(worker, task)
self.dispatcher._task_completed(worker, task, future)
worker._pre_consume_work.assert_called_once_with(task)
self.assertIn(task, [t[2] for t in self.dispatcher.completed_tasks])
self.assertIn(task, [t[1] for t in self.dispatcher.completed_tasks])

def test_task_to_dict(self):
worker = DummyTaskWorkerSimple()
Expand Down Expand Up @@ -201,7 +203,7 @@ def test_get_active_tasks(self):
def test_get_completed_tasks(self):
worker = DummyTaskWorkerSimple()
task = DummyTaskWorkItem(data="test")
self.dispatcher.completed_tasks = deque([(1, worker, task)])
self.dispatcher.completed_tasks = deque([(worker, task)])
result = self.dispatcher.get_completed_tasks()
self.assertIsInstance(result, list)
self.assertEqual(len(result), 1)
Expand Down Expand Up @@ -447,6 +449,7 @@ def test_task_retry(self, mock_log_exception, mock_log_error, mock_log_info):

# Second attempt (should succeed)
self.dispatcher.work_queue.get() # Remove the task from the queue
self.dispatcher.active_tasks += 1
worker._attempt_count = 2 # Simulate successful attempt
future.result.side_effect = None # Remove the exception
self.dispatcher._task_completed(worker, task, future)
Expand Down Expand Up @@ -483,11 +486,13 @@ def test_task_retry_exhausted(

# Second attempt
self.dispatcher.work_queue.get() # Remove the task from the queue
self.dispatcher.active_tasks += 1
self.dispatcher._task_completed(worker, task, future)
self.assertEqual(task._retry_count, 2)

# Third attempt (should not retry anymore)
self.dispatcher.work_queue.get() # Remove the task from the queue
self.dispatcher.active_tasks += 1
self.dispatcher._task_completed(worker, task, future)

# Check final state
Expand All @@ -503,6 +508,73 @@ def test_task_retry_exhausted(
mock_log_info.assert_any_call("Retrying task DummyTaskWorkItem for the 1 time")
mock_log_info.assert_any_call("Retrying task DummyTaskWorkItem for the 2 time")

def test_exception_handling_end_to_end(self):
dispatcher = Dispatcher(self.graph)
self.graph._dispatcher = dispatcher

# Create an ExceptionRaisingTaskWorker
worker = ExceptionRaisingTaskWorker()
self.graph.add_workers(worker)

# Create a task
task = DummyTaskWorkItem(data="test_exception")

# Set up a thread to run the dispatcher
dispatcher_thread = threading.Thread(target=dispatcher.dispatch)
dispatcher_thread.start()

# Add the work to the dispatcher
dispatcher.add_work(worker, task)

# Wait for the task to be processed
dispatcher.wait_for_completion()

# Stop the dispatcher
dispatcher.stop()
dispatcher_thread.join(timeout=5)

# Assertions
self.assertEqual(dispatcher.work_queue.qsize(), 0, "Work queue should be empty")
self.assertEqual(dispatcher.active_tasks, 0, "No active tasks should remain")
self.assertEqual(
dispatcher.total_completed_tasks,
0,
"No tasks should be marked as completed",
)
self.assertEqual(
len(dispatcher.failed_tasks),
1,
"One task should be in the failed tasks list",
)

# Check the failed task
failed_worker, failed_task = dispatcher.failed_tasks[0]
self.assertIsInstance(failed_worker, ExceptionRaisingTaskWorker)
self.assertEqual(failed_task.data, "test_exception")

# Check that the provenance was properly removed
self.assertEqual(len(dispatcher.provenance), 0, "Provenance should be empty")

# Verify that the task is not in the active tasks list
self.assertEqual(
len(dispatcher.debug_active_tasks),
0,
"No tasks should be in the active tasks list",
)

# Check that the exception was logged
with self.assertLogs(level="ERROR") as cm:
dispatcher._task_completed(
worker,
task,
Mock(result=Mock(side_effect=ValueError("Test exception"))),
)
self.assertIn(
"Task DummyTaskWorkItem failed with exception: Test exception", cm.output[0]
)

self.graph._thread_pool.shutdown(wait=True)


class TestDispatcherConcurrent(unittest.TestCase):
def test_concurrent_execution(self):
Expand Down

0 comments on commit 79c7bd0

Please sign in to comment.