From 6fd5db775fae8923620563d2ada0db5804f345d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Tue, 2 Apr 2024 17:49:09 +0200 Subject: [PATCH 1/3] Add ability to set model for submit_prompt call for OpenAI Chat --- src/vanna/openai/openai_chat.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/vanna/openai/openai_chat.py b/src/vanna/openai/openai_chat.py index 5f2ce81b..5cc51066 100644 --- a/src/vanna/openai/openai_chat.py +++ b/src/vanna/openai/openai_chat.py @@ -41,7 +41,7 @@ def __init__(self, client=None, config=None): if config is None and client is None: self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) return - + if "api_key" in config: self.client = OpenAI(api_key=config["api_key"]) @@ -67,7 +67,18 @@ def submit_prompt(self, prompt, **kwargs) -> str: for message in prompt: num_tokens += len(message["content"]) / 4 - if self.config is not None and "engine" in self.config: + if kwargs.get("model", None) is not None: + print( + f"Using model {self.config['model']} for {num_tokens} tokens (approx)" + ) + response = self.client.chat.completions.create( + model=model, + messages=prompt, + max_tokens=self.max_tokens, + stop=None, + temperature=self.temperature, + ) + elif self.config is not None and "engine" in self.config: print( f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)" ) From 17ced37f49c8395626f44e6e1969d47379c246af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Tue, 2 Apr 2024 17:50:33 +0200 Subject: [PATCH 2/3] Allow setting both model and engine for OpenAI Chat calls --- src/vanna/openai/openai_chat.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/vanna/openai/openai_chat.py b/src/vanna/openai/openai_chat.py index 5cc51066..2fd571d5 100644 --- a/src/vanna/openai/openai_chat.py +++ b/src/vanna/openai/openai_chat.py @@ -72,7 +72,18 @@ def submit_prompt(self, prompt, **kwargs) -> str: f"Using model {self.config['model']} for {num_tokens} tokens (approx)" ) response = self.client.chat.completions.create( - model=model, + model=kwargs.get("model", None), + messages=prompt, + max_tokens=self.max_tokens, + stop=None, + temperature=self.temperature, + ) + elif kwargs.get("engine", None) is not None: + print( + f"Using model {self.config['model']} for {num_tokens} tokens (approx)" + ) + response = self.client.chat.completions.create( + engine=kwargs.get("engine", None), messages=prompt, max_tokens=self.max_tokens, stop=None, From 6e68cff9de4b452638533288fac54a3daf05ea04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Tue, 2 Apr 2024 19:27:35 +0200 Subject: [PATCH 3/3] Update prints to use user-defined model in OpenAI calls --- src/vanna/openai/openai_chat.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/vanna/openai/openai_chat.py b/src/vanna/openai/openai_chat.py index 2fd571d5..53990aa9 100644 --- a/src/vanna/openai/openai_chat.py +++ b/src/vanna/openai/openai_chat.py @@ -68,22 +68,24 @@ def submit_prompt(self, prompt, **kwargs) -> str: num_tokens += len(message["content"]) / 4 if kwargs.get("model", None) is not None: + model = kwargs.get("model", None) print( - f"Using model {self.config['model']} for {num_tokens} tokens (approx)" + f"Using model {model} for {num_tokens} tokens (approx)" ) response = self.client.chat.completions.create( - model=kwargs.get("model", None), + model=model, messages=prompt, max_tokens=self.max_tokens, stop=None, temperature=self.temperature, ) elif kwargs.get("engine", None) is not None: + engine = kwargs.get("engine", None) print( - f"Using model {self.config['model']} for {num_tokens} tokens (approx)" + f"Using model {engine} for {num_tokens} tokens (approx)" ) response = self.client.chat.completions.create( - engine=kwargs.get("engine", None), + engine=engine, messages=prompt, max_tokens=self.max_tokens, stop=None,