Skip to content

Commit

Permalink
Undeploy models with no WorkerNodes (#3380) (#3447)
Browse files Browse the repository at this point in the history
* undeploy models with no WorkerNodes

This commit aims to undeploy modelIds that have no nodes associated to them so as to keep the intention of undeploy truthful.

Signed-off-by: Brian Flores <[email protected]>

# Conflicts:
#	plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java

* Exit early when no nodes service the model

Now when entering this method its guaranteed to write to index first before sending back the MLUndeploy response. And will also send back a exception if the write back fails

Signed-off-by: Brian Flores <[email protected]>

* add UTs for undeploy stale model index fix

Added UTs for the 2 scenarios 1. Check that the bulk operation occured when no nodes are returned from the Undeploy response is , 2. Check that the bulk operation did not occur when there are nodes that have found the model within their cache.

Signed-off-by: Brian Flores <[email protected]>

* update code change with comment explaining the change

Signed-off-by: Brian Flores <[email protected]>

* add context stash/restore to write operation

Signed-off-by: Brian Flores <[email protected]>

* Apply spotless

Signed-off-by: Brian Flores <[email protected]>

* Add better logging to write request

Signed-off-by: Brian Flores <[email protected]>

* wrap exception into 5xx

Signed-off-by: Brian Flores <[email protected]>

* adapts undeploy code change to multi-tenancy feature

Signed-off-by: Brian Flores <[email protected]>

* applies spotless

Signed-off-by: Brian Flores <[email protected]>

---------

Signed-off-by: Brian Flores <[email protected]>
(cherry picked from commit 18bcaae)

Co-authored-by: Brian Flores <[email protected]>
  • Loading branch information
opensearch-trigger-bot[bot] and brianf-aws authored Jan 28, 2025
1 parent aadc422 commit 431c31b
Show file tree
Hide file tree
Showing 2 changed files with 226 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,22 @@

import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;

import java.time.Instant;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

import org.opensearch.ExceptionsHelper;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.bulk.BulkRequest;
import org.opensearch.action.bulk.BulkResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
Expand All @@ -33,9 +38,11 @@
import org.opensearch.index.query.TermsQueryBuilder;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelAction;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;
Expand All @@ -57,6 +64,7 @@
import org.opensearch.transport.TransportService;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;

import lombok.extern.log4j.Log4j2;

Expand Down Expand Up @@ -179,11 +187,75 @@ private void undeployModels(
MLUndeployModelNodesRequest mlUndeployModelNodesRequest = new MLUndeployModelNodesRequest(targetNodeIds, modelIds);
mlUndeployModelNodesRequest.setTenantId(tenantId);

client.execute(MLUndeployModelAction.INSTANCE, mlUndeployModelNodesRequest, ActionListener.wrap(r -> {
listener.onResponse(new MLUndeployModelsResponse(r));
client.execute(MLUndeployModelAction.INSTANCE, mlUndeployModelNodesRequest, ActionListener.wrap(response -> {
/*
* The method TransportUndeployModelsAction.processUndeployModelResponseAndUpdate(...) performs
* undeploy action of models by removing the models from the nodes cache and updating the index when it's able to find it.
*
* The problem becomes when the models index is incorrect and no node(s) are servicing the model. This results in
* `{}` responses (on undeploy action), with no update to the model index thus, causing incorrect model state status.
*
* Having this change enables a check that this edge case occurs along with having access to the model id
* allowing us to update the stale model index correctly to `UNDEPLOYED` since no nodes service the model.
*/
if (response.getNodes().isEmpty()) {
bulkSetModelIndexToUndeploy(modelIds, listener, response);
return;
}
listener.onResponse(new MLUndeployModelsResponse(response));
}, listener::onFailure));
}

private void bulkSetModelIndexToUndeploy(
String[] modelIds,
ActionListener<MLUndeployModelsResponse> listener,
MLUndeployModelNodesResponse response
) {
BulkRequest bulkUpdateRequest = new BulkRequest();
for (String modelId : modelIds) {
UpdateRequest updateRequest = new UpdateRequest();

ImmutableMap.Builder<String, Object> builder = ImmutableMap.builder();
builder.put(MLModel.MODEL_STATE_FIELD, MLModelState.UNDEPLOYED.name());

builder.put(MLModel.PLANNING_WORKER_NODES_FIELD, List.of());
builder.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, 0);

builder.put(MLModel.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli());
builder.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, 0);
updateRequest.index(ML_MODEL_INDEX).id(modelId).doc(builder.build());
bulkUpdateRequest.add(updateRequest);
}

bulkUpdateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
log.info("No nodes running these models: {}", Arrays.toString(modelIds));

try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
ActionListener<MLUndeployModelsResponse> listenerWithContextRestoration = ActionListener
.runBefore(listener, () -> threadContext.restore());
ActionListener<BulkResponse> bulkResponseListener = ActionListener.wrap(br -> {
log.debug("Successfully set the following modelId(s) to UNDEPLOY in index: {}", Arrays.toString(modelIds));
listenerWithContextRestoration.onResponse(new MLUndeployModelsResponse(response));
}, e -> {
String modelsNotFoundMessage = String
.format("Failed to set the following modelId(s) to UNDEPLOY in index: %s", Arrays.toString(modelIds));
log.error(modelsNotFoundMessage, e);

OpenSearchStatusException exception = new OpenSearchStatusException(
modelsNotFoundMessage + e.getMessage(),
RestStatus.INTERNAL_SERVER_ERROR
);
listenerWithContextRestoration.onFailure(exception);
});

client.bulk(bulkUpdateRequest, bulkResponseListener);
} catch (Exception e) {
log.error("Unexpected error while setting the following modelId(s) to UNDEPLOY in index: {}", Arrays.toString(modelIds), e);
listener.onFailure(e);
}

}

private void validateAccess(String modelId, String tenantId, ActionListener<Boolean> listener) {
User user = RestActionUtils.getUserContext(client);
boolean isSuperAdmin = isSuperAdminUserWrapper(clusterService, client);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.task.MLPredictTaskRunnerTests.USER_STRING;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import org.junit.Before;
import org.junit.Rule;
Expand All @@ -29,7 +32,10 @@
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.FailedNodeException;
import org.opensearch.action.bulk.BulkRequest;
import org.opensearch.action.bulk.BulkResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.ClusterName;
import org.opensearch.cluster.service.ClusterService;
Expand All @@ -42,6 +48,7 @@
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodeResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
Expand Down Expand Up @@ -172,6 +179,129 @@ public void setup() throws IOException {
}).when(mlModelManager).getModel(any(), any(), any(), any(), isA(ActionListener.class));
}

