Skip to content

Commit

Permalink
feat: Identify Composio tools (#721)
Browse files Browse the repository at this point in the history
Co-authored-by: Caren Thomas <[email protected]>
Co-authored-by: Shubham Naik <[email protected]>
Co-authored-by: Shubham Naik <[email protected]>
Co-authored-by: mlong93 <[email protected]>
Co-authored-by: Mindy Long <[email protected]>
  • Loading branch information
6 people authored Jan 23, 2025
1 parent bb91dab commit cc8f93c
Show file tree
Hide file tree
Showing 12 changed files with 210 additions and 72 deletions.
51 changes: 51 additions & 0 deletions alembic/versions/f895232c144a_backfill_composio_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Backfill composio tools
Revision ID: f895232c144a
Revises: 25fc99e97839
Create Date: 2025-01-16 14:21:33.764332
"""

from typing import Sequence, Union

from alembic import op
from letta.orm.enums import ToolType

# revision identifiers, used by Alembic.
revision: str = "f895232c144a"
down_revision: Union[str, None] = "416b9d2db10b"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# Define the value for EXTERNAL_COMPOSIO
external_composio_value = ToolType.EXTERNAL_COMPOSIO.value

# Update tool_type to EXTERNAL_COMPOSIO if the tags field includes "composio"
# This is super brittle and awful but no other way to do this
op.execute(
f"""
UPDATE tools
SET tool_type = '{external_composio_value}'
WHERE tags::jsonb @> '["composio"]';
"""
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
custom_value = ToolType.CUSTOM.value

# Update tool_type to CUSTOM if the tags field includes "composio"
# This is super brittle and awful but no other way to do this
op.execute(
f"""
UPDATE tools
SET tool_type = '{custom_value}'
WHERE tags::jsonb @> '["composio"]';
"""
)
# ### end Alembic commands ###
2 changes: 1 addition & 1 deletion letta/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2893,7 +2893,7 @@ def load_crewai_tool(self, crewai_tool: "CrewAIBaseTool", additional_imports_mod

def load_composio_tool(self, action: "ActionType") -> Tool:
tool_create = ToolCreate.from_composio(action_name=action.name)
return self.server.tool_manager.create_or_update_tool(pydantic_tool=Tool(**tool_create.model_dump()), actor=self.user)
return self.server.tool_manager.create_or_update_composio_tool(pydantic_tool=Tool(**tool_create.model_dump()), actor=self.user)

def create_tool(
self,
Expand Down
33 changes: 29 additions & 4 deletions letta/functions/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,37 @@
from letta.schemas.message import MessageCreate


def generate_composio_tool_wrapper(action_name: str) -> tuple[str, str]:
# Instantiate the object
tool_instantiation_str = f"composio_toolset.get_tools(actions=['{action_name}'])[0]"
# TODO: This is kind of hacky, as this is used to search up the action later on composio's side
# TODO: So be very careful changing/removing these pair of functions
def generate_func_name_from_composio_action(action_name: str) -> str:
"""
Generates the composio function name from the composio action.
Args:
action_name: The composio action name
Returns:
function name
"""
return action_name.lower()


def generate_composio_action_from_func_name(func_name: str) -> str:
"""
Generates the composio action from the composio function name.
Args:
func_name: The composio function name
Returns:
composio action name
"""
return func_name.upper()


def generate_composio_tool_wrapper(action_name: str) -> tuple[str, str]:
# Generate func name
func_name = action_name.lower()
func_name = generate_func_name_from_composio_action(action_name)

wrapper_function_str = f"""
def {func_name}(**kwargs):
Expand Down
55 changes: 55 additions & 0 deletions letta/functions/schema_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings
from typing import Any, Dict, List, Optional, Type, Union, get_args, get_origin

from composio.client.collections import ActionParametersModel
from docstring_parser import parse
from pydantic import BaseModel

Expand Down Expand Up @@ -429,3 +430,57 @@ def generate_schema_from_args_schema_v2(
function_call_json["parameters"]["required"].append("request_heartbeat")

return function_call_json


def generate_tool_schema_for_composio(
parameters_model: ActionParametersModel,
name: str,
description: str,
append_heartbeat: bool = True,
) -> Dict[str, Any]:
properties_json = {}
required_fields = parameters_model.required or []

# Extract properties from the ActionParametersModel
for field_name, field_props in parameters_model.properties.items():
# Initialize the property structure
property_schema = {
"type": field_props["type"],
"description": field_props.get("description", ""),
}

# Handle optional default values
if "default" in field_props:
property_schema["default"] = field_props["default"]

# Handle enumerations
if "enum" in field_props:
property_schema["enum"] = field_props["enum"]

# Handle array item types
if field_props["type"] == "array" and "items" in field_props:
property_schema["items"] = field_props["items"]

# Add the property to the schema
properties_json[field_name] = property_schema

# Add the optional heartbeat parameter
if append_heartbeat:
properties_json["request_heartbeat"] = {
"type": "boolean",
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
}
required_fields.append("request_heartbeat")

# Return the final schema
return {
"name": name,
"description": description,
"strict": True,
"parameters": {
"type": "object",
"properties": properties_json,
"additionalProperties": False,
"required": required_fields,
},
}
1 change: 1 addition & 0 deletions letta/orm/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ class ToolType(str, Enum):
LETTA_CORE = "letta_core"
LETTA_MEMORY_CORE = "letta_memory_core"
LETTA_MULTI_AGENT_CORE = "letta_multi_agent_core"
EXTERNAL_COMPOSIO = "external_composio"


class JobType(str, Enum):
Expand Down
81 changes: 38 additions & 43 deletions letta/schemas/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@
LETTA_MULTI_AGENT_TOOL_MODULE_NAME,
)
from letta.functions.functions import derive_openai_json_schema, get_json_schema_from_module
from letta.functions.helpers import generate_composio_tool_wrapper, generate_langchain_tool_wrapper
from letta.functions.schema_generator import generate_schema_from_args_schema_v2
from letta.functions.helpers import generate_composio_action_from_func_name, generate_composio_tool_wrapper, generate_langchain_tool_wrapper
from letta.functions.schema_generator import generate_schema_from_args_schema_v2, generate_tool_schema_for_composio
from letta.log import get_logger
from letta.orm.enums import ToolType
from letta.schemas.letta_base import LettaBase

logger = get_logger(__name__)


class BaseTool(LettaBase):
__id_prefix__ = "tool"
Expand Down Expand Up @@ -52,14 +55,16 @@ class Tool(BaseTool):
last_updated_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")

@model_validator(mode="after")
def populate_missing_fields(self):
def refresh_source_code_and_json_schema(self):
"""
Populate missing fields: name, description, and json_schema.
Refresh name, description, source_code, and json_schema.
"""
if self.tool_type == ToolType.CUSTOM:
# If it's a custom tool, we need to ensure source_code is present
if not self.source_code:
raise ValueError(f"Custom tool with id={self.id} is missing source_code field.")
error_msg = f"Custom tool with id={self.id} is missing source_code field."
logger.error(error_msg)
raise ValueError(error_msg)

# Always derive json_schema for freshest possible json_schema
# TODO: Instead of checking the tag, we should having `COMPOSIO` as a specific ToolType
Expand All @@ -72,6 +77,24 @@ def populate_missing_fields(self):
elif self.tool_type in {ToolType.LETTA_MULTI_AGENT_CORE}:
# If it's letta multi-agent tool, we also generate the json_schema on the fly here
self.json_schema = get_json_schema_from_module(module_name=LETTA_MULTI_AGENT_TOOL_MODULE_NAME, function_name=self.name)
elif self.tool_type == ToolType.EXTERNAL_COMPOSIO:
# If it is a composio tool, we generate both the source code and json schema on the fly here
# TODO: This is brittle, need to think long term about how to improve this
try:
composio_action = generate_composio_action_from_func_name(self.name)
tool_create = ToolCreate.from_composio(composio_action)
self.source_code = tool_create.source_code
self.json_schema = tool_create.json_schema
self.description = tool_create.description
self.tags = tool_create.tags
except Exception as e:
logger.error(f"Encountered exception while attempting to refresh source_code and json_schema for composio_tool: {e}")

# At this point, we need to validate that at least json_schema is populated
if not self.json_schema:
error_msg = f"Tool with id={self.id} name={self.name} tool_type={self.tool_type} is missing a json_schema."
logger.error(error_msg)
raise ValueError(error_msg)

# Derive name from the JSON schema if not provided
if not self.name:
Expand Down Expand Up @@ -100,7 +123,7 @@ class ToolCreate(LettaBase):
return_char_limit: int = Field(FUNCTION_RETURN_CHAR_LIMIT, description="The maximum number of characters in the response.")

@classmethod
def from_composio(cls, action_name: str, api_key: Optional[str] = None) -> "ToolCreate":
def from_composio(cls, action_name: str) -> "ToolCreate":
"""
Class method to create an instance of Letta-compatible Composio Tool.
Check https://docs.composio.dev/introduction/intro/overview to look at options for from_composio
Expand All @@ -115,24 +138,21 @@ def from_composio(cls, action_name: str, api_key: Optional[str] = None) -> "Tool
from composio import LogLevel
from composio_langchain import ComposioToolSet

if api_key:
# Pass in an external API key
composio_toolset = ComposioToolSet(logging_level=LogLevel.ERROR, api_key=api_key)
else:
# Use environmental variable
composio_toolset = ComposioToolSet(logging_level=LogLevel.ERROR)
composio_tools = composio_toolset.get_tools(actions=[action_name])
composio_toolset = ComposioToolSet(logging_level=LogLevel.ERROR)
composio_action_schemas = composio_toolset.get_action_schemas(actions=[action_name], check_connected_accounts=False)

assert len(composio_tools) > 0, "User supplied parameters do not match any Composio tools"
assert len(composio_tools) == 1, f"User supplied parameters match too many Composio tools; {len(composio_tools)} > 1"
assert len(composio_action_schemas) > 0, "User supplied parameters do not match any Composio tools"
assert (
len(composio_action_schemas) == 1
), f"User supplied parameters match too many Composio tools; {len(composio_action_schemas)} > 1"

composio_tool = composio_tools[0]
composio_action_schema = composio_action_schemas[0]

description = composio_tool.description
description = composio_action_schema.description
source_type = "python"
tags = [COMPOSIO_TOOL_TAG_NAME]
wrapper_func_name, wrapper_function_str = generate_composio_tool_wrapper(action_name)
json_schema = generate_schema_from_args_schema_v2(composio_tool.args_schema, name=wrapper_func_name, description=description)
json_schema = generate_tool_schema_for_composio(composio_action_schema.parameters, name=wrapper_func_name, description=description)

return cls(
name=wrapper_func_name,
Expand Down Expand Up @@ -175,31 +195,6 @@ def from_langchain(
json_schema=json_schema,
)

@classmethod
def load_default_langchain_tools(cls) -> List["ToolCreate"]:
# For now, we only support wikipedia tool
from langchain_community.tools import WikipediaQueryRun
from langchain_community.utilities import WikipediaAPIWrapper

wikipedia_tool = ToolCreate.from_langchain(
WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper()), {"langchain_community.utilities": "WikipediaAPIWrapper"}
)

return [wikipedia_tool]

@classmethod
def load_default_composio_tools(cls) -> List["ToolCreate"]:
pass

# TODO: Disable composio tools for now
# TODO: Naming is causing issues
# calculator = ToolCreate.from_composio(action_name=Action.MATHEMATICAL_CALCULATOR.name)
# serp_news = ToolCreate.from_composio(action_name=Action.SERPAPI_NEWS_SEARCH.name)
# serp_google_search = ToolCreate.from_composio(action_name=Action.SERPAPI_SEARCH.name)
# serp_google_maps = ToolCreate.from_composio(action_name=Action.SERPAPI_GOOGLE_MAPS_SEARCH.name)

return []


class ToolUpdate(LettaBase):
description: Optional[str] = Field(None, description="The description of the tool.")
Expand Down
5 changes: 2 additions & 3 deletions letta/server/rest_api/routers/v1/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,10 @@ def add_composio_tool(
Add a new Composio tool by action name (Composio refers to each tool as an `Action`)
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
composio_api_key = get_composio_key(server, actor=actor)

try:
tool_create = ToolCreate.from_composio(action_name=composio_action_name, api_key=composio_api_key)
return server.tool_manager.create_or_update_tool(pydantic_tool=Tool(**tool_create.model_dump()), actor=actor)
tool_create = ToolCreate.from_composio(action_name=composio_action_name)
return server.tool_manager.create_or_update_composio_tool(pydantic_tool=Tool(**tool_create.model_dump()), actor=actor)
except EnumStringNotFound as e:
raise HTTPException(
status_code=400, # Bad Request
Expand Down
5 changes: 5 additions & 0 deletions letta/services/tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ def create_or_update_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser

return tool

@enforce_types
def create_or_update_composio_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
pydantic_tool.tool_type = ToolType.EXTERNAL_COMPOSIO
return self.create_or_update_tool(pydantic_tool, actor)

@enforce_types
def create_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
"""Create a new tool based on the ToolCreate schema."""
Expand Down
4 changes: 2 additions & 2 deletions tests/integration_test_tool_execution_sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,15 +183,15 @@ def create_list():
def composio_github_star_tool(test_user):
tool_manager = ToolManager()
tool_create = ToolCreate.from_composio(action_name="GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER")
tool = tool_manager.create_or_update_tool(pydantic_tool=Tool(**tool_create.model_dump()), actor=test_user)
tool = tool_manager.create_or_update_composio_tool(pydantic_tool=Tool(**tool_create.model_dump()), actor=test_user)
yield tool


@pytest.fixture
def composio_gmail_get_profile_tool(test_user):
tool_manager = ToolManager()
tool_create = ToolCreate.from_composio(action_name="GMAIL_GET_PROFILE")
tool = tool_manager.create_or_update_tool(pydantic_tool=Tool(**tool_create.model_dump()), actor=test_user)
tool = tool_manager.create_or_update_composio_tool(pydantic_tool=Tool(**tool_create.model_dump()), actor=test_user)
yield tool


Expand Down
17 changes: 16 additions & 1 deletion tests/test_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from letta.schemas.source import Source as PydanticSource
from letta.schemas.source import SourceUpdate
from letta.schemas.tool import Tool as PydanticTool
from letta.schemas.tool import ToolUpdate
from letta.schemas.tool import ToolCreate, ToolUpdate
from letta.schemas.tool_rule import InitToolRule
from letta.schemas.user import User as PydanticUser
from letta.schemas.user import UserUpdate
Expand Down Expand Up @@ -195,6 +195,13 @@ def print_tool(message: str):
yield tool


@pytest.fixture
def composio_github_star_tool(server, default_user):
tool_create = ToolCreate.from_composio(action_name="GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER")
tool = server.tool_manager.create_or_update_composio_tool(pydantic_tool=PydanticTool(**tool_create.model_dump()), actor=default_user)
yield tool


@pytest.fixture
def default_job(server: SyncServer, default_user):
"""Fixture to create and return a default job."""
Expand Down Expand Up @@ -1548,6 +1555,14 @@ def test_create_tool(server: SyncServer, print_tool, default_user, default_organ
# Assertions to ensure the created tool matches the expected values
assert print_tool.created_by_id == default_user.id
assert print_tool.organization_id == default_organization.id
assert print_tool.tool_type == ToolType.CUSTOM


def test_create_composio_tool(server: SyncServer, composio_github_star_tool, default_user, default_organization):
# Assertions to ensure the created tool matches the expected values
assert composio_github_star_tool.created_by_id == default_user.id
assert composio_github_star_tool.organization_id == default_organization.id
assert composio_github_star_tool.tool_type == ToolType.EXTERNAL_COMPOSIO


@pytest.mark.skipif(USING_SQLITE, reason="Test not applicable when using SQLite.")
Expand Down
Loading

0 comments on commit cc8f93c

Please sign in to comment.