diff --git a/ansible_ai_connect/ai/api/tests/test_views.py b/ansible_ai_connect/ai/api/tests/test_views.py index db2c51d5b..eac79b57c 100644 --- a/ansible_ai_connect/ai/api/tests/test_views.py +++ b/ansible_ai_connect/ai/api/tests/test_views.py @@ -3819,6 +3819,12 @@ class TestChatView(WisdomServiceAPITestCaseBase): "query": "Return the internal server error status code", } + PAYLOAD_WITH_MODEL_AND_PROVIDER = { + "query": "Payload with a non-default model and a non-default provider", + "model": "non_default_model", + "provider": "non_default_provider", + } + JSON_RESPONSE = { "response": "AAP 2.5 introduces an updated, unified UI.", "conversation_id": "123e4567-e89b-12d3-a456-426614174000", @@ -3838,7 +3844,7 @@ def json(self): return self.json_data # Make sure that the given json data is serializable - json.dumps(kwargs["json"]) + input = json.dumps(kwargs["json"]) json_response = { "response": "AAP 2.5 introduces an updated, unified UI.", @@ -3880,7 +3886,9 @@ def json(self): json_response = { "detail": "Internal server error", } - + elif kwargs["json"]["query"] == TestChatView.PAYLOAD_WITH_MODEL_AND_PROVIDER["query"]: + status_code = 200 + json_response["response"] = input return MockResponse(json_response, status_code) @override_settings(CHATBOT_URL="http://localhost:8080") @@ -3926,6 +3934,7 @@ def assert_test( r, expected_exception().default_code, expected_exception().default_detail ) self.assertInLog(expected_log_message, log) + return r def test_chat(self): self.assert_test(TestChatView.VALID_PAYLOAD) @@ -3993,3 +4002,8 @@ def test_chat_internal_server_exception(self): ChatbotInternalServerException, "ChatbotInternalServerException", ) + + def test_chat_with_model_and_provider(self): + r = self.assert_test(TestChatView.PAYLOAD_WITH_MODEL_AND_PROVIDER) + self.assertIn('"model": "non_default_model"', r.data["response"]) + self.assertIn('"provider": "non_default_provider"', r.data["response"]) diff --git a/ansible_ai_connect/ai/api/views.py b/ansible_ai_connect/ai/api/views.py index 0d7652572..cf8456774 100644 --- a/ansible_ai_connect/ai/api/views.py +++ b/ansible_ai_connect/ai/api/views.py @@ -1130,8 +1130,12 @@ def post(self, request) -> Response: data = { "query": request_serializer.validated_data["query"], - "model": settings.CHATBOT_DEFAULT_MODEL, - "provider": settings.CHATBOT_DEFAULT_PROVIDER, + "model": request_serializer.validated_data.get( + "model", settings.CHATBOT_DEFAULT_MODEL + ), + "provider": request_serializer.validated_data.get( + "provider", settings.CHATBOT_DEFAULT_PROVIDER + ), } if "conversation_id" in request_serializer.validated_data: data["conversation_id"] = str(request_serializer.validated_data["conversation_id"]) diff --git a/ansible_ai_connect/main/settings/base.py b/ansible_ai_connect/main/settings/base.py index e8903fbd6..07a00cfed 100644 --- a/ansible_ai_connect/main/settings/base.py +++ b/ansible_ai_connect/main/settings/base.py @@ -624,6 +624,7 @@ def is_ssl_enabled(value: str) -> bool: CHATBOT_URL = os.getenv("CHATBOT_URL") CHATBOT_DEFAULT_PROVIDER = os.getenv("CHATBOT_DEFAULT_PROVIDER") CHATBOT_DEFAULT_MODEL = os.getenv("CHATBOT_DEFAULT_MODEL") +CHATBOT_DEBUG_UI = os.getenv("CHATBOT_DEBUG_UI", "False").lower() == "true" # ========================================== # ========================================== diff --git a/ansible_ai_connect/main/templates/chatbot/index.html b/ansible_ai_connect/main/templates/chatbot/index.html index 1df3341ba..6aadc4c8b 100644 --- a/ansible_ai_connect/main/templates/chatbot/index.html +++ b/ansible_ai_connect/main/templates/chatbot/index.html @@ -20,5 +20,6 @@