Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chatbot debug mode #1401

Merged
merged 3 commits into from
Nov 17, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions ansible_ai_connect/ai/api/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3819,6 +3819,12 @@ class TestChatView(WisdomServiceAPITestCaseBase):
"query": "Return the internal server error status code",
}

PAYLOAD_WITH_MODEL_AND_PROVIDER = {
"query": "Payload with a non-default model and a non-default provider",
"model": "non_default_model",
"provider": "non_default_provider",
}

JSON_RESPONSE = {
"response": "AAP 2.5 introduces an updated, unified UI.",
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
Expand All @@ -3838,7 +3844,7 @@ def json(self):
return self.json_data

# Make sure that the given json data is serializable
json.dumps(kwargs["json"])
input = json.dumps(kwargs["json"])

json_response = {
"response": "AAP 2.5 introduces an updated, unified UI.",
Expand Down Expand Up @@ -3880,7 +3886,9 @@ def json(self):
json_response = {
"detail": "Internal server error",
}

elif kwargs["json"]["query"] == TestChatView.PAYLOAD_WITH_MODEL_AND_PROVIDER["query"]:
status_code = 200
json_response["response"] = input
return MockResponse(json_response, status_code)

@override_settings(CHATBOT_URL="http://localhost:8080")
Expand Down Expand Up @@ -3926,6 +3934,7 @@ def assert_test(
r, expected_exception().default_code, expected_exception().default_detail
)
self.assertInLog(expected_log_message, log)
return r

def test_chat(self):
self.assert_test(TestChatView.VALID_PAYLOAD)
Expand Down Expand Up @@ -3993,3 +4002,8 @@ def test_chat_internal_server_exception(self):
ChatbotInternalServerException,
"ChatbotInternalServerException",
)

def test_chat_with_model_and_provider(self):
r = self.assert_test(TestChatView.PAYLOAD_WITH_MODEL_AND_PROVIDER)
self.assertIn('"model": "non_default_model"', r.data["response"])
self.assertIn('"provider": "non_default_provider"', r.data["response"])
8 changes: 6 additions & 2 deletions ansible_ai_connect/ai/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,8 +1130,12 @@ def post(self, request) -> Response:

data = {
"query": request_serializer.validated_data["query"],
"model": settings.CHATBOT_DEFAULT_MODEL,
"provider": settings.CHATBOT_DEFAULT_PROVIDER,
"model": request_serializer.validated_data.get(
"model", settings.CHATBOT_DEFAULT_MODEL
),
"provider": request_serializer.validated_data.get(
"provider", settings.CHATBOT_DEFAULT_PROVIDER
),
}
if "conversation_id" in request_serializer.validated_data:
data["conversation_id"] = str(request_serializer.validated_data["conversation_id"])
Expand Down
1 change: 1 addition & 0 deletions ansible_ai_connect/main/templates/chatbot/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@
<!-- as it is only used for display purposes and servers no other value -->
<div id="user_name" hidden>{{user_name}}</div>
<div id="bot_name" hidden>{{bot_name}}</div>
<div id="debug" hidden>{{debug}}</div>
{% endblock content %}
</html>
6 changes: 6 additions & 0 deletions ansible_ai_connect/main/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,3 +332,9 @@ def test_chatbot_view_with_rh_user(self):
self.assertEqual(r.status_code, HTTPStatus.OK)
self.assertContains(r, TestChatbotView.CHATBOT_PAGE_TITLE)
self.assertContains(r, self.rh_user.username)

def test_chatbot_view_with_debug_option(self):
self.client.force_login(user=self.rh_user)
r = self.client.get(reverse("chatbot"), {"debug": "true"})
self.assertEqual(r.status_code, HTTPStatus.OK)
self.assertContains(r, '<div id="debug" hidden>true</div>')
3 changes: 3 additions & 0 deletions ansible_ai_connect/main/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ def get(self, request):
and settings.CHATBOT_DEFAULT_MODEL
and settings.CHATBOT_DEFAULT_PROVIDER
):
debug = request.GET.get("debug", "false")
self.debug = debug.lower() == "true"
return super().get(request)

