diff --git a/xql/README.md b/xql/README.md index 0f41a527..b07d2035 100644 --- a/xql/README.md +++ b/xql/README.md @@ -11,6 +11,7 @@ Running SQL like queries on Xarray Datasets. Consider dataset as a table and dat * **`aggregate` Functions** - Aggregate functions `AVG()`, `MIN()`, `MAX()`, etc. Only supported on data variables. * For more checkout the [road-map](https://github.com/google/weather-tools/tree/xql-init/xql#roadmap). > Note: For now, we support `where` conditions on coordinates only. +> Note: For now, Only a single aggregate function is supported per query. # Quickstart @@ -105,7 +106,7 @@ _Updated on 2024-01-08_ 2. [ ] On Variables 3. [x] **Aggregate Functions**: Only `AVG()`, `MIN()`, `MAX()`, `SUM()` are supported. 1. [x] With Group By - 2. [ ] Without Group By + 2. [x] Without Group By 3. [ ] Multiple Aggregate function in a single query 4. [ ] **Order By**: Only suppoted for coordinates. 5. [ ] **Limit**: Limiting the result to display. diff --git a/xql/main.py b/xql/main.py index d8a1a7ec..bf9ec61d 100644 --- a/xql/main.py +++ b/xql/main.py @@ -16,6 +16,7 @@ import readline # noqa import numpy as np +import pandas as pd import typing as t import xarray as xr @@ -42,10 +43,16 @@ } aggregate_function_map = { - 'avg': lambda x, y: x.mean(dim=y) if y else x.mean(), - 'min': lambda x, y: x.min(dim=y) if y else x.min(), - 'max': lambda x, y: x.max(dim=y) if y else x.max(), - 'sum': lambda x, y: x.sum(dim=y) if y else x.sum(), + 'avg': lambda x, y: x.mean(dim=y), + 'min': lambda x, y: x.min(dim=y), + 'max': lambda x, y: x.max(dim=y), + 'sum': lambda x, y: x.sum(dim=y), +} + +timestamp_formats = { + 'time_date':"%Y-%m-%d", + 'time_month':"%Y-%m", + 'time_year': "%Y" } def parse(a: t.Union[xr.DataArray, str], b: t.Union[xr.DataArray, str]) -> t.Tuple[t.Union[xr.DataArray, str], @@ -163,20 +170,20 @@ def apply_group_by(fields: t.List[str], ds: xr.Dataset, agg_funcs: t.Dict[str, s """ grouped_ds = ds - for field in fields: - if field in ds.coords: - grouped_ds = apply_aggregation(grouped_ds, list(agg_funcs.values())[0], field) - else: - field_parts = field.split("_") - groupby_field = field_parts[0] - if len(field_parts) > 1: - groupby_field = f"{groupby_field}.{field_parts[1]}" - groups = grouped_ds.groupby(groupby_field) - grouped_ds = apply_aggregation(groups, list(agg_funcs.values())[0]) + time_fields = list(filter(lambda field: "time" in field, fields)) + + if len(time_fields) > 1: + raise NotImplementedError("GroupBy using multiple time fields is not supported.") + + elif len(time_fields) == 1: + groups = grouped_ds.groupby(grouped_ds['time'].dt.strftime(timestamp_formats[time_fields[0]])) + grouped_ds = apply_aggregation(groups, list(agg_funcs.values())[0], None) + grouped_ds = grouped_ds.rename({"strftime" : time_fields[0]}) + return grouped_ds -def apply_aggregation(groups: t.Union[xr.Dataset, DatasetGroupBy], fun: str, dim: t.Optional[str] = None) -> xr.Dataset: +def apply_aggregation(groups: t.Union[xr.Dataset, DatasetGroupBy], fun: str, dim: t.List[str] = []) -> xr.Dataset: """ Apply aggregation to the groups based on the specified aggregation function. @@ -193,6 +200,26 @@ def apply_aggregation(groups: t.Union[xr.Dataset, DatasetGroupBy], fun: str, dim return aggregate_function_map[fun](groups, dim) +def get_coords_to_squeeze(fields: t.List[str], ds: xr.Dataset) -> t.List[str]: + """ + Get the coordinates to squeeze from an xarray dataset. + + The function identifies coordinates in the dataset that are not part of the specified fields + and are not the 'time' coordinate. + + Args: + fields (List[str]): List of field names. + ds (xr.Dataset): The xarray dataset. + + Returns: + List[str]: List of coordinates to squeeze. + """ + # Identify coordinates not in fields and not 'time' + coord_to_squeeze = [coord for coord in ds.coords if coord not in fields and (coord != "time")] + + return coord_to_squeeze + + def get_table(e: exp.Expression) -> str: """ Get the table name from an expression. @@ -248,9 +275,14 @@ def parse_query(query: str) -> xr.Dataset: mask = inorder(where, ds) ds = ds.where(mask, drop=True) + coord_to_squeeze = None if group_by: - groupby_fields = [ e.args['this'].args['this'] for e in group_by.args['expressions'] ] - ds = apply_group_by(groupby_fields, ds, agg_funcs) + fields = [ e.args['this'].args['this'] for e in group_by.args['expressions'] ] + coord_to_squeeze = get_coords_to_squeeze(fields, ds) + ds = apply_group_by(fields, ds, agg_funcs) + + if len(agg_funcs): + ds = apply_aggregation(ds, list(agg_funcs.values())[0], coord_to_squeeze) return ds @@ -330,7 +362,7 @@ def display_table_dataset_map(cmd: str) -> None: while True: - query = input("xql>") + query = input("xql> ") if query == ".exit": break @@ -351,6 +383,11 @@ def display_table_dataset_map(cmd: str) -> None: result = f"ERROR: {type(e).__name__}: {e.__str__()}." if isinstance(result, xr.Dataset): - print(result.to_dataframe()) + if len(result.coords): + print(result.to_dataframe().reset_index()) + else: + result = result.compute().to_dict(data="list") + df = pd.DataFrame({ k: [v['data']] for k, v in result['data_vars'].items() }) + print(df) else: print(result)