diff --git a/python/beeai_framework/backend/chat.py b/python/beeai_framework/backend/chat.py index aeb105af..88462e89 100644 --- a/python/beeai_framework/backend/chat.py +++ b/python/beeai_framework/backend/chat.py @@ -18,7 +18,7 @@ from collections.abc import AsyncGenerator, Callable from typing import Annotated, Any, Literal, Self, TypeVar -from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, ValidationError +from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, InstanceOf, ValidationError from beeai_framework.backend.constants import ProviderName from beeai_framework.backend.errors import ChatModelError @@ -95,17 +95,10 @@ class ChatModelUsage(BaseModel): total_tokens: int -class ChatModelOutput: - def __init__( - self, - *, - messages: list[Message], - usage: ChatModelUsage | None = None, - finish_reason: str | None = None, - ) -> None: - self.messages = messages - self.usage = usage - self.finish_reason = finish_reason +class ChatModelOutput(BaseModel): + messages: list[InstanceOf[Message]] + usage: InstanceOf[ChatModelUsage] | None = None + finish_reason: str | None = None @classmethod def from_chunks(cls, chunks: list) -> Self: @@ -211,7 +204,7 @@ class DefaultChatModelStructureSchema(BaseModel): # TODO: validate result matches expected schema return ChatModelStructureOutput(object=result) - def create(self, chat_model_input: ModelLike[ChatModelInput]) -> Run: + def create(self, chat_model_input: ModelLike[ChatModelInput]) -> Run[ChatModelOutput]: input = to_model(ChatModelInput, chat_model_input) async def run_create(context: RunContext) -> ChatModelOutput: @@ -235,11 +228,11 @@ async def run_create(context: RunContext) -> ChatModelOutput: await context.emitter.emit("success", {"value": result}) return result except ChatModelError as error: - await context.emitter.emit("error", {input, error}) + await context.emitter.emit("error", error) raise error - except Exception as ex: - await context.emitter.emit("error", {input, ex}) - raise ChatModelError("Model error has occurred.") from ex + except Exception as error: + await context.emitter.emit("error", error) + raise ChatModelError("Model error has occurred.") from error finally: await context.emitter.emit("finish", None)