Skip to content

Commit

Permalink
Merge branch 'main' into feat/openai
Browse files Browse the repository at this point in the history
  • Loading branch information
vabarbosa committed Feb 25, 2025
2 parents 651bea3 + 765a4ba commit 9d52d6d
Show file tree
Hide file tree
Showing 10 changed files with 63 additions and 45 deletions.
7 changes: 7 additions & 0 deletions python/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,10 @@ BEEAI_LOG_LEVEL=INFO
# WATSONX_URL=your-watsonx-instance-base-url
# WATSONX_PROJECT_ID=your-watsonx-project-id
# WATSONX_APIKEY=your-watsonx-api-key

########################
### Ollama specific configuration
########################

# OLLAMA_BASE_URL=http://localhost:11434
# OLLAMA_CHAT_MODEL=llama3.1:8b
7 changes: 6 additions & 1 deletion python/beeai_framework/adapters/ollama/backend/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ def provider_id(self) -> ProviderName:
return "ollama"

def __init__(self, model_id: str | None = None, settings: dict | None = None) -> None:
_settings = settings.copy() if settings is not None else None

if _settings is not None and not hasattr(_settings, "base_url") and "OLLAMA_BASE_URL" in os.environ:
_settings["base_url"] = os.getenv("OLLAMA_BASE_URL")

