forked from elastic/elasticsearch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ML] Stream Cohere Completion (elastic#114080)
Implement and enable streaming for Cohere chat completions (v1). Includes processor for ND JSON streaming responses. Co-authored-by: Elastic Machine <[email protected]>
- Loading branch information
1 parent
46b0696
commit c8390ef
Showing
17 changed files
with
585 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
pr: 114080 | ||
summary: Stream Cohere Completion | ||
area: Machine Learning | ||
type: enhancement | ||
issues: [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
101 changes: 101 additions & 0 deletions
101
...main/java/org/elasticsearch/xpack/inference/external/cohere/CohereStreamingProcessor.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.inference.external.cohere; | ||
|
||
import org.apache.logging.log4j.LogManager; | ||
import org.apache.logging.log4j.Logger; | ||
import org.elasticsearch.ElasticsearchStatusException; | ||
import org.elasticsearch.rest.RestStatus; | ||
import org.elasticsearch.xcontent.XContentFactory; | ||
import org.elasticsearch.xcontent.XContentParser; | ||
import org.elasticsearch.xcontent.XContentParserConfiguration; | ||
import org.elasticsearch.xcontent.XContentType; | ||
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; | ||
import org.elasticsearch.xpack.inference.common.DelegatingProcessor; | ||
|
||
import java.io.IOException; | ||
import java.util.ArrayDeque; | ||
import java.util.Deque; | ||
import java.util.Map; | ||
import java.util.Optional; | ||
|
||
class CohereStreamingProcessor extends DelegatingProcessor<Deque<String>, StreamingChatCompletionResults.Results> { | ||
private static final Logger log = LogManager.getLogger(CohereStreamingProcessor.class); | ||
|
||
@Override | ||
protected void next(Deque<String> item) throws Exception { | ||
if (item.isEmpty()) { | ||
// discard empty result and go to the next | ||
upstream().request(1); | ||
return; | ||
} | ||
|
||
var results = new ArrayDeque<StreamingChatCompletionResults.Result>(item.size()); | ||
for (String json : item) { | ||
try (var jsonParser = jsonParser(json)) { | ||
var responseMap = jsonParser.map(); | ||
var eventType = (String) responseMap.get("event_type"); | ||
switch (eventType) { | ||
case "text-generation" -> parseText(responseMap).ifPresent(results::offer); | ||
case "stream-end" -> validateResponse(responseMap); | ||
case "stream-start", "search-queries-generation", "search-results", "citation-generation", "tool-calls-generation", | ||
"tool-calls-chunk" -> { | ||
log.debug("Skipping event type [{}] for line [{}].", eventType, item); | ||
} | ||
default -> throw new IOException("Unknown eventType found: " + eventType); | ||
} | ||
} catch (ElasticsearchStatusException e) { | ||
throw e; | ||
} catch (Exception e) { | ||
log.warn("Failed to parse json from cohere: {}", json); | ||
throw e; | ||
} | ||
} | ||
|
||
if (results.isEmpty()) { | ||
upstream().request(1); | ||
} else { | ||
downstream().onNext(new StreamingChatCompletionResults.Results(results)); | ||
} | ||
} | ||
|
||
private static XContentParser jsonParser(String line) throws IOException { | ||
return XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, line); | ||
} | ||
|
||
private Optional<StreamingChatCompletionResults.Result> parseText(Map<String, Object> responseMap) throws IOException { | ||
var text = (String) responseMap.get("text"); | ||
if (text != null) { | ||
return Optional.of(new StreamingChatCompletionResults.Result(text)); | ||
} else { | ||
throw new IOException("Null text found in text-generation cohere event"); | ||
} | ||
} | ||
|
||
private void validateResponse(Map<String, Object> responseMap) { | ||
var finishReason = (String) responseMap.get("finish_reason"); | ||
switch (finishReason) { | ||
case "ERROR", "ERROR_TOXIC" -> throw new ElasticsearchStatusException( | ||
"Cohere stopped the stream due to an error: {}", | ||
RestStatus.INTERNAL_SERVER_ERROR, | ||
parseErrorMessage(responseMap) | ||
); | ||
case "ERROR_LIMIT" -> throw new ElasticsearchStatusException( | ||
"Cohere stopped the stream due to an error: {}", | ||
RestStatus.TOO_MANY_REQUESTS, | ||
parseErrorMessage(responseMap) | ||
); | ||
} | ||
} | ||
|
||
@SuppressWarnings("unchecked") | ||
private String parseErrorMessage(Map<String, Object> responseMap) { | ||
var innerResponseMap = (Map<String, Object>) responseMap.get("response"); | ||
return (String) innerResponseMap.get("text"); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
67 changes: 67 additions & 0 deletions
67
...sticsearch/xpack/inference/external/response/streaming/NewlineDelimitedByteProcessor.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.inference.external.response.streaming; | ||
|
||
import org.elasticsearch.xpack.inference.common.DelegatingProcessor; | ||
import org.elasticsearch.xpack.inference.external.http.HttpResult; | ||
|
||
import java.nio.charset.StandardCharsets; | ||
import java.util.ArrayDeque; | ||
import java.util.Deque; | ||
import java.util.regex.Pattern; | ||
|
||
/** | ||
* Processes HttpResult bytes into lines separated by newlines, delimited by either line-feed or carriage-return line-feed. | ||
* Downstream is responsible for validating the structure of the lines after they have been separated. | ||
* Because Upstream (Apache) can send us a single line split between two HttpResults, this processor will aggregate bytes from the last | ||
* HttpResult and append them to the front of the next HttpResult. | ||
* When onComplete is called, the last batch is always flushed to the downstream onNext. | ||
*/ | ||
public class NewlineDelimitedByteProcessor extends DelegatingProcessor<HttpResult, Deque<String>> { | ||
private static final Pattern END_OF_LINE_REGEX = Pattern.compile("\\n|\\r\\n"); | ||
private volatile String previousTokens = ""; | ||
|
||
@Override | ||
protected void next(HttpResult item) { | ||
// discard empty result and go to the next | ||
if (item.isBodyEmpty()) { | ||
upstream().request(1); | ||
return; | ||
} | ||
|
||
var body = previousTokens + new String(item.body(), StandardCharsets.UTF_8); | ||
var lines = END_OF_LINE_REGEX.split(body, -1); // -1 because we actually want trailing empty strings | ||
|
||
var results = new ArrayDeque<String>(lines.length); | ||
for (var i = 0; i < lines.length - 1; i++) { | ||
var line = lines[i].trim(); | ||
if (line.isBlank() == false) { | ||
results.offer(line); | ||
} | ||
} | ||
|
||
previousTokens = lines[lines.length - 1].trim(); | ||
|
||
if (results.isEmpty()) { | ||
upstream().request(1); | ||
} else { | ||
downstream().onNext(results); | ||
} | ||
} | ||
|
||
@Override | ||
public void onComplete() { | ||
if (previousTokens.isBlank()) { | ||
super.onComplete(); | ||
} else if (isClosed.compareAndSet(false, true)) { | ||
var results = new ArrayDeque<String>(1); | ||
results.offer(previousTokens); | ||
downstream().onNext(results); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.