Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: created custom error when env var for providers does not exist or unknown provider, moderator is attempted to be loaded #69

Closed
wants to merge 8 commits into from
13 changes: 13 additions & 0 deletions src/exchange/load_exchange_attribute_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import List


class LoadExchangeAttributeError(Exception):
def __init__(self, attribute_name: str, attribute_value: str, available_values: List[str]) -> None:
self.attribute_name = attribute_name
self.attribute_value = attribute_value
self.available_values = available_values
self.message = (
f"Unknown {attribute_name}: {attribute_value}."
+ f" Available {attribute_name}s: {', '.join(available_values)}"
)
super().__init__(self.message)
6 changes: 5 additions & 1 deletion src/exchange/moderators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import cache
from typing import Type

from exchange.load_exchange_attribute_error import LoadExchangeAttributeError
from exchange.moderators.base import Moderator
from exchange.utils import load_plugins
from exchange.moderators.passive import PassiveModerator # noqa
Expand All @@ -10,4 +11,7 @@

@cache
def get_moderator(name: str) -> Type[Moderator]:
return load_plugins(group="exchange.moderator")[name]
moderators = load_plugins(group="exchange.moderator")
if name not in moderators:
raise LoadExchangeAttributeError("moderator", name, moderators.keys())
return moderators[name]
6 changes: 5 additions & 1 deletion src/exchange/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import cache
from typing import Type

from exchange.load_exchange_attribute_error import LoadExchangeAttributeError
from exchange.providers.anthropic import AnthropicProvider # noqa
from exchange.providers.base import Provider, Usage # noqa
from exchange.providers.databricks import DatabricksProvider # noqa
Expand All @@ -14,4 +15,7 @@

@cache
def get_provider(name: str) -> Type[Provider]:
return load_plugins(group="exchange.provider")[name]
providers = load_plugins(group="exchange.provider")
if name not in providers:
raise LoadExchangeAttributeError("provider", name, providers.keys())
return providers[name]
7 changes: 2 additions & 5 deletions src/exchange/providers/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from exchange.content import Text, ToolResult, ToolUse
from exchange.providers.base import Provider, Usage
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import retry_if_status
from exchange.providers.utils import get_provider_env_value, retry_if_status
from exchange.providers.utils import raise_for_status

