Skip to content

Commit

Permalink
feat: enhance InvokeTaskWorker to support JoinedTaskWorker and add co…
Browse files Browse the repository at this point in the history
…rresponding tests
  • Loading branch information
provos committed Jan 25, 2025
1 parent 0bd7ec8 commit 79f1c80
Show file tree
Hide file tree
Showing 2 changed files with 207 additions and 9 deletions.
61 changes: 52 additions & 9 deletions src/planai/testing/helpers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import logging
from collections import defaultdict
from threading import Lock
from typing import Dict, List, Optional, Type
from typing import Dict, List, Optional, Type, Union

from planai.graph import Graph
from planai.graph_task import SubGraphWorkerInternal
from planai.joined_task import JoinedTaskWorker
from planai.task import Task, TaskWorker


Expand Down Expand Up @@ -79,41 +80,83 @@ class InvokeTaskWorker:
>>> worker.assert_published_task_types([OutputTask])
"""

def __init__(self, worker_class: Type[TaskWorker], **kwargs):
def __init__(
self, worker_class: Union[Type[TaskWorker], Type[JoinedTaskWorker]], **kwargs
):
"""
Args:
worker_class: The TaskWorker class to test
**kwargs: Arguments to pass to the worker constructor
"""
self.context = TestTaskContext()
self.worker = worker_class(**kwargs)
self.is_joined_worker = issubclass(worker_class, JoinedTaskWorker)

def _setup_patch(self):
"""Set up patching of publish_work and reset context."""

def patched_publish_work(task: Task, input_task: Optional[Task]):
self.context.published_tasks.append(task)

original_publish_work = self.worker.publish_work
object.__setattr__(self.worker, "publish_work", patched_publish_work)
self.context.reset()

return original_publish_work

def invoke(self, input_task: Task) -> List[Task]:
"""
Invoke the worker with an input task and return published tasks.
Only valid for TaskWorker instances.
Args:
input_task: The input task to process
Returns:
List of tasks published during processing
"""

def patched_publish_work(task: Task, input_task: Optional[Task]):
self.context.published_tasks.append(task)
original_publish_work = self.worker.publish_work
object.__setattr__(self.worker, "publish_work", patched_publish_work)
Raises:
TypeError: If worker is a JoinedTaskWorker
"""
if self.is_joined_worker:
raise TypeError("Use invoke_joined() for JoinedTaskWorker instances")

self.context.reset()
original_publish_work = self._setup_patch()
self.context.current_input_task = input_task

try:
self.worker.consume_work(input_task)
finally:
object.__setattr__(self.worker, "publish_work", original_publish_work)

return self.context.published_tasks

def invoke_joined(self, input_tasks: List[Task]) -> List[Task]:
"""
Invoke the worker with multiple input tasks and return published tasks.
Only valid for JoinedTaskWorker instances.
Args:
input_tasks: The list of input tasks to process
Returns:
List of tasks published during processing
Raises:
TypeError: If worker is not a JoinedTaskWorker
"""
if not self.is_joined_worker:
raise TypeError("Use invoke() for regular TaskWorker instances")

original_publish_work = self._setup_patch()

try:
self.worker.consume_work_joined(input_tasks)
finally:
object.__setattr__(self.worker, "publish_work", original_publish_work)

return self.context.published_tasks

def assert_published_task_count(self, expected: int):
"""Assert the number of published tasks."""
actual = len(self.context.published_tasks)
Expand Down
155 changes: 155 additions & 0 deletions tests/planai/testing/test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import unittest
from typing import List, Type

from planai.cached_task import CachedTaskWorker
from planai.graph import Graph
from planai.joined_task import JoinedTaskWorker
from planai.task import Task, TaskWorker
from planai.testing.helpers import (
InvokeTaskWorker,
MockCache,
add_input_provenance,
inject_mock_cache,
)


# Sample Task classes for testing
class InputTask(Task):
data: str


