diff --git a/src/promptflow-core/promptflow/executor/_script_executor.py b/src/promptflow-core/promptflow/executor/_script_executor.py index 616667eccb9..ed1855eb537 100644 --- a/src/promptflow-core/promptflow/executor/_script_executor.py +++ b/src/promptflow-core/promptflow/executor/_script_executor.py @@ -121,6 +121,17 @@ def exec_line( allow_generator_output: bool = False, **kwargs, ) -> LineResult: + if self._is_async: + from promptflow._utils.async_utils import async_run_allowing_running_loop + + return async_run_allowing_running_loop( + self.exec_line_async, + inputs=inputs, + index=index, + run_id=run_id, + allow_generator_output=allow_generator_output, + **kwargs, + ) run_id = run_id or str(uuid.uuid4()) inputs = self._apply_sample_inputs(inputs=inputs) inputs = apply_default_value_for_input(self._inputs_sign, inputs) @@ -309,8 +320,12 @@ async def _exec_line_async( line_run_id = run_info.run_id try: Tracer.start_tracing(line_run_id) - output = await self._func_async(**inputs) - output = self._stringify_generator_output(output) if not allow_generator_output else output + output = self._func_async(**inputs) + # Get the result of the output if it is an awaitable. + # Note that if it is an async generator, it would not be awaitable. + if inspect.isawaitable(output): + output = await output + output = await self._stringify_generator_output_async(output) if not allow_generator_output else output traces = Tracer.end_tracing(line_run_id) output_dict = convert_eager_flow_output_to_dict(output) run_info.api_calls = traces @@ -324,6 +339,17 @@ async def _exec_line_async( run_tracker.persist_flow_run(run_info) return self._construct_line_result(output, run_info) + async def _stringify_generator_output_async(self, output): + if isinstance(output, dict): + return await super()._stringify_generator_output_async(output) + if is_dataclass(output): + kv = {field.name: getattr(output, field.name) for field in dataclasses.fields(output)} + updated_kv = await super()._stringify_generator_output_async(kv) + return dataclasses.replace(output, **updated_kv) + kv = {"output": output} + updated_kv = await super()._stringify_generator_output_async(kv) + return updated_kv["output"] + def _stringify_generator_output(self, output): if isinstance(output, dict): return super()._stringify_generator_output(output) @@ -465,6 +491,8 @@ def _parse_entry_func(self): def _initialize_function(self): func = self._parse_entry_func() + original_func = getattr(func, "__original_function", func) + # If the function is not decorated with trace, add trace for it. if not hasattr(func, "__original_function"): func = _traced(func, trace_type=TraceType.FLOW) @@ -482,10 +510,12 @@ def _initialize_function(self): func = _traced(getattr(func, "__original_function"), trace_type=TraceType.FLOW) inputs, _, _, _ = function_to_interface(func) self._inputs = {k: v.to_flow_input_definition() for k, v in inputs.items()} - if inspect.iscoroutinefunction(func): + if inspect.iscoroutinefunction(original_func) or inspect.isasyncgenfunction(original_func): + self._is_async = True self._func = async_to_sync(func) self._func_async = func else: + self._is_async = False self._func = func self._func_async = sync_to_async(func) self._func_name = self._get_func_name(func=func) diff --git a/src/promptflow-core/tests/core/e2etests/test_eager_flow.py b/src/promptflow-core/tests/core/e2etests/test_eager_flow.py index 8be30b0d157..3e00a3fd460 100644 --- a/src/promptflow-core/tests/core/e2etests/test_eager_flow.py +++ b/src/promptflow-core/tests/core/e2etests/test_eager_flow.py @@ -39,12 +39,30 @@ async def func_entry_async(input_str: str) -> str: return "Hello " + input_str +async def gen_func(input_str: str): + for i in range(5): + await asyncio.sleep(0.1) + yield str(i) + + +class ClassEntryGen: + async def __call__(self, input_str: str): + for i in range(5): + await asyncio.sleep(0.1) + yield str(i) + + function_entries = [ (ClassEntry(), {"input_str": "world"}, "Hello world"), (func_entry, {"input_str": "world"}, "Hello world"), (func_entry_async, {"input_str": "world"}, "Hello world"), ] +generator_entries = [ + (gen_func, {"input_str": "world"}, ["0", "1", "2", "3", "4"]), + (ClassEntryGen(), {"input_str": "world"}, ["0", "1", "2", "3", "4"]), +] + @pytest.mark.usefixtures("recording_injection", "setup_connection_provider", "dev_connections") @pytest.mark.e2etest @@ -208,6 +226,26 @@ async def test_flow_run_with_function_entry_async(self, entry, inputs, expected_ msg = f"The two tasks should run concurrently, but got {delta_desc}" assert 0 <= delta_sec < 0.1, msg + @pytest.mark.asyncio + @pytest.mark.parametrize("entry, inputs, expected_output", generator_entries) + async def test_flow_run_with_generator_entry(self, entry, inputs, expected_output): + executor = FlowExecutor.create(entry, {}) + + line_result = executor.exec_line(inputs=inputs) + assert line_result.run_info.status == Status.Completed + assert line_result.output == "".join(expected_output) # When stream=False, it should be a string + + line_result = await executor.exec_line_async(inputs=inputs) + assert line_result.run_info.status == Status.Completed + assert line_result.output == "".join(expected_output) # When stream=False, it should be a string + + line_result = await executor.exec_line_async(inputs=inputs, allow_generator_output=True) + assert line_result.run_info.status == Status.Completed + list_result = [] + async for item in line_result.output: + list_result.append(item) + assert list_result == expected_output # When stream=True, it should be an async generator + def test_flow_run_with_invalid_inputs(self): # Case 1: input not found flow_file = get_yaml_file("flow_with_signature", root=EAGER_FLOW_ROOT)