Skip to content

Commit

Permalink
Update docs and usages, simplify capabilities
Browse files Browse the repository at this point in the history
  • Loading branch information
jackgerrits committed Nov 25, 2024
1 parent df73f35 commit e4dc62b
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -137,16 +137,12 @@
"token_provider = get_bearer_token_provider(DefaultAzureCredential(), \"https://cognitiveservices.azure.com/.default\")\n",
"\n",
"az_model_client = AzureOpenAIChatCompletionClient(\n",
" model=\"{your-azure-deployment}\",\n",
" azure_deployment=\"{your-azure-deployment}\",\n",
" model=\"{model-name, such as gpt-4o}\",\n",
" api_version=\"2024-06-01\",\n",
" azure_endpoint=\"https://{your-custom-endpoint}.openai.azure.com/\",\n",
" azure_ad_token_provider=token_provider, # Optional if you choose key-based authentication.\n",
" # api_key=\"sk-...\", # For key-based authentication.\n",
" model_capabilities={\n",
" \"vision\": True,\n",
" \"function_calling\": True,\n",
" \"json_output\": True,\n",
" },\n",
")"
]
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,11 @@ token_provider = get_bearer_token_provider(
)

client = AzureOpenAIChatCompletionClient(
model="{your-azure-deployment}",
azure_deployment="{your-azure-deployment}",
model="{model-name, such as gpt-4o}",
api_version="2024-02-01",
azure_endpoint="https://{your-custom-endpoint}.openai.azure.com/",
azure_ad_token_provider=token_provider,
model_capabilities={
"vision":True,
"function_calling":True,
"json_output":True,
}
)
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -79,15 +79,11 @@
"\n",
"# Create the client with type-checked environment variables\n",
"client = AzureOpenAIChatCompletionClient(\n",
" model=get_env_variable(\"AZURE_OPENAI_DEPLOYMENT_NAME\"),\n",
" azure_deployment=get_env_variable(\"AZURE_OPENAI_DEPLOYMENT_NAME\"),\n",
" model=get_env_variable(\"AZURE_OPENAI_MODEL\"),\n",
" api_version=get_env_variable(\"AZURE_OPENAI_API_VERSION\"),\n",
" azure_endpoint=get_env_variable(\"AZURE_OPENAI_ENDPOINT\"),\n",
" api_key=get_env_variable(\"AZURE_OPENAI_API_KEY\"),\n",
" model_capabilities={\n",
" \"vision\": False,\n",
" \"function_calling\": True,\n",
" \"json_output\": True,\n",
" },\n",
")"
]
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -294,16 +294,12 @@
"token_provider = get_bearer_token_provider(DefaultAzureCredential(), \"https://cognitiveservices.azure.com/.default\")\n",
"\n",
"az_model_client = AzureOpenAIChatCompletionClient(\n",
" model=\"{your-azure-deployment}\",\n",
" azure_deployment=\"{your-azure-deployment}\",\n",
" model=\"{model-name, such as gpt-4o}\",\n",
" api_version=\"2024-06-01\",\n",
" azure_endpoint=\"https://{your-custom-endpoint}.openai.azure.com/\",\n",
" azure_ad_token_provider=token_provider, # Optional if you choose key-based authentication.\n",
" # api_key=\"sk-...\", # For key-based authentication.\n",
" model_capabilities={\n",
" \"vision\": True,\n",
" \"function_calling\": True,\n",
" \"json_output\": True,\n",
" },\n",
")"
]
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,13 @@ def _azure_openai_client_from_config(config: Mapping[str, Any]) -> AsyncAzureOpe

if "azure_deployment" not in copied_config and "model" in copied_config:
warnings.warn(
"Previous behavior of using the model name as the deployment name is deprecated and will be removed in 0.4",
"Previous behavior of using the model name as the deployment name is deprecated and will be removed in 0.4. Please specify azure_deployment",
stacklevel=2,
)

if "azure_endpoint" not in copied_config and "base_url" in copied_config:
warnings.warn(
"Previous behavior of using the base_url as the endpoint is deprecated and will be removed in 0.4",
"Previous behavior of using the base_url as the endpoint is deprecated and will be removed in 0.4. Please specify azure_endpoint",
stacklevel=2,
)

