From aadf74ea75303fee8811651d547b6ff4ee18b38f Mon Sep 17 00:00:00 2001 From: Tami Takamiya Date: Thu, 30 Jan 2025 14:30:53 -0500 Subject: [PATCH] Add stream to HttpConfiguration --- .../ai/api/model_pipelines/http/configuration.py | 5 +++++ ansible_ai_connect/main/settings/base.py | 1 - ansible_ai_connect/main/settings/legacy.py | 1 + ansible_ai_connect/main/tests/test_views.py | 7 ++++++- ansible_ai_connect/main/views.py | 6 +++++- 5 files changed, 17 insertions(+), 3 deletions(-) diff --git a/ansible_ai_connect/ai/api/model_pipelines/http/configuration.py b/ansible_ai_connect/ai/api/model_pipelines/http/configuration.py index c9a361ed7..38bc2678b 100644 --- a/ansible_ai_connect/ai/api/model_pipelines/http/configuration.py +++ b/ansible_ai_connect/ai/api/model_pipelines/http/configuration.py @@ -41,11 +41,14 @@ def __init__( timeout: Optional[int], enable_health_check: Optional[bool], verify_ssl: bool, + stream: bool = False, ): super().__init__(inference_url, model_id, timeout, enable_health_check) self.verify_ssl = verify_ssl + self.stream = stream verify_ssl: bool + stream: bool @Register(api_type="http") @@ -60,6 +63,7 @@ def __init__(self, **kwargs): timeout=kwargs["timeout"], enable_health_check=kwargs["enable_health_check"], verify_ssl=kwargs["verify_ssl"], + stream=kwargs["stream"], ), ) @@ -67,3 +71,4 @@ def __init__(self, **kwargs): @Register(api_type="http") class HttpConfigurationSerializer(BaseConfigSerializer): verify_ssl = serializers.BooleanField(required=False, default=True) + stream = serializers.BooleanField(required=False, default=False) diff --git a/ansible_ai_connect/main/settings/base.py b/ansible_ai_connect/main/settings/base.py index b09858864..5d617a763 100644 --- a/ansible_ai_connect/main/settings/base.py +++ b/ansible_ai_connect/main/settings/base.py @@ -541,7 +541,6 @@ def is_ssl_enabled(value: str) -> bool: # ------------------------------------------ CHATBOT_DEFAULT_PROVIDER = os.getenv("CHATBOT_DEFAULT_PROVIDER") CHATBOT_DEBUG_UI = os.getenv("CHATBOT_DEBUG_UI", "False").lower() == "true" -CHATBOT_STREAM = os.getenv("CHATBOT_STREAM", "False").lower() == "true" # ========================================== # ========================================== diff --git a/ansible_ai_connect/main/settings/legacy.py b/ansible_ai_connect/main/settings/legacy.py index 605855acc..6543f43c8 100644 --- a/ansible_ai_connect/main/settings/legacy.py +++ b/ansible_ai_connect/main/settings/legacy.py @@ -189,6 +189,7 @@ def load_from_env_vars(): "inference_url": chatbot_service_url or "http://localhost:8000", "model_id": chatbot_service_model_id or "granite3-8b", "verify_ssl": model_service_verify_ssl, + "stream": False, }, } diff --git a/ansible_ai_connect/main/tests/test_views.py b/ansible_ai_connect/main/tests/test_views.py index f5a96c92e..059f7fc28 100644 --- a/ansible_ai_connect/main/tests/test_views.py +++ b/ansible_ai_connect/main/tests/test_views.py @@ -18,6 +18,7 @@ from http import HTTPStatus from textwrap import dedent +from django.apps import apps from django.contrib.auth import get_user_model from django.contrib.auth.models import AnonymousUser, Group from django.http import HttpResponseRedirect @@ -25,6 +26,7 @@ from django.urls import reverse from rest_framework.test import APITransactionTestCase +from ansible_ai_connect.ai.api.model_pipelines.pipelines import ModelPipelineChatBot from ansible_ai_connect.main.settings.base import SOCIAL_AUTH_OIDC_KEY from ansible_ai_connect.main.views import LoginView from ansible_ai_connect.test_utils import ( @@ -348,8 +350,11 @@ def test_chatbot_view_with_debug_ui(self): self.assertEqual(r.status_code, HTTPStatus.OK) self.assertContains(r, '') - @override_settings(CHATBOT_STREAM=True) def test_chatbot_view_with_streaming_enabled(self): + llm: ModelPipelineChatBot = apps.get_app_config("ai").get_model_pipeline( + ModelPipelineChatBot + ) + llm.config.stream = True self.client.force_login(user=self.rh_user) r = self.client.get(reverse("chatbot"), {"stream": "true"}) self.assertEqual(r.status_code, HTTPStatus.OK) diff --git a/ansible_ai_connect/main/views.py b/ansible_ai_connect/main/views.py index 5e81b8ad4..7632b75fd 100644 --- a/ansible_ai_connect/main/views.py +++ b/ansible_ai_connect/main/views.py @@ -139,7 +139,11 @@ def get_context_data(self, **kwargs): if user and user.is_authenticated: context["user_name"] = user.username context["debug"] = "true" if settings.CHATBOT_DEBUG_UI else "false" - context["stream"] = "true" if settings.CHATBOT_STREAM else "false" + + llm: ModelPipelineChatBot = apps.get_app_config("ai").get_model_pipeline( + ModelPipelineChatBot + ) + context["stream"] = "true" if llm.config.stream else "false" return context