diff --git a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java index dc93419070..bf9f630097 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java @@ -194,7 +194,7 @@ private void bulkSetModelIndexToUndeploy( } bulkUpdateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - log.info("No nodes service: {}", modelIds.toString()); + log.info("No nodes service: {}", Arrays.toString(modelIds)); client.bulk(bulkUpdateRequest, ActionListener.wrap(br -> { log.debug("Successfully set modelIds to UNDEPLOY in index"); diff --git a/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java index 42152f473d..c0405d766c 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java @@ -13,14 +13,17 @@ import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.task.MLPredictTaskRunnerTests.USER_STRING; import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.Map; import org.junit.Before; import org.junit.Rule; @@ -29,7 +32,10 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.action.FailedNodeException; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.update.UpdateRequest; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.service.ClusterService; @@ -42,6 +48,7 @@ import org.opensearch.ml.cluster.DiscoveryNodeHelper; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodeResponse; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest; @@ -164,6 +171,129 @@ public void setup() throws IOException { }).when(mlModelManager).getModel(any(), any(), any(), isA(ActionListener.class)); } + public void testDoExecute_undeployModelIndex_WhenNoNodesServiceModel() { + String modelId = "someModelId"; + MLModel mlModel = MLModel + .builder() + .user(User.parse(USER_STRING)) + .modelGroupId("111") + .version("111") + .name("Test Model") + .modelId(modelId) + .algorithm(FunctionName.BATCH_RCF) + .content("content") + .totalChunks(2) + .isHidden(true) + .build(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mlModel); + return null; + }).when(mlModelManager).getModel(any(), any(), any(), isA(ActionListener.class)); + + doReturn(true).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client); + + List responseList = new ArrayList<>(); + List failuresList = new ArrayList<>(); + MLUndeployModelNodesResponse nodesResponse = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList); + + // Send back a response with no nodes associated to the model. Thus, will write back to the model index that its UNDEPLOYED + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(nodesResponse); + return null; + }).when(client).execute(any(), any(), isA(ActionListener.class)); + + ArgumentCaptor bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class); + + // mock the bulk response that can be captured for inspecting the contents of the write to index + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + BulkResponse bulkResponse = mock(BulkResponse.class); + when(bulkResponse.hasFailures()).thenReturn(false); + listener.onResponse(bulkResponse); + return null; + }).when(client).bulk(bulkRequestCaptor.capture(), any(ActionListener.class)); + + String[] modelIds = new String[] { modelId }; + String[] nodeIds = new String[] { "test_node_id1", "test_node_id2" }; + MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds); + + transportUndeployModelsAction.doExecute(task, request, actionListener); + + BulkRequest capturedBulkRequest = bulkRequestCaptor.getValue(); + assertEquals(1, capturedBulkRequest.numberOfActions()); + UpdateRequest updateRequest = (UpdateRequest) capturedBulkRequest.requests().get(0); + + @SuppressWarnings("unchecked") + Map updateDoc = updateRequest.doc().sourceAsMap(); + String modelIdFromBulkRequest = updateRequest.id(); + String indexNameFromBulkRequest = updateRequest.index(); + + assertEquals("Check that the write happened at the model index", ML_MODEL_INDEX, indexNameFromBulkRequest); + assertEquals("Check that the result bulk write hit this specific modelId", modelId, modelIdFromBulkRequest); + + assertEquals(MLModelState.UNDEPLOYED.name(), updateDoc.get(MLModel.MODEL_STATE_FIELD)); + assertEquals(0, updateDoc.get(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD)); + assertEquals(0, updateDoc.get(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD)); + assertEquals(List.of(), updateDoc.get(MLModel.PLANNING_WORKER_NODES_FIELD)); + assertTrue(updateDoc.containsKey(MLModel.LAST_UPDATED_TIME_FIELD)); + + verify(actionListener).onResponse(any(MLUndeployModelsResponse.class)); + verify(client).bulk(any(BulkRequest.class), any(ActionListener.class)); + } + + public void testDoExecute_noBulkRequestFired_WhenSomeNodesServiceModel() { + String modelId = "someModelId"; + MLModel mlModel = MLModel + .builder() + .user(User.parse(USER_STRING)) + .modelGroupId("111") + .version("111") + .name("Test Model") + .modelId(modelId) + .algorithm(FunctionName.BATCH_RCF) + .content("content") + .totalChunks(2) + .isHidden(true) + .build(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mlModel); + return null; + }).when(mlModelManager).getModel(any(), any(), any(), isA(ActionListener.class)); + + doReturn(true).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client); + + List responseList = new ArrayList<>(); + responseList.add(mock(MLUndeployModelNodeResponse.class)); + responseList.add(mock(MLUndeployModelNodeResponse.class)); + List failuresList = new ArrayList<>(); + failuresList.add(mock(FailedNodeException.class)); + failuresList.add(mock(FailedNodeException.class)); + + MLUndeployModelNodesResponse nodesResponse = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList); + + // Send back a response with nodes associated to the model + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(nodesResponse); + return null; + }).when(client).execute(any(), any(), isA(ActionListener.class)); + + String[] modelIds = new String[] { modelId }; + String[] nodeIds = new String[] { "test_node_id1", "test_node_id2" }; + MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds); + + transportUndeployModelsAction.doExecute(task, request, actionListener); + + verify(actionListener).onResponse(any(MLUndeployModelsResponse.class)); + // Check that no bulk write occurred Since there were nodes servicing the model + verify(client, never()).bulk(any(BulkRequest.class), any(ActionListener.class)); + } + public void testHiddenModelSuccess() { MLModel mlModel = MLModel .builder() @@ -186,16 +316,28 @@ public void testHiddenModelSuccess() { List responseList = new ArrayList<>(); List failuresList = new ArrayList<>(); MLUndeployModelNodesResponse response = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList); + doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); listener.onResponse(response); return null; }).when(client).execute(any(), any(), isA(ActionListener.class)); + // Mock the client.bulk call + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + BulkResponse bulkResponse = mock(BulkResponse.class); + when(bulkResponse.hasFailures()).thenReturn(false); + listener.onResponse(bulkResponse); + return null; + }).when(client).bulk(any(BulkRequest.class), any(ActionListener.class)); + doReturn(true).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client); MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds); transportUndeployModelsAction.doExecute(task, request, actionListener); + verify(actionListener).onResponse(any(MLUndeployModelsResponse.class)); + verify(client).bulk(any(BulkRequest.class), any(ActionListener.class)); } public void testHiddenModelPermissionError() { @@ -249,9 +391,19 @@ public void testDoExecute() { listener.onResponse(response); return null; }).when(client).execute(any(), any(), isA(ActionListener.class)); + // Mock the client.bulk call + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + BulkResponse bulkResponse = mock(BulkResponse.class); + when(bulkResponse.hasFailures()).thenReturn(false); + listener.onResponse(bulkResponse); + return null; + }).when(client).bulk(any(BulkRequest.class), any(ActionListener.class)); + MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds); transportUndeployModelsAction.doExecute(task, request, actionListener); verify(actionListener).onResponse(any(MLUndeployModelsResponse.class)); + verify(client).bulk(any(BulkRequest.class), any(ActionListener.class)); } public void testDoExecute_modelAccessControl_notEnabled() {