Skip to content

Commit

Permalink
Cache table name for current session #1
Browse files Browse the repository at this point in the history
  • Loading branch information
pramitchoudhary committed May 16, 2023
1 parent e150cad commit 7b4abc5
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 9 deletions.
20 changes: 16 additions & 4 deletions sidekick/prompter.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,22 @@ def learn(context: str):


@cli.command()
@click.option("--table_info", "-t", help="Table info", prompt="Which table to use?")
@click.option("--question", "-q", help="Database name", prompt="Ask a question")
def query(table_info: str, question: str):
def query(question: str):
"""Asks question and returns SQL."""

# Check if table exists
path = f"{base_path}/var/lib/tmp/data/"
table_context_file = f"{path}/table_context.json"
table_context = json.load(open(table_context_file, "r")) if Path(table_context_file).exists() else {}
if table_context:
table_name = table_context.get("tables_in_use", None)
else:
table_name = [click.prompt("Which table to use?")]
table_context["tables_in_use"] = table_name
with open(f"{path}/table_context.json", "w") as outfile:
json.dump(table_context, outfile, indent=4, sort_keys=False)
logger.info(f"Table in use: {table_name}")
# Check if .env.toml file exists
api_key = env_settings["OPENAI"]["OPENAI_API_KEY"]
if api_key is None or api_key == "":
Expand Down Expand Up @@ -133,7 +145,7 @@ def query(table_info: str, question: str):
)

sql_g = SQLGenerator(db_url, api_key)
sql_g._tasks = sql_g.generate_tasks(table_info, question)
sql_g._tasks = sql_g.generate_tasks(table_name, question)
click.echo(sql_g._tasks)

updated_tasks = None
Expand All @@ -146,7 +158,7 @@ def query(table_info: str, question: str):
click.echo("Skipping edit...")
if updated_tasks is not None:
sql_g._tasks = updated_tasks
res = sql_g.generate_sql(table_info, question)
res = sql_g.generate_sql(table_name, question)
logger.info(f"Generated response:\n\n{res}")

if res is not None:
Expand Down
10 changes: 5 additions & 5 deletions sidekick/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,11 @@ def build_index(self, persist: bool = True):
table_schema_index.save_to_disk(f"{self.path}/sql_index_check.json")
return table_schema_index

def _query_tasks(self, question_str, context_info: dict, data_info, sample_queries, table_name: str):
keys = [table_name]
def _query_tasks(self, question_str, context_info: dict, data_info, sample_queries, table_name: list):
keys = table_name
# TODO: Throw error if context_info is not a dict.
schema_info = list(map(context_info.get, keys))

try:
context_file = f"{self.path}/context.json"
additional_context = json.load(open(context_file, "r")) if Path(context_file).exists() else {}
Expand Down Expand Up @@ -98,8 +99,8 @@ def generate_response(self, context_container, sql_index, input_prompt, attempt_
res = qry_txt
return res
except Exception as se:
_, ex_value, ex_traceback = sys.exc_info()
res = ex_value.statement
# Another exception occurred, return the original SQL
res = qry_txt
return res

def generate_tasks(self, table_name: str, input_question: str, path: str = "./var/lib/tmp/data"):
Expand Down Expand Up @@ -131,7 +132,6 @@ def generate_sql(
_tasks=_tasks.lower(),
)

logger.info(f"Prompt:\n {query_str}")
table_schema_index = self.build_index(persist=False)
self.context_builder.query_index_for_context(table_schema_index, query_str, store_context_str=True)
context_container = self.context_builder.build_context_container()
Expand Down

0 comments on commit 7b4abc5

Please sign in to comment.