Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose attributes used when aggregating/combining datasets #268

Merged
merged 3 commits into from
Aug 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 2 additions & 11 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,5 @@ For more details and examples, refer to the relevant chapters in the main part o
ESM Datastore (intake.open_esm_datastore)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. module:: intake_esm.core

.. autoclass:: esm_datastore

.. automethod:: __getitem__
.. automethod:: from_df
.. automethod:: to_dataset_dict
.. automethod:: nunique
.. automethod:: unique
.. automethod:: search
.. automethod:: serialize
.. automodule:: intake_esm.core
:members:
265 changes: 208 additions & 57 deletions intake_esm/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
from typing import Any, Dict, List, Union
from collections import namedtuple
from typing import Any, Dict, List, Tuple, Union
from warnings import warn

import dask
Expand Down Expand Up @@ -112,10 +113,52 @@ def __init__(
super(esm_datastore, self).__init__(**kwargs)

def _set_groups_and_keys(self):
self._grouped = self.df.groupby(self.aggregation_info['groupby_attrs'])
self._grouped = self.df.groupby(self.aggregation_info.groupby_attrs)
self._keys = list(self._grouped.groups.keys())

def _allnan_or_nonan(self, column: str) -> bool:
"""
Helper function used to filter groupby_attrs to ensure no columns with all nans

Parameters
----------
column : str
Column name

Returns
-------
bool
Whether the dataframe column has all NaNs or no NaN valles

Raises
------
ValueError
When the column has a mix of NaNs non NaN values
"""
if self.df[column].isnull().all():
return False
if self.df[column].isnull().any():
raise ValueError(
f'The data in the {column} column should either be all NaN or there should be no NaNs'
)
return True

def _get_aggregation_info(self):

AggregationInfo = namedtuple(
'AggregationInfo',
[
'groupby_attrs',
'variable_column_name',
'aggregations',
'agg_columns',
'aggregation_dict',
'path_column_name',
'data_format',
'format_column_name',
],
)

groupby_attrs = []
data_format = None
format_column_name = None
Expand All @@ -131,45 +174,27 @@ def _get_aggregation_info(self):
format_column_name = self.esmcol_data['assets']['format_column_name']

if 'aggregation_control' in self.esmcol_data:
aggregation_dict = {}
variable_column_name = self.esmcol_data['aggregation_control']['variable_column_name']
groupby_attrs = self.esmcol_data['aggregation_control'].get('groupby_attrs', [])
aggregations = self.esmcol_data['aggregation_control'].get('aggregations', [])
# Sort aggregations to make sure join_existing is always done before join_new
aggregations = sorted(aggregations, key=lambda i: i['type'], reverse=True)
for agg in aggregations:
key = agg['attribute_name']
rest = agg.copy()
del rest['attribute_name']
aggregation_dict[key] = rest
agg_columns = list(aggregation_dict.keys())
aggregations, aggregation_dict, agg_columns = _construct_agg_info(aggregations)

if not groupby_attrs:
groupby_attrs = self.df.columns.tolist()

# filter groupby_attrs to ensure no columns with all nans
def _allnan_or_nonan(column):
if self.df[column].isnull().all():
return False
if self.df[column].isnull().any():
raise ValueError(
f'The data in the {column} column should either be all NaN or there should be no NaNs'
)
return True

groupby_attrs = list(filter(_allnan_or_nonan, groupby_attrs))

info = {
'groupby_attrs': groupby_attrs,
'variable_column_name': variable_column_name,
'aggregations': aggregations,
'agg_columns': agg_columns,
'aggregation_dict': aggregation_dict,
'path_column_name': path_column_name,
'data_format': data_format,
'format_column_name': format_column_name,
}
return info
groupby_attrs = list(filter(self._allnan_or_nonan, groupby_attrs))

aggregation_info = AggregationInfo(
groupby_attrs,
variable_column_name,
aggregations,
agg_columns,
aggregation_dict,
path_column_name,
data_format,
format_column_name,
)
return aggregation_info

def keys(self) -> List:
"""
Expand All @@ -193,7 +218,98 @@ def key_template(self) -> str:
str
string template used to create catalog entry keys
"""
return self.sep.join(self.aggregation_info['groupby_attrs'])
return self.sep.join(self.aggregation_info.groupby_attrs)

@property
def df(self) -> pd.DataFrame:
"""
Return pandas :py:class:`~pandas.DataFrame`.
"""
return self._df

@df.setter
def df(self, value: pd.DataFrame):
self._df = value
self._set_groups_and_keys()

@property
def groupby_attrs(self) -> list:
"""
Dataframe columns used to determine groups of compatible datasets.

Returns
-------
list
Columns used to determine groups of compatible datasets.
"""
return self.aggregation_info.groupby_attrs

@groupby_attrs.setter
def groupby_attrs(self, value: list) -> None:
groupby_attrs = list(filter(self._allnan_or_nonan, value))
self.aggregation_info = self.aggregation_info._replace(groupby_attrs=groupby_attrs)
self._set_groups_and_keys()

@property
def variable_column_name(self) -> str:
"""
Name of the column that contains the variable name.
"""
return self.aggregation_info.variable_column_name

@variable_column_name.setter
def variable_column_name(self, value: str) -> None:
self.aggregation_info = self.aggregation_info._replace(variable_column_name=value)

@property
def aggregations(self):
return self.aggregation_info.aggregations

@property
def agg_columns(self) -> list:
"""
List of columns used to merge/concatenate compatible
multiple :py:class:`~xarray.Dataset` into a single :py:class:`~xarray.Dataset`.
"""
return self.aggregation_info.agg_columns

@property
def aggregation_dict(self) -> dict:
return self.aggregation_info.aggregation_dict

@property
def path_column_name(self) -> str:
"""
The name of the column containing the path to the asset.
"""
return self.aggregation_info.path_column_name

@path_column_name.setter
def path_column_name(self, value: str) -> None:
self.aggregation_info = self.aggregation_info._replace(path_column_name=value)

@property
def data_format(self) -> str:
"""
The data format. Valid values are netcdf and zarr.
If specified, it means that all data assets in the catalog use the same data format.
"""
return self.aggregation_info.data_format

@data_format.setter
def data_format(self, value: str) -> None:
self.aggregation_info = self.aggregation_info._replace(data_format=value)

@property
def format_column_name(self) -> str:
"""
Name of the column which contains the data format.
"""
return self.aggregation_info.format_column_name

@format_column_name.setter
def format_column_name(self, value: str) -> None:
self.aggregation_info = self.aggregation_info._replace(format_column_name=value)

def __len__(self):
return len(self.keys())
Expand Down Expand Up @@ -288,6 +404,14 @@ def __dir__(self):
'unique',
'nunique',
'key_template',
'groupby_attrs',
'variable_column_name',
'aggregations',
'agg_columns',
'aggregation_dict',
'path_column_name',
'data_format',
'format_column_name',
]
return sorted(list(self.__dict__.keys()) + rv)

