Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support deploy=true on RegisterRemoteModelStep #340

Merged
merged 2 commits into from
Dec 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ private CommonValue() {}
public static final String MODEL_GROUP_STATUS = "model_group_status";
/** Description field */
public static final String DESCRIPTION_FIELD = "description";
/** Description field */
public static final String DEPLOY_FIELD = "deploy";
/** Model format field */
public static final String MODEL_FORMAT = "model_format";
/** Model content hash value field */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.core.action.ActionListener;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
Expand All @@ -21,13 +22,12 @@
import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;

import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;

import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.FUNCTION_NAME;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;
import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS;
import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID;
Expand Down Expand Up @@ -68,53 +68,8 @@

CompletableFuture<WorkflowData> registerRemoteModelFuture = new CompletableFuture<>();

ActionListener<MLRegisterModelResponse> actionListener = new ActionListener<>() {
@Override
public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) {

try {
logger.info("Remote Model registration successful");
String resourceName = getResourceByWorkflowStep(getName());
flowFrameworkIndicesHandler.updateResourceInStateIndex(
currentNodeInputs.getWorkflowId(),
currentNodeId,
getName(),
mlRegisterModelResponse.getModelId(),
ActionListener.wrap(response -> {
logger.info("successfully updated resources created in state index: {}", response.getIndex());
registerRemoteModelFuture.complete(
new WorkflowData(
Map.ofEntries(
Map.entry(resourceName, mlRegisterModelResponse.getModelId()),
Map.entry(REGISTER_MODEL_STATUS, mlRegisterModelResponse.getStatus())
),
currentNodeInputs.getWorkflowId(),
currentNodeInputs.getNodeId()
)
);
}, exception -> {
logger.error("Failed to update new created resource", exception);
registerRemoteModelFuture.completeExceptionally(
new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))
);
})
);

} catch (Exception e) {
logger.error("Failed to parse and update new created resource", e);
registerRemoteModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
}
}

@Override
public void onFailure(Exception e) {
logger.error("Failed to register remote model");
registerRemoteModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
}
};

Set<String> requiredKeys = Set.of(NAME_FIELD, FUNCTION_NAME, CONNECTOR_ID);
Set<String> optionalKeys = Set.of(MODEL_GROUP_ID, DESCRIPTION_FIELD);
Set<String> requiredKeys = Set.of(NAME_FIELD, CONNECTOR_ID);
Set<String> optionalKeys = Set.of(MODEL_GROUP_ID, DESCRIPTION_FIELD, DEPLOY_FIELD);

try {
Map<String, Object> inputs = ParseUtils.getInputsFromPreviousSteps(
Expand All @@ -126,13 +81,13 @@
);

String modelName = (String) inputs.get(NAME_FIELD);
FunctionName functionName = FunctionName.from(((String) inputs.get(FUNCTION_NAME)).toUpperCase(Locale.ROOT));
String modelGroupId = (String) inputs.get(MODEL_GROUP_ID);
String description = (String) inputs.get(DESCRIPTION_FIELD);
String connectorId = (String) inputs.get(CONNECTOR_ID);
final Boolean deploy = (Boolean) inputs.get(DEPLOY_FIELD);

MLRegisterModelInputBuilder builder = MLRegisterModelInput.builder()
.functionName(functionName)
.functionName(FunctionName.REMOTE)
.modelName(modelName)
.connectorId(connectorId);

Expand All @@ -142,9 +97,82 @@
if (description != null) {
builder.description(description);
}
if (deploy != null) {
builder.deployModel(deploy);
}
MLRegisterModelInput mlInput = builder.build();

mlClient.register(mlInput, actionListener);
mlClient.register(mlInput, new ActionListener<MLRegisterModelResponse>() {
@Override
public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) {

try {
logger.info("Remote Model registration successful");
String resourceName = getResourceByWorkflowStep(getName());
flowFrameworkIndicesHandler.updateResourceInStateIndex(
currentNodeInputs.getWorkflowId(),
currentNodeId,
getName(),
mlRegisterModelResponse.getModelId(),
ActionListener.wrap(response -> {
// If we deployed, simulate the deploy step has been called
if (Boolean.TRUE.equals(deploy)) {
flowFrameworkIndicesHandler.updateResourceInStateIndex(
currentNodeInputs.getWorkflowId(),
currentNodeId,
DeployModelStep.NAME,
mlRegisterModelResponse.getModelId(),
ActionListener.wrap(deployUpdateResponse -> {
completeRegisterFuture(deployUpdateResponse, resourceName, mlRegisterModelResponse);
}, deployUpdateException -> {
logger.error("Failed to update simulated deploy step resource", deployUpdateException);
registerRemoteModelFuture.completeExceptionally(

Check warning on line 129 in src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java#L128-L129

Added lines #L128 - L129 were not covered by tests
new FlowFrameworkException(
deployUpdateException.getMessage(),
ExceptionsHelper.status(deployUpdateException)

Check warning on line 132 in src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java#L131-L132

Added lines #L131 - L132 were not covered by tests
)
);
})

Check warning on line 135 in src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java#L135

Added line #L135 was not covered by tests
);
} else {
completeRegisterFuture(response, resourceName, mlRegisterModelResponse);
}
}, exception -> {
logger.error("Failed to update new created resource", exception);
registerRemoteModelFuture.completeExceptionally(
new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))

Check warning on line 143 in src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java#L141-L143

Added lines #L141 - L143 were not covered by tests
);
})

Check warning on line 145 in src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java#L145

Added line #L145 was not covered by tests
);

} catch (Exception e) {
logger.error("Failed to parse and update new created resource", e);
registerRemoteModelFuture.completeExceptionally(
new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))

Check warning on line 151 in src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java#L148-L151

Added lines #L148 - L151 were not covered by tests
);
}
}

