From 8ac94be9c19be7f4f1c6c76cc206ad59f1f59e72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E8=B6=85?= Date: Tue, 7 May 2024 09:46:59 +0800 Subject: [PATCH] feat:add hive and kyuubi to be supportted --- src/vanna/base/base.py | 96 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index 53e1f675..00017043 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -1393,6 +1393,102 @@ def run_sql_presto(sql: str) -> Union[pd.DataFrame, None]: self.run_sql_is_set = True self.run_sql = run_sql_presto + def connect_to_hive( + self, + host: str = None, + dbname: str = 'default', + user: str = None, + password: str = None, + port: int = None, + auth: str = 'CUSTOM' + ): + """ + Connect to a Hive database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql] + Connect to a Hive database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql] + + Args: + host (str): The host of the Hive database. + dbname (str): The name of the database to connect to. + user (str): The username to use for authentication. + password (str): The password to use for authentication. + port (int): The port to use for the connection. + auth (str): The authentication method to use. + + Returns: + None + """ + + try: + from pyhive import hive + except ImportError: + raise DependencyError( + "You need to install required dependencies to execute this method," + " run command: \npip install pyhive" + ) + + if not host: + host = os.getenv("HIVE_HOST") + + if not host: + raise ImproperlyConfigured("Please set your hive host") + + if not dbname: + dbname = os.getenv("HIVE_DATABASE") + + if not dbname: + raise ImproperlyConfigured("Please set your hive database") + + if not user: + user = os.getenv("HIVE_USER") + + if not user: + raise ImproperlyConfigured("Please set your hive user") + + if not password: + password = os.getenv("HIVE_PASSWORD") + + if not port: + port = os.getenv("HIVE_PORT") + + if not port: + raise ImproperlyConfigured("Please set your hive port") + + conn = None + + try: + conn = hive.Connection(host=host, + username=user, + password=password, + database=dbname, + port=port, + auth=auth) + except hive.Error as e: + raise ValidationError(e) + + def run_sql_hive(sql: str) -> Union[pd.DataFrame, None]: + if conn: + try: + cs = conn.cursor() + cs.execute(sql) + results = cs.fetchall() + + # Create a pandas dataframe from the results + df = pd.DataFrame( + results, columns=[desc[0] for desc in cs.description] + ) + return df + + except hive.Error as e: + print(e) + raise ValidationError(e) + + except Exception as e: + print(e) + raise e + + self.run_sql_is_set = True + self.run_sql = run_sql_hive + def run_sql(self, sql: str, **kwargs) -> pd.DataFrame: """ Example: