Skip to content

Commit

Permalink
feat(llms): support OpenAI v1 for Azure OpenAI (#755)
Browse files Browse the repository at this point in the history
* feat(llms): support OpenAI v1 for Azure OpenAI

* 'Refactored by Sourcery'

---------

Co-authored-by: Sourcery AI <>
  • Loading branch information
mspronesti authored Nov 15, 2023
1 parent 57cc654 commit 372f65d
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 73 deletions.
126 changes: 85 additions & 41 deletions pandasai/llm/azure_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
"""

import os
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union

import openai
from ..helpers import load_dotenv

from ..exceptions import APIKeyNotFoundError, MissingModelError
from ..prompts.base import AbstractPrompt
from ..helpers.openai import is_openai_v1
from .base import BaseOpenAI

load_dotenv()
Expand All @@ -29,45 +29,78 @@ class AzureOpenAI(BaseOpenAI):
This class uses `BaseOpenAI` class to support Azure OpenAI features.
"""

azure_endpoint: Union[str, None] = None
"""Your Azure Active Directory token.
Automatically inferred from env var `AZURE_OPENAI_AD_TOKEN` if not provided.
For more:
https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id.
"""
azure_ad_token: Union[str, None] = None
"""A function that returns an Azure Active Directory token.
Will be invoked on every request.
"""
azure_ad_token_provider: Union[str, None] = None
deployment_name: str
api_version: str = ""
"""Legacy, for openai<1.0.0 support."""
api_base: str
"""Legacy, for openai<1.0.0 support."""
api_type: str = "azure"
api_version: str
engine: str

def __init__(
self,
api_token: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
deployment_name: str = None,
is_chat_model: bool = True,
**kwargs,
self,
api_token: Optional[str] = None,
azure_endpoint: Union[str, None] = None,
azure_ad_token: Union[str, None] = None,
azure_ad_token_provider: Union[str, None] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
deployment_name: str = None,
is_chat_model: bool = True,
**kwargs,
):
"""
__init__ method of AzureOpenAI Class.
Args:
api_token (str): Azure OpenAI API token.
api_base (str): Base url of the Azure endpoint.
azure_endpoint (str): Azure endpoint.
It should look like the following:
<https://YOUR_RESOURCE_NAME.openai.azure.com/>
azure_ad_token (str): Your Azure Active Directory token.
Automatically inferred from env var `AZURE_OPENAI_AD_TOKEN` if not provided.
For more: https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id.
azure_ad_token_provider (str): A function that returns an Azure Active Directory token.
Will be invoked on every request.
api_version (str): Version of the Azure OpenAI API.
Be aware the API version may change.
api_base (str): Legacy, kept for backward compatibility with openai < 1.0
deployment_name (str): Custom name of the deployed model
is_chat_model (bool): Whether ``deployment_name`` corresponds to a Chat
or a Completion model.
**kwargs: Inference Parameters.
"""

self.api_token = api_token or os.getenv("OPENAI_API_KEY") or None
self.api_base = api_base or os.getenv("OPENAI_API_BASE") or None
self.api_token = (
api_token
or os.getenv("OPENAI_API_KEY")
or os.getenv("AZURE_OPENAI_API_KEY")
)
self.azure_endpoint = azure_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT")
self.api_base = api_base or os.getenv("OPENAI_API_BASE")
self.api_version = api_version or os.getenv("OPENAI_API_VERSION")
if self.api_token is None:
raise APIKeyNotFoundError(
"Azure OpenAI key is required. Please add an environment variable "
"`OPENAI_API_KEY` or pass `api_token` as a named parameter"
)
if self.api_base is None:
if is_openai_v1():
if self.azure_endpoint is None:
raise APIKeyNotFoundError(
"Azure endpoint is required. Please add an environment variable "
"`AZURE_OPENAI_API_ENDPOINT` or pass `azure_endpoint` as a named parameter"
)
elif self.api_base is None:
raise APIKeyNotFoundError(
"Azure OpenAI base is required. Please add an environment variable "
"`OPENAI_API_BASE` or pass `api_base` as a named parameter"
Expand All @@ -77,25 +110,33 @@ def __init__(
"Azure OpenAI version is required. Please add an environment variable "
"`OPENAI_API_VERSION` or pass `api_version` as a named parameter"
)
openai.api_key = self.api_token
openai.api_base = self.api_base
openai.api_version = self.api_version
openai.api_type = self.api_type

if deployment_name is None:
raise MissingModelError(
"No deployment name provided.",
"Please include deployment name from Azure dashboard.",
)

self.is_chat_model = is_chat_model
self.engine = deployment_name
self.azure_ad_token = azure_ad_token or os.getenv("AZURE_OPENAI_AD_TOKEN")
self.azure_ad_token_provider = azure_ad_token_provider
self._is_chat_model = is_chat_model
self.deployment_name = deployment_name

self.openai_proxy = kwargs.get("openai_proxy") or os.getenv("OPENAI_PROXY")
if self.openai_proxy:
openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy}

self._set_params(**kwargs)
# set the client
if self._is_chat_model:
self.client = (
openai.AzureOpenAI(**self._client_params).chat.completions
if is_openai_v1()
else openai.ChatCompletion
)
elif is_openai_v1():
self.client = openai.AzureOpenAI(**self._client_params).completions
else:
self.client = openai.Completion

@property
def _default_params(self) -> Dict[str, Any]:
Expand All @@ -106,27 +147,30 @@ def _default_params(self) -> Dict[str, Any]:
dict: A dictionary containing Default Params.
"""
return {**super()._default_params, "engine": self.engine}

def call(self, instruction: AbstractPrompt, suffix: str = "") -> str:
"""
Call the Azure OpenAI LLM.
return {**super()._default_params, "model" if is_openai_v1() else "engine": self.deployment_name}

Args:
instruction (AbstractPrompt): A prompt object with instruction for LLM.
suffix (str): Suffix to pass.
Returns:
str: LLM response.
"""
self.last_prompt = instruction.to_string() + suffix
@property
def _invocation_params(self) -> Dict[str, Any]:
"""Get the parameters used to invoke the model."""
if is_openai_v1():
return super()._invocation_params
else:
return {
**super()._invocation_params,
"api_type": self.api_type,
"api_version": self.api_version,
}

return (
self.chat_completion(self.last_prompt)
if self.is_chat_model
else self.completion(self.last_prompt)
)
@property
def _client_params(self) -> Dict[str, any]:
client_params = {
"api_version": self.api_version,
"azure_endpoint": self.azure_endpoint,
"azure_deployment": self.deployment_name,
"azure_ad_token": self.azure_ad_token,
"azure_ad_token_provider": self.azure_ad_token_provider,
}
return {**client_params, **super()._client_params}

@property
def type(self) -> str:
Expand Down
59 changes: 27 additions & 32 deletions tests/llms/test_azure_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
from pandasai.exceptions import APIKeyNotFoundError, MissingModelError
from pandasai.llm import AzureOpenAI

pytest.skip(allow_module_level=True)

class OpenAIObject:
def __init__(self, dictionary):
self.__dict__.update(dictionary)


class TestAzureOpenAILLM:
"""Unit tests for the Azure Openai LLM class"""
Expand All @@ -20,28 +24,28 @@ def test_type_without_endpoint(self):

def test_type_without_api_version(self):
with pytest.raises(APIKeyNotFoundError):
AzureOpenAI(api_token="test", api_base="test")
AzureOpenAI(api_token="test", azure_endpoint="test")

def test_type_without_deployment(self):
with pytest.raises(MissingModelError):
AzureOpenAI(api_token="test", api_base="test", api_version="test")
AzureOpenAI(api_token="test", azure_endpoint="test", api_version="test")

def test_type_with_token(self):
assert (
AzureOpenAI(
api_token="test",
api_base="test",
api_version="test",
deployment_name="test",
).type
== "azure-openai"
AzureOpenAI(
api_token="test",
azure_endpoint="test",
api_version="test",
deployment_name="test",
).type
== "azure-openai"
)

def test_proxy(self):
proxy = "http://proxy.mycompany.com:8080"
client = AzureOpenAI(
api_token="test",
api_base="test",
azure_endpoint="test",
api_version="test",
deployment_name="test",
openai_proxy=proxy,
Expand All @@ -53,7 +57,7 @@ def test_proxy(self):
def test_params_setting(self):
llm = AzureOpenAI(
api_token="test",
api_base="test",
azure_endpoint="test",
api_version="test",
deployment_name="Deployed-GPT-3",
is_chat_model=True,
Expand All @@ -65,8 +69,8 @@ def test_params_setting(self):
stop=["\n"],
)

assert llm.engine == "Deployed-GPT-3"
assert llm.is_chat_model
assert llm.deployment_name == "Deployed-GPT-3"
assert llm._is_chat_model
assert llm.temperature == 0.5
assert llm.max_tokens == 50
assert llm.top_p == 1.0
Expand All @@ -75,9 +79,8 @@ def test_params_setting(self):
assert llm.stop == ["\n"]

def test_completion(self, mocker):
openai_mock = mocker.patch("openai.Completion.create")
expected_text = "This is the generated text."
openai_mock.return_value = OpenAIObject.construct_from(
expected_response = OpenAIObject(
{
"choices": [{"text": expected_text}],
"usage": {
Expand All @@ -91,34 +94,25 @@ def test_completion(self, mocker):

openai = AzureOpenAI(
api_token="test",
api_base="test",
azure_endpoint="test",
api_version="test",
deployment_name="test",
)
mocker.patch.object(openai, "completion", return_value=expected_response)
result = openai.completion("Some prompt.")

openai_mock.assert_called_once_with(
engine=openai.engine,
prompt="Some prompt.",
temperature=openai.temperature,
max_tokens=openai.max_tokens,
top_p=openai.top_p,
frequency_penalty=openai.frequency_penalty,
presence_penalty=openai.presence_penalty,
seed=openai.seed
)

assert result == expected_text
openai.completion.assert_called_once_with("Some prompt.")
assert result == expected_response

def test_chat_completion(self, mocker):
openai = AzureOpenAI(
api_token="test",
api_base="test",
azure_endpoint="test",
api_version="test",
deployment_name="test",
is_chat_model=True,
)
expected_response = OpenAIObject.construct_from(
expected_response = OpenAIObject(
{
"choices": [
{
Expand All @@ -135,4 +129,5 @@ def test_chat_completion(self, mocker):
mocker.patch.object(openai, "chat_completion", return_value=expected_response)

result = openai.chat_completion("Hi")
assert result == expected_response
openai.chat_completion.assert_called_once_with("Hi")
assert result == expected_response

0 comments on commit 372f65d

Please sign in to comment.