From b415b007a386ff2dea188dbf1cf1ab81269dd20c Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Mon, 9 Sep 2024 16:12:10 -0400 Subject: [PATCH] [ML] Treat chunked content as streaming OpenAI/Apache returns a -1 content length for streaming. Apache asserts content length is > 0 whenever the entity's content is read. Now, we detect if the content is > 0 before trying to set it in the initial streaming response. --- .../http/StreamingHttpResultPublisher.java | 9 +++- .../StreamingHttpResultPublisherTests.java | 48 +++++++++++++++++++ 2 files changed, 55 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResultPublisher.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResultPublisher.java index 49a9048a69df1..aad262617d0b3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResultPublisher.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResultPublisher.java @@ -71,8 +71,13 @@ class StreamingHttpResultPublisher implements HttpAsyncResponseConsumer subscriber.onNext(firstResponse)); + if (response.getEntity() == null || response.getEntity().getContentLength() <= 0) { + // on success, we may receive an empty content payload to initiate the stream + this.queue.offer(() -> subscriber.onNext(new HttpResult(response, new byte[0]))); + } else { + var firstResponse = HttpResult.create(settings.getMaxResponseSize(), response); + this.queue.offer(() -> subscriber.onNext(firstResponse)); + } this.listener.onResponse(this); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResultPublisherTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResultPublisherTests.java index 92a332fe545e3..337d7f5b2b236 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResultPublisherTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResultPublisherTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.external.http; +import org.apache.http.HttpEntity; import org.apache.http.HttpResponse; import org.apache.http.nio.ContentDecoder; import org.apache.http.nio.IOControl; @@ -17,6 +18,7 @@ import org.elasticsearch.threadpool.ThreadPool; import org.junit.Before; +import java.io.ByteArrayInputStream; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; @@ -84,6 +86,52 @@ public void testFirstResponseCallsListener() throws IOException { assertThat("Listener's onResponse should be called when we receive a response", latch.getCount(), equalTo(0L)); } + /** + * When we receive an http response with an entity with no content + * Then we call the listener + * And we queue the initial payload + */ + public void testEmptyFirstResponseCallsListener() throws IOException { + var latch = new CountDownLatch(1); + var listener = ActionListener.>wrap( + r -> latch.countDown(), + e -> fail("Listener onFailure should never be called.") + ); + publisher = new StreamingHttpResultPublisher(threadPool, settings, listener); + + var response = mock(HttpResponse.class); + var entity = mock(HttpEntity.class); + when(entity.getContentLength()).thenReturn(-1L); + when(response.getEntity()).thenReturn(entity); + publisher.responseReceived(response); + + assertThat("Listener's onResponse should be called when we receive a response", latch.getCount(), equalTo(0L)); + } + + /** + * When we receive an http response with an entity with content + * Then we call the listener + * And we queue the initial payload + */ + public void testNonEmptyFirstResponseCallsListener() throws IOException { + var latch = new CountDownLatch(1); + var listener = ActionListener.>wrap( + r -> latch.countDown(), + e -> fail("Listener onFailure should never be called.") + ); + publisher = new StreamingHttpResultPublisher(threadPool, settings, listener); + + when(settings.getMaxResponseSize()).thenReturn(ByteSizeValue.ofBytes(9000)); + var response = mock(HttpResponse.class); + var entity = mock(HttpEntity.class); + when(entity.getContentLength()).thenReturn(5L); + when(entity.getContent()).thenReturn(new ByteArrayInputStream(message)); + when(response.getEntity()).thenReturn(entity); + publisher.responseReceived(response); + + assertThat("Listener's onResponse should be called when we receive a response", latch.getCount(), equalTo(0L)); + } + /** * This test combines 4 test since it's easier to verify the exchange of data at once. *