Skip to content

Commit

Permalink
Chatbot debug mode (#1401)
Browse files Browse the repository at this point in the history
* Chatbot debug mode

* Enable chatbot debug mode through a setting

* Fix unit test and logic to check CHATBOT_DEBUG_UI setting
  • Loading branch information
TamiTakamiya authored Nov 17, 2024
1 parent 06f0c74 commit 20de6ca
Show file tree
Hide file tree
Showing 11 changed files with 146 additions and 10 deletions.
18 changes: 16 additions & 2 deletions ansible_ai_connect/ai/api/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3819,6 +3819,12 @@ class TestChatView(WisdomServiceAPITestCaseBase):
"query": "Return the internal server error status code",
}

PAYLOAD_WITH_MODEL_AND_PROVIDER = {
"query": "Payload with a non-default model and a non-default provider",
"model": "non_default_model",
"provider": "non_default_provider",
}

JSON_RESPONSE = {
"response": "AAP 2.5 introduces an updated, unified UI.",
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
Expand All @@ -3838,7 +3844,7 @@ def json(self):
return self.json_data

# Make sure that the given json data is serializable
json.dumps(kwargs["json"])
input = json.dumps(kwargs["json"])

json_response = {
"response": "AAP 2.5 introduces an updated, unified UI.",
Expand Down Expand Up @@ -3880,7 +3886,9 @@ def json(self):
json_response = {
"detail": "Internal server error",
}

elif kwargs["json"]["query"] == TestChatView.PAYLOAD_WITH_MODEL_AND_PROVIDER["query"]:
status_code = 200
json_response["response"] = input
return MockResponse(json_response, status_code)

@override_settings(CHATBOT_URL="http://localhost:8080")
Expand Down Expand Up @@ -3926,6 +3934,7 @@ def assert_test(
r, expected_exception().default_code, expected_exception().default_detail
)
self.assertInLog(expected_log_message, log)
return r

def test_chat(self):
self.assert_test(TestChatView.VALID_PAYLOAD)
Expand Down Expand Up @@ -3993,3 +4002,8 @@ def test_chat_internal_server_exception(self):
ChatbotInternalServerException,
"ChatbotInternalServerException",
)

def test_chat_with_model_and_provider(self):
r = self.assert_test(TestChatView.PAYLOAD_WITH_MODEL_AND_PROVIDER)
self.assertIn('"model": "non_default_model"', r.data["response"])
self.assertIn('"provider": "non_default_provider"', r.data["response"])
8 changes: 6 additions & 2 deletions ansible_ai_connect/ai/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,8 +1130,12 @@ def post(self, request) -> Response:

data = {
"query": request_serializer.validated_data["query"],
"model": settings.CHATBOT_DEFAULT_MODEL,
"provider": settings.CHATBOT_DEFAULT_PROVIDER,
"model": request_serializer.validated_data.get(
"model", settings.CHATBOT_DEFAULT_MODEL
),
"provider": request_serializer.validated_data.get(
"provider", settings.CHATBOT_DEFAULT_PROVIDER
),
}
if "conversation_id" in request_serializer.validated_data:
data["conversation_id"] = str(request_serializer.validated_data["conversation_id"])
Expand Down
1 change: 1 addition & 0 deletions ansible_ai_connect/main/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,7 @@ def is_ssl_enabled(value: str) -> bool:
CHATBOT_URL = os.getenv("CHATBOT_URL")
CHATBOT_DEFAULT_PROVIDER = os.getenv("CHATBOT_DEFAULT_PROVIDER")
CHATBOT_DEFAULT_MODEL = os.getenv("CHATBOT_DEFAULT_MODEL")
CHATBOT_DEBUG_UI = os.getenv("CHATBOT_DEBUG_UI", "False").lower() == "true"
# ==========================================

# ==========================================
Expand Down
1 change: 1 addition & 0 deletions ansible_ai_connect/main/templates/chatbot/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@
<!-- as it is only used for display purposes and servers no other value -->
<div id="user_name" hidden>{{user_name}}</div>
<div id="bot_name" hidden>{{bot_name}}</div>
<div id="debug" hidden>{{debug}}</div>
{% endblock content %}
</html>
9 changes: 9 additions & 0 deletions ansible_ai_connect/main/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def test_get_view_expired_trial(self):
@override_settings(CHATBOT_DEFAULT_PROVIDER="wisdom")
@override_settings(CHATBOT_DEFAULT_MODEL="granite-8b")
@override_settings(ANSIBLE_AI_CHATBOT_NAME="Awesome Chatbot")
@override_settings(CHATBOT_DEBUG_UI=False)
class TestChatbotView(TestCase):
CHATBOT_PAGE_TITLE = "<title>Awesome Chatbot</title>"
DOCUMENT_URL = (
Expand Down Expand Up @@ -332,3 +333,11 @@ def test_chatbot_view_with_rh_user(self):
self.assertEqual(r.status_code, HTTPStatus.OK)
self.assertContains(r, TestChatbotView.CHATBOT_PAGE_TITLE)
self.assertContains(r, self.rh_user.username)
self.assertContains(r, '<div id="debug" hidden>false</div>')

@override_settings(CHATBOT_DEBUG_UI=True)
def test_chatbot_view_with_debug_ui(self):
self.client.force_login(user=self.rh_user)
r = self.client.get(reverse("chatbot"), {"debug": "true"})
self.assertEqual(r.status_code, HTTPStatus.OK)
self.assertContains(r, '<div id="debug" hidden>true</div>')
1 change: 1 addition & 0 deletions ansible_ai_connect/main/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def get_context_data(self, **kwargs):
user = self.request.user
if user and user.is_authenticated:
context["user_name"] = user.username
context["debug"] = "true" if settings.CHATBOT_DEBUG_UI else "false"

return context

Expand Down
1 change: 1 addition & 0 deletions ansible_ai_connect_chatbot/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,6 @@
</div>
<div id="user_name" hidden>{{user_name}}</div>
<div id="bot_name" hidden>{{bot_name}}</div>
<div id="debug" hidden>{{debug}}</div>
</body>
</html>
40 changes: 37 additions & 3 deletions ansible_ai_connect_chatbot/src/AnsibleChatbot/AnsibleChatbot.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,18 @@ import lightspeedLogo from "../assets/lightspeed.svg";
import lightspeedLogoDark from "../assets/lightspeed_dark.svg";

import "./AnsibleChatbot.scss";
import { botMessage, useChatbot } from "../useChatbot/useChatbot";
import {
botMessage,
inDebugMode,
modelsSupported,
useChatbot,
} from "../useChatbot/useChatbot";
import { ReferencedDocuments } from "../ReferencedDocuments/ReferencedDocuments";

import type { ExtendedMessage } from "../types/Message";
import {
ChatbotAlert,
ChatbotHeaderSelectorDropdown,
ChatbotToggle,
FileDropZone,
} from "@patternfly/virtual-assistant";
Expand Down Expand Up @@ -66,8 +72,15 @@ const footnoteProps = {
};

export const AnsibleChatbot: React.FunctionComponent = () => {
const { messages, isLoading, handleSend, alertMessage, setAlertMessage } =
useChatbot();
const {
messages,
isLoading,
handleSend,
alertMessage,
setAlertMessage,
selectedModel,
setSelectedModel,
} = useChatbot();
const [chatbotVisible, setChatbotVisible] = useState<boolean>(true);
const [displayMode, setDisplayMode] = useState<ChatbotDisplayMode>(
ChatbotDisplayMode.default,
Expand All @@ -82,6 +95,13 @@ export const AnsibleChatbot: React.FunctionComponent = () => {
scrollToBottom();
}, [messages]);

const onSelectModel = (
_event: React.MouseEvent<Element, MouseEvent> | undefined,
value: string | number | undefined,
) => {
setSelectedModel(value as string);
};

const onSelectDisplayMode = (
_event: React.MouseEvent<Element, MouseEvent> | undefined,
value: string | number | undefined,
Expand Down Expand Up @@ -114,6 +134,20 @@ export const AnsibleChatbot: React.FunctionComponent = () => {
</Bullseye>
</ChatbotHeaderTitle>
<ChatbotHeaderActions>
{inDebugMode() && (
<ChatbotHeaderSelectorDropdown
value={selectedModel}
onSelect={onSelectModel}
>
<DropdownList>
{modelsSupported.map((m) => (
<DropdownItem value={m.model} key={m.model}>
{m.model}
</DropdownItem>
))}
</DropdownList>
</ChatbotHeaderSelectorDropdown>
)}
<ChatbotHeaderOptionsDropdown onSelect={onSelectDisplayMode}>
<DropdownGroup label="Display mode">
<DropdownList>
Expand Down
42 changes: 40 additions & 2 deletions ansible_ai_connect_chatbot/src/App.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,25 @@ import userEvent from "@testing-library/user-event";
import axios from "axios";

describe("App tests", () => {
const renderApp = () =>
render(
const renderApp = (debug = false) => {
const debugDiv = document.createElement("div");
debugDiv.setAttribute("id", "debug");
debugDiv.innerText = debug.toString();
document.body.appendChild(debugDiv);
const rootDiv = document.createElement("div");
rootDiv.setAttribute("id", "root");
return render(
<MemoryRouter>
<div className="pf-v6-l-flex pf-m-column pf-m-gap-lg ws-full-page-utils pf-v6-m-dir-ltr ">
<ColorThemeSwitch />
</div>
<App />
</MemoryRouter>,
{
container: document.body.appendChild(rootDiv),
},
);
};
const mockAxios = (status: number, reject = false) => {
const spy = vi.spyOn(axios, "post");
if (reject) {
Expand Down Expand Up @@ -46,6 +56,10 @@ describe("App tests", () => {

afterEach(() => {
vi.restoreAllMocks();
const rootDiv = document.getElementById("root");
rootDiv?.remove();
const debugDiv = document.getElementById("debug");
debugDiv?.remove();
});

it("App renders", () => {
Expand Down Expand Up @@ -119,4 +133,28 @@ describe("App tests", () => {
// expect(getComputedStyle(showDark!).display).toEqual("block")
}
});

it("Debug mode test", async () => {
mockAxios(200);

renderApp(true);
const modelSelection = screen.getByText("granite-8b");
await act(async () => fireEvent.click(modelSelection));
expect(screen.getByRole("menuitem", { name: "granite-8b" })).toBeTruthy();
expect(screen.getByRole("menuitem", { name: "granite3-8b" })).toBeTruthy();
await act(async () =>
screen.getByRole("menuitem", { name: "granite3-8b" }).click(),
);

const textArea = screen.getByLabelText("Send a message...");
await act(async () => userEvent.type(textArea, "Hello"));
const sendButton = screen.getByLabelText("Send button");
await act(async () => fireEvent.click(sendButton));
expect(
screen.getByText(
"In Ansible, the precedence of variables is determined by the order...",
),
).toBeInTheDocument();
expect(screen.getByText("Create variables")).toBeInTheDocument();
});
});
4 changes: 4 additions & 0 deletions ansible_ai_connect_chatbot/src/types/Model.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
export type LLMModel = {
model: string;
provider: string;
};
31 changes: 30 additions & 1 deletion ansible_ai_connect_chatbot/src/useChatbot/useChatbot.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,19 @@ import type {
ChatRequest,
ChatResponse,
} from "../types/Message";
import type { LLMModel } from "../types/Model";
import logo from "../assets/lightspeed.svg";
import userLogo from "../assets/user_logo.png";

const userName = document.getElementById("user_name")?.innerText ?? "User";
const botName =
document.getElementById("bot_name")?.innerText ?? "Ansible Lightspeed";

export const modelsSupported: LLMModel[] = [
{ model: "granite-8b", provider: "my_rhoai" },
{ model: "granite3-8b", provider: "my_rhoai_g3" },
];

export const readCookie = (name: string): string | null => {
const nameEQ = name + "=";
const ca = document.cookie.split(";");
Expand All @@ -30,6 +36,11 @@ const getTimestamp = () => {
return `${date.toLocaleDateString()} ${date.toLocaleTimeString()}`;
};

export const inDebugMode = () => {
const debug = document.getElementById("debug")?.innerText ?? "false";
return debug === "true";
};

export const botMessage = (content: string): MessageProps => ({
role: "bot",
content,
Expand Down Expand Up @@ -66,6 +77,7 @@ export const useChatbot = () => {
const [conversationId, setConversationId] = useState<
string | null | undefined
>(undefined);
const [selectedModel, setSelectedModel] = useState("granite-8b");

const handleSend = async (message: string) => {
const userMessage: ExtendedMessage = {
Expand All @@ -83,6 +95,15 @@ export const useChatbot = () => {
query: message,
};

if (inDebugMode()) {
for (const m of modelsSupported) {
if (selectedModel === m.model) {
chatRequest.model = m.model;
chatRequest.provider = m.provider;
}
}
}

setIsLoading(true);
try {
const csrfToken = readCookie("csrftoken");
Expand Down Expand Up @@ -129,5 +150,13 @@ export const useChatbot = () => {
}
};

return { messages, isLoading, handleSend, alertMessage, setAlertMessage };
return {
messages,
isLoading,
handleSend,
alertMessage,
setAlertMessage,
selectedModel,
setSelectedModel,
};
};

0 comments on commit 20de6ca

Please sign in to comment.