From 7f2e89121766b5d124c68c722415bae9c2ff4110 Mon Sep 17 00:00:00 2001 From: Arsenii Shatokhin Date: Tue, 13 Aug 2024 11:22:48 +0400 Subject: [PATCH] Added get_completion_parse method --- agency_swarm/agency/agency.py | 78 +++++++++++++++++++++++++++++----- agency_swarm/agents/agent.py | 3 +- agency_swarm/threads/thread.py | 29 +++++++------ agency_swarm/util/errors.py | 2 + tests/test_agency.py | 5 +++ 5 files changed, 92 insertions(+), 25 deletions(-) create mode 100644 agency_swarm/util/errors.py diff --git a/agency_swarm/agency/agency.py b/agency_swarm/agency/agency.py index 5f4a3bf4..45412c28 100644 --- a/agency_swarm/agency/agency.py +++ b/agency_swarm/agency/agency.py @@ -3,14 +3,20 @@ import os import queue import threading -import time import uuid from enum import Enum -from typing import List, TypedDict, Callable, Any, Dict, Literal, Union, Optional +from typing import Any, Callable, Dict, List, Literal, Optional, Type, TypeVar, TypedDict, Union +from openai.lib._parsing._completions import type_to_response_format_param from openai.types.beta.threads import Message from openai.types.beta.threads.runs import RunStep -from pydantic import Field, field_validator, model_validator +from openai.types.beta.threads.runs.tool_call import ( + CodeInterpreterToolCall, + FileSearchToolCall, + FunctionToolCall, + ToolCall, +) +from pydantic import BaseModel, Field, field_validator, model_validator from rich.console import Console from typing_extensions import override @@ -18,17 +24,16 @@ from agency_swarm.messages import MessageOutput from agency_swarm.messages.message_output import MessageOutputLive from agency_swarm.threads import Thread -from agency_swarm.tools import BaseTool, FileSearch, CodeInterpreter +from agency_swarm.tools import BaseTool, CodeInterpreter, FileSearch from agency_swarm.user import User +from agency_swarm.util.errors import RefusalError from agency_swarm.util.files import determine_file_type from agency_swarm.util.shared_state import SharedState -from openai.types.beta.threads.runs.tool_call import ToolCall, FunctionToolCall, CodeInterpreterToolCall, FileSearchToolCall - - from agency_swarm.util.streaming import AgencyEventHandler console = Console() +T = TypeVar('T', bound=BaseModel) class SettingsCallbacks(TypedDict): load: Callable[[], List[Dict]] @@ -127,7 +132,8 @@ def get_completion(self, message: str, additional_instructions: str = None, attachments: List[dict] = None, tool_choice: dict = None, - verbose: bool = False): + verbose: bool = False, + response_format: dict = None): """ Retrieves the completion for a given message from the main thread. @@ -141,6 +147,7 @@ def get_completion(self, message: str, tool_choice (dict, optional): The tool choice for the recipient agent to use. Defaults to None. parallel_tool_calls (bool, optional): Whether to enable parallel function calling during tool use. Defaults to True. verbose (bool, optional): Whether to print the intermediary messages in console. Defaults to False. + response_format (dict, optional): The response format to use for the completion. Returns: Generator or final response: Depending on the 'yield_messages' flag, this method returns either a generator yielding intermediate messages or the final response from the main thread. @@ -154,7 +161,9 @@ def get_completion(self, message: str, recipient_agent=recipient_agent, additional_instructions=additional_instructions, tool_choice=tool_choice, - yield_messages=yield_messages or verbose) + yield_messages=yield_messages or verbose, + response_format=response_format) + if not yield_messages or verbose: while True: try: @@ -174,7 +183,8 @@ def get_completion_stream(self, recipient_agent: Agent = None, additional_instructions: str = None, attachments: List[dict] = None, - tool_choice: dict = None): + tool_choice: dict = None, + response_format: dict = None): """ Generates a stream of completions for a given message from the main thread. @@ -200,14 +210,60 @@ def get_completion_stream(self, attachments=attachments, recipient_agent=recipient_agent, additional_instructions=additional_instructions, - tool_choice=tool_choice) + tool_choice=tool_choice, + response_format=response_format) while True: try: next(res) except StopIteration as e: event_handler.on_all_streams_end() + return e.value + + def get_completion_parse(self, message: str, + response_format: Type[T], + message_files: List[str] = None, + recipient_agent: Agent = None, + additional_instructions: str = None, + attachments: List[dict] = None, + tool_choice: dict = None) -> T: + """ + Retrieves the completion for a given message from the main thread and parses the response using the provided response format. + + Parameters: + message (str): The message for which completion is to be retrieved. + response_format (type(T)): The response format to use for the completion. + message_files (list, optional): A list of file ids to be sent as attachments with the message. When using this parameter, files will be assigned both to file_search and code_interpreter tools if available. It is recommended to assign files to the most sutiable tool manually, using the attachments parameter. Defaults to None. + recipient_agent (Agent, optional): The agent to which the message should be sent. Defaults to the first agent in the agency chart. + additional_instructions (str, optional): Additional instructions to be sent with the message. Defaults to None. + attachments (List[dict], optional): A list of attachments to be sent with the message, following openai format. Defaults to None. + tool_choice (dict, optional): The tool choice for the recipient agent to use. Defaults to None. + + Returns: + Final response: The final response from the main thread, parsed using the provided response format. + """ + response_model = None + if isinstance(response_format, type): + response_model = response_format + response_format = type_to_response_format_param(response_format) + + res = self.get_completion(message=message, + message_files=message_files, + recipient_agent=recipient_agent, + additional_instructions=additional_instructions, + attachments=attachments, + tool_choice=tool_choice, + response_format=response_format) + + try: + return response_model.model_validate_json(res) + except: + parsed_res = json.loads(res) + if 'refusal' in parsed_res: + raise RefusalError(parsed_res['refusal']) + else: + raise Exception("Failed to parse response: " + res) def demo_gradio(self, height=450, dark_mode=True, **kwargs): """ diff --git a/agency_swarm/agents/agent.py b/agency_swarm/agents/agent.py index eaee5bf6..e49229ae 100644 --- a/agency_swarm/agents/agent.py +++ b/agency_swarm/agents/agent.py @@ -16,7 +16,6 @@ from agency_swarm.util.openapi import validate_openapi_spec from agency_swarm.util.shared_state import SharedState from pydantic import BaseModel -from openai import pydantic_function_tool from openai.lib._parsing._completions import type_to_response_format_param class ExampleMessage(TypedDict): @@ -106,7 +105,7 @@ def __init__( tool_resources (ToolResources, optional): A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the code_interpreter tool requires a list of file IDs, while the file_search tool requires a list of vector store IDs. Defaults to None. temperature (float, optional): The temperature parameter for the OpenAI API. Defaults to None. top_p (float, optional): The top_p parameter for the OpenAI API. Defaults to None. - response_format (Dict, optional): The response format for the OpenAI API. Defaults to None. + response_format (Union[str, Dict, type], optional): The response format for the OpenAI API. If BaseModel is provided, it will be converted to a response format. Defaults to None. tools_folder (str, optional): Path to a directory containing tools associated with the agent. Each tool must be defined in a separate file. File must be named as the class name of the tool. Defaults to None. files_folder (Union[List[str], str], optional): Path or list of paths to directories containing files associated with the agent. Defaults to None. schemas_folder (Union[List[str], str], optional): Path or list of paths to directories containing OpenAPI schemas associated with the agent. Defaults to None. diff --git a/agency_swarm/threads/thread.py b/agency_swarm/threads/thread.py index 54d164ea..5462870d 100644 --- a/agency_swarm/threads/thread.py +++ b/agency_swarm/threads/thread.py @@ -3,7 +3,7 @@ import json import os import time -from typing import List, Optional, Union +from typing import List, Optional, Type, Union from openai import BadRequestError from openai.types.beta import AssistantToolChoice @@ -59,9 +59,10 @@ def get_completion_stream(self, event_handler: type(AgencyEventHandler), message_files: List[str] = None, attachments: Optional[List[Attachment]] = None, - recipient_agent=None, + recipient_agent:Agent=None, additional_instructions: str = None, - tool_choice: AssistantToolChoice = None): + tool_choice: AssistantToolChoice = None, + response_format: Optional[dict] = None): return self.get_completion(message, message_files, @@ -70,17 +71,19 @@ def get_completion_stream(self, additional_instructions, event_handler, tool_choice, - yield_messages=False) + yield_messages=False, + response_format=response_format) def get_completion(self, message: str | List[dict], message_files: List[str] = None, attachments: Optional[List[dict]] = None, - recipient_agent=None, + recipient_agent: Agent = None, additional_instructions: str = None, event_handler: type(AgencyEventHandler) = None, tool_choice: AssistantToolChoice = None, - yield_messages: bool = False + yield_messages: bool = False, + response_format: Optional[dict] = None ): if not recipient_agent: recipient_agent = self.recipient_agent @@ -121,7 +124,7 @@ def get_completion(self, if yield_messages: yield MessageOutput("text", self.agent.name, recipient_agent.name, message, message_obj) - self._create_run(recipient_agent, additional_instructions, event_handler, tool_choice) + self._create_run(recipient_agent, additional_instructions, event_handler, tool_choice, response_format=response_format) error_attempts = 0 validation_attempts = 0 @@ -235,14 +238,14 @@ def handle_output(tool_call, output): # retry run 2 times if error_attempts < 1 and "something went wrong" in self.run.last_error.message.lower(): time.sleep(1) - self._create_run(recipient_agent, additional_instructions, event_handler, tool_choice) + self._create_run(recipient_agent, additional_instructions, event_handler, tool_choice, response_format=response_format) error_attempts += 1 elif 1 <= error_attempts < 5 and "something went wrong" in self.run.last_error.message.lower(): self.create_message( message="Continue.", role="user" ) - self._create_run(recipient_agent, additional_instructions, event_handler, tool_choice) + self._create_run(recipient_agent, additional_instructions, event_handler, tool_choice, response_format=response_format) error_attempts += 1 else: raise Exception("OpenAI Run Failed. Error: ", self.run.last_error.message) @@ -292,13 +295,13 @@ def handle_output(tool_call, output): validation_attempts += 1 - self._create_run(recipient_agent, additional_instructions, event_handler, tool_choice) + self._create_run(recipient_agent, additional_instructions, event_handler, tool_choice, response_format=response_format) continue return last_message - def _create_run(self, recipient_agent, additional_instructions, event_handler, tool_choice, temperature=None): + def _create_run(self, recipient_agent, additional_instructions, event_handler, tool_choice, temperature=None, response_format: Optional[dict] = None): if event_handler: with self.client.beta.threads.runs.stream( thread_id=self.thread.id, @@ -311,6 +314,7 @@ def _create_run(self, recipient_agent, additional_instructions, event_handler, t truncation_strategy=recipient_agent.truncation_strategy, temperature=temperature, extra_body={"parallel_tool_calls": recipient_agent.parallel_tool_calls}, + response_format=response_format ) as stream: stream.until_done() self.run = stream.get_final_run() @@ -324,7 +328,8 @@ def _create_run(self, recipient_agent, additional_instructions, event_handler, t max_completion_tokens=recipient_agent.max_completion_tokens, truncation_strategy=recipient_agent.truncation_strategy, temperature=temperature, - parallel_tool_calls=recipient_agent.parallel_tool_calls + parallel_tool_calls=recipient_agent.parallel_tool_calls, + response_format=response_format ) self.run = self.client.beta.threads.runs.poll( thread_id=self.thread.id, diff --git a/agency_swarm/util/errors.py b/agency_swarm/util/errors.py new file mode 100644 index 00000000..4831a8f0 --- /dev/null +++ b/agency_swarm/util/errors.py @@ -0,0 +1,2 @@ +class RefusalError(Exception): + pass \ No newline at end of file diff --git a/tests/test_agency.py b/tests/test_agency.py index 64ea77f9..68f4eeab 100644 --- a/tests/test_agency.py +++ b/tests/test_agency.py @@ -521,6 +521,11 @@ class Step(BaseModel): # check if result is a MathReasoning object self.assertTrue(MathReasoning.model_validate_json(result)) + result = agency.get_completion_parse("how can I solve 3x + 2 = 14", response_format=MathReasoning) + + # check if result is a MathReasoning object + self.assertTrue(isinstance(result, MathReasoning)) + # --- Helper methods --- def get_class_folder_path(self):