-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(llms): support OpenAI v1 for Azure OpenAI #755
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -12,13 +12,13 @@ | |||||||||||||||||||||||
""" | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
import os | ||||||||||||||||||||||||
from typing import Any, Dict, Optional | ||||||||||||||||||||||||
from typing import Any, Dict, Optional, Union | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
import openai | ||||||||||||||||||||||||
from ..helpers import load_dotenv | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
from ..exceptions import APIKeyNotFoundError, MissingModelError | ||||||||||||||||||||||||
from ..prompts.base import AbstractPrompt | ||||||||||||||||||||||||
from ..helpers.openai import is_openai_v1 | ||||||||||||||||||||||||
from .base import BaseOpenAI | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
load_dotenv() | ||||||||||||||||||||||||
|
@@ -29,73 +29,115 @@ class AzureOpenAI(BaseOpenAI): | |||||||||||||||||||||||
This class uses `BaseOpenAI` class to support Azure OpenAI features. | ||||||||||||||||||||||||
""" | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
azure_endpoint: Union[str, None] = None | ||||||||||||||||||||||||
"""Your Azure Active Directory token. | ||||||||||||||||||||||||
Automatically inferred from env var `AZURE_OPENAI_AD_TOKEN` if not provided. | ||||||||||||||||||||||||
For more: | ||||||||||||||||||||||||
https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id. | ||||||||||||||||||||||||
""" | ||||||||||||||||||||||||
azure_ad_token: Union[str, None] = None | ||||||||||||||||||||||||
"""A function that returns an Azure Active Directory token. | ||||||||||||||||||||||||
Will be invoked on every request. | ||||||||||||||||||||||||
""" | ||||||||||||||||||||||||
azure_ad_token_provider: Union[str, None] = None | ||||||||||||||||||||||||
deployment_name: str | ||||||||||||||||||||||||
api_version: str = "" | ||||||||||||||||||||||||
"""Legacy, for openai<1.0.0 support.""" | ||||||||||||||||||||||||
api_base: str | ||||||||||||||||||||||||
"""Legacy, for openai<1.0.0 support.""" | ||||||||||||||||||||||||
api_type: str = "azure" | ||||||||||||||||||||||||
api_version: str | ||||||||||||||||||||||||
engine: str | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
def __init__( | ||||||||||||||||||||||||
self, | ||||||||||||||||||||||||
api_token: Optional[str] = None, | ||||||||||||||||||||||||
api_base: Optional[str] = None, | ||||||||||||||||||||||||
api_version: Optional[str] = None, | ||||||||||||||||||||||||
deployment_name: str = None, | ||||||||||||||||||||||||
is_chat_model: bool = True, | ||||||||||||||||||||||||
**kwargs, | ||||||||||||||||||||||||
self, | ||||||||||||||||||||||||
api_token: Optional[str] = None, | ||||||||||||||||||||||||
azure_endpoint: Union[str, None] = None, | ||||||||||||||||||||||||
azure_ad_token: Union[str, None] = None, | ||||||||||||||||||||||||
azure_ad_token_provider: Union[str, None] = None, | ||||||||||||||||||||||||
api_base: Optional[str] = None, | ||||||||||||||||||||||||
api_version: Optional[str] = None, | ||||||||||||||||||||||||
deployment_name: str = None, | ||||||||||||||||||||||||
is_chat_model: bool = True, | ||||||||||||||||||||||||
**kwargs, | ||||||||||||||||||||||||
): | ||||||||||||||||||||||||
""" | ||||||||||||||||||||||||
__init__ method of AzureOpenAI Class. | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||
api_token (str): Azure OpenAI API token. | ||||||||||||||||||||||||
api_base (str): Base url of the Azure endpoint. | ||||||||||||||||||||||||
azure_endpoint (str): Azure endpoint. | ||||||||||||||||||||||||
It should look like the following: | ||||||||||||||||||||||||
<https://YOUR_RESOURCE_NAME.openai.azure.com/> | ||||||||||||||||||||||||
azure_ad_token (str): Your Azure Active Directory token. | ||||||||||||||||||||||||
Automatically inferred from env var `AZURE_OPENAI_AD_TOKEN` if not provided. | ||||||||||||||||||||||||
For more: https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id. | ||||||||||||||||||||||||
azure_ad_token_provider (str): A function that returns an Azure Active Directory token. | ||||||||||||||||||||||||
Will be invoked on every request. | ||||||||||||||||||||||||
api_version (str): Version of the Azure OpenAI API. | ||||||||||||||||||||||||
Be aware the API version may change. | ||||||||||||||||||||||||
api_base (str): Legacy, kept for backward compatibility with openai < 1.0 | ||||||||||||||||||||||||
deployment_name (str): Custom name of the deployed model | ||||||||||||||||||||||||
is_chat_model (bool): Whether ``deployment_name`` corresponds to a Chat | ||||||||||||||||||||||||
or a Completion model. | ||||||||||||||||||||||||
**kwargs: Inference Parameters. | ||||||||||||||||||||||||
""" | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
self.api_token = api_token or os.getenv("OPENAI_API_KEY") or None | ||||||||||||||||||||||||
self.api_base = api_base or os.getenv("OPENAI_API_BASE") or None | ||||||||||||||||||||||||
self.api_token = ( | ||||||||||||||||||||||||
api_token | ||||||||||||||||||||||||
or os.getenv("OPENAI_API_KEY") | ||||||||||||||||||||||||
or os.getenv("AZURE_OPENAI_API_KEY") | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
self.azure_endpoint = azure_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT") | ||||||||||||||||||||||||
self.api_base = api_base or os.getenv("OPENAI_API_BASE") | ||||||||||||||||||||||||
self.api_version = api_version or os.getenv("OPENAI_API_VERSION") | ||||||||||||||||||||||||
if self.api_token is None: | ||||||||||||||||||||||||
raise APIKeyNotFoundError( | ||||||||||||||||||||||||
"Azure OpenAI key is required. Please add an environment variable " | ||||||||||||||||||||||||
"`OPENAI_API_KEY` or pass `api_token` as a named parameter" | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
if self.api_base is None: | ||||||||||||||||||||||||
raise APIKeyNotFoundError( | ||||||||||||||||||||||||
"Azure OpenAI base is required. Please add an environment variable " | ||||||||||||||||||||||||
"`OPENAI_API_BASE` or pass `api_base` as a named parameter" | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
if is_openai_v1(): | ||||||||||||||||||||||||
if self.azure_endpoint is None: | ||||||||||||||||||||||||
raise APIKeyNotFoundError( | ||||||||||||||||||||||||
"Azure endpoint is required. Please add an environment variable " | ||||||||||||||||||||||||
"`AZURE_OPENAI_API_ENDPOINT` or pass `azure_endpoint` as a named parameter" | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
else: | ||||||||||||||||||||||||
if self.api_base is None: | ||||||||||||||||||||||||
raise APIKeyNotFoundError( | ||||||||||||||||||||||||
"Azure OpenAI base is required. Please add an environment variable " | ||||||||||||||||||||||||
"`OPENAI_API_BASE` or pass `api_base` as a named parameter" | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
if self.api_version is None: | ||||||||||||||||||||||||
raise APIKeyNotFoundError( | ||||||||||||||||||||||||
"Azure OpenAI version is required. Please add an environment variable " | ||||||||||||||||||||||||
"`OPENAI_API_VERSION` or pass `api_version` as a named parameter" | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
openai.api_key = self.api_token | ||||||||||||||||||||||||
openai.api_base = self.api_base | ||||||||||||||||||||||||
openai.api_version = self.api_version | ||||||||||||||||||||||||
openai.api_type = self.api_type | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if deployment_name is None: | ||||||||||||||||||||||||
raise MissingModelError( | ||||||||||||||||||||||||
"No deployment name provided.", | ||||||||||||||||||||||||
"Please include deployment name from Azure dashboard.", | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
self.is_chat_model = is_chat_model | ||||||||||||||||||||||||
self.engine = deployment_name | ||||||||||||||||||||||||
self.azure_ad_token = azure_ad_token or os.getenv("AZURE_OPENAI_AD_TOKEN") | ||||||||||||||||||||||||
self.azure_ad_token_provider = azure_ad_token_provider | ||||||||||||||||||||||||
self._is_chat_model = is_chat_model | ||||||||||||||||||||||||
self.deployment_name = deployment_name | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
self.openai_proxy = kwargs.get("openai_proxy") or os.getenv("OPENAI_PROXY") | ||||||||||||||||||||||||
if self.openai_proxy: | ||||||||||||||||||||||||
openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
self._set_params(**kwargs) | ||||||||||||||||||||||||
# set the client | ||||||||||||||||||||||||
if self._is_chat_model: | ||||||||||||||||||||||||
if is_openai_v1(): | ||||||||||||||||||||||||
self.client = openai.AzureOpenAI(**self._client_params).chat.completions | ||||||||||||||||||||||||
else: | ||||||||||||||||||||||||
self.client = openai.ChatCompletion | ||||||||||||||||||||||||
else: | ||||||||||||||||||||||||
if is_openai_v1(): | ||||||||||||||||||||||||
self.client = openai.AzureOpenAI(**self._client_params).completions | ||||||||||||||||||||||||
else: | ||||||||||||||||||||||||
self.client = openai.Completion | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
@property | ||||||||||||||||||||||||
def _default_params(self) -> Dict[str, Any]: | ||||||||||||||||||||||||
|
@@ -106,27 +148,30 @@ def _default_params(self) -> Dict[str, Any]: | |||||||||||||||||||||||
dict: A dictionary containing Default Params. | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
""" | ||||||||||||||||||||||||
return {**super()._default_params, "engine": self.engine} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
def call(self, instruction: AbstractPrompt, suffix: str = "") -> str: | ||||||||||||||||||||||||
""" | ||||||||||||||||||||||||
Call the Azure OpenAI LLM. | ||||||||||||||||||||||||
return {**super()._default_params, "model" if is_openai_v1() else "engine": self.deployment_name} | ||||||||||||||||||||||||
Comment on lines
147
to
+150
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a syntax error in the - return {**super()._default_params, "model" if is_openai_v1() else "engine": self.deployment_name}
+ return {**super()._default_params, ("model" if is_openai_v1() else "engine"): self.deployment_name} Commitable suggestion
Suggested change
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||
instruction (AbstractPrompt): A prompt object with instruction for LLM. | ||||||||||||||||||||||||
suffix (str): Suffix to pass. | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
Returns: | ||||||||||||||||||||||||
str: LLM response. | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
""" | ||||||||||||||||||||||||
self.last_prompt = instruction.to_string() + suffix | ||||||||||||||||||||||||
@property | ||||||||||||||||||||||||
def _invocation_params(self) -> Dict[str, Any]: | ||||||||||||||||||||||||
"""Get the parameters used to invoke the model.""" | ||||||||||||||||||||||||
if is_openai_v1(): | ||||||||||||||||||||||||
return super()._invocation_params | ||||||||||||||||||||||||
else: | ||||||||||||||||||||||||
return { | ||||||||||||||||||||||||
**super()._invocation_params, | ||||||||||||||||||||||||
"api_type": self.api_type, | ||||||||||||||||||||||||
"api_version": self.api_version, | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
return ( | ||||||||||||||||||||||||
self.chat_completion(self.last_prompt) | ||||||||||||||||||||||||
if self.is_chat_model | ||||||||||||||||||||||||
else self.completion(self.last_prompt) | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
@property | ||||||||||||||||||||||||
def _client_params(self) -> Dict[str, any]: | ||||||||||||||||||||||||
client_params = { | ||||||||||||||||||||||||
"api_version": self.api_version, | ||||||||||||||||||||||||
"azure_endpoint": self.azure_endpoint, | ||||||||||||||||||||||||
"azure_deployment": self.deployment_name, | ||||||||||||||||||||||||
"azure_ad_token": self.azure_ad_token, | ||||||||||||||||||||||||
"azure_ad_token_provider": self.azure_ad_token_provider, | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
return {**client_params, **super()._client_params} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
@property | ||||||||||||||||||||||||
def type(self) -> str: | ||||||||||||||||||||||||
Comment on lines
147
to
176
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The properties - return {**super()._default_params, "model" if is_openai_v1() else "engine": self.deployment_name}
+ return {**super()._default_params, ("model" if is_openai_v1() else "engine"): self.deployment_name} |
||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -5,7 +5,11 @@ | |||||||||||||||||||||||||||||||
from pandasai.exceptions import APIKeyNotFoundError, MissingModelError | ||||||||||||||||||||||||||||||||
from pandasai.llm import AzureOpenAI | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
pytest.skip(allow_module_level=True) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
class OpenAIObject: | ||||||||||||||||||||||||||||||||
def __init__(self, dictionary): | ||||||||||||||||||||||||||||||||
self.__dict__.update(dictionary) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
class TestAzureOpenAILLM: | ||||||||||||||||||||||||||||||||
"""Unit tests for the Azure Openai LLM class""" | ||||||||||||||||||||||||||||||||
|
@@ -20,28 +24,28 @@ def test_type_without_endpoint(self): | |||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
def test_type_without_api_version(self): | ||||||||||||||||||||||||||||||||
with pytest.raises(APIKeyNotFoundError): | ||||||||||||||||||||||||||||||||
AzureOpenAI(api_token="test", api_base="test") | ||||||||||||||||||||||||||||||||
AzureOpenAI(api_token="test", azure_endpoint="test") | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
def test_type_without_deployment(self): | ||||||||||||||||||||||||||||||||
with pytest.raises(MissingModelError): | ||||||||||||||||||||||||||||||||
AzureOpenAI(api_token="test", api_base="test", api_version="test") | ||||||||||||||||||||||||||||||||
AzureOpenAI(api_token="test", azure_endpoint="test", api_version="test") | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
def test_type_with_token(self): | ||||||||||||||||||||||||||||||||
assert ( | ||||||||||||||||||||||||||||||||
AzureOpenAI( | ||||||||||||||||||||||||||||||||
api_token="test", | ||||||||||||||||||||||||||||||||
api_base="test", | ||||||||||||||||||||||||||||||||
api_version="test", | ||||||||||||||||||||||||||||||||
deployment_name="test", | ||||||||||||||||||||||||||||||||
).type | ||||||||||||||||||||||||||||||||
== "azure-openai" | ||||||||||||||||||||||||||||||||
AzureOpenAI( | ||||||||||||||||||||||||||||||||
api_token="test", | ||||||||||||||||||||||||||||||||
azure_endpoint="test", | ||||||||||||||||||||||||||||||||
api_version="test", | ||||||||||||||||||||||||||||||||
deployment_name="test", | ||||||||||||||||||||||||||||||||
).type | ||||||||||||||||||||||||||||||||
== "azure-openai" | ||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||
Comment on lines
25
to
42
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The tests are checking for the correct exceptions when required parameters are not provided. This is good practice as it ensures that the class constructor correctly handles missing arguments. However, the exception |
||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
def test_proxy(self): | ||||||||||||||||||||||||||||||||
proxy = "http://proxy.mycompany.com:8080" | ||||||||||||||||||||||||||||||||
client = AzureOpenAI( | ||||||||||||||||||||||||||||||||
api_token="test", | ||||||||||||||||||||||||||||||||
api_base="test", | ||||||||||||||||||||||||||||||||
azure_endpoint="test", | ||||||||||||||||||||||||||||||||
api_version="test", | ||||||||||||||||||||||||||||||||
deployment_name="test", | ||||||||||||||||||||||||||||||||
openai_proxy=proxy, | ||||||||||||||||||||||||||||||||
|
@@ -53,7 +57,7 @@ def test_proxy(self): | |||||||||||||||||||||||||||||||
def test_params_setting(self): | ||||||||||||||||||||||||||||||||
llm = AzureOpenAI( | ||||||||||||||||||||||||||||||||
api_token="test", | ||||||||||||||||||||||||||||||||
api_base="test", | ||||||||||||||||||||||||||||||||
azure_endpoint="test", | ||||||||||||||||||||||||||||||||
api_version="test", | ||||||||||||||||||||||||||||||||
deployment_name="Deployed-GPT-3", | ||||||||||||||||||||||||||||||||
is_chat_model=True, | ||||||||||||||||||||||||||||||||
Comment on lines
57
to
63
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The test is setting up an - is_chat_model=True,
+ _is_chat_model=True, Commitable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||
|
@@ -65,8 +69,8 @@ def test_params_setting(self): | |||||||||||||||||||||||||||||||
stop=["\n"], | ||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
assert llm.engine == "Deployed-GPT-3" | ||||||||||||||||||||||||||||||||
assert llm.is_chat_model | ||||||||||||||||||||||||||||||||
assert llm.deployment_name == "Deployed-GPT-3" | ||||||||||||||||||||||||||||||||
assert llm._is_chat_model | ||||||||||||||||||||||||||||||||
assert llm.temperature == 0.5 | ||||||||||||||||||||||||||||||||
assert llm.max_tokens == 50 | ||||||||||||||||||||||||||||||||
assert llm.top_p == 1.0 | ||||||||||||||||||||||||||||||||
Comment on lines
69
to
76
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The assertions here are checking that the properties of the
|
||||||||||||||||||||||||||||||||
|
@@ -75,9 +79,8 @@ def test_params_setting(self): | |||||||||||||||||||||||||||||||
assert llm.stop == ["\n"] | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
def test_completion(self, mocker): | ||||||||||||||||||||||||||||||||
openai_mock = mocker.patch("openai.Completion.create") | ||||||||||||||||||||||||||||||||
expected_text = "This is the generated text." | ||||||||||||||||||||||||||||||||
openai_mock.return_value = OpenAIObject.construct_from( | ||||||||||||||||||||||||||||||||
expected_response = OpenAIObject( | ||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||
"choices": [{"text": expected_text}], | ||||||||||||||||||||||||||||||||
"usage": { | ||||||||||||||||||||||||||||||||
|
@@ -91,34 +94,25 @@ def test_completion(self, mocker): | |||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
openai = AzureOpenAI( | ||||||||||||||||||||||||||||||||
api_token="test", | ||||||||||||||||||||||||||||||||
api_base="test", | ||||||||||||||||||||||||||||||||
azure_endpoint="test", | ||||||||||||||||||||||||||||||||
api_version="test", | ||||||||||||||||||||||||||||||||
deployment_name="test", | ||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||
mocker.patch.object(openai, "completion", return_value=expected_response) | ||||||||||||||||||||||||||||||||
result = openai.completion("Some prompt.") | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
openai_mock.assert_called_once_with( | ||||||||||||||||||||||||||||||||
engine=openai.engine, | ||||||||||||||||||||||||||||||||
prompt="Some prompt.", | ||||||||||||||||||||||||||||||||
temperature=openai.temperature, | ||||||||||||||||||||||||||||||||
max_tokens=openai.max_tokens, | ||||||||||||||||||||||||||||||||
top_p=openai.top_p, | ||||||||||||||||||||||||||||||||
frequency_penalty=openai.frequency_penalty, | ||||||||||||||||||||||||||||||||
presence_penalty=openai.presence_penalty, | ||||||||||||||||||||||||||||||||
seed=openai.seed | ||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
assert result == expected_text | ||||||||||||||||||||||||||||||||
openai.completion.assert_called_once_with("Some prompt.") | ||||||||||||||||||||||||||||||||
assert result == expected_response | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
def test_chat_completion(self, mocker): | ||||||||||||||||||||||||||||||||
openai = AzureOpenAI( | ||||||||||||||||||||||||||||||||
api_token="test", | ||||||||||||||||||||||||||||||||
api_base="test", | ||||||||||||||||||||||||||||||||
azure_endpoint="test", | ||||||||||||||||||||||||||||||||
api_version="test", | ||||||||||||||||||||||||||||||||
deployment_name="test", | ||||||||||||||||||||||||||||||||
is_chat_model=True, | ||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||
expected_response = OpenAIObject.construct_from( | ||||||||||||||||||||||||||||||||
expected_response = OpenAIObject( | ||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||
"choices": [ | ||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||
|
@@ -135,4 +129,5 @@ def test_chat_completion(self, mocker): | |||||||||||||||||||||||||||||||
mocker.patch.object(openai, "chat_completion", return_value=expected_response) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
result = openai.chat_completion("Hi") | ||||||||||||||||||||||||||||||||
assert result == expected_response | ||||||||||||||||||||||||||||||||
openai.chat_completion.assert_called_once_with("Hi") | ||||||||||||||||||||||||||||||||
assert result == expected_response |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The class attributes are well-documented. However, the
azure_ad_token_provider
should be a callable that returns a token, not a string. This needs to be corrected to reflect the proper type.Commitable suggestion