diff --git a/src/promptflow/promptflow/executor/_line_execution_process_pool.py b/src/promptflow/promptflow/executor/_line_execution_process_pool.py index 0669d1e88e7..24341c54a13 100644 --- a/src/promptflow/promptflow/executor/_line_execution_process_pool.py +++ b/src/promptflow/promptflow/executor/_line_execution_process_pool.py @@ -379,6 +379,11 @@ def run(self, batch_inputs): while not async_result.ready(): # Check every 1 second async_result.wait(1) + # To ensure exceptions in thread-pool calls are propagated to the main process for proper handling + # The exceptions raised will be re-raised by the get() method. + # Related link: + # https://docs.python.org/3/library/multiprocessing.html#multiprocessing.pool.AsyncResult + async_result.get() except KeyboardInterrupt: raise except PromptflowException: diff --git a/src/promptflow/tests/executor/unittests/processpool/test_line_execution_process_pool.py b/src/promptflow/tests/executor/unittests/processpool/test_line_execution_process_pool.py index efc345959ac..d7a4b6a2c0a 100644 --- a/src/promptflow/tests/executor/unittests/processpool/test_line_execution_process_pool.py +++ b/src/promptflow/tests/executor/unittests/processpool/test_line_execution_process_pool.py @@ -224,3 +224,37 @@ def test_get_multiprocessing_context(self): # Not set start method context = get_multiprocessing_context() assert context.get_start_method() == multiprocessing.get_start_method() + + @pytest.mark.parametrize( + "flow_folder", + [ + SAMPLE_FLOW, + ], + ) + def test_process_pool_run_with_exception(self, flow_folder, dev_connections, mocker: MockFixture): + # mock process pool run execution raise error + test_error_msg = "Test user error" + mocker.patch( + "promptflow.executor._line_execution_process_pool.LineExecutionProcessPool." "_timeout_process_wrapper", + side_effect=UserErrorException(message=test_error_msg, target=ErrorTarget.AZURE_RUN_STORAGE), + ) + executor = FlowExecutor.create( + get_yaml_file(flow_folder), + dev_connections, + ) + run_id = str(uuid.uuid4()) + bulk_inputs = self.get_bulk_inputs() + nlines = len(bulk_inputs) + with LineExecutionProcessPool( + executor, + nlines, + run_id, + "", + False, + None, + ) as pool: + with pytest.raises(UserErrorException) as e: + pool.run(zip(range(nlines), bulk_inputs)) + assert e.value.message == test_error_msg + assert e.value.target == ErrorTarget.AZURE_RUN_STORAGE + assert e.value.error_codes[0] == "UserError"