Skip to content

Commit

Permalink
Updated README file
Browse files Browse the repository at this point in the history
  • Loading branch information
Darshan Prajapati committed Jan 1, 2024
1 parent 078c4af commit 5a7b0ff
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 23 deletions.
10 changes: 9 additions & 1 deletion xql/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,15 @@ Running SQL like queries on Xarray Datasets.
* > Note: For now, we support conditions on coordinates.
* **`group by` and `aggregate` Functions** - Aggregate functions `AVG()`, `MIN()`, `MAX()` supported after applying `group-by` on any coordinate like time.

# Usage
# Quickstart

## Prerequisites

Get an access to the dataset you want to query. As an example we're using the analysis ready era5 public dataset. [full_37-1h-0p25deg-chunk-1.zarr-v3](https://pantheon.corp.google.com/storage/browser/gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3?project=gcp-public-data-signals).

For this gcloud must be configured in the environment. [Initializing the gcloud CLI](https://cloud.google.com/sdk/docs/initializing).

## Usage

Install required packages
```
Expand Down
90 changes: 69 additions & 21 deletions xql/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
import xarray as xr
import typing as t

from xarray.core.groupby import DatasetGroupBy
from sqlglot import parse_one, exp
import readline # noqa

operate = {
"and" : lambda a, b: a & b,
Expand All @@ -31,10 +33,12 @@

aggregate_function_map = {
'avg': lambda x, y: x.mean(dim=y) if y else x.mean(),
'min': lambda x, y: x.min(dim=y) if y else x.mean(),
'max': lambda x, y: x.max(dim=y) if y else x.mean(),
'min': lambda x, y: x.min(dim=y) if y else x.min(),
'max': lambda x, y: x.max(dim=y) if y else x.max(),
'sum': lambda x, y: x.sum(dim=y) if y else x.sum(),
}


def parse(a, b):
arr_type = a.dtype.name
if arr_type == 'float64':
Expand All @@ -45,18 +49,31 @@ def parse(a, b):
b = np.datetime64(b)
return a, b


def evaluate(a: xr.DataArray, b, operator):
a, b = parse(a, b)
return operate[operator](a, b)

def inorder(exp, ds):
if(exp.key == "identifier"):
return ds[exp.args['this']]

if(exp.key == "literal"):
return exp.args['this']
def inorder(expression: exp.Expression, ds: xr.Dataset):
"""
Evaluate an expression using an xarray Dataset and return the result.
Parameters:
- expression (exp.Expression): The expression to be evaluated.
- ds (xr.Dataset): The xarray Dataset used for evaluation.
Returns:
- Any: The result of evaluating the expression on the given dataset.
"""

if(expression.key == "identifier"):
return ds[expression.args['this']]

args = exp.args
if(expression.key == "literal"):
return expression.args['this']

args = expression.args

left = args['this']
right = None
Expand All @@ -70,26 +87,54 @@ def inorder(exp, ds):
right_sol = inorder(right, ds)

if right_sol is not None:
return evaluate(left_sol, right_sol, exp.key)
return evaluate(left_sol, right_sol, expression.key)
else:
return left_sol

def apply_group_by(fields, ds: xr.Dataset, agg_funcs):

def apply_group_by(fields: t.List[str], ds: xr.Dataset, agg_funcs: t.Dict[str, str]) -> xr.Dataset:
"""
Apply group-by and aggregation operations to the dataset based on specified fields and aggregation functions.
Parameters:
- fields (List[str]): List of fields (variables or coordinates) to be used for grouping.
- ds (xarray.Dataset): The input dataset.
- agg_funcs (Dict[str, str]): Dictionary mapping aggregation function names to their corresponding xarray-compatible string representations.
Returns:
- xarray.Dataset: The dataset after applying group-by and aggregation operations.
"""

grouped_ds = ds
for field in fields:
field_parts = field.split("_")
groupby_field = field_parts[0]
if len(field_parts) > 1:
groupby_field = f"{groupby_field}.{field_parts[1]}"
groups = grouped_ds.groupby(groupby_field)
grouped_ds = apply_aggregation(groups, list(agg_funcs.values())[0])
else:
if field in ds.coords:
grouped_ds = apply_aggregation(grouped_ds, list(agg_funcs.values())[0], field)
else:
field_parts = field.split("_")
groupby_field = field_parts[0]
if len(field_parts) > 1:
groupby_field = f"{groupby_field}.{field_parts[1]}"
groups = grouped_ds.groupby(groupby_field)
grouped_ds = apply_aggregation(groups, list(agg_funcs.values())[0])
return grouped_ds

def apply_aggregation(groups, fun: str, dim: t.Optional[str] = None):

def apply_aggregation(groups: t.Union[xr.Dataset, DatasetGroupBy], fun: str, dim: t.Optional[str] = None) -> xr.Dataset:
"""
Apply aggregation to the groups based on the specified aggregation function.
Parameters:
- groups (Union[xr.Dataset, xr.core.groupby.DatasetGroupBy]): The input dataset or dataset groupby object.
- fun (str): The aggregation function to be applied.
- dim (Optional[str]): The dimension along which to apply the aggregation. If None, aggregation is applied to the entire dataset.
Returns:
- xr.Dataset: The dataset after applying the aggregation.
"""

return aggregate_function_map[fun](groups, dim)


def parse_query(query: str) -> xr.Dataset:

expr = parse_one(query)
Expand Down Expand Up @@ -117,21 +162,24 @@ def parse_query(query: str) -> xr.Dataset:
if len(agg_funcs):
data_vars = agg_funcs.keys()

ds = xr.open_zarr(table, chunks=None)
ds = xr.open_zarr(table)

if is_star is None:
ds = ds[data_vars]

mask = inorder(where, ds)
filterd_ds = ds

filterd_ds = ds.where(mask, drop=True)
if where:
mask = inorder(where, ds)
filterd_ds = ds.where(mask, drop=True)

if group_by:
groupby_fields = [ e.args['this'].args['this'] for e in group_by.args['expressions'] ]
filterd_ds = apply_group_by(groupby_fields, filterd_ds, agg_funcs)

return filterd_ds


if __name__ == "__main__":

while True:
Expand Down
7 changes: 6 additions & 1 deletion xql/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
sqlglot
fsspec
gcsfs
numpy
sqlglot
xarray
zarr

0 comments on commit 5a7b0ff

Please sign in to comment.