diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index 7ed036090fe5b..5cf349b96a4f7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -1068,74 +1068,88 @@ public void getInferenceStats(String[] modelIds, @Nullable TaskId parentTaskId, delegate, client.admin().cluster()::health ); - }).>andThen((delegate, clusterHealthResponse) -> { - if (clusterHealthResponse.isTimedOut()) { - logger.error( - "getInferenceStats Timed out waiting for index [{}] to be available, this will probably cause the request to fail", - MlStatsIndex.indexPattern() - ); - } - - MultiSearchRequest multiSearchRequest = new MultiSearchRequest(); - Arrays.stream(modelIds).map(TrainedModelProvider::buildStatsSearchRequest).forEach(multiSearchRequest::add); - if (multiSearchRequest.requests().isEmpty()) { - listener.onResponse(Collections.emptyList()); - return; - } - if (parentTaskId != null) { - multiSearchRequest.setParentTask(parentTaskId); - } - executeAsyncWithOrigin( + }) + .>andThen( + client.threadPool().executor(MachineLearning.UTILITY_THREAD_POOL_NAME), client.threadPool().getThreadContext(), - ML_ORIGIN, - multiSearchRequest, - ActionListener.wrap(responses -> { - List allStats = new ArrayList<>(modelIds.length); - int modelIndex = 0; - assert responses.getResponses().length == modelIds.length - : "mismatch between search response size and models requested"; - for (MultiSearchResponse.Item response : responses.getResponses()) { - if (response.isFailure()) { - if (ExceptionsHelper.unwrapCause(response.getFailure()) instanceof ResourceNotFoundException) { - modelIndex++; - continue; - } - logger.error( - () -> "[" + Strings.arrayToCommaDelimitedString(modelIds) + "] search failed for models", - response.getFailure() - ); - listener.onFailure( - ExceptionsHelper.serverError( - "Searching for stats for models [{}] failed", - response.getFailure(), - Strings.arrayToCommaDelimitedString(modelIds) - ) - ); - return; - } - try { - InferenceStats inferenceStats = handleMultiNodeStatsResponse(response.getResponse(), modelIds[modelIndex++]); - if (inferenceStats != null) { - allStats.add(inferenceStats); - } - } catch (Exception e) { - listener.onFailure(e); - return; - } + (delegate, clusterHealthResponse) -> { + if (clusterHealthResponse.isTimedOut()) { + logger.error( + "getInferenceStats Timed out waiting for index [{}] to be available, " + + "this will probably cause the request to fail", + MlStatsIndex.indexPattern() + ); } - listener.onResponse(allStats); - }, e -> { - Throwable unwrapped = ExceptionsHelper.unwrapCause(e); - if (unwrapped instanceof ResourceNotFoundException) { - listener.onResponse(Collections.emptyList()); + + MultiSearchRequest multiSearchRequest = new MultiSearchRequest(); + Arrays.stream(modelIds).map(TrainedModelProvider::buildStatsSearchRequest).forEach(multiSearchRequest::add); + if (multiSearchRequest.requests().isEmpty()) { + delegate.onResponse(Collections.emptyList()); return; } - listener.onFailure((Exception) unwrapped); - }), - client::multiSearch - ); + if (parentTaskId != null) { + multiSearchRequest.setParentTask(parentTaskId); + } + executeAsyncWithOrigin( + client.threadPool().getThreadContext(), + ML_ORIGIN, + multiSearchRequest, + ActionListener.wrap(responses -> { + List allStats = new ArrayList<>(modelIds.length); + int modelIndex = 0; + assert responses.getResponses().length == modelIds.length + : "mismatch between search response size and models requested"; + for (MultiSearchResponse.Item response : responses.getResponses()) { + if (response.isFailure()) { + if (ExceptionsHelper.unwrapCause(response.getFailure()) instanceof ResourceNotFoundException) { + modelIndex++; + continue; + } + logger.error( + () -> "[" + Strings.arrayToCommaDelimitedString(modelIds) + "] search failed for models", + response.getFailure() + ); + delegate.onFailure( + ExceptionsHelper.serverError( + "Searching for stats for models [{}] failed", + response.getFailure(), + Strings.arrayToCommaDelimitedString(modelIds) + ) + ); + return; + } + try { + InferenceStats inferenceStats = handleMultiNodeStatsResponse( + response.getResponse(), + modelIds[modelIndex++] + ); + if (inferenceStats != null) { + allStats.add(inferenceStats); + } + } catch (Exception e) { + delegate.onFailure(e); + return; + } + } + delegate.onResponse(allStats); + }, e -> { + Throwable unwrapped = ExceptionsHelper.unwrapCause(e); + if (unwrapped instanceof ResourceNotFoundException) { + delegate.onResponse(Collections.emptyList()); + return; + } + delegate.onFailure((Exception) unwrapped); + }), + client::multiSearch + ); - }).addListener(listener); + } + ) + .addListener( + listener, + client.threadPool().executor(MachineLearning.UTILITY_THREAD_POOL_NAME), + client.threadPool().getThreadContext() + ); } private static SearchRequest buildStatsSearchRequest(String modelId) {