From bc3e518741e7e66d87a1fe8442916d45d41a8a67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Dvo=C5=99=C3=A1k?= Date: Tue, 25 Feb 2025 10:01:12 +0100 Subject: [PATCH 1/4] feat: typings and error propagating (#383) Ref: #378 Signed-off-by: Tomas Dvorak --- python/beeai_framework/agents/base.py | 32 +++++++++++++------------- python/beeai_framework/backend/chat.py | 27 ++++++++-------------- 2 files changed, 26 insertions(+), 33 deletions(-) diff --git a/python/beeai_framework/agents/base.py b/python/beeai_framework/agents/base.py index 15dbd2fa..7312a448 100644 --- a/python/beeai_framework/agents/base.py +++ b/python/beeai_framework/agents/base.py @@ -38,24 +38,24 @@ def run(self, run_input: ModelLike[BeeRunInput], options: ModelLike[BeeRunOption if self.is_running: raise RuntimeError("Agent is already running!") - try: - self.is_running = True + self.is_running = True - async def handler(context: RunContext) -> T: + async def handler(context: RunContext) -> T: + try: return await self._run(run_input, options, context) - - return RunContext.enter( - RunInstance(emitter=self.emitter), - RunContextInput(signal=options.signal if options else None, params=(run_input, options)), - handler, - ) - except Exception as e: - if isinstance(e, RuntimeError): - raise e - else: - raise RuntimeError("Error has occurred!") from e - finally: - self.is_running = False + except Exception as e: + if isinstance(e, RuntimeError): + raise e + else: + raise RuntimeError("Error has occurred!") from e + finally: + self.is_running = False + + return RunContext.enter( + RunInstance(emitter=self.emitter), + RunContextInput(signal=options.signal if options else None, params=(run_input, options)), + handler, + ) @abstractmethod async def _run(self, run_input: BeeRunInput, options: BeeRunOptions | None, context: RunContext) -> T: diff --git a/python/beeai_framework/backend/chat.py b/python/beeai_framework/backend/chat.py index 1d2b8f35..76aef256 100644 --- a/python/beeai_framework/backend/chat.py +++ b/python/beeai_framework/backend/chat.py @@ -18,7 +18,7 @@ from collections.abc import AsyncGenerator, Callable from typing import Annotated, Any, Literal, Self, TypeVar -from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, ValidationError +from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, InstanceOf, ValidationError from beeai_framework.backend.constants import ProviderName from beeai_framework.backend.errors import ChatModelError @@ -95,17 +95,10 @@ class ChatModelUsage(BaseModel): total_tokens: int -class ChatModelOutput: - def __init__( - self, - *, - messages: list[Message], - usage: ChatModelUsage | None = None, - finish_reason: str | None = None, - ) -> None: - self.messages = messages - self.usage = usage - self.finish_reason = finish_reason +class ChatModelOutput(BaseModel): + messages: list[InstanceOf[Message]] + usage: InstanceOf[ChatModelUsage] | None = None + finish_reason: str | None = None @classmethod def from_chunks(cls, chunks: list) -> Self: @@ -211,7 +204,7 @@ class DefaultChatModelStructureSchema(BaseModel): # TODO: validate result matches expected schema return ChatModelStructureOutput(object=result) - def create(self, chat_model_input: ModelLike[ChatModelInput]) -> Run: + def create(self, chat_model_input: ModelLike[ChatModelInput]) -> Run[ChatModelOutput]: input = to_model(ChatModelInput, chat_model_input) async def run_create(context: RunContext) -> ChatModelOutput: @@ -235,11 +228,11 @@ async def run_create(context: RunContext) -> ChatModelOutput: await context.emitter.emit("success", {"value": result}) return result except ChatModelError as error: - await context.emitter.emit("error", {input, error}) + await context.emitter.emit("error", error) raise error - except Exception as ex: - await context.emitter.emit("error", {input, ex}) - raise ChatModelError("Model error has occurred.") from ex + except Exception as error: + await context.emitter.emit("error", error) + raise ChatModelError("Model error has occurred.") from error finally: await context.emitter.emit("finish", None) From 5d538f4bde3421c4e5fbda1198ec8bfbc768d97b Mon Sep 17 00:00:00 2001 From: Alex Bozarth Date: Tue, 25 Feb 2025 08:34:41 -0600 Subject: [PATCH 2/4] fix: update some type hints (#381) --- python/beeai_framework/agents/bee/agent.py | 5 +++-- python/beeai_framework/context.py | 4 ++-- python/beeai_framework/emitter/emitter.py | 12 +++++++----- python/beeai_framework/utils/{_types.py => types.py} | 0 python/beeai_framework/workflows/workflow.py | 2 +- 5 files changed, 13 insertions(+), 10 deletions(-) rename python/beeai_framework/utils/{_types.py => types.py} (100%) diff --git a/python/beeai_framework/agents/bee/agent.py b/python/beeai_framework/agents/bee/agent.py index 3f0b51a9..8fc0b9f6 100644 --- a/python/beeai_framework/agents/bee/agent.py +++ b/python/beeai_framework/agents/bee/agent.py @@ -38,13 +38,14 @@ from beeai_framework.context import RunContext from beeai_framework.emitter import Emitter, EmitterInput from beeai_framework.memory import BaseMemory +from beeai_framework.utils.models import ModelLike, to_model class BeeAgent(BaseAgent[BeeRunOutput]): runner: Callable[..., BaseRunner] - def __init__(self, bee_input: BeeInput) -> None: - self.input = bee_input + def __init__(self, bee_input: ModelLike[BeeInput]) -> None: + self.input = to_model(BeeInput, bee_input) if "granite" in self.input.llm.model_id: self.runner = GraniteRunner else: diff --git a/python/beeai_framework/context.py b/python/beeai_framework/context.py index d36ff8e8..8c116c1b 100644 --- a/python/beeai_framework/context.py +++ b/python/beeai_framework/context.py @@ -58,7 +58,7 @@ def observe(self, fn: Callable[[Emitter], Any]) -> Self: self.tasks.append((fn, self.run_context.emitter)) return self - def context(self, context: "RunContext") -> Self: + def context(self, context: dict) -> Self: self.tasks.append((self._set_context, context)) return self @@ -73,7 +73,7 @@ async def _run_tasks(self) -> R: self.tasks.clear() return await self.handler() - def _set_context(self, context: "RunContext") -> None: + def _set_context(self, context: dict) -> None: self.run_context.context = context self.run_context.emitter.context = context diff --git a/python/beeai_framework/emitter/emitter.py b/python/beeai_framework/emitter/emitter.py index fe6e6cee..b7b78902 100644 --- a/python/beeai_framework/emitter/emitter.py +++ b/python/beeai_framework/emitter/emitter.py @@ -20,7 +20,7 @@ import uuid from collections.abc import Callable from datetime import UTC, datetime -from typing import Any, Generic, TypeVar +from typing import Any, Generic, ParamSpec, TypeAlias, TypeVar from pydantic import BaseModel, ConfigDict, InstanceOf @@ -30,13 +30,15 @@ assert_valid_name, assert_valid_namespace, ) +from beeai_framework.utils.types import MaybeAsync T = TypeVar("T", bound=BaseModel) +P = ParamSpec("P") -MatcherFn = Callable[["EventMeta"], bool] -Matcher = str | MatcherFn -Callback = Callable[[Any, "EventMeta"], Any] -CleanupFn = Callable[[], None] +MatcherFn: TypeAlias = Callable[["EventMeta"], bool] +Matcher: TypeAlias = str | MatcherFn +Callback: TypeAlias = MaybeAsync[[P, "EventMeta"], None] +CleanupFn: TypeAlias = Callable[[], None] class Listener(BaseModel): diff --git a/python/beeai_framework/utils/_types.py b/python/beeai_framework/utils/types.py similarity index 100% rename from python/beeai_framework/utils/_types.py rename to python/beeai_framework/utils/types.py diff --git a/python/beeai_framework/workflows/workflow.py b/python/beeai_framework/workflows/workflow.py index 6049071a..c1b1e2e2 100644 --- a/python/beeai_framework/workflows/workflow.py +++ b/python/beeai_framework/workflows/workflow.py @@ -20,8 +20,8 @@ from pydantic import BaseModel from typing_extensions import TypeVar -from beeai_framework.utils._types import MaybeAsync from beeai_framework.utils.models import ModelLike, check_model, to_model +from beeai_framework.utils.types import MaybeAsync from beeai_framework.workflows.errors import WorkflowError T = TypeVar("T", bound=BaseModel) From 91cbbb9dbbd1396761df6fd2ad15a4359b61a626 Mon Sep 17 00:00:00 2001 From: Graham White Date: Tue, 25 Feb 2025 16:43:36 +0000 Subject: [PATCH 3/4] fix: ability to specify external Ollama server (#389) Closes: #387 --- python/.env.example | 7 +++++++ .../beeai_framework/adapters/ollama/backend/chat.py | 7 ++++++- python/tests/backend/test_chatmodel.py | 12 +++++++++++- 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/python/.env.example b/python/.env.example index b93a37b2..9499e42f 100644 --- a/python/.env.example +++ b/python/.env.example @@ -17,3 +17,10 @@ BEEAI_LOG_LEVEL=INFO # WATSONX_URL=your-watsonx-instance-base-url # WATSONX_PROJECT_ID=your-watsonx-project-id # WATSONX_APIKEY=your-watsonx-api-key + +######################## +### Ollama specific configuration +######################## + +# OLLAMA_BASE_URL=http://localhost:11434 +# OLLAMA_CHAT_MODEL=llama3.1:8b diff --git a/python/beeai_framework/adapters/ollama/backend/chat.py b/python/beeai_framework/adapters/ollama/backend/chat.py index 1ab3c902..a3d42a48 100644 --- a/python/beeai_framework/adapters/ollama/backend/chat.py +++ b/python/beeai_framework/adapters/ollama/backend/chat.py @@ -28,7 +28,12 @@ def provider_id(self) -> ProviderName: return "ollama" def __init__(self, model_id: str | None = None, settings: dict | None = None) -> None: + _settings = settings.copy() if settings is not None else None + + if _settings is not None and not hasattr(_settings, "base_url") and "OLLAMA_BASE_URL" in os.environ: + _settings["base_url"] = os.getenv("OLLAMA_BASE_URL") + super().__init__( model_id if model_id else os.getenv("OLLAMA_CHAT_MODEL", "llama3.1:8b"), - settings={"base_url": "http://localhost:11434"} | (settings or {}), + settings={"base_url": "http://localhost:11434"} | (_settings or {}), ) diff --git a/python/tests/backend/test_chatmodel.py b/python/tests/backend/test_chatmodel.py index 08477163..93e00534 100644 --- a/python/tests/backend/test_chatmodel.py +++ b/python/tests/backend/test_chatmodel.py @@ -14,6 +14,7 @@ import asyncio +import os from collections.abc import AsyncGenerator import pytest @@ -129,8 +130,17 @@ async def test_chat_model_abort(reverse_words_chat: ChatModel, chat_messages_lis @pytest.mark.unit def test_chat_model_from() -> None: - ollama_chat_model = ChatModel.from_name("ollama:llama3.1") + # Ollama with Llama model and base_url specified in code + os.environ.pop("OLLAMA_BASE_URL", None) + ollama_chat_model = ChatModel.from_name("ollama:llama3.1", {"base_url": "http://somewhere:12345"}) assert isinstance(ollama_chat_model, OllamaChatModel) + assert ollama_chat_model.settings["base_url"] == "http://somewhere:12345" + + # Ollama with Granite model and base_url specified in env var + os.environ["OLLAMA_BASE_URL"] = "http://somewhere-else:12345" + ollama_chat_model = ChatModel.from_name("ollama:granite3.1-dense:8b") + assert isinstance(ollama_chat_model, OllamaChatModel) + assert ollama_chat_model.settings["base_url"] == "http://somewhere-else:12345" watsonx_chat_model = ChatModel.from_name("watsonx:ibm/granite-3-8b-instruct") assert isinstance(watsonx_chat_model, WatsonxChatModel) From 765a4bad3e0187615d6b515dd2554dea3af8e10b Mon Sep 17 00:00:00 2001 From: //va Date: Tue, 25 Feb 2025 12:13:46 -0500 Subject: [PATCH 4/4] feat: retryable implementation (#363) * feat: initial retryable implementation Signed-off-by: va * fix: feedback review Signed-off-by: va --------- Signed-off-by: va --- python/beeai_framework/agents/runners/base.py | 4 +- .../agents/runners/default/prompts.py | 10 + .../agents/runners/default/runner.py | 237 ++++++++++++------ .../agents/runners/granite/prompts.py | 7 + .../agents/runners/granite/runner.py | 2 + python/beeai_framework/agents/types.py | 2 +- python/beeai_framework/backend/chat.py | 37 ++- python/beeai_framework/cancellation.py | 30 ++- python/beeai_framework/errors.py | 7 +- python/beeai_framework/retryable.py | 222 ++++++++++++++++ python/examples/backend/providers/ollama.py | 17 +- python/examples/backend/providers/watsonx.py | 20 +- python/tests/test_retryable.py | 131 ++++++++++ 13 files changed, 618 insertions(+), 108 deletions(-) create mode 100644 python/beeai_framework/retryable.py create mode 100644 python/tests/test_retryable.py diff --git a/python/beeai_framework/agents/runners/base.py b/python/beeai_framework/agents/runners/base.py index 0c5d3c9b..0e6c5f3f 100644 --- a/python/beeai_framework/agents/runners/base.py +++ b/python/beeai_framework/agents/runners/base.py @@ -14,6 +14,7 @@ from abc import ABC, abstractmethod +from collections.abc import Awaitable from dataclasses import dataclass from beeai_framework.agents.types import ( @@ -113,14 +114,13 @@ async def create_iteration(self) -> RunnerIteration: BeeRunnerLLMInput(emitter=emitter, signal=self._run.signal, meta=meta) ) self._iterations.append(iteration) - return RunnerIteration(emitter=emitter, state=iteration.state, meta=meta, signal=self._run.signal) async def init(self, input: BeeRunInput) -> None: self._memory = await self.init_memory(input) @abstractmethod - async def llm(self, input: BeeRunnerLLMInput) -> BeeAgentRunIteration: + async def llm(self, input: BeeRunnerLLMInput) -> Awaitable[BeeAgentRunIteration]: pass @abstractmethod diff --git a/python/beeai_framework/agents/runners/default/prompts.py b/python/beeai_framework/agents/runners/default/prompts.py index c0c1c9c0..d7069680 100644 --- a/python/beeai_framework/agents/runners/default/prompts.py +++ b/python/beeai_framework/agents/runners/default/prompts.py @@ -49,6 +49,10 @@ class ToolInputErrorTemplateInput(BaseModel): reason: str +class SchemaErrorTemplateInput(BaseModel): + pass + + UserPromptTemplate = PromptTemplate(schema=UserPromptTemplateInput, template="Message: {{input}}") AssistantPromptTemplate = PromptTemplate( @@ -150,3 +154,9 @@ class ToolInputErrorTemplateInput(BaseModel): schema=AssistantPromptTemplateInput, template="""{{#thought}}Thought: {{&.}}\n{{/thought}}{{#tool_name}}Function Name: {{&.}}\n{{/tool_name}}{{#tool_input}}Function Input: {{&.}}\n{{/tool_input}}{{#tool_output}}Function Output: {{&.}}\n{{/tool_output}}{{#final_answer}}Final Answer: {{&.}}{{/final_answer}}""", # noqa: E501 ) + +SchemaErrorTemplate = PromptTemplate( + schema=SchemaErrorTemplateInput, + template="""Error: The generated response does not adhere to the communication structure mentioned in the system prompt. +You communicate only in instruction lines. Valid instruction lines are 'Thought' followed by either 'Function Name' + 'Function Input' + 'Function Output' or 'Final Answer'.""", # noqa: E501 +) diff --git a/python/beeai_framework/agents/runners/default/runner.py b/python/beeai_framework/agents/runners/default/runner.py index bb4b1e14..18a3ed35 100644 --- a/python/beeai_framework/agents/runners/default/runner.py +++ b/python/beeai_framework/agents/runners/default/runner.py @@ -13,7 +13,7 @@ # limitations under the License. import json -from collections.abc import Callable +from collections.abc import Awaitable, Callable from beeai_framework.agents.runners.base import ( BaseRunner, @@ -23,6 +23,8 @@ ) from beeai_framework.agents.runners.default.prompts import ( AssistantPromptTemplate, + SchemaErrorTemplate, + SchemaErrorTemplateInput, SystemPromptTemplate, SystemPromptTemplateInput, ToolDefinition, @@ -37,12 +39,19 @@ BeeRunInput, ) from beeai_framework.backend.chat import ChatModelInput, ChatModelOutput -from beeai_framework.backend.message import SystemMessage, UserMessage +from beeai_framework.backend.message import AssistantMessage, SystemMessage, UserMessage from beeai_framework.emitter.emitter import EventMeta +from beeai_framework.errors import FrameworkError from beeai_framework.memory.base_memory import BaseMemory from beeai_framework.memory.token_memory import TokenMemory from beeai_framework.parsers.field import ParserField -from beeai_framework.parsers.line_prefix import LinePrefixParser, LinePrefixParserNode, LinePrefixParserUpdate +from beeai_framework.parsers.line_prefix import ( + LinePrefixParser, + LinePrefixParserError, + LinePrefixParserNode, + LinePrefixParserUpdate, +) +from beeai_framework.retryable import Retryable, RetryableConfig, RetryableContext, RetryableInput from beeai_framework.tools import ToolError, ToolInputValidationError from beeai_framework.tools.tool import StringToolOutput, Tool, ToolOutput from beeai_framework.utils.strings import create_strenum @@ -56,6 +65,7 @@ def default_templates(self) -> BeeAgentTemplates: user=UserPromptTemplate, tool_not_found_error=ToolNotFoundErrorTemplate, tool_input_error=ToolInputErrorTemplate, + schema_error=SchemaErrorTemplate, ) def create_parser(self) -> LinePrefixParser: @@ -89,64 +99,95 @@ def create_parser(self) -> LinePrefixParser: ) async def llm(self, input: BeeRunnerLLMInput) -> BeeAgentRunIteration: - await input.emitter.emit("start", {"meta": input.meta, "tools": self._input.tools, "memory": self.memory}) - - parser = self.create_parser() - - async def on_update(data: LinePrefixParserUpdate, event: EventMeta) -> None: - if data.key == "tool_output" and parser.done: - return - - await input.emitter.emit( - "update", - { - "data": parser.final_state, - "update": {"key": data.key, "value": data.field.raw, "parsedValue": data.value.model_dump()}, - "meta": {**input.meta.model_dump(), "success": True}, - "tools": self._input.tools, - "memory": self.memory, - }, + async def on_retry(ctx: RetryableContext, last_error: Exception) -> None: + await input.emitter.emit("retry", {"meta": input.meta}) + + async def on_error(error: Exception, _: RetryableContext) -> None: + await input.emitter.emit("error", {"error": error, "meta": input.meta}) + self._failedAttemptsCounter.use(error) + + if isinstance(error, LinePrefixParserError): + if error.reason == LinePrefixParserError.Reason.NoDataReceived: + await self.memory.add(AssistantMessage("\n", {"tempMessage": True})) + else: + schema_error_prompt: str = self.templates.schema_error.render(SchemaErrorTemplateInput()) + await self.memory.add(UserMessage(schema_error_prompt, {"tempMessage": True})) + + async def executor(_: RetryableContext) -> Awaitable[BeeAgentRunIteration]: + await input.emitter.emit("start", {"meta": input.meta, "tools": self._input.tools, "memory": self.memory}) + + parser = self.create_parser() + + async def on_update(data: LinePrefixParserUpdate, event: EventMeta) -> None: + if data.key == "tool_output" and parser.done: + return + + await input.emitter.emit( + "update", + { + "data": parser.final_state, + "update": {"key": data.key, "value": data.field.raw, "parsedValue": data.value.model_dump()}, + "meta": {**input.meta.model_dump(), "success": True}, + "tools": self._input.tools, + "memory": self.memory, + }, + ) + + async def on_partial_update(data: LinePrefixParserUpdate, event: EventMeta) -> None: + await input.emitter.emit( + "partialUpdate", + { + "data": parser.final_state, + "update": {"key": data.key, "value": data.delta, "parsedValue": data.value.model_dump()}, + "meta": {**input.meta.model_dump(), "success": True}, + "tools": self._input.tools, + "memory": self.memory, + }, + ) + + parser.emitter.on("update", on_update) + parser.emitter.on("partialUpdate", on_partial_update) + + async def on_new_token(value: tuple[ChatModelOutput, Callable], event: EventMeta) -> None: + data, abort = value + + if parser.done: + abort() + return + + chunk = data.get_text_content() + await parser.add(chunk) + + if parser.partial_state.get("tool_output") is not None: + abort() + + output: ChatModelOutput = await self._input.llm.create( + ChatModelInput(messages=self.memory.messages[:], stream=True) + ).observe(lambda llm_emitter: llm_emitter.on("newToken", on_new_token)) + + await parser.end() + + await self.memory.delete_many([msg for msg in self.memory.messages if not msg.meta.get("success", True)]) + + return BeeAgentRunIteration( + raw=output, state=BeeIterationResult.model_validate(parser.final_state, strict=False) ) - async def on_partial_update(data: LinePrefixParserUpdate, event: EventMeta) -> None: - await input.emitter.emit( - "partialUpdate", - { - "data": parser.final_state, - "update": {"key": data.key, "value": data.delta, "parsedValue": data.value.model_dump()}, - "meta": {**input.meta.model_dump(), "success": True}, - "tools": self._input.tools, - "memory": self.memory, - }, + if self._options and self._options.execution and self._options.execution.max_retries_per_step: + max_retries = self._options.execution.max_retries_per_step + else: + max_retries = 0 + + retryable_state = await Retryable( + RetryableInput( + on_retry=on_retry, + on_error=on_error, + executor=executor, + config=RetryableConfig(max_retries=max_retries, signal=input.signal), ) + ).get() - parser.emitter.on("update", on_update) - parser.emitter.on("partialUpdate", on_partial_update) - - async def on_new_token(value: tuple[ChatModelOutput, Callable], event: EventMeta) -> None: - data, abort = value - - if parser.done: - abort() - return - - chunk = data.get_text_content() - await parser.add(chunk) - - if parser.partial_state.get("tool_output") is not None: - abort() - - output: ChatModelOutput = await self._input.llm.create( - ChatModelInput(messages=self.memory.messages[:], stream=True) - ).observe(lambda llm_emitter: llm_emitter.on("newToken", on_new_token)) - - await parser.end() - - await self.memory.delete_many([msg for msg in self.memory.messages if not msg.meta.get("success", True)]) - - return BeeAgentRunIteration( - raw=output, state=BeeIterationResult.model_validate(parser.final_state, strict=False) - ) + return retryable_state.value async def tool(self, input: BeeRunnerToolInput) -> BeeRunnerToolResult: tool: Tool | None = next( @@ -174,33 +215,65 @@ async def tool(self, input: BeeRunnerToolInput) -> BeeRunnerToolResult: ), ) - try: - # tool_options = copy.copy(self._options) - # TODO Tool run is not async - # Convert tool input to dict - tool_output: ToolOutput = tool.run(input.state.tool_input, options={}) # TODO: pass tool options - return BeeRunnerToolResult(output=tool_output, success=True) - # TODO These error templates should be customized to help the LLM to recover - except ToolInputValidationError as e: - self._failed_attempts_counter.use(e) - return BeeRunnerToolResult( - success=False, - output=StringToolOutput(self.templates.tool_input_error.render({"reason": str(e)})), - ) - - except ToolError as e: - self._failed_attempts_counter.use(e) - - return BeeRunnerToolResult( - success=False, - output=StringToolOutput(self.templates.tool_input_error.render({"reason": str(e)})), + async def on_error(error: Exception, _: RetryableContext) -> None: + await input.emitter.emit( + "toolError", + { + "data": { + "iteration": input.state, + "tool": tool, + "input": input.state.tool_input, + "options": self._options, + "error": FrameworkError.ensure(error), + }, + "meta": input.meta, + }, ) - except json.JSONDecodeError as e: - self._failed_attempts_counter.use(e) - return BeeRunnerToolResult( - success=False, - output=StringToolOutput(self.templates.tool_input_error.render({"reason": str(e)})), + self._failed_attempts_counter.use(error) + + async def executor(_: RetryableContext) -> Awaitable[BeeRunnerToolResult]: + try: + # tool_options = copy.copy(self._options) + # TODO Tool run is not async + # Convert tool input to dict + tool_output: ToolOutput = tool.run(input.state.tool_input, options={}) # TODO: pass tool options + return BeeRunnerToolResult(output=tool_output, success=True) + # TODO These error templates should be customized to help the LLM to recover + except ToolInputValidationError as e: + self._failed_attempts_counter.use(e) + return BeeRunnerToolResult( + success=False, + output=StringToolOutput(self.templates.tool_input_error.render({"reason": str(e)})), + ) + + except ToolError as e: + self._failed_attempts_counter.use(e) + + return BeeRunnerToolResult( + success=False, + output=StringToolOutput(self.templates.tool_input_error.render({"reason": str(e)})), + ) + except json.JSONDecodeError as e: + self._failed_attempts_counter.use(e) + return BeeRunnerToolResult( + success=False, + output=StringToolOutput(self.templates.tool_input_error.render({"reason": str(e)})), + ) + + if self._options and self._options.execution and self._options.execution.max_retries_per_step: + max_retries = self._options.execution.max_retries_per_step + else: + max_retries = 0 + + retryable_state = await Retryable( + RetryableInput( + on_error=on_error, + executor=executor, + config=RetryableConfig(max_retries=max_retries), ) + ).get() + + return retryable_state.value async def init_memory(self, input: BeeRunInput) -> BaseMemory: memory = TokenMemory( diff --git a/python/beeai_framework/agents/runners/granite/prompts.py b/python/beeai_framework/agents/runners/granite/prompts.py index 16830ce9..bb810e6e 100644 --- a/python/beeai_framework/agents/runners/granite/prompts.py +++ b/python/beeai_framework/agents/runners/granite/prompts.py @@ -16,6 +16,7 @@ from beeai_framework.agents.runners.default.prompts import ( AssistantPromptTemplateInput, + SchemaErrorTemplateInput, SystemPromptTemplateInput, ToolInputErrorTemplateInput, ToolNotFoundErrorTemplateInput, @@ -92,3 +93,9 @@ HINT: If you're convinced that the input was correct but the tool cannot process it then use a different tool or say I don't know.""", # noqa: E501 ) + +GraniteSchemaErrorTemplate = PromptTemplate( + schema=SchemaErrorTemplateInput, + template="""Error: The generated response does not adhere to the communication structure mentioned in the system prompt. +You communicate only in instruction lines. Valid instruction lines are 'Thought' followed by either 'Function Name' + 'Function Input' + 'Function Output' or 'Final Answer'.""", # noqa: E501 +) diff --git a/python/beeai_framework/agents/runners/granite/runner.py b/python/beeai_framework/agents/runners/granite/runner.py index 39dc0c70..c89c4190 100644 --- a/python/beeai_framework/agents/runners/granite/runner.py +++ b/python/beeai_framework/agents/runners/granite/runner.py @@ -17,6 +17,7 @@ from beeai_framework.agents.runners.default.runner import DefaultRunner from beeai_framework.agents.runners.granite.prompts import ( GraniteAssistantPromptTemplate, + GraniteSchemaErrorTemplate, GraniteSystemPromptTemplate, GraniteToolInputErrorTemplate, GraniteToolNotFoundErrorTemplate, @@ -86,6 +87,7 @@ def default_templates(self) -> BeeAgentTemplates: user=GraniteUserPromptTemplate, tool_not_found_error=GraniteToolNotFoundErrorTemplate, tool_input_error=GraniteToolInputErrorTemplate, + schema_error=GraniteSchemaErrorTemplate, ) async def init_memory(self, input: BeeRunInput) -> BaseMemory: diff --git a/python/beeai_framework/agents/types.py b/python/beeai_framework/agents/types.py index 8eec0d1c..62be7441 100644 --- a/python/beeai_framework/agents/types.py +++ b/python/beeai_framework/agents/types.py @@ -80,7 +80,7 @@ class BeeAgentTemplates(BaseModel): tool_input_error: InstanceOf[PromptTemplate] # tool_no_result_error: InstanceOf[PromptTemplate] tool_not_found_error: InstanceOf[PromptTemplate] - # schema_error: InstanceOf[PromptTemplate] + schema_error: InstanceOf[PromptTemplate] class AgentMeta(BaseModel): diff --git a/python/beeai_framework/backend/chat.py b/python/beeai_framework/backend/chat.py index 76aef256..a09af87c 100644 --- a/python/beeai_framework/backend/chat.py +++ b/python/beeai_framework/backend/chat.py @@ -15,7 +15,7 @@ import json from abc import ABC, abstractmethod -from collections.abc import AsyncGenerator, Callable +from collections.abc import AsyncGenerator, Awaitable, Callable from typing import Annotated, Any, Literal, Self, TypeVar from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, InstanceOf, ValidationError @@ -27,6 +27,7 @@ from beeai_framework.cancellation import AbortController, AbortSignal from beeai_framework.context import Run, RunContext, RunContextInput, RunInstance from beeai_framework.emitter import Emitter, EmitterInput +from beeai_framework.retryable import Retryable, RetryableConfig, RetryableContext from beeai_framework.tools.tool import Tool from beeai_framework.utils.custom_logger import BeeLogger from beeai_framework.utils.models import ModelLike, to_model @@ -193,16 +194,34 @@ class DefaultChatModelStructureSchema(BaseModel): *input_messages, ] - response = await self._create( - ChatModelInput(messages=messages, response_format={"type": "object-json"}, abort_signal=input.abort_signal), - ) + class DefaultChatModelStructureErrorSchema(BaseModel): + errors: str + expected: str + received: str + + async def executor(_: RetryableContext) -> Awaitable[ChatModelOutput]: + response = await self._create( + ChatModelInput( + messages=messages, response_format={"type": "object-json"}, abort_signal=input.abort_signal + ), + run, + ) + + logger.debug(f"Recieved structured response:\n{response}") + + text_response = response.get_text_content() + result = parse_broken_json(text_response) + # TODO: validate result matches expected schema + return ChatModelStructureOutput(object=result) - logger.debug(f"Recieved structured response:\n{response}") + retryable_state = await Retryable( + { + "executor": executor, + "config": RetryableConfig(max_retries=input.max_retries if input else 1, signal=run.signal), + } + ).get() - text_response = response.get_text_content() - result = parse_broken_json(text_response) - # TODO: validate result matches expected schema - return ChatModelStructureOutput(object=result) + return retryable_state.value def create(self, chat_model_input: ModelLike[ChatModelInput]) -> Run[ChatModelOutput]: input = to_model(ChatModelInput, chat_model_input) diff --git a/python/beeai_framework/cancellation.py b/python/beeai_framework/cancellation.py index d92a16d0..abbeef7d 100644 --- a/python/beeai_framework/cancellation.py +++ b/python/beeai_framework/cancellation.py @@ -15,10 +15,12 @@ import contextlib import threading -from collections.abc import Callable +from collections.abc import Awaitable, Callable +from typing import Any from pydantic import BaseModel +from beeai_framework.errors import AbortError from beeai_framework.utils.custom_logger import BeeLogger logger = BeeLogger(__name__) @@ -66,6 +68,10 @@ def _callback() -> None: return signal + def throw_if_aborted(self) -> None: + if self._aborted: + raise AbortError(self._reason) + class AbortController: def __init__(self) -> None: @@ -87,3 +93,25 @@ def trigger_abort(reason: str | None = None) -> None: if signal.aborted: trigger_abort(signal.reason) signal.add_event_listener(trigger_abort) + + +async def abort_signal_handler( + fn: Callable[[], Awaitable[Any]], signal: AbortSignal | None = None, on_abort: Callable[[], None] | None = None +) -> Awaitable[Any]: + def abort_handler() -> None: + if on_abort: + on_abort() + if signal: + signal.cancel() + + if signal: + if signal.aborted: + raise AbortError(signal.reason) + else: + signal.add_event_listener(abort_handler) + + try: + return await fn() + finally: + if signal: + signal.remove_event_listener(abort_handler) diff --git a/python/beeai_framework/errors.py b/python/beeai_framework/errors.py index e4210762..9163d376 100644 --- a/python/beeai_framework/errors.py +++ b/python/beeai_framework/errors.py @@ -47,9 +47,12 @@ def __get_message(error: Exception) -> str: message = str(error) if len(str(error)) > 0 else type(error).__name__ return message - def is_retryable(self) -> bool: + @staticmethod + def is_retryable(error: Exception) -> bool: """is error retryable?.""" - return self._is_retryable + if isinstance(error, FrameworkError): + return error._is_retryable + return isinstance(error, CancelledError) def is_fatal(self) -> bool: """is error fatal?""" diff --git a/python/beeai_framework/retryable.py b/python/beeai_framework/retryable.py new file mode 100644 index 00000000..0195341f --- /dev/null +++ b/python/beeai_framework/retryable.py @@ -0,0 +1,222 @@ +# Copyright 2025 IBM Corp. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +import uuid +from collections.abc import AsyncGenerator, Awaitable, Callable +from typing import Any, Literal, Self, TypeVar + +from pydantic import BaseModel + +from beeai_framework.cancellation import AbortController, AbortSignal, abort_signal_handler +from beeai_framework.errors import FrameworkError +from beeai_framework.utils.custom_logger import BeeLogger +from beeai_framework.utils.models import ModelLike, to_model + +T = TypeVar("T", bound=BaseModel) +logger = BeeLogger(__name__) + + +class RetryableState(BaseModel): + state: Literal["pending", "resolved", "rejected"] = "pending" + value: Any | None = None + + def resolve(self, value: Any) -> None: + self.state = "resolved" + self.value = value + + def reject(self, error: Exception) -> None: + self.state = "rejected" + self.value = error + + @property + def is_resolved(self) -> bool: + return self.state == "resolved" + + @property + def is_rejected(self) -> bool: + return self.state == "rejected" + + +class Meta(BaseModel): + attempt: int + remaining: int + + +class RetryableConfig(BaseModel): + max_retries: int + factor: float | None = None + signal: AbortSignal | None = None + + +class RetryableContext(BaseModel): + execution_id: str + attempt: int + signal: AbortSignal | None + + +class RetryableInput(BaseModel): + executor: Callable[[RetryableContext], Awaitable[T]] + on_reset: Callable[[], None] | None = None + on_error: Callable[[Exception, RetryableContext], Awaitable[None]] | None = None + on_retry: Callable[[RetryableContext, Exception], Awaitable[None]] | None = None + config: RetryableConfig + + +class RetryableRunConfig: + group_signal: AbortSignal + + +async def do_retry(fn: Callable[[int], Awaitable[Any]], options: dict[str, Any] | None = None) -> Awaitable[Any]: + async def handler(attempt: int, remaining: int) -> Awaitable: + logger.debug(f"Entering p_retry handler({attempt}, {remaining})") + try: + factor = options.get("factor", 2) or 2 + + if attempt > 1: + await asyncio.sleep(factor ** (attempt - 1)) + + return await fn(attempt) + except Exception as e: + logger.debug(f"p_retry exception: {e}") + meta = Meta(attempt=attempt, remaining=remaining) + + if isinstance(e, asyncio.CancelledError): + raise e + + if options["on_failed_attempt"]: + await options["on_failed_attempt"](e, meta) + + if remaining <= 0: + raise e + + if (options.get("should_retry", lambda _: False)(e)) is False: + raise e + + return await handler(attempt + 1, remaining - 1) + + return await abort_signal_handler(lambda: handler(1, options.get("retries", 0)), options.get("signal")) + + +class Retryable: + def __init__(self, retryable_input: ModelLike[RetryableInput]) -> None: + self._id = str(uuid.uuid4()) + self._retry_state: RetryableState | None = None + retry_input = to_model(RetryableInput, retryable_input) + self._handlers = to_model(RetryableInput, retry_input) + self._config = retry_input.config + + @staticmethod + async def run_group(inputs: list[Self]) -> list[T]: + async def input_get(input: Self, controller: AbortController) -> RetryableState | None: + try: + return await input.get({"group_signal": controller.signal}) + except Exception as err: + controller.abort(err) + raise err + + controller = AbortController() + results = await asyncio.gather(**[input_get(input, controller) for input in inputs]) + controller.signal.throw_if_aborted() + return [result.value for result in results] + + @staticmethod + async def run_sequence(inputs: list[Self]) -> AsyncGenerator[T]: + for input in inputs: + yield await input.get() + + @staticmethod + async def collect(inputs: dict[str, Self]) -> dict[str, Any]: + await asyncio.gather([input.get() for input in inputs.values()]) + return await asyncio.gather({key: value.get() for key, value in inputs.items()}) + + def _get_context(self, attempt: int) -> RetryableContext: + ctx = RetryableContext( + execution_id=self._id, + attempt=attempt, + signal=self._config.signal, + ) + return ctx + + def is_resolved(self) -> bool: + return self._retry_state.is_resolved if self._retry_state else False + + def is_rejected(self) -> bool: + return self._retry_state.is_rejected if self._retry_state else False + + async def _run(self, config: RetryableRunConfig | None = None) -> RetryableState: + retry_state = RetryableState() + + def assert_aborted() -> None: + if self._config.signal and self._config.signal.throw_if_aborted: + self._config.signal.throw_if_aborted() + if config and config.group_signal and config.group_signal.throw_if_aborted: + config.group_signal.throw_if_aborted() + + last_error: Exception | None = None + + async def _retry(attempt: int) -> Awaitable: + assert_aborted() + ctx = self._get_context(attempt) + if attempt > 1: + await self._handlers.on_retry(ctx, last_error) + return await self._handlers.executor(ctx) + + def _should_retry(e: FrameworkError) -> bool: + should_retry = not ( + not FrameworkError.is_retryable(e) + or (config and config.group_signal and config.group_signal.aborted) + or (self._config.signal and self._config.signal.aborted) + ) + logger.trace("Retryable run should retry:", should_retry) + + async def _on_failed_attempt(e: FrameworkError, meta: Meta) -> None: + nonlocal last_error + last_error = e + await self._handlers.on_error(e, self._get_context(meta.attempt)) + if not FrameworkError.is_retryable(e): + raise e + assert_aborted() + + options = { + "retries": self._config.max_retries, + "factor": self._config.factor, + "signal": self._config.signal, + "should_retry": _should_retry, + "on_failed_attempt": _on_failed_attempt, + } + + try: + retry_response = await do_retry(_retry, options) + retry_state.resolve(retry_response) + except Exception as e: + retry_state.reject(e) + + return retry_state + + async def get(self, config: RetryableRunConfig | None = None) -> Awaitable[T]: + if self.is_resolved(): + return self._retry_state.value + elif self.is_rejected(): + raise self._retry_state.value + elif (self._retry_state.state not in ["resolved", "rejected"] if self._retry_state else False) and not config: + return self._retry_state + else: + self._retry_state = await self._run(config) + return self._retry_state + + def reset(self) -> None: + self._retry_state = None + self._handlers.on_reset() diff --git a/python/examples/backend/providers/ollama.py b/python/examples/backend/providers/ollama.py index 0bf10678..9433e0ca 100644 --- a/python/examples/backend/providers/ollama.py +++ b/python/examples/backend/providers/ollama.py @@ -8,6 +8,7 @@ from beeai_framework.backend.message import UserMessage from beeai_framework.cancellation import AbortSignal from beeai_framework.emitter import EventMeta +from beeai_framework.errors import AbortError from beeai_framework.parsers.field import ParserField from beeai_framework.parsers.line_prefix import LinePrefixParser, LinePrefixParserNode @@ -43,12 +44,18 @@ async def ollama_stream() -> None: async def ollama_stream_abort() -> None: llm = OllamaChatModel("llama3.1") user_message = UserMessage("What is the smallest of the Cape Verde islands?") - response = await llm.create({"messages": [user_message], "stream": True, "abort_signal": AbortSignal.timeout(0.5)}) - if response is not None: - print(response.get_text_content()) - else: - print("No response returned.") + try: + response = await llm.create( + {"messages": [user_message], "stream": True, "abort_signal": AbortSignal.timeout(0.5)} + ) + + if response is not None: + print(response.get_text_content()) + else: + print("No response returned.") + except AbortError as err: + print(f"Aborted: {err}") async def ollama_structure() -> None: diff --git a/python/examples/backend/providers/watsonx.py b/python/examples/backend/providers/watsonx.py index 66555b7d..06bb54a4 100644 --- a/python/examples/backend/providers/watsonx.py +++ b/python/examples/backend/providers/watsonx.py @@ -3,8 +3,10 @@ from pydantic import BaseModel, Field from beeai_framework.adapters.watsonx.backend.chat import WatsonxChatModel +from beeai_framework.backend.chat import ChatModel from beeai_framework.backend.message import UserMessage from beeai_framework.cancellation import AbortSignal +from beeai_framework.errors import AbortError # Setting can be passed here during initiation or pre-configured via environment variables llm = WatsonxChatModel( @@ -18,7 +20,7 @@ async def watsonx_from_name() -> None: - watsonx_llm = WatsonxChatModel.from_name( + watsonx_llm = ChatModel.from_name( "watsonx:ibm/granite-3-8b-instruct", # { # "project_id": "WATSONX_PROJECT_ID", @@ -45,12 +47,18 @@ async def watsonx_stream() -> None: async def watsonx_stream_abort() -> None: user_message = UserMessage("What is the smallest of the Cape Verde islands?") - response = await llm.create({"messages": [user_message], "stream": True, "abort_signal": AbortSignal.timeout(0.5)}) - if response is not None: - print(response.get_text_content()) - else: - print("No response returned.") + try: + response = await llm.create( + {"messages": [user_message], "stream": True, "abort_signal": AbortSignal.timeout(0.5)} + ) + + if response is not None: + print(response.get_text_content()) + else: + print("No response returned.") + except AbortError as err: + print(f"Aborted: {err}") async def watson_structure() -> None: diff --git a/python/tests/test_retryable.py b/python/tests/test_retryable.py new file mode 100644 index 00000000..943a6c5f --- /dev/null +++ b/python/tests/test_retryable.py @@ -0,0 +1,131 @@ +# Copyright 2025 IBM Corp. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Awaitable + +import pytest + +from beeai_framework.errors import FrameworkError +from beeai_framework.retryable import Retryable, RetryableConfig, RetryableContext, RetryableInput + + +async def executor(ctx: RetryableContext) -> Awaitable: + print(f"running executor: {ctx}") + + +def on_reset() -> None: + print("on_reset") + + +async def on_error(e: Exception, ctx: RetryableContext) -> None: + print(f"on_error: {e}") + + +async def on_retry(ctx: RetryableContext, last_error: Exception) -> None: + print(f"on_retry: {ctx}") + + +@pytest.mark.asyncio +async def test_retryable() -> None: + retry_state = await Retryable( + { + "executor": executor, + "on_reset": on_reset, + "on_error": on_error, + "on_retry": on_retry, + "config": RetryableConfig(max_retries=3), + } + ).get() + + assert retry_state.is_resolved + + +@pytest.mark.asyncio +async def test_retryable_error() -> None: + async def executor(ctx: RetryableContext) -> Awaitable: + raise FrameworkError("frameworkerror:test_retryable_error") + + retry = Retryable( + RetryableInput( + executor=executor, + on_reset=on_reset, + on_error=on_error, + on_retry=on_retry, + config=RetryableConfig(max_retries=3), + ) + ) + + retry_state = await retry.get() + assert retry_state.is_rejected + + +@pytest.mark.asyncio +async def test_retryable_retries() -> None: + async def executor(ctx: RetryableContext) -> Awaitable: + print(f"Executing attempt: {ctx.attempt}") + raise FrameworkError(f"frameworkerror:test_retryable_retries:{ctx.attempt}", is_retryable=True) + + max_retries = 3 + + retry = Retryable( + { + "executor": executor, + "on_reset": on_reset, + "on_error": on_error, + "on_retry": on_retry, + "config": RetryableConfig(max_retries=max_retries), + } + ) + + retry_state = await retry.get() + + assert retry_state.is_rejected + assert retry_state.value.message == f"frameworkerror:test_retryable_retries:{max_retries + 1}" + assert retry.is_rejected() + + +@pytest.mark.asyncio +async def test_retryable_reset() -> None: + counter = 0 + + async def executor(ctx: RetryableContext) -> Awaitable: + nonlocal counter + counter += 1 + print(f"Executing count: {counter}") + if counter > 1: + return {"counter": counter} + raise FrameworkError(f"frameworkerror:test_retryable_reset:{counter}") + + retry = Retryable( + RetryableInput( + executor=executor, + on_reset=on_reset, + on_error=on_error, + on_retry=on_retry, + config=RetryableConfig(max_retries=0), + ) + ) + + retry_state = await retry.get() + + assert retry_state.is_rejected + assert retry_state.value.message == "frameworkerror:test_retryable_reset:1" + assert retry.is_rejected() + + retry.reset() + retry_state = await retry.get() + + assert retry_state.is_resolved + assert retry_state.value.get("counter") == counter + assert retry.is_resolved()