Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Executor] Update script executor to support async generator in batch run #3262

Merged
merged 3 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions src/promptflow-core/promptflow/executor/_script_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,17 @@ def exec_line(
allow_generator_output: bool = False,
**kwargs,
) -> LineResult:
if self._is_async:
from promptflow._utils.async_utils import async_run_allowing_running_loop

return async_run_allowing_running_loop(
self.exec_line_async,
inputs=inputs,
index=index,
run_id=run_id,
allow_generator_output=allow_generator_output,
**kwargs,
)
run_id = run_id or str(uuid.uuid4())
inputs = self._apply_sample_inputs(inputs=inputs)
inputs = apply_default_value_for_input(self._inputs_sign, inputs)
Expand Down Expand Up @@ -309,8 +320,12 @@ async def _exec_line_async(
line_run_id = run_info.run_id
try:
Tracer.start_tracing(line_run_id)
output = await self._func_async(**inputs)
output = self._stringify_generator_output(output) if not allow_generator_output else output
output = self._func_async(**inputs)
# Get the result of the output if it is an awaitable.
# Note that if it is an async generator, it would not be awaitable.
if inspect.isawaitable(output):
output = await output
output = await self._stringify_generator_output_async(output) if not allow_generator_output else output
traces = Tracer.end_tracing(line_run_id)
output_dict = convert_eager_flow_output_to_dict(output)
run_info.api_calls = traces
Expand All @@ -324,6 +339,17 @@ async def _exec_line_async(
run_tracker.persist_flow_run(run_info)
return self._construct_line_result(output, run_info)

async def _stringify_generator_output_async(self, output):
if isinstance(output, dict):
return await super()._stringify_generator_output_async(output)
if is_dataclass(output):
kv = {field.name: getattr(output, field.name) for field in dataclasses.fields(output)}
updated_kv = await super()._stringify_generator_output_async(kv)
return dataclasses.replace(output, **updated_kv)
kv = {"output": output}
updated_kv = await super()._stringify_generator_output_async(kv)
return updated_kv["output"]

def _stringify_generator_output(self, output):
if isinstance(output, dict):
return super()._stringify_generator_output(output)
Expand Down Expand Up @@ -465,6 +491,8 @@ def _parse_entry_func(self):

def _initialize_function(self):
func = self._parse_entry_func()
original_func = getattr(func, "__original_function", func)

# If the function is not decorated with trace, add trace for it.
if not hasattr(func, "__original_function"):
func = _traced(func, trace_type=TraceType.FLOW)
Expand All @@ -482,10 +510,12 @@ def _initialize_function(self):
func = _traced(getattr(func, "__original_function"), trace_type=TraceType.FLOW)
inputs, _, _, _ = function_to_interface(func)
self._inputs = {k: v.to_flow_input_definition() for k, v in inputs.items()}
if inspect.iscoroutinefunction(func):
if inspect.iscoroutinefunction(original_func) or inspect.isasyncgenfunction(original_func):
self._is_async = True
self._func = async_to_sync(func)
self._func_async = func
else:
self._is_async = False
self._func = func
self._func_async = sync_to_async(func)
self._func_name = self._get_func_name(func=func)
Expand Down
38 changes: 38 additions & 0 deletions src/promptflow-core/tests/core/e2etests/test_eager_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,30 @@ async def func_entry_async(input_str: str) -> str:
return "Hello " + input_str


async def gen_func(input_str: str):
for i in range(5):
await asyncio.sleep(0.1)
yield str(i)


class ClassEntryGen:
async def __call__(self, input_str: str):
for i in range(5):
await asyncio.sleep(0.1)
yield str(i)


function_entries = [
(ClassEntry(), {"input_str": "world"}, "Hello world"),
(func_entry, {"input_str": "world"}, "Hello world"),
(func_entry_async, {"input_str": "world"}, "Hello world"),
]

generator_entries = [
(gen_func, {"input_str": "world"}, ["0", "1", "2", "3", "4"]),
(ClassEntryGen(), {"input_str": "world"}, ["0", "1", "2", "3", "4"]),
]


@pytest.mark.usefixtures("recording_injection", "setup_connection_provider", "dev_connections")
@pytest.mark.e2etest
Expand Down Expand Up @@ -208,6 +226,26 @@ async def test_flow_run_with_function_entry_async(self, entry, inputs, expected_
msg = f"The two tasks should run concurrently, but got {delta_desc}"
assert 0 <= delta_sec < 0.1, msg

@pytest.mark.asyncio
@pytest.mark.parametrize("entry, inputs, expected_output", generator_entries)
async def test_flow_run_with_generator_entry(self, entry, inputs, expected_output):
executor = FlowExecutor.create(entry, {})

line_result = executor.exec_line(inputs=inputs)
assert line_result.run_info.status == Status.Completed
assert line_result.output == "".join(expected_output) # When stream=False, it should be a string

line_result = await executor.exec_line_async(inputs=inputs)
assert line_result.run_info.status == Status.Completed
assert line_result.output == "".join(expected_output) # When stream=False, it should be a string

line_result = await executor.exec_line_async(inputs=inputs, allow_generator_output=True)
assert line_result.run_info.status == Status.Completed
list_result = []
async for item in line_result.output:
list_result.append(item)
assert list_result == expected_output # When stream=True, it should be an async generator

def test_flow_run_with_invalid_inputs(self):
# Case 1: input not found
flow_file = get_yaml_file("flow_with_signature", root=EAGER_FLOW_ROOT)
Expand Down
Loading