Skip to content

Commit

Permalink
Moved get_openapi_schema method into tool factory
Browse files Browse the repository at this point in the history
  • Loading branch information
VRSEN committed Feb 24, 2024
1 parent 63265c8 commit 2f85913
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 134 deletions.
63 changes: 1 addition & 62 deletions agency_swarm/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,68 +367,7 @@ def get_openapi_schema(self, url):
if self.assistant is None:
raise Exception("Assistant is not initialized. Please initialize the agency first, before using this method")

schema = {
"openapi": "3.1.0",
"info": {
"title": self.name,
"description": self.description if self.description else "",
"version": "v1.0.0"
},
"servers": [
{
"url": url,
}
],
"paths": {},
"components": {
"schemas": {},
"securitySchemes": {
"apiKey": {
"type": "apiKey"
}
}
},
}

for tool in self.tools:
if issubclass(tool, BaseTool):
openai_schema = tool.openai_schema
defs = {}
if '$defs' in openai_schema['parameters']:
defs = openai_schema['parameters']['$defs']
del openai_schema['parameters']['$defs']

schema['paths']["/" + openai_schema['name']] = {
"post": {
"description": openai_schema['description'],
"operationId": openai_schema['name'],
"parameters": [],
"requestBody": {
"content": {
"application/json": {
"schema": openai_schema['parameters']
}
},
"required": True,
},
"deprecated": False,
"security": [
{
"apiKey": []
}
],
"x-openai-isConsequential": False,
}
}

if defs:
schema['components']['schemas'].update(**defs)

print(openai_schema)

schema = json.dumps(schema, indent=2).replace("#/$defs/", "#/components/schemas/")

return schema
return ToolFactory.get_openapi_schema(self.tools, url)

# --- Settings Methods ---

Expand Down
59 changes: 0 additions & 59 deletions agency_swarm/tools/BaseTool.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,65 +48,6 @@ def openai_schema(cls):

return schema

@classmethod
def openapi_schema(cls, url):
openai_schema = cls.openai_schema
defs = {}
if '$defs' in openai_schema['parameters']:
defs = openai_schema['parameters']['$defs']
del openai_schema['parameters']['$defs']

schema = {
"openapi": "3.1.0",
"info": {
"title": openai_schema['name'],
"description": openai_schema['description'],
"version": "v1.0.0"
},
"servers": [
{
"url": url,
}
],
"paths": {},
"components": {
"schemas": {},
"securitySchemes": {
"apiKey": {
"type": "apiKey"
}
}
},
}

schema['paths']["/" + openai_schema['name']] = {
"post": {
"description": openai_schema['description'],
"operationId": openai_schema['name'],
"parameters": [],
"requestBody": {
"content": {
"application/json": {
"schema": openai_schema['parameters']
}
},
"required": True,
},
"deprecated": False,
"security": [
{
"apiKey": []
}
],
"x-openai-isConsequential": False,
}
}

if defs:
schema['components']['schemas'].update(**defs)

return schema

@abstractmethod
def run(self, **kwargs):
pass
66 changes: 66 additions & 0 deletions agency_swarm/tools/ToolFactory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import importlib.util
import inspect
import json
import os
import sys
from typing import Any, Dict, List, Type, Union
Expand Down Expand Up @@ -280,4 +281,69 @@ def from_file(file_path: str):

return tool

@staticmethod
def get_openapi_schema(tools: List[Type[BaseTool]], url: str, title="Agent Tools",
description="A collection of tools."):
"""
Generates an OpenAPI schema from a list of BaseTools.
:param tools: BaseTools to generate the schema from.
:param url: The base URL for the schema.
:param title: The title of the schema.
:param description: The description of the schema.
:return: A JSON string representing the OpenAPI schema with all the tools combined as separate endpoints.
"""
schema = {
"openapi": "3.1.0",
"info": {
"title": title,
"description": description,
"version": "v1.0.0"
},
"servers": [
{
"url": url,
}
],
"paths": {},
"components": {
"schemas": {},
"securitySchemes": {
"apiKey": {
"type": "apiKey"
}
}
},
}

for tool in tools:
if not issubclass(tool, BaseTool):
continue

openai_schema = tool.openai_schema
defs = {}
if '$defs' in openai_schema['parameters']:
defs = openai_schema['parameters']['$defs']
del openai_schema['parameters']['$defs']

schema['paths']["/" + openai_schema['name']] = {
"post": {
"description": openai_schema['description'],
"operationId": openai_schema['name'],
"parameters": [],
"requestBody": {
"content": {
"application/json": {
"schema": openai_schema['parameters']
}
}
}
}
}

schema['components']['schemas'].update(defs)

schema = json.dumps(schema, indent=2).replace("#/$defs/", "#/components/schemas/")

return schema
8 changes: 8 additions & 0 deletions tests/test_tool_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,14 @@ def test_import_from_file(self):

self.assertTrue(tool(content='test').run() == "Tool output")

def test_openapi_schema(self):
with open("./data/schemas/get-headers-params.json", "r") as f:
tools = ToolFactory.from_openapi_schema(f.read())

schema = ToolFactory.get_openapi_schema(tools, "123")

self.assertTrue(schema)




Expand Down
13 changes: 0 additions & 13 deletions tests/test_tools.py

This file was deleted.

0 comments on commit 2f85913

Please sign in to comment.