class OutputTask(Task):
result: str


class JoinedOutputTask(Task):
results: List[str]


# Sample Worker classes for testing
class SimpleWorker(TaskWorker):
output_types: List[Type[Task]] = [OutputTask]

def consume_work(self, task: InputTask):
output = OutputTask(result=f"Processed: {task.data}")
self.publish_work(output, task)


class MultiOutputWorker(TaskWorker):
output_types: List[Type[Task]] = [OutputTask]

def consume_work(self, task: InputTask):
for i in range(3):
output = OutputTask(result=f"Batch {i}: {task.data}")
self.publish_work(output, task)


class JoinedWorker(JoinedTaskWorker):
join_type: Type[TaskWorker] = SimpleWorker
output_types: List[Type[Task]] = [JoinedOutputTask]

def consume_work_joined(
self, tasks: List[OutputTask]
): # Changed from consume_joined
results = [task.result for task in tasks]
output = JoinedOutputTask(results=results)
self.publish_work(output, tasks[0])


class TestMockCache(unittest.TestCase):
def setUp(self):
self.cache = MockCache()

def test_set_get(self):
self.cache.set("key1", "value1")
self.assertEqual(self.cache.get("key1"), "value1")
self.assertEqual(self.cache.get("nonexistent"), None)

def test_stats(self):
self.cache.set("key1", "value1")
self.cache.get("key1")
self.cache.get("key1")

self.assertEqual(self.cache.set_stats["key1"], 1)
self.assertEqual(self.cache.get_stats["key1"], 2)

def test_dont_store(self):
cache = MockCache(dont_store=True)
cache.set("key1", "value1")
self.assertIsNone(cache.get("key1"))


class TestInvokeTaskWorker(unittest.TestCase):
def test_simple_worker(self):
worker = InvokeTaskWorker(SimpleWorker)
input_task = InputTask(data="test")

published = worker.invoke(input_task)

worker.assert_published_task_count(1)
worker.assert_published_task_types([OutputTask])
self.assertEqual(published[0].result, "Processed: test")

def test_multi_output_worker(self):
worker = InvokeTaskWorker(MultiOutputWorker)
input_task = InputTask(data="test")

published = worker.invoke(input_task)

worker.assert_published_task_count(3)
worker.assert_published_task_types([OutputTask] * 3)
self.assertEqual(len(published), 3)

def test_joined_worker(self):
worker = InvokeTaskWorker(JoinedWorker)
input_tasks = [
OutputTask(result="result1"),
OutputTask(result="result2"),
OutputTask(result="result3"),
]

published = worker.invoke_joined(input_tasks)

worker.assert_published_task_count(1)
worker.assert_published_task_types([JoinedOutputTask])
self.assertEqual(published[0].results, ["result1", "result2", "result3"])

def test_wrong_invocation_method(self):
regular_worker = InvokeTaskWorker(SimpleWorker)
joined_worker = InvokeTaskWorker(JoinedWorker)

with self.assertRaises(TypeError):
regular_worker.invoke_joined([InputTask(data="test")])

with self.assertRaises(TypeError):
joined_worker.invoke(InputTask(data="test"))


class TestHelperFunctions(unittest.TestCase):
def test_add_input_provenance(self):
task1 = InputTask(data="original")
task2 = OutputTask(result="result")

result = add_input_provenance(task2, task1)
self.assertEqual(result._input_provenance, [task1])

def test_inject_mock_cache(self):
graph = Graph(name="Test Graph")

class MyWorker(CachedTaskWorker):
output_types: List[Type[Task]] = [InputTask]

def consume_work(self, task: InputTask):
pass

worker = MyWorker()
graph.add_workers(worker)

mock_cache = MockCache()
inject_mock_cache(graph, mock_cache)

self.assertIs(worker._cache, mock_cache)


if __name__ == "__main__":
unittest.main()

0 comments on commit 79f1c80

Please sign in to comment.