Skip to content

Commit

Permalink
XQL: Where Clause Optimization (#435)
Browse files Browse the repository at this point in the history
* xql where clause optimization

* Measure time for query execution

* Rephrase Comments

---------

Co-authored-by: Darshan Prajapati <[email protected]>
  • Loading branch information
DarshanSP19 and Darshan Prajapati authored Feb 6, 2024
1 parent ce24141 commit d9fc06d
Show file tree
Hide file tree
Showing 5 changed files with 324 additions and 30 deletions.
6 changes: 5 additions & 1 deletion xql/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@

if __name__ == '__main__':
try:
main()
while True:
query = input("xql> ")
if query == ".exit":
break
main(query)
except ImportError as e:
raise ImportError('main function is not imported please try again.') from e

54 changes: 25 additions & 29 deletions xql/src/xql/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from sqlglot import parse_one, exp
from xarray.core.groupby import DatasetGroupBy

from .utils import timing
from .where import apply_where

command_info = {
".exit": "To exit from the current session.",
".set": "To set the dataset uri as a shortened key. e.g. .set era5 gs://{BUCKET}/dataset-uri",
Expand Down Expand Up @@ -186,7 +189,7 @@ def aggregate_variables(agg_funcs: t.List[t.Dict[str, str]],
for agg_func in agg_funcs:
variable, function = agg_func['var'], agg_func['func']
grouped_ds = ds[variable]
dims = [value for value in coords_to_squeeze if value in ds[variable].coords]
dims = [value for value in coords_to_squeeze if value in ds[variable].coords] if coords_to_squeeze else None

# If time fields are specified, group by time
if len(time_fields):
Expand Down Expand Up @@ -303,7 +306,7 @@ def parse_query(query: str) -> xr.Dataset:
for var in expr.expressions
if (var.key == "column" or (var.key == "literal" and var.args.get("is_string") is True))]

where = expr.find(exp.Where)
where_clause = expr.find(exp.Where)
group_by = expr.find(exp.Group)

agg_funcs = [
Expand All @@ -323,9 +326,8 @@ def parse_query(query: str) -> xr.Dataset:
if is_star is None:
ds = ds[data_vars]

if where:
mask = inorder(where, ds)
ds = ds.where(mask, drop=True)
if where_clause is not None:
ds = apply_where(ds, where_clause.args['this'])

coords_to_squeeze = None
time_fields = []
Expand All @@ -336,8 +338,9 @@ def parse_query(query: str) -> xr.Dataset:
ds = apply_group_by(time_fields, ds, agg_funcs, coords_to_squeeze)

if len(time_fields) == 0 and len(agg_funcs):
coords_to_squeeze.append('time')
aggregate_variables(agg_funcs, ds, time_fields, coords_to_squeeze)
if isinstance(coords_to_squeeze, t.List):
coords_to_squeeze.append('time')
ds = aggregate_variables(agg_funcs, ds, time_fields, coords_to_squeeze)

return ds

Expand Down Expand Up @@ -430,7 +433,7 @@ def display_result(result: t.Any) -> None:
else:
print(result)


@timing
def run_query(query: str) -> None:
"""
Run a query and display the result.
Expand All @@ -441,31 +444,24 @@ def run_query(query: str) -> None:
result = parse_query(query)
display_result(result)


def main():
@timing
def main(query: str):
"""
Main function for runnning this file.
"""
while True:

query = input("xql> ")
if ".help" in query:
display_help(query)

if query == ".exit":
break
elif ".set" in query:
set_dataset_table(query)

elif ".help" in query:
display_help(query)
elif ".show" in query:
display_table_dataset_map(query)

elif ".set" in query:
set_dataset_table(query)

elif ".show" in query:
display_table_dataset_map(query)

else:
try:
result = parse_query(query)
except Exception as e:
result = f"ERROR: {type(e).__name__}: {e.__str__()}."
else:
try:
result = parse_query(query)
except Exception as e:
result = f"ERROR: {type(e).__name__}: {e.__str__()}."

display_result(result)
display_result(result)
34 changes: 34 additions & 0 deletions xql/src/xql/constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#!/usr/bin/env python3
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

SUPPORTED_CUSTOM_COORDS = ['city', 'country']

COUNTRIES_BOUNDING_BOXES = {
'india': (6.5546079, 35.4940095078, 68.1766451354, 97.4025614766),
'canada': (41.6751050889, 83.23324, -140.99778, -52.6480987209),
'japan': (31.0295791692, 45.5514834662, 129.408463169, 145.543137242),
'united kingdom': (49.959999905, 58.6350001085, -7.57216793459, 1.68153079591),
'south africa': (-34.8191663551, -22.0913127581, 16.3449768409, 32.830120477),
'australia': (-44, -10, 113, 154),
'united states': (24.396308, 49.384358, -125.0, -66.93457)
}

CITIES_BOUNDING_BOXES = {
'delhi': (28.404, 28.883, 76.838, 77.348),
'new york': (40.4774, 40.9176, -74.2591, -73.7002),
'san francisco': (37.6398, 37.9298, -122.5975, -122.3210),
'los angeles': (33.7036, 34.3373, -118.6682, -118.1553),
'london': (51.3849, 51.6724, -0.3515, 0.1482)
}
28 changes: 28 additions & 0 deletions xql/src/xql/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/usr/bin/env python3
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import wraps
from time import gmtime, strftime, time

def timing(f):
"""Measure a time for any function execution."""
@wraps(f)
def wrap(*args, **kw):
ts = time()
result = f(*args, **kw)
te = time()
print(f"Query took: { strftime('%H:%M:%S', gmtime(te - ts)) }")
return result
return wrap
Loading

0 comments on commit d9fc06d

Please sign in to comment.