Skip to content

Commit

Permalink
[ML] Stream Anthropic Completion
Browse files Browse the repository at this point in the history
Enable chat completion streaming responses for Anthropic's server sent
events.
  • Loading branch information
prwhelan committed Oct 8, 2024
1 parent 10f6f25 commit e447dee
Show file tree
Hide file tree
Showing 11 changed files with 418 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,20 @@

import org.apache.logging.log4j.Logger;
import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.ErrorMessageResponseEntity;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;

import java.util.concurrent.Flow;

import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody;
import static org.elasticsearch.xpack.inference.external.http.retry.ResponseHandlerUtils.getFirstHeaderOrUnknown;

Expand All @@ -41,8 +47,11 @@ public class AnthropicResponseHandler extends BaseResponseHandler {

static final String SERVER_BUSY = "Received an Anthropic server is temporarily overloaded status code";

public AnthropicResponseHandler(String requestType, ResponseParser parseFunction) {
private final boolean canHandleStreamingResponses;

public AnthropicResponseHandler(String requestType, ResponseParser parseFunction, boolean canHandleStreamingResponses) {
super(requestType, parseFunction, ErrorMessageResponseEntity::fromResponse);
this.canHandleStreamingResponses = canHandleStreamingResponses;
}

@Override
Expand All @@ -52,6 +61,20 @@ public void validateResponse(ThrottlerManager throttlerManager, Logger logger, R
checkForEmptyBody(throttlerManager, logger, request, result);
}

@Override
public boolean canHandleStreamingResponses() {
return canHandleStreamingResponses;
}

@Override
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
var sseProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
var anthropicProcessor = new AnthropicStreamingProcessor();
sseProcessor.subscribe(anthropicProcessor);
flow.subscribe(sseProcessor);
return new StreamingChatCompletionResults(anthropicProcessor);
}

/**
* Validates the status code throws an RetryException if not in the range [200, 300).
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.anthropic;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField;

import java.io.IOException;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.Optional;

import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField;

public class AnthropicStreamingProcessor extends DelegatingProcessor<Deque<ServerSentEvent>, StreamingChatCompletionResults.Results> {
private static final Logger log = LogManager.getLogger(AnthropicStreamingProcessor.class);
private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in Anthropic chat completions response";

@Override
protected void next(Deque<ServerSentEvent> item) throws Exception {
if (item.isEmpty()) {
upstream().request(1);
return;
}

var results = new ArrayDeque<StreamingChatCompletionResults.Result>(item.size());
for (var event : item) {
if (event.name() == ServerSentEventField.DATA && event.hasValue()) {
try (var parser = parser(event.value())) {
var eventType = eventType(parser);
switch (eventType) {
case "error" -> {
onError(parseError(parser));
return;
}
case "content_block_start" -> {
parseStartBlock(parser).ifPresent(results::offer);
}
case "content_block_delta" -> {
parseMessage(parser).ifPresent(results::offer);
}
case "message_start", "message_stop", "message_delta", "content_block_stop", "ping" -> {
log.debug("Skipping event type [{}] for line [{}].", eventType, item);
}
default -> {
// "handle unknown events gracefully" https://docs.anthropic.com/en/api/messages-streaming#other-events
// we'll ignore unknown events
log.debug("Unknown event type [{}] for line [{}].", eventType, item);
}
}
} catch (Exception e) {
log.warn("Failed to parse line {}", event);
throw e;
}
}
}

if (results.isEmpty()) {
upstream().request(1);
} else {
downstream().onNext(new StreamingChatCompletionResults.Results(results));
}
}

private Throwable parseError(XContentParser parser) throws IOException {
positionParserAtTokenAfterField(parser, "error", FAILED_TO_FIND_FIELD_TEMPLATE);
var type = parseString(parser, "type");
var message = parseString(parser, "message");
var statusCode = switch (type) {
case "invalid_request_error" -> RestStatus.BAD_REQUEST;
case "authentication_error" -> RestStatus.UNAUTHORIZED;
case "permission_error" -> RestStatus.FORBIDDEN;
case "not_found_error" -> RestStatus.NOT_FOUND;
case "request_too_large" -> RestStatus.REQUEST_ENTITY_TOO_LARGE;
case "rate_limit_error" -> RestStatus.TOO_MANY_REQUESTS;
default -> RestStatus.INTERNAL_SERVER_ERROR;
};
return new ElasticsearchStatusException(message, statusCode);
}

private Optional<StreamingChatCompletionResults.Result> parseStartBlock(XContentParser parser) throws IOException {
positionParserAtTokenAfterField(parser, "content_block", FAILED_TO_FIND_FIELD_TEMPLATE);
var text = parseString(parser, "text");
return text.isBlank() ? Optional.empty() : Optional.of(new StreamingChatCompletionResults.Result(text));
}

private Optional<StreamingChatCompletionResults.Result> parseMessage(XContentParser parser) throws IOException {
positionParserAtTokenAfterField(parser, "delta", FAILED_TO_FIND_FIELD_TEMPLATE);
var text = parseString(parser, "text");
return text.isBlank() ? Optional.empty() : Optional.of(new StreamingChatCompletionResults.Result(text));
}

private static XContentParser parser(String line) throws IOException {
return XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, line);
}

private static String eventType(XContentParser parser) throws IOException {
moveToFirstToken(parser);
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
return parseString(parser, "type");
}

private static String parseString(XContentParser parser, String fieldName) throws IOException {
positionParserAtTokenAfterField(parser, fieldName, FAILED_TO_FIND_FIELD_TEMPLATE);
ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser);
return parser.text();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import org.elasticsearch.xpack.inference.external.response.anthropic.AnthropicChatCompletionResponseEntity;
import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionModel;

import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;

Expand Down Expand Up @@ -47,13 +46,15 @@ public void execute(
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
List<String> docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs();
AnthropicChatCompletionRequest request = new AnthropicChatCompletionRequest(docsInput, model);
var docsOnly = DocumentsOnlyInput.of(inferenceInputs);
var docsInput = docsOnly.getInputs();
var stream = docsOnly.stream();
AnthropicChatCompletionRequest request = new AnthropicChatCompletionRequest(docsInput, model, stream);

execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}

private static ResponseHandler createCompletionHandler() {
return new AnthropicResponseHandler("anthropic completions", AnthropicChatCompletionResponseEntity::fromResponse);
return new AnthropicResponseHandler("anthropic completions", AnthropicChatCompletionResponseEntity::fromResponse, true);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,21 @@ public class AnthropicChatCompletionRequest implements Request {
private final AnthropicAccount account;
private final List<String> input;
private final AnthropicChatCompletionModel model;
private final boolean stream;

public AnthropicChatCompletionRequest(List<String> input, AnthropicChatCompletionModel model) {
public AnthropicChatCompletionRequest(List<String> input, AnthropicChatCompletionModel model, boolean stream) {
this.account = AnthropicAccount.of(model);
this.input = Objects.requireNonNull(input);
this.model = Objects.requireNonNull(model);
this.stream = stream;
}

@Override
public HttpRequest createHttpRequest() {
HttpPost httpPost = new HttpPost(account.uri());

ByteArrayEntity byteEntity = new ByteArrayEntity(
Strings.toString(new AnthropicChatCompletionRequestEntity(input, model.getServiceSettings(), model.getTaskSettings()))
Strings.toString(new AnthropicChatCompletionRequestEntity(input, model.getServiceSettings(), model.getTaskSettings(), stream))
.getBytes(StandardCharsets.UTF_8)
);
httpPost.setEntity(byteEntity);
Expand Down Expand Up @@ -75,4 +77,9 @@ public String getInferenceEntityId() {
return model.getInferenceEntityId();
}

@Override
public boolean isStreaming() {
return stream;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,23 @@ public class AnthropicChatCompletionRequestEntity implements ToXContentObject {
private static final String TEMPERATURE_FIELD = "temperature";
private static final String TOP_P_FIELD = "top_p";
private static final String TOP_K_FIELD = "top_k";
private static final String STREAM = "stream";

private final List<String> messages;
private final AnthropicChatCompletionServiceSettings serviceSettings;
private final AnthropicChatCompletionTaskSettings taskSettings;
private final boolean stream;

public AnthropicChatCompletionRequestEntity(
List<String> messages,
AnthropicChatCompletionServiceSettings serviceSettings,
AnthropicChatCompletionTaskSettings taskSettings
AnthropicChatCompletionTaskSettings taskSettings,
boolean stream
) {
this.messages = Objects.requireNonNull(messages);
this.serviceSettings = Objects.requireNonNull(serviceSettings);
this.taskSettings = Objects.requireNonNull(taskSettings);
this.stream = stream;
}

@Override
Expand Down Expand Up @@ -77,6 +81,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(TOP_K_FIELD, taskSettings.topK());
}

if (stream) {
builder.field(STREAM, true);
}

builder.endObject();

return builder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

import java.util.List;
import java.util.Map;
import java.util.Set;

import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
Expand Down Expand Up @@ -199,4 +200,9 @@ protected void doChunkedInfer(
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.ML_ANTHROPIC_INTEGRATION_ADDED;
}

@Override
public Set<TaskType> supportedStreamingTasks() {
return COMPLETION_ONLY;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ private static void callCheckForFailureStatusCode(int statusCode, String inferen
var mockRequest = mock(Request.class);
when(mockRequest.getInferenceEntityId()).thenReturn(inferenceEntityId);
var httpResult = new HttpResult(httpResponse, new byte[] {});
var handler = new AnthropicResponseHandler("", (request, result) -> null);
var handler = new AnthropicResponseHandler("", (request, result) -> null, false);

handler.checkForFailureStatusCode(mockRequest, httpResult);
}
Expand Down
Loading

0 comments on commit e447dee

Please sign in to comment.