Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Backport/backport 3380 to feature/multi tenancy" #3506

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,16 @@
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL;

import java.time.Instant;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.bulk.BulkRequest;
import org.opensearch.action.bulk.BulkResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
Expand All @@ -38,11 +33,9 @@
import org.opensearch.index.query.TermsQueryBuilder;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelAction;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;
Expand All @@ -64,7 +57,6 @@
import org.opensearch.transport.TransportService;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;

import lombok.extern.log4j.Log4j2;

Expand Down Expand Up @@ -186,75 +178,11 @@ private void undeployModels(
MLUndeployModelNodesRequest mlUndeployModelNodesRequest = new MLUndeployModelNodesRequest(targetNodeIds, modelIds);
mlUndeployModelNodesRequest.setTenantId(tenantId);

client.execute(MLUndeployModelAction.INSTANCE, mlUndeployModelNodesRequest, ActionListener.wrap(response -> {
/*
* The method TransportUndeployModelsAction.processUndeployModelResponseAndUpdate(...) performs
* undeploy action of models by removing the models from the nodes cache and updating the index when it's able to find it.
*
* The problem becomes when the models index is incorrect and no node(s) are servicing the model. This results in
* `{}` responses (on undeploy action), with no update to the model index thus, causing incorrect model state status.
*
* Having this change enables a check that this edge case occurs along with having access to the model id
* allowing us to update the stale model index correctly to `UNDEPLOYED` since no nodes service the model.
*/
if (response.getNodes().isEmpty()) {
bulkSetModelIndexToUndeploy(modelIds, listener, response);
return;
}
listener.onResponse(new MLUndeployModelsResponse(response));
client.execute(MLUndeployModelAction.INSTANCE, mlUndeployModelNodesRequest, ActionListener.wrap(r -> {
listener.onResponse(new MLUndeployModelsResponse(r));
}, listener::onFailure));
}

private void bulkSetModelIndexToUndeploy(
String[] modelIds,
ActionListener<MLUndeployModelsResponse> listener,
MLUndeployModelNodesResponse response
) {
BulkRequest bulkUpdateRequest = new BulkRequest();
for (String modelId : modelIds) {
UpdateRequest updateRequest = new UpdateRequest();

ImmutableMap.Builder<String, Object> builder = ImmutableMap.builder();
builder.put(MLModel.MODEL_STATE_FIELD, MLModelState.UNDEPLOYED.name());

builder.put(MLModel.PLANNING_WORKER_NODES_FIELD, List.of());
builder.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, 0);

builder.put(MLModel.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli());
builder.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, 0);
updateRequest.index(ML_MODEL_INDEX).id(modelId).doc(builder.build());
bulkUpdateRequest.add(updateRequest);
}

bulkUpdateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
log.info("No nodes running these models: {}", Arrays.toString(modelIds));

try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
ActionListener<MLUndeployModelsResponse> listenerWithContextRestoration = ActionListener
.runBefore(listener, () -> threadContext.restore());
ActionListener<BulkResponse> bulkResponseListener = ActionListener.wrap(br -> {
log.debug("Successfully set the following modelId(s) to UNDEPLOY in index: {}", Arrays.toString(modelIds));
listenerWithContextRestoration.onResponse(new MLUndeployModelsResponse(response));
}, e -> {
String modelsNotFoundMessage = String
.format("Failed to set the following modelId(s) to UNDEPLOY in index: %s", Arrays.toString(modelIds));
log.error(modelsNotFoundMessage, e);

OpenSearchStatusException exception = new OpenSearchStatusException(
modelsNotFoundMessage + e.getMessage(),
RestStatus.INTERNAL_SERVER_ERROR
);
listenerWithContextRestoration.onFailure(exception);
});

client.bulk(bulkUpdateRequest, bulkResponseListener);
} catch (Exception e) {
log.error("Unexpected error while setting the following modelId(s) to UNDEPLOY in index: {}", Arrays.toString(modelIds), e);
listener.onFailure(e);
}

}

