diff --git a/src/planai/testing/helpers.py b/src/planai/testing/helpers.py index be1024d..40c5a49 100644 --- a/src/planai/testing/helpers.py +++ b/src/planai/testing/helpers.py @@ -1,10 +1,11 @@ import logging from collections import defaultdict from threading import Lock -from typing import Dict, List, Optional, Type +from typing import Dict, List, Optional, Type, Union from planai.graph import Graph from planai.graph_task import SubGraphWorkerInternal +from planai.joined_task import JoinedTaskWorker from planai.task import Task, TaskWorker @@ -79,7 +80,9 @@ class InvokeTaskWorker: >>> worker.assert_published_task_types([OutputTask]) """ - def __init__(self, worker_class: Type[TaskWorker], **kwargs): + def __init__( + self, worker_class: Union[Type[TaskWorker], Type[JoinedTaskWorker]], **kwargs + ): """ Args: worker_class: The TaskWorker class to test @@ -87,26 +90,40 @@ def __init__(self, worker_class: Type[TaskWorker], **kwargs): """ self.context = TestTaskContext() self.worker = worker_class(**kwargs) + self.is_joined_worker = issubclass(worker_class, JoinedTaskWorker) + + def _setup_patch(self): + """Set up patching of publish_work and reset context.""" + + def patched_publish_work(task: Task, input_task: Optional[Task]): + self.context.published_tasks.append(task) + + original_publish_work = self.worker.publish_work + object.__setattr__(self.worker, "publish_work", patched_publish_work) + self.context.reset() + + return original_publish_work def invoke(self, input_task: Task) -> List[Task]: """ Invoke the worker with an input task and return published tasks. + Only valid for TaskWorker instances. Args: input_task: The input task to process Returns: List of tasks published during processing - """ - - def patched_publish_work(task: Task, input_task: Optional[Task]): - self.context.published_tasks.append(task) - original_publish_work = self.worker.publish_work - object.__setattr__(self.worker, "publish_work", patched_publish_work) + Raises: + TypeError: If worker is a JoinedTaskWorker + """ + if self.is_joined_worker: + raise TypeError("Use invoke_joined() for JoinedTaskWorker instances") - self.context.reset() + original_publish_work = self._setup_patch() self.context.current_input_task = input_task + try: self.worker.consume_work(input_task) finally: @@ -114,6 +131,32 @@ def patched_publish_work(task: Task, input_task: Optional[Task]): return self.context.published_tasks + def invoke_joined(self, input_tasks: List[Task]) -> List[Task]: + """ + Invoke the worker with multiple input tasks and return published tasks. + Only valid for JoinedTaskWorker instances. + + Args: + input_tasks: The list of input tasks to process + + Returns: + List of tasks published during processing + + Raises: + TypeError: If worker is not a JoinedTaskWorker + """ + if not self.is_joined_worker: + raise TypeError("Use invoke() for regular TaskWorker instances") + + original_publish_work = self._setup_patch() + + try: + self.worker.consume_work_joined(input_tasks) + finally: + object.__setattr__(self.worker, "publish_work", original_publish_work) + + return self.context.published_tasks + def assert_published_task_count(self, expected: int): """Assert the number of published tasks.""" actual = len(self.context.published_tasks) diff --git a/tests/planai/testing/test_helpers.py b/tests/planai/testing/test_helpers.py new file mode 100644 index 0000000..c25f4a2 --- /dev/null +++ b/tests/planai/testing/test_helpers.py @@ -0,0 +1,155 @@ +import unittest +from typing import List, Type + +from planai.cached_task import CachedTaskWorker +from planai.graph import Graph +from planai.joined_task import JoinedTaskWorker +from planai.task import Task, TaskWorker +from planai.testing.helpers import ( + InvokeTaskWorker, + MockCache, + add_input_provenance, + inject_mock_cache, +) + + +# Sample Task classes for testing +class InputTask(Task): + data: str + + +class OutputTask(Task): + result: str + + +class JoinedOutputTask(Task): + results: List[str] + + +# Sample Worker classes for testing +class SimpleWorker(TaskWorker): + output_types: List[Type[Task]] = [OutputTask] + + def consume_work(self, task: InputTask): + output = OutputTask(result=f"Processed: {task.data}") + self.publish_work(output, task) + + +class MultiOutputWorker(TaskWorker): + output_types: List[Type[Task]] = [OutputTask] + + def consume_work(self, task: InputTask): + for i in range(3): + output = OutputTask(result=f"Batch {i}: {task.data}") + self.publish_work(output, task) + + +class JoinedWorker(JoinedTaskWorker): + join_type: Type[TaskWorker] = SimpleWorker + output_types: List[Type[Task]] = [JoinedOutputTask] + + def consume_work_joined( + self, tasks: List[OutputTask] + ): # Changed from consume_joined + results = [task.result for task in tasks] + output = JoinedOutputTask(results=results) + self.publish_work(output, tasks[0]) + + +class TestMockCache(unittest.TestCase): + def setUp(self): + self.cache = MockCache() + + def test_set_get(self): + self.cache.set("key1", "value1") + self.assertEqual(self.cache.get("key1"), "value1") + self.assertEqual(self.cache.get("nonexistent"), None) + + def test_stats(self): + self.cache.set("key1", "value1") + self.cache.get("key1") + self.cache.get("key1") + + self.assertEqual(self.cache.set_stats["key1"], 1) + self.assertEqual(self.cache.get_stats["key1"], 2) + + def test_dont_store(self): + cache = MockCache(dont_store=True) + cache.set("key1", "value1") + self.assertIsNone(cache.get("key1")) + + +class TestInvokeTaskWorker(unittest.TestCase): + def test_simple_worker(self): + worker = InvokeTaskWorker(SimpleWorker) + input_task = InputTask(data="test") + + published = worker.invoke(input_task) + + worker.assert_published_task_count(1) + worker.assert_published_task_types([OutputTask]) + self.assertEqual(published[0].result, "Processed: test") + + def test_multi_output_worker(self): + worker = InvokeTaskWorker(MultiOutputWorker) + input_task = InputTask(data="test") + + published = worker.invoke(input_task) + + worker.assert_published_task_count(3) + worker.assert_published_task_types([OutputTask] * 3) + self.assertEqual(len(published), 3) + + def test_joined_worker(self): + worker = InvokeTaskWorker(JoinedWorker) + input_tasks = [ + OutputTask(result="result1"), + OutputTask(result="result2"), + OutputTask(result="result3"), + ] + + published = worker.invoke_joined(input_tasks) + + worker.assert_published_task_count(1) + worker.assert_published_task_types([JoinedOutputTask]) + self.assertEqual(published[0].results, ["result1", "result2", "result3"]) + + def test_wrong_invocation_method(self): + regular_worker = InvokeTaskWorker(SimpleWorker) + joined_worker = InvokeTaskWorker(JoinedWorker) + + with self.assertRaises(TypeError): + regular_worker.invoke_joined([InputTask(data="test")]) + + with self.assertRaises(TypeError): + joined_worker.invoke(InputTask(data="test")) + + +class TestHelperFunctions(unittest.TestCase): + def test_add_input_provenance(self): + task1 = InputTask(data="original") + task2 = OutputTask(result="result") + + result = add_input_provenance(task2, task1) + self.assertEqual(result._input_provenance, [task1]) + + def test_inject_mock_cache(self): + graph = Graph(name="Test Graph") + + class MyWorker(CachedTaskWorker): + output_types: List[Type[Task]] = [InputTask] + + def consume_work(self, task: InputTask): + pass + + worker = MyWorker() + graph.add_workers(worker) + + mock_cache = MockCache() + inject_mock_cache(graph, mock_cache) + + self.assertIs(worker._cache, mock_cache) + + +if __name__ == "__main__": + unittest.main()