From 1c49942f4279e4bc61bf377a5fbc45fdf55c0918 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 29 Dec 2023 21:26:47 +0000 Subject: [PATCH] Support deploy=true on RegisterRemoteModelStep (#340) * Support deploy=true on RegisterRemoteModelStep Signed-off-by: Daniel Widdis * Hardcode function_name to remote Signed-off-by: Daniel Widdis --------- Signed-off-by: Daniel Widdis (cherry picked from commit a45b38c0fb76cb9450c219e9babd7a6966b97c0c) Signed-off-by: github-actions[bot] --- .../flowframework/common/CommonValue.java | 2 + .../workflow/RegisterRemoteModelStep.java | 132 +++++++++++------- .../resources/mappings/workflow-steps.json | 1 - .../RegisterRemoteModelStepTests.java | 56 +++++++- 4 files changed, 135 insertions(+), 56 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index f6bc8459c..b4955de15 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -106,6 +106,8 @@ private CommonValue() {} public static final String MODEL_GROUP_STATUS = "model_group_status"; /** Description field */ public static final String DESCRIPTION_FIELD = "description"; + /** Description field */ + public static final String DEPLOY_FIELD = "deploy"; /** Model format field */ public static final String MODEL_FORMAT = "model_format"; /** Model content hash value field */ diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java index f3989287c..4ce4eed78 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java @@ -11,6 +11,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.core.action.ActionListener; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; @@ -21,13 +22,12 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; -import java.util.Locale; import java.util.Map; import java.util.Set; import java.util.concurrent.CompletableFuture; +import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD; import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; -import static org.opensearch.flowframework.common.CommonValue.FUNCTION_NAME; import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID; @@ -68,53 +68,8 @@ public CompletableFuture execute( CompletableFuture registerRemoteModelFuture = new CompletableFuture<>(); - ActionListener actionListener = new ActionListener<>() { - @Override - public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { - - try { - logger.info("Remote Model registration successful"); - String resourceName = getResourceByWorkflowStep(getName()); - flowFrameworkIndicesHandler.updateResourceInStateIndex( - currentNodeInputs.getWorkflowId(), - currentNodeId, - getName(), - mlRegisterModelResponse.getModelId(), - ActionListener.wrap(response -> { - logger.info("successfully updated resources created in state index: {}", response.getIndex()); - registerRemoteModelFuture.complete( - new WorkflowData( - Map.ofEntries( - Map.entry(resourceName, mlRegisterModelResponse.getModelId()), - Map.entry(REGISTER_MODEL_STATUS, mlRegisterModelResponse.getStatus()) - ), - currentNodeInputs.getWorkflowId(), - currentNodeInputs.getNodeId() - ) - ); - }, exception -> { - logger.error("Failed to update new created resource", exception); - registerRemoteModelFuture.completeExceptionally( - new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) - ); - }) - ); - - } catch (Exception e) { - logger.error("Failed to parse and update new created resource", e); - registerRemoteModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); - } - } - - @Override - public void onFailure(Exception e) { - logger.error("Failed to register remote model"); - registerRemoteModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); - } - }; - - Set requiredKeys = Set.of(NAME_FIELD, FUNCTION_NAME, CONNECTOR_ID); - Set optionalKeys = Set.of(MODEL_GROUP_ID, DESCRIPTION_FIELD); + Set requiredKeys = Set.of(NAME_FIELD, CONNECTOR_ID); + Set optionalKeys = Set.of(MODEL_GROUP_ID, DESCRIPTION_FIELD, DEPLOY_FIELD); try { Map inputs = ParseUtils.getInputsFromPreviousSteps( @@ -126,13 +81,13 @@ public void onFailure(Exception e) { ); String modelName = (String) inputs.get(NAME_FIELD); - FunctionName functionName = FunctionName.from(((String) inputs.get(FUNCTION_NAME)).toUpperCase(Locale.ROOT)); String modelGroupId = (String) inputs.get(MODEL_GROUP_ID); String description = (String) inputs.get(DESCRIPTION_FIELD); String connectorId = (String) inputs.get(CONNECTOR_ID); + final Boolean deploy = (Boolean) inputs.get(DEPLOY_FIELD); MLRegisterModelInputBuilder builder = MLRegisterModelInput.builder() - .functionName(functionName) + .functionName(FunctionName.REMOTE) .modelName(modelName) .connectorId(connectorId); @@ -142,9 +97,82 @@ public void onFailure(Exception e) { if (description != null) { builder.description(description); } + if (deploy != null) { + builder.deployModel(deploy); + } MLRegisterModelInput mlInput = builder.build(); - mlClient.register(mlInput, actionListener); + mlClient.register(mlInput, new ActionListener() { + @Override + public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { + + try { + logger.info("Remote Model registration successful"); + String resourceName = getResourceByWorkflowStep(getName()); + flowFrameworkIndicesHandler.updateResourceInStateIndex( + currentNodeInputs.getWorkflowId(), + currentNodeId, + getName(), + mlRegisterModelResponse.getModelId(), + ActionListener.wrap(response -> { + // If we deployed, simulate the deploy step has been called + if (Boolean.TRUE.equals(deploy)) { + flowFrameworkIndicesHandler.updateResourceInStateIndex( + currentNodeInputs.getWorkflowId(), + currentNodeId, + DeployModelStep.NAME, + mlRegisterModelResponse.getModelId(), + ActionListener.wrap(deployUpdateResponse -> { + completeRegisterFuture(deployUpdateResponse, resourceName, mlRegisterModelResponse); + }, deployUpdateException -> { + logger.error("Failed to update simulated deploy step resource", deployUpdateException); + registerRemoteModelFuture.completeExceptionally( + new FlowFrameworkException( + deployUpdateException.getMessage(), + ExceptionsHelper.status(deployUpdateException) + ) + ); + }) + ); + } else { + completeRegisterFuture(response, resourceName, mlRegisterModelResponse); + } + }, exception -> { + logger.error("Failed to update new created resource", exception); + registerRemoteModelFuture.completeExceptionally( + new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) + ); + }) + ); + + } catch (Exception e) { + logger.error("Failed to parse and update new created resource", e); + registerRemoteModelFuture.completeExceptionally( + new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)) + ); + } + } + + void completeRegisterFuture(UpdateResponse response, String resourceName, MLRegisterModelResponse mlRegisterModelResponse) { + logger.info("successfully updated resources created in state index: {}", response.getIndex()); + registerRemoteModelFuture.complete( + new WorkflowData( + Map.ofEntries( + Map.entry(resourceName, mlRegisterModelResponse.getModelId()), + Map.entry(REGISTER_MODEL_STATUS, mlRegisterModelResponse.getStatus()) + ), + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() + ) + ); + } + + @Override + public void onFailure(Exception e) { + logger.error("Failed to register remote model"); + registerRemoteModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } + }); } catch (FlowFrameworkException e) { registerRemoteModelFuture.completeExceptionally(e); diff --git a/src/main/resources/mappings/workflow-steps.json b/src/main/resources/mappings/workflow-steps.json index 24c740473..6431b2aa6 100644 --- a/src/main/resources/mappings/workflow-steps.json +++ b/src/main/resources/mappings/workflow-steps.json @@ -79,7 +79,6 @@ "register_remote_model": { "inputs": [ "name", - "function_name", "connector_id" ], "outputs": [ diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java index a44a8290d..2bc57f888 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java @@ -30,6 +30,7 @@ import org.mockito.MockitoAnnotations; import static org.opensearch.action.DocWriteResponse.Result.UPDATED; +import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID; @@ -59,7 +60,7 @@ public void setUp() throws Exception { this.registerRemoteModelStep = new RegisterRemoteModelStep(mlNodeClient, flowFrameworkIndicesHandler); this.workflowData = new WorkflowData( Map.ofEntries( - Map.entry("function_name", "remote"), + Map.entry("function_name", "ignored"), Map.entry("name", "xyz"), Map.entry("description", "description"), Map.entry(CONNECTOR_ID, "abcdefg") @@ -96,14 +97,63 @@ public void testRegisterRemoteModelSuccess() throws Exception { ); verify(mlNodeClient, times(1)).register(any(MLRegisterModelInput.class), any()); + // only updates register resource + verify(flowFrameworkIndicesHandler, times(1)).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); assertTrue(future.isDone()); - assertTrue(!future.isCompletedExceptionally()); + assertFalse(future.isCompletedExceptionally()); assertEquals(modelId, future.get().getContent().get(MODEL_ID)); assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); } + public void testRegisterAndDeployRemoteModelSuccess() throws Exception { + + String taskId = "abcd"; + String modelId = "efgh"; + String status = MLTaskState.CREATED.name(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, modelId); + actionListener.onResponse(output); + return null; + }).when(mlNodeClient).register(any(MLRegisterModelInput.class), any()); + + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + return null; + }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + + WorkflowData deployWorkflowData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "xyz"), + Map.entry("description", "description"), + Map.entry(CONNECTOR_ID, "abcdefg"), + Map.entry(DEPLOY_FIELD, true) + ), + "test-id", + "test-node-id" + ); + + CompletableFuture future = this.registerRemoteModelStep.execute( + deployWorkflowData.getNodeId(), + deployWorkflowData, + Collections.emptyMap(), + Collections.emptyMap() + ); + + verify(mlNodeClient, times(1)).register(any(MLRegisterModelInput.class), any()); + // updates both register and deploy resources + verify(flowFrameworkIndicesHandler, times(2)).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + + assertTrue(future.isDone()); + assertFalse(future.isCompletedExceptionally()); + assertEquals(modelId, future.get().getContent().get(MODEL_ID)); + assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); + } + public void testRegisterRemoteModelFailure() { doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); @@ -137,7 +187,7 @@ public void testMissingInputs() { ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); assertTrue(ex.getCause() instanceof FlowFrameworkException); assertTrue(ex.getCause().getMessage().startsWith("Missing required inputs [")); - for (String s : new String[] { "name", "function_name", CONNECTOR_ID }) { + for (String s : new String[] { "name", CONNECTOR_ID }) { assertTrue(ex.getCause().getMessage().contains(s)); } assertTrue(ex.getCause().getMessage().endsWith("] in workflow [test-id] node [test-node-id]"));