diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index aadb22d54..8fd2a1fb8 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -30,7 +30,7 @@ jobs: - name: Set up python3 uses: actions/setup-python@v4 with: - python-version: '3.10' + python-version: '3.11' - name: Install dependencies run: | diff --git a/README-ANSIBLE_AI_MODEL_MESH_CONFIG.md b/README-ANSIBLE_AI_MODEL_MESH_CONFIG.md new file mode 100644 index 000000000..0444c0d9b --- /dev/null +++ b/README-ANSIBLE_AI_MODEL_MESH_CONFIG.md @@ -0,0 +1,40 @@ +# Example `ANSIBLE_AI_MODEL_MESH_CONFIG` configuration + +Pay close attention to the formatting of the blocks. + +Each ends with `}},` otherwise conversion of the multi-line setting to a `str` can fail. + +```text +ANSIBLE_AI_MODEL_MESH_CONFIG="{ + "ModelPipelineCompletions": { + "provider": "ollama", + "config": { + "inference_url": "http://host.containers.internal:11434", + "model_id": "mistral:instruct"}}, + "ModelPipelineContentMatch": { + "provider": "ollama", + "config": { + "inference_url": "http://host.containers.internal:11434", + "model_id": "mistral:instruct"}}, + "ModelPipelinePlaybookGeneration": { + "provider": "ollama", + "config": { + "inference_url": "http://host.containers.internal:11434", + "model_id": "mistral:instruct"}}, + "ModelPipelineRoleGeneration": { + "provider": "ollama", + "config": { + "inference_url": "http://host.containers.internal:11434", + "model_id": "mistral:instruct"}}, + "ModelPipelinePlaybookExplanation": { + "provider": "ollama", + "config": { + "inference_url": "http://host.containers.internal:11434", + "model_id": "mistral:instruct"}}, + "ModelPipelineChatBot": { + "provider": "http", + "config": { + "inference_url": "http://localhost:8000", + "model_id": "granite3-8b"}} +}" +``` diff --git a/README.md b/README.md index 63dd98e21..c870bbb3a 100644 --- a/README.md +++ b/README.md @@ -26,10 +26,9 @@ SECRET_KEY="somesecretvalue" ENABLE_ARI_POSTPROCESS="False" WCA_SECRET_BACKEND_TYPE="dummy" # configure model server -ANSIBLE_AI_MODEL_MESH_API_URL="http://host.containers.internal:11434" -ANSIBLE_AI_MODEL_MESH_API_TYPE="ollama" -ANSIBLE_AI_MODEL_MESH_MODEL_ID="mistral:instruct" +ANSIBLE_AI_MODEL_MESH_CONFIG="..." ``` +See the example [ANSIBLE_AI_MODEL_MESH_CONFIG](./docs/config/examples/README-ANSIBLE_AI_MODEL_MESH_CONFIG.md). ### Start service and dependencies @@ -108,9 +107,9 @@ command line the variable `DEBUG=True`. The Django service listens on . -Note that there is no pytorch service defined in the docker-compose -file. You should adjust the `ANSIBLE_AI_MODEL_MESH_API_URL` -configuration key to point on an existing service. +Note that there is no pytorch service defined in the `docker-compose` +file. You should adjust the `ANSIBLE_AI_MODEL_MESH_CONFIG` +configuration to point to an existing service. ## Use the WCA API Keys Manager @@ -460,11 +459,10 @@ To connect to the Mistal 7b Instruct model running on locally on [llama.cpp](htt ``` 1. Set the appropriate environment variables ```bash - ANSIBLE_AI_MODEL_MESH_API_URL=http://$YOUR_REAL_IP:8080 - ANSIBLE_AI_MODEL_MESH_API_TYPE=llamacpp - ANSIBLE_AI_MODEL_MESH_MODEL_ID=mistral-7b-instruct-v0.2.Q5_K_M.gguf + ANSIBLE_AI_MODEL_MESH_CONFIG="..." ENABLE_ARI_POSTPROCESS=False ``` +See the example [ANSIBLE_AI_MODEL_MESH_CONFIG](./docs/config/examples/README-ANSIBLE_AI_MODEL_MESH_CONFIG.md). # Testing diff --git a/ansible_ai_connect/ai/api/model_pipelines/config_loader.py b/ansible_ai_connect/ai/api/model_pipelines/config_loader.py index 302a9f388..b204e0efe 100644 --- a/ansible_ai_connect/ai/api/model_pipelines/config_loader.py +++ b/ansible_ai_connect/ai/api/model_pipelines/config_loader.py @@ -12,18 +12,57 @@ # See the License for the specific language governing permissions and # limitations under the License. import json +import logging +from json import JSONDecodeError +import yaml from django.conf import settings +from yaml import YAMLError from ansible_ai_connect.ai.api.model_pipelines.config_providers import Configuration from ansible_ai_connect.ai.api.model_pipelines.config_serializers import ( ConfigurationSerializer, ) +logger = logging.getLogger(__name__) + def load_config() -> Configuration: - source = json.loads(settings.ANSIBLE_AI_MODEL_MESH_CONFIG) - serializer = ConfigurationSerializer(data=source) + # yaml.safe_load(..) seems to also support loading JSON. Nice. + # However, try to load JSON with the correct _loader_ first in case of corner cases + errors: [Exception] = [] + result = load_json() + if isinstance(result, Exception): + errors.append(result) + result = load_yaml() + if isinstance(result, Exception): + errors.append(result) + else: + errors = [] + + if len(errors) > 0: + raise ExceptionGroup("Unable to parse ANSIBLE_AI_MODEL_MESH_CONFIG", errors) + + serializer = ConfigurationSerializer(data=result) serializer.is_valid(raise_exception=True) serializer.save() return serializer.instance + + +def load_json() -> str | Exception: + try: + logger.info("Attempting to parse ANSIBLE_AI_MODEL_MESH_CONFIG as JSON...") + return json.loads(settings.ANSIBLE_AI_MODEL_MESH_CONFIG) + except JSONDecodeError as e: + logger.exception(f"An error occurring parsing ANSIBLE_AI_MODEL_MESH_CONFIG as JSON:\n{e}") + return e + + +def load_yaml() -> str | Exception: + try: + logger.info("Attempting to parse ANSIBLE_AI_MODEL_MESH_CONFIG as YAML...") + y = yaml.safe_load(settings.ANSIBLE_AI_MODEL_MESH_CONFIG) + return y + except YAMLError as e: + logger.exception(f"An error occurring parsing ANSIBLE_AI_MODEL_MESH_CONFIG as YAML:\n{e}") + return e diff --git a/ansible_ai_connect/ai/api/model_pipelines/config_serializers.py b/ansible_ai_connect/ai/api/model_pipelines/config_serializers.py index e0c3e4072..e3448cff0 100644 --- a/ansible_ai_connect/ai/api/model_pipelines/config_serializers.py +++ b/ansible_ai_connect/ai/api/model_pipelines/config_serializers.py @@ -35,7 +35,8 @@ class PipelineConfigurationSerializer(serializers.Serializer): def to_internal_value(self, data): provider_part = super().to_internal_value(data) - serializer = REGISTRY[provider_part["provider"]][Serializer](data=data["config"]) + config_part = data["config"] if "config" in data else {} + serializer = REGISTRY[provider_part["provider"]][Serializer](data=config_part) serializer.is_valid(raise_exception=True) return {**provider_part, "config": serializer.validated_data} diff --git a/ansible_ai_connect/ai/api/model_pipelines/dummy/pipelines.py b/ansible_ai_connect/ai/api/model_pipelines/dummy/pipelines.py index 3c750b95f..9b99127df 100644 --- a/ansible_ai_connect/ai/api/model_pipelines/dummy/pipelines.py +++ b/ansible_ai_connect/ai/api/model_pipelines/dummy/pipelines.py @@ -134,7 +134,7 @@ def __init__(self, config: DummyConfiguration): super().__init__(config=config) def invoke(self, params: CompletionsParameters) -> CompletionsResponse: - logger.debug("!!!! settings.ANSIBLE_AI_MODEL_MESH_API_TYPE == 'dummy' !!!!") + logger.debug("!!!! ModelPipelineCompletions.provider == 'dummy' !!!!") logger.debug("!!!! Mocking Model response !!!!") if self.config.latency_use_jitter: jitter: float = secrets.randbelow(1000) * 0.001 diff --git a/ansible_ai_connect/ai/api/model_pipelines/http/configuration.py b/ansible_ai_connect/ai/api/model_pipelines/http/configuration.py index d23f75275..c9a361ed7 100644 --- a/ansible_ai_connect/ai/api/model_pipelines/http/configuration.py +++ b/ansible_ai_connect/ai/api/model_pipelines/http/configuration.py @@ -66,4 +66,4 @@ def __init__(self, **kwargs): @Register(api_type="http") class HttpConfigurationSerializer(BaseConfigSerializer): - verify_ssl = serializers.BooleanField(required=False, default=False) + verify_ssl = serializers.BooleanField(required=False, default=True) diff --git a/ansible_ai_connect/ai/api/model_pipelines/llamacpp/configuration.py b/ansible_ai_connect/ai/api/model_pipelines/llamacpp/configuration.py index a1a3d60e8..a96a141f5 100644 --- a/ansible_ai_connect/ai/api/model_pipelines/llamacpp/configuration.py +++ b/ansible_ai_connect/ai/api/model_pipelines/llamacpp/configuration.py @@ -66,4 +66,4 @@ def __init__(self, **kwargs): @Register(api_type="llamacpp") class LlamaCppConfigurationSerializer(BaseConfigSerializer): - verify_ssl = serializers.BooleanField(required=False, default=False) + verify_ssl = serializers.BooleanField(required=False, default=True) diff --git a/ansible_ai_connect/ai/api/model_pipelines/tests/test_config_loader.py b/ansible_ai_connect/ai/api/model_pipelines/tests/test_config_loader.py new file mode 100644 index 000000000..d5740a4a4 --- /dev/null +++ b/ansible_ai_connect/ai/api/model_pipelines/tests/test_config_loader.py @@ -0,0 +1,85 @@ +# Copyright Red Hat +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +from json import JSONDecodeError + +import yaml +from django.test import override_settings +from rest_framework.exceptions import ValidationError +from yaml import YAMLError + +from ansible_ai_connect.ai.api.model_pipelines.config_loader import load_config +from ansible_ai_connect.ai.api.model_pipelines.config_providers import Configuration +from ansible_ai_connect.ai.api.model_pipelines.pipelines import MetaData +from ansible_ai_connect.ai.api.model_pipelines.registry import REGISTRY_ENTRY +from ansible_ai_connect.ai.api.model_pipelines.tests import mock_config +from ansible_ai_connect.test_utils import WisdomTestCase + +EMPTY = { + "MetaData": { + "provider": "dummy", + }, +} + + +def _convert_json_to_yaml(json_config: str): + yaml_config = yaml.safe_load(json_config) + return yaml.safe_dump(yaml_config) + + +class TestConfigLoader(WisdomTestCase): + + def assert_config(self): + config: Configuration = load_config() + pipelines = [i for i in REGISTRY_ENTRY.keys() if issubclass(i, MetaData)] + for k in pipelines: + self.assertTrue(k.__name__ in config) + + def assert_invalid_config(self): + with self.assertRaises(ExceptionGroup) as e: + load_config() + exceptions = e.exception.exceptions + self.assertEqual(len(exceptions), 2) + self.assertIsInstance(exceptions[0], JSONDecodeError) + self.assertIsInstance(exceptions[1], YAMLError) + + @override_settings(ANSIBLE_AI_MODEL_MESH_CONFIG=None) + def test_config_undefined(self): + with self.assertRaises(TypeError): + load_config() + + @override_settings(ANSIBLE_AI_MODEL_MESH_CONFIG=json.dumps(EMPTY)) + def test_config_empty(self): + self.assert_config() + + @override_settings(ANSIBLE_AI_MODEL_MESH_CONFIG="") + def test_config_empty_string(self): + with self.assertRaises(ValidationError): + self.assert_config() + + @override_settings(ANSIBLE_AI_MODEL_MESH_CONFIG='{"MetaData" : {') + def test_config_invalid_json(self): + self.assert_invalid_config() + + @override_settings(ANSIBLE_AI_MODEL_MESH_CONFIG="MetaData:\nbanana") + def test_config_invalid_yaml(self): + self.assert_invalid_config() + + @override_settings(ANSIBLE_AI_MODEL_MESH_CONFIG=mock_config("ollama")) + def test_config_json(self): + self.assert_config() + + @override_settings(ANSIBLE_AI_MODEL_MESH_CONFIG=_convert_json_to_yaml(mock_config("ollama"))) + def test_config_yaml(self): + self.assert_config() diff --git a/ansible_ai_connect/ai/api/model_pipelines/tests/test_wca_client.py b/ansible_ai_connect/ai/api/model_pipelines/tests/test_wca_client.py index 839ab111f..b94247157 100644 --- a/ansible_ai_connect/ai/api/model_pipelines/tests/test_wca_client.py +++ b/ansible_ai_connect/ai/api/model_pipelines/tests/test_wca_client.py @@ -1325,7 +1325,6 @@ def test_codematch_empty_response(self): self.assertEqual(e.exception.model_id, model_id) -@override_settings(ANSIBLE_AI_MODEL_MESH_MODEL_ID=None) class TestDummySecretManager(TestCase): def setUp(self): super().setUp() diff --git a/ansible_ai_connect/ai/api/model_pipelines/wca/configuration_base.py b/ansible_ai_connect/ai/api/model_pipelines/wca/configuration_base.py index b67603d1f..fdbcaf215 100644 --- a/ansible_ai_connect/ai/api/model_pipelines/wca/configuration_base.py +++ b/ansible_ai_connect/ai/api/model_pipelines/wca/configuration_base.py @@ -77,7 +77,7 @@ def __init__(self, provider: t_model_mesh_api_type, config: WCABaseConfiguration class WCABaseConfigurationSerializer(BaseConfigSerializer): api_key = serializers.CharField(required=False, allow_null=True, allow_blank=True) - verify_ssl = serializers.BooleanField(required=False, default=False) + verify_ssl = serializers.BooleanField(required=False, default=True) retry_count = serializers.IntegerField(required=False, default=4) enable_ari_postprocessing = serializers.BooleanField(required=False, default=False) health_check_api_key = serializers.CharField(required=True) diff --git a/ansible_ai_connect/ai/api/model_pipelines/wca/pipelines_onprem.py b/ansible_ai_connect/ai/api/model_pipelines/wca/pipelines_onprem.py index 060df3314..3297f5dc8 100644 --- a/ansible_ai_connect/ai/api/model_pipelines/wca/pipelines_onprem.py +++ b/ansible_ai_connect/ai/api/model_pipelines/wca/pipelines_onprem.py @@ -108,8 +108,8 @@ def __init__(self, config: WCAOnPremConfiguration): raise WcaUsernameNotFound if not self.config.api_key: raise WcaKeyNotFound - # ANSIBLE_AI_MODEL_MESH_MODEL_ID cannot be validated until runtime. The - # User may provide an override value if the Environment Variable is not set. + # WCAOnPremConfiguration.model_id cannot be validated until runtime. The + # User may provide an override value if the setting is not defined. def get_request_headers( self, api_key: str, identifier: Optional[str] diff --git a/ansible_ai_connect/ai/api/model_pipelines/wca/pipelines_saas.py b/ansible_ai_connect/ai/api/model_pipelines/wca/pipelines_saas.py index 9a8b60bf4..7e2180236 100644 --- a/ansible_ai_connect/ai/api/model_pipelines/wca/pipelines_saas.py +++ b/ansible_ai_connect/ai/api/model_pipelines/wca/pipelines_saas.py @@ -144,7 +144,7 @@ def get_api_key(self, user, organization_id: Optional[int]) -> str: if organization_id is None: logger.error( - "User does not have an organization and no ANSIBLE_AI_MODEL_MESH_API_KEY is set" + "User does not have an organization and WCASaaSConfiguration.api_key is not set" ) raise WcaKeyNotFound diff --git a/ansible_ai_connect/ai/api/permissions.py b/ansible_ai_connect/ai/api/permissions.py index a0cfcd80e..62d4cc0cb 100644 --- a/ansible_ai_connect/ai/api/permissions.py +++ b/ansible_ai_connect/ai/api/permissions.py @@ -164,4 +164,4 @@ class IsWCASaaSModelPipeline(permissions.BasePermission): message = "User doesn't have access to the IBM watsonx Code Assistant." def has_permission(self, request, view): - return CONTINUE if settings.ANSIBLE_AI_MODEL_MESH_API_TYPE == "wca" else BLOCK + return CONTINUE if settings.DEPLOYMENT_MODE == "saas" else BLOCK diff --git a/ansible_ai_connect/ai/api/pipelines/completion_stages/inference.py b/ansible_ai_connect/ai/api/pipelines/completion_stages/inference.py index 97dd4de6e..193a8b191 100644 --- a/ansible_ai_connect/ai/api/pipelines/completion_stages/inference.py +++ b/ansible_ai_connect/ai/api/pipelines/completion_stages/inference.py @@ -19,7 +19,6 @@ from ansible_anonymizer import anonymizer from django.apps import apps -from django.conf import settings from django_prometheus.conf import NAMESPACE from prometheus_client import Histogram @@ -124,7 +123,7 @@ def process(self, context: CompletionContext) -> None: except ModelTimeoutError as e: exception = e logger.warning( - f"model timed out after {settings.ANSIBLE_AI_MODEL_MESH_API_TIMEOUT} " + f"model timed out after {model_mesh_client.config.timeout} " f"seconds (per task) for suggestion {suggestion_id}" ) raise ModelTimeoutException(cause=e) diff --git a/ansible_ai_connect/ai/api/tests/test_permissions.py b/ansible_ai_connect/ai/api/tests/test_permissions.py index a3d3f3ee6..b3328e713 100644 --- a/ansible_ai_connect/ai/api/tests/test_permissions.py +++ b/ansible_ai_connect/ai/api/tests/test_permissions.py @@ -209,10 +209,10 @@ def test_ensure_trial_user_can_pass_through_despite_trial_disabled(self): class TestBlockUserWithoutWCASaaSConfiguration(WisdomAppsBackendMocking): - @override_settings(ANSIBLE_AI_MODEL_MESH_API_TYPE="wca") + @override_settings(DEPLOYMENT_MODE="saas") def test_wca_saas_enabled(self): self.assertEqual(IsWCASaaSModelPipeline().has_permission(Mock(), None), CONTINUE) - @override_settings(ANSIBLE_AI_MODEL_MESH_API_TYPE="wca-onprem") + @override_settings(DEPLOYMENT_MODE="onprem") def test_wca_saas_not_enabled(self): self.assertEqual(IsWCASaaSModelPipeline().has_permission(Mock(), None), BLOCK) diff --git a/ansible_ai_connect/ai/api/tests/test_views.py b/ansible_ai_connect/ai/api/tests/test_views.py index 5cc1218b4..56a02c5dc 100644 --- a/ansible_ai_connect/ai/api/tests/test_views.py +++ b/ansible_ai_connect/ai/api/tests/test_views.py @@ -27,7 +27,6 @@ import requests from django.apps import apps -from django.conf import settings from django.contrib.auth import get_user_model from django.test import modify_settings, override_settings from django.urls import reverse @@ -1017,7 +1016,7 @@ def test_full_payload(self): "suggestionId": str(uuid.uuid4()), } response_data = { - "model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID, + "model_id": "a-model-id", "predictions": [" ansible.builtin.apt:\n name: apache2"], } self.client.force_authenticate(user=self.user) @@ -1039,7 +1038,7 @@ def test_multi_task_prompt_commercial(self): "suggestionId": str(uuid.uuid4()), } response_data = { - "model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID, + "model_id": "a-model-id", "predictions": [ "- name: Install Apache\n ansible.builtin.apt:\n name: apache2\n state: latest\n- name: start Apache\n ansible.builtin.service:\n name: apache2\n state: started\n enabled: yes\n" # noqa: E501 ], @@ -1087,7 +1086,7 @@ def test_multi_task_prompt_commercial_with_pii(self): "suggestionId": str(uuid.uuid4()), } response_data = { - "model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID, + "model_id": "a-model-id", "predictions": [ " - name: Install Apache\n ansible.builtin.apt:\n name: apache2\n state: latest\n - name: say hello test@example.com\n ansible.builtin.debug:\n msg: Hello there olivia1@example.com\n" # noqa: E501 ], @@ -1126,7 +1125,7 @@ def test_rate_limit(self): "suggestionId": str(uuid.uuid4()), } response_data = { - "model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID, + "model_id": "a-model-id", "predictions": [" ansible.builtin.apt:\n name: apache2"], } self.client.force_authenticate(user=self.user) @@ -1150,7 +1149,7 @@ def test_missing_prompt(self): "suggestionId": str(uuid.uuid4()), } response_data = { - "model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID, + "model_id": "a-model-id", "predictions": [" ansible.builtin.apt:\n name: apache2"], } self.client.force_authenticate(user=self.user) @@ -1171,7 +1170,7 @@ def test_authentication_error(self): "suggestionId": str(uuid.uuid4()), } response_data = { - "model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID, + "model_id": "a-model-id", "predictions": [" ansible.builtin.apt:\n name: apache2"], } # self.client.force_authenticate(user=self.user) @@ -1200,7 +1199,7 @@ def test_completions_preprocessing_error(self): "suggestionId": str(uuid.uuid4()), } response_data = { - "model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID, + "model_id": "a-model-id", "predictions": [" ansible.builtin.apt:\n name: apache2"], } self.client.force_authenticate(user=self.user) @@ -1226,7 +1225,7 @@ def test_completions_preprocessing_error_without_name_prompt(self): "suggestionId": str(uuid.uuid4()), } response_data = { - "model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID, + "model_id": "a-model-id", "predictions": [" ansible.builtin.apt:\n name: apache2"], } self.client.force_authenticate(user=self.user) @@ -1250,7 +1249,7 @@ def test_full_payload_without_ARI(self): "suggestionId": str(uuid.uuid4()), } response_data = { - "model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID, + "model_id": "a-model-id", "predictions": [" ansible.builtin.apt:\n name: apache2"], } self.client.force_authenticate(user=self.user) @@ -1275,7 +1274,7 @@ def test_full_payload_with_recommendation_with_broken_last_line(self): } # quotation in the last line is not closed, but the truncate function can handle this. response_data = { - "model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID, + "model_id": "a-model-id", "predictions": [ ' ansible.builtin.apt:\n name: apache2\n register: "test' ], @@ -1302,7 +1301,7 @@ def test_completions_postprocessing_error_for_invalid_yaml(self): } # this prediction has indentation problem with the prompt above response_data = { - "model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID, + "model_id": "a-model-id", "predictions": [" ansible.builtin.apt:\n garbage name: apache2"], } self.client.force_authenticate(user=self.user) @@ -1374,7 +1373,7 @@ def test_full_payload_without_ansible_lint_with_commercial_user(self): "suggestionId": str(uuid.uuid4()), } response_data = { - "model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID, + "model_id": "a-model-id", "predictions": [" ansible.builtin.apt:\n name: apache2"], } self.client.force_authenticate(user=self.user) @@ -1557,7 +1556,7 @@ def test_completions_pii_clean_up(self): "suggestionId": str(uuid.uuid4()), } response_data = { - "model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID, + "model_id": "a-model-id", "predictions": [""], } self.client.force_authenticate(user=self.user) @@ -4049,9 +4048,7 @@ def json(self): json_response["response"] = input return MockResponse(json_response, status_code) - @override_settings(CHATBOT_URL="http://localhost:8080") @override_settings(CHATBOT_DEFAULT_PROVIDER="wisdom") - @override_settings(CHATBOT_DEFAULT_MODEL="granite-8b") @mock.patch( "requests.post", side_effect=mocked_requests_post, @@ -4063,7 +4060,6 @@ def query_with_no_error(self, payload, mock_post): "requests.post", side_effect=mocked_requests_post, ) - @override_settings(CHATBOT_URL="") def query_without_chat_config(self, payload, mock_post): return self.client.post(reverse("chat"), payload, format="json") @@ -4175,7 +4171,11 @@ def test_operational_telemetry(self): patch.object( apps.get_app_config("ai"), "get_model_pipeline", - Mock(return_value=HttpChatBotPipeline(mock_pipeline_config("http"))), + Mock( + return_value=HttpChatBotPipeline( + mock_pipeline_config("http", model_id="granite-8b") + ) + ), ), self.assertLogs(logger="root", level="DEBUG") as log, ): @@ -4282,7 +4282,11 @@ def test_operational_telemetry_with_system_prompt_override(self): patch.object( apps.get_app_config("ai"), "get_model_pipeline", - Mock(return_value=HttpChatBotPipeline(mock_pipeline_config("http"))), + Mock( + return_value=HttpChatBotPipeline( + mock_pipeline_config("http", model_id="granite-8b") + ) + ), ), self.assertLogs(logger="root", level="DEBUG") as log, ): diff --git a/ansible_ai_connect/ai/api/views.py b/ansible_ai_connect/ai/api/views.py index a5f66d95e..224298aca 100644 --- a/ansible_ai_connect/ai/api/views.py +++ b/ansible_ai_connect/ai/api/views.py @@ -611,7 +611,7 @@ def perform_content_matching( except ModelTimeoutError as e: exception = e logger.warn( - f"model timed out after {settings.ANSIBLE_AI_MODEL_MESH_API_TIMEOUT} seconds" + f"model timed out after {model_mesh_client.config.timeout} seconds" f" for suggestion {suggestion_id}" ) raise ModelTimeoutException(cause=e) @@ -957,11 +957,15 @@ class ChatEndpointThrottle(EndpointRateThrottle): request_serializer_class = ChatRequestSerializer throttle_classes = [ChatEndpointThrottle] + llm: ModelPipelineChatBot + def __init__(self): super().__init__() + self.llm = apps.get_app_config("ai").get_model_pipeline(ModelPipelineChatBot) + self.chatbot_enabled = ( - settings.CHATBOT_URL - and settings.CHATBOT_DEFAULT_MODEL + self.llm.config.inference_url + and self.llm.config.model_id and settings.CHATBOT_DEFAULT_PROVIDER ) if self.chatbot_enabled: @@ -998,18 +1002,14 @@ def post(self, request) -> Response: self.event.chat_system_prompt = req_system_prompt self.event.provider_id = req_provider self.event.conversation_id = conversation_id - self.event.modelName = self.req_model_id or settings.CHATBOT_DEFAULT_MODEL - - llm: ModelPipelineChatBot = apps.get_app_config("ai").get_model_pipeline( - ModelPipelineChatBot - ) + self.event.modelName = self.req_model_id or self.llm.config.model_id - data = llm.invoke( + data = self.llm.invoke( ChatBotParameters.init( request=request, query=req_query, system_prompt=req_system_prompt, - model_id=self.req_model_id or settings.CHATBOT_DEFAULT_MODEL, + model_id=self.req_model_id or self.llm.config.model_id, provider=req_provider, conversation_id=conversation_id, ) diff --git a/ansible_ai_connect/ai/api/wca/tests/test_api_key_views.py b/ansible_ai_connect/ai/api/wca/tests/test_api_key_views.py index f6def4626..008fbbbc8 100644 --- a/ansible_ai_connect/ai/api/wca/tests/test_api_key_views.py +++ b/ansible_ai_connect/ai/api/wca/tests/test_api_key_views.py @@ -43,7 +43,7 @@ from ansible_ai_connect.test_utils import WisdomAppsBackendMocking -@override_settings(ANSIBLE_AI_MODEL_MESH_API_TYPE="wca") +@override_settings(DEPLOYMENT_MODE="saas") @override_settings(WCA_SECRET_BACKEND_TYPE="aws_sm") @patch.object(IsOrganisationAdministrator, "has_permission", return_value=True) @patch.object(IsOrganisationLightspeedSubscriber, "has_permission", return_value=True) @@ -547,7 +547,7 @@ def test_get_api_key_as_non_subscriber(self, *args): self.assertEqual(r.status_code, HTTPStatus.FORBIDDEN) -@override_settings(ANSIBLE_AI_MODEL_MESH_API_TYPE="wca") +@override_settings(DEPLOYMENT_MODE="saas") @override_settings(WCA_SECRET_BACKEND_TYPE="aws_sm") @patch.object(IsOrganisationAdministrator, "has_permission", return_value=True) @patch.object(IsOrganisationLightspeedSubscriber, "has_permission", return_value=True) diff --git a/ansible_ai_connect/ai/api/wca/tests/test_model_id_views.py b/ansible_ai_connect/ai/api/wca/tests/test_model_id_views.py index d3ea1ed80..495043ee4 100644 --- a/ansible_ai_connect/ai/api/wca/tests/test_model_id_views.py +++ b/ansible_ai_connect/ai/api/wca/tests/test_model_id_views.py @@ -44,7 +44,7 @@ from ansible_ai_connect.test_utils import WisdomAppsBackendMocking, WisdomLogAwareMixin -@override_settings(ANSIBLE_AI_MODEL_MESH_API_TYPE="wca") +@override_settings(DEPLOYMENT_MODE="saas") @override_settings(WCA_SECRET_BACKEND_TYPE="aws_sm") @patch.object(IsOrganisationAdministrator, "has_permission", return_value=True) @patch.object(IsOrganisationLightspeedSubscriber, "has_permission", return_value=True) @@ -319,7 +319,7 @@ def test_get_model_id_as_non_subscriber(self, *args): self.assertEqual(r.status_code, HTTPStatus.FORBIDDEN) -@override_settings(ANSIBLE_AI_MODEL_MESH_API_TYPE="wca") +@override_settings(DEPLOYMENT_MODE="saas") @override_settings(WCA_SECRET_BACKEND_TYPE="aws_sm") @patch.object(IsOrganisationAdministrator, "has_permission", return_value=True) @patch.object(IsOrganisationLightspeedSubscriber, "has_permission", return_value=True) diff --git a/ansible_ai_connect/ai/tests/test_apps.py b/ansible_ai_connect/ai/tests/test_apps.py index 06bf8d9b2..08ec8ccdb 100644 --- a/ansible_ai_connect/ai/tests/test_apps.py +++ b/ansible_ai_connect/ai/tests/test_apps.py @@ -119,7 +119,6 @@ def test_enable_ari_wca_cloud(self): self.assertIsNotNone(app_config.get_ari_caller()) @override_settings(ENABLE_ARI_POSTPROCESS=True) - @override_settings(WCA_ENABLE_ARI_POSTPROCESS=False) @override_settings(ANSIBLE_AI_MODEL_MESH_CONFIG=mock_config("wca")) def test_enable_ari_wca_cloud_disable_wca(self): app_config = AppConfig.create("ansible_ai_connect.ai") @@ -127,7 +126,6 @@ def test_enable_ari_wca_cloud_disable_wca(self): self.assertIsNone(app_config.get_ari_caller()) @override_settings(ENABLE_ARI_POSTPROCESS=False) - @override_settings(WCA_ENABLE_ARI_POSTPROCESS=True) @override_settings(ANSIBLE_AI_MODEL_MESH_CONFIG=mock_config("wca")) def test_disable_ari_wca_cloud_enable_wca(self): app_config = AppConfig.create("ansible_ai_connect.ai") @@ -135,7 +133,6 @@ def test_disable_ari_wca_cloud_enable_wca(self): self.assertIsNone(app_config.get_ari_caller()) @override_settings(ENABLE_ARI_POSTPROCESS=False) - @override_settings(WCA_ENABLE_ARI_POSTPROCESS=False) @override_settings(ANSIBLE_AI_MODEL_MESH_CONFIG=mock_config("wca")) def test_disable_ari_wca_cloud_disable_wca(self): app_config = AppConfig.create("ansible_ai_connect.ai") diff --git a/ansible_ai_connect/healthcheck/tests/test_healthcheck.py b/ansible_ai_connect/healthcheck/tests/test_healthcheck.py index d3294d043..ab7d9b913 100644 --- a/ansible_ai_connect/healthcheck/tests/test_healthcheck.py +++ b/ansible_ai_connect/healthcheck/tests/test_healthcheck.py @@ -81,9 +81,7 @@ @override_settings(LAUNCHDARKLY_SDK_KEY=None) @override_settings(AUTHZ_BACKEND_TYPE="dummy") @override_settings(WCA_SECRET_BACKEND_TYPE="dummy") -@override_settings(CHATBOT_URL="dummy") @override_settings(CHATBOT_DEFAULT_PROVIDER="wisdom") -@override_settings(CHATBOT_DEFAULT_MODEL="granite-8b") class BaseTestHealthCheck(WisdomAppsBackendMocking, APITestCase, WisdomServiceLogAwareTestCase): def setUp(self): super().setUp() @@ -390,7 +388,6 @@ def test_health_check_authorization_disabled(self): else: self.assertTrue(self.is_status_ok(dependency["status"], "dummy")) - @override_settings(CHATBOT_URL="http://localhost:8080") @mock.patch( "requests.get", side_effect=lambda *args, **kwargs: BaseTestHealthCheck.mocked_requests_succeed( @@ -414,7 +411,6 @@ def test_health_check_chatbot_service(self, mock_get): for dependency in dependencies: self.assertTrue(self.is_status_ok(dependency["status"], "http")) - @override_settings(CHATBOT_URL="http://localhost:8080") @mock.patch( "requests.get", side_effect=lambda *args, **kwargs: BaseTestHealthCheck.mocked_requests_succeed( @@ -449,7 +445,6 @@ def test_health_check_chatbot_service_index_not_ready(self, mock_get): {"provider": "http", "models": "unavailable: index is not ready"}, ) - @override_settings(CHATBOT_URL="http://localhost:8080") @mock.patch( "requests.get", side_effect=lambda *args, **kwargs: BaseTestHealthCheck.mocked_requests_succeed( @@ -484,7 +479,6 @@ def test_health_check_chatbot_service_llm_not_ready(self, mock_get): {"provider": "http", "models": "unavailable: llm is not ready"}, ) - @override_settings(CHATBOT_URL="http://localhost:8080") @mock.patch( "requests.get", side_effect=HTTPError, @@ -517,7 +511,6 @@ def test_health_check_chatbot_service_error(self, mock_get): {"provider": "http", "models": "unavailable: An error occurred"}, ) - @override_settings(CHATBOT_URL="http://localhost:8080") @mock.patch( "requests.get", side_effect=BaseTestHealthCheck.mocked_requests_failed, @@ -550,8 +543,6 @@ def test_health_check_chatbot_service_non_200_response(self, mock_get): {"provider": "http", "models": "unavailable: An error occurred"}, ) - @override_settings(ENABLE_HEALTHCHECK_CHATBOT_SERVICE=False) - @override_settings(CHATBOT_URL="http://localhost:8080") @mock.patch( "requests.get", side_effect=BaseTestHealthCheck.mocked_requests_succeed, diff --git a/ansible_ai_connect/main/settings/base.py b/ansible_ai_connect/main/settings/base.py index 8d186634b..e477895f7 100644 --- a/ansible_ai_connect/main/settings/base.py +++ b/ansible_ai_connect/main/settings/base.py @@ -28,14 +28,13 @@ import logging import os import sys -from copy import deepcopy from importlib.resources import files from pathlib import Path from typing import cast +from ansible_ai_connect.main.settings.legacy import load_from_env_vars from ansible_ai_connect.main.settings.types import ( t_deployment_mode, - t_model_mesh_api_type, t_one_click_reports_postman_type, t_wca_secret_backend_type, ) @@ -49,52 +48,6 @@ # Quick-start development settings - unsuitable for production # See https://docs.djangoproject.com/en/4.1/howto/deployment/checklist/ -# ========================================== -# Model Provider -# ------------------------------------------ -ANSIBLE_AI_MODEL_MESH_API_TYPE: t_model_mesh_api_type = os.getenv( - "ANSIBLE_AI_MODEL_MESH_API_TYPE" -) or cast(t_model_mesh_api_type, "http") - -ANSIBLE_AI_MODEL_MESH_API_URL = ( - os.getenv("ANSIBLE_AI_MODEL_MESH_API_URL") or "https://model.wisdom.testing.ansible.com:443" -) - -ANSIBLE_AI_MODEL_MESH_API_KEY = os.getenv("ANSIBLE_AI_MODEL_MESH_API_KEY") -ANSIBLE_AI_MODEL_MESH_MODEL_ID = os.getenv("ANSIBLE_AI_MODEL_MESH_MODEL_ID") -if "ANSIBLE_AI_MODEL_MESH_MODEL_NAME" in os.environ: - logger.warning( - "Use of ANSIBLE_AI_MODEL_MESH_MODEL_NAME is deprecated and " - "should be replaced with ANSIBLE_AI_MODEL_MESH_MODEL_ID." - ) - if "ANSIBLE_AI_MODEL_MESH_MODEL_ID" in os.environ: - logger.warning( - "Environment variable ANSIBLE_AI_MODEL_MESH_MODEL_ID is set and will take precedence." - ) - else: - logger.warning( - "Setting the value of ANSIBLE_AI_MODEL_MESH_MODEL_ID to " - "the value of ANSIBLE_AI_MODEL_MESH_MODEL_NAME." - ) - ANSIBLE_AI_MODEL_MESH_MODEL_ID = os.getenv("ANSIBLE_AI_MODEL_MESH_MODEL_NAME") - -# Model API Timeout (in seconds). Default is None. -ANSIBLE_AI_MODEL_MESH_API_TIMEOUT = os.getenv("ANSIBLE_AI_MODEL_MESH_API_TIMEOUT") - -# WCA - General -ANSIBLE_WCA_IDP_URL = os.getenv("ANSIBLE_WCA_IDP_URL") or "https://iam.cloud.ibm.com/identity" -ANSIBLE_WCA_IDP_LOGIN = os.getenv("ANSIBLE_WCA_IDP_LOGIN") -ANSIBLE_WCA_IDP_PASSWORD = os.getenv("ANSIBLE_WCA_IDP_PASSWORD") -ANSIBLE_WCA_RETRY_COUNT = int(os.getenv("ANSIBLE_WCA_RETRY_COUNT") or "4") -ANSIBLE_WCA_HEALTHCHECK_API_KEY = os.getenv("ANSIBLE_WCA_HEALTHCHECK_API_KEY") -ANSIBLE_WCA_HEALTHCHECK_MODEL_ID = os.getenv("ANSIBLE_WCA_HEALTHCHECK_MODEL_ID") -# WCA - "On prem" -ANSIBLE_WCA_USERNAME = os.getenv("ANSIBLE_WCA_USERNAME") - -# GRPC -ANSIBLE_GRPC_HEALTHCHECK_URL = os.getenv("ANSIBLE_GRPC_HEALTHCHECK_URL") -# ========================================== - SECRET_KEY = os.environ["SECRET_KEY"] ALLOWED_HOSTS = ["localhost"] @@ -187,9 +140,6 @@ def is_ssl_enabled(value: str) -> bool: AAP_API_URL = os.environ.get("AAP_API_URL") AAP_API_PROVIDER_NAME = os.environ.get("AAP_API_PROVIDER_NAME", "Ansible Automation Platform") SOCIAL_AUTH_VERIFY_SSL = is_ssl_enabled(os.getenv("SOCIAL_AUTH_VERIFY_SSL", "True")) -ANSIBLE_AI_MODEL_MESH_API_VERIFY_SSL = is_ssl_enabled( - os.getenv("ANSIBLE_AI_MODEL_MESH_API_VERIFY_SSL", "True") -) SOCIAL_AUTH_AAP_KEY = os.environ.get("SOCIAL_AUTH_AAP_KEY") SOCIAL_AUTH_AAP_SECRET = os.environ.get("SOCIAL_AUTH_AAP_SECRET") SOCIAL_AUTH_AAP_SCOPE = ["read"] @@ -465,21 +415,6 @@ def is_ssl_enabled(value: str) -> bool: APPEND_SLASH = True -DUMMY_MODEL_RESPONSE_BODY = os.environ.get( - "DUMMY_MODEL_RESPONSE_BODY", - ( - '{"predictions":["ansible.builtin.apt:\\n name: nginx\\n' - ' update_cache: true\\n state: present\\n"]}' - ), -) - -DUMMY_MODEL_RESPONSE_MAX_LATENCY_MSEC = int( - os.environ.get("DUMMY_MODEL_RESPONSE_MAX_LATENCY_MSEC", 3000) -) -DUMMY_MODEL_RESPONSE_LATENCY_USE_JITTER = bool( - os.environ.get("DUMMY_MODEL_RESPONSE_LATENCY_USE_JITTER", False) -) - ENABLE_ARI_POSTPROCESS = os.getenv("ENABLE_ARI_POSTPROCESS", "False").lower() == "true" ARI_BASE_DIR = os.getenv("ARI_KB_PATH") or "/etc/ari/kb/" ARI_RULES_DIR = os.path.join(ARI_BASE_DIR, "rules") @@ -546,7 +481,6 @@ def is_ssl_enabled(value: str) -> bool: WCA_SECRET_MANAGER_REPLICA_REGIONS = [ c.strip() for c in os.getenv("WCA_SECRET_MANAGER_REPLICA_REGIONS", "").split(",") if c ] -WCA_ENABLE_ARI_POSTPROCESS = os.getenv("WCA_ENABLE_ARI_POSTPROCESS", "False").lower() == "true" CSP_DEFAULT_SRC = ("'self'", "data:") CSP_INCLUDE_NONCE_IN = ["script-src-elem"] @@ -560,7 +494,6 @@ def is_ssl_enabled(value: str) -> bool: # ------------------------------------------ # Support to disable health checks. The default is that they are enabled. # The naming convention in the existing settings is to ENABLE_XXX and not DISABLE_XXX. -ENABLE_HEALTHCHECK_MODEL_MESH = os.getenv("ENABLE_HEALTHCHECK_MODEL_MESH", "True").lower() == "true" ENABLE_HEALTHCHECK_SECRET_MANAGER = ( os.getenv("ENABLE_HEALTHCHECK_SECRET_MANAGER", "True").lower() == "true" ) @@ -570,10 +503,6 @@ def is_ssl_enabled(value: str) -> bool: ENABLE_HEALTHCHECK_ATTRIBUTION = ( os.getenv("ENABLE_HEALTHCHECK_ATTRIBUTION", "True").lower() == "true" ) -ENABLE_HEALTHCHECK_CHATBOT_SERVICE = ( - os.getenv("ENABLE_HEALTHCHECK_CHATBOT_SERVICE", "True").lower() == "true" -) - # ========================================== # ========================================== @@ -593,16 +522,9 @@ def is_ssl_enabled(value: str) -> bool: os.getenv("ANSIBLE_AI_ENABLE_ONE_CLICK_TRIAL", "False").lower() == "true" ) -ANSIBLE_AI_ENABLE_ONE_CLICK_DEFAULT_API_KEY: str = ( - os.getenv("ANSIBLE_AI_ENABLE_ONE_CLICK_DEFAULT_API_KEY") or "" -) -ANSIBLE_AI_ENABLE_ONE_CLICK_DEFAULT_MODEL_ID: str = ( - os.getenv("ANSIBLE_AI_ENABLE_ONE_CLICK_DEFAULT_MODEL_ID") or "" +ANSIBLE_AI_ONE_CLICK_REPORTS_POSTMAN: t_one_click_reports_postman_type = cast( + t_one_click_reports_postman_type, os.getenv("ANSIBLE_AI_ONE_CLICK_REPORTS_POSTMAN") or "none" ) - -ANSIBLE_AI_ONE_CLICK_REPORTS_POSTMAN: t_one_click_reports_postman_type = os.getenv( - "ANSIBLE_AI_ONE_CLICK_REPORTS_POSTMAN" -) or cast(t_one_click_reports_postman_type, "none") ANSIBLE_AI_ONE_CLICK_REPORTS_CONFIG: dict = ( json.loads(os.getenv("ANSIBLE_AI_ONE_CLICK_REPORTS_CONFIG"), strict=False) if os.getenv("ANSIBLE_AI_ONE_CLICK_REPORTS_CONFIG") @@ -613,9 +535,7 @@ def is_ssl_enabled(value: str) -> bool: # ========================================== # Chatbot # ------------------------------------------ -CHATBOT_URL = os.getenv("CHATBOT_URL") CHATBOT_DEFAULT_PROVIDER = os.getenv("CHATBOT_DEFAULT_PROVIDER") -CHATBOT_DEFAULT_MODEL = os.getenv("CHATBOT_DEFAULT_MODEL") CHATBOT_DEBUG_UI = os.getenv("CHATBOT_DEBUG_UI", "False").lower() == "true" # ========================================== @@ -630,99 +550,5 @@ def is_ssl_enabled(value: str) -> bool: # ========================================== # Pipeline configuration # ------------------------------------------ -# [manstis] This will be enabled when we update AWS SM configuration -# ANSIBLE_AI_MODEL_MESH_CONFIG = os.getenv("ANSIBLE_AI_MODEL_MESH_CONFIG") -# -# [manstis] For now, populate the configuration from the environment variables - -if ANSIBLE_AI_MODEL_MESH_API_TYPE == "wca": - ANSIBLE_AI_PIPELINE_CONFIG = { - "provider": "wca", - "config": { - "inference_url": ANSIBLE_AI_MODEL_MESH_API_URL, - "api_key": ANSIBLE_AI_MODEL_MESH_API_KEY, - "model_id": ANSIBLE_AI_MODEL_MESH_MODEL_ID, - "timeout": ANSIBLE_AI_MODEL_MESH_API_TIMEOUT, - "verify_ssl": ANSIBLE_AI_MODEL_MESH_API_VERIFY_SSL, - "retry_count": ANSIBLE_WCA_RETRY_COUNT, - "enable_ari_postprocessing": WCA_ENABLE_ARI_POSTPROCESS, - "health_check_api_key": ANSIBLE_WCA_HEALTHCHECK_API_KEY, - "health_check_model_id": ANSIBLE_WCA_HEALTHCHECK_MODEL_ID, - "idp_url": ANSIBLE_WCA_IDP_URL, - "idp_login": ANSIBLE_WCA_IDP_LOGIN, - "idp_password": ANSIBLE_WCA_IDP_PASSWORD, - "one_click_default_api_key": ANSIBLE_AI_ENABLE_ONE_CLICK_DEFAULT_API_KEY, - "one_click_default_model_id": ANSIBLE_AI_ENABLE_ONE_CLICK_DEFAULT_MODEL_ID, - }, - } -elif ANSIBLE_AI_MODEL_MESH_API_TYPE == "wca-onprem": - ANSIBLE_AI_PIPELINE_CONFIG = { - "provider": "wca-onprem", - "config": { - "inference_url": ANSIBLE_AI_MODEL_MESH_API_URL, - "api_key": ANSIBLE_AI_MODEL_MESH_API_KEY, - "model_id": ANSIBLE_AI_MODEL_MESH_MODEL_ID, - "timeout": ANSIBLE_AI_MODEL_MESH_API_TIMEOUT, - "verify_ssl": ANSIBLE_AI_MODEL_MESH_API_VERIFY_SSL, - "retry_count": ANSIBLE_WCA_RETRY_COUNT, - "enable_ari_postprocessing": WCA_ENABLE_ARI_POSTPROCESS, - "health_check_api_key": ANSIBLE_WCA_HEALTHCHECK_API_KEY, - "health_check_model_id": ANSIBLE_WCA_HEALTHCHECK_MODEL_ID, - "username": ANSIBLE_WCA_USERNAME, - }, - } -elif ANSIBLE_AI_MODEL_MESH_API_TYPE == "dummy": - ANSIBLE_AI_PIPELINE_CONFIG = { - "provider": "dummy", - "config": { - "inference_url": ANSIBLE_AI_MODEL_MESH_API_URL, - "body": DUMMY_MODEL_RESPONSE_BODY, - "latency_max_msec": DUMMY_MODEL_RESPONSE_MAX_LATENCY_MSEC, - "latency_use_jitter": DUMMY_MODEL_RESPONSE_LATENCY_USE_JITTER, - }, - } -elif ANSIBLE_AI_MODEL_MESH_API_TYPE == "ollama": - ANSIBLE_AI_PIPELINE_CONFIG = { - "provider": "ollama", - "config": { - "inference_url": ANSIBLE_AI_MODEL_MESH_API_URL, - "model_id": ANSIBLE_AI_MODEL_MESH_MODEL_ID, - "timeout": ANSIBLE_AI_MODEL_MESH_API_TIMEOUT, - }, - } -else: - ANSIBLE_AI_PIPELINE_CONFIG = { - "provider": "wca-dummy", - "config": { - "inference_url": ANSIBLE_AI_MODEL_MESH_API_URL, - }, - } - - -# Lazy import to avoid circular dependencies -from ansible_ai_connect.ai.api.model_pipelines.pipelines import MetaData # noqa -from ansible_ai_connect.ai.api.model_pipelines.registry import REGISTRY_ENTRY # noqa - -pipelines = [i for i in REGISTRY_ENTRY.keys() if issubclass(i, MetaData)] -pipeline_config: dict = {k.__name__: deepcopy(ANSIBLE_AI_PIPELINE_CONFIG) for k in pipelines} - -# The ChatBot does not use the same configuration as everything else -pipeline_config["ModelPipelineChatBot"] = { - "provider": "http", - "config": { - "inference_url": CHATBOT_URL or "http://localhost:8000", - "model_id": CHATBOT_DEFAULT_MODEL or "granite3-8b", - "verify_ssl": ANSIBLE_AI_MODEL_MESH_API_VERIFY_SSL, - }, -} - -# Enable Health Checks where we have them implemented -pipeline_config["ModelPipelineCompletions"]["config"][ - "enable_health_check" -] = ENABLE_HEALTHCHECK_MODEL_MESH -pipeline_config["ModelPipelineChatBot"]["config"][ - "enable_health_check" -] = ENABLE_HEALTHCHECK_MODEL_MESH - -ANSIBLE_AI_MODEL_MESH_CONFIG = json.dumps(pipeline_config) +ANSIBLE_AI_MODEL_MESH_CONFIG = os.getenv("ANSIBLE_AI_MODEL_MESH_CONFIG") or load_from_env_vars() # ========================================== diff --git a/ansible_ai_connect/main/settings/legacy.py b/ansible_ai_connect/main/settings/legacy.py new file mode 100644 index 000000000..605855acc --- /dev/null +++ b/ansible_ai_connect/main/settings/legacy.py @@ -0,0 +1,204 @@ +# Copyright Red Hat +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +from copy import deepcopy +from typing import cast + +from ansible_ai_connect.main.settings.types import t_model_mesh_api_type + +logger = logging.getLogger(__name__) + + +def is_ssl_enabled(value: str) -> bool: + """SSL should be enabled if value is not recognized""" + disabled = value.lower() in ("false", "f", "0", "-1") + return not disabled + + +def load_from_env_vars(): + # ========================================== + # Model Provider + # ------------------------------------------ + model_service_type: t_model_mesh_api_type = cast( + t_model_mesh_api_type, os.getenv("ANSIBLE_AI_MODEL_MESH_API_TYPE") or "http" + ) + model_service_url = ( + os.getenv("ANSIBLE_AI_MODEL_MESH_API_URL") or "https://model.wisdom.testing.ansible.com:443" + ) + model_service_api_key = os.getenv("ANSIBLE_AI_MODEL_MESH_API_KEY") + model_service_model_id = os.getenv("ANSIBLE_AI_MODEL_MESH_MODEL_ID") + if "ANSIBLE_AI_MODEL_MESH_MODEL_NAME" in os.environ: + logger.warning( + "Use of ANSIBLE_AI_MODEL_MESH_MODEL_NAME is deprecated and " + "should be replaced with ANSIBLE_AI_MODEL_MESH_MODEL_ID." + ) + if "ANSIBLE_AI_MODEL_MESH_MODEL_ID" in os.environ: + logger.warning( + "Environment variable ANSIBLE_AI_MODEL_MESH_MODEL_ID is set " + "and will take precedence." + ) + else: + logger.warning( + "Setting the value of ANSIBLE_AI_MODEL_MESH_MODEL_ID to " + "the value of ANSIBLE_AI_MODEL_MESH_MODEL_NAME." + ) + model_service_model_id = os.getenv("ANSIBLE_AI_MODEL_MESH_MODEL_NAME") + + # Model API Timeout (in seconds). Default is None. + model_service_timeout = ( + int(os.getenv("ANSIBLE_AI_MODEL_MESH_API_TIMEOUT")) + if os.getenv("ANSIBLE_AI_MODEL_MESH_API_TIMEOUT") + else None + ) + model_service_retry_count = int(os.getenv("ANSIBLE_WCA_RETRY_COUNT") or "4") + model_service_verify_ssl = is_ssl_enabled( + os.getenv("ANSIBLE_AI_MODEL_MESH_API_VERIFY_SSL", "True") + ) + + model_service_enable_health_check = ( + os.getenv("ENABLE_HEALTHCHECK_MODEL_MESH", "True").lower() == "true" + ) + model_service_health_check_api_key = os.getenv("ANSIBLE_WCA_HEALTHCHECK_API_KEY") + model_service_health_check_model_id = os.getenv("ANSIBLE_WCA_HEALTHCHECK_MODEL_ID") + model_service_enable_ari_postprocessing = ( + os.getenv("WCA_ENABLE_ARI_POSTPROCESS", "False").lower() == "true" + ) + # ========================================== + + # ========================================== + # Pipeline JSON configuration + # ------------------------------------------ + if model_service_type == "wca": + wca_saas_idp_url = os.getenv("ANSIBLE_WCA_IDP_URL") or "https://iam.cloud.ibm.com/identity" + wca_saas_idp_login = os.getenv("ANSIBLE_WCA_IDP_LOGIN") + wca_saas_idp_password = os.getenv("ANSIBLE_WCA_IDP_PASSWORD") + wca_saas_one_click_default_api_key: str = ( + os.getenv("ANSIBLE_AI_ENABLE_ONE_CLICK_DEFAULT_API_KEY") or "" + ) + wca_saas_one_click_default_model_id: str = ( + os.getenv("ANSIBLE_AI_ENABLE_ONE_CLICK_DEFAULT_MODEL_ID") or "" + ) + model_pipeline_config = { + "provider": "wca", + "config": { + "inference_url": model_service_url, + "api_key": model_service_api_key, + "model_id": model_service_model_id, + "timeout": model_service_timeout, + "verify_ssl": model_service_verify_ssl, + "retry_count": model_service_retry_count, + "enable_ari_postprocessing": model_service_enable_ari_postprocessing, + "health_check_api_key": model_service_health_check_api_key, + "health_check_model_id": model_service_health_check_model_id, + "idp_url": wca_saas_idp_url, + "idp_login": wca_saas_idp_login, + "idp_password": wca_saas_idp_password, + "one_click_default_api_key": wca_saas_one_click_default_api_key, + "one_click_default_model_id": wca_saas_one_click_default_model_id, + }, + } + elif model_service_type == "wca-onprem": + wca_onprem_username = os.getenv("ANSIBLE_WCA_USERNAME") + model_pipeline_config = { + "provider": "wca-onprem", + "config": { + "inference_url": model_service_url, + "api_key": model_service_api_key, + "model_id": model_service_model_id, + "timeout": model_service_timeout, + "verify_ssl": model_service_verify_ssl, + "retry_count": model_service_retry_count, + "enable_ari_postprocessing": model_service_enable_ari_postprocessing, + "health_check_api_key": model_service_health_check_api_key, + "health_check_model_id": model_service_health_check_model_id, + "username": wca_onprem_username, + }, + } + elif model_service_type == "dummy": + dummy_response_body = os.environ.get( + "DUMMY_MODEL_RESPONSE_BODY", + ( + '{"predictions":["ansible.builtin.apt:\\n name: nginx\\n' + ' update_cache: true\\n state: present\\n"]}' + ), + ) + dummy_response_max_latency_msec = int( + os.environ.get("DUMMY_MODEL_RESPONSE_MAX_LATENCY_MSEC", 3000) + ) + dummy_response_latency_use_jitter = bool( + os.environ.get("DUMMY_MODEL_RESPONSE_LATENCY_USE_JITTER", False) + ) + model_pipeline_config = { + "provider": "dummy", + "config": { + "inference_url": model_service_url, + "body": dummy_response_body, + "latency_max_msec": dummy_response_max_latency_msec, + "latency_use_jitter": dummy_response_latency_use_jitter, + }, + } + elif model_service_type == "ollama": + model_pipeline_config = { + "provider": "ollama", + "config": { + "inference_url": model_service_url, + "model_id": model_service_model_id, + "timeout": model_service_timeout, + }, + } + else: + model_pipeline_config = { + "provider": "wca-dummy", + "config": { + "inference_url": model_service_url, + }, + } + + # Lazy import to avoid circular dependencies + from ansible_ai_connect.ai.api.model_pipelines.pipelines import MetaData # noqa + from ansible_ai_connect.ai.api.model_pipelines.registry import ( # noqa + REGISTRY_ENTRY, + ) + + pipelines = [i for i in REGISTRY_ENTRY.keys() if issubclass(i, MetaData)] + model_pipelines_config: dict = {k.__name__: deepcopy(model_pipeline_config) for k in pipelines} + + # The ChatBot does not use the same configuration as everything else + chatbot_service_url = os.getenv("CHATBOT_URL") + chatbot_service_model_id = os.getenv("CHATBOT_DEFAULT_MODEL") + chatbot_service_enable_health_check = ( + os.getenv("ENABLE_HEALTHCHECK_CHATBOT_SERVICE", "True").lower() == "true" + ) + model_pipelines_config["ModelPipelineChatBot"] = { + "provider": "http", + "config": { + "inference_url": chatbot_service_url or "http://localhost:8000", + "model_id": chatbot_service_model_id or "granite3-8b", + "verify_ssl": model_service_verify_ssl, + }, + } + + # Enable Health Checks where we have them implemented + model_pipelines_config["ModelPipelineCompletions"]["config"][ + "enable_health_check" + ] = model_service_enable_health_check + model_pipelines_config["ModelPipelineChatBot"]["config"][ + "enable_health_check" + ] = chatbot_service_enable_health_check + # ========================================== + + return json.dumps(model_pipelines_config) diff --git a/ansible_ai_connect/main/settings/tests/test_legacy.py b/ansible_ai_connect/main/settings/tests/test_legacy.py new file mode 100644 index 000000000..f7c38cd0b --- /dev/null +++ b/ansible_ai_connect/main/settings/tests/test_legacy.py @@ -0,0 +1,259 @@ +# Copyright Red Hat +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import json +import os +from unittest.mock import patch + +import django.conf +from django.test import SimpleTestCase + +import ansible_ai_connect.main.settings.base +from ansible_ai_connect.test_utils import WisdomLogAwareMixin + + +class TestLegacySettings(SimpleTestCase, WisdomLogAwareMixin): + @classmethod + def reload_settings(cls): + module_name = os.getenv("DJANGO_SETTINGS_MODULE") + settings_module = importlib.import_module( + module_name.replace("ansible_wisdom.", "ansible_ai_connect.") + ) + + importlib.reload(ansible_ai_connect.main.settings.base) + importlib.reload(settings_module) + importlib.reload(django.conf) + from django.conf import settings + + settings.configure(default_settings=settings_module) + return settings + + @classmethod + def tearDownClass(cls): + cls.reload_settings() + + def test_model_service_sections(self): + settings = self.reload_settings() + config = json.loads(settings.ANSIBLE_AI_MODEL_MESH_CONFIG) + + # Lazy import to avoid circular dependencies + from ansible_ai_connect.ai.api.model_pipelines.pipelines import MetaData # noqa + from ansible_ai_connect.ai.api.model_pipelines.registry import ( # noqa + REGISTRY_ENTRY, + ) + + pipelines = [i for i in REGISTRY_ENTRY.keys() if issubclass(i, MetaData)] + for k in pipelines: + self.assertTrue(k.__name__ in config) + + @patch.dict( + os.environ, + { + "ANSIBLE_AI_MODEL_MESH_API_TYPE": "wca", + "ANSIBLE_AI_MODEL_MESH_MODEL_NAME": "a-model", + }, + ) + def test_use_of_model_mesh_model_name(self): + with self.assertLogs(logger="root", level="DEBUG") as log: + settings = self.reload_settings() + config = json.loads(settings.ANSIBLE_AI_MODEL_MESH_CONFIG) + + self.assertEqual(config["ModelPipelineCompletions"]["config"]["model_id"], "a-model") + self.assertTrue( + self.searchInLogOutput("Use of ANSIBLE_AI_MODEL_MESH_MODEL_NAME is deprecated", log) + ) + self.assertTrue( + self.searchInLogOutput("Setting the value of ANSIBLE_AI_MODEL_MESH_MODEL_ID", log) + ) + + @patch.dict( + os.environ, + { + "ANSIBLE_AI_MODEL_MESH_API_TYPE": "wca", + "ANSIBLE_AI_MODEL_MESH_MODEL_NAME": "a-model", + "ANSIBLE_AI_MODEL_MESH_MODEL_ID": "b-model", + }, + ) + def test_use_of_model_mesh_model_name_and_model_id(self): + with self.assertLogs(logger="root", level="DEBUG") as log: + settings = self.reload_settings() + config = json.loads(settings.ANSIBLE_AI_MODEL_MESH_CONFIG) + + self.assertEqual(config["ModelPipelineCompletions"]["config"]["model_id"], "b-model") + self.assertTrue( + self.searchInLogOutput("Use of ANSIBLE_AI_MODEL_MESH_MODEL_NAME is deprecated", log) + ) + self.assertTrue( + self.searchInLogOutput( + "ANSIBLE_AI_MODEL_MESH_MODEL_ID is set and will take precedence", log + ) + ) + + @patch.dict( + os.environ, + { + "ANSIBLE_AI_MODEL_MESH_API_TYPE": "wca", + "ANSIBLE_AI_MODEL_MESH_API_URL": "http://a-url", + "ANSIBLE_AI_MODEL_MESH_API_KEY": "api-key", + "ANSIBLE_AI_MODEL_MESH_MODEL_ID": "model-id", + "ANSIBLE_AI_MODEL_MESH_API_TIMEOUT": "999", + "ANSIBLE_AI_MODEL_MESH_API_VERIFY_SSL": "True", + "ANSIBLE_WCA_RETRY_COUNT": "9", + "WCA_ENABLE_ARI_POSTPROCESS": "True", + "ENABLE_HEALTHCHECK_MODEL_MESH": "True", + "ANSIBLE_WCA_HEALTHCHECK_API_KEY": "health-check-api-key", + "ANSIBLE_WCA_HEALTHCHECK_MODEL_ID": "health-check-model-id", + "ANSIBLE_WCA_IDP_URL": "http://idp-url", + "ANSIBLE_WCA_IDP_LOGIN": "idp-login", + "ANSIBLE_WCA_IDP_PASSWORD": "idp-password", + "ANSIBLE_AI_ENABLE_ONE_CLICK_DEFAULT_API_KEY": "trial-api-key", + "ANSIBLE_AI_ENABLE_ONE_CLICK_DEFAULT_MODEL_ID": "trial-model-id", + }, + ) + def test_wca_saas(self): + settings = self.reload_settings() + config = json.loads(settings.ANSIBLE_AI_MODEL_MESH_CONFIG) + + self.assertEqual( + config["ModelPipelineCompletions"]["config"]["inference_url"], "http://a-url" + ) + self.assertEqual(config["ModelPipelineCompletions"]["config"]["api_key"], "api-key") + self.assertEqual(config["ModelPipelineCompletions"]["config"]["model_id"], "model-id") + self.assertEqual(config["ModelPipelineCompletions"]["config"]["timeout"], 999) + self.assertEqual(config["ModelPipelineCompletions"]["config"]["verify_ssl"], True) + self.assertEqual(config["ModelPipelineCompletions"]["config"]["retry_count"], 9) + self.assertEqual( + config["ModelPipelineCompletions"]["config"]["enable_ari_postprocessing"], True + ) + self.assertEqual(config["ModelPipelineCompletions"]["config"]["enable_health_check"], True) + self.assertEqual( + config["ModelPipelineCompletions"]["config"]["health_check_api_key"], + "health-check-api-key", + ) + self.assertEqual( + config["ModelPipelineCompletions"]["config"]["health_check_model_id"], + "health-check-model-id", + ) + self.assertEqual(config["ModelPipelineCompletions"]["config"]["idp_url"], "http://idp-url") + self.assertEqual(config["ModelPipelineCompletions"]["config"]["idp_login"], "idp-login") + self.assertEqual( + config["ModelPipelineCompletions"]["config"]["idp_password"], "idp-password" + ) + self.assertEqual( + config["ModelPipelineCompletions"]["config"]["one_click_default_api_key"], + "trial-api-key", + ) + self.assertEqual( + config["ModelPipelineCompletions"]["config"]["one_click_default_model_id"], + "trial-model-id", + ) + + @patch.dict( + os.environ, + { + "ANSIBLE_AI_MODEL_MESH_API_TYPE": "wca-onprem", + "ANSIBLE_AI_MODEL_MESH_API_URL": "http://a-url", + "ANSIBLE_AI_MODEL_MESH_API_KEY": "api-key", + "ANSIBLE_AI_MODEL_MESH_MODEL_ID": "model-id", + "ANSIBLE_AI_MODEL_MESH_API_TIMEOUT": "999", + "ANSIBLE_AI_MODEL_MESH_API_VERIFY_SSL": "True", + "ANSIBLE_WCA_RETRY_COUNT": "9", + "WCA_ENABLE_ARI_POSTPROCESS": "True", + "ENABLE_HEALTHCHECK_MODEL_MESH": "True", + "ANSIBLE_WCA_HEALTHCHECK_API_KEY": "health-check-api-key", + "ANSIBLE_WCA_HEALTHCHECK_MODEL_ID": "health-check-model-id", + "ANSIBLE_WCA_USERNAME": "username", + }, + ) + def test_wca_onprem(self): + settings = self.reload_settings() + config = json.loads(settings.ANSIBLE_AI_MODEL_MESH_CONFIG) + + self.assertEqual( + config["ModelPipelineCompletions"]["config"]["inference_url"], "http://a-url" + ) + self.assertEqual(config["ModelPipelineCompletions"]["config"]["api_key"], "api-key") + self.assertEqual(config["ModelPipelineCompletions"]["config"]["model_id"], "model-id") + self.assertEqual(config["ModelPipelineCompletions"]["config"]["timeout"], 999) + self.assertEqual(config["ModelPipelineCompletions"]["config"]["verify_ssl"], True) + self.assertEqual(config["ModelPipelineCompletions"]["config"]["retry_count"], 9) + self.assertEqual( + config["ModelPipelineCompletions"]["config"]["enable_ari_postprocessing"], True + ) + self.assertEqual(config["ModelPipelineCompletions"]["config"]["enable_health_check"], True) + self.assertEqual( + config["ModelPipelineCompletions"]["config"]["health_check_api_key"], + "health-check-api-key", + ) + self.assertEqual( + config["ModelPipelineCompletions"]["config"]["health_check_model_id"], + "health-check-model-id", + ) + self.assertEqual(config["ModelPipelineCompletions"]["config"]["username"], "username") + + @patch.dict( + os.environ, + { + "ANSIBLE_AI_MODEL_MESH_API_TYPE": "wca-dummy", + "ANSIBLE_AI_MODEL_MESH_API_URL": "http://a-url", + }, + ) + def test_wca_dummy(self): + settings = self.reload_settings() + config = json.loads(settings.ANSIBLE_AI_MODEL_MESH_CONFIG) + + self.assertEqual( + config["ModelPipelineCompletions"]["config"]["inference_url"], "http://a-url" + ) + + @patch.dict( + os.environ, + { + "ANSIBLE_AI_MODEL_MESH_API_TYPE": "dummy", + "ANSIBLE_AI_MODEL_MESH_API_URL": "http://a-url", + "DUMMY_MODEL_RESPONSE_BODY": "dummy-body", + "DUMMY_MODEL_RESPONSE_MAX_LATENCY_MSEC": "999", + "DUMMY_MODEL_RESPONSE_LATENCY_USE_JITTER": "True", + }, + ) + def test_dummy(self): + settings = self.reload_settings() + config = json.loads(settings.ANSIBLE_AI_MODEL_MESH_CONFIG) + + self.assertEqual( + config["ModelPipelineCompletions"]["config"]["inference_url"], "http://a-url" + ) + self.assertEqual(config["ModelPipelineCompletions"]["config"]["body"], "dummy-body") + self.assertEqual(config["ModelPipelineCompletions"]["config"]["latency_max_msec"], 999) + self.assertEqual(config["ModelPipelineCompletions"]["config"]["latency_use_jitter"], True) + + @patch.dict( + os.environ, + { + "ANSIBLE_AI_MODEL_MESH_API_TYPE": "ollama", + "ANSIBLE_AI_MODEL_MESH_API_URL": "http://a-url", + "ANSIBLE_AI_MODEL_MESH_MODEL_ID": "model-id", + "ANSIBLE_AI_MODEL_MESH_API_TIMEOUT": "999", + }, + ) + def test_ollama(self): + settings = self.reload_settings() + config = json.loads(settings.ANSIBLE_AI_MODEL_MESH_CONFIG) + + self.assertEqual( + config["ModelPipelineCompletions"]["config"]["inference_url"], "http://a-url" + ) + self.assertEqual(config["ModelPipelineCompletions"]["config"]["model_id"], "model-id") + self.assertEqual(config["ModelPipelineCompletions"]["config"]["timeout"], 999) diff --git a/ansible_ai_connect/main/tests/test_middleware.py b/ansible_ai_connect/main/tests/test_middleware.py index 9ac1ae665..b60685bb8 100644 --- a/ansible_ai_connect/main/tests/test_middleware.py +++ b/ansible_ai_connect/main/tests/test_middleware.py @@ -19,7 +19,6 @@ from urllib.parse import urlencode from django.apps import apps -from django.conf import settings from django.test import override_settings from django.urls import reverse from segment import analytics @@ -39,7 +38,6 @@ def dummy_redact_seated_users_data(event, allow_list): @override_settings(WCA_SECRET_DUMMY_SECRETS="1981:valid") @override_settings(AUTHZ_BACKEND_TYPE="dummy") @override_settings(AUTHZ_DUMMY_ORGS_WITH_SUBSCRIPTION="*") -@override_settings(ANSIBLE_AI_MODEL_MESH_MODEL_ID="wisdom") class TestMiddleware(WisdomAppsBackendMocking, WisdomServiceAPITestCaseBaseOIDC): @override_settings(ENABLE_ARI_POSTPROCESS=True) @@ -73,7 +71,7 @@ def test_full_payload(self): }, } response_data = { - "model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID, + "model_id": "a-model-id", "predictions": [" ansible.builtin.apt:\n name: apache2"], } @@ -100,9 +98,7 @@ def test_full_payload(self): for event in segment_events: properties = event["properties"] self.assertTrue("modelName" in properties) - self.assertEqual( - properties["modelName"], settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID - ) + self.assertEqual(properties["modelName"], "a-model-id") self.assertTrue("imageTags" in properties) self.assertTrue("groups" in properties) self.assertTrue("Group 1" in properties["groups"]) @@ -184,7 +180,7 @@ def test_segment_error(self): }, } response_data = { - "model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID, + "model_id": "a-model-id", "predictions": [" ansible.builtin.apt:\n name: apache2"], } self.client.force_authenticate(user=self.user) @@ -231,7 +227,7 @@ def test_204_empty_response(self): "status_code": 204, } response_data = { - "model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID, + "model_id": "a-model-id", } self.client.force_authenticate(user=self.user) @@ -307,7 +303,7 @@ def test_segment_error_with_data_exceeding_limit(self): }, } response_data = { - "model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID, + "model_id": "a-model-id", "predictions": [" ansible.builtin.apt:\n name: apache2"], } diff --git a/ansible_ai_connect/main/tests/test_settings.py b/ansible_ai_connect/main/tests/test_settings.py index d01eaaa46..364af600a 100644 --- a/ansible_ai_connect/main/tests/test_settings.py +++ b/ansible_ai_connect/main/tests/test_settings.py @@ -97,47 +97,6 @@ def test_github_auth_team_empty_key(self): self.assertEqual(settings.SOCIAL_AUTH_GITHUB_SCOPE, [""]) self.assertEqual(settings.SOCIAL_AUTH_GITHUB_EXTRA_DATA, ["login"]) - @patch.dict( - os.environ, - { - "ANSIBLE_AI_MODEL_MESH_MODEL_NAME": "a-model", - }, - ) - def test_use_of_model_mesh_model_name(self): - with self.assertLogs(logger="root", level="DEBUG") as log: - settings = self.reload_settings() - self.assertEqual(settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID, "a-model") - self.assertTrue( - self.searchInLogOutput("Use of ANSIBLE_AI_MODEL_MESH_MODEL_NAME is deprecated", log) - ) - self.assertTrue( - self.searchInLogOutput("Setting the value of ANSIBLE_AI_MODEL_MESH_MODEL_ID", log) - ) - - @patch.dict( - os.environ, - { - "ANSIBLE_AI_MODEL_MESH_MODEL_NAME": "a-model", - "ANSIBLE_AI_MODEL_MESH_MODEL_ID": "b-model", - }, - ) - def test_use_of_model_mesh_model_name_and_model_id(self): - with self.assertLogs(logger="root", level="DEBUG") as log: - settings = self.reload_settings() - self.assertEqual(settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID, "b-model") - self.assertTrue( - self.searchInLogOutput("Use of ANSIBLE_AI_MODEL_MESH_MODEL_NAME is deprecated", log) - ) - self.assertTrue( - self.searchInLogOutput( - "ANSIBLE_AI_MODEL_MESH_MODEL_ID is set and will take precedence", log - ) - ) - - def test_ansible_ai_model_mesh_model_id_has_no_default(self): - settings = self.reload_settings() - self.assertIsNone(settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID) - @patch.dict( os.environ, { diff --git a/ansible_ai_connect/main/tests/test_views.py b/ansible_ai_connect/main/tests/test_views.py index a1035ac99..53158cce4 100644 --- a/ansible_ai_connect/main/tests/test_views.py +++ b/ansible_ai_connect/main/tests/test_views.py @@ -237,9 +237,7 @@ def test_get_view_expired_trial(self): trial_plan.delete() -@override_settings(CHATBOT_URL="http://127.0.0.1:8080") @override_settings(CHATBOT_DEFAULT_PROVIDER="wisdom") -@override_settings(CHATBOT_DEFAULT_MODEL="granite-8b") @override_settings(ANSIBLE_AI_CHATBOT_NAME="Awesome Chatbot") @override_settings(CHATBOT_DEBUG_UI=False) class TestChatbotView(TestCase): @@ -299,9 +297,7 @@ def test_chatbot_link_with_non_rh_test_user(self): self.assertContains(r, TestChatbotView.DOCUMENT_URL) self.assertContains(r, "Chatbot") - @override_settings(CHATBOT_URL="") @override_settings(CHATBOT_DEFAULT_PROVIDER="") - @override_settings(CHATBOT_DEFAULT_MODEL="") def test_chatbot_link_with_rh_user_but_chatbot_disabled(self): self.client.force_login(user=self.rh_user) r = self.client.get(reverse("home")) @@ -326,9 +322,7 @@ def test_chatbot_view_with_non_rh_user(self): r = self.client.get(reverse("chatbot")) self.assertEqual(r.status_code, HTTPStatus.FORBIDDEN) - @override_settings(CHATBOT_URL="") @override_settings(CHATBOT_DEFAULT_PROVIDER="") - @override_settings(CHATBOT_DEFAULT_MODEL="") def test_chatbot_view_with_rh_user_but_chatbot_disabled(self): self.client.force_login(user=self.rh_user) r = self.client.get(reverse("chatbot")) diff --git a/ansible_ai_connect/main/views.py b/ansible_ai_connect/main/views.py index 00fe279fa..620b84bc0 100644 --- a/ansible_ai_connect/main/views.py +++ b/ansible_ai_connect/main/views.py @@ -16,6 +16,7 @@ import logging +from django.apps import apps from django.conf import settings from django.contrib.auth import views as auth_views from django.contrib.auth.models import AnonymousUser @@ -27,6 +28,7 @@ from rest_framework.renderers import BaseRenderer from rest_framework.views import APIView +from ansible_ai_connect.ai.api.model_pipelines.pipelines import ModelPipelineChatBot from ansible_ai_connect.ai.api.permissions import ( IsOrganisationAdministrator, IsOrganisationLightspeedSubscriber, @@ -121,11 +123,10 @@ class ChatbotView(ProtectedTemplateView): def get(self, request): # Open the chatbot page when the chatbot service is configured. - if ( - settings.CHATBOT_URL - and settings.CHATBOT_DEFAULT_MODEL - and settings.CHATBOT_DEFAULT_PROVIDER - ): + llm: ModelPipelineChatBot = apps.get_app_config("ai").get_model_pipeline( + ModelPipelineChatBot + ) + if llm.config.inference_url and llm.config.model_id and settings.CHATBOT_DEFAULT_PROVIDER: return super().get(request) # Otherwise, redirect to the home page. diff --git a/ansible_ai_connect/users/views.py b/ansible_ai_connect/users/views.py index bf0d41fa4..a7d761de4 100644 --- a/ansible_ai_connect/users/views.py +++ b/ansible_ai_connect/users/views.py @@ -31,6 +31,7 @@ from ansible_ai_connect.ai.api.aws.exceptions import ( WcaSecretManagerMissingCredentialsError, ) +from ansible_ai_connect.ai.api.model_pipelines.pipelines import ModelPipelineChatBot from ansible_ai_connect.ai.api.telemetry import schema1 from ansible_ai_connect.ai.api.telemetry import schema2_utils as schema2 from ansible_ai_connect.ai.api.utils.segment import send_schema1_event @@ -92,10 +93,12 @@ def get_context_data(self, **kwargs): user.rh_internal or user.groups.filter(name="test").exists() ) + # Show chatbot link when the chatbot service is configured. + llm: ModelPipelineChatBot = apps.get_app_config("ai").get_model_pipeline( + ModelPipelineChatBot + ) context["chatbot_enabled"] = ( - settings.CHATBOT_URL - and settings.CHATBOT_DEFAULT_MODEL - and settings.CHATBOT_DEFAULT_PROVIDER + llm.config.inference_url and llm.config.model_id and settings.CHATBOT_DEFAULT_PROVIDER ) return context diff --git a/ansible_ai_connect_chatbot/README.md b/ansible_ai_connect_chatbot/README.md index 34462f1e0..284b68dd6 100644 --- a/ansible_ai_connect_chatbot/README.md +++ b/ansible_ai_connect_chatbot/README.md @@ -51,18 +51,26 @@ in the `coverage` sub-directory. ## Test Chatbot in Local environment **Chatbot is enabled when all of -the following three environment variables are defined:** +the following three parameters are defined:** -1. `CHATBOT_URL` URL of the chat service to be used. -2. `CHATBOT_DEFAULT_PROVIDER` Default AI model provider. It should be -one of providers defined in the configuration used by the chat service. -3. `CHATBOT_DEFAULT_MODEL` Default AI model. It should be +1. `ModelPipelineChatBot.config.inference_url` URL of the chat service to be used. +2. `ModelPipelineChatBot.config.model_id` Default AI model. It should be one of models defined in the configuration used by the chat service. +3. `CHATBOT_DEFAULT_PROVIDER` Default AI model provider. It should be + one of providers defined in the configuration used by the chat service. ```commandline -CHATBOT_URL=http://127.0.0.1:8080 CHATBOT_DEFAULT_PROVIDER=wisdom -CHATBOT_DEFAULT_MODEL=granite3-8b +``` +```json +{ + "ModelPipelineChatBot": { + "config": { + "inference_url": "http://127.0.0.1:8080", + "model_id": "granite3-8b" + } + } +} ``` You also need to configure Red Hat SSO authentication on your local diff --git a/docs/config/examples/README-ANSIBLE_AI_MODEL_MESH_CONFIG.md b/docs/config/examples/README-ANSIBLE_AI_MODEL_MESH_CONFIG.md new file mode 100644 index 000000000..a3a8d4e39 --- /dev/null +++ b/docs/config/examples/README-ANSIBLE_AI_MODEL_MESH_CONFIG.md @@ -0,0 +1,87 @@ +# Example `ANSIBLE_AI_MODEL_MESH_CONFIG` configuration + +`ANSIBLE_AI_MODEL_MESH_CONFIG` can be defined with either JSON or YAML. + +## JSON Configuration + +```json +{ + "ModelPipelineCompletions": { + "provider": "ollama", + "config": { + "inference_url": "http://host.containers.internal:11434", + "model_id": "mistral:instruct" + } + }, + "ModelPipelineContentMatch": { + "provider": "ollama", + "config": { + "inference_url": "http://host.containers.internal:11434", + "model_id": "mistral:instruct" + } + }, + "ModelPipelinePlaybookGeneration": { + "provider": "ollama", + "config": { + "inference_url": "http://host.containers.internal:11434", + "model_id": "mistral:instruct" + } + }, + "ModelPipelineRoleGeneration": { + "provider": "ollama", + "config": { + "inference_url": "http://host.containers.internal:11434", + "model_id": "mistral:instruct" + } + }, + "ModelPipelinePlaybookExplanation": { + "provider": "ollama", + "config": { + "inference_url": "http://host.containers.internal:11434", + "model_id": "mistral:instruct" + } + }, + "ModelPipelineChatBot": { + "provider": "http", + "config": { + "inference_url": "http://localhost:8000", + "model_id": "granite3-8b" + } + } +} +``` + +## YAML Configuration + +```yaml +MetaData: + provider: ollama + config: + inference_url: http://localhost + model_id: a-model-id +ModelPipelineCompletions: + provider: ollama + config: + inference_url: http://localhost + model_id: a-model-id +ModelPipelinePlaybookGeneration: + provider: ollama + config: + inference_url: http://localhost + model_id: a-model-id +ModelPipelineRoleGeneration: + provider: ollama + config: + inference_url: http://localhost + model_id: a-model-id +ModelPipelinePlaybookExplanation: + provider: ollama + config: + inference_url: http://localhost + model_id: a-model-id +ModelPipelineChatBot: + provider: http, + config: + inference_url: http://localhost + model_id: granite3-8b +``` diff --git a/docs/config/examples/README-example-dummy.md b/docs/config/examples/README-example-dummy.md new file mode 100644 index 000000000..96fd580d7 --- /dev/null +++ b/docs/config/examples/README-example-dummy.md @@ -0,0 +1,58 @@ +# Example `ANSIBLE_AI_MODEL_MESH_CONFIG` configuration for `dummy` + +```json +{ + "ModelPipelineCompletions": { + "provider": "dummy", + "config": { + "inference_url": "http://localhost:8000", + "body": "{\"predictions\":[\"ansible.builtin.apt:\\n name: nginx\\n update_cache: true\\n state: present\\n\"]}", + "latency_max_msec": "3000", + "latency_use_jitter": "False" + } + }, + "ModelPipelineContentMatch": { + "provider": "dummy", + "config": { + "inference_url": "http://localhost:8000", + "body": "{\"predictions\":[\"ansible.builtin.apt:\\n name: nginx\\n update_cache: true\\n state: present\\n\"]}", + "latency_max_msec": "3000", + "latency_use_jitter": "False" + } + }, + "ModelPipelinePlaybookGeneration": { + "provider": "dummy", + "config": { + "inference_url": "http://localhost:8000", + "body": "{\"predictions\":[\"ansible.builtin.apt:\\n name: nginx\\n update_cache: true\\n state: present\\n\"]}", + "latency_max_msec": "3000", + "latency_use_jitter": "False" + } + }, + "ModelPipelineRoleGeneration": { + "provider": "dummy", + "config": { + "inference_url": "http://localhost:8000", + "body": "{\"predictions\":[\"ansible.builtin.apt:\\n name: nginx\\n update_cache: true\\n state: present\\n\"]}", + "latency_max_msec": "3000", + "latency_use_jitter": "False" + } + }, + "ModelPipelinePlaybookExplanation": { + "provider": "dummy", + "config": { + "inference_url": "http://localhost:8000", + "body": "{\"predictions\":[\"ansible.builtin.apt:\\n name: nginx\\n update_cache: true\\n state: present\\n\"]}", + "latency_max_msec": "3000", + "latency_use_jitter": "False" + } + }, + "ModelPipelineChatBot": { + "provider": "http", + "config": { + "inference_url": "", + "model_id": "" + } + } +} +``` diff --git a/docs/config/examples/README-example-hybrid.md b/docs/config/examples/README-example-hybrid.md new file mode 100644 index 000000000..71746190f --- /dev/null +++ b/docs/config/examples/README-example-hybrid.md @@ -0,0 +1,38 @@ +# Example _hybrid_ `ANSIBLE_AI_MODEL_MESH_CONFIG` configuration + +Separating the configuration for each pipeline allows different settings to be used for each; or disabled all together. + +For example the following configuration uses `ollama` for "Completions" however `wca-onprem` for "Playbook Generation". "ContentMatches", "Playbook Explanation" and "Role Generation" are not configured and would fall back to a "No Operation" implementation. The "Chat Bot" uses a plain `http` implementation to another service. + +```json +{ + "ModelPipelineCompletions": { + "provider": "ollama", + "config": { + "inference_url": "http://localhost:8000", + "model_id": "ollama-model" + } + }, + "ModelPipelinePlaybookGeneration": { + "provider": "wca-onprem", + "config": { + "inference_url": "", + "api_key": "", + "model_id": "", + "verify_ssl": "True", + "retry_count": "4", + "enable_ari_postprocessing": "False", + "health_check_api_key": "", + "health_check_model_id": "", + "username": "" + } + }, + "ModelPipelineChatBot": { + "provider": "http", + "config": { + "inference_url": "", + "model_id": "" + } + } +} +``` diff --git a/docs/config/examples/README-example-ollama.md b/docs/config/examples/README-example-ollama.md new file mode 100644 index 000000000..627a646b5 --- /dev/null +++ b/docs/config/examples/README-example-ollama.md @@ -0,0 +1,48 @@ +# Example `ANSIBLE_AI_MODEL_MESH_CONFIG` configuration for `ollama` + +```json +{ + "ModelPipelineCompletions": { + "provider": "ollama", + "config": { + "inference_url": "http://localhost:8000", + "model_id": "ollama-model" + } + }, + "ModelPipelineContentMatch": { + "provider": "ollama", + "config": { + "inference_url": "http://localhost:8000", + "model_id": "ollama-model" + } + }, + "ModelPipelinePlaybookGeneration": { + "provider": "ollama", + "config": { + "inference_url": "http://localhost:8000", + "model_id": "ollama-model" + } + }, + "ModelPipelineRoleGeneration": { + "provider": "ollama", + "config": { + "inference_url": "http://localhost:8000", + "model_id": "ollama-model" + } + }, + "ModelPipelinePlaybookExplanation": { + "provider": "ollama", + "config": { + "inference_url": "http://localhost:8000", + "model_id": "ollama-model" + } + }, + "ModelPipelineChatBot": { + "provider": "http", + "config": { + "inference_url": "", + "model_id": "" + } + } +} +``` diff --git a/docs/config/examples/README-example-wca-dummy.md b/docs/config/examples/README-example-wca-dummy.md new file mode 100644 index 000000000..94aac3552 --- /dev/null +++ b/docs/config/examples/README-example-wca-dummy.md @@ -0,0 +1,43 @@ +# Example `ANSIBLE_AI_MODEL_MESH_CONFIG` configuration for `wca-dummy` + +```json +{ + "ModelPipelineCompletions": { + "provider": "wca-dummy", + "config": { + "inference_url": "http://localhost:8000" + } + }, + "ModelPipelineContentMatch": { + "provider": "wca-dummy", + "config": { + "inference_url": "http://localhost:8000" + } + }, + "ModelPipelinePlaybookGeneration": { + "provider": "wca-dummy", + "config": { + "inference_url": "http://localhost:8000" + } + }, + "ModelPipelineRoleGeneration": { + "provider": "wca-dummy", + "config": { + "inference_url": "http://localhost:8000" + } + }, + "ModelPipelinePlaybookExplanation": { + "provider": "wca-dummy", + "config": { + "inference_url": "http://localhost:8000" + } + }, + "ModelPipelineChatBot": { + "provider": "http", + "config": { + "inference_url": "", + "model_id": "" + } + } +} +``` diff --git a/docs/config/examples/README-example-wca-onprem.md b/docs/config/examples/README-example-wca-onprem.md new file mode 100644 index 000000000..951aa637e --- /dev/null +++ b/docs/config/examples/README-example-wca-onprem.md @@ -0,0 +1,84 @@ +# Example `ANSIBLE_AI_MODEL_MESH_CONFIG` configuration for `wca-onprem` + +```json +{ + "ModelPipelineCompletions": { + "provider": "wca-onprem", + "config": { + "inference_url": "", + "api_key": "", + "model_id": "", + "verify_ssl": "True", + "retry_count": "4", + "enable_ari_postprocessing": "False", + "health_check_api_key": "", + "health_check_model_id": "", + "username": "", + "enable_health_check": "True" + } + }, + "ModelPipelineContentMatch": { + "provider": "wca-onprem", + "config": { + "inference_url": "", + "api_key": "", + "model_id": "", + "verify_ssl": "True", + "retry_count": "4", + "enable_ari_postprocessing": "False", + "health_check_api_key": "", + "health_check_model_id": "", + "username": "" + } + }, + "ModelPipelinePlaybookGeneration": { + "provider": "wca-onprem", + "config": { + "inference_url": "", + "api_key": "", + "model_id": "", + "verify_ssl": "True", + "retry_count": "4", + "enable_ari_postprocessing": "False", + "health_check_api_key": "", + "health_check_model_id": "", + "username": "" + } + }, + "ModelPipelineRoleGeneration": { + "provider": "wca-onprem", + "config": { + "inference_url": "", + "api_key": "", + "model_id": "", + "verify_ssl": "True", + "retry_count": "4", + "enable_ari_postprocessing": "False", + "health_check_api_key": "", + "health_check_model_id": "", + "username": "" + } + }, + "ModelPipelinePlaybookExplanation": { + "provider": "wca-onprem", + "config": { + "inference_url": "", + "api_key": "", + "model_id": "", + "verify_ssl": "True", + "retry_count": "4", + "enable_ari_postprocessing": "False", + "health_check_api_key": "", + "health_check_model_id": "", + "username": "" + } + }, + "ModelPipelineChatBot": { + "provider": "http", + "config": { + "inference_url": "", + "model_id": "" + } + } +} +``` diff --git a/docs/config/examples/README-example-wca-saas.md b/docs/config/examples/README-example-wca-saas.md new file mode 100644 index 000000000..762ef8620 --- /dev/null +++ b/docs/config/examples/README-example-wca-saas.md @@ -0,0 +1,104 @@ +# Example `ANSIBLE_AI_MODEL_MESH_CONFIG` configuration for `wca` (SaaS) + +```json +{ + "ModelPipelineCompletions": { + "provider": "wca", + "config": { + "inference_url": "https://api.dataplatform.test.cloud.ibm.com/", + "api_key": "", + "model_id": "", + "verify_ssl": "True", + "retry_count": "4", + "enable_ari_postprocessing": "False", + "health_check_api_key": "", + "health_check_model_id": "", + "idp_url": "", + "idp_login": "", + "idp_password": "", + "one_click_default_api_key": "", + "one_click_default_model_id": "", + "enable_health_check": "True" + } + }, + "ModelPipelineContentMatch": { + "provider": "wca", + "config": { + "inference_url": "https://api.dataplatform.test.cloud.ibm.com/", + "api_key": "", + "model_id": "", + "verify_ssl": "True", + "retry_count": "4", + "enable_ari_postprocessing": "False", + "health_check_api_key": "", + "health_check_model_id": "", + "idp_url": "", + "idp_login": "", + "idp_password": "", + "one_click_default_api_key": "", + "one_click_default_model_id": "" + } + }, + "ModelPipelinePlaybookGeneration": { + "provider": "wca", + "config": { + "inference_url": "https://api.dataplatform.test.cloud.ibm.com/", + "api_key": "", + "model_id": "", + "verify_ssl": "True", + "retry_count": "4", + "enable_ari_postprocessing": "False", + "health_check_api_key": "", + "health_check_model_id": "", + "idp_url": "", + "idp_login": "", + "idp_password": "", + "one_click_default_api_key": "", + "one_click_default_model_id": "4" + } + }, + "ModelPipelineRoleGeneration": { + "provider": "wca", + "config": { + "inference_url": "https://api.dataplatform.test.cloud.ibm.com/", + "api_key": "", + "model_id": "", + "verify_ssl": "True", + "retry_count": "4", + "enable_ari_postprocessing": "False", + "health_check_api_key": "", + "health_check_model_id": "", + "idp_url": "", + "idp_login": "", + "idp_password": "", + "one_click_default_api_key": "", + "one_click_default_model_id": "" + } + }, + "ModelPipelinePlaybookExplanation": { + "provider": "wca", + "config": { + "inference_url": "https://api.dataplatform.test.cloud.ibm.com/", + "api_key": "", + "model_id": "", + "verify_ssl": "True", + "retry_count": "4", + "enable_ari_postprocessing": "False", + "health_check_api_key": "", + "health_check_model_id": "", + "idp_url": "", + "idp_login": "", + "idp_password": "", + "one_click_default_api_key": "", + "one_click_default_model_id": "" + } + }, + "ModelPipelineChatBot": { + "provider": "http", + "config": { + "inference_url": "", + "model_id": "" + } + } +} +``` diff --git a/docs/pycharm-setup.md b/docs/pycharm-setup.md index 64e81bf2f..7dce39160 100644 --- a/docs/pycharm-setup.md +++ b/docs/pycharm-setup.md @@ -194,16 +194,15 @@ ANSIBLE_AI_DATABASE_USER=wisdom DJANGO_SETTINGS_MODULE=main.settings.development PYTHONUNBUFFERED=1 SECRET_KEY=somesecret -ANSIBLE_AI_MODEL_MESH_API_TYPE=ollama -ANSIBLE_AI_MODEL_MESH_MODEL_ID="mistral:instruct" -ANSIBLE_AI_MODEL_MESH_API_URL=http://127.0.0.1:11434 ENABLE_ARI_POSTPROCESS=False DEPLOYMENT_MODE=upstream +ANSIBLE_AI_MODEL_MESH_CONFIG="..." ``` +See the example [ANSIBLE_AI_MODEL_MESH_CONFIG](./config/examples/README-ANSIBLE_AI_MODEL_MESH_CONFIG.md). > [!TIP] -> The example shown above uses local Ollama server with Mistral 7B Instruct model. -> For using a other type of model server that provides prediction results, +> The example referenced above uses local Ollama server with Mistral 7B Instruct model. +> For using a different type of model server that provides prediction results, > you need to set extra environment variables for a client type that is used > to connect to the model server. diff --git a/tools/docker-compose/compose.yaml b/tools/docker-compose/compose.yaml index a8b607a09..8b4c9348f 100644 --- a/tools/docker-compose/compose.yaml +++ b/tools/docker-compose/compose.yaml @@ -75,7 +75,7 @@ services: - ANSIBLE_AI_ENABLE_ONE_CLICK_DEFAULT_API_KEY=${ANSIBLE_AI_ENABLE_ONE_CLICK_DEFAULT_API_KEY} - CHATBOT_URL=${CHATBOT_URL} - CHATBOT_DEFAULT_PROVIDER=${CHATBOT_DEFAULT_PROVIDER} - - CHATBOT_DEFAULT_MODEL=${CHATBOT_DEFAULT_MODEL} + - ANSIBLE_AI_MODEL_MESH_CONFIG=${ANSIBLE_AI_MODEL_MESH_CONFIG} command: - /etc/wisdom/scripts/launch-wisdom.sh networks: