Skip to content

Commit

Permalink
Merge pull request #271 from vanna-ai/see-query-results
Browse files Browse the repository at this point in the history
Generate summaries and followup questions
  • Loading branch information
zainhoda authored Mar 2, 2024
2 parents 9870b87 + da4f301 commit 35c7d4a
Show file tree
Hide file tree
Showing 5 changed files with 230 additions and 155 deletions.
6 changes: 0 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,3 @@ repos:
hooks:
- id: isort
args: [ "--profile", "black", "--filter-files" ]

- repo: https://github.com/odwyersoftware/brunette
rev: 238bead5ec5c58935d6bb12c70f435f70b2bf785
hooks:
- id: brunette
args: [ '--config=setup.cfg' ]
118 changes: 91 additions & 27 deletions src/vanna/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class VannaBase(ABC):
def __init__(self, config=None):
self.config = config
self.run_sql_is_set = False
self.static_documentation = ""

def log(self, message: str):
print(message)
Expand Down Expand Up @@ -140,18 +141,35 @@ def is_sql_valid(self, sql: str) -> bool:
else:
return False

def generate_followup_questions(self, question: str, **kwargs) -> str:
question_sql_list = self.get_similar_question_sql(question, **kwargs)
ddl_list = self.get_related_ddl(question, **kwargs)
doc_list = self.get_related_documentation(question, **kwargs)
prompt = self.get_followup_questions_prompt(
question=question,
question_sql_list=question_sql_list,
ddl_list=ddl_list,
doc_list=doc_list,
**kwargs,
)
llm_response = self.submit_prompt(prompt, **kwargs)
def generate_followup_questions(
self, question: str, sql: str, df: pd.DataFrame, **kwargs
) -> list:
"""
**Example:**
```python
vn.generate_followup_questions("What are the top 10 customers by sales?", df)
```
Generate a list of followup questions that you can ask Vanna.AI.
Args:
question (str): The question that was asked.
df (pd.DataFrame): The results of the SQL query.
Returns:
list: A list of followup questions that you can ask Vanna.AI.
"""

message_log = [
self.system_message(
f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe SQL query for this question was: {sql}\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n"
),
self.user_message(
"Generate a list of followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions. Remember that there should be an unambiguous SQL query that can be generated from the question. Prefer questions that are answerable outside of the context of this conversation. Prefer questions that are slight modifications of the SQL query that was generated that allow digging deeper into the data. Each question will be turned into a button that the user can click to generate a new SQL query so don't use 'example' type questions. Each question must have a one-to-one correspondence with an instantiated SQL query."
),
]

llm_response = self.submit_prompt(message_log, **kwargs)

numbers_removed = re.sub(r"^\d+\.\s*", "", llm_response, flags=re.MULTILINE)
return numbers_removed.split("\n")
Expand All @@ -169,6 +187,36 @@ def generate_questions(self, **kwargs) -> List[str]:

return [q["question"] for q in question_sql]

def generate_summary(self, question: str, df: pd.DataFrame, **kwargs) -> str:
"""
**Example:**
```python
vn.generate_summary("What are the top 10 customers by sales?", df)
```
Generate a summary of the results of a SQL query.
Args:
question (str): The question that was asked.
df (pd.DataFrame): The results of the SQL query.
Returns:
str: The summary of the results of the SQL query.
"""

message_log = [
self.system_message(
f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n"
),
self.user_message(
"Briefly summarize the data based on the question that was asked. Do not respond with any additional explanation beyond the summary."
),
]

summary = self.submit_prompt(message_log, **kwargs)

return summary

# ----------------- Use Any Embeddings API ----------------- #
@abstractmethod
def generate_embedding(self, data: str, **kwargs) -> List[float]:
Expand All @@ -184,7 +232,7 @@ def get_similar_question_sql(self, question: str, **kwargs) -> list:
question (str): The question to get similar questions and their corresponding SQL statements for.
Returns:
list: A list of similar questions and their corresponding SQL statements.
list: A list of similar questions and their corresponding SQL statements.
"""
pass

Expand Down Expand Up @@ -224,15 +272,15 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
sql (str): The SQL query to add.
Returns:
str: The ID of the training data that was added.
str: The ID of the training data that was added.
"""
pass

@abstractmethod
def add_ddl(self, ddl: str, **kwargs) -> str:
"""
This method is used to add a DDL statement to the training data.
Args:
ddl (str): The DDL statement to add.
Expand Down Expand Up @@ -265,7 +313,7 @@ def get_training_data(self, **kwargs) -> pd.DataFrame:
This method is used to get all the training data from the retrieval layer.
Returns:
pd.DataFrame: The training data.
pd.DataFrame: The training data.
"""
pass

