diff --git a/ansible_ai_connect/ai/api/model_pipelines/dummy/configuration.py b/ansible_ai_connect/ai/api/model_pipelines/dummy/configuration.py index 40947286c..64d2995b8 100644 --- a/ansible_ai_connect/ai/api/model_pipelines/dummy/configuration.py +++ b/ansible_ai_connect/ai/api/model_pipelines/dummy/configuration.py @@ -19,9 +19,6 @@ BaseConfig, PipelineConfiguration, ) -from ansible_ai_connect.ai.api.model_pipelines.config_serializers import ( - BaseConfigSerializer, -) from ansible_ai_connect.ai.api.model_pipelines.registry import Register # ANSIBLE_AI_MODEL_MESH_API_URL @@ -29,20 +26,22 @@ # DUMMY_MODEL_RESPONSE_MAX_LATENCY_MSEC # DUMMY_MODEL_RESPONSE_BODY +DEFAULT_BODY = ( + '{"predictions":' + '["ansible.builtin.apt:\\n name: nginx\\n update_cache: true\\n state: present\\n"]}' +) + class DummyConfiguration(BaseConfig): def __init__( self, - inference_url: str, - model_id: str, - timeout: Optional[int], enable_health_check: Optional[bool], latency_use_jitter: bool, latency_max_msec: int, body: str, ): - super().__init__(inference_url, model_id, timeout, enable_health_check) + super().__init__("dummy", "dummy", None, enable_health_check) self.latency_use_jitter = latency_use_jitter self.latency_max_msec = latency_max_msec self.body = body @@ -55,9 +54,6 @@ def __init__(self, **kwargs): super().__init__( "dummy", DummyConfiguration( - inference_url=kwargs["inference_url"], - model_id=kwargs["model_id"], - timeout=kwargs["timeout"], enable_health_check=kwargs["enable_health_check"], latency_use_jitter=kwargs["latency_use_jitter"], latency_max_msec=kwargs["latency_max_msec"], @@ -67,7 +63,10 @@ def __init__(self, **kwargs): @Register(api_type="dummy") -class DummyConfigurationSerializer(BaseConfigSerializer): +class DummyConfigurationSerializer(serializers.Serializer): + enable_health_check = serializers.BooleanField(required=False, default=False) latency_use_jitter = serializers.BooleanField(required=False, default=False) latency_max_msec = serializers.IntegerField(required=False, default=3000) - body = serializers.CharField(required=True) + body = serializers.CharField( + required=False, allow_null=True, allow_blank=True, default=DEFAULT_BODY + ) diff --git a/ansible_ai_connect/ai/api/model_pipelines/dummy/tests/test_configuration.py b/ansible_ai_connect/ai/api/model_pipelines/dummy/tests/test_configuration.py new file mode 100644 index 000000000..21cb087e8 --- /dev/null +++ b/ansible_ai_connect/ai/api/model_pipelines/dummy/tests/test_configuration.py @@ -0,0 +1,44 @@ +# Copyright Red Hat +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ansible_ai_connect.ai.api.model_pipelines.dummy.configuration import ( + DEFAULT_BODY, + DummyConfiguration, + DummyConfigurationSerializer, +) +from ansible_ai_connect.ai.api.model_pipelines.tests import mock_pipeline_config +from ansible_ai_connect.test_utils import WisdomServiceAPITestCaseBaseOIDC + + +class TestDummyConfigurationSerializer(WisdomServiceAPITestCaseBaseOIDC): + + def test_empty(self): + serializer = DummyConfigurationSerializer(data={}) + self.assertTrue(serializer.is_valid()) + self.assertEqual(serializer.validated_data["enable_health_check"], False) + self.assertEqual(serializer.validated_data["latency_use_jitter"], False) + self.assertEqual(serializer.validated_data["latency_max_msec"], 3000) + self.assertEqual(serializer.validated_data["body"], DEFAULT_BODY) + + def test_serializer_with_body(self): + config: DummyConfiguration = mock_pipeline_config("dummy") + serializer = DummyConfigurationSerializer(data=config.__dict__) + self.assertTrue(serializer.is_valid()) + + def test_serializer_without_body(self): + config: DummyConfiguration = mock_pipeline_config("dummy") + del config.__dict__["body"] + serializer = DummyConfigurationSerializer(data=config.__dict__) + self.assertTrue(serializer.is_valid()) + self.assertEqual(serializer.validated_data["body"], DEFAULT_BODY) diff --git a/ansible_ai_connect/ai/api/model_pipelines/tests/__init__.py b/ansible_ai_connect/ai/api/model_pipelines/tests/__init__.py index f70aa9086..4ee8e825d 100644 --- a/ansible_ai_connect/ai/api/model_pipelines/tests/__init__.py +++ b/ansible_ai_connect/ai/api/model_pipelines/tests/__init__.py @@ -66,17 +66,12 @@ def extract(name: str, default: T, **kwargs) -> T: def mock_pipeline_config(pipeline_provider: t_model_mesh_api_type, **kwargs): match pipeline_provider: case "dummy": - inference_url = extract("inference_url", "http://localhost", **kwargs) - c = DummyConfiguration( - inference_url=inference_url, - model_id=extract("model_id", "a-model-id", **kwargs), - timeout=extract("timeout", 1000, **kwargs), + return DummyConfiguration( enable_health_check=extract("enable_health_check", False, **kwargs), latency_use_jitter=extract("latency_use_jitter", False, **kwargs), latency_max_msec=extract("latency_max_msec", 0, **kwargs), body=extract("body", "body", **kwargs), ) - return c case "grpc": return GrpcConfiguration( inference_url=extract("inference_url", "http://localhost", **kwargs),