Skip to content

Commit

Permalink
feat: support importing tools from LangChain (#1745)
Browse files Browse the repository at this point in the history
Co-authored-by: Charles Packer <[email protected]>
  • Loading branch information
sarahwooders and cpacker authored Sep 11, 2024
1 parent 7b830a8 commit 459c367
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 4 deletions.
4 changes: 4 additions & 0 deletions memgpt/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1721,6 +1721,10 @@ def send_message(
messages = self.interface.to_list()
for m in messages:
assert isinstance(m, Message), f"Expected Message object, got {type(m)}"
memgpt_messages = []
for m in messages:
memgpt_messages += m.to_memgpt_message()
return MemGPTResponse(messages=memgpt_messages, usage=usage)

# format messages
if include_full_message:
Expand Down
35 changes: 33 additions & 2 deletions memgpt/functions/schema_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,15 @@ def generate_schema_from_args_schema(
properties = {}
required = []
for field_name, field in args_schema.__fields__.items():
properties[field_name] = {"type": field.type_.__name__, "description": field.field_info.description}
if field.type_.__name__ == "str":
field_type = "string"
elif field.type_.__name__ == "int":
field_type = "integer"
elif field.type_.__name__ == "bool":
field_type = "boolean"
else:
field_type = field.type_.__name__
properties[field_name] = {"type": field_type, "description": field.field_info.description}
if field.required:
required.append(field_name)

Expand All @@ -158,7 +166,28 @@ def generate_schema_from_args_schema(
return function_call_json


def generate_tool_wrapper(tool_name: str) -> str:
def generate_langchain_tool_wrapper(tool_name: str) -> str:
import_statement = f"from langchain_community.tools import {tool_name}"

# NOTE: this will fail for tools like 'wikipedia = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())' since it needs to pass an argument to the tool instantiation
# https://python.langchain.com/v0.1/docs/integrations/tools/wikipedia/
tool_instantiation = f"tool = {tool_name}()"
run_call = f"return tool._run(**kwargs)"
func_name = f"run_{tool_name.lower()}"

# Combine all parts into the wrapper function
wrapper_function_str = f"""
def {func_name}(**kwargs):
if 'self' in kwargs:
del kwargs['self']
{import_statement}
{tool_instantiation}
{run_call}
"""
return func_name, wrapper_function_str


def generate_crewai_tool_wrapper(tool_name: str) -> str:
import_statement = f"from crewai_tools import {tool_name}"
tool_instantiation = f"tool = {tool_name}()"
run_call = f"return tool._run(**kwargs)"
Expand All @@ -167,6 +196,8 @@ def generate_tool_wrapper(tool_name: str) -> str:
# Combine all parts into the wrapper function
wrapper_function_str = f"""
def {func_name}(**kwargs):
if 'self' in kwargs:
del kwargs['self']
{import_statement}
{tool_instantiation}
{run_call}
Expand Down
46 changes: 44 additions & 2 deletions memgpt/schemas/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from pydantic import Field

from memgpt.functions.schema_generator import (
generate_crewai_tool_wrapper,
generate_langchain_tool_wrapper,
generate_schema_from_args_schema,
generate_tool_wrapper,
)
from memgpt.schemas.memgpt_base import MemGPTBase
from memgpt.schemas.openai.chat_completions import ToolCall
Expand Down Expand Up @@ -56,6 +57,40 @@ def to_dict(self):
)
)

@classmethod
def from_langchain(cls, langchain_tool) -> "Tool":
"""
Class method to create an instance of Tool from a Langchain tool (must be from langchain_community.tools).
Args:
langchain_tool (LangchainTool): An instance of a crewAI BaseTool (BaseTool from crewai)
Returns:
Tool: A MemGPT Tool initialized with attributes derived from the provided crewAI BaseTool object.
"""
description = langchain_tool.description
source_type = "python"
tags = ["langchain"]
# NOTE: langchain tools may come from different packages
wrapper_func_name, wrapper_function_str = generate_langchain_tool_wrapper(langchain_tool.__class__.__name__)
json_schema = generate_schema_from_args_schema(langchain_tool.args_schema, name=wrapper_func_name, description=description)

# append heartbeat (necessary for triggering another reasoning step after this tool call)
json_schema["parameters"]["properties"]["request_heartbeat"] = {
"type": "boolean",
"description": "Request an immediate heartbeat after function execution. Set to 'true' if you want to send a follow-up message or run a follow-up function.",
}
json_schema["parameters"]["required"].append("request_heartbeat")

return cls(
name=wrapper_func_name,
description=description,
source_type=source_type,
tags=tags,
source_code=wrapper_function_str,
json_schema=json_schema,
)

@classmethod
def from_crewai(cls, crewai_tool) -> "Tool":
"""
Expand All @@ -71,9 +106,16 @@ def from_crewai(cls, crewai_tool) -> "Tool":
description = crewai_tool.description
source_type = "python"
tags = ["crew-ai"]
wrapper_func_name, wrapper_function_str = generate_tool_wrapper(crewai_tool.__class__.__name__)
wrapper_func_name, wrapper_function_str = generate_crewai_tool_wrapper(crewai_tool.__class__.__name__)
json_schema = generate_schema_from_args_schema(crewai_tool.args_schema, name=wrapper_func_name, description=description)

# append heartbeat (necessary for triggering another reasoning step after this tool call)
json_schema["parameters"]["properties"]["request_heartbeat"] = {
"type": "boolean",
"description": "Request an immediate heartbeat after function execution. Set to 'true' if you want to send a follow-up message or run a follow-up function.",
}
json_schema["parameters"]["required"].append("request_heartbeat")

return cls(
name=wrapper_func_name,
description=description,
Expand Down

0 comments on commit 459c367

Please sign in to comment.