Skip to content

Commit

Permalink
fix: protect active count with locks and add a concurrent stress test
Browse files Browse the repository at this point in the history
  • Loading branch information
provos committed Sep 2, 2024
1 parent 22b87a8 commit e6756d8
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 17 deletions.
35 changes: 19 additions & 16 deletions src/planai/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(self, graph: "Graph", web_port=5000):
maxlen=100
) # Keep last 100 failed tasks
self.total_completed_tasks = 0
self.total_failed_tasks = 0
self.task_id_counter = 0
self.task_lock = threading.Lock()

Expand Down Expand Up @@ -103,14 +104,16 @@ def _notify_task_completion(self, prefix: tuple):
to_notify.append((notifier, prefix))

for notifier, prefix in to_notify:
self.active_tasks += 1
with self.task_lock:
self.active_tasks += 1
future = self.graph._thread_pool.submit(notifier.notify, prefix)
future.add_done_callback(self._notify_completed)

def _dispatch_once(self) -> bool:
try:
worker, task = self.work_queue.get(timeout=1)
self.active_tasks += 1
with self.task_lock:
self.active_tasks += 1
future = self.graph._thread_pool.submit(self._execute_task, worker, task)

# Use a named function instead of a lambda to avoid closure issues
Expand All @@ -124,11 +127,7 @@ def task_completed_callback(future, worker=worker, task=task):
return False

def dispatch(self):
while (
not self.stop_event.is_set()
or not self.work_queue.empty()
or self.active_tasks > 0
):
while not self.stop_event.is_set() or not self.work_queue.empty():
self._dispatch_once()

def _execute_task(self, worker: TaskWorker, task: Task):
Expand Down Expand Up @@ -199,9 +198,10 @@ def _get_task_id(self, task: Task) -> str:
return f"unknown_{id(task)}"

def _notify_completed(self, future):
self.active_tasks -= 1
if self.active_tasks == 0 and self.work_queue.empty():
self.task_completion_event.set()
with self.task_lock:
self.active_tasks -= 1
if self.active_tasks == 0 and self.work_queue.empty():
self.task_completion_event.set()

def _task_completed(self, worker: TaskWorker, task: Task, future):
success: bool = False
Expand All @@ -212,7 +212,6 @@ def _task_completed(self, worker: TaskWorker, task: Task, future):

# Handle successful task completion
logging.info(f"Task {task.name} completed successfully")
self.total_completed_tasks += 1
success = True

except Exception as e:
Expand All @@ -230,27 +229,31 @@ def _task_completed(self, worker: TaskWorker, task: Task, future):
if worker.num_retries > 0:
if task.retry_count < worker.num_retries:
task.increment_retry_count()
self.active_tasks -= 1
self.work_queue.put((worker, task))
with self.task_lock:
self.active_tasks -= 1
self.work_queue.put((worker, task))
logging.info(
f"Retrying task {task.name} for the {task.retry_count} time"
)
return

with self.task_lock:
self.failed_tasks.appendleft((worker, task, error_message))
self.total_failed_tasks += 1
logging.error(
f"Task {task.name} failed after {task.retry_count} retries"
)
# we'll fall through and do the clean up
else:
with self.task_lock:
self.completed_tasks.appendleft((worker, task))
self.total_completed_tasks += 1

self._remove_provenance(task)
self.active_tasks -= 1
if self.active_tasks == 0 and self.work_queue.empty():
self.task_completion_event.set()
with self.task_lock:
self.active_tasks -= 1
if self.active_tasks == 0 and self.work_queue.empty():
self.task_completion_event.set()

def add_work(self, worker: TaskWorker, task: Task):
self._add_provenance(task)
Expand Down
101 changes: 100 additions & 1 deletion tests/planai/test_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from typing import List, Type
from unittest.mock import Mock, patch

from pydantic import PrivateAttr
from pydantic import Field, PrivateAttr

from planai.dispatcher import Dispatcher
from planai.graph import Graph
Expand Down Expand Up @@ -74,6 +74,19 @@ def consume_work(self, task: DummyTask):
raise ValueError(f"Simulated failure (attempt {self._attempt_count})")


class SuccessFailTaskWorker(TaskWorker):
failure_rate: float = Field(default=0.3, ge=0.0, le=1.0)

def __init__(self, failure_rate=0.3, **data):
super().__init__(**data)
self.failure_rate = failure_rate

def consume_work(self, task: DummyTask):
if random.random() < self.failure_rate:
raise ValueError("Simulated failure")
time.sleep(random.uniform(0.001, 0.01)) # Simulate some work


class SingleThreadedExecutor:
def __init__(self):
self.tasks = []
Expand Down Expand Up @@ -639,6 +652,92 @@ def add_initial_work():

graph._thread_pool.shutdown(wait=False)

def test_concurrent_success_and_failures(self):
graph = Graph(name="Test Graph")
dispatcher = Dispatcher(graph)
graph._dispatcher = dispatcher

num_workers = 5
tasks_per_worker = 200
total_tasks = num_workers * tasks_per_worker

# Create workers with different failure rates
workers = [
SuccessFailTaskWorker(failure_rate=i * 0.1) for i in range(num_workers)
]
graph.add_workers(*workers)

# Start the dispatcher in a separate thread
dispatch_thread = threading.Thread(target=dispatcher.dispatch)
dispatch_thread.start()

# Function to add work for a single worker
def add_work_for_worker(worker):
for _ in range(tasks_per_worker):
task = DummyTask(data=f"Task for {worker.name}")
dispatcher.add_work(worker, task)

# Start adding work in separate threads
add_work_threads = []
for worker in workers:
thread = threading.Thread(target=add_work_for_worker, args=(worker,))
add_work_threads.append(thread)
thread.start()

# Wait for all work to be added
for thread in add_work_threads:
thread.join()

# Wait for dispatcher to complete all tasks
dispatcher.wait_for_completion()
dispatcher.stop()
dispatch_thread.join()

# Calculate expected failures and successes
expected_failures = sum(
int(tasks_per_worker * worker.failure_rate) for worker in workers
)
expected_successes = total_tasks - expected_failures

# Check results
actual_failures = dispatcher.total_failed_tasks
actual_successes = dispatcher.total_completed_tasks

print(
f"Expected failures: {expected_failures}, Actual failures: {actual_failures}"
)
print(
f"Expected successes: {expected_successes}, Actual successes: {actual_successes}"
)

# Assert with a small margin of error (e.g., 5% of total tasks)
margin = total_tasks * 0.05
self.assertAlmostEqual(
actual_failures,
expected_failures,
delta=margin,
msg=f"Failed tasks count is off by more than {margin}",
)
self.assertAlmostEqual(
actual_successes,
expected_successes,
delta=margin,
msg=f"Completed tasks count is off by more than {margin}",
)

# Verify that all tasks are accounted for
self.assertEqual(
actual_failures + actual_successes,
total_tasks,
"Total of failed and completed tasks should equal total tasks",
)

# Verify that the work queue is empty and there are no active tasks
self.assertEqual(dispatcher.work_queue.qsize(), 0, "Work queue should be empty")
self.assertEqual(dispatcher.active_tasks, 0, "No active tasks should remain")

graph._thread_pool.shutdown(wait=True)


if __name__ == "__main__":
unittest.main()

0 comments on commit e6756d8

Please sign in to comment.