From 7b4abc5328d0d639252e64dc48891b65aa692575 Mon Sep 17 00:00:00 2001 From: pramitchoudhary Date: Tue, 16 May 2023 15:54:32 -0700 Subject: [PATCH] Cache table name for current session #1 --- sidekick/prompter.py | 20 ++++++++++++++++---- sidekick/query.py | 10 +++++----- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/sidekick/prompter.py b/sidekick/prompter.py index 4545e7e..315c0bf 100644 --- a/sidekick/prompter.py +++ b/sidekick/prompter.py @@ -98,10 +98,22 @@ def learn(context: str): @cli.command() -@click.option("--table_info", "-t", help="Table info", prompt="Which table to use?") @click.option("--question", "-q", help="Database name", prompt="Ask a question") -def query(table_info: str, question: str): +def query(question: str): """Asks question and returns SQL.""" + + # Check if table exists + path = f"{base_path}/var/lib/tmp/data/" + table_context_file = f"{path}/table_context.json" + table_context = json.load(open(table_context_file, "r")) if Path(table_context_file).exists() else {} + if table_context: + table_name = table_context.get("tables_in_use", None) + else: + table_name = [click.prompt("Which table to use?")] + table_context["tables_in_use"] = table_name + with open(f"{path}/table_context.json", "w") as outfile: + json.dump(table_context, outfile, indent=4, sort_keys=False) + logger.info(f"Table in use: {table_name}") # Check if .env.toml file exists api_key = env_settings["OPENAI"]["OPENAI_API_KEY"] if api_key is None or api_key == "": @@ -133,7 +145,7 @@ def query(table_info: str, question: str): ) sql_g = SQLGenerator(db_url, api_key) - sql_g._tasks = sql_g.generate_tasks(table_info, question) + sql_g._tasks = sql_g.generate_tasks(table_name, question) click.echo(sql_g._tasks) updated_tasks = None @@ -146,7 +158,7 @@ def query(table_info: str, question: str): click.echo("Skipping edit...") if updated_tasks is not None: sql_g._tasks = updated_tasks - res = sql_g.generate_sql(table_info, question) + res = sql_g.generate_sql(table_name, question) logger.info(f"Generated response:\n\n{res}") if res is not None: diff --git a/sidekick/query.py b/sidekick/query.py index dec7cf6..9ae492a 100644 --- a/sidekick/query.py +++ b/sidekick/query.py @@ -38,10 +38,11 @@ def build_index(self, persist: bool = True): table_schema_index.save_to_disk(f"{self.path}/sql_index_check.json") return table_schema_index - def _query_tasks(self, question_str, context_info: dict, data_info, sample_queries, table_name: str): - keys = [table_name] + def _query_tasks(self, question_str, context_info: dict, data_info, sample_queries, table_name: list): + keys = table_name # TODO: Throw error if context_info is not a dict. schema_info = list(map(context_info.get, keys)) + try: context_file = f"{self.path}/context.json" additional_context = json.load(open(context_file, "r")) if Path(context_file).exists() else {} @@ -98,8 +99,8 @@ def generate_response(self, context_container, sql_index, input_prompt, attempt_ res = qry_txt return res except Exception as se: - _, ex_value, ex_traceback = sys.exc_info() - res = ex_value.statement + # Another exception occurred, return the original SQL + res = qry_txt return res def generate_tasks(self, table_name: str, input_question: str, path: str = "./var/lib/tmp/data"): @@ -131,7 +132,6 @@ def generate_sql( _tasks=_tasks.lower(), ) - logger.info(f"Prompt:\n {query_str}") table_schema_index = self.build_index(persist=False) self.context_builder.query_index_for_context(table_schema_index, query_str, store_context_str=True) context_container = self.context_builder.build_context_container()