From 0f38b2b10e898b2478debd1a25525f1b2ff11bfb Mon Sep 17 00:00:00 2001 From: Ying Mao Date: Thu, 31 Oct 2024 14:08:58 -0400 Subject: [PATCH] Fixing tests (#116032) --- .../xpack/inference/InferenceCrudIT.java | 38 ++++++++++++++----- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java index fed63477701e3..f9a1318cd9740 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -13,8 +13,10 @@ import org.elasticsearch.client.ResponseException; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceFeature; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; @@ -132,7 +134,11 @@ public void testApisWithoutTaskType() throws IOException { @SuppressWarnings("unchecked") public void testGetServicesWithoutTaskType() throws IOException { List services = getAllServices(); - assertThat(services.size(), equalTo(19)); + if (ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()) { + assertThat(services.size(), equalTo(19)); + } else { + assertThat(services.size(), equalTo(18)); + } String[] providers = new String[services.size()]; for (int i = 0; i < services.size(); i++) { @@ -141,16 +147,15 @@ public void testGetServicesWithoutTaskType() throws IOException { } Arrays.sort(providers); - assertArrayEquals( - providers, - List.of( + + var providerList = new ArrayList<>( + Arrays.asList( "alibabacloud-ai-search", "amazonbedrock", "anthropic", "azureaistudio", "azureopenai", "cohere", - "elastic", "elasticsearch", "googleaistudio", "googlevertexai", @@ -163,8 +168,12 @@ public void testGetServicesWithoutTaskType() throws IOException { "test_service", "text_embedding_test_service", "watsonxai" - ).toArray() + ) ); + if (ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()) { + providerList.add(6, "elastic"); + } + assertArrayEquals(providers, providerList.toArray()); } @SuppressWarnings("unchecked") @@ -248,7 +257,12 @@ public void testGetServicesWithCompletionTaskType() throws IOException { @SuppressWarnings("unchecked") public void testGetServicesWithSparseEmbeddingTaskType() throws IOException { List services = getServices(TaskType.SPARSE_EMBEDDING); - assertThat(services.size(), equalTo(6)); + + if (ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()) { + assertThat(services.size(), equalTo(6)); + } else { + assertThat(services.size(), equalTo(5)); + } String[] providers = new String[services.size()]; for (int i = 0; i < services.size(); i++) { @@ -257,10 +271,14 @@ public void testGetServicesWithSparseEmbeddingTaskType() throws IOException { } Arrays.sort(providers); - assertArrayEquals( - providers, - List.of("alibabacloud-ai-search", "elastic", "elasticsearch", "hugging_face", "hugging_face_elser", "test_service").toArray() + + var providerList = new ArrayList<>( + Arrays.asList("alibabacloud-ai-search", "elasticsearch", "hugging_face", "hugging_face_elser", "test_service") ); + if (ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()) { + providerList.add(1, "elastic"); + } + assertArrayEquals(providers, providerList.toArray()); } public void testSkipValidationAndStart() throws IOException {