From 4756cd1d009a2afd3b34cba6c3c38df5d66cb4e9 Mon Sep 17 00:00:00 2001 From: Adam Dougal Date: Thu, 8 Feb 2024 20:51:40 +0000 Subject: [PATCH] Fix IndexError when collating chat history (#195) * Fix IndexError when collating chat history This fixes a bug which causes the exception: ``` ERROR:root:Exception in /api/conversation/custom | list index out of range Traceback (most recent call last): File "/workspaces/chat-with-your-data-solution-accelerator/code/app/app.py", line 283, in conversation_custom chat_history.append((user_assistant_messages[i]['content'],user_assistant_messages[i+1]['content'])) ``` This is caused when there has been an error providing a response, and the latest message in the history is from a user, rather than the assitant. Our code assumes a user message is always followed by an assistant message. This change removes that assumption and explitely retreives the role for each message when collating the chat history. Required by https://github.com/Azure-Samples/chat-with-your-data-solution-accelerator/issues/114 * Add python formatter to dev container * Add tests for conversation custom - Extract some elements to dedicated function to allow mocking * Switch to black formatter to align with precommit hook * Add test to cover error scenario when message index is out of range * Add dependencies required for running app tests --------- Co-authored-by: Ross Smith --- .devcontainer/devcontainer.json | 1 + .github/workflows/unittests.yml | 2 +- code/app/app.py | 30 ++-- code/utilities/orchestrator/LangChainAgent.py | 6 +- .../utilities/orchestrator/OpenAIFunctions.py | 3 +- tests/test_app.py | 161 ++++++++++++++++++ 6 files changed, 183 insertions(+), 20 deletions(-) create mode 100644 tests/test_app.py diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 6e66ca3a3..b213342df 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -23,6 +23,7 @@ "ms-azuretools.vscode-bicep", "ms-azuretools.vscode-docker", "ms-python.python", + "ms-python.black-formatter", "ms-python.vscode-pylance", "ms-vscode.vscode-node-azure-pack", "TeamsDevApp.ms-teams-vscode-extension" diff --git a/.github/workflows/unittests.yml b/.github/workflows/unittests.yml index eb0cefec8..7c0f6583f 100644 --- a/.github/workflows/unittests.yml +++ b/.github/workflows/unittests.yml @@ -20,6 +20,6 @@ jobs: architecture: x64 - name: Install dependencies run: | - pip install -r code/requirements.txt -r code/dev-requirements.txt + pip install -r code/requirements.txt -r code/dev-requirements.txt -r code/app/requirements.txt - name: Run Python tests run: python -m pytest --rootdir=code -m "not azure" diff --git a/code/app/app.py b/code/app/app.py index 3b92116b7..5ec28fd4b 100644 --- a/code/app/app.py +++ b/code/app/app.py @@ -313,11 +313,21 @@ def conversation_azure_byod(): ) -@app.route("/api/conversation/custom", methods=["GET", "POST"]) -def conversation_custom(): +def get_message_orchestrator(): from utilities.helpers.OrchestratorHelper import Orchestrator - message_orchestrator = Orchestrator() + return Orchestrator() + + +def get_orchestrator_config(): + from utilities.helpers.ConfigHelper import ConfigHelper + + return ConfigHelper.get_active_config_or_default().orchestrator + + +@app.route("/api/conversation/custom", methods=["GET", "POST"]) +def conversation_custom(): + message_orchestrator = get_message_orchestrator() try: user_message = request.json["messages"][-1]["content"] @@ -328,22 +338,12 @@ def conversation_custom(): request.json["messages"][0:-1], ) ) - chat_history = [] - for i, k in enumerate(user_assistant_messages): - if i % 2 == 0: - chat_history.append( - ( - user_assistant_messages[i]["content"], - user_assistant_messages[i + 1]["content"], - ) - ) - from utilities.helpers.ConfigHelper import ConfigHelper messages = message_orchestrator.handle_message( user_message=user_message, - chat_history=chat_history, + chat_history=user_assistant_messages, conversation_id=conversation_id, - orchestrator=ConfigHelper.get_active_config_or_default().orchestrator, + orchestrator=get_orchestrator_config(), ) response_obj = { diff --git a/code/utilities/orchestrator/LangChainAgent.py b/code/utilities/orchestrator/LangChainAgent.py index 445cf1598..5f02f60b7 100644 --- a/code/utilities/orchestrator/LangChainAgent.py +++ b/code/utilities/orchestrator/LangChainAgent.py @@ -89,8 +89,10 @@ def orchestrate( memory_key="chat_history", return_messages=True ) for message in chat_history: - memory.chat_memory.add_user_message(message[0]) - memory.chat_memory.add_ai_message(message[1]) + if message["role"] == "user": + memory.chat_memory.add_user_message(message["content"]) + elif message["role"] == "assistant": + memory.chat_memory.add_ai_message(message["content"]) # Define Agent and Agent Chain llm_chain = LLMChain(llm=llm_helper.get_llm(), prompt=prompt) agent = ZeroShotAgent(llm_chain=llm_chain, tools=self.tools, verbose=True) diff --git a/code/utilities/orchestrator/OpenAIFunctions.py b/code/utilities/orchestrator/OpenAIFunctions.py index c132f2ce0..ab7c8b8f4 100644 --- a/code/utilities/orchestrator/OpenAIFunctions.py +++ b/code/utilities/orchestrator/OpenAIFunctions.py @@ -81,8 +81,7 @@ def orchestrate( # Create conversation history messages = [{"role": "system", "content": system_message}] for message in chat_history: - messages.append({"role": "user", "content": message[0]}) - messages.append({"role": "assistant", "content": message[1]}) + messages.append({"role": message["role"], "content": message["content"]}) messages.append({"role": "user", "content": user_message}) result = llm_helper.get_chat_completion_with_functions( diff --git a/tests/test_app.py b/tests/test_app.py new file mode 100644 index 000000000..6466f5698 --- /dev/null +++ b/tests/test_app.py @@ -0,0 +1,161 @@ +import os + +from unittest.mock import Mock +from unittest.mock import patch + +from code.app.app import app + + +class TestConfig: + def test_returns_correct_config(self): + response = app.test_client().get("/api/config") + + assert response.status_code == 200 + assert response.json == {"azureSpeechKey": None, "azureSpeechRegion": None} + + +class TestCoversationCustom: + def setup_method(self): + self.orchestrator_config = {"strategy": "langchain"} + self.messages = [ + { + "content": '{"citations": [], "intent": "A question?"}', + "end_turn": False, + "role": "tool", + }, + {"content": "An answer", "end_turn": True, "role": "assistant"}, + ] + self.openai_model = "some-model" + self.body = { + "conversation_id": "123", + "messages": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi, how can I help?"}, + {"role": "user", "content": "What is the meaning of life?"}, + ], + } + + @patch("code.app.app.get_message_orchestrator") + @patch("code.app.app.get_orchestrator_config") + def test_converstation_custom_returns_correct_response( + self, get_orchestrator_config_mock, get_message_orchestrator_mock + ): + # given + get_orchestrator_config_mock.return_value = self.orchestrator_config + + message_orchestrator_mock = Mock() + message_orchestrator_mock.handle_message.return_value = self.messages + get_message_orchestrator_mock.return_value = message_orchestrator_mock + + os.environ["AZURE_OPENAI_MODEL"] = self.openai_model + + # when + response = app.test_client().post( + "/api/conversation/custom", + headers={"content-type": "application/json"}, + json=self.body, + ) + + # then + assert response.status_code == 200 + assert response.json == { + "choices": [{"messages": self.messages}], + "created": "response.created", + "id": "response.id", + "model": self.openai_model, + "object": "response.object", + } + + @patch("code.app.app.get_message_orchestrator") + @patch("code.app.app.get_orchestrator_config") + def test_converstation_custom_calls_message_orchestrator_correctly( + self, get_orchestrator_config_mock, get_message_orchestrator_mock + ): + # given + get_orchestrator_config_mock.return_value = self.orchestrator_config + + message_orchestrator_mock = Mock() + message_orchestrator_mock.handle_message.return_value = self.messages + get_message_orchestrator_mock.return_value = message_orchestrator_mock + + os.environ["AZURE_OPENAI_MODEL"] = self.openai_model + + # when + app.test_client().post( + "/api/conversation/custom", + headers={"content-type": "application/json"}, + json=self.body, + ) + + # then + message_orchestrator_mock.handle_message.assert_called_once_with( + user_message=self.body["messages"][-1]["content"], + chat_history=self.body["messages"][:-1], + conversation_id=self.body["conversation_id"], + orchestrator=self.orchestrator_config, + ) + + @patch("code.app.app.get_orchestrator_config") + def test_converstation_custom_returns_error_resonse_on_exception( + self, get_orchestrator_config_mock + ): + # given + get_orchestrator_config_mock.side_effect = Exception("An error occurred") + + # when + response = app.test_client().post( + "/api/conversation/custom", + headers={"content-type": "application/json"}, + json=self.body, + ) + + # then + assert response.status_code == 500 + assert response.json == { + "error": "Exception in /api/conversation/custom. See log for more details." + } + + @patch("code.app.app.get_message_orchestrator") + @patch("code.app.app.get_orchestrator_config") + def test_converstation_custom_allows_multiple_messages_from_user( + self, get_orchestrator_config_mock, get_message_orchestrator_mock + ): + """This can happen if there was an error getting a response from the assistant for the previous user message.""" + + # given + get_orchestrator_config_mock.return_value = self.orchestrator_config + + message_orchestrator_mock = Mock() + message_orchestrator_mock.handle_message.return_value = self.messages + get_message_orchestrator_mock.return_value = message_orchestrator_mock + + os.environ["AZURE_OPENAI_MODEL"] = self.openai_model + + body = { + "conversation_id": "123", + "messages": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi, how can I help?"}, + {"role": "user", "content": "What is the meaning of life?"}, + { + "role": "user", + "content": "Please, what is the meaning of life?", + }, + ], + } + + # when + response = app.test_client().post( + "/api/conversation/custom", + headers={"content-type": "application/json"}, + json=body, + ) + + # then + assert response.status_code == 200 + message_orchestrator_mock.handle_message.assert_called_once_with( + user_message=body["messages"][-1]["content"], + chat_history=body["messages"][:-1], + conversation_id=body["conversation_id"], + orchestrator=self.orchestrator_config, + )