Skip to content

Commit

Permalink
improvements from review
Browse files Browse the repository at this point in the history
  • Loading branch information
maxhniebergall committed Oct 31, 2024
1 parent a2e763b commit d97e769
Showing 1 changed file with 77 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1068,74 +1068,88 @@ public void getInferenceStats(String[] modelIds, @Nullable TaskId parentTaskId,
delegate,
client.admin().cluster()::health
);
}).<List<InferenceStats>>andThen((delegate, clusterHealthResponse) -> {
if (clusterHealthResponse.isTimedOut()) {
logger.error(
"getInferenceStats Timed out waiting for index [{}] to be available, this will probably cause the request to fail",
MlStatsIndex.indexPattern()
);
}

MultiSearchRequest multiSearchRequest = new MultiSearchRequest();
Arrays.stream(modelIds).map(TrainedModelProvider::buildStatsSearchRequest).forEach(multiSearchRequest::add);
if (multiSearchRequest.requests().isEmpty()) {
listener.onResponse(Collections.emptyList());
return;
}
if (parentTaskId != null) {
multiSearchRequest.setParentTask(parentTaskId);
}
executeAsyncWithOrigin(
})
.<List<InferenceStats>>andThen(
client.threadPool().executor(MachineLearning.UTILITY_THREAD_POOL_NAME),
client.threadPool().getThreadContext(),
ML_ORIGIN,
multiSearchRequest,
ActionListener.<MultiSearchResponse>wrap(responses -> {
List<InferenceStats> allStats = new ArrayList<>(modelIds.length);
int modelIndex = 0;
assert responses.getResponses().length == modelIds.length
: "mismatch between search response size and models requested";
for (MultiSearchResponse.Item response : responses.getResponses()) {
if (response.isFailure()) {
if (ExceptionsHelper.unwrapCause(response.getFailure()) instanceof ResourceNotFoundException) {
modelIndex++;
continue;
}
logger.error(
() -> "[" + Strings.arrayToCommaDelimitedString(modelIds) + "] search failed for models",
response.getFailure()
);
listener.onFailure(
ExceptionsHelper.serverError(
"Searching for stats for models [{}] failed",
response.getFailure(),
Strings.arrayToCommaDelimitedString(modelIds)
)
);
return;
}
try {
InferenceStats inferenceStats = handleMultiNodeStatsResponse(response.getResponse(), modelIds[modelIndex++]);
if (inferenceStats != null) {
allStats.add(inferenceStats);
}
} catch (Exception e) {
listener.onFailure(e);
return;
}
(delegate, clusterHealthResponse) -> {
if (clusterHealthResponse.isTimedOut()) {
logger.error(
"getInferenceStats Timed out waiting for index [{}] to be available, "
+ "this will probably cause the request to fail",
MlStatsIndex.indexPattern()
);
}
listener.onResponse(allStats);
}, e -> {
Throwable unwrapped = ExceptionsHelper.unwrapCause(e);
if (unwrapped instanceof ResourceNotFoundException) {
listener.onResponse(Collections.emptyList());

MultiSearchRequest multiSearchRequest = new MultiSearchRequest();
Arrays.stream(modelIds).map(TrainedModelProvider::buildStatsSearchRequest).forEach(multiSearchRequest::add);
if (multiSearchRequest.requests().isEmpty()) {
delegate.onResponse(Collections.emptyList());
return;
}
listener.onFailure((Exception) unwrapped);
}),
client::multiSearch
);
if (parentTaskId != null) {
multiSearchRequest.setParentTask(parentTaskId);
}
executeAsyncWithOrigin(
client.threadPool().getThreadContext(),
ML_ORIGIN,
multiSearchRequest,
ActionListener.<MultiSearchResponse>wrap(responses -> {
List<InferenceStats> allStats = new ArrayList<>(modelIds.length);
int modelIndex = 0;
assert responses.getResponses().length == modelIds.length
: "mismatch between search response size and models requested";
for (MultiSearchResponse.Item response : responses.getResponses()) {
if (response.isFailure()) {
if (ExceptionsHelper.unwrapCause(response.getFailure()) instanceof ResourceNotFoundException) {
modelIndex++;
continue;
}
logger.error(
() -> "[" + Strings.arrayToCommaDelimitedString(modelIds) + "] search failed for models",
response.getFailure()
);
delegate.onFailure(
ExceptionsHelper.serverError(
"Searching for stats for models [{}] failed",
response.getFailure(),
Strings.arrayToCommaDelimitedString(modelIds)
)
);
return;
}
try {
InferenceStats inferenceStats = handleMultiNodeStatsResponse(
response.getResponse(),
modelIds[modelIndex++]
);
if (inferenceStats != null) {
allStats.add(inferenceStats);
}
} catch (Exception e) {
delegate.onFailure(e);
return;
}
}
delegate.onResponse(allStats);
}, e -> {
Throwable unwrapped = ExceptionsHelper.unwrapCause(e);
if (unwrapped instanceof ResourceNotFoundException) {
delegate.onResponse(Collections.emptyList());
return;
}
delegate.onFailure((Exception) unwrapped);
}),
client::multiSearch
);

}).addListener(listener);
}
)
.addListener(
listener,
client.threadPool().executor(MachineLearning.UTILITY_THREAD_POOL_NAME),
client.threadPool().getThreadContext()
);
}

private static SearchRequest buildStatsSearchRequest(String modelId) {
Expand Down

0 comments on commit d97e769

Please sign in to comment.