From 2d4c1fa6e1b24a244df16b2732f6ef26bd32c471 Mon Sep 17 00:00:00 2001 From: drikster80 Date: Mon, 12 Feb 2024 02:18:11 +0000 Subject: [PATCH] Variables for OpenAI API and Model --- .env | 7 +++++++ lilac/data/cluster_titling.py | 11 ++++++++--- lilac/embeddings/openai.py | 3 ++- lilac/gen/generator_openai.py | 10 ++++++---- lilac/router_concept.py | 7 +++++-- 5 files changed, 28 insertions(+), 10 deletions(-) diff --git a/.env b/.env index b8a0aa2e4..525c67d2e 100644 --- a/.env +++ b/.env @@ -13,9 +13,16 @@ # Get key from https://dashboard.cohere.ai/api-keys # COHERE_API_KEY= +OPENAI_API_BASE=https://api.openai.com/v1/ +API_MODEL='gpt-3.5-turbo-1106' # Get key from https://platform.openai.com/account/api-keys # OPENAI_API_KEY= + +# If you are using a different intrerence engine that adheres to the OpenAI, set it here +# OPENAI_API_BASE='http://127.0.0.1:8080/v1' +# OPENAI_VERSION='2023-05-15' + # Get key from https://makersuite.google.com/app/apikey # HuggingFace demos: machine that uploads to HuggingFace. diff --git a/lilac/data/cluster_titling.py b/lilac/data/cluster_titling.py index 47c22b320..805735d56 100644 --- a/lilac/data/cluster_titling.py +++ b/lilac/data/cluster_titling.py @@ -24,6 +24,7 @@ ) from ..tasks import TaskInfo from ..utils import chunks, log +from ..env import env _TOP_K_CENTRAL_DOCS = 7 _TOP_K_CENTRAL_TITLES = 20 @@ -184,9 +185,11 @@ def request_with_retries() -> list[str]: @functools.cache def _openai_client() -> Any: """Get an OpenAI client.""" + api_base = env('OPENAI_API_BASE') try: import openai + except ImportError: raise ImportError( 'Could not import the "openai" python package. ' @@ -196,7 +199,7 @@ def _openai_client() -> Any: # OpenAI requests sometimes hang, without any errors, and the default connection timeout is 10 # mins, which is too long. Set it to 7 seconds (99%-tile for latency is 3-4 sec). Also set # `max_retries` to 0 to disable internal retries so we handle retries ourselves. - return instructor.patch(openai.OpenAI(timeout=7, max_retries=0)) + return instructor.patch(openai.OpenAI(timeout=7, max_retries=0, base_url=api_base)) class Title(BaseModel): @@ -234,11 +237,12 @@ def generate_title_openai(ranked_docs: list[tuple[str, float]]) -> str: stop=stop_after_attempt(_NUM_RETRIES), ) def request_with_retries() -> str: + api_model = env('API_MODEL') max_tokens = _OPENAI_INITIAL_MAX_TOKENS while max_tokens <= _OPENAI_FINAL_MAX_TOKENS: try: title = _openai_client().chat.completions.create( - model='gpt-3.5-turbo-1106', + model=api_model, response_model=Title, temperature=0.0, max_tokens=max_tokens, @@ -295,11 +299,12 @@ def generate_category_openai(ranked_docs: list[tuple[str, float]]) -> str: stop=stop_after_attempt(_NUM_RETRIES), ) def request_with_retries() -> str: + api_model = env('API_MODEL') max_tokens = _OPENAI_INITIAL_MAX_TOKENS while max_tokens <= _OPENAI_FINAL_MAX_TOKENS: try: category = _openai_client().chat.completions.create( - model='gpt-3.5-turbo-1106', + model=api_model, response_model=Category, temperature=0.0, max_tokens=max_tokens, diff --git a/lilac/embeddings/openai.py b/lilac/embeddings/openai.py index b388a3cd7..a9a70124c 100644 --- a/lilac/embeddings/openai.py +++ b/lilac/embeddings/openai.py @@ -39,6 +39,7 @@ class OpenAIEmbedding(TextEmbeddingSignal): @override def setup(self) -> None: api_key = env('OPENAI_API_KEY') + api_base = env('OPENAI_API_BASE') azure_api_key = env('AZURE_OPENAI_KEY') azure_api_version = env('AZURE_OPENAI_VERSION') azure_api_endpoint = env('AZURE_OPENAI_ENDPOINT') @@ -64,7 +65,7 @@ def setup(self) -> None: else: if api_key: - self._client = openai.OpenAI(api_key=api_key) + self._client = openai.OpenAI(api_key=api_key, base_url=api_base) self._azure = False elif azure_api_key: diff --git a/lilac/gen/generator_openai.py b/lilac/gen/generator_openai.py index a3b05c3ec..4c14721a7 100644 --- a/lilac/gen/generator_openai.py +++ b/lilac/gen/generator_openai.py @@ -15,8 +15,7 @@ class OpenAIChatCompletionGenerator(TextGenerator): """An interface for OpenAI chat completion.""" - - model: str = 'gpt-3.5-turbo-0613' + model: str = env('API_MODEL') response_description: str = '' @override @@ -25,11 +24,14 @@ def generate(self, prompt: str) -> str: api_key = env('OPENAI_API_KEY') api_type = env('OPENAI_API_TYPE') api_version = env('OPENAI_API_VERSION') + api_base = env('OPENAI_API_BASE') + api_model = env('API_MODEL') if not api_key: raise ValueError('`OPENAI_API_KEY` environment variable not set.') try: import openai + except ImportError: raise ImportError( 'Could not import the "openai" python package. ' @@ -41,7 +43,7 @@ def generate(self, prompt: str) -> str: openai.api_version = api_version # Enables response_model in the openai client. - client = instructor.patch(openai.OpenAI()) + client = instructor.patch(openai.OpenAI(base_url=api_base)) class Completion(OpenAISchema): """Generated completion of a prompt.""" @@ -49,7 +51,7 @@ class Completion(OpenAISchema): completion: str = Field(..., description=self.response_description) return client.chat.completions.create( - model='gpt-3.5-turbo', + model=api_model, response_model=Completion, messages=[ { diff --git a/lilac/router_concept.py b/lilac/router_concept.py index af712ed80..5bb791b83 100644 --- a/lilac/router_concept.py +++ b/lilac/router_concept.py @@ -248,6 +248,8 @@ def generate_examples(description: str) -> list[str]: api_type = env('OPENAI_API_TYPE') api_version = env('OPENAI_API_VERSION') api_engine = env('OPENAI_API_ENGINE_CHAT') + api_base = env('OPENAI_API_BASE') + api_model = env('API_MODEL') if not api_key: raise ValueError('`OPENAI_API_KEY` environment variable not set.') try: @@ -262,16 +264,17 @@ def generate_examples(description: str) -> list[str]: openai.api_key = api_key api_engine = api_engine + if api_type: openai.api_type = api_type openai.api_version = api_version try: # Enables response_model in the openai client. - client = instructor.patch(openai.OpenAI()) + client = instructor.patch(openai.OpenAI(base_url=api_base)) completion = client.chat.completions.create( - model='gpt-3.5-turbo-1106', + model=api_model, response_model=Examples, temperature=0.0, messages=[