Skip to content

Commit

Permalink
fix: multiple bugs with JoinedTaskWorker where it would not be correc…
Browse files Browse the repository at this point in the history
…tly called

1. If the provenance prefix was alrady removed, notify() was not called
2. Race conditions could lead to not all published work to be received by the JoinedTaskWorker.
  • Loading branch information
provos committed Sep 2, 2024
1 parent bf15499 commit 2d6edbb
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 24 deletions.
87 changes: 77 additions & 10 deletions src/planai/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,16 @@
from collections import defaultdict, deque
from queue import Empty, Queue
from threading import Event, Lock
from typing import TYPE_CHECKING, DefaultDict, Deque, Dict, List, Tuple
from typing import (
TYPE_CHECKING,
DefaultDict,
Deque,
Dict,
Generator,
List,
Optional,
Tuple,
)

from .task import Task, TaskWorker
from .web_interface import is_quit_requested, run_web_interface
Expand Down Expand Up @@ -58,17 +67,20 @@ def __init__(self, graph: "Graph", web_port=5000):
self.task_id_counter = 0
self.task_lock = threading.Lock()

def _generate_prefixes(self, task: Task) -> Generator[Tuple, None, None]:
provenance = task._provenance
for i in range(1, len(provenance) + 1):
yield tuple(provenance[:i])

def _add_provenance(self, task: Task):
with self.provenance_lock:
for i in range(1, len(task._provenance) + 1):
prefix = tuple(task._provenance[:i])
for prefix in self._generate_prefixes(task):
with self.provenance_lock:
self.provenance[prefix] = self.provenance.get(prefix, 0) + 1

def _remove_provenance(self, task: Task):
to_notify = set()
with self.provenance_lock:
for i in range(1, len(task._provenance) + 1):
prefix = tuple(task._provenance[:i])
for prefix in self._generate_prefixes(task):
with self.provenance_lock:
self.provenance[prefix] -= 1
if self.provenance[prefix] == 0:
del self.provenance[prefix]
Expand All @@ -77,14 +89,60 @@ def _remove_provenance(self, task: Task):
for prefix in to_notify:
self._notify_task_completion(prefix, task)

def watch(self, prefix: ProvenanceChain, notifier: TaskWorker) -> bool:
def watch(
self, prefix: ProvenanceChain, notifier: TaskWorker, task: Optional[Task] = None
) -> bool:
"""
Watches the given prefix and notifies the specified notifier when the prefix is no longer tracked
as part of the provenance of all tasks.
This method sets up a watch on a specific prefix in the provenance chain. When the prefix is
no longer part of any task's provenance, the provided notifier will be called with the prefix
as an argument. If the prefix is already not part of any task's provenance, the notifier may
be called immediately.
Parameters:
-----------
prefix : ProvenanceChain
The prefix to watch. Must be a tuple representing a part of a task's provenance chain.
notifier : TaskWorker
The object to be notified when the watched prefix is no longer in use.
Its notify method will be called with the watched prefix as an argument.
task : Task
The task associated with this watch operation if it was called from consume_work.
Returns:
--------
bool
True if the notifier was successfully added to the watch list for the given prefix.
False if the notifier was already in the watch list for this prefix.
Raises:
-------
ValueError
If the provided prefix is not a tuple.
"""
if not isinstance(prefix, tuple):
raise ValueError("Prefix must be a tuple")

added = False
with self.notifiers_lock:
if notifier not in self.notifiers[prefix]:
self.notifiers[prefix].append(notifier)
return True
return False
added = True

if task is not None:
should_notify = False
with self.provenance_lock:
if self.provenance.get(prefix, 0) == 0:
should_notify = True

if should_notify:
self._notify_task_completion(prefix, task)

return added

def unwatch(self, prefix: ProvenanceChain, notifier: TaskWorker) -> bool:
if not isinstance(prefix, tuple):
Expand Down Expand Up @@ -266,6 +324,15 @@ def add_work(self, worker: TaskWorker, task: Task):
self._add_provenance(task)
self.work_queue.put((worker, task))

def add_multiple_work(self, work_items: List[Tuple[TaskWorker, Task]]):
# the ordering of adding provenance first is important for join tasks to
# work correctly. Otherwise, caching may lead to fast execution of tasks
# before all the provenance is added.
for worker, task in work_items:
self._add_provenance(task)
for item in work_items:
self.work_queue.put(item)

def stop(self):
self.stop_event.set()

Expand Down
53 changes: 45 additions & 8 deletions src/planai/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def previous_input_task(self):
return self._input_provenance[-1] if self._input_provenance else None

