Skip to content

Commit

Permalink
support passing a session to generate_answer
Browse files Browse the repository at this point in the history
  • Loading branch information
moshemalawach committed Jan 9, 2025
1 parent b4bf68a commit eebad33
Showing 1 changed file with 34 additions and 30 deletions.
64 changes: 34 additions & 30 deletions libertai_agents/libertai_agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ async def generate_answer(
messages: list[Message],
only_final_answer: bool = True,
system_prompt: str | None = None,
session: ClientSession | None = None,
) -> AsyncIterable[Message]:
"""
Generate an answer based on a conversation
Expand All @@ -108,42 +109,45 @@ async def generate_answer(
prompt = self.model.generate_prompt(
messages, self.tools, system_prompt=system_prompt or self.system_prompt
)
async with aiohttp.ClientSession() as session:
if session is None:
async with aiohttp.ClientSession() as session:
response = await self.__call_model(session, prompt)
else:
response = await self.__call_model(session, prompt)

if response is None:
# TODO: handle error correctly
raise ValueError("Model didn't respond")
if response is None:
# TODO: handle error correctly
raise ValueError("Model didn't respond")

tool_calls = self.model.extract_tool_calls_from_response(response)
if len(tool_calls) == 0:
yield Message(role=MessageRoleEnum.assistant, content=response)
return
tool_calls = self.model.extract_tool_calls_from_response(response)
if len(tool_calls) == 0:
yield Message(role=MessageRoleEnum.assistant, content=response)
return

# Executing the detected tool calls
tool_calls_message = self.__create_tool_calls_message(tool_calls)
messages.append(tool_calls_message)
if not only_final_answer:
yield tool_calls_message
# Executing the detected tool calls
tool_calls_message = self.__create_tool_calls_message(tool_calls)
messages.append(tool_calls_message)
if not only_final_answer:
yield tool_calls_message

executed_calls = self.__execute_tool_calls(
tool_calls_message.tool_calls
executed_calls = self.__execute_tool_calls(
tool_calls_message.tool_calls
)
results = await asyncio.gather(*executed_calls)
tool_results_messages: list[Message] = [
ToolResponseMessage(
role=MessageRoleEnum.tool,
name=call.function.name,
tool_call_id=call.id,
content=str(results[i]),
)
results = await asyncio.gather(*executed_calls)
tool_results_messages: list[Message] = [
ToolResponseMessage(
role=MessageRoleEnum.tool,
name=call.function.name,
tool_call_id=call.id,
content=str(results[i]),
)
for i, call in enumerate(tool_calls_message.tool_calls)
]
if not only_final_answer:
for tool_result_message in tool_results_messages:
yield tool_result_message
# Doing the next iteration of the loop with the results to make other tool calls or to answer
messages = messages + tool_results_messages
for i, call in enumerate(tool_calls_message.tool_calls)
]
if not only_final_answer:
for tool_result_message in tool_results_messages:
yield tool_result_message
# Doing the next iteration of the loop with the results to make other tool calls or to answer
messages = messages + tool_results_messages

async def __api_generate_answer(
self,
Expand Down

0 comments on commit eebad33

Please sign in to comment.