Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xql - Initial Commit #427

Merged
merged 6 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions xql/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# `xql` - Querying Xarray Datasets with SQL

Running SQL like quieries on Xarray Datasets. (Only works on zarr datasets for now.)

# Supported Features

* **Select Variables** - From a large dataset having hundreds of variables select only needed variables.
* **Apply Where Clouse** - A general where condition like SQL. Applicable for queries which includes data for specific time range or only for specific regions. (Conditions on coordinates suppoted for now.)
* **Group By And Aggregate Functions** - Aggregate functions `AVG()`, `MIN()`, `MAX()` suppoted after applying groupby on any coordinate like time.

# Usage

Install required packages
```
pip install -r xql/requirements.txt
```

Jump into xql
```
python xql/main.py
```
---

Running a simple query on dataset. Comparing with SQL a data variable is like a column and table is like a dataset.
```
SELECT evaporation, geopotential_at_surface, temperature FROM '{TABLE}'
```
Replace `{TABLE}` with dataset uri. Ex. `gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3`:

---
Apply a conditions. Query to get temperature of arctic region in last winter:
```
SELECT temperature FROM '{TABLE}' WHERE time >= '2022-12-01' AND time < '2023-03-01' AND latitude >= 66.5
```
---
Aggregating results using Group By and Aggregate function. Daily average of temperature of last winter in arctic region.
```
SELECT AVG(temperature) FROM '{TABLE}' WHERE time >= '2022-12-01' AND time < '2023-03-01' AND latitude >= 66.5
GROUP BY time_day
```
Replace `time_day` to `time_month` or `time_year` if monthly or yearly average is needed. Also use MIN() and MAX() functions same way as AVG().
135 changes: 135 additions & 0 deletions xql/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@

import numpy as np
import xarray as xr
import typing as t

from sqlglot import parse_one, exp

operate = {
"and" : lambda a, b: a & b,
"or" : lambda a, b: a | b,
"eq" : lambda a, b: a == b,
"gt" : lambda a, b: a > b,
"lt" : lambda a, b: a < b,
"gte" : lambda a, b: a >= b,
"lte" : lambda a, b: a <= b,
}

aggregate_function_map = {
'avg': lambda x, y: x.mean(dim=y) if y else x.mean(),
'min': lambda x, y: x.min(dim=y) if y else x.mean(),
'max': lambda x, y: x.max(dim=y) if y else x.mean(),
}

def parse(a, b):
arr_type = a.dtype.name
if arr_type == 'float64':
b = np.float64(b)
elif arr_type == 'float32':
b = np.float32(b)
elif arr_type == 'datetime64[ns]':
b = np.datetime64(b)
return a, b

def evaluate(a: xr.DataArray, b, operator):
a, b = parse(a, b)
return operate[operator](a, b)

def inorder(exp, ds):
if(exp.key == "identifier"):
return ds[exp.args['this']]

if(exp.key == "literal"):
return exp.args['this']

args = exp.args

left = args['this']
right = None
if 'expression' in args:
right = args['expression']

left_sol = inorder(left, ds)

right_sol = None
if right is not None:
right_sol = inorder(right, ds)

if right_sol is not None:
return evaluate(left_sol, right_sol, exp.key)
else:
return left_sol

def apply_group_by(fields, ds: xr.Dataset, agg_funcs):
grouped_ds = ds
for field in fields:
field_parts = field.split("_")
groupby_field = field_parts[0]
if len(field_parts) > 1:
groupby_field = f"{groupby_field}.{field_parts[1]}"
groups = grouped_ds.groupby(groupby_field)
grouped_ds = apply_aggregation(groups, list(agg_funcs.values())[0])
else:
grouped_ds = apply_aggregation(grouped_ds, list(agg_funcs.values())[0], field)
return grouped_ds

def apply_aggregation(groups, fun: str, dim: t.Optional[str] = None):
return aggregate_function_map[fun](groups, dim)

def parse_query(query: str) -> xr.Dataset:

expr = parse_one(query)

if not isinstance(expr, exp.Select):
return "ERROR: Only select queries are supported."

table = expr.find(exp.Table).args['this'].args['this']

is_star = expr.find(exp.Star)

data_vars = []
if is_star is None:
data_vars = [ var.args['this'].args['this'] for var in expr.expressions if var.key == "column" ]

where = expr.find(exp.Where)

group_by = expr.find(exp.Group)

agg_funcs = {
var.args['this'].args['this'].args['this']: var.key
for var in expr.expressions if var.key in aggregate_function_map
}

if len(agg_funcs):
data_vars = agg_funcs.keys()

ds = xr.open_zarr(table, chunks=None)

if is_star is None:
ds = ds[data_vars]

mask = inorder(where, ds)

filterd_ds = ds.where(mask, drop=True)

if group_by:
groupby_fields = [ e.args['this'].args['this'] for e in group_by.args['expressions'] ]
filterd_ds = apply_group_by(groupby_fields, filterd_ds, agg_funcs)

return filterd_ds

if __name__ == "__main__":

while True:

query = input("xql>")

if query == "exit":
break

result = parse_query(query)

if isinstance(result, xr.Dataset):
print(result.to_dataframe())
else:
print(result)
1 change: 1 addition & 0 deletions xql/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sqlglot