void completeRegisterFuture(UpdateResponse response, String resourceName, MLRegisterModelResponse mlRegisterModelResponse) {
logger.info("successfully updated resources created in state index: {}", response.getIndex());
registerRemoteModelFuture.complete(
new WorkflowData(
Map.ofEntries(
Map.entry(resourceName, mlRegisterModelResponse.getModelId()),
Map.entry(REGISTER_MODEL_STATUS, mlRegisterModelResponse.getStatus())
),
currentNodeInputs.getWorkflowId(),
currentNodeInputs.getNodeId()
)
);
}

@Override
public void onFailure(Exception e) {
logger.error("Failed to register remote model");
registerRemoteModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
}
});

} catch (FlowFrameworkException e) {
registerRemoteModelFuture.completeExceptionally(e);
Expand Down
1 change: 0 additions & 1 deletion src/main/resources/mappings/workflow-steps.json
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@
"register_remote_model": {
"inputs": [
"name",
"function_name",
"connector_id"
],
"outputs": [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.mockito.MockitoAnnotations;

import static org.opensearch.action.DocWriteResponse.Result.UPDATED;
import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD;
import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX;
import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID;
Expand Down Expand Up @@ -59,7 +60,7 @@ public void setUp() throws Exception {
this.registerRemoteModelStep = new RegisterRemoteModelStep(mlNodeClient, flowFrameworkIndicesHandler);
this.workflowData = new WorkflowData(
Map.ofEntries(
Map.entry("function_name", "remote"),
Map.entry("function_name", "ignored"),
Map.entry("name", "xyz"),
Map.entry("description", "description"),
Map.entry(CONNECTOR_ID, "abcdefg")
Expand Down Expand Up @@ -96,14 +97,63 @@ public void testRegisterRemoteModelSuccess() throws Exception {
);

verify(mlNodeClient, times(1)).register(any(MLRegisterModelInput.class), any());
// only updates register resource
verify(flowFrameworkIndicesHandler, times(1)).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any());

assertTrue(future.isDone());
assertTrue(!future.isCompletedExceptionally());
assertFalse(future.isCompletedExceptionally());
assertEquals(modelId, future.get().getContent().get(MODEL_ID));
assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS));

}

public void testRegisterAndDeployRemoteModelSuccess() throws Exception {

String taskId = "abcd";
String modelId = "efgh";
String status = MLTaskState.CREATED.name();

doAnswer(invocation -> {
ActionListener<MLRegisterModelResponse> actionListener = invocation.getArgument(1);
MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, modelId);
actionListener.onResponse(output);
return null;
}).when(mlNodeClient).register(any(MLRegisterModelInput.class), any());

doAnswer(invocation -> {
ActionListener<UpdateResponse> updateResponseListener = invocation.getArgument(4);
updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED));
return null;
}).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any());

WorkflowData deployWorkflowData = new WorkflowData(
Map.ofEntries(
Map.entry("name", "xyz"),
Map.entry("description", "description"),
Map.entry(CONNECTOR_ID, "abcdefg"),
Map.entry(DEPLOY_FIELD, true)
),
"test-id",
"test-node-id"
);

CompletableFuture<WorkflowData> future = this.registerRemoteModelStep.execute(
deployWorkflowData.getNodeId(),
deployWorkflowData,
Collections.emptyMap(),
Collections.emptyMap()
);

verify(mlNodeClient, times(1)).register(any(MLRegisterModelInput.class), any());
// updates both register and deploy resources
verify(flowFrameworkIndicesHandler, times(2)).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any());

assertTrue(future.isDone());
assertFalse(future.isCompletedExceptionally());
assertEquals(modelId, future.get().getContent().get(MODEL_ID));
assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS));
}

public void testRegisterRemoteModelFailure() {
doAnswer(invocation -> {
ActionListener<MLRegisterModelResponse> actionListener = invocation.getArgument(1);
Expand Down Expand Up @@ -137,7 +187,7 @@ public void testMissingInputs() {
ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent());
assertTrue(ex.getCause() instanceof FlowFrameworkException);
assertTrue(ex.getCause().getMessage().startsWith("Missing required inputs ["));
for (String s : new String[] { "name", "function_name", CONNECTOR_ID }) {
for (String s : new String[] { "name", CONNECTOR_ID }) {
assertTrue(ex.getCause().getMessage().contains(s));
}
assertTrue(ex.getCause().getMessage().endsWith("] in workflow [test-id] node [test-node-id]"));
Expand Down
Loading