Skip to content

Commit

Permalink
Merge pull request #333 from andreped/openai-model-fix
Browse files Browse the repository at this point in the history
Allow to set `model` for separate `submit_prompt` calls to OpenAI
  • Loading branch information
zainhoda authored Apr 5, 2024
2 parents fbe4b62 + 6e68cff commit 70cdf8d
Showing 1 changed file with 26 additions and 2 deletions.
28 changes: 26 additions & 2 deletions src/vanna/openai/openai_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, client=None, config=None):
if config is None and client is None:
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
return

if "api_key" in config:
self.client = OpenAI(api_key=config["api_key"])

Expand All @@ -67,7 +67,31 @@ def submit_prompt(self, prompt, **kwargs) -> str:
for message in prompt:
num_tokens += len(message["content"]) / 4

if self.config is not None and "engine" in self.config:
if kwargs.get("model", None) is not None:
model = kwargs.get("model", None)
print(
f"Using model {model} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
model=model,
messages=prompt,
max_tokens=self.max_tokens,
stop=None,
temperature=self.temperature,
)
elif kwargs.get("engine", None) is not None:
engine = kwargs.get("engine", None)
print(
f"Using model {engine} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
engine=engine,
messages=prompt,
max_tokens=self.max_tokens,
stop=None,
temperature=self.temperature,
)
elif self.config is not None and "engine" in self.config:
print(
f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)"
)
Expand Down

0 comments on commit 70cdf8d

Please sign in to comment.