Skip to content

Commit

Permalink
Merge pull request #508 from Molrn/split-flask-api-and-app
Browse files Browse the repository at this point in the history
Split VannaFlaskApp into an API and full app
  • Loading branch information
zainhoda authored Jul 25, 2024
2 parents de94875 + d253073 commit efd4cf5
Showing 1 changed file with 151 additions and 131 deletions.
282 changes: 151 additions & 131 deletions src/vanna/flask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from .assets import css_content, html_content, js_content
from .auth import AuthInterface, NoAuth
from ..base import VannaBase


class Cache(ABC):
Expand Down Expand Up @@ -88,7 +89,8 @@ def delete(self, id):
if id in self.cache:
del self.cache[id]

class VannaFlaskApp:

class VannaFlaskAPI:
flask_app = None

def requires_cache(self, required_fields, optional_fields=[]):
Expand Down Expand Up @@ -135,83 +137,44 @@ def decorated(*args, **kwargs):

return decorated

def __init__(self, vn, cache: Cache = MemoryCache(),
auth: AuthInterface = NoAuth(),
debug=True,
allow_llm_to_see_data=False,
logo="https://img.vanna.ai/vanna-flask.svg",
title="Welcome to Vanna.AI",
subtitle="Your AI-powered copilot for SQL queries.",
show_training_data=True,
suggested_questions=True,
sql=True,
table=True,
csv_download=True,
chart=True,
redraw_chart=True,
auto_fix_sql=True,
ask_results_correct=True,
followup_questions=True,
summarization=True,
function_generation=True,
index_html_path=None,
assets_folder=None,
):
def __init__(
self,
vn: VannaBase,
cache: Cache = MemoryCache(),
auth: AuthInterface = NoAuth(),
debug=True,
allow_llm_to_see_data=False,
chart=True,
):
"""
Expose a Flask app that can be used to interact with a Vanna instance.
Expose a Flask API that can be used to interact with a Vanna instance.
Args:
vn: The Vanna instance to interact with.
cache: The cache to use. Defaults to MemoryCache, which uses an in-memory cache. You can also pass in a custom cache that implements the Cache interface.
auth: The authentication method to use. Defaults to NoAuth, which doesn't require authentication. You can also pass in a custom authentication method that implements the AuthInterface interface.
debug: Show the debug console. Defaults to True.
allow_llm_to_see_data: Whether to allow the LLM to see data. Defaults to False.
logo: The logo to display in the UI. Defaults to the Vanna logo.
title: The title to display in the UI. Defaults to "Welcome to Vanna.AI".
subtitle: The subtitle to display in the UI. Defaults to "Your AI-powered copilot for SQL queries.".
show_training_data: Whether to show the training data in the UI. Defaults to True.
suggested_questions: Whether to show suggested questions in the UI. Defaults to True.
sql: Whether to show the SQL input in the UI. Defaults to True.
table: Whether to show the table output in the UI. Defaults to True.
csv_download: Whether to allow downloading the table output as a CSV file. Defaults to True.
chart: Whether to show the chart output in the UI. Defaults to True.
redraw_chart: Whether to allow redrawing the chart. Defaults to True.
auto_fix_sql: Whether to allow auto-fixing SQL errors. Defaults to True.
ask_results_correct: Whether to ask the user if the results are correct. Defaults to True.
followup_questions: Whether to show followup questions. Defaults to True.
summarization: Whether to show summarization. Defaults to True.
index_html_path: Path to the index.html. Defaults to None, which will use the default index.html
assets_folder: The location where you'd like to serve the static assets from. Defaults to None, which will use hardcoded Python variables.
Returns:
None
"""

self.flask_app = Flask(__name__)
self.sock = Sock(self.flask_app)
self.ws_clients = []
self.vn = vn
self.debug = debug
self.auth = auth
self.cache = cache
self.debug = debug
self.allow_llm_to_see_data = allow_llm_to_see_data
self.logo = logo
self.title = title
self.subtitle = subtitle
self.show_training_data = show_training_data
self.suggested_questions = suggested_questions
self.sql = sql
self.table = table
self.csv_download = csv_download
self.chart = chart
self.redraw_chart = redraw_chart
self.auto_fix_sql = auto_fix_sql
self.ask_results_correct = ask_results_correct
self.followup_questions = followup_questions
self.summarization = summarization
self.function_generation = function_generation and hasattr(vn, "get_function")
self.index_html_path = index_html_path
self.assets_folder = assets_folder

