From 2107467d8bfea24b2ed00e8a56a8917fa5dc398b Mon Sep 17 00:00:00 2001 From: Arsenii Shatokhin Date: Tue, 27 Aug 2024 10:39:20 +0400 Subject: [PATCH] Adjusted logic for uploading message files --- agency_swarm/agency/agency.py | 8 +++----- agency_swarm/util/__init__.py | 2 +- agency_swarm/util/files.py | 15 +++++++++------ 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/agency_swarm/agency/agency.py b/agency_swarm/agency/agency.py index 6738e9b5..3c147b14 100644 --- a/agency_swarm/agency/agency.py +++ b/agency_swarm/agency/agency.py @@ -27,7 +27,7 @@ from agency_swarm.tools import BaseTool, CodeInterpreter, FileSearch from agency_swarm.user import User from agency_swarm.util.errors import RefusalError -from agency_swarm.util.files import determine_file_type, get_tools +from agency_swarm.util.files import get_tools, get_file_purpose from agency_swarm.util.shared_state import SharedState from agency_swarm.util.streaming import AgencyEventHandler @@ -330,9 +330,7 @@ def handle_file_upload(file_list): if file_list: try: for file_obj in file_list: - file_type = determine_file_type(file_obj.name) - purpose = "assistants" if file_type != "vision" else "vision" - tools = [{"type": "code_interpreter"}] if file_type == "assistants.code_interpreter" else [{"type": "file_search"}] + purpose = get_file_purpose(file_obj.name) with open(file_obj.name, 'rb') as f: # Upload the file to OpenAI @@ -341,7 +339,7 @@ def handle_file_upload(file_list): purpose=purpose ) - if file_type == "vision": + if purpose == "vision": images.append({ "type": "image_file", "image_file": {"file_id": file.id} diff --git a/agency_swarm/util/__init__.py b/agency_swarm/util/__init__.py index b4c5feb0..479b725a 100644 --- a/agency_swarm/util/__init__.py +++ b/agency_swarm/util/__init__.py @@ -1,5 +1,5 @@ from .cli.create_agent_template import create_agent_template from .cli.import_agent import import_agent from .oai import set_openai_key, get_openai_client, set_openai_client -from .files import determine_file_type +from .files import get_tools, get_file_purpose from .validators import llm_validator \ No newline at end of file diff --git a/agency_swarm/util/files.py b/agency_swarm/util/files.py index 3da1b5ac..010d6ca1 100644 --- a/agency_swarm/util/files.py +++ b/agency_swarm/util/files.py @@ -1,5 +1,9 @@ import mimetypes +image_types = [ + "image/jpeg", "image/jpg", "image/png", "image/webp", "image/gif" +] + code_interpreter_types = [ "application/csv", "image/jpeg", "image/gif", "image/png", "application/x-tar", "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", @@ -17,18 +21,17 @@ "application/typescript" ] -def determine_file_type(file_path): +def get_file_purpose(file_path): mime_type, _ = mimetypes.guess_type(file_path) if mime_type: - if mime_type in code_interpreter_types: - return "assistants.code_interpreter" - elif mime_type.startswith('image/'): + if mime_type in image_types: return "vision" - elif mime_type in dual_types: - return "assistants.file_search" + if mime_type in code_interpreter_types or mime_type in dual_types: + return "assistants" raise ValueError(f"Unsupported file type: {mime_type}") def get_tools(file_path): + """Returns the tools for the given file path""" mime_type, _ = mimetypes.guess_type(file_path) if mime_type in code_interpreter_types: return [{"type": "code_interpreter"}]