Skip to content

Commit

Permalink
postgres training plan
Browse files Browse the repository at this point in the history
  • Loading branch information
zainhoda committed Aug 3, 2023
1 parent a670ff5 commit fef7ff2
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions src/vanna/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,43 @@ def __get_information_schema_tables(database: str) -> pd.DataFrame:

return df_tables

def get_training_plan_postgres(filter_databases: Union[List[str], None] = None, filter_schemas: Union[List[str], None] = None, include_information_schema: bool = False, use_historical_queries: bool = True) -> TrainingPlan:
plan = TrainingPlan([])

if run_sql is None:
raise ValidationError("Please connect to a database first.")

df_columns = run_sql("select * from INFORMATION_SCHEMA.COLUMNS")

databases = df_columns['table_catalog'].unique().tolist()

for database in databases:
if filter_databases is not None and database not in filter_databases:
continue

for schema in df_columns.query(f'table_catalog == "{database}"')['table_schema'].unique().tolist():
if filter_schemas is not None and schema not in filter_schemas:
continue

if not include_information_schema and (schema == "information_schema" or schema == "pg_catalog"):
continue

df_columns_filtered = df_columns.query(f'table_catalog == "{database}" and table_schema == "{schema}"')

for table in df_columns_filtered['table_name'].unique().tolist():
df_columns_filtered_to_table = df_columns_filtered.query(f'table_name == "{table}"')
doc = f"The following columns are in the {table} table in the {database} database:\n\n"
doc += df_columns_filtered_to_table[["table_catalog", "table_schema", "table_name", "column_name", "data_type"]].to_markdown()

plan._plan.append(TrainingPlanItem(
item_type=TrainingPlanItem.ITEM_TYPE_IS,
item_group=f"{database}.{schema}",
item_name=table,
item_value=doc
))

return plan


def get_training_plan_experimental(filter_databases: Union[List[str], None] = None, filter_schemas: Union[List[str], None] = None, include_information_schema: bool = False, use_historical_queries: bool = True) -> TrainingPlan:
"""
Expand Down Expand Up @@ -1632,6 +1669,7 @@ def run_sql_postgres(sql: str) -> Union[pd.DataFrame, None]:
return df

except psycopg2.Error as e:
conn.rollback()
raise ValidationError(e)

global run_sql
Expand Down

0 comments on commit fef7ff2

Please sign in to comment.