self.config = {
"debug": debug,
"allow_llm_to_see_data": allow_llm_to_see_data,
"chart": chart,
}
log = logging.getLogger("werkzeug")
log.setLevel(logging.ERROR)

Expand All @@ -225,42 +188,10 @@ def log(message, title="Info"):

self.vn.log = log

@self.flask_app.route("/auth/login", methods=["POST"])
def login():
return self.auth.login_handler(flask.request)

@self.flask_app.route("/auth/callback", methods=["GET"])
def callback():
return self.auth.callback_handler(flask.request)

@self.flask_app.route("/auth/logout", methods=["GET"])
def logout():
return self.auth.logout_handler(flask.request)

@self.flask_app.route("/api/v0/get_config", methods=["GET"])
@self.requires_auth
def get_config(user: any):
config = {
"debug": self.debug,
"logo": self.logo,
"title": self.title,
"subtitle": self.subtitle,
"show_training_data": self.show_training_data,
"suggested_questions": self.suggested_questions,
"sql": self.sql,
"table": self.table,
"csv_download": self.csv_download,
"chart": self.chart,
"redraw_chart": self.redraw_chart,
"auto_fix_sql": self.auto_fix_sql,
"ask_results_correct": self.ask_results_correct,
"followup_questions": self.followup_questions,
"summarization": self.summarization,
"function_generation": self.function_generation,
}

config = self.auth.override_config_for_user(user, config)

