Skip to content

Commit

Permalink
Code Optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
Darshan Prajapati committed Jan 1, 2024
1 parent 9208c50 commit 23e2996
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 29 deletions.
129 changes: 101 additions & 28 deletions xql/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import readline # noqa

import numpy as np
import xarray as xr
import typing as t
import xarray as xr

from sqlglot import parse_one, exp
from xarray.core.groupby import DatasetGroupBy

operate = {
"and" : lambda a, b: a & b,
Expand All @@ -31,11 +34,25 @@

aggregate_function_map = {
'avg': lambda x, y: x.mean(dim=y) if y else x.mean(),
'min': lambda x, y: x.min(dim=y) if y else x.mean(),
'max': lambda x, y: x.max(dim=y) if y else x.mean(),
'min': lambda x, y: x.min(dim=y) if y else x.min(),
'max': lambda x, y: x.max(dim=y) if y else x.max(),
'sum': lambda x, y: x.sum(dim=y) if y else x.sum(),
}

def parse(a, b):
def parse(a: t.Union[xr.DataArray, str], b: t.Union[xr.DataArray, str]) -> t.Tuple[t.Union[xr.DataArray, str], t.Union[xr.DataArray, str]]:
"""
Parse input values 'a' and 'b' into NumPy arrays with compatible types for evaluation.
Parameters:
- a (Union[xr.DataArray, str]): The first input value.
- b (Union[xr.DataArray, str]): The second input value.
Returns:
- Tuple[xr.DataArray, Union[np.float64, np.float32, np.datetime64]]: Parsed NumPy arrays 'a' and 'b'.
"""

if isinstance(a, str):
a, b = b, a
arr_type = a.dtype.name
if arr_type == 'float64':
b = np.float64(b)
Expand All @@ -45,18 +62,42 @@ def parse(a, b):
b = np.datetime64(b)
return a, b

def evaluate(a: xr.DataArray, b, operator):

def evaluate(a: t.Union[xr.DataArray, str], b: t.Union[xr.DataArray, str], operator: str) -> xr.DataArray:
"""
Evaluate the expression 'a operator b' using NumPy arrays.
Parameters:
- a (Union[xr.DataArray, str]): The first input value.
- b (Union[xr.DataArray, str]): The second input value.
- operator (str): The operator to be applied.
Returns:
- xr.DataArray: The result of the evaluation.
"""
a, b = parse(a, b)
return operate[operator](a, b)

def inorder(exp, ds):
if(exp.key == "identifier"):
return ds[exp.args['this']]

if(exp.key == "literal"):
return exp.args['this']
def inorder(expression: exp.Expression, ds: xr.Dataset) -> xr.DataArray:
"""
Evaluate an expression using an xarray Dataset and return the result.
Parameters:
- expression (exp.Expression): The expression to be evaluated.
- ds (xr.Dataset): The xarray Dataset used for evaluation.
Returns:
- xr.DataArray: The result of evaluating the expression on the given dataset.
"""

if(expression.key == "identifier"):
return ds[expression.args['this']]

if(expression.key == "literal"):
return expression.args['this']

args = exp.args
args = expression.args

left = args['this']
right = None
Expand All @@ -70,26 +111,54 @@ def inorder(exp, ds):
right_sol = inorder(right, ds)

if right_sol is not None:
return evaluate(left_sol, right_sol, exp.key)
return evaluate(left_sol, right_sol, expression.key)
else:
return left_sol

def apply_group_by(fields, ds: xr.Dataset, agg_funcs):

def apply_group_by(fields: t.List[str], ds: xr.Dataset, agg_funcs: t.Dict[str, str]) -> xr.Dataset:
"""
Apply group-by and aggregation operations to the dataset based on specified fields and aggregation functions.
Parameters:
- fields (List[str]): List of fields (variables or coordinates) to be used for grouping.
- ds (xarray.Dataset): The input dataset.
- agg_funcs (Dict[str, str]): Dictionary mapping aggregation function names to their corresponding xarray-compatible string representations.
Returns:
- xarray.Dataset: The dataset after applying group-by and aggregation operations.
"""

grouped_ds = ds
for field in fields:
field_parts = field.split("_")
groupby_field = field_parts[0]
if len(field_parts) > 1:
groupby_field = f"{groupby_field}.{field_parts[1]}"
groups = grouped_ds.groupby(groupby_field)
grouped_ds = apply_aggregation(groups, list(agg_funcs.values())[0])
else:
if field in ds.coords:
grouped_ds = apply_aggregation(grouped_ds, list(agg_funcs.values())[0], field)
else:
field_parts = field.split("_")
groupby_field = field_parts[0]
if len(field_parts) > 1:
groupby_field = f"{groupby_field}.{field_parts[1]}"
groups = grouped_ds.groupby(groupby_field)
grouped_ds = apply_aggregation(groups, list(agg_funcs.values())[0])
return grouped_ds

def apply_aggregation(groups, fun: str, dim: t.Optional[str] = None):

def apply_aggregation(groups: t.Union[xr.Dataset, DatasetGroupBy], fun: str, dim: t.Optional[str] = None) -> xr.Dataset:
"""
Apply aggregation to the groups based on the specified aggregation function.
Parameters:
- groups (Union[xr.Dataset, xr.core.groupby.DatasetGroupBy]): The input dataset or dataset groupby object.
- fun (str): The aggregation function to be applied.
- dim (Optional[str]): The dimension along which to apply the aggregation. If None, aggregation is applied to the entire dataset.
Returns:
- xr.Dataset: The dataset after applying the aggregation.
"""

return aggregate_function_map[fun](groups, dim)


def parse_query(query: str) -> xr.Dataset:

expr = parse_one(query)
Expand Down Expand Up @@ -117,20 +186,21 @@ def parse_query(query: str) -> xr.Dataset:
if len(agg_funcs):
data_vars = agg_funcs.keys()

ds = xr.open_zarr(table, chunks=None)
ds = xr.open_zarr(table)

if is_star is None:
ds = ds[data_vars]

mask = inorder(where, ds)

filterd_ds = ds.where(mask, drop=True)
if where:
mask = inorder(where, ds)
ds = ds.where(mask, drop=True)

if group_by:
groupby_fields = [ e.args['this'].args['this'] for e in group_by.args['expressions'] ]
filterd_ds = apply_group_by(groupby_fields, filterd_ds, agg_funcs)
ds = apply_group_by(groupby_fields, ds, agg_funcs)

return ds

return filterd_ds

if __name__ == "__main__":

Expand All @@ -141,7 +211,10 @@ def parse_query(query: str) -> xr.Dataset:
if query == "exit":
break

result = parse_query(query)
try:
result = parse_query(query)
except:
result = "Something wrong with the query."

if isinstance(result, xr.Dataset):
print(result.to_dataframe())
Expand Down
7 changes: 6 additions & 1 deletion xql/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
sqlglot
fsspec
gcsfs
numpy
sqlglot
xarray
zarr

0 comments on commit 23e2996

Please sign in to comment.