From 372f65df42a4b13f70fe4b5aabbbe02742f926fa Mon Sep 17 00:00:00 2001 From: Massimiliano Pronesti Date: Wed, 15 Nov 2023 11:37:25 +0100 Subject: [PATCH] feat(llms): support OpenAI v1 for Azure OpenAI (#755) * feat(llms): support OpenAI v1 for Azure OpenAI * 'Refactored by Sourcery' --------- Co-authored-by: Sourcery AI <> --- pandasai/llm/azure_openai.py | 126 +++++++++++++++++++++----------- tests/llms/test_azure_openai.py | 59 +++++++-------- 2 files changed, 112 insertions(+), 73 deletions(-) diff --git a/pandasai/llm/azure_openai.py b/pandasai/llm/azure_openai.py index d854ffe58..03b540e1a 100644 --- a/pandasai/llm/azure_openai.py +++ b/pandasai/llm/azure_openai.py @@ -12,13 +12,13 @@ """ import os -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union import openai from ..helpers import load_dotenv from ..exceptions import APIKeyNotFoundError, MissingModelError -from ..prompts.base import AbstractPrompt +from ..helpers.openai import is_openai_v1 from .base import BaseOpenAI load_dotenv() @@ -29,45 +29,78 @@ class AzureOpenAI(BaseOpenAI): This class uses `BaseOpenAI` class to support Azure OpenAI features. """ + azure_endpoint: Union[str, None] = None + """Your Azure Active Directory token. + Automatically inferred from env var `AZURE_OPENAI_AD_TOKEN` if not provided. + For more: + https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id. + """ + azure_ad_token: Union[str, None] = None + """A function that returns an Azure Active Directory token. + Will be invoked on every request. + """ + azure_ad_token_provider: Union[str, None] = None + deployment_name: str + api_version: str = "" + """Legacy, for openai<1.0.0 support.""" api_base: str + """Legacy, for openai<1.0.0 support.""" api_type: str = "azure" - api_version: str - engine: str def __init__( - self, - api_token: Optional[str] = None, - api_base: Optional[str] = None, - api_version: Optional[str] = None, - deployment_name: str = None, - is_chat_model: bool = True, - **kwargs, + self, + api_token: Optional[str] = None, + azure_endpoint: Union[str, None] = None, + azure_ad_token: Union[str, None] = None, + azure_ad_token_provider: Union[str, None] = None, + api_base: Optional[str] = None, + api_version: Optional[str] = None, + deployment_name: str = None, + is_chat_model: bool = True, + **kwargs, ): """ __init__ method of AzureOpenAI Class. Args: api_token (str): Azure OpenAI API token. - api_base (str): Base url of the Azure endpoint. + azure_endpoint (str): Azure endpoint. It should look like the following: + azure_ad_token (str): Your Azure Active Directory token. + Automatically inferred from env var `AZURE_OPENAI_AD_TOKEN` if not provided. + For more: https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id. + azure_ad_token_provider (str): A function that returns an Azure Active Directory token. + Will be invoked on every request. api_version (str): Version of the Azure OpenAI API. Be aware the API version may change. + api_base (str): Legacy, kept for backward compatibility with openai < 1.0 deployment_name (str): Custom name of the deployed model is_chat_model (bool): Whether ``deployment_name`` corresponds to a Chat or a Completion model. **kwargs: Inference Parameters. """ - self.api_token = api_token or os.getenv("OPENAI_API_KEY") or None - self.api_base = api_base or os.getenv("OPENAI_API_BASE") or None + self.api_token = ( + api_token + or os.getenv("OPENAI_API_KEY") + or os.getenv("AZURE_OPENAI_API_KEY") + ) + self.azure_endpoint = azure_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT") + self.api_base = api_base or os.getenv("OPENAI_API_BASE") self.api_version = api_version or os.getenv("OPENAI_API_VERSION") if self.api_token is None: raise APIKeyNotFoundError( "Azure OpenAI key is required. Please add an environment variable " "`OPENAI_API_KEY` or pass `api_token` as a named parameter" ) - if self.api_base is None: + if is_openai_v1(): + if self.azure_endpoint is None: + raise APIKeyNotFoundError( + "Azure endpoint is required. Please add an environment variable " + "`AZURE_OPENAI_API_ENDPOINT` or pass `azure_endpoint` as a named parameter" + ) + elif self.api_base is None: raise APIKeyNotFoundError( "Azure OpenAI base is required. Please add an environment variable " "`OPENAI_API_BASE` or pass `api_base` as a named parameter" @@ -77,25 +110,33 @@ def __init__( "Azure OpenAI version is required. Please add an environment variable " "`OPENAI_API_VERSION` or pass `api_version` as a named parameter" ) - openai.api_key = self.api_token - openai.api_base = self.api_base - openai.api_version = self.api_version - openai.api_type = self.api_type if deployment_name is None: raise MissingModelError( "No deployment name provided.", "Please include deployment name from Azure dashboard.", ) - - self.is_chat_model = is_chat_model - self.engine = deployment_name + self.azure_ad_token = azure_ad_token or os.getenv("AZURE_OPENAI_AD_TOKEN") + self.azure_ad_token_provider = azure_ad_token_provider + self._is_chat_model = is_chat_model + self.deployment_name = deployment_name self.openai_proxy = kwargs.get("openai_proxy") or os.getenv("OPENAI_PROXY") if self.openai_proxy: openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy} self._set_params(**kwargs) + # set the client + if self._is_chat_model: + self.client = ( + openai.AzureOpenAI(**self._client_params).chat.completions + if is_openai_v1() + else openai.ChatCompletion + ) + elif is_openai_v1(): + self.client = openai.AzureOpenAI(**self._client_params).completions + else: + self.client = openai.Completion @property def _default_params(self) -> Dict[str, Any]: @@ -106,27 +147,30 @@ def _default_params(self) -> Dict[str, Any]: dict: A dictionary containing Default Params. """ - return {**super()._default_params, "engine": self.engine} - - def call(self, instruction: AbstractPrompt, suffix: str = "") -> str: - """ - Call the Azure OpenAI LLM. + return {**super()._default_params, "model" if is_openai_v1() else "engine": self.deployment_name} - Args: - instruction (AbstractPrompt): A prompt object with instruction for LLM. - suffix (str): Suffix to pass. - - Returns: - str: LLM response. - - """ - self.last_prompt = instruction.to_string() + suffix + @property + def _invocation_params(self) -> Dict[str, Any]: + """Get the parameters used to invoke the model.""" + if is_openai_v1(): + return super()._invocation_params + else: + return { + **super()._invocation_params, + "api_type": self.api_type, + "api_version": self.api_version, + } - return ( - self.chat_completion(self.last_prompt) - if self.is_chat_model - else self.completion(self.last_prompt) - ) + @property + def _client_params(self) -> Dict[str, any]: + client_params = { + "api_version": self.api_version, + "azure_endpoint": self.azure_endpoint, + "azure_deployment": self.deployment_name, + "azure_ad_token": self.azure_ad_token, + "azure_ad_token_provider": self.azure_ad_token_provider, + } + return {**client_params, **super()._client_params} @property def type(self) -> str: diff --git a/tests/llms/test_azure_openai.py b/tests/llms/test_azure_openai.py index 1b0c7f021..65058ae06 100644 --- a/tests/llms/test_azure_openai.py +++ b/tests/llms/test_azure_openai.py @@ -5,7 +5,11 @@ from pandasai.exceptions import APIKeyNotFoundError, MissingModelError from pandasai.llm import AzureOpenAI -pytest.skip(allow_module_level=True) + +class OpenAIObject: + def __init__(self, dictionary): + self.__dict__.update(dictionary) + class TestAzureOpenAILLM: """Unit tests for the Azure Openai LLM class""" @@ -20,28 +24,28 @@ def test_type_without_endpoint(self): def test_type_without_api_version(self): with pytest.raises(APIKeyNotFoundError): - AzureOpenAI(api_token="test", api_base="test") + AzureOpenAI(api_token="test", azure_endpoint="test") def test_type_without_deployment(self): with pytest.raises(MissingModelError): - AzureOpenAI(api_token="test", api_base="test", api_version="test") + AzureOpenAI(api_token="test", azure_endpoint="test", api_version="test") def test_type_with_token(self): assert ( - AzureOpenAI( - api_token="test", - api_base="test", - api_version="test", - deployment_name="test", - ).type - == "azure-openai" + AzureOpenAI( + api_token="test", + azure_endpoint="test", + api_version="test", + deployment_name="test", + ).type + == "azure-openai" ) def test_proxy(self): proxy = "http://proxy.mycompany.com:8080" client = AzureOpenAI( api_token="test", - api_base="test", + azure_endpoint="test", api_version="test", deployment_name="test", openai_proxy=proxy, @@ -53,7 +57,7 @@ def test_proxy(self): def test_params_setting(self): llm = AzureOpenAI( api_token="test", - api_base="test", + azure_endpoint="test", api_version="test", deployment_name="Deployed-GPT-3", is_chat_model=True, @@ -65,8 +69,8 @@ def test_params_setting(self): stop=["\n"], ) - assert llm.engine == "Deployed-GPT-3" - assert llm.is_chat_model + assert llm.deployment_name == "Deployed-GPT-3" + assert llm._is_chat_model assert llm.temperature == 0.5 assert llm.max_tokens == 50 assert llm.top_p == 1.0 @@ -75,9 +79,8 @@ def test_params_setting(self): assert llm.stop == ["\n"] def test_completion(self, mocker): - openai_mock = mocker.patch("openai.Completion.create") expected_text = "This is the generated text." - openai_mock.return_value = OpenAIObject.construct_from( + expected_response = OpenAIObject( { "choices": [{"text": expected_text}], "usage": { @@ -91,34 +94,25 @@ def test_completion(self, mocker): openai = AzureOpenAI( api_token="test", - api_base="test", + azure_endpoint="test", api_version="test", deployment_name="test", ) + mocker.patch.object(openai, "completion", return_value=expected_response) result = openai.completion("Some prompt.") - openai_mock.assert_called_once_with( - engine=openai.engine, - prompt="Some prompt.", - temperature=openai.temperature, - max_tokens=openai.max_tokens, - top_p=openai.top_p, - frequency_penalty=openai.frequency_penalty, - presence_penalty=openai.presence_penalty, - seed=openai.seed - ) - - assert result == expected_text + openai.completion.assert_called_once_with("Some prompt.") + assert result == expected_response def test_chat_completion(self, mocker): openai = AzureOpenAI( api_token="test", - api_base="test", + azure_endpoint="test", api_version="test", deployment_name="test", is_chat_model=True, ) - expected_response = OpenAIObject.construct_from( + expected_response = OpenAIObject( { "choices": [ { @@ -135,4 +129,5 @@ def test_chat_completion(self, mocker): mocker.patch.object(openai, "chat_completion", return_value=expected_response) result = openai.chat_completion("Hi") - assert result == expected_response \ No newline at end of file + openai.chat_completion.assert_called_once_with("Hi") + assert result == expected_response