Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Language Model Integration #445

Merged
merged 6 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 53 additions & 2 deletions xql/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ For this `gcloud` must be configured in your local environment. Refer [Initializ
pip install xql

# Jump into xql
python xql/main.py
python xql/main.py xql
```
---
### Supported meta commands
Expand Down Expand Up @@ -136,4 +136,55 @@ _Updated on 2024-01-08_
7. [ ] **Aliases**: Add support to alias while querying.
8. [ ] **Join Operations**: Support joining tables and apply query.
9. [ ] **Nested Queries**: Add support to write nested queries.
10. [ ] **Custom Aggregate Functions**: Support custom aggregate functions
10. [ ] **Custom Aggregate Functions**: Support custom aggregate functions

# `weather-lm` - Querying weather data using Natural Language prompts

Querying weather data using Natural Language prompts. This uses a gemini (large language model from Google) to generate SQL like queries and `xql` to execute that query.

# Quickstart

## Prerequisites

Google API Key is needed to initiate Language Model. Refer [Setup your API key](https://ai.google.dev/tutorials/python_quickstart#setup_your_api_key) to generate that key.

Set that key as an environment variable. Run below command.
```
export GOOGLE_API_KEY="generate_key"
```

## Usage
```
# Install required packages
pip install xql

# Jump into language model
python xql/main.py lm
```
---
### Examples
`Input Prompt`: Daily average temperature of New York for January 2015

Relevant SQL Query:
```
SELECT
AVG('temperature')
FROM
'gs://darshan-store/ar/2013-2022-full_37-1h-0p25deg-chunk-1.zarr-v3'
WHERE
time >= '2015-01-01' AND
time < '2015-02-01' AND
city = 'New York'
GROUP BY time_date
```
Output Data:
```
time_date avg_temperature
0 2015-01-01 240.978073
1 2015-01-02 243.375031
2 2015-01-03 244.584747
3 2015-01-04 249.673065
4 2015-01-05 245.650833
...
Query took: 00:01:55
```
24 changes: 22 additions & 2 deletions xql/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,38 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse

from src.weather_lm import nl_to_weather_data
from src.xql import main
from src.xql.utils import connect_dask_cluster
from typing import List, Tuple

def parse_args() -> Tuple[argparse.Namespace, List[str]]:
parser = argparse.ArgumentParser()

parser.add_argument('mode', type=str, help='Select one from [xql, lm]')

return parser.parse_known_args()


if __name__ == '__main__':
known_args, _ = parse_args()
if known_args.mode not in ["xql", "lm"]:
raise RuntimeError("Invalid mode type. Select one from [xql, lm]")
prefix = "xql" if known_args.mode == "xql" else "lm"
try:
# Connect Dask Cluster
connect_dask_cluster()
while True:
query = input("xql> ")
query = input(f"{prefix}> ")
if query == ".exit":
break
main(query)
if known_args.mode == "xql":
main(query)
else:
print(nl_to_weather_data(query))
except ImportError as e:
raise ImportError('main function is not imported please try again.') from e

8 changes: 6 additions & 2 deletions xql/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@
"toolz==0.12.0",
"xarray==2024.01.0",
"xee",
"zarr==2.17.0"
"zarr==2.17.0",
"langchain",
"langchain-experimental",
"langchain-openai",
"langchain-google-genai"
]

setup(
Expand All @@ -40,5 +44,5 @@
description=("Running SQL queries on Xarray Datasets. Consider dataset as a table and data variable as a column."),
long_description=open('README.md', 'r', encoding='utf-8').read(),
long_description_content_type='text/markdown',
python_requires='>=3.8, <3.11',
python_requires='>=3.9, <3.11',
)
16 changes: 16 additions & 0 deletions xql/src/weather_lm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/usr/bin/env python3
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .gemini import nl_to_sql_query, nl_to_weather_data #noqa
89 changes: 89 additions & 0 deletions xql/src/weather_lm/constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#!/usr/bin/env python3
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# ruff: noqa: E501

METADATA_URI = "gs://darshan-store/xql/metadata.json"

GENERATE_SQL_TEMPLATE = """You are a SQL expert. Given an input question, first create a syntactically correct SQL query to execute.
Never query for all columns from a table. You must query only the columns that are needed to answer the question.
Pay attention to use only the column names you can see in the tables below.
Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Everytime wrap table name in single quotes ('').
Specify the time range, latitude, longitude as follows: (time >= '2012-12-01' AND time < '2013-04-01').

While accessing the variable name from the table don't use "\" this.
Ex. ( SELECT MAX("vertical_velocity") FROM 'table' ) => True syntax
( SELECT MAX(\"vertical_velocity\") FROM 'table' ) => False syntax
Avoid using the 'time BETWEEN', 'latitude BETWEEN' syntax, opt for the former style instead.

Note: At present, only data variables are supported in the SELECT Clause. Coordinates (latitude, longitude, time) are not supported. Therefore,
coordinates should not be used in the SELECT Clause.

Example:

Some important data details to consider:
- Use latitude and longitude ranges for cities and countries.
- Standard aggregations are applied to the data. A unique convention for aggregation for daily, monthly and yearly are time_date, time_month and time_year.
- The WHERE clause and GROUP BY is specifically applies to coordinates variables. e.g. timestamp, latitude, longitude, and level coordinates.
For "timestamp," use time_date for grouping by date and time_month for grouping by month. Standard SQL GROUP BY operations apply only to "latitude",
"longitude", and "level" column.
- Write time always into 'YYYY-MM-DD' format. i.e. '2021-12-01'.

Please use the following format:

Question: "Question here"
SQLQuery: "SQL Query to run"

Use the following information for the database:
- Use {table} as table name.
- The dataset includes columns like {columns}. Select appropriate columns from these which are most relevant to the Question.
- Latitude range is {latitude_range}, and longitude range is {longitude_range}. Generate query accordingly.
- {latitude_dim} and {longitude_dim} are my columns for latitude and longitude so use them everywhere in query.
Ex. If lat and lon are in the {dims} then instead of latitude > x AND longitude > y use lat > x AND lon > y.
- The interpretation of the "organic" soil type is value of soil type is equal to 6.
- "Over all locations", "globally" entails iterating through "latitude" & "longitude."

Some examples of SQL queries that correspond to questions are:

{few_shot_examples}

Question: {question}"""

few_shots = {
"Aggregate precipitation over months over all locations?" : "SELECT SUM(precipitation) FROM 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3' GROUP BY latitude, longitude, time_month",
"Daily average temperature.":"""SELECT AVG(temperature) FROM "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3" GROUP BY time_date""",
"Average temperature of the Antarctic Area during last monsoon over months.":"SELECT AVG(temperature) FROM 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3' WHERE time >= '2022-06-01' AND time < '2022-11-01' AND latitude < 66.5 GROUP BY time_month",
"Average temperature over years.":"SELECT AVG(temperature) FROM 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3' GROUP BY time_year",
"Aggregate precipitation globally?" : "SELECT SUM(precipitation) FROM 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3' GROUP BY latitude, longitude",
"For January 2000" : "SELECT * from TABLE where time >= '2000-01-01 00:00:00' AND time < '2000-02-01 00:00:00' ",
"Daily average temperature of city x for January 2015?": "SELECT AVG(temperature) FROM 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3' WHERE time >= '2015-01-01' AND time < '2015-02-01' AND latitude > 40 AND latitude < 41 AND longitude > 286 AND longitude < 287 GROUP BY time_date",
"Daily min reflectivity of city x for January 2015?": "SELECT AVG(reflectivity) FROM 'ee://projects/anthromet-prod/assets/opera/instantaneous_maximum_reflectivity' WHERE time >= '2015-01-01' AND time < '2015-02-01' AND lat > 40.48 AND lat < 41.87 AND lon > -74.25 AND lon < -71.98 GROUP BY time_date"
}

SELECT_DATASET_TEMPLATE = """
I have some description of tables that stores weather related data.
Analyze and give me an table that i need to query for provided question.
Sometimes the exact column not be there in the table so select table that contains most relevant columns.
Ex. Daily average of precipitation rate asked but exact precipitation column is not there then select the table that contains relevant column like total_precipitation, precipitation_rate, total_precipitation_rate, etc.
Below is the description and input qustion
{table_map}

Question: {question}

Please use the following format:

{question}:appropriate table
"""
97 changes: 97 additions & 0 deletions xql/src/weather_lm/gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#!/usr/bin/env python3
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

from langchain_google_genai import ChatGoogleGenerativeAI

from .constant import few_shots
from .template import DEFINED_PROMPTS
from .utils import get_invocation_steps, get_table_map_prompt
from xql import run_query

def nl_to_sql_query(input_statement: str) -> str:
"""
Convert a natural language query to SQL.

Parameters:
- input_statement (str): The natural language query.

Returns:
- str: The generated SQL query.
"""
# Check if API key is provided either directly or through environment variable
api_key = os.getenv("GOOGLE_API_KEY")

if api_key is None:
raise RuntimeError("Environment variable GOOGLE_API_KEY is not set.")

# Get table map
table_map_prompt, table_map = get_table_map_prompt()

# Initialize model for natural language processing
model = ChatGoogleGenerativeAI(model="gemini-pro")

# Get invocation steps for selecting dataset
select_dataset_model = get_invocation_steps(DEFINED_PROMPTS['select_dataset'], model)

# Get invocation steps for generating SQL
generate_sql_model = get_invocation_steps(DEFINED_PROMPTS['generate_sql'], model)

# Invoke pipeline to select dataset based on input statement
select_dataset_res = select_dataset_model.invoke({ "question": input_statement, "table_map": table_map_prompt })

# Extract dataset key from result
dataset_key = select_dataset_res.split(":")[-1].strip()

# Retrieve dataset metadata using dataset key
dataset_metadata = table_map[dataset_key]

# Invoke pipeline to generate SQL query
generate_sql_res = generate_sql_model.invoke({
"question": input_statement,
"table": dataset_metadata['uri'],
"columns": dataset_metadata['columns'],
"few_shot_examples": few_shots,
"dims": dataset_metadata["dims"],
'latitude_dim': dataset_metadata["latitude_dim"],
'latitude_range': dataset_metadata["latitude_range"],
'longitude_dim': dataset_metadata["longitude_dim"],
'longitude_range': dataset_metadata["longitude_range"]
})

# Extract SQL query from result.
# The response will look like [SQLQuery: SELECT * FROM {table} WHERE ...].
# So slice the sql query from string.
sql_query = generate_sql_res[11:-1]

return sql_query


def nl_to_weather_data(input_statement: str):
"""
Convert a natural language query to SQL and fetch weather data.

Parameters:
- input_statement (str): The natural language query.
"""
# Generate SQL query
sql_query = nl_to_sql_query(input_statement)

# Print generated SQL statement for debugging
print("Generated SQL Statement:", sql_query)

# Execute SQL query to fetch weather data
return run_query(sql_query)
39 changes: 39 additions & 0 deletions xql/src/weather_lm/template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/usr/bin/env python3
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from langchain.prompts import PromptTemplate

from .constant import GENERATE_SQL_TEMPLATE, SELECT_DATASET_TEMPLATE

DEFINED_PROMPTS = {
'select_dataset': PromptTemplate(
input_variables = ["table_map", "question"],
template = SELECT_DATASET_TEMPLATE,
),
'generate_sql': PromptTemplate(
input_variables = [
"question",
"few_shot_examples",
"table",
"columns",
"dims",
"latitude_dim",
"latitude_range",
"longitude_dim",
"longitude_range"
],
template = GENERATE_SQL_TEMPLATE,
)
}
Loading
Loading