public void testDoExecute_undeployModelIndex_WhenNoNodesServiceModel() {
String modelId = "someModelId";
MLModel mlModel = MLModel
.builder()
.user(User.parse(USER_STRING))
.modelGroupId("111")
.version("111")
.name("Test Model")
.modelId(modelId)
.algorithm(FunctionName.BATCH_RCF)
.content("content")
.totalChunks(2)
.isHidden(true)
.build();

doAnswer(invocation -> {
ActionListener<MLModel> listener = invocation.getArgument(4);
listener.onResponse(mlModel);
return null;
}).when(mlModelManager).getModel(any(), any(), any(), any(), isA(ActionListener.class));

doReturn(true).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client);

List<MLUndeployModelNodeResponse> responseList = new ArrayList<>();
List<FailedNodeException> failuresList = new ArrayList<>();
MLUndeployModelNodesResponse nodesResponse = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList);

// Send back a response with no nodes associated to the model. Thus, will write back to the model index that its UNDEPLOYED
doAnswer(invocation -> {
ActionListener<MLUndeployModelNodesResponse> listener = invocation.getArgument(2);
listener.onResponse(nodesResponse);
return null;
}).when(client).execute(any(), any(), isA(ActionListener.class));

ArgumentCaptor<BulkRequest> bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class);

// mock the bulk response that can be captured for inspecting the contents of the write to index
doAnswer(invocation -> {
ActionListener<BulkResponse> listener = invocation.getArgument(1);
BulkResponse bulkResponse = mock(BulkResponse.class);
when(bulkResponse.hasFailures()).thenReturn(false);
listener.onResponse(bulkResponse);
return null;
}).when(client).bulk(bulkRequestCaptor.capture(), any(ActionListener.class));

String[] modelIds = new String[] { modelId };
String[] nodeIds = new String[] { "test_node_id1", "test_node_id2" };
MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null);

transportUndeployModelsAction.doExecute(task, request, actionListener);

BulkRequest capturedBulkRequest = bulkRequestCaptor.getValue();
assertEquals(1, capturedBulkRequest.numberOfActions());
UpdateRequest updateRequest = (UpdateRequest) capturedBulkRequest.requests().get(0);

@SuppressWarnings("unchecked")
Map<String, Object> updateDoc = updateRequest.doc().sourceAsMap();
String modelIdFromBulkRequest = updateRequest.id();
String indexNameFromBulkRequest = updateRequest.index();

