Skip to content

Commit

Permalink
refactor: remove optional task parameter from watch methods in Graph …
Browse files Browse the repository at this point in the history
…and ProvenanceTracker
  • Loading branch information
provos committed Jan 28, 2025
1 parent e139f97 commit a0a63b0
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 33 deletions.
6 changes: 2 additions & 4 deletions src/planai/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,8 @@ def __init__(self, **data):
def trace(self, prefix: ProvenanceChain):
self._provenance_tracker.trace(prefix)

def watch(
self, prefix: ProvenanceChain, notifier: TaskWorker, task: Optional[Task] = None
) -> bool:
return self._provenance_tracker.watch(prefix, notifier, task)
def watch(self, prefix: ProvenanceChain, notifier: TaskWorker) -> bool:
return self._provenance_tracker.watch(prefix, notifier)

def unwatch(self, prefix: ProvenanceChain, notifier: TaskWorker) -> bool:
return self._provenance_tracker.unwatch(prefix, notifier)
Expand Down
29 changes: 7 additions & 22 deletions src/planai/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,9 @@ def get_prefix_by_type(
ProvenanceChain: A tuple containing the provenance chain elements corresponding to the specified worker type.
None: If the worker type is not found in the task's provenance chain.
"""
# XXX - This should always be true - but I fear it might not be
assert len(task._provenance) == len(task.input_provenance)
for index, entry in enumerate(task.input_provenance):
for index, entry in enumerate(task._provenance):
if entry[0] == worker_type.__name__:
return tuple(task.input_provenance[: index + 1])
return tuple(task._provenance[: index + 1])
return None

def _add_provenance(self, task: Task):
Expand Down Expand Up @@ -262,8 +260,9 @@ def _remove_provenance(self, task: Task, worker: TaskWorker):

if final_notify:
logging.info(
"Re-adding provenance for %s - as we need to wait for the notification to complete before completely removing it",
"Re-adding provenance for %s (%s) - as we need to wait for the notification to complete before completely removing it",
task.name,
task._provenance,
)
self._add_provenance(task)
self._notify_task_completion(final_notify, worker, task)
Expand All @@ -280,7 +279,7 @@ def _notify_task_completion(
self,
to_notify: List[Tuple[TaskWorker, ProvenanceChain]],
worker: TaskWorker,
task: Optional[Task],
task: Task,
):
if not to_notify:
if task is not None:
Expand Down Expand Up @@ -309,7 +308,7 @@ def _notify_task_completion(

# Use a named function instead of a lambda to avoid closure issues
def task_completed_callback(
future, worker=notifier, to_notify=sorted_to_notify, task=task
future, worker: TaskWorker = notifier, task: Task = task
):
dispatcher: Dispatcher = worker._graph._dispatcher
dispatcher._task_completed(worker, None, future)
Expand Down Expand Up @@ -340,9 +339,7 @@ def get_traces(self) -> Dict:
with self.provenance_lock:
return self.provenance_trace

def watch(
self, prefix: ProvenanceChain, notifier: TaskWorker, task: Optional[Task] = None
) -> bool:
def watch(self, prefix: ProvenanceChain, notifier: TaskWorker) -> bool:
"""
Watches the given prefix and notifies the specified notifier when the prefix is no longer tracked
as part of the provenance of all tasks.
Expand All @@ -361,9 +358,6 @@ def watch(
The object to be notified when the watched prefix is no longer in use.
Its notify method will be called with the watched prefix as an argument.
task : Task
The task associated with this watch operation if it was called from consume_work.
Returns:
--------
bool
Expand All @@ -384,15 +378,6 @@ def watch(
self.notifiers[prefix].append(notifier)
added = True

if task is not None:
should_notify = False
with self.provenance_lock:
if self.provenance.get(prefix, 0) == 0:
should_notify = True

if should_notify:
self._notify_task_completion([(notifier, prefix)], notifier, None)

return added

def unwatch(self, prefix: ProvenanceChain, notifier: TaskWorker) -> bool:
Expand Down
8 changes: 2 additions & 6 deletions src/planai/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def trace(self, prefix: "ProvenanceChain"):
"""
self._graph.trace(prefix)

def watch(self, prefix: "ProvenanceChain", task: Optional[Task] = None) -> bool:
def watch(self, prefix: "ProvenanceChain") -> bool:
"""
Watches for the completion of a specific provenance chain prefix in the task graph.
Expand All @@ -380,10 +380,6 @@ def watch(self, prefix: "ProvenanceChain", task: Optional[Task] = None) -> bool:
The prefix to watch. Must be a tuple representing a part of a task's provenance chain.
This is the sequence of task identifiers leading up to (but not including) the current task.
task : Optional[Task], default=None
The task associated with this watch operation. This parameter is optional and may be
used for additional context or functionality in the underlying implementation.
Returns:
--------
bool
Expand All @@ -397,7 +393,7 @@ def watch(self, prefix: "ProvenanceChain", task: Optional[Task] = None) -> bool:
"""
if not isinstance(prefix, tuple):
raise ValueError("Prefix must be a tuple")
return self._graph.watch(prefix, self, task)
return self._graph.watch(prefix, self)

def unwatch(self, prefix: "ProvenanceChain") -> bool:
"""
Expand Down
33 changes: 33 additions & 0 deletions tests/planai/test_provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,39 @@ def test_callback_execution_on_provenance_removal(self):
mock_metadata, (("Task1", 1),), None, None, "Task removed"
)

def test_get_prefix_by_type(self):
# Create a task with a multi-worker provenance chain
task = DummyTask(data="test")
task._provenance = [("Worker1", 1), ("Worker2", 2), ("Worker3", 3)]
task._input_provenance = []

# Define some test worker classes
class Worker1(TaskWorker):
def consume_work(self, task):
pass

class Worker2(TaskWorker):
def consume_work(self, task):
pass

class Worker3(TaskWorker):
def consume_work(self, task):
pass

# Test getting prefix for different worker types
prefix1 = self.provenance_tracker.get_prefix_by_type(task, Worker1)
prefix2 = self.provenance_tracker.get_prefix_by_type(task, Worker2)
prefix3 = self.provenance_tracker.get_prefix_by_type(task, Worker3)
prefix4 = self.provenance_tracker.get_prefix_by_type(
task, DummyTaskWorkerSimple
)

# Verify the correct prefixes are returned
self.assertEqual(prefix1, (("Worker1", 1),))
self.assertEqual(prefix2, (("Worker1", 1), ("Worker2", 2)))
self.assertEqual(prefix3, (("Worker1", 1), ("Worker2", 2), ("Worker3", 3)))
self.assertIsNone(prefix4) # Should return None for worker type not in chain


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion tests/planai/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def test_watch(self):
self.assertIsNotNone(prefix)
result = self.worker.watch(prefix)
self.assertIsNotNone(prefix)
graph.watch.assert_called_once_with(prefix, self.worker, None)
graph.watch.assert_called_once_with(prefix, self.worker)
self.assertEqual(result, graph.watch.return_value)

def test_unwatch(self):
Expand Down

0 comments on commit a0a63b0

Please sign in to comment.