Skip to content

Commit

Permalink
add UTs for undeploy stale model index fix
Browse files Browse the repository at this point in the history
Added UTs for the 2 scenarios 1. Check that the bulk operation occured when no nodes are returned from the Undeploy response is , 2. Check that the bulk operation did not occur when there are nodes that have found the model within their cache.

Signed-off-by: Brian Flores <[email protected]>
  • Loading branch information
brianf-aws committed Jan 12, 2025
1 parent 9104cb8 commit 6d3b398
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ private void bulkSetModelIndexToUndeploy(
}

bulkUpdateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
log.info("No nodes service: {}", modelIds.toString());
log.info("No nodes service: {}", Arrays.toString(modelIds));

client.bulk(bulkUpdateRequest, ActionListener.wrap(br -> {
log.debug("Successfully set modelIds to UNDEPLOY in index");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.task.MLPredictTaskRunnerTests.USER_STRING;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import org.junit.Before;
import org.junit.Rule;
Expand All @@ -29,7 +32,10 @@
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.FailedNodeException;
import org.opensearch.action.bulk.BulkRequest;
import org.opensearch.action.bulk.BulkResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.ClusterName;
import org.opensearch.cluster.service.ClusterService;
Expand All @@ -42,6 +48,7 @@
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodeResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
Expand Down Expand Up @@ -164,6 +171,129 @@ public void setup() throws IOException {
}).when(mlModelManager).getModel(any(), any(), any(), isA(ActionListener.class));
}

public void testDoExecute_undeployModelIndex_WhenNoNodesServiceModel() {
String modelId = "someModelId";
MLModel mlModel = MLModel
.builder()
.user(User.parse(USER_STRING))
.modelGroupId("111")
.version("111")
.name("Test Model")
.modelId(modelId)
.algorithm(FunctionName.BATCH_RCF)
.content("content")
.totalChunks(2)
.isHidden(true)
.build();

doAnswer(invocation -> {
ActionListener<MLModel> listener = invocation.getArgument(3);
listener.onResponse(mlModel);
return null;
}).when(mlModelManager).getModel(any(), any(), any(), isA(ActionListener.class));

doReturn(true).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client);

List<MLUndeployModelNodeResponse> responseList = new ArrayList<>();
List<FailedNodeException> failuresList = new ArrayList<>();
MLUndeployModelNodesResponse nodesResponse = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList);

// Send back a response with no nodes associated to the model. Thus, will write back to the model index that its UNDEPLOYED
doAnswer(invocation -> {
ActionListener<MLUndeployModelNodesResponse> listener = invocation.getArgument(2);
listener.onResponse(nodesResponse);
return null;
}).when(client).execute(any(), any(), isA(ActionListener.class));

ArgumentCaptor<BulkRequest> bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class);

// mock the bulk response that can be captured for inspecting the contents of the write to index
doAnswer(invocation -> {
ActionListener<BulkResponse> listener = invocation.getArgument(1);
BulkResponse bulkResponse = mock(BulkResponse.class);
when(bulkResponse.hasFailures()).thenReturn(false);
listener.onResponse(bulkResponse);
return null;
}).when(client).bulk(bulkRequestCaptor.capture(), any(ActionListener.class));

String[] modelIds = new String[] { modelId };
String[] nodeIds = new String[] { "test_node_id1", "test_node_id2" };
MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds);

transportUndeployModelsAction.doExecute(task, request, actionListener);

BulkRequest capturedBulkRequest = bulkRequestCaptor.getValue();
assertEquals(1, capturedBulkRequest.numberOfActions());
UpdateRequest updateRequest = (UpdateRequest) capturedBulkRequest.requests().get(0);

@SuppressWarnings("unchecked")
Map<String, Object> updateDoc = updateRequest.doc().sourceAsMap();
String modelIdFromBulkRequest = updateRequest.id();
String indexNameFromBulkRequest = updateRequest.index();

assertEquals("Check that the write happened at the model index", ML_MODEL_INDEX, indexNameFromBulkRequest);
assertEquals("Check that the result bulk write hit this specific modelId", modelId, modelIdFromBulkRequest);

