Skip to content

Commit

Permalink
feat(llms): support OpenAI API v1 (#753)
Browse files Browse the repository at this point in the history
* feat(llms): support OpenAI API v1

* fmt

* temporarily skip azure openai tests

* 'Refactored by Sourcery'

---------

Co-authored-by: Sourcery AI <>
  • Loading branch information
mspronesti authored Nov 15, 2023
1 parent 6de6b39 commit 68163d5
Show file tree
Hide file tree
Showing 9 changed files with 518 additions and 165 deletions.
10 changes: 10 additions & 0 deletions pandasai/helpers/openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from __future__ import annotations

from importlib.metadata import version

from packaging.version import parse


def is_openai_v1() -> bool:
_version = parse(version("openai"))
return _version.major >= 1
5 changes: 2 additions & 3 deletions pandasai/helpers/openai_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from contextvars import ContextVar
from typing import Optional, Generator

from openai.openai_object import OpenAIObject

MODEL_COST_PER_1K_TOKENS = {
# GPT-4 input
Expand Down Expand Up @@ -132,10 +131,10 @@ def __repr__(self) -> str:
f"Total Cost (USD): ${self.total_cost:9.6f}"
)

def __call__(self, response: OpenAIObject) -> None:
def __call__(self, response) -> None:
"""Collect token usage"""
usage = response.usage
if "total_tokens" not in usage:
if not hasattr(usage, "total_tokens"):
return None

model_name = standardize_model_name(response.model)
Expand Down
118 changes: 90 additions & 28 deletions pandasai/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@ class CustomLLM(BaseOpenAI):
import os
import ast
import re
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from abc import abstractmethod
from typing import Any, Dict, Optional, Union, Mapping, Tuple

import openai
import requests

from ..exceptions import (
Expand All @@ -29,6 +28,7 @@ class CustomLLM(BaseOpenAI):
NoCodeFoundError,
LLMResponseHTTPError,
)
from ..helpers.openai import is_openai_v1
from ..helpers.openai_info import openai_callback_var
from ..prompts.base import AbstractPrompt

Expand Down Expand Up @@ -133,9 +133,9 @@ def _extract_tag_text(self, response: str, tag: str) -> str:
"""

if match := re.search(
f"(<{tag}>)(.*)(</{tag}>)",
response,
re.DOTALL | re.MULTILINE,
f"(<{tag}>)(.*)(</{tag}>)",
response,
re.DOTALL | re.MULTILINE,
):
return match[2]
return None
Expand Down Expand Up @@ -207,30 +207,42 @@ def generate_code(self, instruction: AbstractPrompt) -> [str, str, str]:
]


class BaseOpenAI(LLM, ABC):
class BaseOpenAI(LLM):
"""Base class to implement a new OpenAI LLM.
LLM base class, this class is extended to be used with OpenAI API.
"""

api_token: str
api_base: str
temperature: float = 0
max_tokens: int = 1000
top_p: float = 1
frequency_penalty: float = 0
presence_penalty: float = 0.6
best_of: int = 1
n: int = 1
stop: Optional[str] = None
request_timeout: Union[float, Tuple[float, float], Any, None] = None
max_retries: int = 2
seed: Optional[int] = None
# support explicit proxy for OpenAI
openai_proxy: Optional[str] = None
default_headers: Union[Mapping[str, str], None] = None
default_query: Union[Mapping[str, object], None] = None
# Configure a custom httpx client. See the
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
http_client: Union[Any, None] = None
client: Any
_is_chat_model: bool

def _set_params(self, **kwargs):
"""
Set Parameters
Args:
**kwargs: ["model", "engine", "deployment_id", "temperature","max_tokens",
"top_p", "frequency_penalty", "presence_penalty", "stop", "seed", ]
**kwargs: ["model", "deployment_name", "temperature","max_tokens",
"top_p", "frequency_penalty", "presence_penalty", "stop", "seed"]
Returns:
None.
Expand All @@ -239,8 +251,7 @@ def _set_params(self, **kwargs):

valid_params = [
"model",
"engine",
"deployment_id",
"deployment_name",
"temperature",
"max_tokens",
"top_p",
Expand All @@ -255,21 +266,50 @@ def _set_params(self, **kwargs):

@property
def _default_params(self) -> Dict[str, Any]:
"""
Get the default parameters for calling OpenAI API
Returns
Dict: A dict of OpenAi API parameters.
"""

return {
"""Get the default parameters for calling OpenAI API."""
params: Dict[str, Any] = {
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"seed": self.seed,
"stop": self.stop,
"n": self.n,
}

if self.max_tokens is not None:
params["max_tokens"] = self.max_tokens

# Azure gpt-35-turbo doesn't support best_of
# don't specify best_of if it is 1
if self.best_of > 1:
params["best_of"] = self.best_of

return params

@property
def _invocation_params(self) -> Dict[str, Any]:
"""Get the parameters used to invoke the model."""
openai_creds: Dict[str, Any] = {}
if not is_openai_v1():
openai_creds.update(
{
"api_key": self.api_token,
"api_base": self.api_base,
}
)

return {**openai_creds, **self._default_params}

@property
def _client_params(self) -> Dict[str, any]:
return {
"api_key": self.api_token,
"timeout": self.request_timeout,
"max_retries": self.max_retries,
"default_headers": self.default_headers,
"default_query": self.default_query,
"http_client": self.http_client,
}

def completion(self, prompt: str) -> str:
Expand All @@ -283,17 +323,17 @@ def completion(self, prompt: str) -> str:
str: LLM response.
"""
params = {**self._default_params, "prompt": prompt}
params = {**self._invocation_params, "prompt": prompt}

if self.stop is not None:
params["stop"] = [self.stop]

response = openai.Completion.create(**params)
response = self.client.create(**params)

if openai_handler := openai_callback_var.get():
openai_handler(response)

return response["choices"][0]["text"]
return response.choices[0].text

def chat_completion(self, value: str) -> str:
"""
Expand All @@ -307,7 +347,7 @@ def chat_completion(self, value: str) -> str:
"""
params = {
**self._default_params,
**self._invocation_params,
"messages": [
{
"role": "system",
Expand All @@ -319,12 +359,34 @@ def chat_completion(self, value: str) -> str:
if self.stop is not None:
params["stop"] = [self.stop]

response = openai.ChatCompletion.create(**params)
response = self.client.create(**params)

if openai_handler := openai_callback_var.get():
openai_handler(response)

return response["choices"][0]["message"]["content"]
return response.choices[0].message.content

def call(self, instruction: AbstractPrompt, suffix: str = ""):
"""
Call the OpenAI LLM.
Args:
instruction (AbstractPrompt): A prompt object with instruction for LLM.
suffix (str): Suffix to pass.
Raises:
UnsupportedModelError: Unsupported model
Returns:
str: Response
"""
self.last_prompt = instruction.to_string() + suffix

return (
self.chat_completion(self.last_prompt)
if self._is_chat_model
else self.completion(self.last_prompt)
)


class HuggingFaceLLM(LLM):
Expand Down Expand Up @@ -352,7 +414,7 @@ def _setup(self, **kwargs):
"""
self.api_token = (
kwargs.get("api_token") or os.getenv("HUGGINGFACE_API_KEY") or None
kwargs.get("api_token") or os.getenv("HUGGINGFACE_API_KEY") or None
)
if self.api_token is None:
raise APIKeyNotFoundError("HuggingFace Hub API key is required")
Expand Down
61 changes: 22 additions & 39 deletions pandasai/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ..helpers import load_dotenv

from ..exceptions import APIKeyNotFoundError, UnsupportedModelError
from ..prompts.base import AbstractPrompt
from ..helpers.openai import is_openai_v1
from .base import BaseOpenAI

load_dotenv()
Expand Down Expand Up @@ -50,10 +50,9 @@ class OpenAI(BaseOpenAI):
model: str = "gpt-3.5-turbo"

def __init__(
self,
api_token: Optional[str] = None,
api_key_path: Optional[str] = None,
**kwargs,
self,
api_token: Optional[str] = None,
**kwargs,
):
"""
__init__ method of OpenAI Class
Expand All @@ -64,21 +63,31 @@ def __init__(
"""
self.api_token = api_token or os.getenv("OPENAI_API_KEY") or None
self.api_key_path = api_key_path

if (not self.api_token) and (not self.api_key_path):
raise APIKeyNotFoundError("Either OpenAI API key or key path is required")

if self.api_token:
openai.api_key = self.api_token
else:
openai.api_key_path = self.api_key_path
if not self.api_token:
raise APIKeyNotFoundError("OpenAI API key is required")

self.openai_proxy = kwargs.get("openai_proxy") or os.getenv("OPENAI_PROXY")
if self.openai_proxy:
openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy}

self._set_params(**kwargs)
# set the client
model_name = self.model.split(":")[1] if "ft:" in self.model else self.model
if model_name in self._supported_chat_models:
self._is_chat_model = True
if is_openai_v1():
self.client = openai.OpenAI(**self._client_params).chat.completions
else:
self.client = openai.ChatCompletion
elif model_name in self._supported_completion_models:
self._is_chat_model = False
if is_openai_v1():
self.client = openai.OpenAI(**self._client_params).completions
else:
self.client = openai.Completion
else:
raise UnsupportedModelError(self.model)

@property
def _default_params(self) -> Dict[str, Any]:
Expand All @@ -88,32 +97,6 @@ def _default_params(self) -> Dict[str, Any]:
"model": self.model,
}

def call(self, instruction: AbstractPrompt, suffix: str = "") -> str:
"""
Call the OpenAI LLM.
Args:
instruction (AbstractPrompt): A prompt object with instruction for LLM.
suffix (str): Suffix to pass.
Raises:
UnsupportedModelError: Unsupported model
Returns:
str: Response
"""
self.last_prompt = instruction.to_string() + suffix

model_name = self.model.split(":")[1] if "ft:" in self.model else self.model
if model_name in self._supported_chat_models:
response = self.chat_completion(self.last_prompt)
elif model_name in self._supported_completion_models:
response = self.completion(self.last_prompt)
else:
raise UnsupportedModelError(self.model)

return response

@property
def type(self) -> str:
return "openai"
Loading

0 comments on commit 68163d5

Please sign in to comment.