# Otherwise, redirect to the home page.
Expand All @@ -136,6 +138,7 @@ def get_context_data(self, **kwargs):
user = self.request.user
if user and user.is_authenticated:
context["user_name"] = user.username
context["debug"] = "true" if self.debug else "false"

return context

Expand Down
1 change: 1 addition & 0 deletions ansible_ai_connect_chatbot/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,6 @@
</div>
<div id="user_name" hidden>{{user_name}}</div>
<div id="bot_name" hidden>{{bot_name}}</div>
<div id="debug" hidden>{{debug}}</div>
</body>
</html>
40 changes: 37 additions & 3 deletions ansible_ai_connect_chatbot/src/AnsibleChatbot/AnsibleChatbot.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,18 @@ import lightspeedLogo from "../assets/lightspeed.svg";
import lightspeedLogoDark from "../assets/lightspeed_dark.svg";

import "./AnsibleChatbot.scss";
import { botMessage, useChatbot } from "../useChatbot/useChatbot";
import {
botMessage,
inDebugMode,
modelsSupported,
useChatbot,
} from "../useChatbot/useChatbot";
import { ReferencedDocuments } from "../ReferencedDocuments/ReferencedDocuments";

import type { ExtendedMessage } from "../types/Message";
import {
ChatbotAlert,
ChatbotHeaderSelectorDropdown,
ChatbotToggle,
FileDropZone,
} from "@patternfly/virtual-assistant";
Expand Down Expand Up @@ -66,8 +72,15 @@ const footnoteProps = {
};

