Skip to content

Commit

Permalink
Aggregate Function with group by running.
Browse files Browse the repository at this point in the history
  • Loading branch information
dabhicusp committed Jan 18, 2024
1 parent c79b28c commit f7d23e0
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 19 deletions.
4 changes: 2 additions & 2 deletions xql/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ python xql/main.py
time >= '2022-01-01' AND
time < '2022-02-01' AND
latitude >= 66.5
GROUP BY time_day
GROUP BY time_date
```
Replace `time_day` to `time_month` or `time_year` if monthly or yearly average is needed. Also use `MIN()` and `MAX()` functions same way as `AVG()`.
Replace `time_date` to `time_month` or `time_year` if monthly or yearly average is needed. Also use `MIN()` and `MAX()` functions same way as `AVG()`.
3. `caveat`: Above queries run on the client's local machine and it generates a large two dimensional array so querying for very large amount of data will fall into out of memory erros.
Expand Down
47 changes: 30 additions & 17 deletions xql/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,37 +155,46 @@ def apply_order_by(fields: t.List[str], ds: xr.Dataset) -> xr.Dataset:
return ordered_ds


def apply_group_by(fields: t.List[str], ds: xr.Dataset, agg_funcs: t.Dict[str, str]) -> xr.Dataset:
def apply_group_by(time_fields: t.List[str], ds: xr.Dataset, agg_funcs: t.Dict[str, str], dim: t.List[str] = []) -> xr.Dataset:
"""
Apply group-by and aggregation operations to the dataset based on specified fields and aggregation functions.
Parameters:
- fields (List[str]): List of fields (variables or coordinates) to be used for grouping.
- time_fields (List[str]): List of time_fields(coordinates) to be used for grouping.
- ds (xarray.Dataset): The input dataset.
- agg_funcs (Dict[str, str]): Dictionary mapping aggregation function names to their corresponding
xarray-compatible string representations.
- dim (Optional[str]): The dimension along which to apply the aggregation. If None, aggregation is applied
to the entire dataset.
Returns:
- xarray.Dataset: The dataset after applying group-by and aggregation operations.
"""

grouped_ds = ds
time_fields = list(filter(lambda field: "time" in field, fields))

if len(time_fields) > 1:
raise NotImplementedError("GroupBy using multiple time fields is not supported.")

elif len(time_fields) == 1:
groups = grouped_ds.groupby(grouped_ds['time'].dt.strftime(timestamp_formats[time_fields[0]]))
grouped_ds = apply_aggregation(groups, list(agg_funcs.values())[0], None)
grouped_ds = grouped_ds.rename({"strftime" : time_fields[0]})
agg_datasets = []
for agg_func in agg_funcs:
grouped_ds = ds[[agg_func['var']]]
groups = grouped_ds.groupby(grouped_ds['time'].dt.strftime(timestamp_formats[time_fields[0]]))
grouped_ds = apply_aggregation(groups, agg_func['func'], None)
grouped_ds = grouped_ds.rename({"strftime" : time_fields[0]})
grouped_ds = apply_aggregation(grouped_ds, agg_func['func'], dim)
grouped_ds = grouped_ds.rename({agg_func['var'] : f"{agg_func['func']}_{agg_func['var']}"})
agg_datasets.append(grouped_ds)
grouped_ds = xr.merge(agg_datasets)


return grouped_ds


def apply_aggregation(groups: t.Union[xr.Dataset, DatasetGroupBy], fun: str, dim: t.List[str] = []) -> xr.Dataset:
"""
Apply aggregation to the groups based on the specified aggregation function.
Apply aggregation to the groups based` on the specified aggregation function.
Parameters:
- groups (Union[xr.Dataset, xr.core.groupby.DatasetGroupBy]): The input dataset or dataset groupby object.
Expand Down Expand Up @@ -282,19 +291,23 @@ def parse_query(query: str) -> xr.Dataset:
ds = ds.where(mask, drop=True)

coord_to_squeeze = None
time_fields = []
if group_by:
fields = [ e.args['this'].args['this'] for e in group_by.args['expressions'] ]
time_fields = list(filter(lambda field: "time" in field, fields))
coord_to_squeeze = get_coords_to_squeeze(fields, ds)
ds = apply_group_by(fields, ds, agg_funcs)

agg_datasets = []
for agg_func in agg_funcs:
key, value = agg_func['var'], agg_func['func']
agg_result = apply_aggregation(ds[[key]], value, coord_to_squeeze)
agg_result = agg_result.rename({key : f'{value}_{key}'})
agg_datasets.append(agg_result)
if len(agg_funcs):
ds = xr.merge(agg_datasets)
ds = apply_group_by(time_fields, ds, agg_funcs, coord_to_squeeze)

if len(time_fields) == 0:
coord_to_squeeze.append('time')
agg_datasets = []
for agg_func in agg_funcs:
key, value = agg_func['var'], agg_func['func']
agg_result = apply_aggregation(ds[[key]], value, coord_to_squeeze)
agg_result = agg_result.rename({key : f'{value}_{key}'})
agg_datasets.append(agg_result)
if len(agg_funcs):
ds = xr.merge(agg_datasets)

return ds

Expand Down

0 comments on commit f7d23e0

Please sign in to comment.