Skip to content

Commit

Permalink
Merge pull request #178 from vanna-ai/update-azure-openai
Browse files Browse the repository at this point in the history
Update OpenAI class for new Azure API
  • Loading branch information
zainhoda authored Jan 23, 2024
2 parents 2237677 + 4b7502d commit b425397
Showing 1 changed file with 70 additions and 26 deletions.
96 changes: 70 additions & 26 deletions src/vanna/openai/openai_chat.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,42 @@
import os
import re
from abc import abstractmethod

import openai
import pandas as pd
from openai import OpenAI

from ..base import VannaBase


class OpenAI_Chat(VannaBase):
def __init__(self, config=None):
def __init__(self, client=None, config=None):
VannaBase.__init__(self, config=config)

if config is None:
if client is not None:
self.client = client
return

if config is None and client is None:
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
return

if "api_type" in config:
openai.api_type = config["api_type"]
raise Exception(
"Passing api_type is now deprecated. Please pass an OpenAI client instead."
)

if "api_base" in config:
openai.api_base = config["api_base"]
raise Exception(
"Passing api_base is now deprecated. Please pass an OpenAI client instead."
)

if "api_version" in config:
openai.api_version = config["api_version"]
raise Exception(
"Passing api_version is now deprecated. Please pass an OpenAI client instead."
)

if "api_key" in config:
openai.api_key = config["api_key"]
self.client = OpenAI(api_key=config["api_key"])

@staticmethod
def system_message(message: str) -> dict:
Expand All @@ -43,34 +55,52 @@ def str_to_approx_token_count(string: str) -> int:
return len(string) / 4

@staticmethod
def add_ddl_to_prompt(initial_prompt: str, ddl_list: list[str], max_tokens: int = 14000) -> str:
def add_ddl_to_prompt(
initial_prompt: str, ddl_list: list[str], max_tokens: int = 14000
) -> str:
if len(ddl_list) > 0:
initial_prompt += f"\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"

for ddl in ddl_list:
if OpenAI_Chat.str_to_approx_token_count(initial_prompt) + OpenAI_Chat.str_to_approx_token_count(ddl) < max_tokens:
if (
OpenAI_Chat.str_to_approx_token_count(initial_prompt)
+ OpenAI_Chat.str_to_approx_token_count(ddl)
< max_tokens
):
initial_prompt += f"{ddl}\n\n"

return initial_prompt

@staticmethod
def add_documentation_to_prompt(initial_prompt: str, documentation_list: list[str], max_tokens: int = 14000) -> str:
def add_documentation_to_prompt(
initial_prompt: str, documentation_list: list[str], max_tokens: int = 14000
) -> str:
if len(documentation_list) > 0:
initial_prompt += f"\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"

for documentation in documentation_list:
if OpenAI_Chat.str_to_approx_token_count(initial_prompt) + OpenAI_Chat.str_to_approx_token_count(documentation) < max_tokens:
if (
OpenAI_Chat.str_to_approx_token_count(initial_prompt)
+ OpenAI_Chat.str_to_approx_token_count(documentation)
< max_tokens
):
initial_prompt += f"{documentation}\n\n"

return initial_prompt

@staticmethod
def add_sql_to_prompt(initial_prompt: str, sql_list: list[str], max_tokens: int = 14000) -> str:
def add_sql_to_prompt(
initial_prompt: str, sql_list: list[str], max_tokens: int = 14000
) -> str:
if len(sql_list) > 0:
initial_prompt += f"\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"

for question in sql_list:
if OpenAI_Chat.str_to_approx_token_count(initial_prompt) + OpenAI_Chat.str_to_approx_token_count(question["sql"]) < max_tokens:
if (
OpenAI_Chat.str_to_approx_token_count(initial_prompt)
+ OpenAI_Chat.str_to_approx_token_count(question["sql"])
< max_tokens
):
initial_prompt += f"{question['question']}\n{question['sql']}\n\n"

return initial_prompt
Expand All @@ -85,9 +115,13 @@ def get_sql_prompt(
):
initial_prompt = "The user provides a question and you provide SQL. You will only respond with SQL code and not with any explanations.\n\nRespond with only SQL code. Do not answer with any explanations -- just the code.\n"

initial_prompt = OpenAI_Chat.add_ddl_to_prompt(initial_prompt, ddl_list, max_tokens=14000)
initial_prompt = OpenAI_Chat.add_ddl_to_prompt(
initial_prompt, ddl_list, max_tokens=14000
)

initial_prompt = OpenAI_Chat.add_documentation_to_prompt(initial_prompt, doc_list, max_tokens=14000)
initial_prompt = OpenAI_Chat.add_documentation_to_prompt(
initial_prompt, doc_list, max_tokens=14000
)

message_log = [OpenAI_Chat.system_message(initial_prompt)]

Expand All @@ -104,24 +138,34 @@ def get_sql_prompt(
return message_log

def get_followup_questions_prompt(
self,
question: str,
self,
question: str,
df: pd.DataFrame,
question_sql_list: list,
ddl_list: list,
doc_list: list,
**kwargs
doc_list: list,
**kwargs,
):
initial_prompt = f"The user initially asked the question: '{question}': \n\n"

initial_prompt = OpenAI_Chat.add_ddl_to_prompt(initial_prompt, ddl_list, max_tokens=14000)
initial_prompt = OpenAI_Chat.add_ddl_to_prompt(
initial_prompt, ddl_list, max_tokens=14000
)

initial_prompt = OpenAI_Chat.add_documentation_to_prompt(initial_prompt, doc_list, max_tokens=14000)
initial_prompt = OpenAI_Chat.add_documentation_to_prompt(
initial_prompt, doc_list, max_tokens=14000
)

initial_prompt = OpenAI_Chat.add_sql_to_prompt(initial_prompt, question_sql_list, max_tokens=14000)
initial_prompt = OpenAI_Chat.add_sql_to_prompt(
initial_prompt, question_sql_list, max_tokens=14000
)

message_log = [OpenAI_Chat.system_message(initial_prompt)]
message_log.append(OpenAI_Chat.user_message("Generate a list of followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions."))
message_log.append(
OpenAI_Chat.user_message(
"Generate a list of followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions."
)
)

return message_log

Expand Down Expand Up @@ -204,7 +248,7 @@ def submit_prompt(self, prompt, **kwargs) -> str:
print(
f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)"
)
response = openai.chat.completions.create(
response = self.client.chat.completions.create(
engine=self.config["engine"],
messages=prompt,
max_tokens=500,
Expand All @@ -215,7 +259,7 @@ def submit_prompt(self, prompt, **kwargs) -> str:
print(
f"Using model {self.config['model']} for {num_tokens} tokens (approx)"
)
response = openai.chat.completions.create(
response = self.client.chat.completions.create(
model=self.config["model"],
messages=prompt,
max_tokens=500,
Expand All @@ -229,7 +273,7 @@ def submit_prompt(self, prompt, **kwargs) -> str:
model = "gpt-3.5-turbo"

print(f"Using model {model} for {num_tokens} tokens (approx)")
response = openai.chat.completions.create(
response = self.client.chat.completions.create(
model=model, messages=prompt, max_tokens=500, stop=None, temperature=0.7
)

Expand Down

0 comments on commit b425397

Please sign in to comment.