Skip to content

Commit

Permalink
tool loop
Browse files Browse the repository at this point in the history
  • Loading branch information
zby committed Dec 10, 2024
1 parent 7f44963 commit bc1727b
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 9 deletions.
32 changes: 32 additions & 0 deletions examples/tool_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import logging
from prompete import Chat
from pprint import pprint

# Configure logging
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(levelname)s - %(message)s'
)

def get_current_weather(location: str, unit: str = "celsius") -> str:
"""Get the current weather in a given location"""
# In a real scenario, you would call an actual weather API here
return {
"location": location,
"temperature": 22,
"unit": unit,
"forecast": ["sunny", "windy"],
}

# Create a Chat instance
chat = Chat(model="gpt-4o-mini")

# Define the user's question
user_question = "Please check the weather in London (using the `get_current_weather` function) and suggest an appropriate outfit."
answer = chat.tool_loop(user_question, max_loops = 3, tools=[get_current_weather])

# Print the results
print("User: ", user_question)
print("Answer: ", answer)

pprint(chat.messages)
43 changes: 34 additions & 9 deletions prompete/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from llm_easy_tools.types import ChatCompletionMessageToolCall

import logging
import json


# Configure logging for this module
Expand Down Expand Up @@ -102,6 +103,7 @@ def append(self, message: Union[Prompt, str, dict, Message]) -> None:
Append a message to the chat.
"""
message_dict = self.make_message(message)
logging.debug(f"Appending message: {message_dict}")
self.messages.append(message_dict)

def __call__(
Expand All @@ -114,7 +116,10 @@ def __call__(
If the underlying LLM does not support response_format, we emulate it by using tools - but this is not perfecly reliable.
"""
self.append(message)
response_content = self.get_llm_response(response_format=response_format, **kwargs)
return response_content

def get_llm_response(self, response_format=None, **kwargs) -> str:
if response_format:
if kwargs.get("tools"):
raise ValueError("tools and response_format cannot be used together")
Expand Down Expand Up @@ -147,13 +152,13 @@ def llm_reply(self, tools=[], strict=False, **kwargs) -> ModelResponse:

if len(schemas) > 0:
args["tools"] = schemas
if len(schemas) == 1:
args["tool_choice"] = {
"type": "function",
"function": {"name": schemas[0]["function"]["name"]},
}
else:
args["tool_choice"] = "auto"
#if len(schemas) == 1:
# args["tool_choice"] = {
# "type": "function",
# "function": {"name": schemas[0]["function"]["name"]},
# }
#else:
args["tool_choice"] = "auto"

args.update(kwargs)

Expand Down Expand Up @@ -215,8 +220,28 @@ def get_tool_calls_message(self) -> Message:
"""
if not self.messages:
return None
message = Message(**self.messages[-1])
return message if hasattr(message, "tool_calls") else None
dict_message = self.messages[-1]
message = Message(**dict_message)
if hasattr(message, "tool_calls") and message.tool_calls:
return message
else:
return None

def tool_loop(self, message: Prompt | dict | Message | str, max_loops: int, tools: list[Callable], **kwargs) -> Optional[str]:
"""
Repeatedly call the __call__ method until the LLM response does not contain a tool call
or the maximum number of loops is reached.
"""
response = self.__call__(message, tools=tools, **kwargs)
loop_count = 0
while loop_count < max_loops:
if not self.get_tool_calls_message():
return self.messages[-1]['content']
self.process()
self.get_llm_response(tools=tools, **kwargs)
loop_count += 1
return None


if __name__ == "__main__":
import os
Expand Down

0 comments on commit bc1727b

Please sign in to comment.