Skip to content

Commit

Permalink
Merge pull request #1 from vanna-ai/init
Browse files Browse the repository at this point in the history
Init
  • Loading branch information
zainhoda authored Jun 21, 2023
2 parents 6d03e4a + cdbad80 commit 9f90a30
Show file tree
Hide file tree
Showing 9 changed files with 2,418 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
build
**.egg-info
venv
.DS_Store
7 changes: 7 additions & 0 deletions docs/index.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
<!doctype html>
<html>
<head>
<meta charset="utf-8">
<meta http-equiv="refresh" content="0; url=./vanna.html"/>
</head>
</html>
46 changes: 46 additions & 0 deletions docs/search.js

Large diffs are not rendered by default.

81 changes: 81 additions & 0 deletions docs/vanna-py-overview.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
---
marp: true
theme: gaia
_class: lead
paginate: true
backgroundColor: #111827
color: #fff
header: 'Updated: 2023-05-22'
---
<style>
strong {
font-family: 'Roboto Slab';
color: transparent !important;
background: linear-gradient(15deg, #009efd, #2af598);
background-clip: text;
-webkit-background-clip: text;
}
marp-pre {
font-family: 'Fira Code Light';
font-size: 0.75em;
background: #000;
border-radius: 30px;
}
</style>

![bg left:40% 80%](https://ask.vanna.ai/static/img/vanna.svg)

# **Vanna.AI**
## Python Package

For Natural Language to SQL
(and associated functionality)

[email protected]

---
# What can you do with **Vanna.AI**?

**Vanna.AI** has a Python package that allows you to convert natural language to SQL.

```python
import vanna as vn

vn.api_key = 'vanna-key-...' # Set your API key
vn.set_org('') # Set your organization name

my_question = 'What are the top 10 ABC by XYZ?'

sql = vn.generate_sql(question=my_question, error_msg=None)
# SELECT * FROM table_name WHERE column_name = 'value'

(my_df, error_msg) = vn.run_sql(cs: snowflake.Cursor, sql=sql)

vn.generate_plotly_code(question=my_question, df=my_df)
# fig = px.bar(df, x='column_name', y='column_name')

vn.run_plotly_code(plotly_code=fig, df=my_df)

```

---

# Installation

## Global Installation
```bash
pip install vanna
```
or
```bash
pip3 install vanna
```

## Use a Virtual Environment
```bash
python3 -m venv venv
source venv/bin/activate
pip install vanna
```

---
494 changes: 494 additions & 0 deletions docs/vanna.html

Large diffs are not rendered by default.

1,406 changes: 1,406 additions & 0 deletions docs/vanna/types.html

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vanna"
version = "0.0.1"
version = "0.0.2"
authors = [
{ name="Zain Hoda", email="[email protected]" },
]
Expand All @@ -12,6 +12,9 @@ classifiers = [
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
]
dependencies = [
"requests", "tabulate", "plotly"
]

[project.urls]
"Homepage" = "https://github.com/vanna-ai/vanna-py"
Expand Down
282 changes: 281 additions & 1 deletion src/vanna/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,281 @@
print("Vanna.AI Imported")
r'''
A module to interact with the Vanna.AI API, providing the functionality to generate SQL explanations.
```python
import vanna as vn
vn.api_key = 'vanna-key-...' # Set your API key
vn.set_org('') # Set your organization name
vn.store_sql(question="Who are the top 10 customers by Sales?", sql="SELECT customer_name, sales FROM customers ORDER BY sales DESC LIMIT 10")
my_question = 'What are the top 10 ABC by XYZ?'
sql = vn.generate_sql(question=my_question, error_msg=None)
# SELECT * FROM table_name WHERE column_name = 'value'
conn = snowflake.connector.connect(
user='my_user',
password='my_password',
account='my_account',
database='my_database',
)
cs = conn.cursor()
df = vn.get_results(cs, my_default_db, sql)
plotly_code = vn.generate_plotly_code(question="Who are the top 10 customers by Sales?", sql=sql, df=df)
# px.bar(df, x='column_name', y='column_name')
fig = vn.get_plotly_figure(plotly_code=plotly_code, df=df)
```
'''
print("Vanna.AI Imported")

import requests
import pandas as pd
import json
import dataclasses
import plotly
import plotly.express as px
import plotly.graph_objects as go
from .types import SQLAnswer, Explanation, QuestionSQLPair, Question, QuestionId, DataResult, PlotlyResult, Status
from typing import List, Dict, Any, Union, Optional

api_key: Union[str, None] = None # API key for Vanna.AI
__org: Union[str, None] = None # Organization name for Vanna.AI
_endpoint = "https://ask.vanna.ai/rpc"

def __rpc_call(method, params):
"""
Make a RPC call to the Vanna.AI API.
Args:
method (str): The name of the method to call.
params (list): A list of parameters for the method.
Returns:
dict: The JSON response from the API converted into a dictionary.
"""
global api_key
global __org

if api_key is None:
raise Exception("API key not set")

if __org is None:
raise Exception("Organization name not set")

headers = {
'Content-Type': 'application/json',
'Vanna-Key': api_key,
'Vanna-Org': __org
}
data = {
"method": method,
"params": [__dataclass_to_dict(obj) for obj in params]
}

response = requests.post(_endpoint, headers=headers, data=json.dumps(data))
return response.json()

def __dataclass_to_dict(obj):
"""
Converts a dataclass object to a dictionary.
Args:
obj (object): The dataclass object to convert.
Returns:
dict: The dataclass object as a dictionary.
"""
return dataclasses.asdict(obj)

def set_org(org: str) -> None:
"""
Set the organization name for the Vanna.AI API.
Args:
org (str): The organization name.
"""
global __org
__org = org

def store_sql(question: str, sql: str) -> bool:
"""
Store a question and its corresponding SQL query in the Vanna.AI database.
Args:
question (str): The question to store.
sql (str): The SQL query to store.
"""
params = [QuestionSQLPair(
question=question,
sql=sql,
)]

d = __rpc_call(method="store_sql", params=params)

if 'result' not in d:
return False

status = Status(**d['result'])

return status.success

def remove_sql(question: str) -> bool:
"""
Remove a question and its corresponding SQL query from the Vanna.AI database.
Args:
question (str): The question to remove.
"""
params = [Question(question=question)]

d = __rpc_call(method="remove_sql", params=params)

if 'result' not in d:
return False

status = Status(**d['result'])

return status.success

def generate_sql(question: str) -> str | None:
"""
Generate an SQL query using the Vanna.AI API.
Args:
question (str): The question to generate an SQL query for.
Returns:
str or None: The SQL query, or None if an error occurred.
"""
params = [Question(question=question)]

d = __rpc_call(method="generate_sql_from_question", params=params)

if 'result' not in d:
return None

# Load the result into a dataclass
sql_answer = SQLAnswer(**d['result'])

return sql_answer.sql

def generate_plotly_code(question: str | None, sql: str | None, df: pd.DataFrame) -> str | None:
"""
Generate Plotly code using the Vanna.AI API.
Args:
question (str): The question to generate Plotly code for.
sql (str): The SQL query to generate Plotly code for.
df (pd.DataFrame): The dataframe to generate Plotly code for.
Returns:
str or None: The Plotly code, or None if an error occurred.
"""
params = [DataResult(
question=question,
sql=sql,
table_markdown=df.head().to_markdown(),
error=None,
correction_attempts=0,
)]

d = __rpc_call(method="generate_plotly_code", params=params)

if 'result' not in d:
return None

# Load the result into a dataclass
plotly_code = PlotlyResult(**d['result'])

return plotly_code.plotly_code

def get_plotly_figure(plotly_code: str, df: pd.DataFrame, dark_mode: bool = True) -> plotly.graph_objs.Figure | None:
"""
Get a Plotly figure from a dataframe and Plotly code.
Args:
df (pd.DataFrame): The dataframe to use.
plotly_code (str): The Plotly code to use.
Returns:
plotly.graph_objs.Figure: The Plotly figure.
"""
ldict = {'df': df, 'px': px, 'go': go}
exec(plotly_code, globals(), ldict)

fig = ldict.get('fig', None)

if fig is None:
return None

if dark_mode:
fig.update_layout(template="plotly_dark")

return fig

def get_results(cs, default_database: str, sql: str) -> pd.DataFrame:
"""
Get the results of an SQL query using the Vanna.AI API.
:param cs: The Snowflake cursor to use.
:type cs: snowflake.connector.cursor.SnowflakeCursor
:param default_database: The default database to use (executed as "USE DATABASE {default_database};")
:type default_database: str
:param sql: The SQL query to run.
:type sql: str
:return: The results of the SQL query.
:rtype: pd.DataFrame
"""
cs.execute(f"USE DATABASE {default_database}")

cur = cs.execute(sql)

results = cur.fetchall()

# Create a pandas dataframe from the results
df = pd.DataFrame(results, columns=[desc[0] for desc in cur.description])

return df


def generate_explanation(sql: str) -> str | None:
"""
## Example
```python
vn.generate_explanation(sql="SELECT * FROM students WHERE name = 'John Doe'")
# 'AI Response'
```
Generate an explanation of an SQL query using the Vanna.AI API.
:param sql: The SQL query to explain.
:type sql: str
:return: The explanation of the SQL query, or None if an error occurred.
:rtype: str or None
"""
params = [SQLAnswer(
raw_answer="",
prefix="",
postfix="",
sql=sql,
)]

d = __rpc_call(method="generate_explanation", params=params)

if 'result' not in d:
return None

# Load the result into a dataclass
explanation = Explanation(**d['result'])

return explanation.explanation
Loading

0 comments on commit 9f90a30

Please sign in to comment.