From bc1727b5929a547a794347a624801134f0a02c11 Mon Sep 17 00:00:00 2001 From: Zbigniew Lukasiak Date: Tue, 10 Dec 2024 19:08:00 +0100 Subject: [PATCH] tool loop --- examples/tool_loop.py | 32 ++++++++++++++++++++++++++++++++ prompete/chat.py | 43 ++++++++++++++++++++++++++++++++++--------- 2 files changed, 66 insertions(+), 9 deletions(-) create mode 100644 examples/tool_loop.py diff --git a/examples/tool_loop.py b/examples/tool_loop.py new file mode 100644 index 0000000..ee89d81 --- /dev/null +++ b/examples/tool_loop.py @@ -0,0 +1,32 @@ +import logging +from prompete import Chat +from pprint import pprint + +# Configure logging +logging.basicConfig( + level=logging.DEBUG, + format='%(asctime)s - %(levelname)s - %(message)s' +) + +def get_current_weather(location: str, unit: str = "celsius") -> str: + """Get the current weather in a given location""" + # In a real scenario, you would call an actual weather API here + return { + "location": location, + "temperature": 22, + "unit": unit, + "forecast": ["sunny", "windy"], + } + +# Create a Chat instance +chat = Chat(model="gpt-4o-mini") + +# Define the user's question +user_question = "Please check the weather in London (using the `get_current_weather` function) and suggest an appropriate outfit." +answer = chat.tool_loop(user_question, max_loops = 3, tools=[get_current_weather]) + +# Print the results +print("User: ", user_question) +print("Answer: ", answer) + +pprint(chat.messages) diff --git a/prompete/chat.py b/prompete/chat.py index 3beb363..67eac0b 100644 --- a/prompete/chat.py +++ b/prompete/chat.py @@ -8,6 +8,7 @@ from llm_easy_tools.types import ChatCompletionMessageToolCall import logging +import json # Configure logging for this module @@ -102,6 +103,7 @@ def append(self, message: Union[Prompt, str, dict, Message]) -> None: Append a message to the chat. """ message_dict = self.make_message(message) + logging.debug(f"Appending message: {message_dict}") self.messages.append(message_dict) def __call__( @@ -114,7 +116,10 @@ def __call__( If the underlying LLM does not support response_format, we emulate it by using tools - but this is not perfecly reliable. """ self.append(message) + response_content = self.get_llm_response(response_format=response_format, **kwargs) + return response_content + def get_llm_response(self, response_format=None, **kwargs) -> str: if response_format: if kwargs.get("tools"): raise ValueError("tools and response_format cannot be used together") @@ -147,13 +152,13 @@ def llm_reply(self, tools=[], strict=False, **kwargs) -> ModelResponse: if len(schemas) > 0: args["tools"] = schemas - if len(schemas) == 1: - args["tool_choice"] = { - "type": "function", - "function": {"name": schemas[0]["function"]["name"]}, - } - else: - args["tool_choice"] = "auto" + #if len(schemas) == 1: + # args["tool_choice"] = { + # "type": "function", + # "function": {"name": schemas[0]["function"]["name"]}, + # } + #else: + args["tool_choice"] = "auto" args.update(kwargs) @@ -215,8 +220,28 @@ def get_tool_calls_message(self) -> Message: """ if not self.messages: return None - message = Message(**self.messages[-1]) - return message if hasattr(message, "tool_calls") else None + dict_message = self.messages[-1] + message = Message(**dict_message) + if hasattr(message, "tool_calls") and message.tool_calls: + return message + else: + return None + + def tool_loop(self, message: Prompt | dict | Message | str, max_loops: int, tools: list[Callable], **kwargs) -> Optional[str]: + """ + Repeatedly call the __call__ method until the LLM response does not contain a tool call + or the maximum number of loops is reached. + """ + response = self.__call__(message, tools=tools, **kwargs) + loop_count = 0 + while loop_count < max_loops: + if not self.get_tool_calls_message(): + return self.messages[-1]['content'] + self.process() + self.get_llm_response(tools=tools, **kwargs) + loop_count += 1 + return None + if __name__ == "__main__": import os