private void validateAccess(String modelId, String tenantId, ActionListener<Boolean> listener) {
User user = RestActionUtils.getUserContext(client);
boolean isSuperAdmin = isSuperAdminUserWrapper(clusterService, client);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,16 @@
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_THREAD_POOL_PREFIX;
import static org.opensearch.ml.task.MLPredictTaskRunnerTests.USER_STRING;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;

import org.junit.AfterClass;
Expand All @@ -37,10 +34,7 @@
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.FailedNodeException;
import org.opensearch.action.bulk.BulkRequest;
import org.opensearch.action.bulk.BulkResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.ClusterName;
import org.opensearch.cluster.service.ClusterService;
Expand All @@ -55,7 +49,6 @@
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodeResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
Expand Down Expand Up @@ -201,129 +194,6 @@ public void setup() throws IOException {
}).when(mlModelManager).getModel(any(), any(), any(), any(), isA(ActionListener.class));
}

public void testDoExecute_undeployModelIndex_WhenNoNodesServiceModel() {
String modelId = "someModelId";
MLModel mlModel = MLModel
.builder()
.user(User.parse(USER_STRING))
.modelGroupId("111")
.version("111")
.name("Test Model")
.modelId(modelId)
.algorithm(FunctionName.BATCH_RCF)
.content("content")
.totalChunks(2)
.isHidden(true)
.build();

doAnswer(invocation -> {
ActionListener<MLModel> listener = invocation.getArgument(4);
listener.onResponse(mlModel);
return null;
}).when(mlModelManager).getModel(any(), any(), any(), any(), isA(ActionListener.class));

doReturn(true).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client);

List<MLUndeployModelNodeResponse> responseList = new ArrayList<>();
List<FailedNodeException> failuresList = new ArrayList<>();
MLUndeployModelNodesResponse nodesResponse = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList);

// Send back a response with no nodes associated to the model. Thus, will write back to the model index that its UNDEPLOYED
doAnswer(invocation -> {
ActionListener<MLUndeployModelNodesResponse> listener = invocation.getArgument(2);
listener.onResponse(nodesResponse);
return null;
}).when(client).execute(any(), any(), isA(ActionListener.class));

ArgumentCaptor<BulkRequest> bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class);

// mock the bulk response that can be captured for inspecting the contents of the write to index
doAnswer(invocation -> {
ActionListener<BulkResponse> listener = invocation.getArgument(1);
BulkResponse bulkResponse = mock(BulkResponse.class);
when(bulkResponse.hasFailures()).thenReturn(false);
listener.onResponse(bulkResponse);
return null;
}).when(client).bulk(bulkRequestCaptor.capture(), any(ActionListener.class));

String[] modelIds = new String[] { modelId };
String[] nodeIds = new String[] { "test_node_id1", "test_node_id2" };
MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null);

transportUndeployModelsAction.doExecute(task, request, actionListener);

BulkRequest capturedBulkRequest = bulkRequestCaptor.getValue();
assertEquals(1, capturedBulkRequest.numberOfActions());
UpdateRequest updateRequest = (UpdateRequest) capturedBulkRequest.requests().get(0);

@SuppressWarnings("unchecked")
Map<String, Object> updateDoc = updateRequest.doc().sourceAsMap();
String modelIdFromBulkRequest = updateRequest.id();
String indexNameFromBulkRequest = updateRequest.index();

assertEquals("Check that the write happened at the model index", ML_MODEL_INDEX, indexNameFromBulkRequest);
assertEquals("Check that the result bulk write hit this specific modelId", modelId, modelIdFromBulkRequest);

assertEquals(MLModelState.UNDEPLOYED.name(), updateDoc.get(MLModel.MODEL_STATE_FIELD));
assertEquals(0, updateDoc.get(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD));
assertEquals(0, updateDoc.get(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD));
assertEquals(List.of(), updateDoc.get(MLModel.PLANNING_WORKER_NODES_FIELD));
assertTrue(updateDoc.containsKey(MLModel.LAST_UPDATED_TIME_FIELD));

verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));
}

