Skip to content

Commit

Permalink
[ML] Stream Cohere Completion (elastic#114080)
Browse files Browse the repository at this point in the history
Implement and enable streaming for Cohere chat completions (v1).

Includes processor for ND JSON streaming responses.

Co-authored-by: Elastic Machine <[email protected]>
  • Loading branch information
prwhelan and elasticmachine committed Oct 7, 2024
1 parent 46b0696 commit c8390ef
Show file tree
Hide file tree
Showing 17 changed files with 585 additions and 26 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/114080.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 114080
summary: Stream Cohere Completion
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
public abstract class DelegatingProcessor<T, R> implements Flow.Processor<T, R> {
private static final Logger log = LogManager.getLogger(DelegatingProcessor.class);
private final AtomicLong pendingRequests = new AtomicLong();
private final AtomicBoolean isClosed = new AtomicBoolean(false);
protected final AtomicBoolean isClosed = new AtomicBoolean(false);
private Flow.Subscriber<? super R> downstream;
private Flow.Subscription upstream;

Expand Down Expand Up @@ -49,7 +49,7 @@ private Flow.Subscription forwardingSubscription() {
@Override
public void request(long n) {
if (isClosed.get()) {
downstream.onComplete(); // shouldn't happen, but reinforce that we're no longer listening
downstream.onComplete();
} else if (upstream != null) {
upstream.request(n);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,19 @@
package org.elasticsearch.xpack.inference.external.cohere;

import org.apache.logging.log4j.Logger;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.cohere.CohereErrorResponseEntity;
import org.elasticsearch.xpack.inference.external.response.streaming.NewlineDelimitedByteProcessor;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;

import java.util.concurrent.Flow;

import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody;

/**
Expand All @@ -33,9 +38,11 @@
public class CohereResponseHandler extends BaseResponseHandler {
static final String TEXTS_ARRAY_TOO_LARGE_MESSAGE_MATCHER = "invalid request: total number of texts must be at most";
static final String TEXTS_ARRAY_ERROR_MESSAGE = "Received a texts array too large response";
private final boolean canHandleStreamingResponse;

public CohereResponseHandler(String requestType, ResponseParser parseFunction) {
public CohereResponseHandler(String requestType, ResponseParser parseFunction, boolean canHandleStreamingResponse) {
super(requestType, parseFunction, CohereErrorResponseEntity::fromResponse);
this.canHandleStreamingResponse = canHandleStreamingResponse;
}

@Override
Expand All @@ -45,6 +52,20 @@ public void validateResponse(ThrottlerManager throttlerManager, Logger logger, R
checkForEmptyBody(throttlerManager, logger, request, result);
}

@Override
public boolean canHandleStreamingResponses() {
return canHandleStreamingResponse;
}

@Override
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
var ndProcessor = new NewlineDelimitedByteProcessor();
var cohereProcessor = new CohereStreamingProcessor();
flow.subscribe(ndProcessor);
ndProcessor.subscribe(cohereProcessor);
return new StreamingChatCompletionResults(cohereProcessor);
}

/**
* Validates the status code throws an RetryException if not in the range [200, 300).
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.cohere;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;

import java.io.IOException;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.Map;
import java.util.Optional;

class CohereStreamingProcessor extends DelegatingProcessor<Deque<String>, StreamingChatCompletionResults.Results> {
private static final Logger log = LogManager.getLogger(CohereStreamingProcessor.class);

@Override
protected void next(Deque<String> item) throws Exception {
if (item.isEmpty()) {
// discard empty result and go to the next
upstream().request(1);
return;
}

var results = new ArrayDeque<StreamingChatCompletionResults.Result>(item.size());
for (String json : item) {
try (var jsonParser = jsonParser(json)) {
var responseMap = jsonParser.map();
var eventType = (String) responseMap.get("event_type");
switch (eventType) {
case "text-generation" -> parseText(responseMap).ifPresent(results::offer);
case "stream-end" -> validateResponse(responseMap);
case "stream-start", "search-queries-generation", "search-results", "citation-generation", "tool-calls-generation",
"tool-calls-chunk" -> {
log.debug("Skipping event type [{}] for line [{}].", eventType, item);
}
default -> throw new IOException("Unknown eventType found: " + eventType);
}
} catch (ElasticsearchStatusException e) {
throw e;
} catch (Exception e) {
log.warn("Failed to parse json from cohere: {}", json);
throw e;
}
}

if (results.isEmpty()) {
upstream().request(1);
} else {
downstream().onNext(new StreamingChatCompletionResults.Results(results));
}
}

private static XContentParser jsonParser(String line) throws IOException {
return XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, line);
}

private Optional<StreamingChatCompletionResults.Result> parseText(Map<String, Object> responseMap) throws IOException {
var text = (String) responseMap.get("text");
if (text != null) {
return Optional.of(new StreamingChatCompletionResults.Result(text));
} else {
throw new IOException("Null text found in text-generation cohere event");
}
}

private void validateResponse(Map<String, Object> responseMap) {
var finishReason = (String) responseMap.get("finish_reason");
switch (finishReason) {
case "ERROR", "ERROR_TOXIC" -> throw new ElasticsearchStatusException(
"Cohere stopped the stream due to an error: {}",
RestStatus.INTERNAL_SERVER_ERROR,
parseErrorMessage(responseMap)
);
case "ERROR_LIMIT" -> throw new ElasticsearchStatusException(
"Cohere stopped the stream due to an error: {}",
RestStatus.TOO_MANY_REQUESTS,
parseErrorMessage(responseMap)
);
}
}

@SuppressWarnings("unchecked")
private String parseErrorMessage(Map<String, Object> responseMap) {
var innerResponseMap = (Map<String, Object>) responseMap.get("response");
return (String) innerResponseMap.get("text");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import org.elasticsearch.xpack.inference.external.response.cohere.CohereCompletionResponseEntity;
import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModel;

import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;

Expand All @@ -30,7 +29,7 @@ public class CohereCompletionRequestManager extends CohereRequestManager {
private static final ResponseHandler HANDLER = createCompletionHandler();

private static ResponseHandler createCompletionHandler() {
return new CohereResponseHandler("cohere completion", CohereCompletionResponseEntity::fromResponse);
return new CohereResponseHandler("cohere completion", CohereCompletionResponseEntity::fromResponse, true);
}

public static CohereCompletionRequestManager of(CohereCompletionModel model, ThreadPool threadPool) {
Expand All @@ -51,8 +50,10 @@ public void execute(
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
List<String> docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs();
CohereCompletionRequest request = new CohereCompletionRequest(docsInput, model);
var docsOnly = DocumentsOnlyInput.of(inferenceInputs);
var docsInput = docsOnly.getInputs();
var stream = docsOnly.stream();
CohereCompletionRequest request = new CohereCompletionRequest(docsInput, model, stream);

execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public class CohereEmbeddingsRequestManager extends CohereRequestManager {
private static final ResponseHandler HANDLER = createEmbeddingsHandler();

private static ResponseHandler createEmbeddingsHandler() {
return new CohereResponseHandler("cohere text embedding", CohereEmbeddingsResponseEntity::fromResponse);
return new CohereResponseHandler("cohere text embedding", CohereEmbeddingsResponseEntity::fromResponse, false);
}

public static CohereEmbeddingsRequestManager of(CohereEmbeddingsModel model, ThreadPool threadPool) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public class CohereRerankRequestManager extends CohereRequestManager {
private static final ResponseHandler HANDLER = createCohereResponseHandler();

private static ResponseHandler createCohereResponseHandler() {
return new CohereResponseHandler("cohere rerank", (request, response) -> CohereRankedResponseEntity.fromResponse(response));
return new CohereResponseHandler("cohere rerank", (request, response) -> CohereRankedResponseEntity.fromResponse(response), false);
}

public static CohereRerankRequestManager of(CohereRerankModel model, ThreadPool threadPool) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,30 +25,28 @@
import java.util.Objects;

public class CohereCompletionRequest extends CohereRequest {

private final CohereAccount account;

private final List<String> input;

private final String modelId;

private final String inferenceEntityId;
private final boolean stream;

public CohereCompletionRequest(List<String> input, CohereCompletionModel model) {
public CohereCompletionRequest(List<String> input, CohereCompletionModel model, boolean stream) {
Objects.requireNonNull(model);

this.account = CohereAccount.of(model, CohereCompletionRequest::buildDefaultUri);
this.input = Objects.requireNonNull(input);
this.modelId = model.getServiceSettings().modelId();
this.inferenceEntityId = model.getInferenceEntityId();
this.stream = stream;
}

@Override
public HttpRequest createHttpRequest() {
HttpPost httpPost = new HttpPost(account.uri());

ByteArrayEntity byteEntity = new ByteArrayEntity(
Strings.toString(new CohereCompletionRequestEntity(input, modelId)).getBytes(StandardCharsets.UTF_8)
Strings.toString(new CohereCompletionRequestEntity(input, modelId, isStreaming())).getBytes(StandardCharsets.UTF_8)
);
httpPost.setEntity(byteEntity);

Expand All @@ -62,6 +60,11 @@ public String getInferenceEntityId() {
return inferenceEntityId;
}

@Override
public boolean isStreaming() {
return stream;
}

@Override
public URI getURI() {
return account.uri();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
import java.util.List;
import java.util.Objects;

public record CohereCompletionRequestEntity(List<String> input, @Nullable String model) implements ToXContentObject {
public record CohereCompletionRequestEntity(List<String> input, @Nullable String model, boolean stream) implements ToXContentObject {

private static final String MESSAGE_FIELD = "message";

private static final String MODEL = "model";
private static final String STREAM = "stream";

public CohereCompletionRequestEntity {
Objects.requireNonNull(input);
Expand All @@ -36,6 +36,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(MODEL, model);
}

if (stream) {
builder.field(STREAM, true);
}

builder.endObject();

return builder;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.response.streaming;

import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
import org.elasticsearch.xpack.inference.external.http.HttpResult;

import java.nio.charset.StandardCharsets;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.regex.Pattern;

/**
* Processes HttpResult bytes into lines separated by newlines, delimited by either line-feed or carriage-return line-feed.
* Downstream is responsible for validating the structure of the lines after they have been separated.
* Because Upstream (Apache) can send us a single line split between two HttpResults, this processor will aggregate bytes from the last
* HttpResult and append them to the front of the next HttpResult.
* When onComplete is called, the last batch is always flushed to the downstream onNext.
*/
public class NewlineDelimitedByteProcessor extends DelegatingProcessor<HttpResult, Deque<String>> {
private static final Pattern END_OF_LINE_REGEX = Pattern.compile("\\n|\\r\\n");
private volatile String previousTokens = "";

@Override
protected void next(HttpResult item) {
// discard empty result and go to the next
if (item.isBodyEmpty()) {
upstream().request(1);
return;
}

var body = previousTokens + new String(item.body(), StandardCharsets.UTF_8);
var lines = END_OF_LINE_REGEX.split(body, -1); // -1 because we actually want trailing empty strings

var results = new ArrayDeque<String>(lines.length);
for (var i = 0; i < lines.length - 1; i++) {
var line = lines[i].trim();
if (line.isBlank() == false) {
results.offer(line);
}
}

previousTokens = lines[lines.length - 1].trim();

if (results.isEmpty()) {
upstream().request(1);
} else {
downstream().onNext(results);
}
}

@Override
public void onComplete() {
if (previousTokens.isBlank()) {
super.onComplete();
} else if (isClosed.compareAndSet(false, true)) {
var results = new ArrayDeque<String>(1);
results.offer(previousTokens);
downstream().onNext(results);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

import java.util.List;
import java.util.Map;
import java.util.Set;

import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
Expand Down Expand Up @@ -288,4 +289,9 @@ static SimilarityMeasure defaultSimilarity() {
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.ML_INFERENCE_RATE_LIMIT_SETTINGS_ADDED;
}

@Override
public Set<TaskType> supportedStreamingTasks() {
return COMPLETION_ONLY;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ private static void callCheckForFailureStatusCode(int statusCode, @Nullable Stri
var mockRequest = mock(Request.class);
when(mockRequest.getInferenceEntityId()).thenReturn(modelId);
var httpResult = new HttpResult(httpResponse, errorMessage == null ? new byte[] {} : responseJson.getBytes(StandardCharsets.UTF_8));
var handler = new CohereResponseHandler("", (request, result) -> null);
var handler = new CohereResponseHandler("", (request, result) -> null, false);

handler.checkForFailureStatusCode(mockRequest, httpResult);
}
Expand Down
Loading

0 comments on commit c8390ef

Please sign in to comment.