From 98387f38fa75dbdcac137329a3712e1feb2a5b6b Mon Sep 17 00:00:00 2001 From: Junwei Dai Date: Mon, 6 Jan 2025 13:08:22 -0800 Subject: [PATCH 01/16] Add synchronous execution option to workflow provisioning Signed-off-by: Junwei Dai --- .../flowframework/common/CommonValue.java | 6 + .../rest/RestCreateWorkflowAction.java | 6 +- .../rest/RestProvisionWorkflowAction.java | 5 +- .../CreateWorkflowTransportAction.java | 16 +- .../ProvisionWorkflowTransportAction.java | 127 +++++++++++++- .../transport/WorkflowRequest.java | 48 +++++- .../transport/WorkflowResponse.java | 39 ++++- .../rest/RestCreateWorkflowActionTests.java | 35 ++++ .../RestProvisionWorkflowActionTests.java | 22 +++ .../CreateWorkflowTransportActionTests.java | 156 ++++++++++++++++-- .../WorkflowRequestResponseTests.java | 40 ++++- 11 files changed, 476 insertions(+), 24 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index 9c88788b3..2fe46996b 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -8,6 +8,8 @@ */ package org.opensearch.flowframework.common; +import org.opensearch.common.unit.TimeValue; + /** * Representation of common values that are used across project */ @@ -55,6 +57,8 @@ private CommonValue() {} /** The last provisioned time field */ public static final String LAST_PROVISIONED_TIME_FIELD = "last_provisioned_time"; + public static final TimeValue DEFAULT_WAIT_FOR_COMPLETION_TIMEOUT = TimeValue.timeValueSeconds(1); + /* * Constants associated with Rest or Transport actions */ @@ -74,6 +78,8 @@ private CommonValue() {} public static final String PROVISION_WORKFLOW = "provision"; /** The param name for update workflow field in create API */ public static final String UPDATE_WORKFLOW_FIELDS = "update_fields"; + /** The param name for specifying the timeout duration in seconds to wait for workflow completion */ + public static final String WAIT_FOR_COMPLETION_TIMEOUT = "wait_for_completion_timeout"; /** The field name for workflow steps. This field represents the name of the workflow steps to be fetched. */ public static final String WORKFLOW_STEP = "workflow_step"; /** The param name for default use case, used by the create workflow API */ diff --git a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java index 4abedc365..1e0ad5088 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java @@ -12,6 +12,7 @@ import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; import org.opensearch.client.node.NodeClient; +import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContent; @@ -43,6 +44,7 @@ import static org.opensearch.flowframework.common.CommonValue.UPDATE_WORKFLOW_FIELDS; import static org.opensearch.flowframework.common.CommonValue.USE_CASE; import static org.opensearch.flowframework.common.CommonValue.VALIDATION; +import static org.opensearch.flowframework.common.CommonValue.WAIT_FOR_COMPLETION_TIMEOUT; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; @@ -87,6 +89,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli boolean provision = request.paramAsBoolean(PROVISION_WORKFLOW, false); boolean reprovision = request.paramAsBoolean(REPROVISION_WORKFLOW, false); boolean updateFields = request.paramAsBoolean(UPDATE_WORKFLOW_FIELDS, false); + TimeValue waitForCompletionTimeout = request.paramAsTime(WAIT_FOR_COMPLETION_TIMEOUT, null); String useCase = request.param(USE_CASE); // If provisioning, consume all other params and pass to provision transport action @@ -226,7 +229,8 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli validation, provision || updateFields, params, - reprovision + reprovision, + waitForCompletionTimeout ); return channel -> client.execute(CreateWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> { diff --git a/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java index 6ae56905c..502bf9423 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java @@ -12,6 +12,7 @@ import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; import org.opensearch.client.node.NodeClient; +import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContent; @@ -33,6 +34,7 @@ import java.util.stream.Collectors; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.flowframework.common.CommonValue.WAIT_FOR_COMPLETION_TIMEOUT; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; @@ -73,6 +75,7 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { String workflowId = request.param(WORKFLOW_ID); + TimeValue waitForCompletionTimeout = request.paramAsTime(WAIT_FOR_COMPLETION_TIMEOUT, null); try { Map params = parseParamsAndContent(request); if (!flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()) { @@ -86,7 +89,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST); } // Create request and provision - WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null, params); + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null, params, waitForCompletionTimeout); return channel -> client.execute(ProvisionWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> { XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS); channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder)); diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index 813613a32..eda7e42f3 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -251,7 +251,8 @@ private void createExecute(WorkflowRequest request, User user, ActionListener { - listener.onResponse(new WorkflowResponse(provisionResponse.getWorkflowId())); + if (request.getWaitForCompletionTimeout() != null) { + listener.onResponse( + new WorkflowResponse( + provisionResponse.getWorkflowId(), + provisionResponse.getWorkflowState() + ) + ); + } else { + listener.onResponse( + new WorkflowResponse(provisionResponse.getWorkflowId()) + ); + } }, exception -> { String errorMessage = "Provisioning failed."; logger.error(errorMessage, exception); diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index 45f374161..841c76cf5 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -45,6 +45,8 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; import static org.opensearch.flowframework.common.CommonValue.ERROR_FIELD; @@ -210,14 +212,27 @@ private void executeProvisionRequest( ), ActionListener.wrap(updateResponse -> { logger.info("updated workflow {} state to {}", request.getWorkflowId(), State.PROVISIONING); - executeWorkflowAsync(workflowId, provisionProcessSequence, listener); + if (request.getWaitForCompletionTimeout() != null) { + executeWorkflowSync( + workflowId, + provisionProcessSequence, + listener, + request.getWaitForCompletionTimeout().getMillis() + ); + } else { + executeWorkflowAsync(workflowId, provisionProcessSequence, listener); + } // update last provisioned field in template Template newTemplate = Template.builder(template).lastProvisionedTime(Instant.now()).build(); flowFrameworkIndicesHandler.updateTemplateInGlobalContext( request.getWorkflowId(), newTemplate, ActionListener.wrap(templateResponse -> { - listener.onResponse(new WorkflowResponse(request.getWorkflowId())); + if (request.getWaitForCompletionTimeout() != null) { + logger.info("Waiting for workflow completion"); + } else { + listener.onResponse(new WorkflowResponse(request.getWorkflowId())); + } }, exception -> { String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( "Failed to update use case template {}", @@ -275,18 +290,105 @@ private void executeProvisionRequest( */ private void executeWorkflowAsync(String workflowId, List workflowSequence, ActionListener listener) { try { - threadPool.executor(PROVISION_WORKFLOW_THREAD_POOL).execute(() -> { executeWorkflow(workflowSequence, workflowId); }); + threadPool.executor(PROVISION_WORKFLOW_THREAD_POOL) + .execute(() -> { executeWorkflow(workflowSequence, workflowId, listener, false); }); } catch (Exception exception) { listener.onFailure(new FlowFrameworkException("Failed to execute workflow " + workflowId, ExceptionsHelper.status(exception))); } } + /** + * Retrieves a thread from the provision thread pool to execute a workflow with a timeout mechanism. + * If the execution exceeds the specified timeout, it will return the current status of the workflow. + * + * @param workflowId The id of the workflow + * @param workflowSequence The sorted workflow to execute + * @param listener ActionListener for any failures or responses + * @param timeout The timeout duration in milliseconds + */ + private void executeWorkflowSync( + String workflowId, + List workflowSequence, + ActionListener listener, + long timeout + ) { + PlainActionFuture workflowFuture = new PlainActionFuture<>(); + AtomicBoolean isResponseSent = new AtomicBoolean(false); + CompletableFuture.runAsync(() -> { + try { + executeWorkflow(workflowSequence, workflowId, new ActionListener<>() { + @Override + public void onResponse(WorkflowResponse workflowResponse) { + if (isResponseSent.get()) { + logger.info("Ignoring onResponse for workflowId: {} as timeout already occurred", workflowId); + return; + } + isResponseSent.set(true); + workflowFuture.onResponse(null); + listener.onResponse(new WorkflowResponse(workflowResponse.getWorkflowId(), workflowResponse.getWorkflowState())); + } + + @Override + public void onFailure(Exception e) { + if (isResponseSent.get()) { + logger.info("Ignoring onFailure for workflowId: {} as timeout already occurred", workflowId); + return; + } + isResponseSent.set(true); + workflowFuture.onFailure( + new FlowFrameworkException("Failed to execute workflow " + workflowId, ExceptionsHelper.status(e)) + ); + listener.onFailure( + new FlowFrameworkException("Failed to execute workflow " + workflowId, ExceptionsHelper.status(e)) + ); + } + }, true); + } catch (Exception ex) { + if (!isResponseSent.get()) { + isResponseSent.set(true); + workflowFuture.onFailure( + new FlowFrameworkException("Failed to execute workflow " + workflowId, ExceptionsHelper.status(ex)) + ); + listener.onFailure(new FlowFrameworkException("Failed to execute workflow " + workflowId, ExceptionsHelper.status(ex))); + } + } + }, threadPool.executor(PROVISION_WORKFLOW_THREAD_POOL)); + + threadPool.executor(PROVISION_WORKFLOW_THREAD_POOL).execute(() -> { + try { + Thread.sleep(timeout); + if (isResponseSent.compareAndSet(false, true)) { + logger.warn("Workflow execution timed out for workflowId: {}", workflowId); + client.execute( + GetWorkflowStateAction.INSTANCE, + new GetWorkflowStateRequest(workflowId, false), + ActionListener.wrap( + response -> listener.onResponse(new WorkflowResponse(workflowId, response.getWorkflowState())), + exception -> listener.onFailure( + new FlowFrameworkException("Failed to get workflow state after timeout", ExceptionsHelper.status(exception)) + ) + ) + ); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }); + } + /** * Executes the given workflow sequence * @param workflowSequence The topologically sorted workflow to execute * @param workflowId The workflowId associated with the workflow that is executing + * @param listener The ActionListener to handle the workflow response or failure + * @param isSyncExecution Flag indicating whether the workflow should be executed synchronously (true) or asynchronously (false) */ - private void executeWorkflow(List workflowSequence, String workflowId) { + private void executeWorkflow( + List workflowSequence, + String workflowId, + ActionListener listener, + boolean isSyncExecution + ) { String currentStepId = ""; try { Map> workflowFutureMap = new LinkedHashMap<>(); @@ -324,6 +426,23 @@ private void executeWorkflow(List workflowSequence, String workflow ), ActionListener.wrap(updateResponse -> { logger.info("updated workflow {} state to {}", workflowId, State.COMPLETED); + if (isSyncExecution) { + client.execute( + GetWorkflowStateAction.INSTANCE, + new GetWorkflowStateRequest(workflowId, false), + ActionListener.wrap(response -> { + listener.onResponse(new WorkflowResponse(workflowId, response.getWorkflowState())); + }, exception -> { + String errorMessage = "Failed to get workflow state."; + logger.error(errorMessage, exception); + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + } + }) + ); + } }, exception -> { logger.error("Failed to update workflow state for workflow {}", workflowId, exception); }) ); } catch (Exception ex) { diff --git a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java index 97f032e31..9c480be65 100644 --- a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java +++ b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java @@ -11,6 +11,7 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.Nullable; +import org.opensearch.common.unit.TimeValue; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.flowframework.model.Template; @@ -62,13 +63,20 @@ public class WorkflowRequest extends ActionRequest { */ private Map params; + /** + * The timeout duration to wait for workflow completion. + * If null, the request will respond immediately with the workflowId. + */ + @Nullable + private TimeValue waitForCompletionTimeout; + /** * Instantiates a new WorkflowRequest, set validation to all, no provisioning * @param workflowId the documentId of the workflow * @param template the use case template which describes the workflow */ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template) { - this(workflowId, template, new String[] { "all" }, false, Collections.emptyMap(), false); + this(workflowId, template, new String[] { "all" }, false, Collections.emptyMap(), false, null); } /** @@ -78,7 +86,27 @@ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template) * @param params The parameters from the REST path */ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template, Map params) { - this(workflowId, template, new String[] { "all" }, true, params, false); + this(workflowId, template, new String[] { "all" }, true, params, false, null); + } + + /** + * Instantiates a new WorkflowRequest with a specified wait-for-completion timeout. + * This constructor allows the caller to specify a custom timeout for the workflow execution, + * which determines how long the system will wait for the workflow to complete before returning a response. + * By default, the validation is set to "all", and provisioning is set to true. + * @param workflowId The unique document ID of the workflow. Can be null for new workflows. + * @param template The use case template that defines the structure and logic of the workflow. Can be null if not provided. + * @param params A map of parameters extracted from the REST request path, used to customize the workflow execution. + * @param waitForCompletionTimeout The maximum duration to wait for the workflow execution to complete. + * If the workflow does not complete within this timeout, the request will return a timeout response. + */ + public WorkflowRequest( + @Nullable String workflowId, + @Nullable Template template, + Map params, + TimeValue waitForCompletionTimeout + ) { + this(workflowId, template, new String[] { "all" }, true, params, false, waitForCompletionTimeout); } /** @@ -89,6 +117,7 @@ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template, * @param provisionOrUpdate provision or updateFields flag. Only one may be true, the presence of update_fields key in map indicates if updating fields, otherwise true means it's provisioning. * @param params map of REST path params. If provisionOrUpdate is false, must be an empty map. If update_fields key is present, must be only key. * @param reprovision flag to indicate if request is to reprovision + * @param waitForCompletionTimeout the timeout duration (in milliseconds) to wait for workflow completion */ public WorkflowRequest( @Nullable String workflowId, @@ -96,7 +125,8 @@ public WorkflowRequest( String[] validation, boolean provisionOrUpdate, Map params, - boolean reprovision + boolean reprovision, + TimeValue waitForCompletionTimeout ) { this.workflowId = workflowId; this.template = template; @@ -108,6 +138,7 @@ public WorkflowRequest( } this.params = this.updateFields ? Collections.emptyMap() : params; this.reprovision = reprovision; + this.waitForCompletionTimeout = waitForCompletionTimeout; } /** @@ -133,6 +164,7 @@ public WorkflowRequest(StreamInput in) throws IOException { this.params = Collections.emptyMap(); } this.reprovision = !provision && Boolean.parseBoolean(params.get(REPROVISION_WORKFLOW)); + this.waitForCompletionTimeout = in.readOptionalTimeValue(); } /** @@ -193,6 +225,15 @@ public boolean isReprovision() { return this.reprovision; } + /** + * Gets the timeout duration (in milliseconds) to wait for workflow completion. + * @return the timeout duration, or null if the request should return immediately + */ + @Nullable + public TimeValue getWaitForCompletionTimeout() { + return this.waitForCompletionTimeout; + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); @@ -207,6 +248,7 @@ public void writeTo(StreamOutput out) throws IOException { } else if (reprovision) { out.writeMap(Map.of(REPROVISION_WORKFLOW, "true"), StreamOutput::writeString, StreamOutput::writeString); } + out.writeOptionalTimeValue(waitForCompletionTimeout); } @Override diff --git a/src/main/java/org/opensearch/flowframework/transport/WorkflowResponse.java b/src/main/java/org/opensearch/flowframework/transport/WorkflowResponse.java index 20a7700a3..8a9f21d93 100644 --- a/src/main/java/org/opensearch/flowframework/transport/WorkflowResponse.java +++ b/src/main/java/org/opensearch/flowframework/transport/WorkflowResponse.java @@ -13,6 +13,7 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.flowframework.model.WorkflowState; import java.io.IOException; @@ -27,6 +28,8 @@ public class WorkflowResponse extends ActionResponse implements ToXContentObject * The documentId of the workflow entry within the Global Context index */ private String workflowId; + /** The workflow state */ + private WorkflowState workflowState; /** * Instantiates a new WorkflowResponse from params @@ -44,6 +47,8 @@ public WorkflowResponse(String workflowId) { public WorkflowResponse(StreamInput in) throws IOException { super(in); this.workflowId = in.readString(); + this.workflowState = in.readOptionalWriteable(WorkflowState::new); + } /** @@ -54,14 +59,46 @@ public String getWorkflowId() { return this.workflowId; } + /** + * Gets the workflowState of this repsonse + * @return the workflowState + */ + public WorkflowState getWorkflowState() { + return this.workflowState; + } + + /** + * Constructs a new WorkflowResponse object with the specified workflowId and workflowState. + * The WorkflowResponse is typically returned as part of a `wait_for_completion` request, + * indicating the final state of a workflow after execution. + * @param workflowId The unique identifier for the workflow. + * @param workflowState The current state of the workflow, including status, errors (if any), + * and resources created as part of the workflow execution. + */ + public WorkflowResponse(String workflowId, WorkflowState workflowState) { + this.workflowId = workflowId; + this.workflowState = WorkflowState.builder() + .workflowId(workflowId) + .error(workflowState.getError()) + .state(workflowState.getState()) + .resourcesCreated(workflowState.resourcesCreated()) + .build(); + + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(workflowId); + out.writeOptionalWriteable(workflowState); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - return builder.startObject().field(WORKFLOW_ID, this.workflowId).endObject(); + if (workflowState != null) { + return workflowState.toXContent(builder, params); + } else { + return builder.startObject().field(WORKFLOW_ID, this.workflowId).endObject(); + } } } diff --git a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java index f6b1a5fc7..fd31cb823 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java @@ -128,6 +128,41 @@ public void testCreateWorkflowRequestWithParamsAndProvision() throws Exception { assertTrue(channel.capturedResponse().content().utf8ToString().contains("id-123")); } + public void testRestCreateWorkflow_withWaitForCompletionTimeout() throws Exception { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withParams(Map.of("wait_for_completion_timeout", "5s")) + .withContent(new BytesArray(validTemplate), MediaTypeRegistry.JSON) + .build(); + + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(new WorkflowResponse("workflow_1")); + return null; + }).when(nodeClient).execute(any(), any(WorkflowRequest.class), any()); + + createWorkflowRestAction.handleRequest(request, channel, nodeClient); + + assertEquals(RestStatus.CREATED, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("workflow_1")); + } + + public void testInvalidValueForCompletionTimeout() throws Exception { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withParams(Map.of("wait_for_completion_timeout", "invalid_value")) + .withContent(new BytesArray(validTemplate), MediaTypeRegistry.JSON) + .build(); + + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + + IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> { + createWorkflowRestAction.handleRequest(request, channel, nodeClient); + }); + + assertTrue(exception.getMessage().contains("failed to parse setting [wait_for_completion_timeout] with value [invalid_value]")); + } + public void testCreateWorkflowRequestWithParamsButNoProvision() throws Exception { RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) .withPath(this.createWorkflowPath) diff --git a/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java index fd5cd478d..625e48e34 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java @@ -144,4 +144,26 @@ public void testFeatureFlagNotEnabled() throws Exception { assertEquals(RestStatus.FORBIDDEN, channel.capturedResponse().status()); assertTrue(channel.capturedResponse().content().utf8ToString().contains("This API is disabled.")); } + + public void testProvisionWorkflowWithValidWaitForCompletionTimeout() throws Exception { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.provisionWorkflowPath) + .withParams(Map.of("workflow_id", "abc", "wait_for_completion_timeout", "5s")) + .withContent(new BytesArray("{\"foo\": \"bar\"}"), MediaTypeRegistry.JSON) + .build(); + + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(new WorkflowResponse("workflow_1")); + return null; + }).when(nodeClient).execute(any(), any(WorkflowRequest.class), any()); + + provisionWorkflowRestAction.handleRequest(request, channel, nodeClient); + + assertEquals(RestStatus.OK, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("workflow_1")); + } + } diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index ba76bc833..86a950175 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -38,6 +38,7 @@ import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; +import org.opensearch.flowframework.model.WorkflowState; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.plugins.PluginsService; import org.opensearch.search.SearchHit; @@ -48,6 +49,7 @@ import org.opensearch.transport.TransportService; import java.io.IOException; +import java.time.Instant; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; @@ -252,7 +254,7 @@ public void testMaxWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false, Collections.emptyMap(), false); + WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false, Collections.emptyMap(), false,null); doAnswer(invocation -> { ActionListener searchListener = invocation.getArgument(1); @@ -289,7 +291,15 @@ public void onFailure(Exception e) { public void testFailedToCreateNewWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false, Collections.emptyMap(), false); + WorkflowRequest workflowRequest = new WorkflowRequest( + null, + template, + new String[] { "off" }, + false, + Collections.emptyMap(), + false, + null + ); // Bypass checkMaxWorkflows and force onResponse doAnswer(invocation -> { @@ -320,7 +330,15 @@ public void testFailedToCreateNewWorkflow() { public void testCreateNewWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false, Collections.emptyMap(), false); + WorkflowRequest workflowRequest = new WorkflowRequest( + null, + template, + new String[] { "off" }, + false, + Collections.emptyMap(), + false, + null + ); // Bypass checkMaxWorkflows and force onResponse doAnswer(invocation -> { @@ -384,7 +402,15 @@ public void testCreateWithUserAndFilterOn() { ); ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false, Collections.emptyMap(), false); + WorkflowRequest workflowRequest = new WorkflowRequest( + null, + template, + new String[] { "off" }, + false, + Collections.emptyMap(), + false, + null + ); // Bypass checkMaxWorkflows and force onResponse doAnswer(invocation -> { @@ -448,7 +474,15 @@ public void testFailedToCreateNewWorkflowWithNullUser() { ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false, Collections.emptyMap(), false); + WorkflowRequest workflowRequest = new WorkflowRequest( + null, + template, + new String[] { "off" }, + false, + Collections.emptyMap(), + false, + null + ); createWorkflowTransportAction1.doExecute(mock(Task.class), workflowRequest, listener); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); @@ -483,7 +517,15 @@ public void testFailedToCreateNewWorkflowWithNoBackendRoleUser() { ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false, Collections.emptyMap(), false); + WorkflowRequest workflowRequest = new WorkflowRequest( + null, + template, + new String[] { "off" }, + false, + Collections.emptyMap(), + false, + null + ); createWorkflowTransportAction1.doExecute(mock(Task.class), workflowRequest, listener); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); @@ -497,7 +539,15 @@ public void testFailedToCreateNewWorkflowWithNoBackendRoleUser() { public void testUpdateWorkflowWithReprovision() throws IOException { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest("1", template, new String[] { "off" }, false, Collections.emptyMap(), true); + WorkflowRequest workflowRequest = new WorkflowRequest( + "1", + template, + new String[] { "off" }, + false, + Collections.emptyMap(), + true, + null + ); doAnswer(invocation -> { ActionListener getListener = invocation.getArgument(1); @@ -541,7 +591,15 @@ public void testUpdateWorkflowWithReprovision() throws IOException { public void testFailedToUpdateWorkflowWithReprovision() throws IOException { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest("1", template, new String[] { "off" }, false, Collections.emptyMap(), true); + WorkflowRequest workflowRequest = new WorkflowRequest( + "1", + template, + new String[] { "off" }, + false, + Collections.emptyMap(), + true, + null + ); doAnswer(invocation -> { ActionListener getListener = invocation.getArgument(1); @@ -841,7 +899,8 @@ public void testCreateWorkflow_withValidation_withProvision_Success() throws Exc new String[] { "all" }, true, Collections.emptyMap(), - false + false, + null ); // Bypass checkMaxWorkflows and force onResponse @@ -888,6 +947,82 @@ public void testCreateWorkflow_withValidation_withProvision_Success() throws Exc assertEquals("1", workflowResponseCaptor.getValue().getWorkflowId()); } + public void testCreateWorkflow_withValidation_withWaitForCompletion_withProvision_Success() throws Exception { + + Template validTemplate = generateValidTemplate(); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + doNothing().when(workflowProcessSorter).validate(any(), any()); + WorkflowRequest workflowRequest = new WorkflowRequest( + null, + validTemplate, + new String[] { "all" }, + true, + Collections.emptyMap(), + false, + TimeValue.timeValueSeconds(5) + ); + + // Bypass checkMaxWorkflows and force onResponse + doAnswer(invocation -> { + ActionListener checkMaxWorkflowListener = invocation.getArgument(2); + checkMaxWorkflowListener.onResponse(true); + return null; + }).when(createWorkflowTransportAction).checkMaxWorkflows(any(TimeValue.class), anyInt(), any()); + + // Bypass initializeConfigIndex and force onResponse + doAnswer(invocation -> { + ActionListener initalizeMasterKeyIndexListener = invocation.getArgument(0); + initalizeMasterKeyIndexListener.onResponse(true); + return null; + }).when(flowFrameworkIndicesHandler).initializeConfigIndex(any()); + + // Bypass putTemplateToGlobalContext and force onResponse + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onResponse(new IndexResponse(new ShardId(GLOBAL_CONTEXT_INDEX, "", 1), "1", 1L, 1L, 1L, true)); + return null; + }).when(flowFrameworkIndicesHandler).putTemplateToGlobalContext(any(), any()); + + // Bypass putInitialStateToWorkflowState and force on response + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onResponse(new IndexResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "1", 1L, 1L, 1L, true)); + return null; + }).when(flowFrameworkIndicesHandler).putInitialStateToWorkflowState(any(), any(), any()); + + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(2); + WorkflowResponse response = mock(WorkflowResponse.class); + when(response.getWorkflowId()).thenReturn("1"); + when(response.getWorkflowState()).thenReturn( + new WorkflowState( + "1", + "test", + "PROVISIONING", + "IN_PROGRESS", + Instant.now(), + Instant.now(), + TestHelpers.randomUser(), + Collections.emptyMap(), + Collections.emptyList() + ) + ); + responseListener.onResponse(response); + return null; + }).when(client).execute(eq(ProvisionWorkflowAction.INSTANCE), any(WorkflowRequest.class), any(ActionListener.class)); + + ArgumentCaptor workflowResponseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); + + createWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); + + verify(listener, times(1)).onResponse(workflowResponseCaptor.capture()); + assertEquals("1", workflowResponseCaptor.getValue().getWorkflowId()); + assertEquals("PROVISIONING", workflowResponseCaptor.getValue().getWorkflowState().getState()); + } + public void testCreateWorkflow_withValidation_withProvision_FailedProvisioning() throws Exception { Template validTemplate = generateValidTemplate(); @@ -901,7 +1036,8 @@ public void testCreateWorkflow_withValidation_withProvision_FailedProvisioning() new String[] { "all" }, true, Collections.emptyMap(), - false + false, + null ); // Bypass checkMaxWorkflows and force onResponse diff --git a/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java b/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java index e92255e0f..50c60a19e 100644 --- a/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java @@ -21,9 +21,11 @@ import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; +import org.opensearch.flowframework.model.WorkflowState; import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; +import java.time.Instant; import java.util.Collections; import java.util.List; import java.util.Map; @@ -156,7 +158,7 @@ public void testWorkflowRequestWithParams() throws IOException { public void testWorkflowRequestWithParamsNoProvision() throws IOException { IllegalArgumentException ex = assertThrows( IllegalArgumentException.class, - () -> new WorkflowRequest("123", template, new String[] { "all" }, false, Map.of("foo", "bar"), false) + () -> new WorkflowRequest("123", template, new String[] { "all" }, false, Map.of("foo", "bar"), false, null) ); assertEquals("Params may only be included when provisioning.", ex.getMessage()); } @@ -168,7 +170,8 @@ public void testWorkflowRequestWithOnlyUpdateParamNoProvision() throws IOExcepti new String[] { "all" }, true, Map.of(UPDATE_WORKFLOW_FIELDS, "true"), - false + false, + null ); assertNotNull(workflowRequest.getWorkflowId()); assertEquals(template, workflowRequest.getTemplate()); @@ -208,4 +211,37 @@ public void testWorkflowResponse() throws IOException { assertEquals("{\"workflow_id\":\"123\"}", builder.toString()); } + public void testWorkflowResponseWithWaitForCompletionTimeOut() throws IOException { + WorkflowState workFlowState = new WorkflowState( + "123", + "test", + "PROVISIONING", + "IN_PROGRESS", + Instant.now(), + Instant.now(), + TestHelpers.randomUser(), + Collections.emptyMap(), + Collections.emptyList() + ); + + WorkflowResponse response = new WorkflowResponse("123", workFlowState); + assertEquals("123", response.getWorkflowId()); + assertEquals("PROVISIONING", response.getWorkflowState().getState()); + + BytesStreamOutput out = new BytesStreamOutput(); + response.writeTo(out); + BytesStreamInput in = new BytesStreamInput(BytesReference.toBytes(out.bytes())); + WorkflowResponse streamInputResponse = new WorkflowResponse(in); + + assertEquals(response.getWorkflowId(), streamInputResponse.getWorkflowId()); + assertEquals(response.getWorkflowState().getState(), streamInputResponse.getWorkflowState().getState()); + + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + + assertNotNull(builder); + assertTrue(builder.toString().contains("\"workflow_id\":\"123\"")); + assertTrue(builder.toString().contains("\"state\":\"PROVISIONING\"")); + } + } From f3898cff5fbf90c4830fa748aa02dd557c51f65c Mon Sep 17 00:00:00 2001 From: Junwei Dai Date: Wed, 8 Jan 2025 13:08:34 -0800 Subject: [PATCH 02/16] code refactor Signed-off-by: Junwei Dai --- .../CreateWorkflowTransportAction.java | 14 ++--- .../ProvisionWorkflowTransportAction.java | 53 +++++++++++++++---- 2 files changed, 48 insertions(+), 19 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index eda7e42f3..ab5e45df6 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -262,18 +262,14 @@ private void createExecute(WorkflowRequest request, User user, ActionListener { - if (request.getWaitForCompletionTimeout() != null) { - listener.onResponse( - new WorkflowResponse( + listener.onResponse( + request.getWaitForCompletionTimeout() != null + ? new WorkflowResponse( provisionResponse.getWorkflowId(), provisionResponse.getWorkflowState() ) - ); - } else { - listener.onResponse( - new WorkflowResponse(provisionResponse.getWorkflowId()) - ); - } + : new WorkflowResponse(provisionResponse.getWorkflowId()) + ); }, exception -> { String errorMessage = "Provisioning failed."; logger.error(errorMessage, exception); diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index 841c76cf5..17aa4cdff 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -354,21 +354,33 @@ public void onFailure(Exception e) { } }, threadPool.executor(PROVISION_WORKFLOW_THREAD_POOL)); + // Schedule timeout handler + scheduleTimeoutHandler(workflowId, listener, timeout, isResponseSent); + } + + /** + * Schedules a timeout handler for workflow execution. + * This method starts a new task in the thread pool to wait for the specified timeout duration. + * If the workflow does not complete within the given timeout, it triggers a follow-up action + * to fetch the workflow's state and notify the listener. + * + * @param workflowId The unique identifier of the workflow being executed. + * @param listener The ActionListener to notify with the workflow's response or failure. + * @param timeout The maximum time (in milliseconds) to wait for the workflow to complete before timing out. + * @param isResponseSent An AtomicBoolean flag to ensure the response is sent only once. + */ + private void scheduleTimeoutHandler( + String workflowId, + ActionListener listener, + long timeout, + AtomicBoolean isResponseSent + ) { threadPool.executor(PROVISION_WORKFLOW_THREAD_POOL).execute(() -> { try { Thread.sleep(timeout); if (isResponseSent.compareAndSet(false, true)) { logger.warn("Workflow execution timed out for workflowId: {}", workflowId); - client.execute( - GetWorkflowStateAction.INSTANCE, - new GetWorkflowStateRequest(workflowId, false), - ActionListener.wrap( - response -> listener.onResponse(new WorkflowResponse(workflowId, response.getWorkflowState())), - exception -> listener.onFailure( - new FlowFrameworkException("Failed to get workflow state after timeout", ExceptionsHelper.status(exception)) - ) - ) - ); + fetchWorkflowStateAfterTimeout(workflowId, listener); } } catch (InterruptedException e) { Thread.currentThread().interrupt(); @@ -376,6 +388,27 @@ public void onFailure(Exception e) { }); } + /** + * Fetches the workflow state after a timeout has occurred. + * This method sends a request to retrieve the current state of the workflow + * and notifies the listener with the updated state or an error if the request fails. + * + * @param workflowId The unique identifier of the workflow whose state needs to be fetched. + * @param listener The ActionListener to notify with the workflow's updated state or failure. + */ + private void fetchWorkflowStateAfterTimeout(String workflowId, ActionListener listener) { + client.execute( + GetWorkflowStateAction.INSTANCE, + new GetWorkflowStateRequest(workflowId, false), + ActionListener.wrap( + response -> listener.onResponse(new WorkflowResponse(workflowId, response.getWorkflowState())), + exception -> listener.onFailure( + new FlowFrameworkException("Failed to get workflow state after timeout", ExceptionsHelper.status(exception)) + ) + ) + ); + } + /** * Executes the given workflow sequence * @param workflowSequence The topologically sorted workflow to execute From 80c5340afba205c462498f4fc132f6ba562d6739 Mon Sep 17 00:00:00 2001 From: Junwei Dai Date: Wed, 8 Jan 2025 13:31:03 -0800 Subject: [PATCH 03/16] add change log Signed-off-by: Junwei Dai --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c8f99f0bb..cd347365f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,8 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) ## [Unreleased 2.x](https://github.com/opensearch-project/flow-framework/compare/2.18...2.x) ### Features +- Add synchronous execution option to workflow provisioning ([#990](https://github.com/opensearch-project/flow-framework/pull/990)) + ### Enhancements ### Bug Fixes - Remove useCase and defaultParams field in WorkflowRequest ([#758](https://github.com/opensearch-project/flow-framework/pull/758)) From a2321e8fc2bf71e89207d3f35cf1d5b18b6929d4 Mon Sep 17 00:00:00 2001 From: Junwei Dai Date: Thu, 9 Jan 2025 14:56:07 -0800 Subject: [PATCH 04/16] refactor code based on comment Signed-off-by: Junwei Dai --- .../rest/RestCreateWorkflowAction.java | 6 +- .../rest/RestProvisionWorkflowAction.java | 2 +- .../CreateWorkflowTransportAction.java | 20 +- .../ProvisionWorkflowTransportAction.java | 98 ++------- .../transport/WorkflowRequest.java | 24 ++- .../transport/WorkflowResponse.java | 11 +- .../util/WorkflowTimeoutUtility.java | 194 ++++++++++++++++++ .../rest/RestCreateWorkflowActionTests.java | 16 +- .../CreateWorkflowTransportActionTests.java | 81 +++++++- .../util/WorkflowTimeoutUtilityTests.java | 135 ++++++++++++ 10 files changed, 468 insertions(+), 119 deletions(-) create mode 100644 src/main/java/org/opensearch/flowframework/util/WorkflowTimeoutUtility.java create mode 100644 src/test/java/org/opensearch/flowframework/util/WorkflowTimeoutUtilityTests.java diff --git a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java index 1e0ad5088..4abedc365 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java @@ -12,7 +12,6 @@ import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; import org.opensearch.client.node.NodeClient; -import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContent; @@ -44,7 +43,6 @@ import static org.opensearch.flowframework.common.CommonValue.UPDATE_WORKFLOW_FIELDS; import static org.opensearch.flowframework.common.CommonValue.USE_CASE; import static org.opensearch.flowframework.common.CommonValue.VALIDATION; -import static org.opensearch.flowframework.common.CommonValue.WAIT_FOR_COMPLETION_TIMEOUT; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; @@ -89,7 +87,6 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli boolean provision = request.paramAsBoolean(PROVISION_WORKFLOW, false); boolean reprovision = request.paramAsBoolean(REPROVISION_WORKFLOW, false); boolean updateFields = request.paramAsBoolean(UPDATE_WORKFLOW_FIELDS, false); - TimeValue waitForCompletionTimeout = request.paramAsTime(WAIT_FOR_COMPLETION_TIMEOUT, null); String useCase = request.param(USE_CASE); // If provisioning, consume all other params and pass to provision transport action @@ -229,8 +226,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli validation, provision || updateFields, params, - reprovision, - waitForCompletionTimeout + reprovision ); return channel -> client.execute(CreateWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> { diff --git a/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java index 502bf9423..e197312ed 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java @@ -75,7 +75,7 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { String workflowId = request.param(WORKFLOW_ID); - TimeValue waitForCompletionTimeout = request.paramAsTime(WAIT_FOR_COMPLETION_TIMEOUT, null); + TimeValue waitForCompletionTimeout = request.paramAsTime(WAIT_FOR_COMPLETION_TIMEOUT, TimeValue.MINUS_ONE); try { Map params = parseParamsAndContent(request); if (!flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()) { diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index ab5e45df6..44c8fab1e 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -53,6 +53,7 @@ import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; import static org.opensearch.flowframework.common.CommonValue.PROVISIONING_PROGRESS_FIELD; import static org.opensearch.flowframework.common.CommonValue.STATE_FIELD; +import static org.opensearch.flowframework.common.CommonValue.WAIT_FOR_COMPLETION_TIMEOUT; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES; import static org.opensearch.flowframework.util.ParseUtils.checkFilterByBackendRoles; import static org.opensearch.flowframework.util.ParseUtils.getUserContext; @@ -248,11 +249,14 @@ private void createExecute(WorkflowRequest request, User user, ActionListener { logger.info("Creating state workflow doc: {}", globalContextResponse.getId()); if (request.isProvision()) { + String waitForTimeCompletion = request.getParams() + .getOrDefault(WAIT_FOR_COMPLETION_TIMEOUT, TimeValue.MINUS_ONE.toString()); WorkflowRequest workflowRequest = new WorkflowRequest( globalContextResponse.getId(), null, request.getParams(), - request.getWaitForCompletionTimeout() + // todo : what is this setting name represent? + TimeValue.parseTimeValue(waitForTimeCompletion, "provision.timout") ); logger.info( "Provisioning parameter is set, continuing to provision workflow {}", @@ -262,14 +266,18 @@ private void createExecute(WorkflowRequest request, User user, ActionListener { - listener.onResponse( - request.getWaitForCompletionTimeout() != null - ? new WorkflowResponse( + if (workflowRequest.getWaitForCompletionTimeout() == TimeValue.MINUS_ONE) { + listener.onResponse( + new WorkflowResponse(provisionResponse.getWorkflowId()) + ); + } else { + listener.onResponse( + new WorkflowResponse( provisionResponse.getWorkflowId(), provisionResponse.getWorkflowState() ) - : new WorkflowResponse(provisionResponse.getWorkflowId()) - ); + ); + } }, exception -> { String errorMessage = "Provisioning failed."; logger.error(errorMessage, exception); diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index 17aa4cdff..bfe9aee0e 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -20,6 +20,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; @@ -32,6 +33,7 @@ import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.util.EncryptorUtils; +import org.opensearch.flowframework.util.WorkflowTimeoutUtility; import org.opensearch.flowframework.workflow.ProcessNode; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.plugins.PluginsService; @@ -212,15 +214,15 @@ private void executeProvisionRequest( ), ActionListener.wrap(updateResponse -> { logger.info("updated workflow {} state to {}", request.getWorkflowId(), State.PROVISIONING); - if (request.getWaitForCompletionTimeout() != null) { + if (request.getWaitForCompletionTimeout() == TimeValue.MINUS_ONE) { + executeWorkflowAsync(workflowId, provisionProcessSequence, listener); + } else { executeWorkflowSync( workflowId, provisionProcessSequence, listener, request.getWaitForCompletionTimeout().getMillis() ); - } else { - executeWorkflowAsync(workflowId, provisionProcessSequence, listener); } // update last provisioned field in template Template newTemplate = Template.builder(template).lastProvisionedTime(Instant.now()).build(); @@ -228,10 +230,10 @@ private void executeProvisionRequest( request.getWorkflowId(), newTemplate, ActionListener.wrap(templateResponse -> { - if (request.getWaitForCompletionTimeout() != null) { - logger.info("Waiting for workflow completion"); - } else { + if (request.getWaitForCompletionTimeout() == TimeValue.MINUS_ONE) { listener.onResponse(new WorkflowResponse(request.getWorkflowId())); + } else { + logger.info("Waiting for workflow completion"); } }, exception -> { String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( @@ -312,101 +314,27 @@ private void executeWorkflowSync( ActionListener listener, long timeout ) { - PlainActionFuture workflowFuture = new PlainActionFuture<>(); AtomicBoolean isResponseSent = new AtomicBoolean(false); + CompletableFuture.runAsync(() -> { try { executeWorkflow(workflowSequence, workflowId, new ActionListener<>() { @Override public void onResponse(WorkflowResponse workflowResponse) { - if (isResponseSent.get()) { - logger.info("Ignoring onResponse for workflowId: {} as timeout already occurred", workflowId); - return; - } - isResponseSent.set(true); - workflowFuture.onResponse(null); - listener.onResponse(new WorkflowResponse(workflowResponse.getWorkflowId(), workflowResponse.getWorkflowState())); + WorkflowTimeoutUtility.handleResponse(workflowId, workflowResponse, isResponseSent, listener); } @Override public void onFailure(Exception e) { - if (isResponseSent.get()) { - logger.info("Ignoring onFailure for workflowId: {} as timeout already occurred", workflowId); - return; - } - isResponseSent.set(true); - workflowFuture.onFailure( - new FlowFrameworkException("Failed to execute workflow " + workflowId, ExceptionsHelper.status(e)) - ); - listener.onFailure( - new FlowFrameworkException("Failed to execute workflow " + workflowId, ExceptionsHelper.status(e)) - ); + WorkflowTimeoutUtility.handleFailure(workflowId, e, isResponseSent, listener); } }, true); } catch (Exception ex) { - if (!isResponseSent.get()) { - isResponseSent.set(true); - workflowFuture.onFailure( - new FlowFrameworkException("Failed to execute workflow " + workflowId, ExceptionsHelper.status(ex)) - ); - listener.onFailure(new FlowFrameworkException("Failed to execute workflow " + workflowId, ExceptionsHelper.status(ex))); - } + WorkflowTimeoutUtility.handleFailure(workflowId, ex, isResponseSent, listener); } }, threadPool.executor(PROVISION_WORKFLOW_THREAD_POOL)); - // Schedule timeout handler - scheduleTimeoutHandler(workflowId, listener, timeout, isResponseSent); - } - - /** - * Schedules a timeout handler for workflow execution. - * This method starts a new task in the thread pool to wait for the specified timeout duration. - * If the workflow does not complete within the given timeout, it triggers a follow-up action - * to fetch the workflow's state and notify the listener. - * - * @param workflowId The unique identifier of the workflow being executed. - * @param listener The ActionListener to notify with the workflow's response or failure. - * @param timeout The maximum time (in milliseconds) to wait for the workflow to complete before timing out. - * @param isResponseSent An AtomicBoolean flag to ensure the response is sent only once. - */ - private void scheduleTimeoutHandler( - String workflowId, - ActionListener listener, - long timeout, - AtomicBoolean isResponseSent - ) { - threadPool.executor(PROVISION_WORKFLOW_THREAD_POOL).execute(() -> { - try { - Thread.sleep(timeout); - if (isResponseSent.compareAndSet(false, true)) { - logger.warn("Workflow execution timed out for workflowId: {}", workflowId); - fetchWorkflowStateAfterTimeout(workflowId, listener); - } - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } - }); - } - - /** - * Fetches the workflow state after a timeout has occurred. - * This method sends a request to retrieve the current state of the workflow - * and notifies the listener with the updated state or an error if the request fails. - * - * @param workflowId The unique identifier of the workflow whose state needs to be fetched. - * @param listener The ActionListener to notify with the workflow's updated state or failure. - */ - private void fetchWorkflowStateAfterTimeout(String workflowId, ActionListener listener) { - client.execute( - GetWorkflowStateAction.INSTANCE, - new GetWorkflowStateRequest(workflowId, false), - ActionListener.wrap( - response -> listener.onResponse(new WorkflowResponse(workflowId, response.getWorkflowState())), - exception -> listener.onFailure( - new FlowFrameworkException("Failed to get workflow state after timeout", ExceptionsHelper.status(exception)) - ) - ) - ); + WorkflowTimeoutUtility.scheduleTimeoutHandler(client, threadPool, workflowId, listener, timeout, isResponseSent); } /** diff --git a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java index 9c480be65..709f1ca8a 100644 --- a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java +++ b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java @@ -76,7 +76,7 @@ public class WorkflowRequest extends ActionRequest { * @param template the use case template which describes the workflow */ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template) { - this(workflowId, template, new String[] { "all" }, false, Collections.emptyMap(), false, null); + this(workflowId, template, new String[] { "all" }, false, Collections.emptyMap(), false, TimeValue.MINUS_ONE); } /** @@ -86,7 +86,7 @@ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template) * @param params The parameters from the REST path */ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template, Map params) { - this(workflowId, template, new String[] { "all" }, true, params, false, null); + this(workflowId, template, new String[] { "all" }, true, params, false, TimeValue.MINUS_ONE); } /** @@ -109,6 +109,26 @@ public WorkflowRequest( this(workflowId, template, new String[] { "all" }, true, params, false, waitForCompletionTimeout); } + /** + * Instantiates a new WorkflowRequest + * @param workflowId the documentId of the workflow + * @param template the use case template which describes the workflow + * @param validation flag to indicate if validation is necessary + * @param provisionOrUpdate provision or updateFields flag. Only one may be true, the presence of update_fields key in map indicates if updating fields, otherwise true means it's provisioning. + * @param params map of REST path params. If provisionOrUpdate is false, must be an empty map. If update_fields key is present, must be only key. + * @param reprovision flag to indicate if request is to reprovision + */ + public WorkflowRequest( + @Nullable String workflowId, + @Nullable Template template, + String[] validation, + boolean provisionOrUpdate, + Map params, + boolean reprovision + ) { + this(workflowId, template, validation, provisionOrUpdate, params, reprovision, TimeValue.MINUS_ONE); + } + /** * Instantiates a new WorkflowRequest * @param workflowId the documentId of the workflow diff --git a/src/main/java/org/opensearch/flowframework/transport/WorkflowResponse.java b/src/main/java/org/opensearch/flowframework/transport/WorkflowResponse.java index 8a9f21d93..1f8dd5681 100644 --- a/src/main/java/org/opensearch/flowframework/transport/WorkflowResponse.java +++ b/src/main/java/org/opensearch/flowframework/transport/WorkflowResponse.java @@ -8,6 +8,7 @@ */ package org.opensearch.flowframework.transport; +import org.opensearch.Version; import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -47,7 +48,10 @@ public WorkflowResponse(String workflowId) { public WorkflowResponse(StreamInput in) throws IOException { super(in); this.workflowId = in.readString(); - this.workflowState = in.readOptionalWriteable(WorkflowState::new); + // todo : change version to 2_19_0 + if (in.getVersion().onOrAfter(Version.CURRENT)) { + this.workflowState = in.readOptionalWriteable(WorkflowState::new); + } } @@ -89,7 +93,10 @@ public WorkflowResponse(String workflowId, WorkflowState workflowState) { @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(workflowId); - out.writeOptionalWriteable(workflowState); + // todo : change version to 2_19_0 + if (out.getVersion().onOrAfter(Version.CURRENT)) { + out.writeOptionalWriteable(workflowState); + } } @Override diff --git a/src/main/java/org/opensearch/flowframework/util/WorkflowTimeoutUtility.java b/src/main/java/org/opensearch/flowframework/util/WorkflowTimeoutUtility.java new file mode 100644 index 000000000..4991e0491 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/util/WorkflowTimeoutUtility.java @@ -0,0 +1,194 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.util; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.client.Client; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.transport.GetWorkflowStateAction; +import org.opensearch.flowframework.transport.GetWorkflowStateRequest; +import org.opensearch.flowframework.transport.WorkflowResponse; +import org.opensearch.threadpool.Scheduler; +import org.opensearch.threadpool.ThreadPool; + +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * Utility class for managing timeout tasks in workflow execution. + * This class provides methods to schedule timeout handlers, wrap listeners with timeout cancellation logic, + * and fetch workflow states after timeouts. + */ +public class WorkflowTimeoutUtility { + + private static final Logger logger = LogManager.getLogger(WorkflowTimeoutUtility.class); + private static final TimeValue MAX_TIMEOUT_MILLIS = TimeValue.timeValueSeconds(300); + + /** + * Schedules a timeout task for a workflow execution. + * + * @param client The OpenSearch client used to interact with the cluster. + * @param threadPool The thread pool to schedule the timeout task. + * @param workflowId The unique identifier of the workflow being executed. + * @param listener The listener to notify when the task completes or times out. + * @param timeout The timeout duration in milliseconds. + * @param isResponseSent An atomic boolean to ensure the response is sent only once. + * @return A wrapped ActionListener with timeout cancellation logic. + */ + public static ActionListener scheduleTimeoutHandler( + Client client, + ThreadPool threadPool, + final String workflowId, + ActionListener listener, + long timeout, + AtomicBoolean isResponseSent + ) { + long adjustedTimeout = Math.min(timeout, MAX_TIMEOUT_MILLIS.millis()); + Scheduler.ScheduledCancellable scheduledCancellable = threadPool.schedule( + new WorkflowTimeoutListener(client, workflowId, listener, isResponseSent), + TimeValue.timeValueMillis(adjustedTimeout), + ThreadPool.Names.GENERIC + ); + + return wrapWithTimeoutCancellationListener(listener, scheduledCancellable, isResponseSent); + } + + /** + * A listener that handles timeout for a workflow execution. + */ + private static class WorkflowTimeoutListener implements Runnable { + private final Client client; + private final String workflowId; + private final ActionListener listener; + private final AtomicBoolean isResponseSent; + + WorkflowTimeoutListener(Client client, String workflowId, ActionListener listener, AtomicBoolean isResponseSent) { + this.client = client; + this.workflowId = workflowId; + this.listener = listener; + this.isResponseSent = isResponseSent; + } + + @Override + public void run() { + if (isResponseSent.compareAndSet(false, true)) { + logger.warn("Workflow execution timed out for workflowId: {}", workflowId); + fetchWorkflowStateAfterTimeout(client, workflowId, listener); + } + } + } + + /** + * Wraps a listener with a timeout cancellation listener to cancel the timeout task when the workflow completes. + * + * @param listener The original listener to wrap. + * @param scheduledCancellable The cancellable timeout task. + * @param isResponseSent An atomic boolean to ensure the response is sent only once. + * @param The type of the response expected by the listener. + * @return A wrapped ActionListener with timeout cancellation logic. + */ + public static ActionListener wrapWithTimeoutCancellationListener( + ActionListener listener, + Scheduler.ScheduledCancellable scheduledCancellable, + AtomicBoolean isResponseSent + ) { + return new ActionListener<>() { + @Override + public void onResponse(Response response) { + if (isResponseSent.compareAndSet(false, true)) { + scheduledCancellable.cancel(); + } + listener.onResponse(response); + } + + @Override + public void onFailure(Exception e) { + if (isResponseSent.compareAndSet(false, true)) { + scheduledCancellable.cancel(); + } + listener.onFailure(e); + } + }; + } + + /** + * Handles the successful completion of a workflow. + * + * @param workflowId The unique identifier of the workflow. + * @param workflowResponse The response from the workflow execution. + * @param isResponseSent An atomic boolean to ensure the response is sent only once. + * @param listener The listener to notify of the workflow completion. + */ + public static void handleResponse( + String workflowId, + WorkflowResponse workflowResponse, + AtomicBoolean isResponseSent, + ActionListener listener + ) { + if (isResponseSent.compareAndSet(false, true)) { + listener.onResponse(new WorkflowResponse(workflowResponse.getWorkflowId(), workflowResponse.getWorkflowState())); + } else { + logger.info("Ignoring onResponse for workflowId: {} as timeout already occurred", workflowId); + } + } + + /** + * Handles the failure of a workflow execution. + * + * @param workflowId The unique identifier of the workflow. + * @param e The exception that occurred during workflow execution. + * @param isResponseSent An atomic boolean to ensure the response is sent only once. + * @param listener The listener to notify of the workflow failure. + */ + public static void handleFailure( + String workflowId, + Exception e, + AtomicBoolean isResponseSent, + ActionListener listener + ) { + if (isResponseSent.compareAndSet(false, true)) { + FlowFrameworkException exception = new FlowFrameworkException( + "Failed to execute workflow " + workflowId, + ExceptionsHelper.status(e) + ); + listener.onFailure(exception); + } else { + logger.info("Ignoring onFailure for workflowId: {} as timeout already occurred", workflowId); + } + } + + /** + * Fetches the workflow state after a timeout has occurred. + * This method sends a request to retrieve the current state of the workflow + * and notifies the listener with the updated state or an error if the request fails. + * + * @param client The OpenSearch client used to fetch the workflow state. + * @param workflowId The unique identifier of the workflow. + * @param listener The listener to notify with the updated state or failure. + */ + public static void fetchWorkflowStateAfterTimeout( + final Client client, + final String workflowId, + final ActionListener listener + ) { + client.execute( + GetWorkflowStateAction.INSTANCE, + new GetWorkflowStateRequest(workflowId, false), + ActionListener.wrap( + response -> listener.onResponse(new WorkflowResponse(workflowId, response.getWorkflowState())), + exception -> listener.onFailure( + new FlowFrameworkException("Failed to get workflow state after timeout", ExceptionsHelper.status(exception)) + ) + ) + ); + } +} diff --git a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java index fd31cb823..063881d26 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java @@ -130,7 +130,7 @@ public void testCreateWorkflowRequestWithParamsAndProvision() throws Exception { public void testRestCreateWorkflow_withWaitForCompletionTimeout() throws Exception { RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) - .withParams(Map.of("wait_for_completion_timeout", "5s")) + .withParams(Map.ofEntries(Map.entry(PROVISION_WORKFLOW, "true"), Map.entry("wait_for_completion_timeout", "5s"))) .withContent(new BytesArray(validTemplate), MediaTypeRegistry.JSON) .build(); @@ -148,20 +148,6 @@ public void testRestCreateWorkflow_withWaitForCompletionTimeout() throws Excepti assertTrue(channel.capturedResponse().content().utf8ToString().contains("workflow_1")); } - public void testInvalidValueForCompletionTimeout() throws Exception { - RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) - .withParams(Map.of("wait_for_completion_timeout", "invalid_value")) - .withContent(new BytesArray(validTemplate), MediaTypeRegistry.JSON) - .build(); - - FakeRestChannel channel = new FakeRestChannel(request, false, 1); - - IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> { - createWorkflowRestAction.handleRequest(request, channel, nodeClient); - }); - - assertTrue(exception.getMessage().contains("failed to parse setting [wait_for_completion_timeout] with value [invalid_value]")); - } public void testCreateWorkflowRequestWithParamsButNoProvision() throws Exception { RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index 86a950175..a38688757 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -64,6 +64,7 @@ import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; import static org.opensearch.flowframework.common.CommonValue.UPDATE_WORKFLOW_FIELDS; +import static org.opensearch.flowframework.common.CommonValue.WAIT_FOR_COMPLETION_TIMEOUT; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID; import static org.opensearch.flowframework.common.WorkflowResources.CREATE_CONNECTOR; @@ -960,9 +961,83 @@ public void testCreateWorkflow_withValidation_withWaitForCompletion_withProvisio validTemplate, new String[] { "all" }, true, - Collections.emptyMap(), - false, - TimeValue.timeValueSeconds(5) + Map.of(WAIT_FOR_COMPLETION_TIMEOUT, "5s"), + false + ); + + // Bypass checkMaxWorkflows and force onResponse + doAnswer(invocation -> { + ActionListener checkMaxWorkflowListener = invocation.getArgument(2); + checkMaxWorkflowListener.onResponse(true); + return null; + }).when(createWorkflowTransportAction).checkMaxWorkflows(any(TimeValue.class), anyInt(), any()); + + // Bypass initializeConfigIndex and force onResponse + doAnswer(invocation -> { + ActionListener initalizeMasterKeyIndexListener = invocation.getArgument(0); + initalizeMasterKeyIndexListener.onResponse(true); + return null; + }).when(flowFrameworkIndicesHandler).initializeConfigIndex(any()); + + // Bypass putTemplateToGlobalContext and force onResponse + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onResponse(new IndexResponse(new ShardId(GLOBAL_CONTEXT_INDEX, "", 1), "1", 1L, 1L, 1L, true)); + return null; + }).when(flowFrameworkIndicesHandler).putTemplateToGlobalContext(any(), any()); + + // Bypass putInitialStateToWorkflowState and force on response + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onResponse(new IndexResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "1", 1L, 1L, 1L, true)); + return null; + }).when(flowFrameworkIndicesHandler).putInitialStateToWorkflowState(any(), any(), any()); + + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(2); + WorkflowResponse response = mock(WorkflowResponse.class); + when(response.getWorkflowId()).thenReturn("1"); + when(response.getWorkflowState()).thenReturn( + new WorkflowState( + "1", + "test", + "PROVISIONING", + "IN_PROGRESS", + Instant.now(), + Instant.now(), + TestHelpers.randomUser(), + Collections.emptyMap(), + Collections.emptyList() + ) + ); + responseListener.onResponse(response); + return null; + }).when(client).execute(eq(ProvisionWorkflowAction.INSTANCE), any(WorkflowRequest.class), any(ActionListener.class)); + + ArgumentCaptor workflowResponseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); + + createWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); + + verify(listener, times(1)).onResponse(workflowResponseCaptor.capture()); + assertEquals("1", workflowResponseCaptor.getValue().getWorkflowId()); + assertEquals("PROVISIONING", workflowResponseCaptor.getValue().getWorkflowState().getState()); + } + + public void testCreateWorkflow_withValidation_withWaitForCompletionTimeSetZero_withProvision_Success() throws Exception { + + Template validTemplate = generateValidTemplate(); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + doNothing().when(workflowProcessSorter).validate(any(), any()); + WorkflowRequest workflowRequest = new WorkflowRequest( + null, + validTemplate, + new String[] { "all" }, + true, + Map.of(WAIT_FOR_COMPLETION_TIMEOUT, "0s"), + false ); // Bypass checkMaxWorkflows and force onResponse diff --git a/src/test/java/org/opensearch/flowframework/util/WorkflowTimeoutUtilityTests.java b/src/test/java/org/opensearch/flowframework/util/WorkflowTimeoutUtilityTests.java new file mode 100644 index 000000000..4cbf8a61d --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/util/WorkflowTimeoutUtilityTests.java @@ -0,0 +1,135 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.util; + +import org.opensearch.client.Client; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.TestHelpers; +import org.opensearch.flowframework.model.WorkflowState; +import org.opensearch.flowframework.transport.GetWorkflowStateAction; +import org.opensearch.flowframework.transport.GetWorkflowStateRequest; +import org.opensearch.flowframework.transport.WorkflowResponse; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.Scheduler; +import org.opensearch.threadpool.ThreadPool; + +import java.time.Instant; +import java.util.Collections; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class WorkflowTimeoutUtilityTests extends OpenSearchTestCase { + + private Client mockClient; + private ThreadPool mockThreadPool; + private Scheduler.ScheduledCancellable mockScheduledCancellable; + private AtomicBoolean isResponseSent; + private ActionListener mockListener; + + @Override + public void setUp() throws Exception { + super.setUp(); + mockClient = mock(Client.class); + mockThreadPool = mock(ThreadPool.class); + mockScheduledCancellable = mock(Scheduler.ScheduledCancellable.class); + isResponseSent = new AtomicBoolean(false); + mockListener = mock(ActionListener.class); + + when(mockThreadPool.schedule(any(Runnable.class), any(TimeValue.class), anyString())).thenReturn(mockScheduledCancellable); + } + + public void testScheduleTimeoutHandler() { + String workflowId = "testWorkflowId"; + long timeout = 1000L; + + ActionListener returnedListener = WorkflowTimeoutUtility.scheduleTimeoutHandler( + mockClient, + mockThreadPool, + workflowId, + mockListener, + timeout, + isResponseSent + ); + + assertNotNull(returnedListener); + verify(mockThreadPool, times(1)).schedule( + any(Runnable.class), + eq(TimeValue.timeValueMillis(timeout)), + eq(ThreadPool.Names.GENERIC) + ); + } + + public void testWrapWithTimeoutCancellationListener_OnResponse() { + WorkflowResponse response = new WorkflowResponse( + "testWorkflowId", + new WorkflowState( + "1", + "test", + "PROVISIONING", + "IN_PROGRESS", + Instant.now(), + Instant.now(), + TestHelpers.randomUser(), + Collections.emptyMap(), + Collections.emptyList() + ) + ); + Scheduler.ScheduledCancellable scheduledCancellable = mock(Scheduler.ScheduledCancellable.class); + + ActionListener wrappedListener = WorkflowTimeoutUtility.wrapWithTimeoutCancellationListener( + mockListener, + scheduledCancellable, + isResponseSent + ); + + wrappedListener.onResponse(response); + + assertTrue(isResponseSent.get()); + verify(scheduledCancellable, times(1)).cancel(); + verify(mockListener, times(1)).onResponse(response); + } + + public void testWrapWithTimeoutCancellationListener_OnFailure() { + Exception exception = new Exception("Test exception"); + Scheduler.ScheduledCancellable scheduledCancellable = mock(Scheduler.ScheduledCancellable.class); + + ActionListener wrappedListener = WorkflowTimeoutUtility.wrapWithTimeoutCancellationListener( + mockListener, + scheduledCancellable, + isResponseSent + ); + + wrappedListener.onFailure(exception); + + assertTrue(isResponseSent.get()); + verify(scheduledCancellable, times(1)).cancel(); + verify(mockListener, times(1)).onFailure(exception); + } + + public void testFetchWorkflowStateAfterTimeout() { + String workflowId = "testWorkflowId"; + ActionListener mockListener = mock(ActionListener.class); + + WorkflowTimeoutUtility.fetchWorkflowStateAfterTimeout(mockClient, workflowId, mockListener); + + verify(mockClient, times(1)).execute( + eq(GetWorkflowStateAction.INSTANCE), + any(GetWorkflowStateRequest.class), + any(ActionListener.class) + ); + } +} From e690bd3fefa9b07f96ca4171f3953d39167b99e3 Mon Sep 17 00:00:00 2001 From: Junwei Dai Date: Mon, 13 Jan 2025 10:39:34 -0800 Subject: [PATCH 05/16] fix spotless check Signed-off-by: Junwei Dai --- .../flowframework/rest/RestCreateWorkflowActionTests.java | 1 - 1 file changed, 1 deletion(-) diff --git a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java index 063881d26..baf8852a8 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java @@ -148,7 +148,6 @@ public void testRestCreateWorkflow_withWaitForCompletionTimeout() throws Excepti assertTrue(channel.capturedResponse().content().utf8ToString().contains("workflow_1")); } - public void testCreateWorkflowRequestWithParamsButNoProvision() throws Exception { RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) .withPath(this.createWorkflowPath) From 014e6c149fb38560f5361e3a6bdc3b3ad806154a Mon Sep 17 00:00:00 2001 From: Junwei Dai Date: Mon, 13 Jan 2025 11:08:53 -0800 Subject: [PATCH 06/16] Limit workflow timeout to a range of 1 to 300 seconds Signed-off-by: Junwei Dai --- .../opensearch/flowframework/util/WorkflowTimeoutUtility.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/flowframework/util/WorkflowTimeoutUtility.java b/src/main/java/org/opensearch/flowframework/util/WorkflowTimeoutUtility.java index 4991e0491..e5956e921 100644 --- a/src/main/java/org/opensearch/flowframework/util/WorkflowTimeoutUtility.java +++ b/src/main/java/org/opensearch/flowframework/util/WorkflowTimeoutUtility.java @@ -18,6 +18,7 @@ import org.opensearch.flowframework.transport.GetWorkflowStateAction; import org.opensearch.flowframework.transport.GetWorkflowStateRequest; import org.opensearch.flowframework.transport.WorkflowResponse; +import org.opensearch.search.aggregations.metrics.Min; import org.opensearch.threadpool.Scheduler; import org.opensearch.threadpool.ThreadPool; @@ -32,7 +33,7 @@ public class WorkflowTimeoutUtility { private static final Logger logger = LogManager.getLogger(WorkflowTimeoutUtility.class); private static final TimeValue MAX_TIMEOUT_MILLIS = TimeValue.timeValueSeconds(300); - + private static final TimeValue MIN_TIMEOUT_MILLIS = TimeValue.timeValueSeconds(1); /** * Schedules a timeout task for a workflow execution. * @@ -53,6 +54,7 @@ public static ActionListener scheduleTimeoutHandler( AtomicBoolean isResponseSent ) { long adjustedTimeout = Math.min(timeout, MAX_TIMEOUT_MILLIS.millis()); + adjustedTimeout = Math.max(adjustedTimeout, MIN_TIMEOUT_MILLIS.millis()); Scheduler.ScheduledCancellable scheduledCancellable = threadPool.schedule( new WorkflowTimeoutListener(client, workflowId, listener, isResponseSent), TimeValue.timeValueMillis(adjustedTimeout), From ce70128cc17a1a087fd8e5ae2967b05a434d7b52 Mon Sep 17 00:00:00 2001 From: Junwei Dai Date: Mon, 13 Jan 2025 11:08:53 -0800 Subject: [PATCH 07/16] Limit workflow timeout to a range of 1 to 300 seconds Signed-off-by: Junwei Dai # Conflicts: # src/main/java/org/opensearch/flowframework/util/WorkflowTimeoutUtility.java --- .../opensearch/flowframework/util/WorkflowTimeoutUtility.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/flowframework/util/WorkflowTimeoutUtility.java b/src/main/java/org/opensearch/flowframework/util/WorkflowTimeoutUtility.java index e5956e921..7ecad2e11 100644 --- a/src/main/java/org/opensearch/flowframework/util/WorkflowTimeoutUtility.java +++ b/src/main/java/org/opensearch/flowframework/util/WorkflowTimeoutUtility.java @@ -18,7 +18,6 @@ import org.opensearch.flowframework.transport.GetWorkflowStateAction; import org.opensearch.flowframework.transport.GetWorkflowStateRequest; import org.opensearch.flowframework.transport.WorkflowResponse; -import org.opensearch.search.aggregations.metrics.Min; import org.opensearch.threadpool.Scheduler; import org.opensearch.threadpool.ThreadPool; @@ -34,6 +33,7 @@ public class WorkflowTimeoutUtility { private static final Logger logger = LogManager.getLogger(WorkflowTimeoutUtility.class); private static final TimeValue MAX_TIMEOUT_MILLIS = TimeValue.timeValueSeconds(300); private static final TimeValue MIN_TIMEOUT_MILLIS = TimeValue.timeValueSeconds(1); + /** * Schedules a timeout task for a workflow execution. * From b9060b4e79c0f656ca4ea18d4dfc8d0c95054725 Mon Sep 17 00:00:00 2001 From: Junwei Dai Date: Mon, 13 Jan 2025 13:17:02 -0800 Subject: [PATCH 08/16] Limit workflow timeout to non-negative Signed-off-by: Junwei Dai --- .../flowframework/util/WorkflowTimeoutUtility.java | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/util/WorkflowTimeoutUtility.java b/src/main/java/org/opensearch/flowframework/util/WorkflowTimeoutUtility.java index 7ecad2e11..341de98da 100644 --- a/src/main/java/org/opensearch/flowframework/util/WorkflowTimeoutUtility.java +++ b/src/main/java/org/opensearch/flowframework/util/WorkflowTimeoutUtility.java @@ -31,8 +31,7 @@ public class WorkflowTimeoutUtility { private static final Logger logger = LogManager.getLogger(WorkflowTimeoutUtility.class); - private static final TimeValue MAX_TIMEOUT_MILLIS = TimeValue.timeValueSeconds(300); - private static final TimeValue MIN_TIMEOUT_MILLIS = TimeValue.timeValueSeconds(1); + private static final TimeValue MIN_TIMEOUT_MILLIS = TimeValue.timeValueSeconds(0); /** * Schedules a timeout task for a workflow execution. @@ -53,8 +52,8 @@ public static ActionListener scheduleTimeoutHandler( long timeout, AtomicBoolean isResponseSent ) { - long adjustedTimeout = Math.min(timeout, MAX_TIMEOUT_MILLIS.millis()); - adjustedTimeout = Math.max(adjustedTimeout, MIN_TIMEOUT_MILLIS.millis()); + // Ensure timeout is within the valid range (non-negative) + long adjustedTimeout = Math.max(timeout, MIN_TIMEOUT_MILLIS.millis()); Scheduler.ScheduledCancellable scheduledCancellable = threadPool.schedule( new WorkflowTimeoutListener(client, workflowId, listener, isResponseSent), TimeValue.timeValueMillis(adjustedTimeout), From edaafe80d1d8cf455ee857119f55a6755d6d9755 Mon Sep 17 00:00:00 2001 From: Junwei Dai Date: Tue, 14 Jan 2025 13:05:42 -0800 Subject: [PATCH 09/16] Add synchronous execution to reprovision Signed-off-by: Junwei Dai --- .../flowframework/common/CommonValue.java | 2 ++ .../CreateWorkflowTransportAction.java | 24 +++++++------- .../transport/ReprovisionWorkflowRequest.java | 32 ++++++++++++++++++- .../transport/WorkflowRequest.java | 12 +++++-- .../ReprovisionWorkflowRequestTests.java | 3 +- ...provisionWorkflowTransportActionTests.java | 11 ++++--- 6 files changed, 62 insertions(+), 22 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index 2fe46996b..e99e689df 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -192,6 +192,8 @@ private CommonValue() {} public static final String SOURCE_INDEX = "source_index"; /** The destination index field for reindex */ public static final String DESTINATION_INDEX = "destination_index"; + /** Provision Timeout field */ + public static final String PROVISION_TIMEOUT_FIELD = "provision.timeout"; /* * Constants associated with resource provisioning / state */ diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index 44c8fab1e..d629d5aee 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -52,6 +52,7 @@ import static java.lang.Boolean.FALSE; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; import static org.opensearch.flowframework.common.CommonValue.PROVISIONING_PROGRESS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_TIMEOUT_FIELD; import static org.opensearch.flowframework.common.CommonValue.STATE_FIELD; import static org.opensearch.flowframework.common.CommonValue.WAIT_FOR_COMPLETION_TIMEOUT; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES; @@ -255,8 +256,7 @@ private void createExecute(WorkflowRequest request, User user, ActionListener { - if (workflowRequest.getWaitForCompletionTimeout() == TimeValue.MINUS_ONE) { - listener.onResponse( - new WorkflowResponse(provisionResponse.getWorkflowId()) - ); - } else { - listener.onResponse( - new WorkflowResponse( + listener.onResponse( + (workflowRequest.getWaitForCompletionTimeout() == TimeValue.MINUS_ONE) + ? new WorkflowResponse(provisionResponse.getWorkflowId()) + : new WorkflowResponse( provisionResponse.getWorkflowId(), provisionResponse.getWorkflowState() ) - ); - } + ); }, exception -> { String errorMessage = "Provisioning failed."; logger.error(errorMessage, exception); @@ -362,12 +358,14 @@ private void createExecute(WorkflowRequest request, User user, ActionListener listener = mock(ActionListener.class); - ReprovisionWorkflowRequest request = new ReprovisionWorkflowRequest(workflowId, mockTemplate, mockTemplate); + ReprovisionWorkflowRequest request = new ReprovisionWorkflowRequest(workflowId, mockTemplate, mockTemplate, TimeValue.MINUS_ONE); reprovisionWorkflowTransportAction.doExecute(mock(Task.class), request, listener); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); @@ -189,7 +190,7 @@ public void testReprovisionProvisioningWorkflow() throws Exception { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - ReprovisionWorkflowRequest request = new ReprovisionWorkflowRequest(workflowId, mockTemplate, mockTemplate); + ReprovisionWorkflowRequest request = new ReprovisionWorkflowRequest(workflowId, mockTemplate, mockTemplate, TimeValue.MINUS_ONE); reprovisionWorkflowTransportAction.doExecute(mock(Task.class), request, listener); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); @@ -229,7 +230,7 @@ public void testReprovisionNotStartedWorkflow() throws Exception { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - ReprovisionWorkflowRequest request = new ReprovisionWorkflowRequest(workflowId, mockTemplate, mockTemplate); + ReprovisionWorkflowRequest request = new ReprovisionWorkflowRequest(workflowId, mockTemplate, mockTemplate, TimeValue.MINUS_ONE); reprovisionWorkflowTransportAction.doExecute(mock(Task.class), request, listener); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); @@ -280,7 +281,7 @@ public void testFailedStateUpdate() throws Exception { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - ReprovisionWorkflowRequest request = new ReprovisionWorkflowRequest(workflowId, mockTemplate, mockTemplate); + ReprovisionWorkflowRequest request = new ReprovisionWorkflowRequest(workflowId, mockTemplate, mockTemplate, TimeValue.MINUS_ONE); reprovisionWorkflowTransportAction.doExecute(mock(Task.class), request, listener); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); @@ -312,7 +313,7 @@ public void testFailedWorkflowStateRetrieval() throws Exception { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - ReprovisionWorkflowRequest request = new ReprovisionWorkflowRequest(workflowId, mockTemplate, mockTemplate); + ReprovisionWorkflowRequest request = new ReprovisionWorkflowRequest(workflowId, mockTemplate, mockTemplate, TimeValue.MINUS_ONE); reprovisionWorkflowTransportAction.doExecute(mock(Task.class), request, listener); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); From 258432956c4f2da3dae51c5495fb2ed30c650ed5 Mon Sep 17 00:00:00 2001 From: Junwei Dai Date: Tue, 14 Jan 2025 15:57:41 -0800 Subject: [PATCH 10/16] remove unsued common value Signed-off-by: Junwei Dai --- .../java/org/opensearch/flowframework/common/CommonValue.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index e99e689df..73163ecdf 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -57,8 +57,6 @@ private CommonValue() {} /** The last provisioned time field */ public static final String LAST_PROVISIONED_TIME_FIELD = "last_provisioned_time"; - public static final TimeValue DEFAULT_WAIT_FOR_COMPLETION_TIMEOUT = TimeValue.timeValueSeconds(1); - /* * Constants associated with Rest or Transport actions */ From 15e052b36043f64b1973dc9216860e8fc6201c2f Mon Sep 17 00:00:00 2001 From: Junwei Dai Date: Tue, 14 Jan 2025 18:43:48 -0800 Subject: [PATCH 11/16] add reprovision sync execution Signed-off-by: Junwei Dai # Conflicts: # src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java --- .../flowframework/common/CommonValue.java | 2 - .../rest/RestCreateWorkflowAction.java | 18 ++++- .../CreateWorkflowTransportAction.java | 9 ++- .../ReprovisionWorkflowTransportAction.java | 81 +++++++++++++++++-- .../transport/WorkflowRequest.java | 4 +- .../transport/WorkflowResponse.java | 2 + .../util/WorkflowTimeoutUtility.java | 5 +- .../rest/RestCreateWorkflowActionTests.java | 20 ++++- 8 files changed, 128 insertions(+), 13 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index 73163ecdf..0a4af0758 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -8,8 +8,6 @@ */ package org.opensearch.flowframework.common; -import org.opensearch.common.unit.TimeValue; - /** * Representation of common values that are used across project */ diff --git a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java index 4abedc365..3de1d1980 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java @@ -12,6 +12,7 @@ import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; import org.opensearch.client.node.NodeClient; +import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContent; @@ -43,6 +44,7 @@ import static org.opensearch.flowframework.common.CommonValue.UPDATE_WORKFLOW_FIELDS; import static org.opensearch.flowframework.common.CommonValue.USE_CASE; import static org.opensearch.flowframework.common.CommonValue.VALIDATION; +import static org.opensearch.flowframework.common.CommonValue.WAIT_FOR_COMPLETION_TIMEOUT; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; @@ -88,6 +90,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli boolean reprovision = request.paramAsBoolean(REPROVISION_WORKFLOW, false); boolean updateFields = request.paramAsBoolean(UPDATE_WORKFLOW_FIELDS, false); String useCase = request.param(USE_CASE); + TimeValue waitForCompletionTimeout = request.paramAsTime(WAIT_FOR_COMPLETION_TIMEOUT, TimeValue.MINUS_ONE); // If provisioning, consume all other params and pass to provision transport action Map params = provision @@ -145,6 +148,17 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli ); return processError(ffe, params, request); } + // Ensure wait_for_completion is not set unless reprovision or provision is true + if (waitForCompletionTimeout != TimeValue.MINUS_ONE && !(reprovision || provision)) { + FlowFrameworkException ffe = new FlowFrameworkException( + "Request parameters " + + request.consumedParams() + + " are not allowed unless the 'provision' or 'reprovision' parameter is set to true.", + RestStatus.BAD_REQUEST + ); + return processError(ffe, params, request); + } + try { Template template; Map useCaseDefaultsMap = Collections.emptyMap(); @@ -219,7 +233,9 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli if (updateFields) { params = Map.of(UPDATE_WORKFLOW_FIELDS, "true"); } - + if (waitForCompletionTimeout != TimeValue.MINUS_ONE) { + params = Map.of(WAIT_FOR_COMPLETION_TIMEOUT, waitForCompletionTimeout.toString()); + } WorkflowRequest workflowRequest = new WorkflowRequest( workflowId, template, diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index d629d5aee..317ae0b26 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -372,7 +372,14 @@ private void createExecute(WorkflowRequest request, User user, ActionListener { - listener.onResponse(new WorkflowResponse(reprovisionResponse.getWorkflowId())); + listener.onResponse( + reprovisionRequest.getWaitForCompletionTimeout() == TimeValue.MINUS_ONE + ? new WorkflowResponse(reprovisionResponse.getWorkflowId()) + : new WorkflowResponse( + reprovisionResponse.getWorkflowId(), + reprovisionResponse.getWorkflowState() + ) + ); }, exception -> { String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( "Reprovisioning failed for workflow {}", diff --git a/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java index 8e501228b..9c9681dc8 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java @@ -19,6 +19,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; @@ -34,6 +35,7 @@ import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowState; import org.opensearch.flowframework.util.EncryptorUtils; +import org.opensearch.flowframework.util.WorkflowTimeoutUtility; import org.opensearch.flowframework.workflow.ProcessNode; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.flowframework.workflow.WorkflowStepFactory; @@ -48,6 +50,8 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; import static org.opensearch.flowframework.common.CommonValue.ERROR_FIELD; @@ -243,9 +247,23 @@ private void executeReprovisionRequest( Template updatedTemplateWithProvisionedTime = Template.builder(updatedTemplate) .lastProvisionedTime(Instant.now()) .build(); - executeWorkflowAsync(workflowId, updatedTemplateWithProvisionedTime, reprovisionProcessSequence, listener); - - listener.onResponse(new WorkflowResponse(workflowId)); + if (request.getWaitForCompletionTimeout() == TimeValue.MINUS_ONE) { + executeWorkflowAsync(workflowId, updatedTemplateWithProvisionedTime, reprovisionProcessSequence, listener); + } else { + executeWorkflowSync( + workflowId, + updatedTemplate, + reprovisionProcessSequence, + listener, + request.getWaitForCompletionTimeout().getMillis() + ); + } + + if (request.getWaitForCompletionTimeout() == TimeValue.MINUS_ONE) { + listener.onResponse(new WorkflowResponse(workflowId)); + } else { + logger.info("Waiting for workflow completion"); + } }, exception -> { String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage("Failed to update workflow state: {}", workflowId) @@ -284,13 +302,42 @@ private void executeWorkflowAsync( try { threadPool.executor(PROVISION_WORKFLOW_THREAD_POOL).execute(() -> { updateTemplate(template, workflowId); - executeWorkflow(template, workflowSequence, workflowId); + executeWorkflow(template, workflowSequence, workflowId, listener, false); }); } catch (Exception exception) { listener.onFailure(new FlowFrameworkException("Failed to execute workflow " + workflowId, ExceptionsHelper.status(exception))); } } + private void executeWorkflowSync( + String workflowId, + Template template, + List workflowSequence, + ActionListener listener, + long timeout + ) { + AtomicBoolean isResponseSent = new AtomicBoolean(false); + CompletableFuture.runAsync(() -> { + try { + updateTemplate(template, workflowId); + executeWorkflow(template, workflowSequence, workflowId, new ActionListener<>() { + @Override + public void onResponse(WorkflowResponse workflowResponse) { + WorkflowTimeoutUtility.handleResponse(workflowId, workflowResponse, isResponseSent, listener); + } + + @Override + public void onFailure(Exception e) { + WorkflowTimeoutUtility.handleFailure(workflowId, e, isResponseSent, listener); + } + }, true); + } catch (Exception ex) { + WorkflowTimeoutUtility.handleFailure(workflowId, ex, isResponseSent, listener); + } + }, threadPool.executor(PROVISION_WORKFLOW_THREAD_POOL)); + WorkflowTimeoutUtility.scheduleTimeoutHandler(client, threadPool, workflowId, listener, timeout, isResponseSent); + } + /** * Replace template document * @param template The template to store after reprovisioning completes successfully @@ -310,7 +357,13 @@ private void updateTemplate(Template template, String workflowId) { * @param workflowSequence The topologically sorted workflow to execute * @param workflowId The workflowId associated with the workflow that is executing */ - private void executeWorkflow(Template template, List workflowSequence, String workflowId) { + private void executeWorkflow( + Template template, + List workflowSequence, + String workflowId, + ActionListener listener, + boolean isSyncExecution + ) { String currentStepId = ""; try { Map> workflowFutureMap = new LinkedHashMap<>(); @@ -349,7 +402,23 @@ private void executeWorkflow(Template template, List workflowSequen ActionListener.wrap(updateResponse -> { logger.info("updated workflow {} state to {}", workflowId, State.COMPLETED); - + if (isSyncExecution) { + client.execute( + GetWorkflowStateAction.INSTANCE, + new GetWorkflowStateRequest(workflowId, false), + ActionListener.wrap(response -> { + listener.onResponse(new WorkflowResponse(workflowId, response.getWorkflowState())); + }, exception -> { + String errorMessage = "Failed to get workflow state."; + logger.error(errorMessage, exception); + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + } + }) + ); + } }, exception -> { logger.error("Failed to update workflow state for workflow {}", workflowId, exception); }) ); } catch (Exception ex) { diff --git a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java index 0e1d32cce..ec7b762e0 100644 --- a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java +++ b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java @@ -23,6 +23,7 @@ import static org.opensearch.flowframework.common.CommonValue.REPROVISION_WORKFLOW; import static org.opensearch.flowframework.common.CommonValue.UPDATE_WORKFLOW_FIELDS; +import static org.opensearch.flowframework.common.CommonValue.WAIT_FOR_COMPLETION_TIMEOUT; /** * Transport Request to create, provision, and deprovision a workflow @@ -154,7 +155,8 @@ public WorkflowRequest( this.validation = validation; this.provision = provisionOrUpdate && !params.containsKey(UPDATE_WORKFLOW_FIELDS); this.updateFields = !provision && Boolean.parseBoolean(params.get(UPDATE_WORKFLOW_FIELDS)); - if (!this.provision && params.keySet().stream().anyMatch(k -> !UPDATE_WORKFLOW_FIELDS.equals(k))) { + if (!this.provision + && params.keySet().stream().anyMatch(k -> !UPDATE_WORKFLOW_FIELDS.equals(k) && !WAIT_FOR_COMPLETION_TIMEOUT.equals(k))) { throw new IllegalArgumentException("Params may only be included when provisioning."); } this.params = this.updateFields ? Collections.emptyMap() : params; diff --git a/src/main/java/org/opensearch/flowframework/transport/WorkflowResponse.java b/src/main/java/org/opensearch/flowframework/transport/WorkflowResponse.java index 1f8dd5681..a75999d85 100644 --- a/src/main/java/org/opensearch/flowframework/transport/WorkflowResponse.java +++ b/src/main/java/org/opensearch/flowframework/transport/WorkflowResponse.java @@ -9,6 +9,7 @@ package org.opensearch.flowframework.transport; import org.opensearch.Version; +import org.opensearch.common.Nullable; import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -67,6 +68,7 @@ public String getWorkflowId() { * Gets the workflowState of this repsonse * @return the workflowState */ + @Nullable public WorkflowState getWorkflowState() { return this.workflowState; } diff --git a/src/main/java/org/opensearch/flowframework/util/WorkflowTimeoutUtility.java b/src/main/java/org/opensearch/flowframework/util/WorkflowTimeoutUtility.java index 341de98da..703da5af6 100644 --- a/src/main/java/org/opensearch/flowframework/util/WorkflowTimeoutUtility.java +++ b/src/main/java/org/opensearch/flowframework/util/WorkflowTimeoutUtility.java @@ -23,6 +23,8 @@ import java.util.concurrent.atomic.AtomicBoolean; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; + /** * Utility class for managing timeout tasks in workflow execution. * This class provides methods to schedule timeout handlers, wrap listeners with timeout cancellation logic, @@ -57,7 +59,7 @@ public static ActionListener scheduleTimeoutHandler( Scheduler.ScheduledCancellable scheduledCancellable = threadPool.schedule( new WorkflowTimeoutListener(client, workflowId, listener, isResponseSent), TimeValue.timeValueMillis(adjustedTimeout), - ThreadPool.Names.GENERIC + PROVISION_WORKFLOW_THREAD_POOL ); return wrapWithTimeoutCancellationListener(listener, scheduledCancellable, isResponseSent); @@ -181,6 +183,7 @@ public static void fetchWorkflowStateAfterTimeout( final String workflowId, final ActionListener listener ) { + logger.info("Fetching workflow state after timeout"); client.execute( GetWorkflowStateAction.INSTANCE, new GetWorkflowStateRequest(workflowId, false), diff --git a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java index baf8852a8..747de4351 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java @@ -39,6 +39,7 @@ import static org.opensearch.flowframework.common.CommonValue.REPROVISION_WORKFLOW; import static org.opensearch.flowframework.common.CommonValue.UPDATE_WORKFLOW_FIELDS; import static org.opensearch.flowframework.common.CommonValue.USE_CASE; +import static org.opensearch.flowframework.common.CommonValue.WAIT_FOR_COMPLETION_TIMEOUT; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; @@ -128,7 +129,7 @@ public void testCreateWorkflowRequestWithParamsAndProvision() throws Exception { assertTrue(channel.capturedResponse().content().utf8ToString().contains("id-123")); } - public void testRestCreateWorkflow_withWaitForCompletionTimeout() throws Exception { + public void testRestCreateWorkflowWithWaitForCompletionTimeout() throws Exception { RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) .withParams(Map.ofEntries(Map.entry(PROVISION_WORKFLOW, "true"), Map.entry("wait_for_completion_timeout", "5s"))) .withContent(new BytesArray(validTemplate), MediaTypeRegistry.JSON) @@ -162,6 +163,23 @@ public void testCreateWorkflowRequestWithParamsButNoProvision() throws Exception ); } + public void testCreateWorkflowRequestWithWaitForTimeCompletionTimeoutButNoProvision() throws Exception { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.createWorkflowPath) + .withParams(Map.of(WAIT_FOR_COMPLETION_TIMEOUT, "1s")) + .withContent(new BytesArray(validTemplate), MediaTypeRegistry.JSON) + .build(); + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + createWorkflowRestAction.handleRequest(request, channel, nodeClient); + assertEquals(RestStatus.BAD_REQUEST, channel.capturedResponse().status()); + assertTrue( + channel.capturedResponse() + .content() + .utf8ToString() + .contains("are not allowed unless the 'provision' or 'reprovision' parameter is set to true.") + ); + } + public void testCreateWorkflowRequestWithUpdateAndProvision() throws Exception { RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) .withPath(this.createWorkflowPath) From 95816c9813f896ef3e1606d516173ae7e29752ed Mon Sep 17 00:00:00 2001 From: Junwei Dai Date: Wed, 15 Jan 2025 10:27:44 -0800 Subject: [PATCH 12/16] fix test for WorkflowTimeoutUtilityTests Signed-off-by: Junwei Dai --- .../flowframework/util/WorkflowTimeoutUtilityTests.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/test/java/org/opensearch/flowframework/util/WorkflowTimeoutUtilityTests.java b/src/test/java/org/opensearch/flowframework/util/WorkflowTimeoutUtilityTests.java index 4cbf8a61d..24ceb77ca 100644 --- a/src/test/java/org/opensearch/flowframework/util/WorkflowTimeoutUtilityTests.java +++ b/src/test/java/org/opensearch/flowframework/util/WorkflowTimeoutUtilityTests.java @@ -24,6 +24,7 @@ import java.util.Collections; import java.util.concurrent.atomic.AtomicBoolean; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; @@ -69,7 +70,7 @@ public void testScheduleTimeoutHandler() { verify(mockThreadPool, times(1)).schedule( any(Runnable.class), eq(TimeValue.timeValueMillis(timeout)), - eq(ThreadPool.Names.GENERIC) + eq(PROVISION_WORKFLOW_THREAD_POOL) ); } From 18a1dbb84b9060498e9b4e2a74f824aff7288485 Mon Sep 17 00:00:00 2001 From: Junwei Dai Date: Wed, 15 Jan 2025 10:29:13 -0800 Subject: [PATCH 13/16] fix test name for WorkflowTimeoutUtilityTests Signed-off-by: Junwei Dai --- .../flowframework/util/WorkflowTimeoutUtilityTests.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/test/java/org/opensearch/flowframework/util/WorkflowTimeoutUtilityTests.java b/src/test/java/org/opensearch/flowframework/util/WorkflowTimeoutUtilityTests.java index 24ceb77ca..d7dcaedab 100644 --- a/src/test/java/org/opensearch/flowframework/util/WorkflowTimeoutUtilityTests.java +++ b/src/test/java/org/opensearch/flowframework/util/WorkflowTimeoutUtilityTests.java @@ -74,7 +74,7 @@ public void testScheduleTimeoutHandler() { ); } - public void testWrapWithTimeoutCancellationListener_OnResponse() { + public void testWrapWithTimeoutCancellationListenerOnResponse() { WorkflowResponse response = new WorkflowResponse( "testWorkflowId", new WorkflowState( @@ -104,7 +104,7 @@ public void testWrapWithTimeoutCancellationListener_OnResponse() { verify(mockListener, times(1)).onResponse(response); } - public void testWrapWithTimeoutCancellationListener_OnFailure() { + public void testWrapWithTimeoutCancellationListenerOnFailure() { Exception exception = new Exception("Test exception"); Scheduler.ScheduledCancellable scheduledCancellable = mock(Scheduler.ScheduledCancellable.class); From 6e8f1472167ec238b524948d4405f1f538595aec Mon Sep 17 00:00:00 2001 From: Junwei Dai Date: Wed, 15 Jan 2025 12:16:17 -0800 Subject: [PATCH 14/16] Add comments to explain AtomicBoolean usage in WorkflowTimeoutUtility, update error message Signed-off-by: Junwei Dai --- .../rest/RestCreateWorkflowAction.java | 4 +--- .../CreateWorkflowTransportAction.java | 24 ++++++++++++++----- .../util/WorkflowTimeoutUtility.java | 8 +++++-- 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java index 3de1d1980..b106b05f2 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java @@ -151,9 +151,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli // Ensure wait_for_completion is not set unless reprovision or provision is true if (waitForCompletionTimeout != TimeValue.MINUS_ONE && !(reprovision || provision)) { FlowFrameworkException ffe = new FlowFrameworkException( - "Request parameters " - + request.consumedParams() - + " are not allowed unless the 'provision' or 'reprovision' parameter is set to true.", + "Request parameters 'wait_for_completion_timeout' are not allowed unless the 'provision' or 'reprovision' parameter is set to true.", RestStatus.BAD_REQUEST ); return processError(ffe, params, request); diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index 317ae0b26..a7fa3b7d0 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -250,13 +250,19 @@ private void createExecute(WorkflowRequest request, User user, ActionListener { logger.info("Creating state workflow doc: {}", globalContextResponse.getId()); if (request.isProvision()) { - String waitForTimeCompletion = request.getParams() - .getOrDefault(WAIT_FOR_COMPLETION_TIMEOUT, TimeValue.MINUS_ONE.toString()); + // default to minus one indicate async execution + TimeValue waitForTimeCompletion = TimeValue.MINUS_ONE; + if (request.getParams().containsKey(WAIT_FOR_COMPLETION_TIMEOUT)) { + waitForTimeCompletion = TimeValue.parseTimeValue( + request.getParams().get(WAIT_FOR_COMPLETION_TIMEOUT), + WAIT_FOR_COMPLETION_TIMEOUT + ); + } WorkflowRequest workflowRequest = new WorkflowRequest( globalContextResponse.getId(), null, request.getParams(), - TimeValue.parseTimeValue(waitForTimeCompletion, PROVISION_TIMEOUT_FIELD) + waitForTimeCompletion ); logger.info( "Provisioning parameter is set, continuing to provision workflow {}", @@ -358,14 +364,20 @@ private void createExecute(WorkflowRequest request, User user, ActionListener scheduleTimeoutHandler( AtomicBoolean isResponseSent ) { // Ensure timeout is within the valid range (non-negative) - long adjustedTimeout = Math.max(timeout, MIN_TIMEOUT_MILLIS.millis()); + long adjustedTimeout = Math.max(timeout, TimeValue.timeValueMillis(0).millis()); Scheduler.ScheduledCancellable scheduledCancellable = threadPool.schedule( new WorkflowTimeoutListener(client, workflowId, listener, isResponseSent), TimeValue.timeValueMillis(adjustedTimeout), @@ -83,6 +82,7 @@ private static class WorkflowTimeoutListener implements Runnable { @Override public void run() { + // This AtomicBoolean ensures that the timeout logic is executed only once, preventing duplicate responses. if (isResponseSent.compareAndSet(false, true)) { logger.warn("Workflow execution timed out for workflowId: {}", workflowId); fetchWorkflowStateAfterTimeout(client, workflowId, listener); @@ -107,6 +107,7 @@ public static ActionListener wrapWithTimeoutCancellationLis return new ActionListener<>() { @Override public void onResponse(Response response) { + // Cancel the timeout task if the response is successfully sent. if (isResponseSent.compareAndSet(false, true)) { scheduledCancellable.cancel(); } @@ -115,6 +116,7 @@ public void onResponse(Response response) { @Override public void onFailure(Exception e) { + // Cancel the timeout task if an error occurs and the failure is reported. if (isResponseSent.compareAndSet(false, true)) { scheduledCancellable.cancel(); } @@ -137,6 +139,7 @@ public static void handleResponse( AtomicBoolean isResponseSent, ActionListener listener ) { + // Check if the response has already been sent, and send it only if it hasn't been sent yet. if (isResponseSent.compareAndSet(false, true)) { listener.onResponse(new WorkflowResponse(workflowResponse.getWorkflowId(), workflowResponse.getWorkflowState())); } else { @@ -158,6 +161,7 @@ public static void handleFailure( AtomicBoolean isResponseSent, ActionListener listener ) { + // Check if the failure has already been reported, and report it only if it hasn't been reported yet. if (isResponseSent.compareAndSet(false, true)) { FlowFrameworkException exception = new FlowFrameworkException( "Failed to execute workflow " + workflowId, From 257564a2e6a6052b40d14f8f758da851b4defc00 Mon Sep 17 00:00:00 2001 From: Junwei Dai Date: Wed, 15 Jan 2025 13:26:54 -0800 Subject: [PATCH 15/16] fix spotless check Signed-off-by: Junwei Dai --- .../flowframework/transport/CreateWorkflowTransportAction.java | 1 - 1 file changed, 1 deletion(-) diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index a7fa3b7d0..7a9c45480 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -52,7 +52,6 @@ import static java.lang.Boolean.FALSE; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; import static org.opensearch.flowframework.common.CommonValue.PROVISIONING_PROGRESS_FIELD; -import static org.opensearch.flowframework.common.CommonValue.PROVISION_TIMEOUT_FIELD; import static org.opensearch.flowframework.common.CommonValue.STATE_FIELD; import static org.opensearch.flowframework.common.CommonValue.WAIT_FOR_COMPLETION_TIMEOUT; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES; From eac9355479dacfc44acd43bc3d81df0302c85eb8 Mon Sep 17 00:00:00 2001 From: Junwei Dai Date: Wed, 15 Jan 2025 15:14:03 -0800 Subject: [PATCH 16/16] addressed some comments Signed-off-by: Junwei Dai --- .../CreateWorkflowTransportAction.java | 26 +++++++------------ .../transport/ReprovisionWorkflowRequest.java | 6 ++--- .../transport/WorkflowRequest.java | 12 ++++----- .../transport/WorkflowResponse.java | 6 ++--- 4 files changed, 19 insertions(+), 31 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index 7a9c45480..ff1df88b0 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -215,6 +215,16 @@ private void createExecute(WorkflowRequest request, User user, ActionListener { logger.info("Creating state workflow doc: {}", globalContextResponse.getId()); if (request.isProvision()) { - // default to minus one indicate async execution - TimeValue waitForTimeCompletion = TimeValue.MINUS_ONE; - if (request.getParams().containsKey(WAIT_FOR_COMPLETION_TIMEOUT)) { - waitForTimeCompletion = TimeValue.parseTimeValue( - request.getParams().get(WAIT_FOR_COMPLETION_TIMEOUT), - WAIT_FOR_COMPLETION_TIMEOUT - ); - } WorkflowRequest workflowRequest = new WorkflowRequest( globalContextResponse.getId(), null, @@ -363,14 +365,6 @@ private void createExecute(WorkflowRequest request, User user, ActionListener