ANTHROPIC_HOST = "https://api.anthropic.com/v1/messages"
Expand All @@ -27,10 +27,7 @@ def __init__(self, client: httpx.Client) -> None:
@classmethod
def from_env(cls: Type["AnthropicProvider"]) -> "AnthropicProvider":
url = os.environ.get("ANTHROPIC_HOST", ANTHROPIC_HOST)
try:
key = os.environ["ANTHROPIC_API_KEY"]
except KeyError:
raise RuntimeError("Failed to get ANTHROPIC_API_KEY from the environment")
key = get_provider_env_value("ANTHROPIC_API_KEY", "anthropic")
client = httpx.Client(
base_url=url,
headers={
Expand Down
30 changes: 10 additions & 20 deletions src/exchange/providers/azure.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import os
from typing import Type

import httpx

from exchange.providers import OpenAiProvider
from exchange.providers.utils import get_provider_env_value


class AzureProvider(OpenAiProvider):
Expand All @@ -14,25 +14,11 @@ def __init__(self, client: httpx.Client) -> None:

@classmethod
def from_env(cls: Type["AzureProvider"]) -> "AzureProvider":
try:
url = os.environ["AZURE_CHAT_COMPLETIONS_HOST_NAME"]
except KeyError:
raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_HOST_NAME from the environment.")

try:
deployment_name = os.environ["AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME"]
except KeyError:
raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME from the environment.")

try:
api_version = os.environ["AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION"]
except KeyError:
raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION from the environment.")

try:
key = os.environ["AZURE_CHAT_COMPLETIONS_KEY"]
except KeyError:
raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_KEY from the environment.")
url = cls._get_env_variable("AZURE_CHAT_COMPLETIONS_HOST_NAME")
deployment_name = cls._get_env_variable("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME")

api_version = cls._get_env_variable("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION")
key = cls._get_env_variable("AZURE_CHAT_COMPLETIONS_KEY")

# format the url host/"openai/deployments/" + deployment_name + "/?api-version=" + api_version
url = f"{url}/openai/deployments/{deployment_name}/"
Expand All @@ -43,3 +29,7 @@ def from_env(cls: Type["AzureProvider"]) -> "AzureProvider":
timeout=httpx.Timeout(60 * 10),
)
return cls(client)

@classmethod
def _get_env_variable(cls: Type["AzureProvider"], key: str) -> str:
return get_provider_env_value(key, "azure")
13 changes: 12 additions & 1 deletion src/exchange/providers/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from attrs import define, field
from typing import List, Tuple, Type
from typing import List, Optional, Tuple, Type

from exchange.message import Message
from exchange.tool import Tool
Expand Down Expand Up @@ -28,3 +28,14 @@ def complete(
) -> Tuple[Message, Usage]:
"""Generate the next message using the specified model"""
pass


class MissingProviderEnvVariableError(Exception):
def __init__(self, env_variable: str, provider: str, instructions_url: Optional[str] = None) -> None:
self.env_variable = env_variable
self.provider = provider
self.instructions_url = instructions_url
self.message = f"Missing environment variable: {env_variable} for provider {provider}."
if instructions_url:
self.message += f"\n Please see {instructions_url} for instructions"
super().__init__(self.message)
15 changes: 8 additions & 7 deletions src/exchange/providers/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from exchange.message import Message
from exchange.providers import Provider, Usage
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import retry_if_status
from exchange.providers.utils import get_provider_env_value, retry_if_status
from exchange.providers.utils import raise_for_status
from exchange.tool import Tool

Expand Down Expand Up @@ -154,12 +154,9 @@ def __init__(self, client: AwsClient) -> None:
@classmethod
def from_env(cls: Type["BedrockProvider"]) -> "BedrockProvider":
aws_region = os.environ.get("AWS_REGION", "us-east-1")
try:
aws_access_key = os.environ["AWS_ACCESS_KEY_ID"]
aws_secret_key = os.environ["AWS_SECRET_ACCESS_KEY"]
aws_session_token = os.environ.get("AWS_SESSION_TOKEN")
except KeyError:
raise RuntimeError("Failed to get AWS_ACCESS_KEY_ID or AWS_SECRET_ACCESS_KEY from the environment")
aws_access_key = cls._get_env_variable("AWS_ACCESS_KEY_ID")
aws_secret_key = cls._get_env_variable("AWS_SECRET_ACCESS_KEY")
aws_session_token = cls._get_env_variable("AWS_SESSION_TOKEN")

client = AwsClient(
aws_region=aws_region,
Expand Down Expand Up @@ -326,3 +323,7 @@ def tools_to_bedrock_spec(tools: Tuple[Tool]) -> Optional[dict]:
tools_added.add(tool.name)
tool_config = {"tools": tool_config_list}
return tool_config

@classmethod
def _get_env_variable(cls: Type["BedrockProvider"], key: str) -> str:
return get_provider_env_value(key, "bedrock")
22 changes: 8 additions & 14 deletions src/exchange/providers/databricks.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import os
from typing import Any, Dict, List, Tuple, Type

import httpx

from exchange.message import Message
from exchange.providers.base import Provider, Usage
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import raise_for_status, retry_if_status
from exchange.providers.utils import get_provider_env_value, raise_for_status, retry_if_status
from exchange.providers.utils import (
messages_to_openai_spec,
openai_response_to_message,
Expand Down Expand Up @@ -37,18 +36,8 @@ def __init__(self, client: httpx.Client) -> None:

@classmethod
def from_env(cls: Type["DatabricksProvider"]) -> "DatabricksProvider":
try:
url = os.environ["DATABRICKS_HOST"]
except KeyError:
raise RuntimeError(
"Failed to get DATABRICKS_HOST from the environment. See https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields"
)
try:
key = os.environ["DATABRICKS_TOKEN"]
except KeyError:
raise RuntimeError(
"Failed to get DATABRICKS_TOKEN from the environment. See https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields"
)
url = cls._get_env_variable("DATABRICKS_HOST")
key = cls._get_env_variable("DATABRICKS_TOKEN")
client = httpx.Client(
base_url=url,
auth=("token", key),
Expand Down Expand Up @@ -100,3 +89,8 @@ def _post(self, model: str, payload: dict) -> httpx.Response:
json=payload,
)
return raise_for_status(response).json()

@classmethod
def _get_env_variable(cls: Type["DatabricksProvider"], key: str) -> str:
instruction = "https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields"
return get_provider_env_value(key, "databricks", instruction)
11 changes: 3 additions & 8 deletions src/exchange/providers/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from exchange.content import Text, ToolResult, ToolUse
from exchange.providers.base import Provider, Usage
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import retry_if_status
from exchange.providers.utils import get_provider_env_value, retry_if_status
from exchange.providers.utils import raise_for_status

GOOGLE_HOST = "https://generativelanguage.googleapis.com/v1beta"
Expand All @@ -27,13 +27,8 @@ def __init__(self, client: httpx.Client) -> None:
@classmethod
def from_env(cls: Type["GoogleProvider"]) -> "GoogleProvider":
url = os.environ.get("GOOGLE_HOST", GOOGLE_HOST)
try:
key = os.environ["GOOGLE_API_KEY"]
except KeyError:
raise RuntimeError(
"Failed to get GOOGLE_API_KEY from the environment, see https://ai.google.dev/gemini-api/docs/api-key"
)

api_key_instructions_url = "see https://ai.google.dev/gemini-api/docs/api-key"
key = get_provider_env_value("GOOGLE_API_KEY", "google", api_key_instructions_url)
client = httpx.Client(
base_url=url,
headers={
Expand Down
9 changes: 3 additions & 6 deletions src/exchange/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from exchange.message import Message
from exchange.providers.base import Provider, Usage
from exchange.providers.utils import (
get_provider_env_value,
messages_to_openai_spec,
openai_response_to_message,
openai_single_message_context_length_exceeded,
Expand Down Expand Up @@ -36,12 +37,8 @@ def __init__(self, client: httpx.Client) -> None:
@classmethod
def from_env(cls: Type["OpenAiProvider"]) -> "OpenAiProvider":
url = os.environ.get("OPENAI_HOST", OPENAI_HOST)
try:
key = os.environ["OPENAI_API_KEY"]
except KeyError:
raise RuntimeError(
"Failed to get OPENAI_API_KEY from the environment, see https://platform.openai.com/docs/api-reference/api-keys"
)
api_key_instructions_url = "see https://platform.openai.com/docs/api-reference/api-keys"
key = get_provider_env_value("OPENAI_API_KEY", "openai", api_key_instructions_url)
client = httpx.Client(
base_url=url + "v1/",
auth=("Bearer", key),
Expand Down
9 changes: 9 additions & 0 deletions src/exchange/providers/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import base64
import json
import os
import re
from typing import Any, Callable, Dict, List, Optional, Tuple

import httpx
from exchange.content import Text, ToolResult, ToolUse
from exchange.message import Message
from exchange.providers.base import MissingProviderEnvVariableError
from exchange.tool import Tool
from tenacity import retry_if_exception

Expand Down Expand Up @@ -179,6 +181,13 @@ def openai_single_message_context_length_exceeded(error_dict: dict) -> None:
raise InitialMessageTooLargeError(f"Input message too long. Message: {error_dict.get('message')}")


def get_provider_env_value(env_variable: str, provider: str, instructions_url: Optional[str] = None) -> str:
try:
return os.environ[env_variable]
except KeyError:
raise MissingProviderEnvVariableError(env_variable, provider, instructions_url)


class InitialMessageTooLargeError(Exception):
"""Custom error raised when the first input message in an exchange is too large."""

Expand Down
10 changes: 10 additions & 0 deletions tests/providers/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from exchange import Message, Text
from exchange.content import ToolResult, ToolUse
from exchange.providers.anthropic import AnthropicProvider
from exchange.providers.base import MissingProviderEnvVariableError
from exchange.tool import Tool


Expand All @@ -25,6 +26,15 @@ def anthropic_provider():
return AnthropicProvider.from_env()


def test_from_env_throw_error_when_missing_api_key():
with patch.dict(os.environ, {}, clear=True):
with pytest.raises(MissingProviderEnvVariableError) as context:
AnthropicProvider.from_env()
assert context.value.provider == "anthropic"
assert context.value.env_variable == "ANTHROPIC_API_KEY"
assert context.value.message == "Missing environment variable: ANTHROPIC_API_KEY for provider anthropic."


def test_anthropic_response_to_text_message() -> None:
response = {
"content": [{"type": "text", "text": "Hello from Claude!"}],
Expand Down
30 changes: 30 additions & 0 deletions tests/providers/test_azure.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,44 @@
import os
from unittest.mock import patch

import pytest

from exchange import Text, ToolUse
from exchange.providers.azure import AzureProvider
from exchange.providers.base import MissingProviderEnvVariableError
from .conftest import complete, tools

AZURE_MODEL = os.getenv("AZURE_MODEL", "gpt-4o-mini")


@pytest.mark.parametrize(
"env_var_name",
[
("AZURE_CHAT_COMPLETIONS_HOST_NAME"),
("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME"),
("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION"),
("AZURE_CHAT_COMPLETIONS_KEY"),
],
)
def test_from_env_throw_error_when_missing_env_var(env_var_name):
with patch.dict(
os.environ,
{
"AZURE_CHAT_COMPLETIONS_HOST_NAME": "test_host_name",
"AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME": "test_deployment_name",
"AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION": "test_api_version",
"AZURE_CHAT_COMPLETIONS_KEY": "test_api_key",
},
clear=True,
):
os.environ.pop(env_var_name)
with pytest.raises(MissingProviderEnvVariableError) as context:
AzureProvider.from_env()
assert context.value.provider == "azure"
assert context.value.env_variable == env_var_name
assert context.value.message == f"Missing environment variable: {env_var_name} for provider azure."


@pytest.mark.vcr()
def test_azure_complete(default_azure_env):
reply_message, reply_usage = complete(AzureProvider, AZURE_MODEL)
Expand Down
Loading