Skip to content

Commit

Permalink
AAP-38593: ModelPipelines: dummy: All parameters are optional (#1493)
Browse files Browse the repository at this point in the history
  • Loading branch information
manstis authored Jan 14, 2025
1 parent a495222 commit fb78dfe
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 18 deletions.
23 changes: 11 additions & 12 deletions ansible_ai_connect/ai/api/model_pipelines/dummy/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,29 @@
BaseConfig,
PipelineConfiguration,
)
from ansible_ai_connect.ai.api.model_pipelines.config_serializers import (
BaseConfigSerializer,
)
from ansible_ai_connect.ai.api.model_pipelines.registry import Register

# ANSIBLE_AI_MODEL_MESH_API_URL
# DUMMY_MODEL_RESPONSE_LATENCY_USE_JITTER
# DUMMY_MODEL_RESPONSE_MAX_LATENCY_MSEC
# DUMMY_MODEL_RESPONSE_BODY

DEFAULT_BODY = (
'{"predictions":'
'["ansible.builtin.apt:\\n name: nginx\\n update_cache: true\\n state: present\\n"]}'
)


class DummyConfiguration(BaseConfig):

def __init__(
self,
inference_url: str,
model_id: str,
timeout: Optional[int],
enable_health_check: Optional[bool],
latency_use_jitter: bool,
latency_max_msec: int,
body: str,
):
super().__init__(inference_url, model_id, timeout, enable_health_check)
super().__init__("dummy", "dummy", None, enable_health_check)
self.latency_use_jitter = latency_use_jitter
self.latency_max_msec = latency_max_msec
self.body = body
Expand All @@ -55,9 +54,6 @@ def __init__(self, **kwargs):
super().__init__(
"dummy",
DummyConfiguration(
inference_url=kwargs["inference_url"],
model_id=kwargs["model_id"],
timeout=kwargs["timeout"],
enable_health_check=kwargs["enable_health_check"],
latency_use_jitter=kwargs["latency_use_jitter"],
latency_max_msec=kwargs["latency_max_msec"],
Expand All @@ -67,7 +63,10 @@ def __init__(self, **kwargs):


@Register(api_type="dummy")
class DummyConfigurationSerializer(BaseConfigSerializer):
class DummyConfigurationSerializer(serializers.Serializer):
enable_health_check = serializers.BooleanField(required=False, default=False)
latency_use_jitter = serializers.BooleanField(required=False, default=False)
latency_max_msec = serializers.IntegerField(required=False, default=3000)
body = serializers.CharField(required=True)
body = serializers.CharField(
required=False, allow_null=True, allow_blank=True, default=DEFAULT_BODY
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright Red Hat
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from ansible_ai_connect.ai.api.model_pipelines.dummy.configuration import (
DEFAULT_BODY,
DummyConfiguration,
DummyConfigurationSerializer,
)
from ansible_ai_connect.ai.api.model_pipelines.tests import mock_pipeline_config
from ansible_ai_connect.test_utils import WisdomServiceAPITestCaseBaseOIDC


class TestDummyConfigurationSerializer(WisdomServiceAPITestCaseBaseOIDC):

def test_empty(self):
serializer = DummyConfigurationSerializer(data={})
self.assertTrue(serializer.is_valid())
self.assertEqual(serializer.validated_data["enable_health_check"], False)
self.assertEqual(serializer.validated_data["latency_use_jitter"], False)
self.assertEqual(serializer.validated_data["latency_max_msec"], 3000)
self.assertEqual(serializer.validated_data["body"], DEFAULT_BODY)

def test_serializer_with_body(self):
config: DummyConfiguration = mock_pipeline_config("dummy")
serializer = DummyConfigurationSerializer(data=config.__dict__)
self.assertTrue(serializer.is_valid())

def test_serializer_without_body(self):
config: DummyConfiguration = mock_pipeline_config("dummy")
del config.__dict__["body"]
serializer = DummyConfigurationSerializer(data=config.__dict__)
self.assertTrue(serializer.is_valid())
self.assertEqual(serializer.validated_data["body"], DEFAULT_BODY)
7 changes: 1 addition & 6 deletions ansible_ai_connect/ai/api/model_pipelines/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,12 @@ def extract(name: str, default: T, **kwargs) -> T:
def mock_pipeline_config(pipeline_provider: t_model_mesh_api_type, **kwargs):
match pipeline_provider:
case "dummy":
inference_url = extract("inference_url", "http://localhost", **kwargs)
c = DummyConfiguration(
inference_url=inference_url,
model_id=extract("model_id", "a-model-id", **kwargs),
timeout=extract("timeout", 1000, **kwargs),
return DummyConfiguration(
enable_health_check=extract("enable_health_check", False, **kwargs),
latency_use_jitter=extract("latency_use_jitter", False, **kwargs),
latency_max_msec=extract("latency_max_msec", 0, **kwargs),
body=extract("body", "body", **kwargs),
)
return c
case "grpc":
return GrpcConfiguration(
inference_url=extract("inference_url", "http://localhost", **kwargs),
Expand Down

0 comments on commit fb78dfe

Please sign in to comment.