super().__init__(
model_id if model_id else os.getenv("OLLAMA_CHAT_MODEL", "llama3.1:8b"),
settings={"base_url": "http://localhost:11434"} | (settings or {}),
settings={"base_url": "http://localhost:11434"} | (_settings or {}),
)
32 changes: 16 additions & 16 deletions python/beeai_framework/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,24 +38,24 @@ def run(self, run_input: ModelLike[BeeRunInput], options: ModelLike[BeeRunOption
if self.is_running:
raise RuntimeError("Agent is already running!")

try:
self.is_running = True
self.is_running = True

async def handler(context: RunContext) -> T:
async def handler(context: RunContext) -> T:
try:
return await self._run(run_input, options, context)

return RunContext.enter(
RunInstance(emitter=self.emitter),
RunContextInput(signal=options.signal if options else None, params=(run_input, options)),
handler,
)
except Exception as e:
if isinstance(e, RuntimeError):
raise e
else:
raise RuntimeError("Error has occurred!") from e
finally:
self.is_running = False
except Exception as e:
if isinstance(e, RuntimeError):
raise e
else:
raise RuntimeError("Error has occurred!") from e
finally:
self.is_running = False

return RunContext.enter(
RunInstance(emitter=self.emitter),
RunContextInput(signal=options.signal if options else None, params=(run_input, options)),
handler,
)

@abstractmethod
async def _run(self, run_input: BeeRunInput, options: BeeRunOptions | None, context: RunContext) -> T:
Expand Down
5 changes: 3 additions & 2 deletions python/beeai_framework/agents/bee/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,14 @@
from beeai_framework.context import RunContext
from beeai_framework.emitter import Emitter, EmitterInput
from beeai_framework.memory import BaseMemory
from beeai_framework.utils.models import ModelLike, to_model


class BeeAgent(BaseAgent[BeeRunOutput]):
runner: Callable[..., BaseRunner]

def __init__(self, bee_input: BeeInput) -> None:
self.input = bee_input
def __init__(self, bee_input: ModelLike[BeeInput]) -> None:
self.input = to_model(BeeInput, bee_input)
if "granite" in self.input.llm.model_id:
self.runner = GraniteRunner
else:
Expand Down
27 changes: 10 additions & 17 deletions python/beeai_framework/backend/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Annotated, Any, Literal, Self, TypeVar

from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, ValidationError
from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, InstanceOf, ValidationError

from beeai_framework.backend.constants import ProviderName
from beeai_framework.backend.errors import ChatModelError
Expand Down Expand Up @@ -96,17 +96,10 @@ class ChatModelUsage(BaseModel):
total_tokens: int


class ChatModelOutput:
def __init__(
self,
*,
messages: list[Message],
usage: ChatModelUsage | None = None,
finish_reason: str | None = None,
) -> None:
self.messages = messages
self.usage = usage
self.finish_reason = finish_reason
class ChatModelOutput(BaseModel):
messages: list[InstanceOf[Message]]
usage: InstanceOf[ChatModelUsage] | None = None
finish_reason: str | None = None

@classmethod
def from_chunks(cls, chunks: list) -> Self:
Expand Down Expand Up @@ -230,7 +223,7 @@ async def executor(_: RetryableContext) -> Awaitable[ChatModelOutput]:

return retryable_state.value

def create(self, chat_model_input: ModelLike[ChatModelInput]) -> Run:
def create(self, chat_model_input: ModelLike[ChatModelInput]) -> Run[ChatModelOutput]:
input = to_model(ChatModelInput, chat_model_input)

async def run_create(context: RunContext) -> ChatModelOutput:
Expand All @@ -254,11 +247,11 @@ async def run_create(context: RunContext) -> ChatModelOutput:
await context.emitter.emit("success", {"value": result})
return result
except ChatModelError as error:
await context.emitter.emit("error", {input, error})
await context.emitter.emit("error", error)
raise error
except Exception as ex:
await context.emitter.emit("error", {input, ex})
raise ChatModelError("Model error has occurred.") from ex
except Exception as error:
await context.emitter.emit("error", error)
raise ChatModelError("Model error has occurred.") from error
finally:
await context.emitter.emit("finish", None)

Expand Down
4 changes: 2 additions & 2 deletions python/beeai_framework/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def observe(self, fn: Callable[[Emitter], Any]) -> Self:
self.tasks.append((fn, self.run_context.emitter))
return self

def context(self, context: "RunContext") -> Self:
def context(self, context: dict) -> Self:
self.tasks.append((self._set_context, context))
return self

Expand All @@ -73,7 +73,7 @@ async def _run_tasks(self) -> R:
self.tasks.clear()
return await self.handler()

def _set_context(self, context: "RunContext") -> None:
def _set_context(self, context: dict) -> None:
self.run_context.context = context
self.run_context.emitter.context = context

Expand Down
12 changes: 7 additions & 5 deletions python/beeai_framework/emitter/emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import uuid
from collections.abc import Callable
from datetime import UTC, datetime
from typing import Any, Generic, TypeVar
from typing import Any, Generic, ParamSpec, TypeAlias, TypeVar

from pydantic import BaseModel, ConfigDict, InstanceOf

Expand All @@ -30,13 +30,15 @@
assert_valid_name,
assert_valid_namespace,
)
from beeai_framework.utils.types import MaybeAsync

T = TypeVar("T", bound=BaseModel)
P = ParamSpec("P")

MatcherFn = Callable[["EventMeta"], bool]
Matcher = str | MatcherFn
Callback = Callable[[Any, "EventMeta"], Any]
CleanupFn = Callable[[], None]
MatcherFn: TypeAlias = Callable[["EventMeta"], bool]
Matcher: TypeAlias = str | MatcherFn
Callback: TypeAlias = MaybeAsync[[P, "EventMeta"], None]
CleanupFn: TypeAlias = Callable[[], None]


class Listener(BaseModel):
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion python/beeai_framework/workflows/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from pydantic import BaseModel
from typing_extensions import TypeVar

from beeai_framework.utils._types import MaybeAsync
from beeai_framework.utils.models import ModelLike, check_model, to_model
from beeai_framework.utils.types import MaybeAsync
from beeai_framework.workflows.errors import WorkflowError

T = TypeVar("T", bound=BaseModel)
Expand Down
12 changes: 11 additions & 1 deletion python/tests/backend/test_chatmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


import asyncio
import os
from collections.abc import AsyncGenerator

import pytest
Expand Down Expand Up @@ -130,8 +131,17 @@ async def test_chat_model_abort(reverse_words_chat: ChatModel, chat_messages_lis

@pytest.mark.unit
def test_chat_model_from() -> None:
ollama_chat_model = ChatModel.from_name("ollama:llama3.1")
# Ollama with Llama model and base_url specified in code
os.environ.pop("OLLAMA_BASE_URL", None)
ollama_chat_model = ChatModel.from_name("ollama:llama3.1", {"base_url": "http://somewhere:12345"})
assert isinstance(ollama_chat_model, OllamaChatModel)
assert ollama_chat_model.settings["base_url"] == "http://somewhere:12345"

# Ollama with Granite model and base_url specified in env var
os.environ["OLLAMA_BASE_URL"] = "http://somewhere-else:12345"
ollama_chat_model = ChatModel.from_name("ollama:granite3.1-dense:8b")
assert isinstance(ollama_chat_model, OllamaChatModel)
assert ollama_chat_model.settings["base_url"] == "http://somewhere-else:12345"

watsonx_chat_model = ChatModel.from_name("watsonx:ibm/granite-3-8b-instruct")
assert isinstance(watsonx_chat_model, WatsonxChatModel)
Expand Down

0 comments on commit 9d52d6d

Please sign in to comment.