Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(wren-ai-service): improve ai service #1068

Merged
merged 8 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion wren-ai-service/Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,7 @@ prepare-wren-engine:
> tools/dev/etc/mdl/sample.json

use-wren-ui-as-engine:
poetry run python -m src.force_update_config
poetry run python -m src.force_update_config

run-sql mdl_path="" data_source="" sample_dataset="":
poetry run python demo/run_sql.py --mdl-path "{{mdl_path}}" --data-source "{{data_source}}" --sample-dataset "{{sample_dataset}}"
79 changes: 79 additions & 0 deletions wren-ai-service/demo/run_sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import argparse
import json

from utils import get_data_from_wren_engine, rerun_wren_engine


def main():
parser = argparse.ArgumentParser(
description="Execute SQL query against MDL manifest"
)

parser.add_argument(
"--mdl-path",
type=str,
required=True,
help="Path to MDL JSON file",
)

parser.add_argument(
"--data-source",
type=str,
default="bigquery",
choices=["bigquery", "duckdb"],
help="Data source (default: bigquery)",
)

parser.add_argument(
"--sample-dataset",
type=str,
default="ecommerce",
choices=["ecommerce", "hr", ""],
help="Sample dataset (default: ecommerce)",
)

args = parser.parse_args()

mdl_path = args.mdl_path
data_source = args.data_source
sample_dataset = args.sample_dataset

# Load MDL JSON file
try:
with open(mdl_path, "r") as f:
mdl_json = json.load(f)
except FileNotFoundError:
print(f"Error: MDL file not found at {mdl_path}")
return
except json.JSONDecodeError:
print(f"Error: Invalid JSON in MDL file {mdl_path}")
return

rerun_wren_engine(mdl_json, data_source, sample_dataset)

# Execute query
print("Enter SQL query (end with semicolon on a new line to execute, 'q' to quit):")
lines = []
while True:
line = input()
if line.strip() == "q":
break
if line.strip() == ";":
command = "\n".join(lines)
lines = []
try:
df = get_data_from_wren_engine(
sql=command,
dataset_type=data_source,
manifest=mdl_json,
limit=10,
)
print(f"\nExecution result:\n{df.to_string()}\n")
except Exception as e:
print(f"\nError executing query: {str(e)}")
else:
lines.append(line)


if __name__ == "__main__":
main()
8 changes: 3 additions & 5 deletions wren-ai-service/demo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _update_wren_engine_configs(configs: list[dict]):
assert response.status_code == 200


def rerun_wren_engine(mdl_json: Dict, dataset_type: str, dataset: str):
def rerun_wren_engine(mdl_json: Dict, dataset_type: str, dataset: Optional[str] = None):
assert dataset_type in DATA_SOURCES

SOURCE = dataset_type
Expand Down Expand Up @@ -118,7 +118,6 @@ def get_mdl_json(database_name: str):
return mdl_json


@st.cache_data
def get_data_from_wren_engine(
sql: str,
dataset_type: str,
Expand All @@ -139,7 +138,7 @@ def get_data_from_wren_engine(
},
)

assert response.status_code == 200, response.json()
assert response.status_code == 200, response.text

data = response.json()

Expand All @@ -162,7 +161,7 @@ def get_data_from_wren_engine(
},
)

assert response.status_code == 200, response.json()
assert response.status_code == 200, response.text

data = response.json()

Expand Down Expand Up @@ -803,7 +802,6 @@ def sql_regeneration(sql_regeneration_data: dict):
return None


@st.cache_data
def fill_vega_lite_values(vega_lite_schema: dict, df: pd.DataFrame) -> dict:
"""Fill Vega-Lite schema values from pandas DataFrame based on x/y encodings.
Expand Down
3 changes: 2 additions & 1 deletion wren-ai-service/src/pipelines/generation/utils/chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
- Please generate vega-lite schema using v5 version, which is https://vega.github.io/schema/vega-lite/v5.json
- Chart types: Bar chart, Line chart, Area chart, Pie chart, Stacked bar chart, Grouped bar chart
- You can only use the chart types provided in the instructions
- If you think the data is not suitable for visualization, you can return an empty string for the schema
- If the sample data is not suitable for visualization, you must return an empty string for the schema
- If the sample data is empty, you must return an empty string for the schema
- The language for the chart and reasoning must be the same language provided by the user
- Please use the current time provided by the user to generate the chart
- In order to generate the grouped bar chart, you need to follow the given instructions:
Expand Down
Loading