Skip to content

Commit

Permalink
feat: typings and error propagating
Browse files Browse the repository at this point in the history
Ref: #378
Signed-off-by: Tomas Dvorak <[email protected]>
  • Loading branch information
Tomas2D committed Feb 24, 2025
1 parent 7be2056 commit a66ad9e
Showing 1 changed file with 10 additions and 17 deletions.
27 changes: 10 additions & 17 deletions python/beeai_framework/backend/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from collections.abc import AsyncGenerator, Callable
from typing import Annotated, Any, Literal, Self, TypeVar

from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, ValidationError
from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, InstanceOf, ValidationError

from beeai_framework.backend.constants import ProviderName
from beeai_framework.backend.errors import ChatModelError
Expand Down Expand Up @@ -95,17 +95,10 @@ class ChatModelUsage(BaseModel):
total_tokens: int


class ChatModelOutput:
def __init__(
self,
*,
messages: list[Message],
usage: ChatModelUsage | None = None,
finish_reason: str | None = None,
) -> None:
self.messages = messages
self.usage = usage
self.finish_reason = finish_reason
class ChatModelOutput(BaseModel):
messages: list[InstanceOf[Message]]
usage: InstanceOf[ChatModelUsage] | None = None
finish_reason: str | None = None

@classmethod
def from_chunks(cls, chunks: list) -> Self:
Expand Down Expand Up @@ -211,7 +204,7 @@ class DefaultChatModelStructureSchema(BaseModel):
# TODO: validate result matches expected schema
return ChatModelStructureOutput(object=result)

def create(self, chat_model_input: ModelLike[ChatModelInput]) -> Run:
def create(self, chat_model_input: ModelLike[ChatModelInput]) -> Run[ChatModelOutput]:
input = to_model(ChatModelInput, chat_model_input)

async def run_create(context: RunContext) -> ChatModelOutput:
Expand All @@ -235,11 +228,11 @@ async def run_create(context: RunContext) -> ChatModelOutput:
await context.emitter.emit("success", {"value": result})
return result
except ChatModelError as error:
await context.emitter.emit("error", {input, error})
await context.emitter.emit("error", error)
raise error
except Exception as ex:
await context.emitter.emit("error", {input, ex})
raise ChatModelError("Model error has occurred.") from ex
except Exception as error:
await context.emitter.emit("error", error)
raise ChatModelError("Model error has occurred.") from error
finally:
await context.emitter.emit("finish", None)

Expand Down

0 comments on commit a66ad9e

Please sign in to comment.