Skip to content

Commit

Permalink
Variables for OpenAI API and Model
Browse files Browse the repository at this point in the history
  • Loading branch information
drikster80 committed Feb 12, 2024
1 parent 2e527f5 commit 2d4c1fa
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 10 deletions.
7 changes: 7 additions & 0 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,16 @@
# Get key from https://dashboard.cohere.ai/api-keys
# COHERE_API_KEY=

OPENAI_API_BASE=https://api.openai.com/v1/
API_MODEL='gpt-3.5-turbo-1106'

# Get key from https://platform.openai.com/account/api-keys
# OPENAI_API_KEY=

# If you are using a different intrerence engine that adheres to the OpenAI, set it here
# OPENAI_API_BASE='http://127.0.0.1:8080/v1'
# OPENAI_VERSION='2023-05-15'

# Get key from https://makersuite.google.com/app/apikey

# HuggingFace demos: machine that uploads to HuggingFace.
Expand Down
11 changes: 8 additions & 3 deletions lilac/data/cluster_titling.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from ..tasks import TaskInfo
from ..utils import chunks, log
from ..env import env

_TOP_K_CENTRAL_DOCS = 7
_TOP_K_CENTRAL_TITLES = 20
Expand Down Expand Up @@ -184,9 +185,11 @@ def request_with_retries() -> list[str]:
@functools.cache
def _openai_client() -> Any:
"""Get an OpenAI client."""
api_base = env('OPENAI_API_BASE')
try:
import openai


except ImportError:
raise ImportError(
'Could not import the "openai" python package. '
Expand All @@ -196,7 +199,7 @@ def _openai_client() -> Any:
# OpenAI requests sometimes hang, without any errors, and the default connection timeout is 10
# mins, which is too long. Set it to 7 seconds (99%-tile for latency is 3-4 sec). Also set
# `max_retries` to 0 to disable internal retries so we handle retries ourselves.
return instructor.patch(openai.OpenAI(timeout=7, max_retries=0))
return instructor.patch(openai.OpenAI(timeout=7, max_retries=0, base_url=api_base))


class Title(BaseModel):
Expand Down Expand Up @@ -234,11 +237,12 @@ def generate_title_openai(ranked_docs: list[tuple[str, float]]) -> str:
stop=stop_after_attempt(_NUM_RETRIES),
)
def request_with_retries() -> str:
api_model = env('API_MODEL')
max_tokens = _OPENAI_INITIAL_MAX_TOKENS
while max_tokens <= _OPENAI_FINAL_MAX_TOKENS:
try:
title = _openai_client().chat.completions.create(
model='gpt-3.5-turbo-1106',
model=api_model,
response_model=Title,
temperature=0.0,
max_tokens=max_tokens,
Expand Down Expand Up @@ -295,11 +299,12 @@ def generate_category_openai(ranked_docs: list[tuple[str, float]]) -> str:
stop=stop_after_attempt(_NUM_RETRIES),
)
def request_with_retries() -> str:
api_model = env('API_MODEL')
max_tokens = _OPENAI_INITIAL_MAX_TOKENS
while max_tokens <= _OPENAI_FINAL_MAX_TOKENS:
try:
category = _openai_client().chat.completions.create(
model='gpt-3.5-turbo-1106',
model=api_model,
response_model=Category,
temperature=0.0,
max_tokens=max_tokens,
Expand Down
3 changes: 2 additions & 1 deletion lilac/embeddings/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class OpenAIEmbedding(TextEmbeddingSignal):
@override
def setup(self) -> None:
api_key = env('OPENAI_API_KEY')
api_base = env('OPENAI_API_BASE')
azure_api_key = env('AZURE_OPENAI_KEY')
azure_api_version = env('AZURE_OPENAI_VERSION')
azure_api_endpoint = env('AZURE_OPENAI_ENDPOINT')
Expand All @@ -64,7 +65,7 @@ def setup(self) -> None:

else:
if api_key:
self._client = openai.OpenAI(api_key=api_key)
self._client = openai.OpenAI(api_key=api_key, base_url=api_base)
self._azure = False

elif azure_api_key:
Expand Down
10 changes: 6 additions & 4 deletions lilac/gen/generator_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@

class OpenAIChatCompletionGenerator(TextGenerator):
"""An interface for OpenAI chat completion."""

model: str = 'gpt-3.5-turbo-0613'
model: str = env('API_MODEL')
response_description: str = ''

@override
Expand All @@ -25,11 +24,14 @@ def generate(self, prompt: str) -> str:
api_key = env('OPENAI_API_KEY')
api_type = env('OPENAI_API_TYPE')
api_version = env('OPENAI_API_VERSION')
api_base = env('OPENAI_API_BASE')
api_model = env('API_MODEL')
if not api_key:
raise ValueError('`OPENAI_API_KEY` environment variable not set.')

try:
import openai

except ImportError:
raise ImportError(
'Could not import the "openai" python package. '
Expand All @@ -41,15 +43,15 @@ def generate(self, prompt: str) -> str:
openai.api_version = api_version

# Enables response_model in the openai client.
client = instructor.patch(openai.OpenAI())
client = instructor.patch(openai.OpenAI(base_url=api_base))

class Completion(OpenAISchema):
"""Generated completion of a prompt."""

completion: str = Field(..., description=self.response_description)

return client.chat.completions.create(
model='gpt-3.5-turbo',
model=api_model,
response_model=Completion,
messages=[
{
Expand Down
7 changes: 5 additions & 2 deletions lilac/router_concept.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ def generate_examples(description: str) -> list[str]:
api_type = env('OPENAI_API_TYPE')
api_version = env('OPENAI_API_VERSION')
api_engine = env('OPENAI_API_ENGINE_CHAT')
api_base = env('OPENAI_API_BASE')
api_model = env('API_MODEL')
if not api_key:
raise ValueError('`OPENAI_API_KEY` environment variable not set.')
try:
Expand All @@ -262,16 +264,17 @@ def generate_examples(description: str) -> list[str]:
openai.api_key = api_key
api_engine = api_engine


if api_type:
openai.api_type = api_type
openai.api_version = api_version

try:
# Enables response_model in the openai client.
client = instructor.patch(openai.OpenAI())
client = instructor.patch(openai.OpenAI(base_url=api_base))

completion = client.chat.completions.create(
model='gpt-3.5-turbo-1106',
model=api_model,
response_model=Examples,
temperature=0.0,
messages=[
Expand Down

0 comments on commit 2d4c1fa

Please sign in to comment.