Skip to content

Commit

Permalink
[ML] Inference duration and error metrics
Browse files Browse the repository at this point in the history
Add `es.inference.requests.time` metric around `infer` API.

As recommended by OTel spec, errors are determined by the
presence or absence of the `error.type` attribute in the metric.
"error.type" will be the http status code (as a string) if it is
available, otherwise it will be the name of the exception (e.g.
NullPointerException).

Additional notes:
- ApmInferenceStats is merged into InferenceStats. Originally we planned
  to have multiple implementations, but now we're only using APM.
- Request count is now always recorded, even when there are failures
  loading the endpoint configuration.
- Added a hook in streaming for cancel messages, so we can close the
  metrics when a user cancels the stream.
  • Loading branch information
prwhelan committed Oct 29, 2024
1 parent 9adbebb commit c83298d
Show file tree
Hide file tree
Showing 10 changed files with 815 additions and 151 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@
import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService;
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
import org.elasticsearch.xpack.inference.telemetry.ApmInferenceStats;
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;

import java.util.ArrayList;
Expand Down Expand Up @@ -234,7 +233,7 @@ public Collection<?> createComponents(PluginServices services) {
shardBulkInferenceActionFilter.set(actionFilter);

var meterRegistry = services.telemetryProvider().getMeterRegistry();
var stats = new PluginComponentBinding<>(InferenceStats.class, ApmInferenceStats.create(meterRegistry));
var stats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry));

return List.of(modelRegistry, registry, httpClientManager, stats);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.inference.InferenceServiceResults;
Expand All @@ -25,20 +26,21 @@
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
import org.elasticsearch.xpack.inference.telemetry.InferenceTimer;

import java.util.Set;
import java.util.stream.Collectors;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes;
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes;

public class TransportInferenceAction extends HandledTransportAction<InferenceAction.Request, InferenceAction.Response> {
private static final String STREAMING_INFERENCE_TASK_TYPE = "streaming_inference";
private static final String STREAMING_TASK_ACTION = "xpack/inference/streaming_inference[n]";

private static final Set<Class<? extends InferenceService>> supportsStreaming = Set.of();

private final ModelRegistry modelRegistry;
private final InferenceServiceRegistry serviceRegistry;
private final InferenceStats inferenceStats;
Expand All @@ -62,17 +64,22 @@ public TransportInferenceAction(

@Override
protected void doExecute(Task task, InferenceAction.Request request, ActionListener<InferenceAction.Response> listener) {
var timer = InferenceTimer.start();

ActionListener<UnparsedModel> getModelListener = listener.delegateFailureAndWrap((delegate, unparsedModel) -> {
var getModelListener = ActionListener.wrap((UnparsedModel unparsedModel) -> {
var service = serviceRegistry.getService(unparsedModel.service());
if (service.isEmpty()) {
listener.onFailure(unknownServiceException(unparsedModel.service(), request.getInferenceEntityId()));
var e = unknownServiceException(unparsedModel.service(), request.getInferenceEntityId());
listener.onFailure(e);
recordMetrics(unparsedModel, timer, e);
return;
}

if (request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false) {
// not the wildcard task type and not the model task type
listener.onFailure(incompatibleTaskTypeException(request.getTaskType(), unparsedModel.taskType()));
var e = incompatibleTaskTypeException(request.getTaskType(), unparsedModel.taskType());
listener.onFailure(e);
recordMetrics(unparsedModel, timer, e);
return;
}

Expand All @@ -83,20 +90,57 @@ protected void doExecute(Task task, InferenceAction.Request request, ActionListe
unparsedModel.settings(),
unparsedModel.secrets()
);
inferOnService(model, request, service.get(), delegate);
inferOnServiceWithMetrics(model, request, service.get(), timer, listener);
}, e -> {
listener.onFailure(e);
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(e));
});

modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), getModelListener);
}

private void inferOnService(
private void recordMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) {
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t));
}

private void inferOnServiceWithMetrics(
Model model,
InferenceAction.Request request,
InferenceService service,
InferenceTimer timer,
ActionListener<InferenceAction.Response> listener
) {
inferenceStats.requestCount().incrementBy(1, modelAttributes(model));
inferOnService(model, request, service, ActionListener.wrap(inferenceResults -> {
if (request.isStreaming()) {
var taskProcessor = streamingTaskManager.<ChunkedToXContent>create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION);
inferenceResults.publisher().subscribe(taskProcessor);

var instrumentedStream = new PublisherWithMetrics(timer, model);
taskProcessor.subscribe(instrumentedStream);

listener.onResponse(new InferenceAction.Response(inferenceResults, instrumentedStream));
} else {
listener.onResponse(new InferenceAction.Response(inferenceResults));
recordMetrics(model, timer, null);
}
}, e -> {
listener.onFailure(e);
recordMetrics(model, timer, e);
}));
}

private void recordMetrics(Model model, InferenceTimer timer, @Nullable Throwable t) {
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t));
}

