Skip to content

Commit

Permalink
remove unused code
Browse files Browse the repository at this point in the history
  • Loading branch information
cyyeh committed Feb 7, 2025
1 parent bab3392 commit 0f33dbb
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 232 deletions.
16 changes: 0 additions & 16 deletions wren-ai-service/demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,6 @@
st.session_state["preview_sql"] = None
if "query_history" not in st.session_state:
st.session_state["query_history"] = None
if "sql_explanation_question" not in st.session_state:
st.session_state["sql_explanation_question"] = None
if "sql_explanation_steps_with_analysis" not in st.session_state:
st.session_state["sql_explanation_steps_with_analysis"] = None
if "sql_analysis_results" not in st.session_state:
st.session_state["sql_analysis_results"] = None
if "sql_explanation_results" not in st.session_state:
st.session_state["sql_explanation_results"] = None
if "sql_user_corrections_by_step" not in st.session_state:
st.session_state["sql_user_corrections_by_step"] = []
if "sql_regeneration_results" not in st.session_state:
st.session_state["sql_regeneration_results"] = None
if "language" not in st.session_state:
st.session_state["language"] = "English"
if "timezone" not in st.session_state:
Expand Down Expand Up @@ -242,10 +230,6 @@ def onchange_timezone():
ask_details_result = ask_details()
if ask_details_response := ask_details_result.get("response"):
st.session_state["asks_details_result"] = ask_details_response
st.session_state["sql_explanation_question"] = None
st.session_state["sql_explanation_steps_with_analysis"] = None
st.session_state["sql_analysis_results"] = None
st.session_state["sql_explanation_results"] = None
else:
st.error(
f'An error occurred while processing the query: {ask_details_result.get("error")}',
Expand Down
223 changes: 7 additions & 216 deletions wren-ai-service/demo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,15 +181,8 @@ def on_change_sql_generation_reasoning():


def on_click_regenerate_sql():
if not st.session_state["sql_generation_reasoning"]:
sql_generation_reasoning = st.session_state["asks_results"][
"sql_generation_reasoning"
]
else:
sql_generation_reasoning = st.session_state["sql_generation_reasoning"]

ask_feedback(
sql_generation_reasoning,
st.session_state["sql_generation_reasoning"],
st.session_state["asks_results"]["response"][0]["sql"],
)

Expand Down Expand Up @@ -225,9 +218,9 @@ def show_asks_results():
st.markdown("### SQL Generation Reasoning")
st.text_area(
"SQL Generation Reasoning",
st.session_state["asks_results"]["sql_generation_reasoning"],
st.session_state["sql_generation_reasoning"],
key="sql_generation_reasoning_input",
height=300,
height=250,
on_change=on_change_sql_generation_reasoning,
)

Expand Down Expand Up @@ -298,49 +291,6 @@ def on_click_preview_data_button(index: int, full_sqls: List[str]):
st.session_state["preview_sql"] = full_sqls[index]


def get_sql_analysis_results(sqls: List[str], manifest: Dict):
results = []
for sql in sqls:
response = requests.get(
f"{WREN_ENGINE_API_URL}/v1/analysis/sql",
json={
"sql": sql,
"manifest": manifest,
},
)

assert response.status_code == 200, response.json()

results.append(response.json())

return results


def on_click_sql_explanation_button(
question: str,
sqls: List[str],
summaries: List[str],
manifest: Dict,
):
sql_analysis_results = get_sql_analysis_results(sqls, manifest)

st.session_state["sql_explanation_question"] = question
st.session_state["sql_analysis_results"] = sql_analysis_results
st.session_state["sql_explanation_steps_with_analysis"] = [
{"sql": sql, "summary": summary, "sql_analysis_results": sql_analysis_results}
for sql, summary, sql_analysis_results in zip(
sqls, summaries, sql_analysis_results
)
]

sql_explanation_results = sql_explanation()
st.session_state["sql_explanation_results"] = sql_explanation_results
if sql_explanation_results:
st.session_state["sql_user_corrections_by_step"] = [
[] for _ in range(len(sql_explanation_results))
]


def on_change_user_correction(
step_idx: int, explanation_index: int, explanation_result: dict
):
Expand Down Expand Up @@ -414,25 +364,6 @@ def _get_decision_point(explanation_result: dict):
)


def on_click_sql_regeneration_button(
ask_details_results: dict,
sql_user_corrections_by_step: List[List[dict]],
):
sql_regeneration_data = copy.deepcopy(ask_details_results)
for i, (_, sql_user_corrections) in enumerate(
zip(sql_regeneration_data["steps"], sql_user_corrections_by_step)
):
if sql_user_corrections:
sql_regeneration_data["steps"][i]["corrections"] = sql_user_corrections
else:
sql_regeneration_data["steps"][i]["corrections"] = []

st.session_state["sql_regeneration_results"] = sql_regeneration(
sql_regeneration_data
)
show_sql_regeneration_results_dialog(sql_user_corrections_by_step)


