Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Variables for OpenAI API and Model #1172

Merged
merged 1 commit into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,16 @@
# Get key from https://dashboard.cohere.ai/api-keys
# COHERE_API_KEY=

OPENAI_API_BASE=https://api.openai.com/v1/
API_MODEL='gpt-3.5-turbo-1106'

# Get key from https://platform.openai.com/account/api-keys
# OPENAI_API_KEY=

# If you are using a different intrerence engine that adheres to the OpenAI, set it here
# OPENAI_API_BASE='http://127.0.0.1:8080/v1'
# OPENAI_VERSION='2023-05-15'

# Get key from https://makersuite.google.com/app/apikey

# HuggingFace demos: machine that uploads to HuggingFace.
Expand Down
11 changes: 8 additions & 3 deletions lilac/data/cluster_titling.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from ..tasks import TaskInfo
from ..utils import chunks, log
from ..env import env

_TOP_K_CENTRAL_DOCS = 7
_TOP_K_CENTRAL_TITLES = 20
Expand Down Expand Up @@ -184,9 +185,11 @@ def request_with_retries() -> list[str]:
@functools.cache
def _openai_client() -> Any:
"""Get an OpenAI client."""
api_base = env('OPENAI_API_BASE')
try:
import openai


except ImportError:
raise ImportError(
'Could not import the "openai" python package. '
Expand All @@ -196,7 +199,7 @@ def _openai_client() -> Any:
# OpenAI requests sometimes hang, without any errors, and the default connection timeout is 10
# mins, which is too long. Set it to 7 seconds (99%-tile for latency is 3-4 sec). Also set
# `max_retries` to 0 to disable internal retries so we handle retries ourselves.
return instructor.patch(openai.OpenAI(timeout=7, max_retries=0))
return instructor.patch(openai.OpenAI(timeout=7, max_retries=0, base_url=api_base))


class Title(BaseModel):
Expand Down Expand Up @@ -234,11 +237,12 @@ def generate_title_openai(ranked_docs: list[tuple[str, float]]) -> str:
stop=stop_after_attempt(_NUM_RETRIES),
)
def request_with_retries() -> str:
api_model = env('API_MODEL')
max_tokens = _OPENAI_INITIAL_MAX_TOKENS
while max_tokens <= _OPENAI_FINAL_MAX_TOKENS:
try:
title = _openai_client().chat.completions.create(
model='gpt-3.5-turbo-1106',
model=api_model,
response_model=Title,
temperature=0.0,
max_tokens=max_tokens,
Expand Down Expand Up @@ -295,11 +299,12 @@ def generate_category_openai(ranked_docs: list[tuple[str, float]]) -> str:
stop=stop_after_attempt(_NUM_RETRIES),
)
def request_with_retries() -> str:
api_model = env('API_MODEL')
max_tokens = _OPENAI_INITIAL_MAX_TOKENS
while max_tokens <= _OPENAI_FINAL_MAX_TOKENS:
try:
category = _openai_client().chat.completions.create(
model='gpt-3.5-turbo-1106',
model=api_model,
response_model=Category,
temperature=0.0,
max_tokens=max_tokens,
Expand Down
3 changes: 2 additions & 1 deletion lilac/embeddings/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class OpenAIEmbedding(TextEmbeddingSignal):
@override
def setup(self) -> None:
api_key = env('OPENAI_API_KEY')
api_base = env('OPENAI_API_BASE')
azure_api_key = env('AZURE_OPENAI_KEY')
azure_api_version = env('AZURE_OPENAI_VERSION')
azure_api_endpoint = env('AZURE_OPENAI_ENDPOINT')
Expand All @@ -64,7 +65,7 @@ def setup(self) -> None:

else:
if api_key:
self._client = openai.OpenAI(api_key=api_key)
self._client = openai.OpenAI(api_key=api_key, base_url=api_base)
self._azure = False

elif azure_api_key:
Expand Down
10 changes: 6 additions & 4 deletions lilac/gen/generator_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@

class OpenAIChatCompletionGenerator(TextGenerator):
"""An interface for OpenAI chat completion."""

model: str = 'gpt-3.5-turbo-0613'
model: str = env('API_MODEL')
response_description: str = ''

@override
Expand All @@ -25,11 +24,14 @@ def generate(self, prompt: str) -> str:
api_key = env('OPENAI_API_KEY')
api_type = env('OPENAI_API_TYPE')
api_version = env('OPENAI_API_VERSION')
api_base = env('OPENAI_API_BASE')
api_model = env('API_MODEL')
if not api_key:
raise ValueError('`OPENAI_API_KEY` environment variable not set.')

try:
import openai

except ImportError:
raise ImportError(
'Could not import the "openai" python package. '
Expand All @@ -41,15 +43,15 @@ def generate(self, prompt: str) -> str:
openai.api_version = api_version

# Enables response_model in the openai client.
client = instructor.patch(openai.OpenAI())
client = instructor.patch(openai.OpenAI(base_url=api_base))

class Completion(OpenAISchema):
"""Generated completion of a prompt."""

completion: str = Field(..., description=self.response_description)

return client.chat.completions.create(
model='gpt-3.5-turbo',
model=api_model,
response_model=Completion,
messages=[
{
Expand Down
7 changes: 5 additions & 2 deletions lilac/router_concept.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ def generate_examples(description: str) -> list[str]:
api_type = env('OPENAI_API_TYPE')
api_version = env('OPENAI_API_VERSION')
api_engine = env('OPENAI_API_ENGINE_CHAT')
api_base = env('OPENAI_API_BASE')
api_model = env('API_MODEL')
if not api_key:
raise ValueError('`OPENAI_API_KEY` environment variable not set.')
try:
Expand All @@ -262,16 +264,17 @@ def generate_examples(description: str) -> list[str]:
openai.api_key = api_key
api_engine = api_engine


if api_type:
openai.api_type = api_type
openai.api_version = api_version

try:
# Enables response_model in the openai client.
client = instructor.patch(openai.OpenAI())
client = instructor.patch(openai.OpenAI(base_url=api_base))

completion = client.chat.completions.create(
model='gpt-3.5-turbo-1106',
model=api_model,
response_model=Examples,
temperature=0.0,
messages=[
Expand Down
Loading