From eebad33be5ff96a04a9eec024a767d2eb2facd7d Mon Sep 17 00:00:00 2001 From: Moshe Malawachh Date: Thu, 9 Jan 2025 15:20:43 +0100 Subject: [PATCH] support passing a session to generate_answer --- libertai_agents/libertai_agents/agents.py | 64 ++++++++++++----------- 1 file changed, 34 insertions(+), 30 deletions(-) diff --git a/libertai_agents/libertai_agents/agents.py b/libertai_agents/libertai_agents/agents.py index e9910cf..f14607e 100644 --- a/libertai_agents/libertai_agents/agents.py +++ b/libertai_agents/libertai_agents/agents.py @@ -90,6 +90,7 @@ async def generate_answer( messages: list[Message], only_final_answer: bool = True, system_prompt: str | None = None, + session: ClientSession | None = None, ) -> AsyncIterable[Message]: """ Generate an answer based on a conversation @@ -108,42 +109,45 @@ async def generate_answer( prompt = self.model.generate_prompt( messages, self.tools, system_prompt=system_prompt or self.system_prompt ) - async with aiohttp.ClientSession() as session: + if session is None: + async with aiohttp.ClientSession() as session: + response = await self.__call_model(session, prompt) + else: response = await self.__call_model(session, prompt) - if response is None: - # TODO: handle error correctly - raise ValueError("Model didn't respond") + if response is None: + # TODO: handle error correctly + raise ValueError("Model didn't respond") - tool_calls = self.model.extract_tool_calls_from_response(response) - if len(tool_calls) == 0: - yield Message(role=MessageRoleEnum.assistant, content=response) - return + tool_calls = self.model.extract_tool_calls_from_response(response) + if len(tool_calls) == 0: + yield Message(role=MessageRoleEnum.assistant, content=response) + return - # Executing the detected tool calls - tool_calls_message = self.__create_tool_calls_message(tool_calls) - messages.append(tool_calls_message) - if not only_final_answer: - yield tool_calls_message + # Executing the detected tool calls + tool_calls_message = self.__create_tool_calls_message(tool_calls) + messages.append(tool_calls_message) + if not only_final_answer: + yield tool_calls_message - executed_calls = self.__execute_tool_calls( - tool_calls_message.tool_calls + executed_calls = self.__execute_tool_calls( + tool_calls_message.tool_calls + ) + results = await asyncio.gather(*executed_calls) + tool_results_messages: list[Message] = [ + ToolResponseMessage( + role=MessageRoleEnum.tool, + name=call.function.name, + tool_call_id=call.id, + content=str(results[i]), ) - results = await asyncio.gather(*executed_calls) - tool_results_messages: list[Message] = [ - ToolResponseMessage( - role=MessageRoleEnum.tool, - name=call.function.name, - tool_call_id=call.id, - content=str(results[i]), - ) - for i, call in enumerate(tool_calls_message.tool_calls) - ] - if not only_final_answer: - for tool_result_message in tool_results_messages: - yield tool_result_message - # Doing the next iteration of the loop with the results to make other tool calls or to answer - messages = messages + tool_results_messages + for i, call in enumerate(tool_calls_message.tool_calls) + ] + if not only_final_answer: + for tool_result_message in tool_results_messages: + yield tool_result_message + # Doing the next iteration of the loop with the results to make other tool calls or to answer + messages = messages + tool_results_messages async def __api_generate_answer( self,