From 20de6ca084428a27e0b061157b2d56b4f77884c9 Mon Sep 17 00:00:00 2001 From: Tami Takamiya Date: Sat, 16 Nov 2024 19:06:43 -0500 Subject: [PATCH] Chatbot debug mode (#1401) * Chatbot debug mode * Enable chatbot debug mode through a setting * Fix unit test and logic to check CHATBOT_DEBUG_UI setting --- ansible_ai_connect/ai/api/tests/test_views.py | 18 +++++++- ansible_ai_connect/ai/api/views.py | 8 +++- ansible_ai_connect/main/settings/base.py | 1 + .../main/templates/chatbot/index.html | 1 + ansible_ai_connect/main/tests/test_views.py | 9 ++++ ansible_ai_connect/main/views.py | 1 + ansible_ai_connect_chatbot/index.html | 1 + .../src/AnsibleChatbot/AnsibleChatbot.tsx | 40 ++++++++++++++++-- ansible_ai_connect_chatbot/src/App.test.tsx | 42 ++++++++++++++++++- ansible_ai_connect_chatbot/src/types/Model.ts | 4 ++ .../src/useChatbot/useChatbot.ts | 31 +++++++++++++- 11 files changed, 146 insertions(+), 10 deletions(-) create mode 100644 ansible_ai_connect_chatbot/src/types/Model.ts diff --git a/ansible_ai_connect/ai/api/tests/test_views.py b/ansible_ai_connect/ai/api/tests/test_views.py index db2c51d5b..eac79b57c 100644 --- a/ansible_ai_connect/ai/api/tests/test_views.py +++ b/ansible_ai_connect/ai/api/tests/test_views.py @@ -3819,6 +3819,12 @@ class TestChatView(WisdomServiceAPITestCaseBase): "query": "Return the internal server error status code", } + PAYLOAD_WITH_MODEL_AND_PROVIDER = { + "query": "Payload with a non-default model and a non-default provider", + "model": "non_default_model", + "provider": "non_default_provider", + } + JSON_RESPONSE = { "response": "AAP 2.5 introduces an updated, unified UI.", "conversation_id": "123e4567-e89b-12d3-a456-426614174000", @@ -3838,7 +3844,7 @@ def json(self): return self.json_data # Make sure that the given json data is serializable - json.dumps(kwargs["json"]) + input = json.dumps(kwargs["json"]) json_response = { "response": "AAP 2.5 introduces an updated, unified UI.", @@ -3880,7 +3886,9 @@ def json(self): json_response = { "detail": "Internal server error", } - + elif kwargs["json"]["query"] == TestChatView.PAYLOAD_WITH_MODEL_AND_PROVIDER["query"]: + status_code = 200 + json_response["response"] = input return MockResponse(json_response, status_code) @override_settings(CHATBOT_URL="http://localhost:8080") @@ -3926,6 +3934,7 @@ def assert_test( r, expected_exception().default_code, expected_exception().default_detail ) self.assertInLog(expected_log_message, log) + return r def test_chat(self): self.assert_test(TestChatView.VALID_PAYLOAD) @@ -3993,3 +4002,8 @@ def test_chat_internal_server_exception(self): ChatbotInternalServerException, "ChatbotInternalServerException", ) + + def test_chat_with_model_and_provider(self): + r = self.assert_test(TestChatView.PAYLOAD_WITH_MODEL_AND_PROVIDER) + self.assertIn('"model": "non_default_model"', r.data["response"]) + self.assertIn('"provider": "non_default_provider"', r.data["response"]) diff --git a/ansible_ai_connect/ai/api/views.py b/ansible_ai_connect/ai/api/views.py index 0d7652572..cf8456774 100644 --- a/ansible_ai_connect/ai/api/views.py +++ b/ansible_ai_connect/ai/api/views.py @@ -1130,8 +1130,12 @@ def post(self, request) -> Response: data = { "query": request_serializer.validated_data["query"], - "model": settings.CHATBOT_DEFAULT_MODEL, - "provider": settings.CHATBOT_DEFAULT_PROVIDER, + "model": request_serializer.validated_data.get( + "model", settings.CHATBOT_DEFAULT_MODEL + ), + "provider": request_serializer.validated_data.get( + "provider", settings.CHATBOT_DEFAULT_PROVIDER + ), } if "conversation_id" in request_serializer.validated_data: data["conversation_id"] = str(request_serializer.validated_data["conversation_id"]) diff --git a/ansible_ai_connect/main/settings/base.py b/ansible_ai_connect/main/settings/base.py index e8903fbd6..07a00cfed 100644 --- a/ansible_ai_connect/main/settings/base.py +++ b/ansible_ai_connect/main/settings/base.py @@ -624,6 +624,7 @@ def is_ssl_enabled(value: str) -> bool: CHATBOT_URL = os.getenv("CHATBOT_URL") CHATBOT_DEFAULT_PROVIDER = os.getenv("CHATBOT_DEFAULT_PROVIDER") CHATBOT_DEFAULT_MODEL = os.getenv("CHATBOT_DEFAULT_MODEL") +CHATBOT_DEBUG_UI = os.getenv("CHATBOT_DEBUG_UI", "False").lower() == "true" # ========================================== # ========================================== diff --git a/ansible_ai_connect/main/templates/chatbot/index.html b/ansible_ai_connect/main/templates/chatbot/index.html index 1df3341ba..6aadc4c8b 100644 --- a/ansible_ai_connect/main/templates/chatbot/index.html +++ b/ansible_ai_connect/main/templates/chatbot/index.html @@ -20,5 +20,6 @@ + {% endblock content %} diff --git a/ansible_ai_connect/main/tests/test_views.py b/ansible_ai_connect/main/tests/test_views.py index 1459a9eaa..022d5f301 100644 --- a/ansible_ai_connect/main/tests/test_views.py +++ b/ansible_ai_connect/main/tests/test_views.py @@ -233,6 +233,7 @@ def test_get_view_expired_trial(self): @override_settings(CHATBOT_DEFAULT_PROVIDER="wisdom") @override_settings(CHATBOT_DEFAULT_MODEL="granite-8b") @override_settings(ANSIBLE_AI_CHATBOT_NAME="Awesome Chatbot") +@override_settings(CHATBOT_DEBUG_UI=False) class TestChatbotView(TestCase): CHATBOT_PAGE_TITLE = "Awesome Chatbot" DOCUMENT_URL = ( @@ -332,3 +333,11 @@ def test_chatbot_view_with_rh_user(self): self.assertEqual(r.status_code, HTTPStatus.OK) self.assertContains(r, TestChatbotView.CHATBOT_PAGE_TITLE) self.assertContains(r, self.rh_user.username) + self.assertContains(r, '') + + @override_settings(CHATBOT_DEBUG_UI=True) + def test_chatbot_view_with_debug_ui(self): + self.client.force_login(user=self.rh_user) + r = self.client.get(reverse("chatbot"), {"debug": "true"}) + self.assertEqual(r.status_code, HTTPStatus.OK) + self.assertContains(r, '') diff --git a/ansible_ai_connect/main/views.py b/ansible_ai_connect/main/views.py index dd48d0d0e..77333403c 100644 --- a/ansible_ai_connect/main/views.py +++ b/ansible_ai_connect/main/views.py @@ -136,6 +136,7 @@ def get_context_data(self, **kwargs): user = self.request.user if user and user.is_authenticated: context["user_name"] = user.username + context["debug"] = "true" if settings.CHATBOT_DEBUG_UI else "false" return context diff --git a/ansible_ai_connect_chatbot/index.html b/ansible_ai_connect_chatbot/index.html index 63117ff28..4e56da4f6 100644 --- a/ansible_ai_connect_chatbot/index.html +++ b/ansible_ai_connect_chatbot/index.html @@ -23,5 +23,6 @@ + diff --git a/ansible_ai_connect_chatbot/src/AnsibleChatbot/AnsibleChatbot.tsx b/ansible_ai_connect_chatbot/src/AnsibleChatbot/AnsibleChatbot.tsx index 2af447040..573a9eeb6 100644 --- a/ansible_ai_connect_chatbot/src/AnsibleChatbot/AnsibleChatbot.tsx +++ b/ansible_ai_connect_chatbot/src/AnsibleChatbot/AnsibleChatbot.tsx @@ -33,12 +33,18 @@ import lightspeedLogo from "../assets/lightspeed.svg"; import lightspeedLogoDark from "../assets/lightspeed_dark.svg"; import "./AnsibleChatbot.scss"; -import { botMessage, useChatbot } from "../useChatbot/useChatbot"; +import { + botMessage, + inDebugMode, + modelsSupported, + useChatbot, +} from "../useChatbot/useChatbot"; import { ReferencedDocuments } from "../ReferencedDocuments/ReferencedDocuments"; import type { ExtendedMessage } from "../types/Message"; import { ChatbotAlert, + ChatbotHeaderSelectorDropdown, ChatbotToggle, FileDropZone, } from "@patternfly/virtual-assistant"; @@ -66,8 +72,15 @@ const footnoteProps = { }; export const AnsibleChatbot: React.FunctionComponent = () => { - const { messages, isLoading, handleSend, alertMessage, setAlertMessage } = - useChatbot(); + const { + messages, + isLoading, + handleSend, + alertMessage, + setAlertMessage, + selectedModel, + setSelectedModel, + } = useChatbot(); const [chatbotVisible, setChatbotVisible] = useState(true); const [displayMode, setDisplayMode] = useState( ChatbotDisplayMode.default, @@ -82,6 +95,13 @@ export const AnsibleChatbot: React.FunctionComponent = () => { scrollToBottom(); }, [messages]); + const onSelectModel = ( + _event: React.MouseEvent | undefined, + value: string | number | undefined, + ) => { + setSelectedModel(value as string); + }; + const onSelectDisplayMode = ( _event: React.MouseEvent | undefined, value: string | number | undefined, @@ -114,6 +134,20 @@ export const AnsibleChatbot: React.FunctionComponent = () => { + {inDebugMode() && ( + + + {modelsSupported.map((m) => ( + + {m.model} + + ))} + + + )} diff --git a/ansible_ai_connect_chatbot/src/App.test.tsx b/ansible_ai_connect_chatbot/src/App.test.tsx index 578dc89f3..53d90b95a 100644 --- a/ansible_ai_connect_chatbot/src/App.test.tsx +++ b/ansible_ai_connect_chatbot/src/App.test.tsx @@ -8,15 +8,25 @@ import userEvent from "@testing-library/user-event"; import axios from "axios"; describe("App tests", () => { - const renderApp = () => - render( + const renderApp = (debug = false) => { + const debugDiv = document.createElement("div"); + debugDiv.setAttribute("id", "debug"); + debugDiv.innerText = debug.toString(); + document.body.appendChild(debugDiv); + const rootDiv = document.createElement("div"); + rootDiv.setAttribute("id", "root"); + return render(
, + { + container: document.body.appendChild(rootDiv), + }, ); + }; const mockAxios = (status: number, reject = false) => { const spy = vi.spyOn(axios, "post"); if (reject) { @@ -46,6 +56,10 @@ describe("App tests", () => { afterEach(() => { vi.restoreAllMocks(); + const rootDiv = document.getElementById("root"); + rootDiv?.remove(); + const debugDiv = document.getElementById("debug"); + debugDiv?.remove(); }); it("App renders", () => { @@ -119,4 +133,28 @@ describe("App tests", () => { // expect(getComputedStyle(showDark!).display).toEqual("block") } }); + + it("Debug mode test", async () => { + mockAxios(200); + + renderApp(true); + const modelSelection = screen.getByText("granite-8b"); + await act(async () => fireEvent.click(modelSelection)); + expect(screen.getByRole("menuitem", { name: "granite-8b" })).toBeTruthy(); + expect(screen.getByRole("menuitem", { name: "granite3-8b" })).toBeTruthy(); + await act(async () => + screen.getByRole("menuitem", { name: "granite3-8b" }).click(), + ); + + const textArea = screen.getByLabelText("Send a message..."); + await act(async () => userEvent.type(textArea, "Hello")); + const sendButton = screen.getByLabelText("Send button"); + await act(async () => fireEvent.click(sendButton)); + expect( + screen.getByText( + "In Ansible, the precedence of variables is determined by the order...", + ), + ).toBeInTheDocument(); + expect(screen.getByText("Create variables")).toBeInTheDocument(); + }); }); diff --git a/ansible_ai_connect_chatbot/src/types/Model.ts b/ansible_ai_connect_chatbot/src/types/Model.ts new file mode 100644 index 000000000..7d3cdcc71 --- /dev/null +++ b/ansible_ai_connect_chatbot/src/types/Model.ts @@ -0,0 +1,4 @@ +export type LLMModel = { + model: string; + provider: string; +}; diff --git a/ansible_ai_connect_chatbot/src/useChatbot/useChatbot.ts b/ansible_ai_connect_chatbot/src/useChatbot/useChatbot.ts index ca31ddd54..7297600d2 100644 --- a/ansible_ai_connect_chatbot/src/useChatbot/useChatbot.ts +++ b/ansible_ai_connect_chatbot/src/useChatbot/useChatbot.ts @@ -6,6 +6,7 @@ import type { ChatRequest, ChatResponse, } from "../types/Message"; +import type { LLMModel } from "../types/Model"; import logo from "../assets/lightspeed.svg"; import userLogo from "../assets/user_logo.png"; @@ -13,6 +14,11 @@ const userName = document.getElementById("user_name")?.innerText ?? "User"; const botName = document.getElementById("bot_name")?.innerText ?? "Ansible Lightspeed"; +export const modelsSupported: LLMModel[] = [ + { model: "granite-8b", provider: "my_rhoai" }, + { model: "granite3-8b", provider: "my_rhoai_g3" }, +]; + export const readCookie = (name: string): string | null => { const nameEQ = name + "="; const ca = document.cookie.split(";"); @@ -30,6 +36,11 @@ const getTimestamp = () => { return `${date.toLocaleDateString()} ${date.toLocaleTimeString()}`; }; +export const inDebugMode = () => { + const debug = document.getElementById("debug")?.innerText ?? "false"; + return debug === "true"; +}; + export const botMessage = (content: string): MessageProps => ({ role: "bot", content, @@ -66,6 +77,7 @@ export const useChatbot = () => { const [conversationId, setConversationId] = useState< string | null | undefined >(undefined); + const [selectedModel, setSelectedModel] = useState("granite-8b"); const handleSend = async (message: string) => { const userMessage: ExtendedMessage = { @@ -83,6 +95,15 @@ export const useChatbot = () => { query: message, }; + if (inDebugMode()) { + for (const m of modelsSupported) { + if (selectedModel === m.model) { + chatRequest.model = m.model; + chatRequest.provider = m.provider; + } + } + } + setIsLoading(true); try { const csrfToken = readCookie("csrftoken"); @@ -129,5 +150,13 @@ export const useChatbot = () => { } }; - return { messages, isLoading, handleSend, alertMessage, setAlertMessage }; + return { + messages, + isLoading, + handleSend, + alertMessage, + setAlertMessage, + selectedModel, + setSelectedModel, + }; };