Skip to content

Commit

Permalink
Apply Aggregation Without Group By (#429)
Browse files Browse the repository at this point in the history
* Apply Aggregation Without Group By

* Added note

* Minor nits

* Default coord_to_squeeze to None

---------

Co-authored-by: Darshan Prajapati <[email protected]>
  • Loading branch information
DarshanSP19 and Darshan Prajapati authored Jan 17, 2024
1 parent 6fee8fb commit 21552c6
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 20 deletions.
3 changes: 2 additions & 1 deletion xql/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Running SQL like queries on Xarray Datasets. Consider dataset as a table and dat
* **`aggregate` Functions** - Aggregate functions `AVG()`, `MIN()`, `MAX()`, etc. Only supported on data variables.
* For more checkout the [road-map](https://github.com/google/weather-tools/tree/xql-init/xql#roadmap).
> Note: For now, we support `where` conditions on coordinates only.
> Note: For now, Only a single aggregate function is supported per query.
# Quickstart

Expand Down Expand Up @@ -105,7 +106,7 @@ _Updated on 2024-01-08_
2. [ ] On Variables
3. [x] **Aggregate Functions**: Only `AVG()`, `MIN()`, `MAX()`, `SUM()` are supported.
1. [x] With Group By
2. [ ] Without Group By
2. [x] Without Group By
3. [ ] Multiple Aggregate function in a single query
4. [ ] **Order By**: Only suppoted for coordinates.
5. [ ] **Limit**: Limiting the result to display.
Expand Down
75 changes: 56 additions & 19 deletions xql/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import readline # noqa

import numpy as np
import pandas as pd
import typing as t
import xarray as xr

Expand All @@ -42,10 +43,16 @@
}

aggregate_function_map = {
'avg': lambda x, y: x.mean(dim=y) if y else x.mean(),
'min': lambda x, y: x.min(dim=y) if y else x.min(),
'max': lambda x, y: x.max(dim=y) if y else x.max(),
'sum': lambda x, y: x.sum(dim=y) if y else x.sum(),
'avg': lambda x, y: x.mean(dim=y),
'min': lambda x, y: x.min(dim=y),
'max': lambda x, y: x.max(dim=y),
'sum': lambda x, y: x.sum(dim=y),
}

timestamp_formats = {
'time_date':"%Y-%m-%d",
'time_month':"%Y-%m",
'time_year': "%Y"
}

def parse(a: t.Union[xr.DataArray, str], b: t.Union[xr.DataArray, str]) -> t.Tuple[t.Union[xr.DataArray, str],
Expand Down Expand Up @@ -163,20 +170,20 @@ def apply_group_by(fields: t.List[str], ds: xr.Dataset, agg_funcs: t.Dict[str, s
"""

grouped_ds = ds
for field in fields:
if field in ds.coords:
grouped_ds = apply_aggregation(grouped_ds, list(agg_funcs.values())[0], field)
else:
field_parts = field.split("_")
groupby_field = field_parts[0]
if len(field_parts) > 1:
groupby_field = f"{groupby_field}.{field_parts[1]}"
groups = grouped_ds.groupby(groupby_field)
grouped_ds = apply_aggregation(groups, list(agg_funcs.values())[0])
time_fields = list(filter(lambda field: "time" in field, fields))

if len(time_fields) > 1:
raise NotImplementedError("GroupBy using multiple time fields is not supported.")

elif len(time_fields) == 1:
groups = grouped_ds.groupby(grouped_ds['time'].dt.strftime(timestamp_formats[time_fields[0]]))
grouped_ds = apply_aggregation(groups, list(agg_funcs.values())[0], None)
grouped_ds = grouped_ds.rename({"strftime" : time_fields[0]})

return grouped_ds


def apply_aggregation(groups: t.Union[xr.Dataset, DatasetGroupBy], fun: str, dim: t.Optional[str] = None) -> xr.Dataset:
def apply_aggregation(groups: t.Union[xr.Dataset, DatasetGroupBy], fun: str, dim: t.List[str] = []) -> xr.Dataset:
"""
Apply aggregation to the groups based on the specified aggregation function.
Expand All @@ -193,6 +200,26 @@ def apply_aggregation(groups: t.Union[xr.Dataset, DatasetGroupBy], fun: str, dim
return aggregate_function_map[fun](groups, dim)


def get_coords_to_squeeze(fields: t.List[str], ds: xr.Dataset) -> t.List[str]:
"""
Get the coordinates to squeeze from an xarray dataset.
The function identifies coordinates in the dataset that are not part of the specified fields
and are not the 'time' coordinate.
Args:
fields (List[str]): List of field names.
ds (xr.Dataset): The xarray dataset.
Returns:
List[str]: List of coordinates to squeeze.
"""
# Identify coordinates not in fields and not 'time'
coord_to_squeeze = [coord for coord in ds.coords if coord not in fields and (coord != "time")]

return coord_to_squeeze


def get_table(e: exp.Expression) -> str:
"""
Get the table name from an expression.
Expand Down Expand Up @@ -248,9 +275,14 @@ def parse_query(query: str) -> xr.Dataset:
mask = inorder(where, ds)
ds = ds.where(mask, drop=True)

coord_to_squeeze = None
if group_by:
groupby_fields = [ e.args['this'].args['this'] for e in group_by.args['expressions'] ]
ds = apply_group_by(groupby_fields, ds, agg_funcs)
fields = [ e.args['this'].args['this'] for e in group_by.args['expressions'] ]
coord_to_squeeze = get_coords_to_squeeze(fields, ds)
ds = apply_group_by(fields, ds, agg_funcs)

if len(agg_funcs):
ds = apply_aggregation(ds, list(agg_funcs.values())[0], coord_to_squeeze)

return ds

Expand Down Expand Up @@ -330,7 +362,7 @@ def display_table_dataset_map(cmd: str) -> None:

while True:

query = input("xql>")
query = input("xql> ")

if query == ".exit":
break
Expand All @@ -351,6 +383,11 @@ def display_table_dataset_map(cmd: str) -> None:
result = f"ERROR: {type(e).__name__}: {e.__str__()}."

if isinstance(result, xr.Dataset):
print(result.to_dataframe())
if len(result.coords):
print(result.to_dataframe().reset_index())
else:
result = result.compute().to_dict(data="list")
df = pd.DataFrame({ k: [v['data']] for k, v in result['data_vars'].items() })
print(df)
else:
print(result)

0 comments on commit 21552c6

Please sign in to comment.