Skip to content

Commit

Permalink
[ML] Treat chunked content as streaming
Browse files Browse the repository at this point in the history
OpenAI/Apache returns a -1 content length for streaming.
Apache asserts content length is > 0 whenever the entity's
content is read. Now, we detect if the content is > 0
before trying to set it in the initial streaming response.
  • Loading branch information
prwhelan committed Sep 9, 2024
1 parent 72248e3 commit b415b00
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,13 @@ class StreamingHttpResultPublisher implements HttpAsyncResponseConsumer<HttpResp
@Override
public void responseReceived(HttpResponse httpResponse) throws IOException {
this.response = httpResponse;
var firstResponse = HttpResult.create(settings.getMaxResponseSize(), response);
this.queue.offer(() -> subscriber.onNext(firstResponse));
if (response.getEntity() == null || response.getEntity().getContentLength() <= 0) {
// on success, we may receive an empty content payload to initiate the stream
this.queue.offer(() -> subscriber.onNext(new HttpResult(response, new byte[0])));
} else {
var firstResponse = HttpResult.create(settings.getMaxResponseSize(), response);
this.queue.offer(() -> subscriber.onNext(firstResponse));
}
this.listener.onResponse(this);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.inference.external.http;

import org.apache.http.HttpEntity;
import org.apache.http.HttpResponse;
import org.apache.http.nio.ContentDecoder;
import org.apache.http.nio.IOControl;
Expand All @@ -17,6 +18,7 @@
import org.elasticsearch.threadpool.ThreadPool;
import org.junit.Before;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
Expand Down Expand Up @@ -84,6 +86,52 @@ public void testFirstResponseCallsListener() throws IOException {
assertThat("Listener's onResponse should be called when we receive a response", latch.getCount(), equalTo(0L));
}

/**
* When we receive an http response with an entity with no content
* Then we call the listener
* And we queue the initial payload
*/
public void testEmptyFirstResponseCallsListener() throws IOException {
var latch = new CountDownLatch(1);
var listener = ActionListener.<Flow.Publisher<HttpResult>>wrap(
r -> latch.countDown(),
e -> fail("Listener onFailure should never be called.")
);
publisher = new StreamingHttpResultPublisher(threadPool, settings, listener);

var response = mock(HttpResponse.class);
var entity = mock(HttpEntity.class);
when(entity.getContentLength()).thenReturn(-1L);
when(response.getEntity()).thenReturn(entity);
publisher.responseReceived(response);

assertThat("Listener's onResponse should be called when we receive a response", latch.getCount(), equalTo(0L));
}

/**
* When we receive an http response with an entity with content
* Then we call the listener
* And we queue the initial payload
*/
public void testNonEmptyFirstResponseCallsListener() throws IOException {
var latch = new CountDownLatch(1);
var listener = ActionListener.<Flow.Publisher<HttpResult>>wrap(
r -> latch.countDown(),
e -> fail("Listener onFailure should never be called.")
);
publisher = new StreamingHttpResultPublisher(threadPool, settings, listener);

when(settings.getMaxResponseSize()).thenReturn(ByteSizeValue.ofBytes(9000));
var response = mock(HttpResponse.class);
var entity = mock(HttpEntity.class);
when(entity.getContentLength()).thenReturn(5L);
when(entity.getContent()).thenReturn(new ByteArrayInputStream(message));
when(response.getEntity()).thenReturn(entity);
publisher.responseReceived(response);

assertThat("Listener's onResponse should be called when we receive a response", latch.getCount(), equalTo(0L));
}

/**
* This test combines 4 test since it's easier to verify the exchange of data at once.
*
Expand Down

0 comments on commit b415b00

Please sign in to comment.