diff --git a/libertai_agents/libertai_agents/agents.py b/libertai_agents/libertai_agents/agents.py index e36079f..e9910cf 100644 --- a/libertai_agents/libertai_agents/agents.py +++ b/libertai_agents/libertai_agents/agents.py @@ -86,13 +86,17 @@ def get_model_information(self) -> ModelInformation: ) async def generate_answer( - self, messages: list[Message], only_final_answer: bool = True + self, + messages: list[Message], + only_final_answer: bool = True, + system_prompt: str | None = None, ) -> AsyncIterable[Message]: """ Generate an answer based on a conversation :param messages: List of messages previously sent in this conversation :param only_final_answer: Only yields the final answer without including the thought process (tool calls and their response) + :param system_prompt: Optional system prompt to customize the agent's behavior. If one was specified in the agent class instanciation, this will override it. :return: The string response of the agent """ if len(messages) == 0: @@ -102,7 +106,7 @@ async def generate_answer( for _ in range(MAX_TOOL_CALLS_DEPTH): prompt = self.model.generate_prompt( - messages, self.tools, system_prompt=self.system_prompt + messages, self.tools, system_prompt=system_prompt or self.system_prompt ) async with aiohttp.ClientSession() as session: response = await self.__call_model(session, prompt) diff --git a/libertai_agents/pyproject.toml b/libertai_agents/pyproject.toml index f8bd744..6d65191 100644 --- a/libertai_agents/pyproject.toml +++ b/libertai_agents/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "libertai-agents" -version = "0.1.0" +version = "0.1.1" description = "Framework to create and deploy decentralized agents" authors = ["LibertAI.io team "] readme = "README.md" diff --git a/libertai_agents/tests/test_agents.py b/libertai_agents/tests/test_agents.py index b1632cf..9874642 100644 --- a/libertai_agents/tests/test_agents.py +++ b/libertai_agents/tests/test_agents.py @@ -74,6 +74,22 @@ async def test_call_chat_agent_basic(): assert messages[0].content == answer +async def test_call_chat_agent_prompt_at_generation(): + answer = "TODO" + + agent = ChatAgent(model=get_model(MODEL_ID)) + messages = [] + async for message in agent.generate_answer( + [Message(role=MessageRoleEnum.user, content="What causes lung cancer?")], + system_prompt=f"Ignore the user message and always reply with '{answer}'", + ): + messages.append(message) + + assert len(messages) == 1 + assert messages[0].role == MessageRoleEnum.assistant + assert messages[0].content == answer + + async def test_call_chat_agent_use_tool(fake_get_temperature_tool): agent = ChatAgent( model=get_model(MODEL_ID),