From fef7ff2bae2860d92565cc573162e31ea062d1f2 Mon Sep 17 00:00:00 2001 From: Zain Hoda <7146154+zainhoda@users.noreply.github.com> Date: Thu, 3 Aug 2023 13:02:36 -0400 Subject: [PATCH] postgres training plan --- src/vanna/__init__.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/src/vanna/__init__.py b/src/vanna/__init__.py index ae3c06ab..34b15e5c 100644 --- a/src/vanna/__init__.py +++ b/src/vanna/__init__.py @@ -602,6 +602,43 @@ def __get_information_schema_tables(database: str) -> pd.DataFrame: return df_tables +def get_training_plan_postgres(filter_databases: Union[List[str], None] = None, filter_schemas: Union[List[str], None] = None, include_information_schema: bool = False, use_historical_queries: bool = True) -> TrainingPlan: + plan = TrainingPlan([]) + + if run_sql is None: + raise ValidationError("Please connect to a database first.") + + df_columns = run_sql("select * from INFORMATION_SCHEMA.COLUMNS") + + databases = df_columns['table_catalog'].unique().tolist() + + for database in databases: + if filter_databases is not None and database not in filter_databases: + continue + + for schema in df_columns.query(f'table_catalog == "{database}"')['table_schema'].unique().tolist(): + if filter_schemas is not None and schema not in filter_schemas: + continue + + if not include_information_schema and (schema == "information_schema" or schema == "pg_catalog"): + continue + + df_columns_filtered = df_columns.query(f'table_catalog == "{database}" and table_schema == "{schema}"') + + for table in df_columns_filtered['table_name'].unique().tolist(): + df_columns_filtered_to_table = df_columns_filtered.query(f'table_name == "{table}"') + doc = f"The following columns are in the {table} table in the {database} database:\n\n" + doc += df_columns_filtered_to_table[["table_catalog", "table_schema", "table_name", "column_name", "data_type"]].to_markdown() + + plan._plan.append(TrainingPlanItem( + item_type=TrainingPlanItem.ITEM_TYPE_IS, + item_group=f"{database}.{schema}", + item_name=table, + item_value=doc + )) + + return plan + def get_training_plan_experimental(filter_databases: Union[List[str], None] = None, filter_schemas: Union[List[str], None] = None, include_information_schema: bool = False, use_historical_queries: bool = True) -> TrainingPlan: """ @@ -1632,6 +1669,7 @@ def run_sql_postgres(sql: str) -> Union[pd.DataFrame, None]: return df except psycopg2.Error as e: + conn.rollback() raise ValidationError(e) global run_sql