assertEquals(MLModelState.UNDEPLOYED.name(), updateDoc.get(MLModel.MODEL_STATE_FIELD));
assertEquals(0, updateDoc.get(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD));
assertEquals(0, updateDoc.get(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD));
assertEquals(List.of(), updateDoc.get(MLModel.PLANNING_WORKER_NODES_FIELD));
assertTrue(updateDoc.containsKey(MLModel.LAST_UPDATED_TIME_FIELD));

verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));
}

public void testDoExecute_noBulkRequestFired_WhenSomeNodesServiceModel() {
String modelId = "someModelId";
MLModel mlModel = MLModel
.builder()
.user(User.parse(USER_STRING))
.modelGroupId("111")
.version("111")
.name("Test Model")
.modelId(modelId)
.algorithm(FunctionName.BATCH_RCF)
.content("content")
.totalChunks(2)
.isHidden(true)
.build();

doAnswer(invocation -> {
ActionListener<MLModel> listener = invocation.getArgument(3);
listener.onResponse(mlModel);
return null;
}).when(mlModelManager).getModel(any(), any(), any(), isA(ActionListener.class));

doReturn(true).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client);

List<MLUndeployModelNodeResponse> responseList = new ArrayList<>();
responseList.add(mock(MLUndeployModelNodeResponse.class));
responseList.add(mock(MLUndeployModelNodeResponse.class));
List<FailedNodeException> failuresList = new ArrayList<>();
failuresList.add(mock(FailedNodeException.class));
failuresList.add(mock(FailedNodeException.class));

MLUndeployModelNodesResponse nodesResponse = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList);

// Send back a response with nodes associated to the model
doAnswer(invocation -> {
ActionListener<MLUndeployModelNodesResponse> listener = invocation.getArgument(2);
listener.onResponse(nodesResponse);
return null;
}).when(client).execute(any(), any(), isA(ActionListener.class));

String[] modelIds = new String[] { modelId };
String[] nodeIds = new String[] { "test_node_id1", "test_node_id2" };
MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds);

transportUndeployModelsAction.doExecute(task, request, actionListener);

verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
// Check that no bulk write occurred Since there were nodes servicing the model
verify(client, never()).bulk(any(BulkRequest.class), any(ActionListener.class));
}

public void testHiddenModelSuccess() {
MLModel mlModel = MLModel
.builder()
Expand All @@ -186,16 +316,28 @@ public void testHiddenModelSuccess() {
List<MLUndeployModelNodeResponse> responseList = new ArrayList<>();
List<FailedNodeException> failuresList = new ArrayList<>();
MLUndeployModelNodesResponse response = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList);

doAnswer(invocation -> {
ActionListener<MLUndeployModelNodesResponse> listener = invocation.getArgument(2);
listener.onResponse(response);
return null;
}).when(client).execute(any(), any(), isA(ActionListener.class));

// Mock the client.bulk call
doAnswer(invocation -> {
ActionListener<BulkResponse> listener = invocation.getArgument(1);
BulkResponse bulkResponse = mock(BulkResponse.class);
when(bulkResponse.hasFailures()).thenReturn(false);
listener.onResponse(bulkResponse);
return null;
}).when(client).bulk(any(BulkRequest.class), any(ActionListener.class));

doReturn(true).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client);
MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds);
transportUndeployModelsAction.doExecute(task, request, actionListener);

verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));
}

public void testHiddenModelPermissionError() {
Expand Down Expand Up @@ -249,9 +391,19 @@ public void testDoExecute() {
listener.onResponse(response);
return null;
}).when(client).execute(any(), any(), isA(ActionListener.class));
// Mock the client.bulk call
doAnswer(invocation -> {
ActionListener<BulkResponse> listener = invocation.getArgument(1);
BulkResponse bulkResponse = mock(BulkResponse.class);
when(bulkResponse.hasFailures()).thenReturn(false);
listener.onResponse(bulkResponse);
return null;
}).when(client).bulk(any(BulkRequest.class), any(ActionListener.class));

MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds);
transportUndeployModelsAction.doExecute(task, request, actionListener);
verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));
}

public void testDoExecute_modelAccessControl_notEnabled() {
Expand Down

0 comments on commit 6d3b398

Please sign in to comment.