export const AnsibleChatbot: React.FunctionComponent = () => {
const { messages, isLoading, handleSend, alertMessage, setAlertMessage } =
useChatbot();
const {
messages,
isLoading,
handleSend,
alertMessage,
setAlertMessage,
selectedModel,
setSelectedModel,
} = useChatbot();
const [chatbotVisible, setChatbotVisible] = useState<boolean>(true);
const [displayMode, setDisplayMode] = useState<ChatbotDisplayMode>(
ChatbotDisplayMode.default,
Expand All @@ -82,6 +95,13 @@ export const AnsibleChatbot: React.FunctionComponent = () => {
scrollToBottom();
}, [messages]);

const onSelectModel = (
_event: React.MouseEvent<Element, MouseEvent> | undefined,
value: string | number | undefined,
) => {
setSelectedModel(value as string);
};

const onSelectDisplayMode = (
_event: React.MouseEvent<Element, MouseEvent> | undefined,
value: string | number | undefined,
Expand Down Expand Up @@ -114,6 +134,20 @@ export const AnsibleChatbot: React.FunctionComponent = () => {
</Bullseye>
</ChatbotHeaderTitle>
<ChatbotHeaderActions>
{inDebugMode() && (
<ChatbotHeaderSelectorDropdown
value={selectedModel}
onSelect={onSelectModel}
>
<DropdownList>
{modelsSupported.map((m) => (
<DropdownItem value={m.model} key={m.model}>
{m.model}
</DropdownItem>
))}
</DropdownList>
</ChatbotHeaderSelectorDropdown>
)}
<ChatbotHeaderOptionsDropdown onSelect={onSelectDisplayMode}>
<DropdownGroup label="Display mode">
<DropdownList>
Expand Down
42 changes: 40 additions & 2 deletions ansible_ai_connect_chatbot/src/App.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,25 @@ import userEvent from "@testing-library/user-event";
import axios from "axios";

describe("App tests", () => {
const renderApp = () =>
render(
const renderApp = (debug = false) => {
const debugDiv = document.createElement("div");
debugDiv.setAttribute("id", "debug");
debugDiv.innerText = debug.toString();
document.body.appendChild(debugDiv);
const rootDiv = document.createElement("div");
rootDiv.setAttribute("id", "root");
return render(
<MemoryRouter>
<div className="pf-v6-l-flex pf-m-column pf-m-gap-lg ws-full-page-utils pf-v6-m-dir-ltr ">
<ColorThemeSwitch />
</div>
<App />
</MemoryRouter>,
{
container: document.body.appendChild(rootDiv),
},
);
};
const mockAxios = (status: number, reject = false) => {
const spy = vi.spyOn(axios, "post");
if (reject) {
Expand Down Expand Up @@ -46,6 +56,10 @@ describe("App tests", () => {

afterEach(() => {
vi.restoreAllMocks();
const rootDiv = document.getElementById("root");
rootDiv?.remove();
const debugDiv = document.getElementById("debug");
debugDiv?.remove();
});

it("App renders", () => {
Expand Down Expand Up @@ -119,4 +133,28 @@ describe("App tests", () => {
// expect(getComputedStyle(showDark!).display).toEqual("block")
}
});

it("Debug mode test", async () => {
mockAxios(200);

renderApp(true);
const modelSelection = screen.getByText("granite-8b");
await act(async () => fireEvent.click(modelSelection));
expect(screen.getByRole("menuitem", { name: "granite-8b" })).toBeTruthy();
expect(screen.getByRole("menuitem", { name: "granite3-8b" })).toBeTruthy();
await act(async () =>
screen.getByRole("menuitem", { name: "granite3-8b" }).click(),
);

const textArea = screen.getByLabelText("Send a message...");
await act(async () => userEvent.type(textArea, "Hello"));
const sendButton = screen.getByLabelText("Send button");
await act(async () => fireEvent.click(sendButton));
expect(
screen.getByText(
"In Ansible, the precedence of variables is determined by the order...",
),
).toBeInTheDocument();
expect(screen.getByText("Create variables")).toBeInTheDocument();
});
});
4 changes: 4 additions & 0 deletions ansible_ai_connect_chatbot/src/types/Model.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
export type LLMModel = {
model: string;
provider: string;
};
31 changes: 30 additions & 1 deletion ansible_ai_connect_chatbot/src/useChatbot/useChatbot.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,19 @@ import type {
ChatRequest,
ChatResponse,
} from "../types/Message";
import type { LLMModel } from "../types/Model";
import logo from "../assets/lightspeed.svg";
import userLogo from "../assets/user_logo.png";

const userName = document.getElementById("user_name")?.innerText ?? "User";
const botName =
document.getElementById("bot_name")?.innerText ?? "Ansible Lightspeed";

export const modelsSupported: LLMModel[] = [
{ model: "granite-8b", provider: "my_rhoai" },
{ model: "granite3-8b", provider: "my_rhoai_g3" },
];

export const readCookie = (name: string): string | null => {
const nameEQ = name + "=";
const ca = document.cookie.split(";");
Expand All @@ -30,6 +36,11 @@ const getTimestamp = () => {
return `${date.toLocaleDateString()} ${date.toLocaleTimeString()}`;
};

export const inDebugMode = () => {
const debug = document.getElementById("debug")?.innerText ?? "false";
return debug === "true";
};

export const botMessage = (content: string): MessageProps => ({
role: "bot",
content,
Expand Down Expand Up @@ -66,6 +77,7 @@ export const useChatbot = () => {
const [conversationId, setConversationId] = useState<
string | null | undefined
>(undefined);
const [selectedModel, setSelectedModel] = useState("granite-8b");

const handleSend = async (message: string) => {
const userMessage: ExtendedMessage = {
Expand All @@ -83,6 +95,15 @@ export const useChatbot = () => {
query: message,
};

if (inDebugMode()) {
for (const m of modelsSupported) {
if (selectedModel === m.model) {
chatRequest.model = m.model;
chatRequest.provider = m.provider;
}
}
}

setIsLoading(true);
try {
const csrfToken = readCookie("csrftoken");
Expand Down Expand Up @@ -129,5 +150,13 @@ export const useChatbot = () => {
}
};

return { messages, isLoading, handleSend, alertMessage, setAlertMessage };
return {
messages,
isLoading,
handleSend,
alertMessage,
setAlertMessage,
selectedModel,
setSelectedModel,
};
};
Loading