private void inferOnService(
Model model,
InferenceAction.Request request,
InferenceService service,
ActionListener<InferenceServiceResults> listener
) {
if (request.isStreaming() == false || service.canStream(request.getTaskType())) {
inferenceStats.incrementRequestCount(model);
service.infer(
model,
request.getQuery(),
Expand All @@ -105,7 +149,7 @@ private void inferOnService(
request.getTaskSettings(),
request.getInputType(),
request.getInferenceTimeout(),
createListener(request, listener)
listener
);
} else {
listener.onFailure(unsupportedStreamingTaskException(request, service));
Expand Down Expand Up @@ -133,20 +177,6 @@ private ElasticsearchStatusException unsupportedStreamingTaskException(Inference
}
}

private ActionListener<InferenceServiceResults> createListener(
InferenceAction.Request request,
ActionListener<InferenceAction.Response> listener
) {
if (request.isStreaming()) {
return listener.delegateFailureAndWrap((l, inferenceResults) -> {
var taskProcessor = streamingTaskManager.<ChunkedToXContent>create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION);
inferenceResults.publisher().subscribe(taskProcessor);
l.onResponse(new InferenceAction.Response(inferenceResults, taskProcessor));
});
}
return listener.delegateFailureAndWrap((l, inferenceResults) -> l.onResponse(new InferenceAction.Response(inferenceResults)));
}

private static ElasticsearchStatusException unknownServiceException(String service, String inferenceId) {
return new ElasticsearchStatusException("Unknown service [{}] for model [{}]. ", RestStatus.BAD_REQUEST, service, inferenceId);
}
Expand All @@ -160,4 +190,46 @@ private static ElasticsearchStatusException incompatibleTaskTypeException(TaskTy
);
}

private class PublisherWithMetrics extends DelegatingProcessor<ChunkedToXContent, ChunkedToXContent> {
private final InferenceTimer timer;
private final Model model;

private PublisherWithMetrics(InferenceTimer timer, Model model) {
this.timer = timer;
this.model = model;
}

@Override
protected void next(ChunkedToXContent item) {
downstream().onNext(item);
}

@Override
public void onError(Throwable throwable) {
try {
super.onError(throwable);
} finally {
recordMetrics(model, timer, throwable);
}
}

@Override
protected void onCancel() {
try {
super.onCancel();
} finally {
recordMetrics(model, timer, null);
}
}

@Override
public void onComplete() {
try {
super.onComplete();
} finally {
recordMetrics(model, timer, null);
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,14 @@ public void request(long n) {
public void cancel() {
if (isClosed.compareAndSet(false, true) && upstream != null) {
upstream.cancel();
onCancel();
}
}
};
}

protected void onCancel() {}

@Override
public void onSubscribe(Flow.Subscription subscription) {
if (upstream != null) {
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,87 @@

package org.elasticsearch.xpack.inference.telemetry;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.telemetry.metric.LongCounter;
import org.elasticsearch.telemetry.metric.LongHistogram;
import org.elasticsearch.telemetry.metric.MeterRegistry;

public interface InferenceStats {
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
* Increment the counter for a particular value in a thread safe manner.
* @param model the model to increment request count for
*/
void incrementRequestCount(Model model);
import static java.util.Map.entry;
import static java.util.stream.Stream.concat;

InferenceStats NOOP = model -> {};
public record InferenceStats(LongCounter requestCount, LongHistogram inferenceDuration) {

public InferenceStats {
Objects.requireNonNull(requestCount);
Objects.requireNonNull(inferenceDuration);
}

public static InferenceStats create(MeterRegistry meterRegistry) {
return new InferenceStats(
meterRegistry.registerLongCounter(
"es.inference.requests.count.total",
"Inference API request counts for a particular service, task type, model ID",
"operations"
),
meterRegistry.registerLongHistogram(
"es.inference.requests.time",
"Inference API request counts for a particular service, task type, model ID",
"ms"
)
);
}

public static Map<String, Object> modelAttributes(Model model) {
return toMap(modelAttributeEntries(model));
}

private static Stream<Map.Entry<String, Object>> modelAttributeEntries(Model model) {
var stream = Stream.<Map.Entry<String, Object>>builder()
.add(entry("service", model.getConfigurations().getService()))
.add(entry("task_type", model.getTaskType().toString()));
if (model.getServiceSettings().modelId() != null) {
stream.add(entry("model_id", model.getServiceSettings().modelId()));
}
return stream.build();
}

private static Map<String, Object> toMap(Stream<Map.Entry<String, Object>> stream) {
return stream.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
}

public static Map<String, Object> responseAttributes(Model model, @Nullable Throwable t) {
return toMap(concat(modelAttributeEntries(model), errorAttributes(t)));
}

public static Map<String, Object> responseAttributes(UnparsedModel model, @Nullable Throwable t) {
var unknownModelAttributes = Stream.<Map.Entry<String, Object>>builder()
.add(entry("service", model.service()))
.add(entry("task_type", model.taskType().toString()))
.build();

return toMap(concat(unknownModelAttributes, errorAttributes(t)));
}

public static Map<String, Object> responseAttributes(@Nullable Throwable t) {
return toMap(errorAttributes(t));
}

private static Stream<Map.Entry<String, Object>> errorAttributes(@Nullable Throwable t) {
return switch (t) {
case null -> Stream.of(entry("status_code", 200));
case ElasticsearchStatusException ese -> Stream.<Map.Entry<String, Object>>builder()
.add(entry("status_code", ese.status().getStatus()))
.add(entry("error.type", String.valueOf(ese.status().getStatus())))
.build();
default -> Stream.of(entry("error.type", t.getClass().getSimpleName()));
};
}
}
Loading

0 comments on commit c83298d

Please sign in to comment.