diff --git a/ansible_ai_connect/ai/api/model_pipelines/__init__.py b/ansible_ai_connect/ai/api/model_pipelines/__init__.py index 13e5e9bb2..5cf8e245c 100644 --- a/ansible_ai_connect/ai/api/model_pipelines/__init__.py +++ b/ansible_ai_connect/ai/api/model_pipelines/__init__.py @@ -1,5 +1,3 @@ -import ansible_ai_connect.ai.api.model_pipelines.bam.configuration # noqa -import ansible_ai_connect.ai.api.model_pipelines.bam.pipelines # noqa import ansible_ai_connect.ai.api.model_pipelines.dummy.configuration # noqa import ansible_ai_connect.ai.api.model_pipelines.dummy.pipelines # noqa import ansible_ai_connect.ai.api.model_pipelines.grpc.configuration # noqa diff --git a/ansible_ai_connect/ai/api/model_pipelines/bam/__init__.py b/ansible_ai_connect/ai/api/model_pipelines/bam/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/ansible_ai_connect/ai/api/model_pipelines/bam/configuration.py b/ansible_ai_connect/ai/api/model_pipelines/bam/configuration.py deleted file mode 100644 index 18afecd8c..000000000 --- a/ansible_ai_connect/ai/api/model_pipelines/bam/configuration.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright Red Hat -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from dataclasses import dataclass -from typing import Optional - -from rest_framework import serializers - -from ansible_ai_connect.ai.api.model_pipelines.langchain.configuration import ( - LangchainBasePipelineConfiguration, - LangchainConfiguration, - LangchainConfigurationSerializer, -) -from ansible_ai_connect.ai.api.model_pipelines.registry import Register - -# -- Base -# ANSIBLE_AI_MODEL_MESH_API_URL -# ANSIBLE_AI_MODEL_MESH_MODEL_ID -# ANSIBLE_AI_MODEL_MESH_API_TIMEOUT -# ENABLE_HEALTHCHECK_XXX - -# -- BAM -# ANSIBLE_AI_MODEL_MESH_API_KEY -# ANSIBLE_AI_MODEL_MESH_API_VERIFY_SSL - - -@dataclass -class BAMConfiguration(LangchainConfiguration): - - def __init__( - self, - inference_url: str, - model_id: str, - timeout: Optional[int], - enable_health_check: Optional[bool], - api_key: str, - verify_ssl: bool, - ): - super().__init__( - inference_url, - model_id, - timeout, - enable_health_check, - ) - self.api_key = api_key - self.verify_ssl = verify_ssl - - api_key: str - verify_ssl: bool - - -@Register(api_type="bam") -class BAMPipelineConfiguration(LangchainBasePipelineConfiguration): - - def __init__(self, **kwargs): - super().__init__( - "bam", - BAMConfiguration( - inference_url=kwargs["inference_url"], - model_id=kwargs["model_id"], - timeout=kwargs["timeout"], - enable_health_check=kwargs["enable_health_check"], - api_key=kwargs["api_key"], - verify_ssl=kwargs["verify_ssl"], - ), - ) - - -@Register(api_type="bam") -class BAMConfigurationSerializer(LangchainConfigurationSerializer): - api_key = serializers.CharField(required=True) - verify_ssl = serializers.BooleanField(required=False, default=False) diff --git a/ansible_ai_connect/ai/api/model_pipelines/bam/pipelines.py b/ansible_ai_connect/ai/api/model_pipelines/bam/pipelines.py deleted file mode 100644 index 5713f17df..000000000 --- a/ansible_ai_connect/ai/api/model_pipelines/bam/pipelines.py +++ /dev/null @@ -1,146 +0,0 @@ -# Copyright Red Hat -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import logging -from typing import Any, Callable, List, Optional, Union - -import requests -from django.conf import settings -from langchain_core.callbacks import CallbackManagerForLLMRun -from langchain_core.language_models.chat_models import SimpleChatModel -from langchain_core.messages import BaseMessage - -from ansible_ai_connect.ai.api.model_pipelines.bam.configuration import BAMConfiguration -from ansible_ai_connect.ai.api.model_pipelines.langchain.pipelines import ( - LangchainCompletionsPipeline, - LangchainMetaData, - LangchainPlaybookExplanationPipeline, - LangchainPlaybookGenerationPipeline, -) -from ansible_ai_connect.ai.api.model_pipelines.registry import Register - -logger = logging.getLogger(__name__) - - -class ChatBAM(SimpleChatModel): - api_key: str - model_id: str - prediction_url: str - timeout: Callable[[int], Union[int, None]] - - @property - def _llm_type(self) -> str: - return "BAM" - - def _call( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> str: - if stop is not None: - raise ValueError("stop kwargs are not permitted.") - - bam_messages = list( - map(lambda x: {"role": x.additional_kwargs["role"], "content": x.content}, messages) - ) - session = requests.Session() - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}", - } - - params = { - "model_id": self.model_id, - "messages": bam_messages, - "parameters": { - "temperature": 0.1, - "decoding_method": "greedy", - "repetition_penalty": 1.05, - "min_new_tokens": 1, - "max_new_tokens": 2048, - }, - } - - logger.info(f"request: {params}") - - result = session.post( - self.prediction_url, - headers=headers, - json=params, - timeout=self.timeout(1), - verify=settings.ANSIBLE_AI_MODEL_MESH_API_VERIFY_SSL, - ) - result.raise_for_status() - body = json.loads(result.text) - logger.info(f"response: {body}") - response = body.get("results", [{}])[0].get("generated_text", "") - return response - - -@Register(api_type="bam") -class BAMMetaData(LangchainMetaData[BAMConfiguration]): - - def __init__(self, config: BAMConfiguration): - super().__init__(config=config) - - -@Register(api_type="bam") -class BAMCompletionsPipeline(LangchainCompletionsPipeline[BAMConfiguration]): - - def __init__(self, config: BAMConfiguration): - super().__init__(config=config) - - def infer_from_parameters(self, api_key, model_id, context, prompt, suggestion_id=None): - raise NotImplementedError - - def get_chat_model(self, model_id): - return ChatBAM( - api_key=self.config.api_key, - model_id=model_id, - prediction_url=f"{self.config.inference_url}/v2/text/chat?version=2024-01-10", - timeout=self.timeout, - ) - - -@Register(api_type="bam") -class BAMPlaybookGenerationPipeline(LangchainPlaybookGenerationPipeline[BAMConfiguration]): - - def __init__(self, config: BAMConfiguration): - super().__init__(config=config) - - def get_chat_model(self, model_id): - return ChatBAM( - api_key=self.config.api_key, - model_id=model_id, - prediction_url=f"{self.config.inference_url}/v2/text/chat?version=2024-01-10", - timeout=self.timeout, - ) - - -@Register(api_type="bam") -class BAMPlaybookExplanationPipeline(LangchainPlaybookExplanationPipeline[BAMConfiguration]): - - def __init__(self, config: BAMConfiguration): - super().__init__(config=config) - - def get_chat_model(self, model_id): - return ChatBAM( - api_key=self.config.api_key, - model_id=model_id, - prediction_url=f"{self.config.inference_url}/v2/text/chat?version=2024-01-10", - timeout=self.timeout, - ) diff --git a/ansible_ai_connect/ai/api/model_pipelines/bam/tests/__init__.py b/ansible_ai_connect/ai/api/model_pipelines/bam/tests/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/ansible_ai_connect/ai/api/model_pipelines/bam/tests/test_factory.py b/ansible_ai_connect/ai/api/model_pipelines/bam/tests/test_factory.py deleted file mode 100644 index f1a8f73ab..000000000 --- a/ansible_ai_connect/ai/api/model_pipelines/bam/tests/test_factory.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright Red Hat -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from django.test import override_settings - -from ansible_ai_connect.ai.api.model_pipelines.bam.pipelines import ( - BAMCompletionsPipeline, - BAMPlaybookExplanationPipeline, - BAMPlaybookGenerationPipeline, -) -from ansible_ai_connect.ai.api.model_pipelines.nop.pipelines import ( - NopChatBotPipeline, - NopContentMatchPipeline, - NopRoleGenerationPipeline, -) -from ansible_ai_connect.ai.api.model_pipelines.pipelines import ( - ModelPipelineChatBot, - ModelPipelineCompletions, - ModelPipelineContentMatch, - ModelPipelinePlaybookExplanation, - ModelPipelinePlaybookGeneration, - ModelPipelineRoleGeneration, -) -from ansible_ai_connect.ai.api.model_pipelines.tests import mock_config -from ansible_ai_connect.ai.api.model_pipelines.tests.test_factory import ( - TestModelPipelineFactoryImplementations, -) - - -@override_settings(ANSIBLE_AI_MODEL_MESH_CONFIG=mock_config("bam")) -class TestModelPipelineFactory(TestModelPipelineFactoryImplementations): - - def test_completions_pipeline(self): - self.assert_concrete_implementation(ModelPipelineCompletions, BAMCompletionsPipeline) - - def test_content_match_pipeline(self): - self.assert_default_implementation(ModelPipelineContentMatch, NopContentMatchPipeline) - - def test_playbook_generation_pipeline(self): - self.assert_concrete_implementation( - ModelPipelinePlaybookGeneration, BAMPlaybookGenerationPipeline - ) - - def test_role_generation_pipeline(self): - self.assert_default_implementation(ModelPipelineRoleGeneration, NopRoleGenerationPipeline) - - def test_playbook_explanation_pipeline(self): - self.assert_concrete_implementation( - ModelPipelinePlaybookExplanation, BAMPlaybookExplanationPipeline - ) - - def test_chatbot_pipeline(self): - self.assert_default_implementation(ModelPipelineChatBot, NopChatBotPipeline) diff --git a/ansible_ai_connect/ai/api/model_pipelines/bam/tests/test_pipeline.py b/ansible_ai_connect/ai/api/model_pipelines/bam/tests/test_pipeline.py deleted file mode 100644 index 635cd970a..000000000 --- a/ansible_ai_connect/ai/api/model_pipelines/bam/tests/test_pipeline.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright Red Hat -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -from unittest.mock import Mock - -import responses -from django.test import TestCase, override_settings -from responses import matchers - -from ansible_ai_connect.ai.api.model_pipelines.bam.pipelines import ( - BAMCompletionsPipeline, -) -from ansible_ai_connect.ai.api.model_pipelines.pipelines import CompletionsParameters -from ansible_ai_connect.ai.api.model_pipelines.tests import mock_pipeline_config - - -class TestBam(TestCase): - def setUp(self): - super().setUp() - self.inference_url = "http://localhost" - self.prediction_url = f"{self.inference_url}/v2/text/chat?version=2024-01-10" - self.model_input = { - "instances": [ - { - "context": "", - "prompt": "- name: hey siri, return a task that installs ffmpeg", - } - ] - } - - self.expected_task_body = "ansible.builtin.debug:\n msg: something went wrong" - self.expected_response = { - "predictions": [self.expected_task_body], - "model_id": "test", - } - - @override_settings(ANSIBLE_AI_MODEL_MESH_API_KEY="my_key") - @override_settings(ANSIBLE_AI_MODEL_MESH_MODEL_ID="test") - @responses.activate - def test_infer(self): - model = "test" - model_client = BAMCompletionsPipeline(mock_pipeline_config("bam")) - responses.post( - self.prediction_url, - match=[ - matchers.header_matcher({"Content-Type": "application/json"}), - matchers.json_params_matcher( - { - "model_id": "test", - "messages": [ - { - "content": ( - "You are an Ansible expert. Return a single task that " - "best completes the partial playbook. Return only the " - "task as YAML. Do not return multiple tasks. Do not explain " - "your response. Do not include the prompt in your response." - ), - "role": "system", - }, - { - "content": "- name: hey siri, return a task that installs ffmpeg\n", - "role": "user", - }, - ], - "parameters": { - "temperature": 0.1, - "decoding_method": "greedy", - "repetition_penalty": 1.05, - "min_new_tokens": 1, - "max_new_tokens": 2048, - }, - }, - strict_match=False, - ), - ], - json={ - "results": [ - {"generated_text": "ansible.builtin.debug:\n msg: something went wrong"} - ], - }, - ) - - response = model_client.invoke( - CompletionsParameters.init(request=Mock(), model_input=self.model_input, model_id=model) - ) - self.assertEqual(json.dumps(self.expected_response), json.dumps(response)) diff --git a/ansible_ai_connect/ai/api/model_pipelines/tests/__init__.py b/ansible_ai_connect/ai/api/model_pipelines/tests/__init__.py index fd97486ca..f70aa9086 100644 --- a/ansible_ai_connect/ai/api/model_pipelines/tests/__init__.py +++ b/ansible_ai_connect/ai/api/model_pipelines/tests/__init__.py @@ -14,7 +14,6 @@ import json from typing import TypeVar -from ansible_ai_connect.ai.api.model_pipelines.bam.configuration import BAMConfiguration from ansible_ai_connect.ai.api.model_pipelines.dummy.configuration import ( DummyConfiguration, ) @@ -66,15 +65,6 @@ def extract(name: str, default: T, **kwargs) -> T: def mock_pipeline_config(pipeline_provider: t_model_mesh_api_type, **kwargs): match pipeline_provider: - case "bam": - return BAMConfiguration( - inference_url=extract("inference_url", "http://localhost", **kwargs), - api_key=extract("api_key", "an-api-key", **kwargs), - model_id=extract("model_id", "a-model-id", **kwargs), - timeout=extract("timeout", 1000, **kwargs), - enable_health_check=extract("enable_health_check", False, **kwargs), - verify_ssl=extract("verify_ssl", False, **kwargs), - ) case "dummy": inference_url = extract("inference_url", "http://localhost", **kwargs) c = DummyConfiguration( diff --git a/ansible_ai_connect/main/settings/types.py b/ansible_ai_connect/main/settings/types.py index 0645613ab..a8f876900 100644 --- a/ansible_ai_connect/main/settings/types.py +++ b/ansible_ai_connect/main/settings/types.py @@ -28,7 +28,6 @@ "wca-dummy", "ollama", "llamacpp", - "bam", "nop", ] diff --git a/pyproject.toml b/pyproject.toml index d9b4b7c60..f04f86c98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,7 +77,6 @@ profile = "black" [tool.pyright] include = [ "ansible_ai_connect/ai/api/aws/wca_secret_manager.py", - "ansible_ai_connect/ai/api/model_pipelines/bam/pipelines.py", "ansible_ai_connect/ai/api/model_pipelines/dummy/pipelines.py", "ansible_ai_connect/ai/api/model_pipelines/http/pipelines.py", "ansible_ai_connect/ai/api/model_pipelines/langchain/pipelines.py",