Skip to content

Commit

Permalink
Fixing tests (#116032)
Browse files Browse the repository at this point in the history
  • Loading branch information
ymao1 authored Oct 31, 2024
1 parent b280e94 commit 0f38b2b
Showing 1 changed file with 28 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
import org.elasticsearch.client.ResponseException;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceFeature;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -132,7 +134,11 @@ public void testApisWithoutTaskType() throws IOException {
@SuppressWarnings("unchecked")
public void testGetServicesWithoutTaskType() throws IOException {
List<Object> services = getAllServices();
assertThat(services.size(), equalTo(19));
if (ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()) {
assertThat(services.size(), equalTo(19));
} else {
assertThat(services.size(), equalTo(18));
}

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Expand All @@ -141,16 +147,15 @@ public void testGetServicesWithoutTaskType() throws IOException {
}

Arrays.sort(providers);
assertArrayEquals(
providers,
List.of(

var providerList = new ArrayList<>(
Arrays.asList(
"alibabacloud-ai-search",
"amazonbedrock",
"anthropic",
"azureaistudio",
"azureopenai",
"cohere",
"elastic",
"elasticsearch",
"googleaistudio",
"googlevertexai",
Expand All @@ -163,8 +168,12 @@ public void testGetServicesWithoutTaskType() throws IOException {
"test_service",
"text_embedding_test_service",
"watsonxai"
).toArray()
)
);
if (ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()) {
providerList.add(6, "elastic");
}
assertArrayEquals(providers, providerList.toArray());
}

@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -248,7 +257,12 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
@SuppressWarnings("unchecked")
public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
List<Object> services = getServices(TaskType.SPARSE_EMBEDDING);
assertThat(services.size(), equalTo(6));

if (ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()) {
assertThat(services.size(), equalTo(6));
} else {
assertThat(services.size(), equalTo(5));
}

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Expand All @@ -257,10 +271,14 @@ public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
}

Arrays.sort(providers);
assertArrayEquals(
providers,
List.of("alibabacloud-ai-search", "elastic", "elasticsearch", "hugging_face", "hugging_face_elser", "test_service").toArray()

var providerList = new ArrayList<>(
Arrays.asList("alibabacloud-ai-search", "elasticsearch", "hugging_face", "hugging_face_elser", "test_service")
);
if (ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()) {
providerList.add(1, "elastic");
}
assertArrayEquals(providers, providerList.toArray());
}

public void testSkipValidationAndStart() throws IOException {
Expand Down

0 comments on commit 0f38b2b

Please sign in to comment.