Expand Down Expand Up @@ -324,7 +448,7 @@ def from_df(

Returns
-------
intake_esm.core.esm_datastore
:py:class:`~intake_esm.core.esm_datastore`
Catalog object
"""
return cls(
Expand All @@ -336,18 +460,6 @@ def from_df(
**kwargs,
)

@property
def df(self) -> pd.DataFrame:
"""
Return pandas dataframe.
"""
return self._df

@df.setter
def df(self, value: pd.DataFrame):
self._df = value
self._set_groups_and_keys()

def search(self, require_all_on: Union[str, List] = None, **query):
"""Search for entries in the catalog.

Expand Down Expand Up @@ -386,7 +498,7 @@ def search(self, require_all_on: Union[str, List] = None, **query):
401 CMIP BCC BCC-CSM2-MR ... gn gs://cmip6/CMIP/BCC/BCC-CSM2-MR/historical/r3i... NaN

The search method also accepts compiled regular expression objects
from :py:method:`~re.compile` as patterns.
from :py:func:`~re.compile` as patterns.

>>> import re
>>> # Let's search for variables containing "Frac" in their name
Expand Down Expand Up @@ -568,9 +680,9 @@ def to_dataset_dict(
Parameters
----------
zarr_kwargs : dict
Keyword arguments to pass to `xarray.open_zarr()` function
Keyword arguments to pass to :py:func:`~xarray.open_zarr` function
cdf_kwargs : dict
Keyword arguments to pass to `xarray.open_dataset()` function
Keyword arguments to pass to :py:func:`~xarray.open_dataset` function
preprocess : callable, optional
If provided, call this function on each dataset prior to aggregation.
aggregate : bool, optional
Expand All @@ -585,7 +697,7 @@ def to_dataset_dict(
Returns
-------
dsets : dict
A dictionary of xarray :py:class:`~xarray.Dataset`s.
A dictionary of xarray :py:class:`~xarray.Dataset`.

Examples
--------
Expand Down Expand Up @@ -671,14 +783,53 @@ def _load_source(source):
def _make_entry(key, df, aggregation_info):
args = dict(
df=df,
aggregation_dict=aggregation_info['aggregation_dict'],
path_column=aggregation_info['path_column_name'],
variable_column=aggregation_info['variable_column_name'],
data_format=aggregation_info['data_format'],
format_column=aggregation_info['format_column_name'],
aggregation_dict=aggregation_info.aggregation_dict,
path_column=aggregation_info.path_column_name,
variable_column=aggregation_info.variable_column_name,
data_format=aggregation_info.data_format,
format_column=aggregation_info.format_column_name,
key=key,
)
entry = intake.catalog.local.LocalCatalogEntry(
name=key, description='', driver='esm_group', args=args, metadata={}
)
return entry.get()


def _construct_agg_info(aggregations: List[Dict]) -> Tuple[List[Dict], Dict, List]:
"""
Helper function used to determine aggregation columns information and their
respective settings.

Examples
--------

>>> a = [{'type': 'union', 'attribute_name': 'variable_id'},
... {'type': 'join_new',
... 'attribute_name': 'member_id',
... 'options': {'coords': 'minimal', 'compat': 'override'}},
... {'type': 'join_new',
... 'attribute_name': 'dcpp_init_year',
... 'options': {'coords': 'minimal', 'compat': 'override'}}]
>>> aggregations, aggregation_dict, agg_columns = _construct_agg_info(a)
>>> agg_columns
['variable_id', 'member_id', 'dcpp_init_year']
>>> aggregation_dict
{'variable_id': {'type': 'union'},
'member_id': {'type': 'join_new',
'options': {'coords': 'minimal', 'compat': 'override'}},
'dcpp_init_year': {'type': 'join_new',
'options': {'coords': 'minimal', 'compat': 'override'}}}
"""
agg_columns = []
aggregation_dict = {}
if aggregations:
# Sort aggregations to make sure join_existing is always done before join_new
aggregations = sorted(aggregations, key=lambda i: i['type'], reverse=True)
for agg in aggregations:
key = agg['attribute_name']
rest = agg.copy()
del rest['attribute_name']
aggregation_dict[key] = rest
agg_columns = list(aggregation_dict.keys())
return aggregations, aggregation_dict, agg_columns
Loading