Expand Down Expand Up @@ -350,9 +350,7 @@ def __init__(
model_capabilities: Optional[ModelCapabilities] = None,
):
self._client = client
if model_capabilities is None and isinstance(client, AsyncAzureOpenAI):
raise ValueError("AzureOpenAIChatCompletionClient requires explicit model capabilities")
elif model_capabilities is None:
if model_capabilities is None:
self._model_capabilities = _model_info.get_capabilities(create_args["model"])
else:
self._model_capabilities = model_capabilities
Expand Down Expand Up @@ -963,7 +961,7 @@ class AzureOpenAIChatCompletionClient(BaseOpenAIChatCompletionClient):
api_version (str): The API version to use. **Required for Azure models.**
azure_ad_token (str): The Azure AD token to use. Provide this or `azure_ad_token_provider` for token-based authentication.
azure_ad_token_provider (Callable[[], Awaitable[str]]): The Azure AD token provider to use. Provide this or `azure_ad_token` for token-based authentication.
model_capabilities (ModelCapabilities): The capabilities of the model. **Required for Azure models.**
model_capabilities (ModelCapabilities): The capabilities of the model if default resolved values are not correct.
api_key (optional, str): The API key to use, use this if you are using key based authentication. It is optional if you are using Azure AD token based authentication or `AZURE_OPENAI_API_KEY` environment variable.
timeout (optional, int): The timeout for the request in seconds.
max_retries (optional, int): The maximum number of retries to attempt.
Expand All @@ -990,26 +988,19 @@ class AzureOpenAIChatCompletionClient(BaseOpenAIChatCompletionClient):
token_provider = get_bearer_token_provider(DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default")
az_model_client = AzureOpenAIChatCompletionClient(
model="{your-azure-deployment}",
azure_deployment="{your-azure-deployment}",
model="{deployed-model, such as 'gpt-4o'}",
api_version="2024-06-01",
azure_endpoint="https://{your-custom-endpoint}.openai.azure.com/",
azure_ad_token_provider=token_provider, # Optional if you choose key-based authentication.
# api_key="sk-...", # For key-based authentication. `AZURE_OPENAI_API_KEY` environment variable can also be used instead.
model_capabilities={
"vision": True,
"function_calling": True,
"json_output": True,
},
)
See `here <https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/managed-identity#chat-completions>`_ for how to use the Azure client directly or for more info.
"""

def __init__(self, **kwargs: Unpack[AzureOpenAIClientConfiguration]):
if "model" not in kwargs:
raise ValueError("model is required for OpenAIChatCompletionClient")

model_capabilities: Optional[ModelCapabilities] = None
copied_args = dict(kwargs).copy()
if "model_capabilities" in kwargs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ class BaseOpenAIClientConfiguration(CreateArguments, total=False):
api_key: str
timeout: Union[float, None]
max_retries: int
model_capabilities: ModelCapabilities
"""What functionality the model supports, determined by default from model name but is overriden if value passed."""


# See OpenAI docs for explanation of these parameters
class OpenAIClientConfiguration(BaseOpenAIClientConfiguration, total=False):
organization: str
base_url: str
# Not required
model_capabilities: ModelCapabilities


class AzureOpenAIClientConfiguration(BaseOpenAIClientConfiguration, total=False):
Expand All @@ -47,8 +47,6 @@ class AzureOpenAIClientConfiguration(BaseOpenAIClientConfiguration, total=False)
api_version: Required[str]
azure_ad_token: str
azure_ad_token_provider: AsyncAzureADTokenProvider
# Must be provided
model_capabilities: Required[ModelCapabilities]


__all__ = ["AzureOpenAIClientConfiguration", "OpenAIClientConfiguration"]
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ async def test_openai_chat_completion_client() -> None:
@pytest.mark.asyncio
async def test_azure_openai_chat_completion_client() -> None:
client = AzureOpenAIChatCompletionClient(
azure_deployment="gpt-4o-1",
model="gpt-4o",
api_key="api_key",
api_version="2020-08-04",
Expand Down

0 comments on commit e4dc62b

Please sign in to comment.