From 2f85913f82d18b34ad8e973a17abaf00158a366c Mon Sep 17 00:00:00 2001 From: Arsenii Shatokhin Date: Sat, 24 Feb 2024 08:48:45 +0400 Subject: [PATCH] Moved get_openapi_schema method into tool factory --- agency_swarm/agents/agent.py | 63 +---------------------------- agency_swarm/tools/BaseTool.py | 59 --------------------------- agency_swarm/tools/ToolFactory.py | 66 +++++++++++++++++++++++++++++++ tests/test_tool_factory.py | 8 ++++ tests/test_tools.py | 13 ------ 5 files changed, 75 insertions(+), 134 deletions(-) delete mode 100644 tests/test_tools.py diff --git a/agency_swarm/agents/agent.py b/agency_swarm/agents/agent.py index 6a8cf178..122c8ddf 100644 --- a/agency_swarm/agents/agent.py +++ b/agency_swarm/agents/agent.py @@ -367,68 +367,7 @@ def get_openapi_schema(self, url): if self.assistant is None: raise Exception("Assistant is not initialized. Please initialize the agency first, before using this method") - schema = { - "openapi": "3.1.0", - "info": { - "title": self.name, - "description": self.description if self.description else "", - "version": "v1.0.0" - }, - "servers": [ - { - "url": url, - } - ], - "paths": {}, - "components": { - "schemas": {}, - "securitySchemes": { - "apiKey": { - "type": "apiKey" - } - } - }, - } - - for tool in self.tools: - if issubclass(tool, BaseTool): - openai_schema = tool.openai_schema - defs = {} - if '$defs' in openai_schema['parameters']: - defs = openai_schema['parameters']['$defs'] - del openai_schema['parameters']['$defs'] - - schema['paths']["/" + openai_schema['name']] = { - "post": { - "description": openai_schema['description'], - "operationId": openai_schema['name'], - "parameters": [], - "requestBody": { - "content": { - "application/json": { - "schema": openai_schema['parameters'] - } - }, - "required": True, - }, - "deprecated": False, - "security": [ - { - "apiKey": [] - } - ], - "x-openai-isConsequential": False, - } - } - - if defs: - schema['components']['schemas'].update(**defs) - - print(openai_schema) - - schema = json.dumps(schema, indent=2).replace("#/$defs/", "#/components/schemas/") - - return schema + return ToolFactory.get_openapi_schema(self.tools, url) # --- Settings Methods --- diff --git a/agency_swarm/tools/BaseTool.py b/agency_swarm/tools/BaseTool.py index 0acdc3b2..2056fd52 100644 --- a/agency_swarm/tools/BaseTool.py +++ b/agency_swarm/tools/BaseTool.py @@ -48,65 +48,6 @@ def openai_schema(cls): return schema - @classmethod - def openapi_schema(cls, url): - openai_schema = cls.openai_schema - defs = {} - if '$defs' in openai_schema['parameters']: - defs = openai_schema['parameters']['$defs'] - del openai_schema['parameters']['$defs'] - - schema = { - "openapi": "3.1.0", - "info": { - "title": openai_schema['name'], - "description": openai_schema['description'], - "version": "v1.0.0" - }, - "servers": [ - { - "url": url, - } - ], - "paths": {}, - "components": { - "schemas": {}, - "securitySchemes": { - "apiKey": { - "type": "apiKey" - } - } - }, - } - - schema['paths']["/" + openai_schema['name']] = { - "post": { - "description": openai_schema['description'], - "operationId": openai_schema['name'], - "parameters": [], - "requestBody": { - "content": { - "application/json": { - "schema": openai_schema['parameters'] - } - }, - "required": True, - }, - "deprecated": False, - "security": [ - { - "apiKey": [] - } - ], - "x-openai-isConsequential": False, - } - } - - if defs: - schema['components']['schemas'].update(**defs) - - return schema - @abstractmethod def run(self, **kwargs): pass diff --git a/agency_swarm/tools/ToolFactory.py b/agency_swarm/tools/ToolFactory.py index cd56b9b1..a26c2fab 100644 --- a/agency_swarm/tools/ToolFactory.py +++ b/agency_swarm/tools/ToolFactory.py @@ -1,5 +1,6 @@ import importlib.util import inspect +import json import os import sys from typing import Any, Dict, List, Type, Union @@ -280,4 +281,69 @@ def from_file(file_path: str): return tool + @staticmethod + def get_openapi_schema(tools: List[Type[BaseTool]], url: str, title="Agent Tools", + description="A collection of tools."): + """ + Generates an OpenAPI schema from a list of BaseTools. + + :param tools: BaseTools to generate the schema from. + :param url: The base URL for the schema. + :param title: The title of the schema. + :param description: The description of the schema. + + :return: A JSON string representing the OpenAPI schema with all the tools combined as separate endpoints. + """ + schema = { + "openapi": "3.1.0", + "info": { + "title": title, + "description": description, + "version": "v1.0.0" + }, + "servers": [ + { + "url": url, + } + ], + "paths": {}, + "components": { + "schemas": {}, + "securitySchemes": { + "apiKey": { + "type": "apiKey" + } + } + }, + } + + for tool in tools: + if not issubclass(tool, BaseTool): + continue + + openai_schema = tool.openai_schema + defs = {} + if '$defs' in openai_schema['parameters']: + defs = openai_schema['parameters']['$defs'] + del openai_schema['parameters']['$defs'] + + schema['paths']["/" + openai_schema['name']] = { + "post": { + "description": openai_schema['description'], + "operationId": openai_schema['name'], + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": openai_schema['parameters'] + } + } + } + } + } + + schema['components']['schemas'].update(defs) + + schema = json.dumps(schema, indent=2).replace("#/$defs/", "#/components/schemas/") + return schema diff --git a/tests/test_tool_factory.py b/tests/test_tool_factory.py index 2dcc0f1c..d55a2bbe 100644 --- a/tests/test_tool_factory.py +++ b/tests/test_tool_factory.py @@ -126,6 +126,14 @@ def test_import_from_file(self): self.assertTrue(tool(content='test').run() == "Tool output") + def test_openapi_schema(self): + with open("./data/schemas/get-headers-params.json", "r") as f: + tools = ToolFactory.from_openapi_schema(f.read()) + + schema = ToolFactory.get_openapi_schema(tools, "123") + + self.assertTrue(schema) + diff --git a/tests/test_tools.py b/tests/test_tools.py deleted file mode 100644 index ef689862..00000000 --- a/tests/test_tools.py +++ /dev/null @@ -1,13 +0,0 @@ -import unittest - -from agency_swarm.tools.coding import ChangeDir - - -class ToolsTest(unittest.TestCase): - def test_change_dir_example(self): - output = ChangeDir(path="./").run() - self.assertFalse("error" in output.lower()) - - -if __name__ == '__main__': - unittest.main()