def prefix_for_input_task(
self, task_class: Type["TaskWorker"]
self, task_class: Type["Task"]
) -> Optional["ProvenanceChain"]:
"""
Finds the provenance chain for the most recent input task of the specified class.
Expand Down Expand Up @@ -132,6 +132,7 @@ class TaskWorker(BaseModel, ABC):
_graph: Optional["Graph"] = PrivateAttr(default=None)
_last_input_task: Optional[Task] = PrivateAttr(default=None)
_instance_id: uuid.UUID = PrivateAttr(default_factory=uuid.uuid4)
_local: threading.local = PrivateAttr(default_factory=threading.local)

def __hash__(self):
return hash(self._instance_id)
Expand All @@ -141,6 +142,13 @@ def __eq__(self, other):
return self._instance_id == other._instance_id
return False

def _init_work_buffer(self):
if not hasattr(self._local, "work_buffer"):
self._local.work_buffer = []

def _clear_work_buffer(self):
self._local.work_buffer = []

@property
def name(self) -> str:
"""
Expand Down Expand Up @@ -185,19 +193,38 @@ def next(self, downstream: "TaskWorker"):
self._graph.set_dependency(self, downstream)
return downstream

def watch(self, prefix: "ProvenanceChain") -> bool:
def watch(self, prefix: "ProvenanceChain", task: Optional[Task] = None) -> bool:
"""
Watches for this task provenance to be completed in the graph.
Watches for the completion of a specific provenance chain prefix in the task graph.
This method sets up a watch on a given prefix in the provenance chain. It will be notified
in its notify method when this prefix is no longer part of any active task's provenance, indicating
that all tasks with this prefix have been completed.
Parameters:
worker (Type["Task"]): The worker to watch.
-----------
prefix : ProvenanceChain
The prefix to watch. Must be a tuple representing a part of a task's provenance chain.
This is the sequence of task identifiers leading up to (but not including) the current task.
task : Optional[Task], default=None
The task associated with this watch operation. This parameter is optional and may be
used for additional context or functionality in the underlying implementation.
Returns:
True if the watch was added, False if the watch was already present.
--------
bool
True if the watch was successfully added for the given prefix.
False if a watch for this prefix was already present.
Raises:
-------
ValueError
If the provided prefix is not a tuple.
"""
if not isinstance(prefix, tuple):
raise ValueError("Prefix must be a tuple")
return self._graph._dispatcher.watch(prefix, self)
return self._graph._dispatcher.watch(prefix, self, task)

def unwatch(self, prefix: "ProvenanceChain") -> bool:
"""
Expand All @@ -216,7 +243,9 @@ def unwatch(self, prefix: "ProvenanceChain") -> bool:
def _pre_consume_work(self, task: Task):
with self._state_lock:
self._last_input_task = task
self._init_work_buffer()
self.consume_work(task)
self.flush_work_buffer()

def init(self):
"""
Expand Down Expand Up @@ -282,10 +311,18 @@ def publish_work(self, task: Task, input_task: Optional[Task]):
consumer.name,
task.__class__.__name__,
)

self._init_work_buffer()
self._local.work_buffer.append((consumer, task))

def flush_work_buffer(self):
self._init_work_buffer()
if self._graph and self._graph._dispatcher:
self._graph._dispatcher.add_work(consumer, task)
self._graph._dispatcher.add_multiple_work(self._local.work_buffer)
else:
self._dispatch_work(task)
for consumer, task in self._local.work_buffer:
self._dispatch_work(task)
self._clear_work_buffer()

def completed(self):
"""Called to let the worker know that it has finished processing all work."""
Expand Down
13 changes: 7 additions & 6 deletions tests/planai/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,11 @@ def test_watch(self):
graph._dispatcher = mock_dispatcher
task = DummyTask()
task._provenance = [("DummyTask", 1)]
result = self.worker.watch(task.prefix_for_input_task(DummyTask))
self.assertIsNotNone(task.prefix_for_input_task(DummyTask))
mock_dispatcher.watch.assert_called_once_with(
task.prefix_for_input_task(DummyTask), self.worker
)
prefix = task.prefix_for_input_task(DummyTask)
self.assertIsNotNone(prefix)
result = self.worker.watch(prefix)
self.assertIsNotNone(prefix)
mock_dispatcher.watch.assert_called_once_with(prefix, self.worker, None)
self.assertEqual(result, mock_dispatcher.watch.return_value)

def test_unwatch(self):
Expand Down Expand Up @@ -141,12 +141,13 @@ def test_publish_work(self):
self.worker.register_consumer(DummyTask, self.worker)

self.worker.publish_work(task, input_task)
self.worker.flush_work_buffer()

self.assertEqual(len(task._provenance), 1)
self.assertEqual(task._provenance[0][0], self.worker.name)
self.assertEqual(len(task._input_provenance), 1)
self.assertIs(task._input_provenance[0], input_task)
mock_dispatcher.add_work.assert_called_once_with(self.worker, task)
mock_dispatcher.add_multiple_work.assert_called_once_with([(self.worker, task)])

def test_publish_work_invalid_type(self):
class InvalidTask(Task):
Expand Down

0 comments on commit 2d6edbb

Please sign in to comment.