diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index ebbf1e59e8b1f..7ee828702ba07 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -98,7 +98,6 @@ import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService; import org.elasticsearch.xpack.inference.services.mistral.MistralService; import org.elasticsearch.xpack.inference.services.openai.OpenAiService; -import org.elasticsearch.xpack.inference.telemetry.ApmInferenceStats; import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import java.util.ArrayList; @@ -234,7 +233,7 @@ public Collection createComponents(PluginServices services) { shardBulkInferenceActionFilter.set(actionFilter); var meterRegistry = services.telemetryProvider().getMeterRegistry(); - var stats = new PluginComponentBinding<>(InferenceStats.class, ApmInferenceStats.create(meterRegistry)); + var stats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry)); return List.of(modelRegistry, registry, httpClientManager, stats); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index e046e2aad463b..d109a170a57c4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -13,6 +13,7 @@ import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.xcontent.ChunkedToXContent; +import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.InferenceServiceResults; @@ -25,20 +26,21 @@ import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; +import org.elasticsearch.xpack.inference.common.DelegatingProcessor; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.telemetry.InferenceStats; +import org.elasticsearch.xpack.inference.telemetry.InferenceTimer; -import java.util.Set; import java.util.stream.Collectors; import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes; +import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes; public class TransportInferenceAction extends HandledTransportAction { private static final String STREAMING_INFERENCE_TASK_TYPE = "streaming_inference"; private static final String STREAMING_TASK_ACTION = "xpack/inference/streaming_inference[n]"; - private static final Set> supportsStreaming = Set.of(); - private final ModelRegistry modelRegistry; private final InferenceServiceRegistry serviceRegistry; private final InferenceStats inferenceStats; @@ -62,17 +64,22 @@ public TransportInferenceAction( @Override protected void doExecute(Task task, InferenceAction.Request request, ActionListener listener) { + var timer = InferenceTimer.start(); - ActionListener getModelListener = listener.delegateFailureAndWrap((delegate, unparsedModel) -> { + var getModelListener = ActionListener.wrap((UnparsedModel unparsedModel) -> { var service = serviceRegistry.getService(unparsedModel.service()); if (service.isEmpty()) { - listener.onFailure(unknownServiceException(unparsedModel.service(), request.getInferenceEntityId())); + var e = unknownServiceException(unparsedModel.service(), request.getInferenceEntityId()); + listener.onFailure(e); + recordMetrics(unparsedModel, timer, e); return; } if (request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false) { // not the wildcard task type and not the model task type - listener.onFailure(incompatibleTaskTypeException(request.getTaskType(), unparsedModel.taskType())); + var e = incompatibleTaskTypeException(request.getTaskType(), unparsedModel.taskType()); + listener.onFailure(e); + recordMetrics(unparsedModel, timer, e); return; } @@ -83,20 +90,57 @@ protected void doExecute(Task task, InferenceAction.Request request, ActionListe unparsedModel.settings(), unparsedModel.secrets() ); - inferOnService(model, request, service.get(), delegate); + inferOnServiceWithMetrics(model, request, service.get(), timer, listener); + }, e -> { + listener.onFailure(e); + inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(e)); }); modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), getModelListener); } - private void inferOnService( + private void recordMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) { + inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t)); + } + + private void inferOnServiceWithMetrics( Model model, InferenceAction.Request request, InferenceService service, + InferenceTimer timer, ActionListener listener + ) { + inferenceStats.requestCount().incrementBy(1, modelAttributes(model)); + inferOnService(model, request, service, ActionListener.wrap(inferenceResults -> { + if (request.isStreaming()) { + var taskProcessor = streamingTaskManager.create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION); + inferenceResults.publisher().subscribe(taskProcessor); + + var instrumentedStream = new PublisherWithMetrics(timer, model); + taskProcessor.subscribe(instrumentedStream); + + listener.onResponse(new InferenceAction.Response(inferenceResults, instrumentedStream)); + } else { + listener.onResponse(new InferenceAction.Response(inferenceResults)); + recordMetrics(model, timer, null); + } + }, e -> { + listener.onFailure(e); + recordMetrics(model, timer, e); + })); + } + + private void recordMetrics(Model model, InferenceTimer timer, @Nullable Throwable t) { + inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t)); + } + + private void inferOnService( + Model model, + InferenceAction.Request request, + InferenceService service, + ActionListener listener ) { if (request.isStreaming() == false || service.canStream(request.getTaskType())) { - inferenceStats.incrementRequestCount(model); service.infer( model, request.getQuery(), @@ -105,7 +149,7 @@ private void inferOnService( request.getTaskSettings(), request.getInputType(), request.getInferenceTimeout(), - createListener(request, listener) + listener ); } else { listener.onFailure(unsupportedStreamingTaskException(request, service)); @@ -133,20 +177,6 @@ private ElasticsearchStatusException unsupportedStreamingTaskException(Inference } } - private ActionListener createListener( - InferenceAction.Request request, - ActionListener listener - ) { - if (request.isStreaming()) { - return listener.delegateFailureAndWrap((l, inferenceResults) -> { - var taskProcessor = streamingTaskManager.create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION); - inferenceResults.publisher().subscribe(taskProcessor); - l.onResponse(new InferenceAction.Response(inferenceResults, taskProcessor)); - }); - } - return listener.delegateFailureAndWrap((l, inferenceResults) -> l.onResponse(new InferenceAction.Response(inferenceResults))); - } - private static ElasticsearchStatusException unknownServiceException(String service, String inferenceId) { return new ElasticsearchStatusException("Unknown service [{}] for model [{}]. ", RestStatus.BAD_REQUEST, service, inferenceId); } @@ -160,4 +190,46 @@ private static ElasticsearchStatusException incompatibleTaskTypeException(TaskTy ); } + private class PublisherWithMetrics extends DelegatingProcessor { + private final InferenceTimer timer; + private final Model model; + + private PublisherWithMetrics(InferenceTimer timer, Model model) { + this.timer = timer; + this.model = model; + } + + @Override + protected void next(ChunkedToXContent item) { + downstream().onNext(item); + } + + @Override + public void onError(Throwable throwable) { + try { + super.onError(throwable); + } finally { + recordMetrics(model, timer, throwable); + } + } + + @Override + protected void onCancel() { + try { + super.onCancel(); + } finally { + recordMetrics(model, timer, null); + } + } + + @Override + public void onComplete() { + try { + super.onComplete(); + } finally { + recordMetrics(model, timer, null); + } + } + } + } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java index fc2d890dd89e6..03e794e42c3a2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java @@ -61,11 +61,14 @@ public void request(long n) { public void cancel() { if (isClosed.compareAndSet(false, true) && upstream != null) { upstream.cancel(); + onCancel(); } } }; } + protected void onCancel() {} + @Override public void onSubscribe(Flow.Subscription subscription) { if (upstream != null) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/ApmInferenceStats.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/ApmInferenceStats.java deleted file mode 100644 index ae14a0792dead..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/ApmInferenceStats.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.telemetry; - -import org.elasticsearch.inference.Model; -import org.elasticsearch.telemetry.metric.LongCounter; -import org.elasticsearch.telemetry.metric.MeterRegistry; - -import java.util.HashMap; -import java.util.Objects; - -public class ApmInferenceStats implements InferenceStats { - private final LongCounter inferenceAPMRequestCounter; - - public ApmInferenceStats(LongCounter inferenceAPMRequestCounter) { - this.inferenceAPMRequestCounter = Objects.requireNonNull(inferenceAPMRequestCounter); - } - - @Override - public void incrementRequestCount(Model model) { - var service = model.getConfigurations().getService(); - var taskType = model.getTaskType(); - var modelId = model.getServiceSettings().modelId(); - - var attributes = new HashMap(5); - attributes.put("service", service); - attributes.put("task_type", taskType.toString()); - if (modelId != null) { - attributes.put("model_id", modelId); - } - - inferenceAPMRequestCounter.incrementBy(1, attributes); - } - - public static ApmInferenceStats create(MeterRegistry meterRegistry) { - return new ApmInferenceStats( - meterRegistry.registerLongCounter( - "es.inference.requests.count.total", - "Inference API request counts for a particular service, task type, model ID", - "operations" - ) - ); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceStats.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceStats.java index d080e818e45fc..afdbc21bae319 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceStats.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceStats.java @@ -7,15 +7,87 @@ package org.elasticsearch.xpack.inference.telemetry; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.telemetry.metric.LongCounter; +import org.elasticsearch.telemetry.metric.LongHistogram; +import org.elasticsearch.telemetry.metric.MeterRegistry; -public interface InferenceStats { +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; +import java.util.stream.Stream; - /** - * Increment the counter for a particular value in a thread safe manner. - * @param model the model to increment request count for - */ - void incrementRequestCount(Model model); +import static java.util.Map.entry; +import static java.util.stream.Stream.concat; - InferenceStats NOOP = model -> {}; +public record InferenceStats(LongCounter requestCount, LongHistogram inferenceDuration) { + + public InferenceStats { + Objects.requireNonNull(requestCount); + Objects.requireNonNull(inferenceDuration); + } + + public static InferenceStats create(MeterRegistry meterRegistry) { + return new InferenceStats( + meterRegistry.registerLongCounter( + "es.inference.requests.count.total", + "Inference API request counts for a particular service, task type, model ID", + "operations" + ), + meterRegistry.registerLongHistogram( + "es.inference.requests.time", + "Inference API request counts for a particular service, task type, model ID", + "ms" + ) + ); + } + + public static Map modelAttributes(Model model) { + return toMap(modelAttributeEntries(model)); + } + + private static Stream> modelAttributeEntries(Model model) { + var stream = Stream.>builder() + .add(entry("service", model.getConfigurations().getService())) + .add(entry("task_type", model.getTaskType().toString())); + if (model.getServiceSettings().modelId() != null) { + stream.add(entry("model_id", model.getServiceSettings().modelId())); + } + return stream.build(); + } + + private static Map toMap(Stream> stream) { + return stream.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + } + + public static Map responseAttributes(Model model, @Nullable Throwable t) { + return toMap(concat(modelAttributeEntries(model), errorAttributes(t))); + } + + public static Map responseAttributes(UnparsedModel model, @Nullable Throwable t) { + var unknownModelAttributes = Stream.>builder() + .add(entry("service", model.service())) + .add(entry("task_type", model.taskType().toString())) + .build(); + + return toMap(concat(unknownModelAttributes, errorAttributes(t))); + } + + public static Map responseAttributes(@Nullable Throwable t) { + return toMap(errorAttributes(t)); + } + + private static Stream> errorAttributes(@Nullable Throwable t) { + return switch (t) { + case null -> Stream.of(entry("status_code", 200)); + case ElasticsearchStatusException ese -> Stream.>builder() + .add(entry("status_code", ese.status().getStatus())) + .add(entry("error.type", String.valueOf(ese.status().getStatus()))) + .build(); + default -> Stream.of(entry("error.type", t.getClass().getSimpleName())); + }; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceTimer.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceTimer.java new file mode 100644 index 0000000000000..d43f4954edb52 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceTimer.java @@ -0,0 +1,33 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.telemetry; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.Objects; + +public record InferenceTimer(Instant startTime, Clock clock) { + + public InferenceTimer { + Objects.requireNonNull(startTime); + Objects.requireNonNull(clock); + } + + public static InferenceTimer start() { + return start(Clock.systemUTC()); + } + + public static InferenceTimer start(Clock clock) { + return new InferenceTimer(clock.instant(), clock); + } + + public long elapsedMillis() { + return Duration.between(startTime(), clock().instant()).toMillis(); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java new file mode 100644 index 0000000000000..0ed9cbf56b3fa --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java @@ -0,0 +1,354 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.common.xcontent.ChunkedToXContent; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.telemetry.InferenceStats; +import org.junit.Before; +import org.mockito.ArgumentCaptor; + +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.Flow; +import java.util.function.Consumer; + +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.isA; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.assertArg; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class TransportInferenceActionTests extends ESTestCase { + private static final String serviceId = "serviceId"; + private static final TaskType taskType = TaskType.COMPLETION; + private static final String inferenceId = "inferenceEntityId"; + private ModelRegistry modelRegistry; + private InferenceServiceRegistry serviceRegistry; + private InferenceStats inferenceStats; + private StreamingTaskManager streamingTaskManager; + private TransportInferenceAction action; + + @Before + public void setUp() throws Exception { + super.setUp(); + TransportService transportService = mock(); + ActionFilters actionFilters = mock(); + modelRegistry = mock(); + serviceRegistry = mock(); + inferenceStats = new InferenceStats(mock(), mock()); + streamingTaskManager = mock(); + action = new TransportInferenceAction( + transportService, + actionFilters, + modelRegistry, + serviceRegistry, + inferenceStats, + streamingTaskManager + ); + } + + public void testMetricsAfterModelRegistryError() { + var expectedException = new IllegalStateException("hello"); + var expectedError = expectedException.getClass().getSimpleName(); + + doAnswer(ans -> { + ActionListener listener = ans.getArgument(1); + listener.onFailure(expectedException); + return null; + }).when(modelRegistry).getModelWithSecrets(any(), any()); + + var listener = doExecute(taskType); + verify(listener).onFailure(same(expectedException)); + + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), nullValue()); + assertThat(attributes.get("task_type"), nullValue()); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), nullValue()); + assertThat(attributes.get("error.type"), is(expectedError)); + })); + } + + private ActionListener doExecute(TaskType taskType) { + return doExecute(taskType, false); + } + + private ActionListener doExecute(TaskType taskType, boolean stream) { + InferenceAction.Request request = mock(); + when(request.getInferenceEntityId()).thenReturn(inferenceId); + when(request.getTaskType()).thenReturn(taskType); + when(request.isStreaming()).thenReturn(stream); + ActionListener listener = mock(); + action.doExecute(mock(), request, listener); + return listener; + } + + public void testMetricsAfterMissingService() { + mockModelRegistry(taskType); + + when(serviceRegistry.getService(any())).thenReturn(Optional.empty()); + + var listener = doExecute(taskType); + + verify(listener).onFailure(assertArg(e -> { + assertThat(e, isA(ElasticsearchStatusException.class)); + assertThat(e.getMessage(), is("Unknown service [" + serviceId + "] for model [" + inferenceId + "]. ")); + assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST)); + })); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus())); + assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus()))); + })); + } + + private void mockModelRegistry(TaskType expectedTaskType) { + var unparsedModel = new UnparsedModel(inferenceId, expectedTaskType, serviceId, Map.of(), Map.of()); + doAnswer(ans -> { + ActionListener listener = ans.getArgument(1); + listener.onResponse(unparsedModel); + return null; + }).when(modelRegistry).getModelWithSecrets(any(), any()); + } + + public void testMetricsAfterUnknownTaskType() { + var modelTaskType = TaskType.RERANK; + var requestTaskType = TaskType.SPARSE_EMBEDDING; + mockModelRegistry(modelTaskType); + when(serviceRegistry.getService(any())).thenReturn(Optional.of(mock())); + + var listener = doExecute(requestTaskType); + + verify(listener).onFailure(assertArg(e -> { + assertThat(e, isA(ElasticsearchStatusException.class)); + assertThat( + e.getMessage(), + is( + "Incompatible task_type, the requested type [" + + requestTaskType + + "] does not match the model type [" + + modelTaskType + + "]" + ) + ); + assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST)); + })); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(modelTaskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus())); + assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus()))); + })); + } + + public void testMetricsAfterInferError() { + var expectedException = new IllegalStateException("hello"); + var expectedError = expectedException.getClass().getSimpleName(); + mockService(listener -> listener.onFailure(expectedException)); + + var listener = doExecute(taskType); + + verify(listener).onFailure(same(expectedException)); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), nullValue()); + assertThat(attributes.get("error.type"), is(expectedError)); + })); + } + + public void testMetricsAfterStreamUnsupported() { + var expectedStatus = RestStatus.METHOD_NOT_ALLOWED; + var expectedError = String.valueOf(expectedStatus.getStatus()); + mockService(l -> {}); + + var listener = doExecute(taskType, true); + + verify(listener).onFailure(assertArg(e -> { + assertThat(e, isA(ElasticsearchStatusException.class)); + var ese = (ElasticsearchStatusException) e; + assertThat(ese.getMessage(), is("Streaming is not allowed for service [" + serviceId + "].")); + assertThat(ese.status(), is(expectedStatus)); + })); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(expectedStatus.getStatus())); + assertThat(attributes.get("error.type"), is(expectedError)); + })); + } + + public void testMetricsAfterInferSuccess() { + mockService(listener -> listener.onResponse(mock())); + + var listener = doExecute(taskType); + + verify(listener).onResponse(any()); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(200)); + assertThat(attributes.get("error.type"), nullValue()); + })); + } + + public void testMetricsAfterStreamInferSuccess() { + mockStreamResponse(Flow.Subscriber::onComplete); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(200)); + assertThat(attributes.get("error.type"), nullValue()); + })); + } + + public void testMetricsAfterStreamInferFailure() { + var expectedException = new IllegalStateException("hello"); + var expectedError = expectedException.getClass().getSimpleName(); + mockStreamResponse(subscriber -> { + subscriber.subscribe(mock()); + subscriber.onError(expectedException); + }); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), nullValue()); + assertThat(attributes.get("error.type"), is(expectedError)); + })); + } + + public void testMetricsAfterStreamCancel() { + var response = mockStreamResponse(s -> s.onSubscribe(mock())); + response.subscribe(new Flow.Subscriber<>() { + @Override + public void onSubscribe(Flow.Subscription subscription) { + subscription.cancel(); + } + + @Override + public void onNext(ChunkedToXContent item) { + + } + + @Override + public void onError(Throwable throwable) { + + } + + @Override + public void onComplete() { + + } + }); + + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(200)); + assertThat(attributes.get("error.type"), nullValue()); + })); + } + + private Flow.Publisher mockStreamResponse(Consumer> action) { + mockService(true, Set.of(), listener -> { + Flow.Processor taskProcessor = mock(); + doAnswer(innerAns -> { + action.accept(innerAns.getArgument(0)); + return null; + }).when(taskProcessor).subscribe(any()); + when(streamingTaskManager.create(any(), any())).thenReturn(taskProcessor); + var inferenceServiceResults = mock(InferenceServiceResults.class); + when(inferenceServiceResults.publisher()).thenReturn(mock()); + listener.onResponse(inferenceServiceResults); + }); + + var listener = doExecute(taskType, true); + var captor = ArgumentCaptor.forClass(InferenceAction.Response.class); + verify(listener).onResponse(captor.capture()); + assertTrue(captor.getValue().isStreaming()); + assertNotNull(captor.getValue().publisher()); + return captor.getValue().publisher(); + } + + private void mockService(Consumer> listenerAction) { + mockService(false, Set.of(), listenerAction); + } + + private void mockService( + boolean stream, + Set supportedStreamingTasks, + Consumer> listenerAction + ) { + InferenceService service = mock(); + Model model = mockModel(); + when(service.parsePersistedConfigWithSecrets(any(), any(), any(), any())).thenReturn(model); + when(service.name()).thenReturn(serviceId); + + when(service.canStream(any())).thenReturn(stream); + when(service.supportedStreamingTasks()).thenReturn(supportedStreamingTasks); + doAnswer(ans -> { + listenerAction.accept(ans.getArgument(7)); + return null; + }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any()); + mockModelAndServiceRegistry(service); + } + + private Model mockModel() { + Model model = mock(); + ModelConfigurations modelConfigurations = mock(); + when(modelConfigurations.getService()).thenReturn(serviceId); + when(model.getConfigurations()).thenReturn(modelConfigurations); + when(model.getTaskType()).thenReturn(taskType); + when(model.getServiceSettings()).thenReturn(mock()); + return model; + } + + private void mockModelAndServiceRegistry(InferenceService service) { + var unparsedModel = new UnparsedModel(inferenceId, taskType, serviceId, Map.of(), Map.of()); + doAnswer(ans -> { + ActionListener listener = ans.getArgument(1); + listener.onResponse(unparsedModel); + return null; + }).when(modelRegistry).getModelWithSecrets(any(), any()); + + when(serviceRegistry.getService(any())).thenReturn(Optional.of(service)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/telemetry/ApmInferenceStatsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/telemetry/ApmInferenceStatsTests.java deleted file mode 100644 index 1a5aba5f89ad2..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/telemetry/ApmInferenceStatsTests.java +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.telemetry; - -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.ServiceSettings; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.telemetry.metric.LongCounter; -import org.elasticsearch.telemetry.metric.MeterRegistry; -import org.elasticsearch.test.ESTestCase; - -import java.util.Map; - -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -public class ApmInferenceStatsTests extends ESTestCase { - - public void testRecordWithModel() { - var longCounter = mock(LongCounter.class); - - var stats = new ApmInferenceStats(longCounter); - - stats.incrementRequestCount(model("service", TaskType.ANY, "modelId")); - - verify(longCounter).incrementBy( - eq(1L), - eq(Map.of("service", "service", "task_type", TaskType.ANY.toString(), "model_id", "modelId")) - ); - } - - public void testRecordWithoutModel() { - var longCounter = mock(LongCounter.class); - - var stats = new ApmInferenceStats(longCounter); - - stats.incrementRequestCount(model("service", TaskType.ANY, null)); - - verify(longCounter).incrementBy(eq(1L), eq(Map.of("service", "service", "task_type", TaskType.ANY.toString()))); - } - - public void testCreation() { - assertNotNull(ApmInferenceStats.create(MeterRegistry.NOOP)); - } - - private Model model(String service, TaskType taskType, String modelId) { - var configuration = mock(ModelConfigurations.class); - when(configuration.getService()).thenReturn(service); - var settings = mock(ServiceSettings.class); - if (modelId != null) { - when(settings.modelId()).thenReturn(modelId); - } - - var model = mock(Model.class); - when(model.getTaskType()).thenReturn(taskType); - when(model.getConfigurations()).thenReturn(configuration); - when(model.getServiceSettings()).thenReturn(settings); - - return model; - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/telemetry/InferenceStatsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/telemetry/InferenceStatsTests.java new file mode 100644 index 0000000000000..d9327295ba5fa --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/telemetry/InferenceStatsTests.java @@ -0,0 +1,217 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.telemetry; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.telemetry.metric.LongCounter; +import org.elasticsearch.telemetry.metric.LongHistogram; +import org.elasticsearch.telemetry.metric.MeterRegistry; +import org.elasticsearch.test.ESTestCase; + +import java.util.Map; + +import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes; +import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.ArgumentMatchers.assertArg; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class InferenceStatsTests extends ESTestCase { + + public void testRecordWithModel() { + var longCounter = mock(LongCounter.class); + var stats = new InferenceStats(longCounter, mock()); + + stats.requestCount().incrementBy(1, modelAttributes(model("service", TaskType.ANY, "modelId"))); + + verify(longCounter).incrementBy( + eq(1L), + eq(Map.of("service", "service", "task_type", TaskType.ANY.toString(), "model_id", "modelId")) + ); + } + + public void testRecordWithoutModel() { + var longCounter = mock(LongCounter.class); + var stats = new InferenceStats(longCounter, mock()); + + stats.requestCount().incrementBy(1, modelAttributes(model("service", TaskType.ANY, null))); + + verify(longCounter).incrementBy(eq(1L), eq(Map.of("service", "service", "task_type", TaskType.ANY.toString()))); + } + + public void testCreation() { + assertNotNull(InferenceStats.create(MeterRegistry.NOOP)); + } + + public void testRecordDurationWithoutError() { + var expectedLong = randomLong(); + var histogramCounter = mock(LongHistogram.class); + var stats = new InferenceStats(mock(), histogramCounter); + + stats.inferenceDuration().record(expectedLong, responseAttributes(model("service", TaskType.ANY, "modelId"), null)); + + verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> { + assertThat(attributes.get("service"), is("service")); + assertThat(attributes.get("task_type"), is(TaskType.ANY.toString())); + assertThat(attributes.get("model_id"), is("modelId")); + assertThat(attributes.get("status_code"), is(200)); + assertThat(attributes.get("error.type"), nullValue()); + })); + } + + /** + * "If response status code was sent or received and status indicates an error according to HTTP span status definition, + * error.type SHOULD be set to the status code number (represented as a string)" + * - https://opentelemetry.io/docs/specs/semconv/http/http-metrics/ + */ + public void testRecordDurationWithElasticsearchStatusException() { + var expectedLong = randomLong(); + var histogramCounter = mock(LongHistogram.class); + var stats = new InferenceStats(mock(), histogramCounter); + var statusCode = RestStatus.BAD_REQUEST; + var exception = new ElasticsearchStatusException("hello", statusCode); + var expectedError = String.valueOf(statusCode.getStatus()); + + stats.inferenceDuration().record(expectedLong, responseAttributes(model("service", TaskType.ANY, "modelId"), exception)); + + verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> { + assertThat(attributes.get("service"), is("service")); + assertThat(attributes.get("task_type"), is(TaskType.ANY.toString())); + assertThat(attributes.get("model_id"), is("modelId")); + assertThat(attributes.get("status_code"), is(statusCode.getStatus())); + assertThat(attributes.get("error.type"), is(expectedError)); + })); + } + + /** + * "If the request fails with an error before response status code was sent or received, + * error.type SHOULD be set to exception type" + * - https://opentelemetry.io/docs/specs/semconv/http/http-metrics/ + */ + public void testRecordDurationWithOtherException() { + var expectedLong = randomLong(); + var histogramCounter = mock(LongHistogram.class); + var stats = new InferenceStats(mock(), histogramCounter); + var exception = new IllegalStateException("ahh"); + var expectedError = exception.getClass().getSimpleName(); + + stats.inferenceDuration().record(expectedLong, responseAttributes(model("service", TaskType.ANY, "modelId"), exception)); + + verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> { + assertThat(attributes.get("service"), is("service")); + assertThat(attributes.get("task_type"), is(TaskType.ANY.toString())); + assertThat(attributes.get("model_id"), is("modelId")); + assertThat(attributes.get("status_code"), nullValue()); + assertThat(attributes.get("error.type"), is(expectedError)); + })); + } + + public void testRecordDurationWithUnparsedModelAndElasticsearchStatusException() { + var expectedLong = randomLong(); + var histogramCounter = mock(LongHistogram.class); + var stats = new InferenceStats(mock(), histogramCounter); + var statusCode = RestStatus.BAD_REQUEST; + var exception = new ElasticsearchStatusException("hello", statusCode); + var expectedError = String.valueOf(statusCode.getStatus()); + + var unparsedModel = new UnparsedModel("inferenceEntityId", TaskType.ANY, "service", Map.of(), Map.of()); + + stats.inferenceDuration().record(expectedLong, responseAttributes(unparsedModel, exception)); + + verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> { + assertThat(attributes.get("service"), is("service")); + assertThat(attributes.get("task_type"), is(TaskType.ANY.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(statusCode.getStatus())); + assertThat(attributes.get("error.type"), is(expectedError)); + })); + } + + public void testRecordDurationWithUnparsedModelAndOtherException() { + var expectedLong = randomLong(); + var histogramCounter = mock(LongHistogram.class); + var stats = new InferenceStats(mock(), histogramCounter); + var exception = new IllegalStateException("ahh"); + var expectedError = exception.getClass().getSimpleName(); + + var unparsedModel = new UnparsedModel("inferenceEntityId", TaskType.ANY, "service", Map.of(), Map.of()); + + stats.inferenceDuration().record(expectedLong, responseAttributes(unparsedModel, exception)); + + verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> { + assertThat(attributes.get("service"), is("service")); + assertThat(attributes.get("task_type"), is(TaskType.ANY.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), nullValue()); + assertThat(attributes.get("error.type"), is(expectedError)); + })); + } + + public void testRecordDurationWithUnknownModelAndElasticsearchStatusException() { + var expectedLong = randomLong(); + var histogramCounter = mock(LongHistogram.class); + var stats = new InferenceStats(mock(), histogramCounter); + var statusCode = RestStatus.BAD_REQUEST; + var exception = new ElasticsearchStatusException("hello", statusCode); + var expectedError = String.valueOf(statusCode.getStatus()); + + stats.inferenceDuration().record(expectedLong, responseAttributes(exception)); + + verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> { + assertThat(attributes.get("service"), nullValue()); + assertThat(attributes.get("task_type"), nullValue()); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(statusCode.getStatus())); + assertThat(attributes.get("error.type"), is(expectedError)); + })); + } + + public void testRecordDurationWithUnknownModelAndOtherException() { + var expectedLong = randomLong(); + var histogramCounter = mock(LongHistogram.class); + var stats = new InferenceStats(mock(), histogramCounter); + var exception = new IllegalStateException("ahh"); + var expectedError = exception.getClass().getSimpleName(); + + stats.inferenceDuration().record(expectedLong, responseAttributes(exception)); + + verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> { + assertThat(attributes.get("service"), nullValue()); + assertThat(attributes.get("task_type"), nullValue()); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), nullValue()); + assertThat(attributes.get("error.type"), is(expectedError)); + })); + } + + private Model model(String service, TaskType taskType, String modelId) { + var configuration = mock(ModelConfigurations.class); + when(configuration.getService()).thenReturn(service); + var settings = mock(ServiceSettings.class); + if (modelId != null) { + when(settings.modelId()).thenReturn(modelId); + } + + var model = mock(Model.class); + when(model.getTaskType()).thenReturn(taskType); + when(model.getConfigurations()).thenReturn(configuration); + when(model.getServiceSettings()).thenReturn(settings); + + return model; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/telemetry/InferenceTimerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/telemetry/InferenceTimerTests.java new file mode 100644 index 0000000000000..72b29d176f8c1 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/telemetry/InferenceTimerTests.java @@ -0,0 +1,32 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.telemetry; + +import org.elasticsearch.test.ESTestCase; + +import java.time.Clock; +import java.time.Instant; +import java.time.temporal.ChronoUnit; + +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class InferenceTimerTests extends ESTestCase { + + public void testElapsedMillis() { + var expectedDuration = randomLongBetween(10, 300); + + var startTime = Instant.now(); + var clock = mock(Clock.class); + when(clock.instant()).thenReturn(startTime).thenReturn(startTime.plus(expectedDuration, ChronoUnit.MILLIS)); + var timer = InferenceTimer.start(clock); + + assertThat(expectedDuration, is(timer.elapsedMillis())); + } +}