From 704996ff2db9f52a23779508c468c54d13a32612 Mon Sep 17 00:00:00 2001 From: Heyi Date: Tue, 14 May 2024 17:29:10 +0800 Subject: [PATCH 1/3] Update script executor to support async generator in batch run --- .../promptflow/executor/_script_executor.py | 37 +++++++++++++++++-- 1 file changed, 34 insertions(+), 3 deletions(-) diff --git a/src/promptflow-core/promptflow/executor/_script_executor.py b/src/promptflow-core/promptflow/executor/_script_executor.py index 616667eccb9..fcb72e59a4f 100644 --- a/src/promptflow-core/promptflow/executor/_script_executor.py +++ b/src/promptflow-core/promptflow/executor/_script_executor.py @@ -1,3 +1,4 @@ +import asyncio import contextlib import dataclasses import importlib @@ -121,6 +122,17 @@ def exec_line( allow_generator_output: bool = False, **kwargs, ) -> LineResult: + if self._is_async: + from promptflow._utils.async_utils import async_run_allowing_running_loop + + return async_run_allowing_running_loop( + self.exec_line_async, + inputs=inputs, + index=index, + run_id=run_id, + allow_generator_output=allow_generator_output, + **kwargs, + ) run_id = run_id or str(uuid.uuid4()) inputs = self._apply_sample_inputs(inputs=inputs) inputs = apply_default_value_for_input(self._inputs_sign, inputs) @@ -309,8 +321,12 @@ async def _exec_line_async( line_run_id = run_info.run_id try: Tracer.start_tracing(line_run_id) - output = await self._func_async(**inputs) - output = self._stringify_generator_output(output) if not allow_generator_output else output + output = self._func_async(**inputs) + # Get the result of the task if it is a task + # Note that if it is an async generator, it would not be a task. + if isinstance(output, asyncio.Task): + output = await output + output = await self._stringify_generator_output_async(output) if not allow_generator_output else output traces = Tracer.end_tracing(line_run_id) output_dict = convert_eager_flow_output_to_dict(output) run_info.api_calls = traces @@ -324,6 +340,17 @@ async def _exec_line_async( run_tracker.persist_flow_run(run_info) return self._construct_line_result(output, run_info) + async def _stringify_generator_output_async(self, output): + if isinstance(output, dict): + return await super()._stringify_generator_output_async(output) + if is_dataclass(output): + kv = {field.name: getattr(output, field.name) for field in dataclasses.fields(output)} + updated_kv = await super()._stringify_generator_output_async(kv) + return dataclasses.replace(output, **updated_kv) + kv = {"output": output} + updated_kv = await super()._stringify_generator_output_async(kv) + return updated_kv["output"] + def _stringify_generator_output(self, output): if isinstance(output, dict): return super()._stringify_generator_output(output) @@ -465,6 +492,8 @@ def _parse_entry_func(self): def _initialize_function(self): func = self._parse_entry_func() + original_func = getattr(func, "__original_function", func) + # If the function is not decorated with trace, add trace for it. if not hasattr(func, "__original_function"): func = _traced(func, trace_type=TraceType.FLOW) @@ -482,10 +511,12 @@ def _initialize_function(self): func = _traced(getattr(func, "__original_function"), trace_type=TraceType.FLOW) inputs, _, _, _ = function_to_interface(func) self._inputs = {k: v.to_flow_input_definition() for k, v in inputs.items()} - if inspect.iscoroutinefunction(func): + if inspect.iscoroutinefunction(original_func) or inspect.isasyncgenfunction(original_func): + self._is_async = True self._func = async_to_sync(func) self._func_async = func else: + self._is_async = False self._func = func self._func_async = sync_to_async(func) self._func_name = self._get_func_name(func=func) From 00895c2d9b49d4b87b4130ec7968ef5d2707a6a2 Mon Sep 17 00:00:00 2001 From: Heyi Tang Date: Wed, 15 May 2024 16:42:24 +0800 Subject: [PATCH 2/3] Fix bug and add tests --- .../promptflow/executor/_script_executor.py | 6 +-- .../tests/core/e2etests/test_eager_flow.py | 38 +++++++++++++++++++ 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/src/promptflow-core/promptflow/executor/_script_executor.py b/src/promptflow-core/promptflow/executor/_script_executor.py index fcb72e59a4f..cdb38deeb77 100644 --- a/src/promptflow-core/promptflow/executor/_script_executor.py +++ b/src/promptflow-core/promptflow/executor/_script_executor.py @@ -322,9 +322,9 @@ async def _exec_line_async( try: Tracer.start_tracing(line_run_id) output = self._func_async(**inputs) - # Get the result of the task if it is a task - # Note that if it is an async generator, it would not be a task. - if isinstance(output, asyncio.Task): + # Get the result of the output if it is an awaitable. + # Note that if it is an async generator, it would not be awaitable. + if inspect.isawaitable(output): output = await output output = await self._stringify_generator_output_async(output) if not allow_generator_output else output traces = Tracer.end_tracing(line_run_id) diff --git a/src/promptflow-core/tests/core/e2etests/test_eager_flow.py b/src/promptflow-core/tests/core/e2etests/test_eager_flow.py index 8be30b0d157..ce641fb09ee 100644 --- a/src/promptflow-core/tests/core/e2etests/test_eager_flow.py +++ b/src/promptflow-core/tests/core/e2etests/test_eager_flow.py @@ -39,12 +39,30 @@ async def func_entry_async(input_str: str) -> str: return "Hello " + input_str +async def gen_func(input_str: str): + for i in range(5): + await asyncio.sleep(0.1) + yield str(i) + + +class ClassEntryGen: + async def __call__(self, input_str: str): + for i in range(5): + await asyncio.sleep(0.1) + yield str(i) + + function_entries = [ (ClassEntry(), {"input_str": "world"}, "Hello world"), (func_entry, {"input_str": "world"}, "Hello world"), (func_entry_async, {"input_str": "world"}, "Hello world"), ] +generator_entries = [ + (gen_func, {"input_str": "world"}, ["0", "1", "2", "3", "4"]), + (ClassEntryGen(), {"input_str": "world"}, ["0", "1", "2", "3", "4"]), +] + @pytest.mark.usefixtures("recording_injection", "setup_connection_provider", "dev_connections") @pytest.mark.e2etest @@ -207,6 +225,26 @@ async def test_flow_run_with_function_entry_async(self, entry, inputs, expected_ delta_desc = f"{delta_sec}s from {line_result1.run_info.end_time} to {line_result2.run_info.end_time}" msg = f"The two tasks should run concurrently, but got {delta_desc}" assert 0 <= delta_sec < 0.1, msg + + @pytest.mark.asyncio + @pytest.mark.parametrize("entry, inputs, expected_output", generator_entries) + async def test_flow_run_with_generator_entry(self, entry, inputs, expected_output): + executor = FlowExecutor.create(entry, {}) + + line_result = executor.exec_line(inputs=inputs) + assert line_result.run_info.status == Status.Completed + assert line_result.output == "".join(expected_output) # When stream=False, it should be a string + + line_result = await executor.exec_line_async(inputs=inputs) + assert line_result.run_info.status == Status.Completed + assert line_result.output == "".join(expected_output) # When stream=False, it should be a string + + line_result = await executor.exec_line_async(inputs=inputs, allow_generator_output=True) + assert line_result.run_info.status == Status.Completed + list_result = [] + async for item in line_result.output: + list_result.append(item) + assert list_result == expected_output # When stream=True, it should be an async generator def test_flow_run_with_invalid_inputs(self): # Case 1: input not found From 352cc2b636000d3910cff3fb55d718d3ebfe9659 Mon Sep 17 00:00:00 2001 From: Heyi Tang Date: Wed, 15 May 2024 17:10:46 +0800 Subject: [PATCH 3/3] Fix flake8 --- src/promptflow-core/promptflow/executor/_script_executor.py | 1 - src/promptflow-core/tests/core/e2etests/test_eager_flow.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/promptflow-core/promptflow/executor/_script_executor.py b/src/promptflow-core/promptflow/executor/_script_executor.py index cdb38deeb77..ed1855eb537 100644 --- a/src/promptflow-core/promptflow/executor/_script_executor.py +++ b/src/promptflow-core/promptflow/executor/_script_executor.py @@ -1,4 +1,3 @@ -import asyncio import contextlib import dataclasses import importlib diff --git a/src/promptflow-core/tests/core/e2etests/test_eager_flow.py b/src/promptflow-core/tests/core/e2etests/test_eager_flow.py index ce641fb09ee..3e00a3fd460 100644 --- a/src/promptflow-core/tests/core/e2etests/test_eager_flow.py +++ b/src/promptflow-core/tests/core/e2etests/test_eager_flow.py @@ -225,12 +225,12 @@ async def test_flow_run_with_function_entry_async(self, entry, inputs, expected_ delta_desc = f"{delta_sec}s from {line_result1.run_info.end_time} to {line_result2.run_info.end_time}" msg = f"The two tasks should run concurrently, but got {delta_desc}" assert 0 <= delta_sec < 0.1, msg - + @pytest.mark.asyncio @pytest.mark.parametrize("entry, inputs, expected_output", generator_entries) async def test_flow_run_with_generator_entry(self, entry, inputs, expected_output): executor = FlowExecutor.create(entry, {}) - + line_result = executor.exec_line(inputs=inputs) assert line_result.run_info.status == Status.Completed assert line_result.output == "".join(expected_output) # When stream=False, it should be a string