def on_click_adjust_chart(
query: str,
sql: str,
Expand Down Expand Up @@ -581,6 +512,7 @@ def prepare_semantics(mdl_json: dict):
st.session_state["preview_data_button_index"] = None
st.session_state["preview_sql"] = None
st.session_state["query_history"] = None
st.session_state["sql_generation_reasoning"] = None

if st.session_state["semantics_preparation_status"] == "failed":
st.toast("An error occurred while preparing the semantics", icon="🚨")
Expand Down Expand Up @@ -630,6 +562,9 @@ def ask(query: str, timezone: str, query_history: Optional[dict] = None):
display_general_response(query_id)
elif asks_type == "TEXT_TO_SQL":
st.session_state["asks_results"] = asks_status_response.json()
st.session_state["sql_generation_reasoning"] = st.session_state[
"asks_results"
]["sql_generation_reasoning"]
else:
st.session_state["asks_results"] = asks_type
elif asks_status == "failed":
Expand Down Expand Up @@ -783,73 +718,6 @@ def ask_details():
return asks_details_status_response.json()


def sql_explanation():
sql_explanation_response = requests.post(
f"{WREN_AI_SERVICE_BASE_URL}/v1/sql-explanations",
json={
"question": st.session_state["sql_explanation_question"],
"steps_with_analysis_results": st.session_state[
"sql_explanation_steps_with_analysis"
],
},
)

assert sql_explanation_response.status_code == 200
query_id = sql_explanation_response.json()["query_id"]
sql_explanation_status = None

while (
sql_explanation_status != "finished" and sql_explanation_status != "failed"
) or not sql_explanation_status:
sql_explanation_status_response = requests.get(
f"{WREN_AI_SERVICE_BASE_URL}/v1/sql-explanations/{query_id}/result"
)
assert sql_explanation_status_response.status_code == 200
sql_explanation_status = sql_explanation_status_response.json()["status"]
st.toast(f"The query processing status: {sql_explanation_status}")
time.sleep(POLLING_INTERVAL)

if sql_explanation_status == "finished":
return sql_explanation_status_response.json()["response"]
elif sql_explanation_status == "failed":
st.error(
f'An error occurred while processing the query: {sql_explanation_status_response.json()['error']}',
icon="🚨",
)
return None


def sql_regeneration(sql_regeneration_data: dict):
sql_regeneration_response = requests.post(
f"{WREN_AI_SERVICE_BASE_URL}/v1/sql-regenerations",
json=sql_regeneration_data,
)

assert sql_regeneration_response.status_code == 200
query_id = sql_regeneration_response.json()["query_id"]
sql_regeneration_status = None

while (
sql_regeneration_status != "finished" and sql_regeneration_status != "failed"
) or not sql_regeneration_status:
sql_regeneration_status_response = requests.get(
f"{WREN_AI_SERVICE_BASE_URL}/v1/sql-regenerations/{query_id}/result"
)
assert sql_regeneration_status_response.status_code == 200
sql_regeneration_status = sql_regeneration_status_response.json()["status"]
st.toast(f"The query processing status: {sql_regeneration_status}")
time.sleep(POLLING_INTERVAL)

if sql_regeneration_status == "finished":
return sql_regeneration_status_response.json()["response"]
elif sql_regeneration_status == "failed":
st.error(
f'An error occurred while processing the query: {sql_regeneration_status_response.json()['error']}',
icon="🚨",
)
return None


def fill_vega_lite_values(vega_lite_schema: dict, df: pd.DataFrame) -> dict:
"""Fill Vega-Lite schema values from pandas DataFrame based on x/y encodings.
Expand Down Expand Up @@ -991,83 +859,6 @@ def adjust_chart(
return chart_response


@st.dialog(
"Comparing SQL step-by-step breakdown before and after SQL Generation Feedback",
width="large",
)
def show_sql_regeneration_results_dialog(
sql_user_corrections_by_step: List[List[dict]],
):
st.markdown("### Adjustments")
st.json(sql_user_corrections_by_step, expanded=True)

col1, col2 = st.columns(2)
original_sqls = []
with col1:
st.markdown("### Before SQL Generation Feedback")
st.markdown(
f'Description: {st.session_state['asks_details_result']["description"]}'
)

sqls_with_cte = []
for i, step in enumerate(st.session_state["asks_details_result"]["steps"]):
st.markdown(f"#### Step {i + 1}")
st.markdown(f'Summary: {step["summary"]}')

sql = ""
if sqls_with_cte:
sql += "WITH " + ",\n".join(sqls_with_cte) + "\n\n"
sql += step["sql"]
original_sqls.append(sql)

st.markdown("SQL")
st.code(
body=sqlparse.format(sql, reindent=True, keyword_case="upper"),
language="sql",
)
sqls_with_cte.append(f"{step['cte_name']} AS ( {step['sql']} )")
with col2:
st.markdown("### After SQL Generation Feedback")

if (
st.session_state["sql_regeneration_results"]["description"]
== st.session_state["asks_details_result"]["description"]
):
st.markdown(
f'Description: {st.session_state['sql_regeneration_results']["description"]}'
)
else:
st.markdown(
f':red[Description:] {st.session_state['sql_regeneration_results']["description"]}'
)

sqls_with_cte = []
for i, step in enumerate(st.session_state["sql_regeneration_results"]["steps"]):
st.markdown(f"#### Step {i + 1}")
if (
step["summary"]
== st.session_state["asks_details_result"]["steps"][i]["summary"]
):
st.markdown(f'Summary: {step["summary"]}')
else:
st.markdown(f':red[Summary:] {step["summary"]}')

sql = ""
if sqls_with_cte:
sql += "WITH " + ",\n".join(sqls_with_cte) + "\n\n"
sql += step["sql"]

if sql == original_sqls[i]:
st.markdown("SQL")
else:
st.markdown(":red[SQL:]")
st.code(
body=sqlparse.format(sql, reindent=True, keyword_case="upper"),
language="sql",
)
sqls_with_cte.append(f"{step['cte_name']} AS ( {step['sql']} )")


def show_original_chart(chart_schema: dict, reasoning: str, chart_type: str):
st.markdown("### Original")
st.markdown(f"#### Chart Type: {chart_type}")
Expand Down

0 comments on commit 0f33dbb

Please sign in to comment.