public void testDoExecute_noBulkRequestFired_WhenSomeNodesServiceModel() {
String modelId = "someModelId";
MLModel mlModel = MLModel
.builder()
.user(User.parse(USER_STRING))
.modelGroupId("111")
.version("111")
.name("Test Model")
.modelId(modelId)
.algorithm(FunctionName.BATCH_RCF)
.content("content")
.totalChunks(2)
.isHidden(true)
.build();

doAnswer(invocation -> {
ActionListener<MLModel> listener = invocation.getArgument(4);
listener.onResponse(mlModel);
return null;
}).when(mlModelManager).getModel(any(), any(), any(), any(), isA(ActionListener.class));

doReturn(true).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client);

List<MLUndeployModelNodeResponse> responseList = new ArrayList<>();
responseList.add(mock(MLUndeployModelNodeResponse.class));
responseList.add(mock(MLUndeployModelNodeResponse.class));
List<FailedNodeException> failuresList = new ArrayList<>();
failuresList.add(mock(FailedNodeException.class));
failuresList.add(mock(FailedNodeException.class));

MLUndeployModelNodesResponse nodesResponse = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList);

// Send back a response with nodes associated to the model
doAnswer(invocation -> {
ActionListener<MLUndeployModelNodesResponse> listener = invocation.getArgument(2);
listener.onResponse(nodesResponse);
return null;
}).when(client).execute(any(), any(), isA(ActionListener.class));

String[] modelIds = new String[] { modelId };
String[] nodeIds = new String[] { "test_node_id1", "test_node_id2" };
MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null);

transportUndeployModelsAction.doExecute(task, request, actionListener);

verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
// Check that no bulk write occurred Since there were nodes servicing the model
verify(client, never()).bulk(any(BulkRequest.class), any(ActionListener.class));
}

@AfterClass
public static void cleanup() {
ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS);
Expand Down Expand Up @@ -351,28 +221,16 @@ public void testHiddenModelSuccess() {
List<MLUndeployModelNodeResponse> responseList = new ArrayList<>();
List<FailedNodeException> failuresList = new ArrayList<>();
MLUndeployModelNodesResponse response = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList);

doAnswer(invocation -> {
ActionListener<MLUndeployModelNodesResponse> listener = invocation.getArgument(2);
listener.onResponse(response);
return null;
}).when(client).execute(any(), any(), isA(ActionListener.class));

// Mock the client.bulk call
doAnswer(invocation -> {
ActionListener<BulkResponse> listener = invocation.getArgument(1);
BulkResponse bulkResponse = mock(BulkResponse.class);
when(bulkResponse.hasFailures()).thenReturn(false);
listener.onResponse(bulkResponse);
return null;
}).when(client).bulk(any(BulkRequest.class), any(ActionListener.class));

doReturn(true).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client);
MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null);
transportUndeployModelsAction.doExecute(task, request, actionListener);

verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));
}

public void testHiddenModelPermissionError() {
Expand Down Expand Up @@ -426,25 +284,15 @@ public void testDoExecute() {
listener.onResponse(response);
return null;
}).when(client).execute(any(), any(), isA(ActionListener.class));
// Mock the client.bulk call
doAnswer(invocation -> {
ActionListener<BulkResponse> listener = invocation.getArgument(1);
BulkResponse bulkResponse = mock(BulkResponse.class);
when(bulkResponse.hasFailures()).thenReturn(false);
listener.onResponse(bulkResponse);
return null;
}).when(client).bulk(any(BulkRequest.class), any(ActionListener.class));

MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null);
transportUndeployModelsAction.doExecute(task, request, actionListener);
verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));
}

public void testDoExecute_modelAccessControl_notEnabled() {
when(modelAccessControlHelper.isModelAccessControlEnabled()).thenReturn(false);
doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(3);
ActionListener<Boolean> listener = invocation.getArgument(4);
listener.onResponse(true);
return null;
}).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), isA(ActionListener.class));
Expand All @@ -463,7 +311,7 @@ public void testDoExecute_modelAccessControl_notEnabled() {
public void testDoExecute_validate_false() {
doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(3);
listener.onResponse(true);
listener.onResponse(false);
return null;
}).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), isA(ActionListener.class));

Expand Down
Loading