diff --git a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java index d64f817fa4..4fed08ff8c 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java @@ -7,6 +7,7 @@ import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; +import java.time.Instant; import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; @@ -14,10 +15,14 @@ import org.opensearch.ExceptionsHelper; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.action.update.UpdateRequest; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; @@ -33,9 +38,11 @@ import org.opensearch.index.query.TermsQueryBuilder; import org.opensearch.ml.cluster.DiscoveryNodeHelper; import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelAction; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesRequest; +import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse; @@ -57,6 +64,7 @@ import org.opensearch.transport.TransportService; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMap; import lombok.extern.log4j.Log4j2; @@ -179,11 +187,75 @@ private void undeployModels( MLUndeployModelNodesRequest mlUndeployModelNodesRequest = new MLUndeployModelNodesRequest(targetNodeIds, modelIds); mlUndeployModelNodesRequest.setTenantId(tenantId); - client.execute(MLUndeployModelAction.INSTANCE, mlUndeployModelNodesRequest, ActionListener.wrap(r -> { - listener.onResponse(new MLUndeployModelsResponse(r)); + client.execute(MLUndeployModelAction.INSTANCE, mlUndeployModelNodesRequest, ActionListener.wrap(response -> { + /* + * The method TransportUndeployModelsAction.processUndeployModelResponseAndUpdate(...) performs + * undeploy action of models by removing the models from the nodes cache and updating the index when it's able to find it. + * + * The problem becomes when the models index is incorrect and no node(s) are servicing the model. This results in + * `{}` responses (on undeploy action), with no update to the model index thus, causing incorrect model state status. + * + * Having this change enables a check that this edge case occurs along with having access to the model id + * allowing us to update the stale model index correctly to `UNDEPLOYED` since no nodes service the model. + */ + if (response.getNodes().isEmpty()) { + bulkSetModelIndexToUndeploy(modelIds, listener, response); + return; + } + listener.onResponse(new MLUndeployModelsResponse(response)); }, listener::onFailure)); } + private void bulkSetModelIndexToUndeploy( + String[] modelIds, + ActionListener listener, + MLUndeployModelNodesResponse response + ) { + BulkRequest bulkUpdateRequest = new BulkRequest(); + for (String modelId : modelIds) { + UpdateRequest updateRequest = new UpdateRequest(); + + ImmutableMap.Builder builder = ImmutableMap.builder(); + builder.put(MLModel.MODEL_STATE_FIELD, MLModelState.UNDEPLOYED.name()); + + builder.put(MLModel.PLANNING_WORKER_NODES_FIELD, List.of()); + builder.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, 0); + + builder.put(MLModel.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli()); + builder.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, 0); + updateRequest.index(ML_MODEL_INDEX).id(modelId).doc(builder.build()); + bulkUpdateRequest.add(updateRequest); + } + + bulkUpdateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + log.info("No nodes running these models: {}", Arrays.toString(modelIds)); + + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + ActionListener listenerWithContextRestoration = ActionListener + .runBefore(listener, () -> threadContext.restore()); + ActionListener bulkResponseListener = ActionListener.wrap(br -> { + log.debug("Successfully set the following modelId(s) to UNDEPLOY in index: {}", Arrays.toString(modelIds)); + listenerWithContextRestoration.onResponse(new MLUndeployModelsResponse(response)); + }, e -> { + String modelsNotFoundMessage = String + .format("Failed to set the following modelId(s) to UNDEPLOY in index: %s", Arrays.toString(modelIds)); + log.error(modelsNotFoundMessage, e); + + OpenSearchStatusException exception = new OpenSearchStatusException( + modelsNotFoundMessage + e.getMessage(), + RestStatus.INTERNAL_SERVER_ERROR + ); + listenerWithContextRestoration.onFailure(exception); + }); + + client.bulk(bulkUpdateRequest, bulkResponseListener); + } catch (Exception e) { + log.error("Unexpected error while setting the following modelId(s) to UNDEPLOY in index: {}", Arrays.toString(modelIds), e); + listener.onFailure(e); + } + + } + private void validateAccess(String modelId, String tenantId, ActionListener listener) { User user = RestActionUtils.getUserContext(client); boolean isSuperAdmin = isSuperAdminUserWrapper(clusterService, client); diff --git a/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java index 2964ed583b..7abde9aa33 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java @@ -13,14 +13,17 @@ import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.task.MLPredictTaskRunnerTests.USER_STRING; import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.Map; import org.junit.Before; import org.junit.Rule; @@ -29,7 +32,10 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.action.FailedNodeException; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.update.UpdateRequest; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.service.ClusterService; @@ -42,6 +48,7 @@ import org.opensearch.ml.cluster.DiscoveryNodeHelper; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodeResponse; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest; @@ -172,6 +179,129 @@ public void setup() throws IOException { }).when(mlModelManager).getModel(any(), any(), any(), any(), isA(ActionListener.class)); } + public void testDoExecute_undeployModelIndex_WhenNoNodesServiceModel() { + String modelId = "someModelId"; + MLModel mlModel = MLModel + .builder() + .user(User.parse(USER_STRING)) + .modelGroupId("111") + .version("111") + .name("Test Model") + .modelId(modelId) + .algorithm(FunctionName.BATCH_RCF) + .content("content") + .totalChunks(2) + .isHidden(true) + .build(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(4); + listener.onResponse(mlModel); + return null; + }).when(mlModelManager).getModel(any(), any(), any(), any(), isA(ActionListener.class)); + + doReturn(true).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client); + + List responseList = new ArrayList<>(); + List failuresList = new ArrayList<>(); + MLUndeployModelNodesResponse nodesResponse = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList); + + // Send back a response with no nodes associated to the model. Thus, will write back to the model index that its UNDEPLOYED + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(nodesResponse); + return null; + }).when(client).execute(any(), any(), isA(ActionListener.class)); + + ArgumentCaptor bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class); + + // mock the bulk response that can be captured for inspecting the contents of the write to index + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + BulkResponse bulkResponse = mock(BulkResponse.class); + when(bulkResponse.hasFailures()).thenReturn(false); + listener.onResponse(bulkResponse); + return null; + }).when(client).bulk(bulkRequestCaptor.capture(), any(ActionListener.class)); + + String[] modelIds = new String[] { modelId }; + String[] nodeIds = new String[] { "test_node_id1", "test_node_id2" }; + MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null); + + transportUndeployModelsAction.doExecute(task, request, actionListener); + + BulkRequest capturedBulkRequest = bulkRequestCaptor.getValue(); + assertEquals(1, capturedBulkRequest.numberOfActions()); + UpdateRequest updateRequest = (UpdateRequest) capturedBulkRequest.requests().get(0); + + @SuppressWarnings("unchecked") + Map updateDoc = updateRequest.doc().sourceAsMap(); + String modelIdFromBulkRequest = updateRequest.id(); + String indexNameFromBulkRequest = updateRequest.index(); + + assertEquals("Check that the write happened at the model index", ML_MODEL_INDEX, indexNameFromBulkRequest); + assertEquals("Check that the result bulk write hit this specific modelId", modelId, modelIdFromBulkRequest); + + assertEquals(MLModelState.UNDEPLOYED.name(), updateDoc.get(MLModel.MODEL_STATE_FIELD)); + assertEquals(0, updateDoc.get(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD)); + assertEquals(0, updateDoc.get(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD)); + assertEquals(List.of(), updateDoc.get(MLModel.PLANNING_WORKER_NODES_FIELD)); + assertTrue(updateDoc.containsKey(MLModel.LAST_UPDATED_TIME_FIELD)); + + verify(actionListener).onResponse(any(MLUndeployModelsResponse.class)); + verify(client).bulk(any(BulkRequest.class), any(ActionListener.class)); + } + + public void testDoExecute_noBulkRequestFired_WhenSomeNodesServiceModel() { + String modelId = "someModelId"; + MLModel mlModel = MLModel + .builder() + .user(User.parse(USER_STRING)) + .modelGroupId("111") + .version("111") + .name("Test Model") + .modelId(modelId) + .algorithm(FunctionName.BATCH_RCF) + .content("content") + .totalChunks(2) + .isHidden(true) + .build(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(4); + listener.onResponse(mlModel); + return null; + }).when(mlModelManager).getModel(any(), any(), any(), any(), isA(ActionListener.class)); + + doReturn(true).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client); + + List responseList = new ArrayList<>(); + responseList.add(mock(MLUndeployModelNodeResponse.class)); + responseList.add(mock(MLUndeployModelNodeResponse.class)); + List failuresList = new ArrayList<>(); + failuresList.add(mock(FailedNodeException.class)); + failuresList.add(mock(FailedNodeException.class)); + + MLUndeployModelNodesResponse nodesResponse = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList); + + // Send back a response with nodes associated to the model + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(nodesResponse); + return null; + }).when(client).execute(any(), any(), isA(ActionListener.class)); + + String[] modelIds = new String[] { modelId }; + String[] nodeIds = new String[] { "test_node_id1", "test_node_id2" }; + MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null); + + transportUndeployModelsAction.doExecute(task, request, actionListener); + + verify(actionListener).onResponse(any(MLUndeployModelsResponse.class)); + // Check that no bulk write occurred Since there were nodes servicing the model + verify(client, never()).bulk(any(BulkRequest.class), any(ActionListener.class)); + } + public void testHiddenModelSuccess() { MLModel mlModel = MLModel .builder() @@ -194,16 +324,28 @@ public void testHiddenModelSuccess() { List responseList = new ArrayList<>(); List failuresList = new ArrayList<>(); MLUndeployModelNodesResponse response = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList); + doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); listener.onResponse(response); return null; }).when(client).execute(any(), any(), isA(ActionListener.class)); + // Mock the client.bulk call + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + BulkResponse bulkResponse = mock(BulkResponse.class); + when(bulkResponse.hasFailures()).thenReturn(false); + listener.onResponse(bulkResponse); + return null; + }).when(client).bulk(any(BulkRequest.class), any(ActionListener.class)); + doReturn(true).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client); MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null); transportUndeployModelsAction.doExecute(task, request, actionListener); + verify(actionListener).onResponse(any(MLUndeployModelsResponse.class)); + verify(client).bulk(any(BulkRequest.class), any(ActionListener.class)); } public void testHiddenModelPermissionError() { @@ -257,9 +399,19 @@ public void testDoExecute() { listener.onResponse(response); return null; }).when(client).execute(any(), any(), isA(ActionListener.class)); + // Mock the client.bulk call + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + BulkResponse bulkResponse = mock(BulkResponse.class); + when(bulkResponse.hasFailures()).thenReturn(false); + listener.onResponse(bulkResponse); + return null; + }).when(client).bulk(any(BulkRequest.class), any(ActionListener.class)); + MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null); transportUndeployModelsAction.doExecute(task, request, actionListener); verify(actionListener).onResponse(any(MLUndeployModelsResponse.class)); + verify(client).bulk(any(BulkRequest.class), any(ActionListener.class)); } public void testDoExecute_modelAccessControl_notEnabled() {