From 750a3830f1b4ab293dfb30bcc9aa496778de36b5 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Wed, 5 Jun 2024 07:32:17 +0800 Subject: [PATCH] Fix output_prefix in do() method for ChatGPT Agent (#2457) Signed-off-by: Future-Outlier Co-authored-by: pingsutw --- flytekit/extend/backend/base_agent.py | 8 +++++-- .../flytekitplugins/openai/chatgpt/agent.py | 1 + .../tests/chatgpt/test_chatgpt.py | 22 +++++++++++++++++++ 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 33a03e282b..e8ec18806e 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -119,7 +119,9 @@ class SyncAgentBase(AgentBase): name = "Base Sync Agent" @abstractmethod - def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap], output_prefix: str, **kwargs) -> Resource: + def do( + self, task_template: TaskTemplate, output_prefix: str, inputs: Optional[LiteralMap] = None, **kwargs + ) -> Resource: """ This is the method that the agent will run. """ @@ -247,7 +249,9 @@ def execute(self: PythonTask, **kwargs) -> LiteralMap: agent = AgentRegistry.get_agent(task_template.type, task_template.task_type_version) - resource = asyncio.run(self._do(agent, task_template, output_prefix, kwargs)) + resource = asyncio.run( + self._do(agent=agent, template=task_template, output_prefix=output_prefix, inputs=kwargs) + ) if resource.phase != TaskExecution.SUCCEEDED: raise FlyteUserException(f"Failed to run the task {self.name} with error: {resource.message}") diff --git a/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py b/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py index afd3af1321..e4f24baa5a 100644 --- a/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py +++ b/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py @@ -27,6 +27,7 @@ async def do( self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, + **kwargs, ) -> Resource: ctx = FlyteContextManager.current_context() input_python_value = TypeEngine.literal_map_to_kwargs(ctx, inputs, {"message": str}) diff --git a/plugins/flytekit-openai/tests/chatgpt/test_chatgpt.py b/plugins/flytekit-openai/tests/chatgpt/test_chatgpt.py index 6298bdf52c..12de3da23b 100644 --- a/plugins/flytekit-openai/tests/chatgpt/test_chatgpt.py +++ b/plugins/flytekit-openai/tests/chatgpt/test_chatgpt.py @@ -1,4 +1,5 @@ from collections import OrderedDict +from unittest import mock from flytekitplugins.openai import ChatGPTTask @@ -7,6 +8,14 @@ from flytekit.models.types import SimpleType +async def mock_acreate(*args, **kwargs) -> str: + mock_response = mock.MagicMock() + mock_choice = mock.MagicMock() + mock_choice.message.content = "mocked_message" + mock_response.choices = [mock_choice] + return mock_response + + def test_chatgpt_task(): chatgpt_task = ChatGPTTask( name="chatgpt", @@ -40,3 +49,16 @@ def test_chatgpt_task(): assert chatgpt_task_spec.template.interface.inputs["message"].type.simple == SimpleType.STRING assert chatgpt_task_spec.template.interface.outputs["o0"].type.simple == SimpleType.STRING + + with mock.patch("openai.resources.chat.completions.AsyncCompletions.create", new=mock_acreate): + chatgpt_task = ChatGPTTask( + name="chatgpt", + openai_organization="TEST ORGANIZATION ID", + chatgpt_config={ + "model": "gpt-3.5-turbo", + "temperature": 0.7, + }, + ) + + response = chatgpt_task(message="hi") + assert response == "mocked_message"