Skip to content

Commit

Permalink
Prompt format adjustment #1
Browse files Browse the repository at this point in the history
  • Loading branch information
pramitchoudhary committed May 15, 2023
1 parent f8f8dc5 commit f9deda3
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
24 changes: 16 additions & 8 deletions sidekick/configs/prompt_template.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
TASK_PROMPT = {
"system_prompt": "Act as a Data Analyst",
"user_prompt": """
### For Table: {table_name} Given an input (question), only return specific tasks as an ordered itemized list for SQL generation that answer the question. Don't generate SQL code.
Infer the return type of the question. Add a task to return final output.
# Example data: \n {samples}
# (question):\n {question_str}
# Output format: Tasks: list of tasks
### For Table: {_table_name} Given an input *Question*, only return specific and informative tasks as an ordered itemized list for SQL generation that answer the question.
Extract all of the proper nouns (generally capitalized, abbreviated) from the Samples section and add to Context section as Key, Value pair.
Use the Context section and Samples section to establish relationship when tokens from Question does not match column names.
If information is not found in Context or Samples section, attempt to reason for possible tasks but also ask questions for.
Infer the return type of the Question. Don't generate SQL code.
# Data information: \n{_data_info}
# Samples: \n{_sample_queries}
# Context: {_context}
# *Question*: {_question_str};
# Output: Tasks: ordered list of tasks
""",
}

Expand All @@ -16,14 +21,17 @@
# Reference: https://arxiv.org/pdf/2005.14165.pdf
QUERY_PROMPT = """
### System: Act as a SQL Expert
# Given an input (question), only generate syntactically correct SQL queries
# Given an input *Question*, only generate syntactically correct SQL queries using step by step reasoning from Tasks section.
# Extract all of the proper nouns (generally capitalized, abbreviated) from the Examples section and add to Context section as Key, Value pair.
# Use the context section to establish relationship when tokens from Question does not match column names.
# Pick the SQL query which has the highest average log probability of explaining the
candidate question
candidate question.
### {dialect} SQL tables
Examples:\n{_sample_queries}
### *question*:\n{_question};
### *Question*: {_question};
# SELECT 1
### Tasks:\n{_tasks}
### Context: {_context}
### Suggestions:
# Don't use aggregate and window function together;
# Avoid COUNT(*) and prefer COUNT(1);
Expand Down
14 changes: 9 additions & 5 deletions sidekick/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from configs.prompt_template import DEBUGGING_PROMPT, QUERY_PROMPT, TASK_PROMPT
from examples.sample_data import sample_values, samples_queries
from langchain import OpenAI
from llama_index import GPTSimpleVectorIndex, GPTSQLStructStoreIndex, LLMPredictor, ServiceContext, SQLDatabase
from llama_index import (GPTSimpleVectorIndex, GPTSQLStructStoreIndex,
LLMPredictor, ServiceContext, SQLDatabase)
from llama_index.indices.struct_store import SQLContextContainerBuilder
from loguru import logger
from sqlalchemy import create_engine
Expand Down Expand Up @@ -35,14 +36,18 @@ def build_index(self, persist: bool = True):
table_schema_index.save_to_disk(f"{self.path}/sql_index_check.json")
return table_schema_index

def _query_tasks(self, question_str, context_info: dict, samples, table_name: str):
def _query_tasks(self, question_str, context_info: dict, data_info, sample_queries, table_name: str):
keys = [table_name]
# TODO: Throw error if context_info is not a dict.
schema_info = list(map(context_info.get, keys))
try:
system_prompt = TASK_PROMPT["system_prompt"]
user_prompt = TASK_PROMPT["user_prompt"].format(
table_name=table_name, schema_info=schema_info, samples=samples, question_str=question_str
_table_name=table_name,
_schema_info=schema_info,
_data_info=data_info,
_sample_queries=sample_queries,
_question_str=question_str,
)

# Role and content
Expand Down Expand Up @@ -98,8 +103,7 @@ def generate_tasks(self, table_name: str, input_question: str, path: str = "./va
context_builder = SQLContextContainerBuilder(self.sql_database)

c_info = context_builder.full_context_dict
_sample_values = sample_values
task_list = self._query_tasks(input_question, c_info, _sample_values, table_name)
task_list = self._query_tasks(input_question, c_info, sample_values, samples_queries, table_name)
with open(f"{path}/tasks.txt", "w") as f:
f.write(task_list)
return task_list
Expand Down

0 comments on commit f9deda3

Please sign in to comment.