diff --git a/src/vanna/flask/__init__.py b/src/vanna/flask/__init__.py index 83c2f8f4..9d51268c 100644 --- a/src/vanna/flask/__init__.py +++ b/src/vanna/flask/__init__.py @@ -13,6 +13,7 @@ from .assets import css_content, html_content, js_content from .auth import AuthInterface, NoAuth +from ..base import VannaBase class Cache(ABC): @@ -88,7 +89,8 @@ def delete(self, id): if id in self.cache: del self.cache[id] -class VannaFlaskApp: + +class VannaFlaskAPI: flask_app = None def requires_cache(self, required_fields, optional_fields=[]): @@ -135,30 +137,17 @@ def decorated(*args, **kwargs): return decorated - def __init__(self, vn, cache: Cache = MemoryCache(), - auth: AuthInterface = NoAuth(), - debug=True, - allow_llm_to_see_data=False, - logo="https://img.vanna.ai/vanna-flask.svg", - title="Welcome to Vanna.AI", - subtitle="Your AI-powered copilot for SQL queries.", - show_training_data=True, - suggested_questions=True, - sql=True, - table=True, - csv_download=True, - chart=True, - redraw_chart=True, - auto_fix_sql=True, - ask_results_correct=True, - followup_questions=True, - summarization=True, - function_generation=True, - index_html_path=None, - assets_folder=None, - ): + def __init__( + self, + vn: VannaBase, + cache: Cache = MemoryCache(), + auth: AuthInterface = NoAuth(), + debug=True, + allow_llm_to_see_data=False, + chart=True, + ): """ - Expose a Flask app that can be used to interact with a Vanna instance. + Expose a Flask API that can be used to interact with a Vanna instance. Args: vn: The Vanna instance to interact with. @@ -166,52 +155,26 @@ def __init__(self, vn, cache: Cache = MemoryCache(), auth: The authentication method to use. Defaults to NoAuth, which doesn't require authentication. You can also pass in a custom authentication method that implements the AuthInterface interface. debug: Show the debug console. Defaults to True. allow_llm_to_see_data: Whether to allow the LLM to see data. Defaults to False. - logo: The logo to display in the UI. Defaults to the Vanna logo. - title: The title to display in the UI. Defaults to "Welcome to Vanna.AI". - subtitle: The subtitle to display in the UI. Defaults to "Your AI-powered copilot for SQL queries.". - show_training_data: Whether to show the training data in the UI. Defaults to True. - suggested_questions: Whether to show suggested questions in the UI. Defaults to True. - sql: Whether to show the SQL input in the UI. Defaults to True. - table: Whether to show the table output in the UI. Defaults to True. - csv_download: Whether to allow downloading the table output as a CSV file. Defaults to True. chart: Whether to show the chart output in the UI. Defaults to True. - redraw_chart: Whether to allow redrawing the chart. Defaults to True. - auto_fix_sql: Whether to allow auto-fixing SQL errors. Defaults to True. - ask_results_correct: Whether to ask the user if the results are correct. Defaults to True. - followup_questions: Whether to show followup questions. Defaults to True. - summarization: Whether to show summarization. Defaults to True. - index_html_path: Path to the index.html. Defaults to None, which will use the default index.html - assets_folder: The location where you'd like to serve the static assets from. Defaults to None, which will use hardcoded Python variables. Returns: None """ + self.flask_app = Flask(__name__) self.sock = Sock(self.flask_app) self.ws_clients = [] self.vn = vn - self.debug = debug self.auth = auth self.cache = cache + self.debug = debug self.allow_llm_to_see_data = allow_llm_to_see_data - self.logo = logo - self.title = title - self.subtitle = subtitle - self.show_training_data = show_training_data - self.suggested_questions = suggested_questions - self.sql = sql - self.table = table - self.csv_download = csv_download self.chart = chart - self.redraw_chart = redraw_chart - self.auto_fix_sql = auto_fix_sql - self.ask_results_correct = ask_results_correct - self.followup_questions = followup_questions - self.summarization = summarization - self.function_generation = function_generation and hasattr(vn, "get_function") - self.index_html_path = index_html_path - self.assets_folder = assets_folder - + self.config = { + "debug": debug, + "allow_llm_to_see_data": allow_llm_to_see_data, + "chart": chart, + } log = logging.getLogger("werkzeug") log.setLevel(logging.ERROR) @@ -225,42 +188,10 @@ def log(message, title="Info"): self.vn.log = log - @self.flask_app.route("/auth/login", methods=["POST"]) - def login(): - return self.auth.login_handler(flask.request) - - @self.flask_app.route("/auth/callback", methods=["GET"]) - def callback(): - return self.auth.callback_handler(flask.request) - - @self.flask_app.route("/auth/logout", methods=["GET"]) - def logout(): - return self.auth.logout_handler(flask.request) - @self.flask_app.route("/api/v0/get_config", methods=["GET"]) @self.requires_auth def get_config(user: any): - config = { - "debug": self.debug, - "logo": self.logo, - "title": self.title, - "subtitle": self.subtitle, - "show_training_data": self.show_training_data, - "suggested_questions": self.suggested_questions, - "sql": self.sql, - "table": self.table, - "csv_download": self.csv_download, - "chart": self.chart, - "redraw_chart": self.redraw_chart, - "auto_fix_sql": self.auto_fix_sql, - "ask_results_correct": self.ask_results_correct, - "followup_questions": self.followup_questions, - "summarization": self.summarization, - "function_generation": self.function_generation, - } - - config = self.auth.override_config_for_user(user, config) - + config = self.auth.override_config_for_user(user, self.config) return jsonify( { "type": "config", @@ -718,6 +649,136 @@ def catch_all(catch_all): {"type": "error", "error": "The rest of the API is not ported yet."} ) + if self.debug: + @self.sock.route("/api/v0/log") + def sock_log(ws): + self.ws_clients.append(ws) + + try: + while True: + message = ws.receive() # This example just reads and ignores to keep the socket open + finally: + self.ws_clients.remove(ws) + + def run(self, *args, **kwargs): + """ + Run the Flask app. + + Args: + *args: Arguments to pass to Flask's run method. + **kwargs: Keyword arguments to pass to Flask's run method. + + Returns: + None + """ + if args or kwargs: + self.flask_app.run(*args, **kwargs) + + else: + try: + from google.colab import output + + output.serve_kernel_port_as_window(8084) + from google.colab.output import eval_js + + print("Your app is running at:") + print(eval_js("google.colab.kernel.proxyPort(8084)")) + except: + print("Your app is running at:") + print("http://localhost:8084") + + self.flask_app.run(host="0.0.0.0", port=8084, debug=self.debug, use_reloader=False) + + +class VannaFlaskApp(VannaFlaskAPI): + def __init__( + self, + vn: VannaBase, + cache: Cache = MemoryCache(), + auth: AuthInterface = NoAuth(), + debug=True, + allow_llm_to_see_data=False, + logo="https://img.vanna.ai/vanna-flask.svg", + title="Welcome to Vanna.AI", + subtitle="Your AI-powered copilot for SQL queries.", + show_training_data=True, + suggested_questions=True, + sql=True, + table=True, + csv_download=True, + chart=True, + redraw_chart=True, + auto_fix_sql=True, + ask_results_correct=True, + followup_questions=True, + summarization=True, + function_generation=True, + index_html_path=None, + assets_folder=None, + ): + """ + Expose a Flask app that can be used to interact with a Vanna instance. + + Args: + vn: The Vanna instance to interact with. + cache: The cache to use. Defaults to MemoryCache, which uses an in-memory cache. You can also pass in a custom cache that implements the Cache interface. + auth: The authentication method to use. Defaults to NoAuth, which doesn't require authentication. You can also pass in a custom authentication method that implements the AuthInterface interface. + debug: Show the debug console. Defaults to True. + allow_llm_to_see_data: Whether to allow the LLM to see data. Defaults to False. + logo: The logo to display in the UI. Defaults to the Vanna logo. + title: The title to display in the UI. Defaults to "Welcome to Vanna.AI". + subtitle: The subtitle to display in the UI. Defaults to "Your AI-powered copilot for SQL queries.". + show_training_data: Whether to show the training data in the UI. Defaults to True. + suggested_questions: Whether to show suggested questions in the UI. Defaults to True. + sql: Whether to show the SQL input in the UI. Defaults to True. + table: Whether to show the table output in the UI. Defaults to True. + csv_download: Whether to allow downloading the table output as a CSV file. Defaults to True. + chart: Whether to show the chart output in the UI. Defaults to True. + redraw_chart: Whether to allow redrawing the chart. Defaults to True. + auto_fix_sql: Whether to allow auto-fixing SQL errors. Defaults to True. + ask_results_correct: Whether to ask the user if the results are correct. Defaults to True. + followup_questions: Whether to show followup questions. Defaults to True. + summarization: Whether to show summarization. Defaults to True. + index_html_path: Path to the index.html. Defaults to None, which will use the default index.html + assets_folder: The location where you'd like to serve the static assets from. Defaults to None, which will use hardcoded Python variables. + + Returns: + None + """ + super().__init__(vn, cache, auth, debug, allow_llm_to_see_data, chart) + + self.config["logo"] = logo + self.config["title"] = title + self.config["subtitle"] = subtitle + self.config["show_training_data"] = show_training_data + self.config["suggested_questions"] = suggested_questions + self.config["sql"] = sql + self.config["table"] = table + self.config["csv_download"] = csv_download + self.config["chart"] = chart + self.config["redraw_chart"] = redraw_chart + self.config["auto_fix_sql"] = auto_fix_sql + self.config["ask_results_correct"] = ask_results_correct + self.config["followup_questions"] = followup_questions + self.config["summarization"] = summarization + self.config["function_generation"] = function_generation + + self.index_html_path = index_html_path + self.assets_folder = assets_folder + + @self.flask_app.route("/auth/login", methods=["POST"]) + def login(): + return self.auth.login_handler(flask.request) + + @self.flask_app.route("/auth/callback", methods=["GET"]) + def callback(): + return self.auth.callback_handler(flask.request) + + @self.flask_app.route("/auth/logout", methods=["GET"]) + def logout(): + return self.auth.logout_handler(flask.request) + + @self.flask_app.route("/assets/") def proxy_assets(filename): if self.assets_folder: @@ -755,18 +816,6 @@ def proxy_vanna_svg(): else: return "Error fetching file from remote server", response.status_code - if self.debug: - @self.sock.route("/api/v0/log") - def sock_log(ws): - self.ws_clients.append(ws) - - try: - while True: - message = ws.receive() # This example just reads and ignores to keep the socket open - finally: - self.ws_clients.remove(ws) - - @self.flask_app.route("/", defaults={"path": ""}) @self.flask_app.route("/") def hello(path: str): @@ -775,32 +824,3 @@ def hello(path: str): filename = os.path.basename(self.index_html_path) return send_from_directory(directory=directory, path=filename) return html_content - - def run(self, *args, **kwargs): - """ - Run the Flask app. - - Args: - *args: Arguments to pass to Flask's run method. - **kwargs: Keyword arguments to pass to Flask's run method. - - Returns: - None - """ - if args or kwargs: - self.flask_app.run(*args, **kwargs) - - else: - try: - from google.colab import output - - output.serve_kernel_port_as_window(8084) - from google.colab.output import eval_js - - print("Your app is running at:") - print(eval_js("google.colab.kernel.proxyPort(8084)")) - except: - print("Your app is running at:") - print("http://localhost:8084") - - self.flask_app.run(host="0.0.0.0", port=8084, debug=self.debug, use_reloader=False)