assertEquals("Check that the write happened at the model index", ML_MODEL_INDEX, indexNameFromBulkRequest);
assertEquals("Check that the result bulk write hit this specific modelId", modelId, modelIdFromBulkRequest);

assertEquals(MLModelState.UNDEPLOYED.name(), updateDoc.get(MLModel.MODEL_STATE_FIELD));
assertEquals(0, updateDoc.get(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD));
assertEquals(0, updateDoc.get(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD));
assertEquals(List.of(), updateDoc.get(MLModel.PLANNING_WORKER_NODES_FIELD));
assertTrue(updateDoc.containsKey(MLModel.LAST_UPDATED_TIME_FIELD));

verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));
}

public void testDoExecute_noBulkRequestFired_WhenSomeNodesServiceModel() {
String modelId = "someModelId";
MLModel mlModel = MLModel
.builder()
.user(User.parse(USER_STRING))
.modelGroupId("111")
.version("111")
.name("Test Model")
.modelId(modelId)
.algorithm(FunctionName.BATCH_RCF)
.content("content")
.totalChunks(2)
.isHidden(true)
.build();

doAnswer(invocation -> {
ActionListener<MLModel> listener = invocation.getArgument(4);
listener.onResponse(mlModel);
return null;
}).when(mlModelManager).getModel(any(), any(), any(), any(), isA(ActionListener.class));

doReturn(true).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client);

List<MLUndeployModelNodeResponse> responseList = new ArrayList<>();
responseList.add(mock(MLUndeployModelNodeResponse.class));
responseList.add(mock(MLUndeployModelNodeResponse.class));
List<FailedNodeException> failuresList = new ArrayList<>();
failuresList.add(mock(FailedNodeException.class));
failuresList.add(mock(FailedNodeException.class));

MLUndeployModelNodesResponse nodesResponse = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList);

// Send back a response with nodes associated to the model
doAnswer(invocation -> {
ActionListener<MLUndeployModelNodesResponse> listener = invocation.getArgument(2);
listener.onResponse(nodesResponse);
return null;
}).when(client).execute(any(), any(), isA(ActionListener.class));

String[] modelIds = new String[] { modelId };
String[] nodeIds = new String[] { "test_node_id1", "test_node_id2" };
MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null);

transportUndeployModelsAction.doExecute(task, request, actionListener);

verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
// Check that no bulk write occurred Since there were nodes servicing the model
verify(client, never()).bulk(any(BulkRequest.class), any(ActionListener.class));
}

public void testHiddenModelSuccess() {
MLModel mlModel = MLModel
.builder()
Expand All @@ -194,16 +324,28 @@ public void testHiddenModelSuccess() {
List<MLUndeployModelNodeResponse> responseList = new ArrayList<>();
List<FailedNodeException> failuresList = new ArrayList<>();
MLUndeployModelNodesResponse response = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList);

doAnswer(invocation -> {
ActionListener<MLUndeployModelNodesResponse> listener = invocation.getArgument(2);
listener.onResponse(response);
return null;
}).when(client).execute(any(), any(), isA(ActionListener.class));

// Mock the client.bulk call
doAnswer(invocation -> {
ActionListener<BulkResponse> listener = invocation.getArgument(1);
BulkResponse bulkResponse = mock(BulkResponse.class);
when(bulkResponse.hasFailures()).thenReturn(false);
listener.onResponse(bulkResponse);
return null;
}).when(client).bulk(any(BulkRequest.class), any(ActionListener.class));

doReturn(true).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client);
MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null);
transportUndeployModelsAction.doExecute(task, request, actionListener);

verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));
}

public void testHiddenModelPermissionError() {
Expand Down Expand Up @@ -257,9 +399,19 @@ public void testDoExecute() {
listener.onResponse(response);
return null;
}).when(client).execute(any(), any(), isA(ActionListener.class));
// Mock the client.bulk call
doAnswer(invocation -> {
ActionListener<BulkResponse> listener = invocation.getArgument(1);
BulkResponse bulkResponse = mock(BulkResponse.class);
when(bulkResponse.hasFailures()).thenReturn(false);
listener.onResponse(bulkResponse);
return null;
}).when(client).bulk(any(BulkRequest.class), any(ActionListener.class));

MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null);
transportUndeployModelsAction.doExecute(task, request, actionListener);
verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));
}

public void testDoExecute_modelAccessControl_notEnabled() {
Expand Down

0 comments on commit 431c31b

Please sign in to comment.