diff --git a/src/vanna/openai/openai_chat.py b/src/vanna/openai/openai_chat.py index 1dd3dcf2..3125c36b 100644 --- a/src/vanna/openai/openai_chat.py +++ b/src/vanna/openai/openai_chat.py @@ -1,30 +1,42 @@ +import os import re from abc import abstractmethod -import openai import pandas as pd +from openai import OpenAI from ..base import VannaBase class OpenAI_Chat(VannaBase): - def __init__(self, config=None): + def __init__(self, client=None, config=None): VannaBase.__init__(self, config=config) - if config is None: + if client is not None: + self.client = client + return + + if config is None and client is None: + self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) return if "api_type" in config: - openai.api_type = config["api_type"] + raise Exception( + "Passing api_type is now deprecated. Please pass an OpenAI client instead." + ) if "api_base" in config: - openai.api_base = config["api_base"] + raise Exception( + "Passing api_base is now deprecated. Please pass an OpenAI client instead." + ) if "api_version" in config: - openai.api_version = config["api_version"] + raise Exception( + "Passing api_version is now deprecated. Please pass an OpenAI client instead." + ) if "api_key" in config: - openai.api_key = config["api_key"] + self.client = OpenAI(api_key=config["api_key"]) @staticmethod def system_message(message: str) -> dict: @@ -43,34 +55,52 @@ def str_to_approx_token_count(string: str) -> int: return len(string) / 4 @staticmethod - def add_ddl_to_prompt(initial_prompt: str, ddl_list: list[str], max_tokens: int = 14000) -> str: + def add_ddl_to_prompt( + initial_prompt: str, ddl_list: list[str], max_tokens: int = 14000 + ) -> str: if len(ddl_list) > 0: initial_prompt += f"\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n" for ddl in ddl_list: - if OpenAI_Chat.str_to_approx_token_count(initial_prompt) + OpenAI_Chat.str_to_approx_token_count(ddl) < max_tokens: + if ( + OpenAI_Chat.str_to_approx_token_count(initial_prompt) + + OpenAI_Chat.str_to_approx_token_count(ddl) + < max_tokens + ): initial_prompt += f"{ddl}\n\n" return initial_prompt @staticmethod - def add_documentation_to_prompt(initial_prompt: str, documentation_list: list[str], max_tokens: int = 14000) -> str: + def add_documentation_to_prompt( + initial_prompt: str, documentation_list: list[str], max_tokens: int = 14000 + ) -> str: if len(documentation_list) > 0: initial_prompt += f"\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n" for documentation in documentation_list: - if OpenAI_Chat.str_to_approx_token_count(initial_prompt) + OpenAI_Chat.str_to_approx_token_count(documentation) < max_tokens: + if ( + OpenAI_Chat.str_to_approx_token_count(initial_prompt) + + OpenAI_Chat.str_to_approx_token_count(documentation) + < max_tokens + ): initial_prompt += f"{documentation}\n\n" return initial_prompt @staticmethod - def add_sql_to_prompt(initial_prompt: str, sql_list: list[str], max_tokens: int = 14000) -> str: + def add_sql_to_prompt( + initial_prompt: str, sql_list: list[str], max_tokens: int = 14000 + ) -> str: if len(sql_list) > 0: initial_prompt += f"\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n" for question in sql_list: - if OpenAI_Chat.str_to_approx_token_count(initial_prompt) + OpenAI_Chat.str_to_approx_token_count(question["sql"]) < max_tokens: + if ( + OpenAI_Chat.str_to_approx_token_count(initial_prompt) + + OpenAI_Chat.str_to_approx_token_count(question["sql"]) + < max_tokens + ): initial_prompt += f"{question['question']}\n{question['sql']}\n\n" return initial_prompt @@ -85,9 +115,13 @@ def get_sql_prompt( ): initial_prompt = "The user provides a question and you provide SQL. You will only respond with SQL code and not with any explanations.\n\nRespond with only SQL code. Do not answer with any explanations -- just the code.\n" - initial_prompt = OpenAI_Chat.add_ddl_to_prompt(initial_prompt, ddl_list, max_tokens=14000) + initial_prompt = OpenAI_Chat.add_ddl_to_prompt( + initial_prompt, ddl_list, max_tokens=14000 + ) - initial_prompt = OpenAI_Chat.add_documentation_to_prompt(initial_prompt, doc_list, max_tokens=14000) + initial_prompt = OpenAI_Chat.add_documentation_to_prompt( + initial_prompt, doc_list, max_tokens=14000 + ) message_log = [OpenAI_Chat.system_message(initial_prompt)] @@ -104,24 +138,34 @@ def get_sql_prompt( return message_log def get_followup_questions_prompt( - self, - question: str, + self, + question: str, df: pd.DataFrame, question_sql_list: list, ddl_list: list, - doc_list: list, - **kwargs + doc_list: list, + **kwargs, ): initial_prompt = f"The user initially asked the question: '{question}': \n\n" - initial_prompt = OpenAI_Chat.add_ddl_to_prompt(initial_prompt, ddl_list, max_tokens=14000) + initial_prompt = OpenAI_Chat.add_ddl_to_prompt( + initial_prompt, ddl_list, max_tokens=14000 + ) - initial_prompt = OpenAI_Chat.add_documentation_to_prompt(initial_prompt, doc_list, max_tokens=14000) + initial_prompt = OpenAI_Chat.add_documentation_to_prompt( + initial_prompt, doc_list, max_tokens=14000 + ) - initial_prompt = OpenAI_Chat.add_sql_to_prompt(initial_prompt, question_sql_list, max_tokens=14000) + initial_prompt = OpenAI_Chat.add_sql_to_prompt( + initial_prompt, question_sql_list, max_tokens=14000 + ) message_log = [OpenAI_Chat.system_message(initial_prompt)] - message_log.append(OpenAI_Chat.user_message("Generate a list of followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions.")) + message_log.append( + OpenAI_Chat.user_message( + "Generate a list of followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions." + ) + ) return message_log @@ -204,7 +248,7 @@ def submit_prompt(self, prompt, **kwargs) -> str: print( f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)" ) - response = openai.chat.completions.create( + response = self.client.chat.completions.create( engine=self.config["engine"], messages=prompt, max_tokens=500, @@ -215,7 +259,7 @@ def submit_prompt(self, prompt, **kwargs) -> str: print( f"Using model {self.config['model']} for {num_tokens} tokens (approx)" ) - response = openai.chat.completions.create( + response = self.client.chat.completions.create( model=self.config["model"], messages=prompt, max_tokens=500, @@ -229,7 +273,7 @@ def submit_prompt(self, prompt, **kwargs) -> str: model = "gpt-3.5-turbo" print(f"Using model {model} for {num_tokens} tokens (approx)") - response = openai.chat.completions.create( + response = self.client.chat.completions.create( model=model, messages=prompt, max_tokens=500, stop=None, temperature=0.7 )