config = self.auth.override_config_for_user(user, self.config)
return jsonify(
{
"type": "config",
Expand Down Expand Up @@ -718,6 +649,136 @@ def catch_all(catch_all):
{"type": "error", "error": "The rest of the API is not ported yet."}
)

if self.debug:
@self.sock.route("/api/v0/log")
def sock_log(ws):
self.ws_clients.append(ws)

try:
while True:
message = ws.receive() # This example just reads and ignores to keep the socket open
finally:
self.ws_clients.remove(ws)

def run(self, *args, **kwargs):
"""
Run the Flask app.
Args:
*args: Arguments to pass to Flask's run method.
**kwargs: Keyword arguments to pass to Flask's run method.
Returns:
None
"""
if args or kwargs:
self.flask_app.run(*args, **kwargs)

else:
try:
from google.colab import output

output.serve_kernel_port_as_window(8084)
from google.colab.output import eval_js

print("Your app is running at:")
print(eval_js("google.colab.kernel.proxyPort(8084)"))
except:
print("Your app is running at:")
print("http://localhost:8084")

self.flask_app.run(host="0.0.0.0", port=8084, debug=self.debug, use_reloader=False)


class VannaFlaskApp(VannaFlaskAPI):
def __init__(
self,
vn: VannaBase,
cache: Cache = MemoryCache(),
auth: AuthInterface = NoAuth(),
debug=True,
allow_llm_to_see_data=False,
logo="https://img.vanna.ai/vanna-flask.svg",
title="Welcome to Vanna.AI",
subtitle="Your AI-powered copilot for SQL queries.",
show_training_data=True,
suggested_questions=True,
sql=True,
table=True,
csv_download=True,
chart=True,
redraw_chart=True,
auto_fix_sql=True,
ask_results_correct=True,
followup_questions=True,
summarization=True,
function_generation=True,
index_html_path=None,
assets_folder=None,
):
"""
Expose a Flask app that can be used to interact with a Vanna instance.
Args:
vn: The Vanna instance to interact with.
cache: The cache to use. Defaults to MemoryCache, which uses an in-memory cache. You can also pass in a custom cache that implements the Cache interface.
auth: The authentication method to use. Defaults to NoAuth, which doesn't require authentication. You can also pass in a custom authentication method that implements the AuthInterface interface.
debug: Show the debug console. Defaults to True.
allow_llm_to_see_data: Whether to allow the LLM to see data. Defaults to False.
logo: The logo to display in the UI. Defaults to the Vanna logo.
title: The title to display in the UI. Defaults to "Welcome to Vanna.AI".
subtitle: The subtitle to display in the UI. Defaults to "Your AI-powered copilot for SQL queries.".
show_training_data: Whether to show the training data in the UI. Defaults to True.
suggested_questions: Whether to show suggested questions in the UI. Defaults to True.
sql: Whether to show the SQL input in the UI. Defaults to True.
table: Whether to show the table output in the UI. Defaults to True.
csv_download: Whether to allow downloading the table output as a CSV file. Defaults to True.
chart: Whether to show the chart output in the UI. Defaults to True.
redraw_chart: Whether to allow redrawing the chart. Defaults to True.
auto_fix_sql: Whether to allow auto-fixing SQL errors. Defaults to True.
ask_results_correct: Whether to ask the user if the results are correct. Defaults to True.
followup_questions: Whether to show followup questions. Defaults to True.
summarization: Whether to show summarization. Defaults to True.
index_html_path: Path to the index.html. Defaults to None, which will use the default index.html
assets_folder: The location where you'd like to serve the static assets from. Defaults to None, which will use hardcoded Python variables.
Returns:
None
"""
super().__init__(vn, cache, auth, debug, allow_llm_to_see_data, chart)

self.config["logo"] = logo
self.config["title"] = title
self.config["subtitle"] = subtitle
self.config["show_training_data"] = show_training_data
self.config["suggested_questions"] = suggested_questions
self.config["sql"] = sql
self.config["table"] = table
self.config["csv_download"] = csv_download
self.config["chart"] = chart
self.config["redraw_chart"] = redraw_chart
self.config["auto_fix_sql"] = auto_fix_sql
self.config["ask_results_correct"] = ask_results_correct
self.config["followup_questions"] = followup_questions
self.config["summarization"] = summarization
self.config["function_generation"] = function_generation

self.index_html_path = index_html_path
self.assets_folder = assets_folder

@self.flask_app.route("/auth/login", methods=["POST"])
def login():
return self.auth.login_handler(flask.request)

@self.flask_app.route("/auth/callback", methods=["GET"])
def callback():
return self.auth.callback_handler(flask.request)

@self.flask_app.route("/auth/logout", methods=["GET"])
def logout():
return self.auth.logout_handler(flask.request)


@self.flask_app.route("/assets/<path:filename>")
def proxy_assets(filename):
if self.assets_folder:
Expand Down Expand Up @@ -755,18 +816,6 @@ def proxy_vanna_svg():
else:
return "Error fetching file from remote server", response.status_code

if self.debug:
@self.sock.route("/api/v0/log")
def sock_log(ws):
self.ws_clients.append(ws)

try:
while True:
message = ws.receive() # This example just reads and ignores to keep the socket open
finally:
self.ws_clients.remove(ws)


@self.flask_app.route("/", defaults={"path": ""})
@self.flask_app.route("/<path:path>")
def hello(path: str):
Expand All @@ -775,32 +824,3 @@ def hello(path: str):
filename = os.path.basename(self.index_html_path)
return send_from_directory(directory=directory, path=filename)
return html_content

def run(self, *args, **kwargs):
"""
Run the Flask app.
Args:
*args: Arguments to pass to Flask's run method.
**kwargs: Keyword arguments to pass to Flask's run method.
Returns:
None
"""
if args or kwargs:
self.flask_app.run(*args, **kwargs)

else:
try:
from google.colab import output

output.serve_kernel_port_as_window(8084)
from google.colab.output import eval_js

print("Your app is running at:")
print(eval_js("google.colab.kernel.proxyPort(8084)"))
except:
print("Your app is running at:")
print("http://localhost:8084")

self.flask_app.run(host="0.0.0.0", port=8084, debug=self.debug, use_reloader=False)

0 comments on commit efd4cf5

Please sign in to comment.