Skip to content

Commit

Permalink
feat: Allow changing the system prompt for each generation
Browse files Browse the repository at this point in the history
  • Loading branch information
RezaRahemtola committed Dec 21, 2024
1 parent 7baaf02 commit 769b04f
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 3 deletions.
8 changes: 6 additions & 2 deletions libertai_agents/libertai_agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,17 @@ def get_model_information(self) -> ModelInformation:
)

async def generate_answer(
self, messages: list[Message], only_final_answer: bool = True
self,
messages: list[Message],
only_final_answer: bool = True,
system_prompt: str | None = None,
) -> AsyncIterable[Message]:
"""
Generate an answer based on a conversation
:param messages: List of messages previously sent in this conversation
:param only_final_answer: Only yields the final answer without including the thought process (tool calls and their response)
:param system_prompt: Optional system prompt to customize the agent's behavior. If one was specified in the agent class instanciation, this will override it.
:return: The string response of the agent
"""
if len(messages) == 0:
Expand All @@ -102,7 +106,7 @@ async def generate_answer(

for _ in range(MAX_TOOL_CALLS_DEPTH):
prompt = self.model.generate_prompt(
messages, self.tools, system_prompt=self.system_prompt
messages, self.tools, system_prompt=system_prompt or self.system_prompt
)
async with aiohttp.ClientSession() as session:
response = await self.__call_model(session, prompt)
Expand Down
2 changes: 1 addition & 1 deletion libertai_agents/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "libertai-agents"
version = "0.1.0"
version = "0.1.1"
description = "Framework to create and deploy decentralized agents"
authors = ["LibertAI.io team <[email protected]>"]
readme = "README.md"
Expand Down
16 changes: 16 additions & 0 deletions libertai_agents/tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,22 @@ async def test_call_chat_agent_basic():
assert messages[0].content == answer


async def test_call_chat_agent_prompt_at_generation():
answer = "TODO"

agent = ChatAgent(model=get_model(MODEL_ID))
messages = []
async for message in agent.generate_answer(
[Message(role=MessageRoleEnum.user, content="What causes lung cancer?")],
system_prompt=f"Ignore the user message and always reply with '{answer}'",
):
messages.append(message)

assert len(messages) == 1
assert messages[0].role == MessageRoleEnum.assistant
assert messages[0].content == answer


async def test_call_chat_agent_use_tool(fake_get_temperature_tool):
agent = ChatAgent(
model=get_model(MODEL_ID),
Expand Down

0 comments on commit 769b04f

Please sign in to comment.