diff --git a/xql/README.md b/xql/README.md index 70ee3000..bb23c9f8 100644 --- a/xql/README.md +++ b/xql/README.md @@ -28,7 +28,7 @@ For this `gcloud` must be configured in your local environment. Refer [Initializ pip install xql # Jump into xql -python xql/main.py +python xql/main.py xql ``` --- ### Supported meta commands @@ -136,4 +136,55 @@ _Updated on 2024-01-08_ 7. [ ] **Aliases**: Add support to alias while querying. 8. [ ] **Join Operations**: Support joining tables and apply query. 9. [ ] **Nested Queries**: Add support to write nested queries. -10. [ ] **Custom Aggregate Functions**: Support custom aggregate functions \ No newline at end of file +10. [ ] **Custom Aggregate Functions**: Support custom aggregate functions + +# `weather-lm` - Querying weather data using Natural Language prompts + +Querying weather data using Natural Language prompts. This uses a gemini (large language model from Google) to generate SQL like queries and `xql` to execute that query. + +# Quickstart + +## Prerequisites + +Google API Key is needed to initiate Language Model. Refer [Setup your API key](https://ai.google.dev/tutorials/python_quickstart#setup_your_api_key) to generate that key. + +Set that key as an environment variable. Run below command. +``` +export GOOGLE_API_KEY="generate_key" +``` + +## Usage +``` +# Install required packages +pip install xql + +# Jump into language model +python xql/main.py lm +``` +--- +### Examples +`Input Prompt`: Daily average temperature of New York for January 2015 + +Relevant SQL Query: +``` +SELECT + AVG('temperature') +FROM + 'gs://darshan-store/ar/2013-2022-full_37-1h-0p25deg-chunk-1.zarr-v3' +WHERE + time >= '2015-01-01' AND + time < '2015-02-01' AND + city = 'New York' +GROUP BY time_date +``` +Output Data: +``` + time_date avg_temperature +0 2015-01-01 240.978073 +1 2015-01-02 243.375031 +2 2015-01-03 244.584747 +3 2015-01-04 249.673065 +4 2015-01-05 245.650833 +... +Query took: 00:01:55 +``` diff --git a/xql/main.py b/xql/main.py index b720739a..e129410d 100644 --- a/xql/main.py +++ b/xql/main.py @@ -11,18 +11,38 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import argparse + +from src.weather_lm import nl_to_weather_data from src.xql import main from src.xql.utils import connect_dask_cluster +from typing import List, Tuple + +def parse_args() -> Tuple[argparse.Namespace, List[str]]: + parser = argparse.ArgumentParser() + + parser.add_argument('mode', type=str, help='Select one from [xql, lm]') + + return parser.parse_known_args() + if __name__ == '__main__': + known_args, _ = parse_args() + if known_args.mode not in ["xql", "lm"]: + raise RuntimeError("Invalid mode type. Select one from [xql, lm]") + prefix = "xql" if known_args.mode == "xql" else "lm" try: # Connect Dask Cluster connect_dask_cluster() while True: - query = input("xql> ") + query = input(f"{prefix}> ") if query == ".exit": break - main(query) + if known_args.mode == "xql": + main(query) + else: + print(nl_to_weather_data(query)) except ImportError as e: raise ImportError('main function is not imported please try again.') from e diff --git a/xql/setup.py b/xql/setup.py index dc854790..81865677 100644 --- a/xql/setup.py +++ b/xql/setup.py @@ -26,7 +26,11 @@ "toolz==0.12.0", "xarray==2024.01.0", "xee", - "zarr==2.17.0" + "zarr==2.17.0", + "langchain", + "langchain-experimental", + "langchain-openai", + "langchain-google-genai" ] setup( @@ -40,5 +44,5 @@ description=("Running SQL queries on Xarray Datasets. Consider dataset as a table and data variable as a column."), long_description=open('README.md', 'r', encoding='utf-8').read(), long_description_content_type='text/markdown', - python_requires='>=3.8, <3.11', + python_requires='>=3.9, <3.11', ) diff --git a/xql/src/weather_lm/__init__.py b/xql/src/weather_lm/__init__.py new file mode 100644 index 00000000..f0d9589b --- /dev/null +++ b/xql/src/weather_lm/__init__.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python3 +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .gemini import nl_to_sql_query, nl_to_weather_data #noqa diff --git a/xql/src/weather_lm/constant.py b/xql/src/weather_lm/constant.py new file mode 100644 index 00000000..576814f3 --- /dev/null +++ b/xql/src/weather_lm/constant.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa: E501 + +METADATA_URI = "gs://darshan-store/xql/metadata.json" + +GENERATE_SQL_TEMPLATE = """You are a SQL expert. Given an input question, first create a syntactically correct SQL query to execute. +Never query for all columns from a table. You must query only the columns that are needed to answer the question. +Pay attention to use only the column names you can see in the tables below. +Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. +Everytime wrap table name in single quotes (''). +Specify the time range, latitude, longitude as follows: (time >= '2012-12-01' AND time < '2013-04-01'). + +While accessing the variable name from the table don't use "\" this. +Ex. ( SELECT MAX("vertical_velocity") FROM 'table' ) => True syntax + ( SELECT MAX(\"vertical_velocity\") FROM 'table' ) => False syntax +Avoid using the 'time BETWEEN', 'latitude BETWEEN' syntax, opt for the former style instead. + +Note: At present, only data variables are supported in the SELECT Clause. Coordinates (latitude, longitude, time) are not supported. Therefore, +coordinates should not be used in the SELECT Clause. + +Example: + +Some important data details to consider: +- Use latitude and longitude ranges for cities and countries. +- Standard aggregations are applied to the data. A unique convention for aggregation for daily, monthly and yearly are time_date, time_month and time_year. +- The WHERE clause and GROUP BY is specifically applies to coordinates variables. e.g. timestamp, latitude, longitude, and level coordinates. + For "timestamp," use time_date for grouping by date and time_month for grouping by month. Standard SQL GROUP BY operations apply only to "latitude", + "longitude", and "level" column. +- Write time always into 'YYYY-MM-DD' format. i.e. '2021-12-01'. + +Please use the following format: + +Question: "Question here" +SQLQuery: "SQL Query to run" + +Use the following information for the database: +- Use {table} as table name. +- The dataset includes columns like {columns}. Select appropriate columns from these which are most relevant to the Question. +- Latitude range is {latitude_range}, and longitude range is {longitude_range}. Generate query accordingly. +- {latitude_dim} and {longitude_dim} are my columns for latitude and longitude so use them everywhere in query. + Ex. If lat and lon are in the {dims} then instead of latitude > x AND longitude > y use lat > x AND lon > y. +- The interpretation of the "organic" soil type is value of soil type is equal to 6. +- "Over all locations", "globally" entails iterating through "latitude" & "longitude." + +Some examples of SQL queries that correspond to questions are: + +{few_shot_examples} + +Question: {question}""" + +few_shots = { + "Aggregate precipitation over months over all locations?" : "SELECT SUM(precipitation) FROM 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3' GROUP BY latitude, longitude, time_month", + "Daily average temperature.":"""SELECT AVG(temperature) FROM "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3" GROUP BY time_date""", + "Average temperature of the Antarctic Area during last monsoon over months.":"SELECT AVG(temperature) FROM 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3' WHERE time >= '2022-06-01' AND time < '2022-11-01' AND latitude < 66.5 GROUP BY time_month", + "Average temperature over years.":"SELECT AVG(temperature) FROM 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3' GROUP BY time_year", + "Aggregate precipitation globally?" : "SELECT SUM(precipitation) FROM 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3' GROUP BY latitude, longitude", + "For January 2000" : "SELECT * from TABLE where time >= '2000-01-01 00:00:00' AND time < '2000-02-01 00:00:00' ", + "Daily average temperature of city x for January 2015?": "SELECT AVG(temperature) FROM 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3' WHERE time >= '2015-01-01' AND time < '2015-02-01' AND latitude > 40 AND latitude < 41 AND longitude > 286 AND longitude < 287 GROUP BY time_date", + "Daily min reflectivity of city x for January 2015?": "SELECT AVG(reflectivity) FROM 'ee://projects/anthromet-prod/assets/opera/instantaneous_maximum_reflectivity' WHERE time >= '2015-01-01' AND time < '2015-02-01' AND lat > 40.48 AND lat < 41.87 AND lon > -74.25 AND lon < -71.98 GROUP BY time_date" +} + +SELECT_DATASET_TEMPLATE = """ +I have some description of tables that stores weather related data. +Analyze and give me an table that i need to query for provided question. +Sometimes the exact column not be there in the table so select table that contains most relevant columns. + Ex. Daily average of precipitation rate asked but exact precipitation column is not there then select the table that contains relevant column like total_precipitation, precipitation_rate, total_precipitation_rate, etc. +Below is the description and input qustion +{table_map} + +Question: {question} + +Please use the following format: + +{question}:appropriate table +""" diff --git a/xql/src/weather_lm/gemini.py b/xql/src/weather_lm/gemini.py new file mode 100644 index 00000000..84405739 --- /dev/null +++ b/xql/src/weather_lm/gemini.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from langchain_google_genai import ChatGoogleGenerativeAI + +from .constant import few_shots +from .template import DEFINED_PROMPTS +from .utils import get_invocation_steps, get_table_map_prompt +from xql import run_query + +def nl_to_sql_query(input_statement: str) -> str: + """ + Convert a natural language query to SQL. + + Parameters: + - input_statement (str): The natural language query. + + Returns: + - str: The generated SQL query. + """ + # Check if API key is provided either directly or through environment variable + api_key = os.getenv("GOOGLE_API_KEY") + + if api_key is None: + raise RuntimeError("Environment variable GOOGLE_API_KEY is not set.") + + # Get table map + table_map_prompt, table_map = get_table_map_prompt() + + # Initialize model for natural language processing + model = ChatGoogleGenerativeAI(model="gemini-pro") + + # Get invocation steps for selecting dataset + select_dataset_model = get_invocation_steps(DEFINED_PROMPTS['select_dataset'], model) + + # Get invocation steps for generating SQL + generate_sql_model = get_invocation_steps(DEFINED_PROMPTS['generate_sql'], model) + + # Invoke pipeline to select dataset based on input statement + select_dataset_res = select_dataset_model.invoke({ "question": input_statement, "table_map": table_map_prompt }) + + # Extract dataset key from result + dataset_key = select_dataset_res.split(":")[-1].strip() + + # Retrieve dataset metadata using dataset key + dataset_metadata = table_map[dataset_key] + + # Invoke pipeline to generate SQL query + generate_sql_res = generate_sql_model.invoke({ + "question": input_statement, + "table": dataset_metadata['uri'], + "columns": dataset_metadata['columns'], + "few_shot_examples": few_shots, + "dims": dataset_metadata["dims"], + 'latitude_dim': dataset_metadata["latitude_dim"], + 'latitude_range': dataset_metadata["latitude_range"], + 'longitude_dim': dataset_metadata["longitude_dim"], + 'longitude_range': dataset_metadata["longitude_range"] + }) + + # Extract SQL query from result. + # The response will look like [SQLQuery: SELECT * FROM {table} WHERE ...]. + # So slice the sql query from string. + sql_query = generate_sql_res[11:-1] + + return sql_query + + +def nl_to_weather_data(input_statement: str): + """ + Convert a natural language query to SQL and fetch weather data. + + Parameters: + - input_statement (str): The natural language query. + """ + # Generate SQL query + sql_query = nl_to_sql_query(input_statement) + + # Print generated SQL statement for debugging + print("Generated SQL Statement:", sql_query) + + # Execute SQL query to fetch weather data + return run_query(sql_query) diff --git a/xql/src/weather_lm/template.py b/xql/src/weather_lm/template.py new file mode 100644 index 00000000..c881d645 --- /dev/null +++ b/xql/src/weather_lm/template.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from langchain.prompts import PromptTemplate + +from .constant import GENERATE_SQL_TEMPLATE, SELECT_DATASET_TEMPLATE + +DEFINED_PROMPTS = { + 'select_dataset': PromptTemplate( + input_variables = ["table_map", "question"], + template = SELECT_DATASET_TEMPLATE, + ), + 'generate_sql': PromptTemplate( + input_variables = [ + "question", + "few_shot_examples", + "table", + "columns", + "dims", + "latitude_dim", + "latitude_range", + "longitude_dim", + "longitude_range" + ], + template = GENERATE_SQL_TEMPLATE, + ) +} diff --git a/xql/src/weather_lm/utils.py b/xql/src/weather_lm/utils.py new file mode 100644 index 00000000..199a6622 --- /dev/null +++ b/xql/src/weather_lm/utils.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import typing as t + +from gcsfs import GCSFileSystem +from langchain.prompts import PromptTemplate +from langchain.schema.output_parser import StrOutputParser +from langchain_google_genai import ChatGoogleGenerativeAI + +from .constant import METADATA_URI + +def get_table_map() -> t.Dict: + """ + Load and return the table map from dataset-meta.json file. + + Returns: + dict: Dictionary containing table names as keys and their metadata as values. + """ + fs = GCSFileSystem() + table_map = {} + with fs.open(METADATA_URI) as f: + table_map = json.load(f) + return table_map + +def get_table_map_prompt() -> t.Tuple: + """ + Generate a prompt containing information about each table in the dataset. + + Returns: + tuple: A tuple containing the prompt string and the table map dictionary. + """ + table_prompts = [] + table_map = get_table_map() + for k, v in table_map.items(): + data_str = f"""Table name is {k}. + It's located at {v['uri']} and containing following columns: {', '.join(v['columns'])}""" + + table_prompts.append(data_str) + return "\n".join(table_prompts), table_map + + +def get_invocation_steps(prompt: PromptTemplate, model: ChatGoogleGenerativeAI): + """ + Get the invocation steps for a given prompt and model. + + Parameters: + - prompt (PromptTemplate): The prompt template to use. + - model (ChatGoogleGenerativeAI): The generative model to use. + + Returns: + - Pipeline: The invocation steps for the given prompt and model. + """ + chat = ( + prompt + | model.bind() + | StrOutputParser() + ) + return chat diff --git a/xql/src/xql/apply.py b/xql/src/xql/apply.py index d45f8d87..5f54f15e 100644 --- a/xql/src/xql/apply.py +++ b/xql/src/xql/apply.py @@ -466,4 +466,4 @@ def main(query: str): else: result = run_query(query) - print(result, query) + print(result)