diff --git a/src/planai/dispatcher.py b/src/planai/dispatcher.py index 8af81cd..8f45e66 100644 --- a/src/planai/dispatcher.py +++ b/src/planai/dispatcher.py @@ -17,7 +17,16 @@ from collections import defaultdict, deque from queue import Empty, Queue from threading import Event, Lock -from typing import TYPE_CHECKING, DefaultDict, Deque, Dict, List, Tuple +from typing import ( + TYPE_CHECKING, + DefaultDict, + Deque, + Dict, + Generator, + List, + Optional, + Tuple, +) from .task import Task, TaskWorker from .web_interface import is_quit_requested, run_web_interface @@ -58,17 +67,20 @@ def __init__(self, graph: "Graph", web_port=5000): self.task_id_counter = 0 self.task_lock = threading.Lock() + def _generate_prefixes(self, task: Task) -> Generator[Tuple, None, None]: + provenance = task._provenance + for i in range(1, len(provenance) + 1): + yield tuple(provenance[:i]) + def _add_provenance(self, task: Task): - with self.provenance_lock: - for i in range(1, len(task._provenance) + 1): - prefix = tuple(task._provenance[:i]) + for prefix in self._generate_prefixes(task): + with self.provenance_lock: self.provenance[prefix] = self.provenance.get(prefix, 0) + 1 def _remove_provenance(self, task: Task): to_notify = set() - with self.provenance_lock: - for i in range(1, len(task._provenance) + 1): - prefix = tuple(task._provenance[:i]) + for prefix in self._generate_prefixes(task): + with self.provenance_lock: self.provenance[prefix] -= 1 if self.provenance[prefix] == 0: del self.provenance[prefix] @@ -77,14 +89,60 @@ def _remove_provenance(self, task: Task): for prefix in to_notify: self._notify_task_completion(prefix, task) - def watch(self, prefix: ProvenanceChain, notifier: TaskWorker) -> bool: + def watch( + self, prefix: ProvenanceChain, notifier: TaskWorker, task: Optional[Task] = None + ) -> bool: + """ + Watches the given prefix and notifies the specified notifier when the prefix is no longer tracked + as part of the provenance of all tasks. + + This method sets up a watch on a specific prefix in the provenance chain. When the prefix is + no longer part of any task's provenance, the provided notifier will be called with the prefix + as an argument. If the prefix is already not part of any task's provenance, the notifier may + be called immediately. + + Parameters: + ----------- + prefix : ProvenanceChain + The prefix to watch. Must be a tuple representing a part of a task's provenance chain. + + notifier : TaskWorker + The object to be notified when the watched prefix is no longer in use. + Its notify method will be called with the watched prefix as an argument. + + task : Task + The task associated with this watch operation if it was called from consume_work. + + Returns: + -------- + bool + True if the notifier was successfully added to the watch list for the given prefix. + False if the notifier was already in the watch list for this prefix. + + Raises: + ------- + ValueError + If the provided prefix is not a tuple. + """ if not isinstance(prefix, tuple): raise ValueError("Prefix must be a tuple") + + added = False with self.notifiers_lock: if notifier not in self.notifiers[prefix]: self.notifiers[prefix].append(notifier) - return True - return False + added = True + + if task is not None: + should_notify = False + with self.provenance_lock: + if self.provenance.get(prefix, 0) == 0: + should_notify = True + + if should_notify: + self._notify_task_completion(prefix, task) + + return added def unwatch(self, prefix: ProvenanceChain, notifier: TaskWorker) -> bool: if not isinstance(prefix, tuple): @@ -266,6 +324,15 @@ def add_work(self, worker: TaskWorker, task: Task): self._add_provenance(task) self.work_queue.put((worker, task)) + def add_multiple_work(self, work_items: List[Tuple[TaskWorker, Task]]): + # the ordering of adding provenance first is important for join tasks to + # work correctly. Otherwise, caching may lead to fast execution of tasks + # before all the provenance is added. + for worker, task in work_items: + self._add_provenance(task) + for item in work_items: + self.work_queue.put(item) + def stop(self): self.stop_event.set() diff --git a/src/planai/task.py b/src/planai/task.py index bd4753c..8575888 100644 --- a/src/planai/task.py +++ b/src/planai/task.py @@ -84,7 +84,7 @@ def previous_input_task(self): return self._input_provenance[-1] if self._input_provenance else None def prefix_for_input_task( - self, task_class: Type["TaskWorker"] + self, task_class: Type["Task"] ) -> Optional["ProvenanceChain"]: """ Finds the provenance chain for the most recent input task of the specified class. @@ -132,6 +132,7 @@ class TaskWorker(BaseModel, ABC): _graph: Optional["Graph"] = PrivateAttr(default=None) _last_input_task: Optional[Task] = PrivateAttr(default=None) _instance_id: uuid.UUID = PrivateAttr(default_factory=uuid.uuid4) + _local: threading.local = PrivateAttr(default_factory=threading.local) def __hash__(self): return hash(self._instance_id) @@ -141,6 +142,13 @@ def __eq__(self, other): return self._instance_id == other._instance_id return False + def _init_work_buffer(self): + if not hasattr(self._local, "work_buffer"): + self._local.work_buffer = [] + + def _clear_work_buffer(self): + self._local.work_buffer = [] + @property def name(self) -> str: """ @@ -185,19 +193,38 @@ def next(self, downstream: "TaskWorker"): self._graph.set_dependency(self, downstream) return downstream - def watch(self, prefix: "ProvenanceChain") -> bool: + def watch(self, prefix: "ProvenanceChain", task: Optional[Task] = None) -> bool: """ - Watches for this task provenance to be completed in the graph. + Watches for the completion of a specific provenance chain prefix in the task graph. + + This method sets up a watch on a given prefix in the provenance chain. It will be notified + in its notify method when this prefix is no longer part of any active task's provenance, indicating + that all tasks with this prefix have been completed. Parameters: - worker (Type["Task"]): The worker to watch. + ----------- + prefix : ProvenanceChain + The prefix to watch. Must be a tuple representing a part of a task's provenance chain. + This is the sequence of task identifiers leading up to (but not including) the current task. + + task : Optional[Task], default=None + The task associated with this watch operation. This parameter is optional and may be + used for additional context or functionality in the underlying implementation. Returns: - True if the watch was added, False if the watch was already present. + -------- + bool + True if the watch was successfully added for the given prefix. + False if a watch for this prefix was already present. + + Raises: + ------- + ValueError + If the provided prefix is not a tuple. """ if not isinstance(prefix, tuple): raise ValueError("Prefix must be a tuple") - return self._graph._dispatcher.watch(prefix, self) + return self._graph._dispatcher.watch(prefix, self, task) def unwatch(self, prefix: "ProvenanceChain") -> bool: """ @@ -216,7 +243,9 @@ def unwatch(self, prefix: "ProvenanceChain") -> bool: def _pre_consume_work(self, task: Task): with self._state_lock: self._last_input_task = task + self._init_work_buffer() self.consume_work(task) + self.flush_work_buffer() def init(self): """ @@ -282,10 +311,18 @@ def publish_work(self, task: Task, input_task: Optional[Task]): consumer.name, task.__class__.__name__, ) + + self._init_work_buffer() + self._local.work_buffer.append((consumer, task)) + + def flush_work_buffer(self): + self._init_work_buffer() if self._graph and self._graph._dispatcher: - self._graph._dispatcher.add_work(consumer, task) + self._graph._dispatcher.add_multiple_work(self._local.work_buffer) else: - self._dispatch_work(task) + for consumer, task in self._local.work_buffer: + self._dispatch_work(task) + self._clear_work_buffer() def completed(self): """Called to let the worker know that it has finished processing all work.""" diff --git a/tests/planai/test_task.py b/tests/planai/test_task.py index 2455813..a9eaaf1 100644 --- a/tests/planai/test_task.py +++ b/tests/planai/test_task.py @@ -99,11 +99,11 @@ def test_watch(self): graph._dispatcher = mock_dispatcher task = DummyTask() task._provenance = [("DummyTask", 1)] - result = self.worker.watch(task.prefix_for_input_task(DummyTask)) - self.assertIsNotNone(task.prefix_for_input_task(DummyTask)) - mock_dispatcher.watch.assert_called_once_with( - task.prefix_for_input_task(DummyTask), self.worker - ) + prefix = task.prefix_for_input_task(DummyTask) + self.assertIsNotNone(prefix) + result = self.worker.watch(prefix) + self.assertIsNotNone(prefix) + mock_dispatcher.watch.assert_called_once_with(prefix, self.worker, None) self.assertEqual(result, mock_dispatcher.watch.return_value) def test_unwatch(self): @@ -141,12 +141,13 @@ def test_publish_work(self): self.worker.register_consumer(DummyTask, self.worker) self.worker.publish_work(task, input_task) + self.worker.flush_work_buffer() self.assertEqual(len(task._provenance), 1) self.assertEqual(task._provenance[0][0], self.worker.name) self.assertEqual(len(task._input_provenance), 1) self.assertIs(task._input_provenance[0], input_task) - mock_dispatcher.add_work.assert_called_once_with(self.worker, task) + mock_dispatcher.add_multiple_work.assert_called_once_with([(self.worker, task)]) def test_publish_work_invalid_type(self): class InvalidTask(Task):