Skip to content

Commit

Permalink
[ML] Fix stream support for TaskType.ANY
Browse files Browse the repository at this point in the history
If we support one, then we support any.
  • Loading branch information
prwhelan committed Oct 25, 2024
1 parent 6e0bdbe commit 3ad8595
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@
import org.elasticsearch.xpack.inference.external.http.sender.Sender;

import java.io.IOException;
import java.util.EnumSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

public abstract class SenderService implements InferenceService {
protected static final Set<TaskType> COMPLETION_ONLY = Set.of(TaskType.COMPLETION);
protected static final Set<TaskType> COMPLETION_ONLY = EnumSet.of(TaskType.COMPLETION, TaskType.ANY);
private final Sender sender;
private final ServiceComponents serviceComponents;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1304,6 +1304,13 @@ public void testInfer_UnauthorizedResponse() throws IOException {
}
}

public void testSupportsStreaming() throws IOException {
try (var service = new AmazonBedrockService(mock(), mock(), createWithEmptySettings(mock()))) {
assertTrue(service.canStream(TaskType.COMPLETION));
assertTrue(service.canStream(TaskType.ANY));
}
}

public void testChunkedInfer_CallsInfer_ConvertsFloatResponse_ForEmbeddings() throws IOException {
var model = AmazonBedrockEmbeddingsModelTests.createModel(
"id",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,13 @@ public void testInfer_StreamRequest_ErrorResponse() throws Exception {
.hasErrorContaining("blah");
}

public void testSupportsStreaming() throws IOException {
try (var service = new AnthropicService(mock(), createWithEmptySettings(mock()))) {
assertTrue(service.canStream(TaskType.COMPLETION));
assertTrue(service.canStream(TaskType.ANY));
}
}

private AnthropicService createServiceWithMockSender() {
return new AnthropicService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1384,6 +1384,13 @@ public void testInfer_StreamRequest_ErrorResponse() throws Exception {
.hasErrorContaining("You didn't provide an API key...");
}

public void testSupportsStreaming() throws IOException {
try (var service = new AzureAiStudioService(mock(), createWithEmptySettings(mock()))) {
assertTrue(service.canStream(TaskType.COMPLETION));
assertTrue(service.canStream(TaskType.ANY));
}
}

// ----------------------------------------------------------------

private AzureAiStudioService createService() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1504,6 +1504,13 @@ public void testInfer_StreamRequest_ErrorResponse() throws Exception {
.hasErrorContaining("You didn't provide an API key...");
}

public void testSupportsStreaming() throws IOException {
try (var service = new AzureOpenAiService(mock(), createWithEmptySettings(mock()))) {
assertTrue(service.canStream(TaskType.COMPLETION));
assertTrue(service.canStream(TaskType.ANY));
}
}

private AzureOpenAiService createAzureOpenAiService() {
return new AzureOpenAiService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1683,6 +1683,13 @@ public void testInfer_StreamRequest_ErrorResponse() throws Exception {
.hasErrorContaining("how dare you");
}

public void testSupportsStreaming() throws IOException {
try (var service = new CohereService(mock(), createWithEmptySettings(mock()))) {
assertTrue(service.canStream(TaskType.COMPLETION));
assertTrue(service.canStream(TaskType.ANY));
}
}

private Map<String, Object> getRequestConfigMap(
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1219,6 +1219,13 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si
}
}

public void testSupportsStreaming() throws IOException {
try (var service = new GoogleAiStudioService(mock(), createWithEmptySettings(mock()))) {
assertTrue(service.canStream(TaskType.COMPLETION));
assertTrue(service.canStream(TaskType.ANY));
}
}

public static Map<String, Object> buildExpectationCompletions(List<String> completions) {
return Map.of(
ChatCompletionResults.COMPLETION,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,13 @@ public void testInfer_StreamRequest_ErrorResponse() throws Exception {
.hasErrorContaining("You didn't provide an API key...");
}

public void testSupportsStreaming() throws IOException {
try (var service = new OpenAiService(mock(), createWithEmptySettings(mock()))) {
assertTrue(service.canStream(TaskType.COMPLETION));
assertTrue(service.canStream(TaskType.ANY));
}
}

public void testCheckModelConfig_IncludesMaxTokens() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);

Expand Down

0 comments on commit 3ad8595

Please sign in to comment.