Expand Down Expand Up @@ -321,7 +369,10 @@ def add_ddl_to_prompt(
return initial_prompt

def add_documentation_to_prompt(
self, initial_prompt: str, documentation_list: list[str], max_tokens: int = 14000
self,
initial_prompt: str,
documentation_list: list[str],
max_tokens: int = 14000,
) -> str:
if len(documentation_list) > 0:
initial_prompt += f"\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
Expand Down Expand Up @@ -389,6 +440,9 @@ def get_sql_prompt(
initial_prompt, ddl_list, max_tokens=14000
)

if self.static_documentation != "":
doc_list.append(self.static_documentation)

initial_prompt = self.add_documentation_to_prompt(
initial_prompt, doc_list, max_tokens=14000
)
Expand Down Expand Up @@ -599,6 +653,7 @@ def run_sql_snowflake(sql: str) -> pd.DataFrame:

return df

self.static_documentation = "This is a Snowflake database"
self.run_sql = run_sql_snowflake
self.run_sql_is_set = True

Expand Down Expand Up @@ -632,6 +687,7 @@ def connect_to_sqlite(self, url: str):
def run_sql_sqlite(sql: str):
return pd.read_sql_query(sql, conn)

self.static_documentation = "This is a SQLite database"
self.run_sql = run_sql_sqlite
self.run_sql_is_set = True

Expand Down Expand Up @@ -731,11 +787,12 @@ def run_sql_postgres(sql: str) -> Union[pd.DataFrame, None]:
except psycopg2.Error as e:
conn.rollback()
raise ValidationError(e)

except Exception as e:
conn.rollback()
raise e

self.static_documentation = "This is a Postgres database"
self.run_sql_is_set = True
self.run_sql = run_sql_postgres

Expand Down Expand Up @@ -825,6 +882,7 @@ def run_sql_bigquery(sql: str) -> Union[pd.DataFrame, None]:
raise errors
return None

self.static_documentation = "This is a BigQuery database"
self.run_sql_is_set = True
self.run_sql = run_sql_bigquery

Expand All @@ -847,13 +905,13 @@ def connect_to_duckdb(self, url: str, init_sql: str = None):
" run command: \npip install vanna[duckdb]"
)
# URL of the database to download
if url==":memory:" or url=="":
path=":memory:"
if url == ":memory:" or url == "":
path = ":memory:"
else:
# Path to save the downloaded database
print(os.path.exists(url))
if os.path.exists(url):
path=url
path = url
elif url.startswith("md") or url.startswith("motherduck"):
path = url
else:
Expand All @@ -873,6 +931,7 @@ def connect_to_duckdb(self, url: str, init_sql: str = None):
def run_sql_duckdb(sql: str):
return conn.query(sql).to_df()

self.static_documentation = "This is a DuckDB database"
self.run_sql = run_sql_duckdb
self.run_sql_is_set = True

Expand All @@ -895,27 +954,31 @@ def connect_to_mssql(self, odbc_conn_str: str):
)

try:
from sqlalchemy.engine import URL
import sqlalchemy as sa
from sqlalchemy.engine import URL
except ImportError:
raise DependencyError(
"You need to install required dependencies to execute this method,"
" run command: pip install sqlalchemy"
)

connection_url = URL.create("mssql+pyodbc", query={"odbc_connect": odbc_conn_str})
connection_url = URL.create(
"mssql+pyodbc", query={"odbc_connect": odbc_conn_str}
)

from sqlalchemy import create_engine

engine = create_engine(connection_url)

def run_sql_mssql(sql: str):
# Execute the SQL statement and return the result as a pandas DataFrame
with engine.begin() as conn:
df = pd.read_sql_query(sa.text(sql), conn)
return df

raise Exception("Couldn't run sql")

self.static_documentation = "This is a Microsoft SQL Server database"
self.run_sql = run_sql_mssql
self.run_sql_is_set = True

Expand Down Expand Up @@ -943,7 +1006,7 @@ def ask(
question: Union[str, None] = None,
print_results: bool = True,
auto_train: bool = True,
visualize: bool = True, # if False, will not generate plotly code
visualize: bool = True, # if False, will not generate plotly code
) -> Union[
Tuple[
Union[str, None],
Expand Down Expand Up @@ -1024,7 +1087,9 @@ def ask(
display = __import__(
"IPython.display", fromlist=["display"]
).display
Image = __import__("IPython.display", fromlist=["Image"]).Image
Image = __import__(
"IPython.display", fromlist=["Image"]
).Image
img_bytes = fig.to_image(format="png", scale=2)
display(Image(img_bytes))
except Exception as e:
Expand Down Expand Up @@ -1377,4 +1442,3 @@ def get_plotly_figure(
fig.update_layout(template="plotly_dark")

return fig

Loading

0 comments on commit 35c7d4a

Please sign in to comment.