Skip to content

Commit

Permalink
Expose attributes used when aggregating/combining datasets (#268)
Browse files Browse the repository at this point in the history
  • Loading branch information
andersy005 authored Aug 4, 2020
1 parent 20b5272 commit d7e8e82
Show file tree
Hide file tree
Showing 3 changed files with 327 additions and 68 deletions.
13 changes: 2 additions & 11 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,5 @@ For more details and examples, refer to the relevant chapters in the main part o
ESM Datastore (intake.open_esm_datastore)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. module:: intake_esm.core

.. autoclass:: esm_datastore

.. automethod:: __getitem__
.. automethod:: from_df
.. automethod:: to_dataset_dict
.. automethod:: nunique
.. automethod:: unique
.. automethod:: search
.. automethod:: serialize
.. automodule:: intake_esm.core
:members:
265 changes: 208 additions & 57 deletions intake_esm/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
from typing import Any, Dict, List, Union
from collections import namedtuple
from typing import Any, Dict, List, Tuple, Union
from warnings import warn

import dask
Expand Down Expand Up @@ -112,10 +113,52 @@ def __init__(
super(esm_datastore, self).__init__(**kwargs)

def _set_groups_and_keys(self):
self._grouped = self.df.groupby(self.aggregation_info['groupby_attrs'])
self._grouped = self.df.groupby(self.aggregation_info.groupby_attrs)
self._keys = list(self._grouped.groups.keys())

def _allnan_or_nonan(self, column: str) -> bool:
"""
Helper function used to filter groupby_attrs to ensure no columns with all nans
Parameters
----------
column : str
Column name
Returns
-------
bool
Whether the dataframe column has all NaNs or no NaN valles
Raises
------
ValueError
When the column has a mix of NaNs non NaN values
"""
if self.df[column].isnull().all():
return False
if self.df[column].isnull().any():
raise ValueError(
f'The data in the {column} column should either be all NaN or there should be no NaNs'
)
return True

def _get_aggregation_info(self):

AggregationInfo = namedtuple(
'AggregationInfo',
[
'groupby_attrs',
'variable_column_name',
'aggregations',
'agg_columns',
'aggregation_dict',
'path_column_name',
'data_format',
'format_column_name',
],
)

groupby_attrs = []
data_format = None
format_column_name = None
Expand All @@ -131,45 +174,27 @@ def _get_aggregation_info(self):
format_column_name = self.esmcol_data['assets']['format_column_name']

if 'aggregation_control' in self.esmcol_data:
aggregation_dict = {}
variable_column_name = self.esmcol_data['aggregation_control']['variable_column_name']
groupby_attrs = self.esmcol_data['aggregation_control'].get('groupby_attrs', [])
aggregations = self.esmcol_data['aggregation_control'].get('aggregations', [])
# Sort aggregations to make sure join_existing is always done before join_new
aggregations = sorted(aggregations, key=lambda i: i['type'], reverse=True)
for agg in aggregations:
key = agg['attribute_name']
rest = agg.copy()
del rest['attribute_name']
aggregation_dict[key] = rest
agg_columns = list(aggregation_dict.keys())
aggregations, aggregation_dict, agg_columns = _construct_agg_info(aggregations)

if not groupby_attrs:
groupby_attrs = self.df.columns.tolist()

# filter groupby_attrs to ensure no columns with all nans
def _allnan_or_nonan(column):
if self.df[column].isnull().all():
return False
if self.df[column].isnull().any():
raise ValueError(
f'The data in the {column} column should either be all NaN or there should be no NaNs'
)
return True

groupby_attrs = list(filter(_allnan_or_nonan, groupby_attrs))

info = {
'groupby_attrs': groupby_attrs,
'variable_column_name': variable_column_name,
'aggregations': aggregations,
'agg_columns': agg_columns,
'aggregation_dict': aggregation_dict,
'path_column_name': path_column_name,
'data_format': data_format,
'format_column_name': format_column_name,
}
return info
groupby_attrs = list(filter(self._allnan_or_nonan, groupby_attrs))

aggregation_info = AggregationInfo(
groupby_attrs,
variable_column_name,
aggregations,
agg_columns,
aggregation_dict,
path_column_name,
data_format,
format_column_name,
)
return aggregation_info

def keys(self) -> List:
"""
Expand All @@ -193,7 +218,98 @@ def key_template(self) -> str:
str
string template used to create catalog entry keys
"""
return self.sep.join(self.aggregation_info['groupby_attrs'])
return self.sep.join(self.aggregation_info.groupby_attrs)

@property
def df(self) -> pd.DataFrame:
"""
Return pandas :py:class:`~pandas.DataFrame`.
"""
return self._df

@df.setter
def df(self, value: pd.DataFrame):
self._df = value
self._set_groups_and_keys()

@property
def groupby_attrs(self) -> list:
"""
Dataframe columns used to determine groups of compatible datasets.
Returns
-------
list
Columns used to determine groups of compatible datasets.
"""
return self.aggregation_info.groupby_attrs

@groupby_attrs.setter
def groupby_attrs(self, value: list) -> None:
groupby_attrs = list(filter(self._allnan_or_nonan, value))
self.aggregation_info = self.aggregation_info._replace(groupby_attrs=groupby_attrs)
self._set_groups_and_keys()

@property
def variable_column_name(self) -> str:
"""
Name of the column that contains the variable name.
"""
return self.aggregation_info.variable_column_name

@variable_column_name.setter
def variable_column_name(self, value: str) -> None:
self.aggregation_info = self.aggregation_info._replace(variable_column_name=value)

@property
def aggregations(self):
return self.aggregation_info.aggregations

@property
def agg_columns(self) -> list:
"""
List of columns used to merge/concatenate compatible
multiple :py:class:`~xarray.Dataset` into a single :py:class:`~xarray.Dataset`.
"""
return self.aggregation_info.agg_columns

@property
def aggregation_dict(self) -> dict:
return self.aggregation_info.aggregation_dict

@property
def path_column_name(self) -> str:
"""
The name of the column containing the path to the asset.
"""
return self.aggregation_info.path_column_name

@path_column_name.setter
def path_column_name(self, value: str) -> None:
self.aggregation_info = self.aggregation_info._replace(path_column_name=value)

@property
def data_format(self) -> str:
"""
The data format. Valid values are netcdf and zarr.
If specified, it means that all data assets in the catalog use the same data format.
"""
return self.aggregation_info.data_format

@data_format.setter
def data_format(self, value: str) -> None:
self.aggregation_info = self.aggregation_info._replace(data_format=value)

@property
def format_column_name(self) -> str:
"""
Name of the column which contains the data format.
"""
return self.aggregation_info.format_column_name

@format_column_name.setter
def format_column_name(self, value: str) -> None:
self.aggregation_info = self.aggregation_info._replace(format_column_name=value)

def __len__(self):
return len(self.keys())
Expand Down Expand Up @@ -288,6 +404,14 @@ def __dir__(self):
'unique',
'nunique',
'key_template',
'groupby_attrs',
'variable_column_name',
'aggregations',
'agg_columns',
'aggregation_dict',
'path_column_name',
'data_format',
'format_column_name',
]
return sorted(list(self.__dict__.keys()) + rv)

Expand Down Expand Up @@ -324,7 +448,7 @@ def from_df(
Returns
-------
intake_esm.core.esm_datastore
:py:class:`~intake_esm.core.esm_datastore`
Catalog object
"""
return cls(
Expand All @@ -336,18 +460,6 @@ def from_df(
**kwargs,
)

@property
def df(self) -> pd.DataFrame:
"""
Return pandas dataframe.
"""
return self._df

@df.setter
def df(self, value: pd.DataFrame):
self._df = value
self._set_groups_and_keys()

def search(self, require_all_on: Union[str, List] = None, **query):
"""Search for entries in the catalog.
Expand Down Expand Up @@ -386,7 +498,7 @@ def search(self, require_all_on: Union[str, List] = None, **query):
401 CMIP BCC BCC-CSM2-MR ... gn gs://cmip6/CMIP/BCC/BCC-CSM2-MR/historical/r3i... NaN
The search method also accepts compiled regular expression objects
from :py:method:`~re.compile` as patterns.
from :py:func:`~re.compile` as patterns.
>>> import re
>>> # Let's search for variables containing "Frac" in their name
Expand Down Expand Up @@ -568,9 +680,9 @@ def to_dataset_dict(
Parameters
----------
zarr_kwargs : dict
Keyword arguments to pass to `xarray.open_zarr()` function
Keyword arguments to pass to :py:func:`~xarray.open_zarr` function
cdf_kwargs : dict
Keyword arguments to pass to `xarray.open_dataset()` function
Keyword arguments to pass to :py:func:`~xarray.open_dataset` function
preprocess : callable, optional
If provided, call this function on each dataset prior to aggregation.
aggregate : bool, optional
Expand All @@ -585,7 +697,7 @@ def to_dataset_dict(
Returns
-------
dsets : dict
A dictionary of xarray :py:class:`~xarray.Dataset`s.
A dictionary of xarray :py:class:`~xarray.Dataset`.
Examples
--------
Expand Down Expand Up @@ -671,14 +783,53 @@ def _load_source(source):
def _make_entry(key, df, aggregation_info):
args = dict(
df=df,
aggregation_dict=aggregation_info['aggregation_dict'],
path_column=aggregation_info['path_column_name'],
variable_column=aggregation_info['variable_column_name'],
data_format=aggregation_info['data_format'],
format_column=aggregation_info['format_column_name'],
aggregation_dict=aggregation_info.aggregation_dict,
path_column=aggregation_info.path_column_name,
variable_column=aggregation_info.variable_column_name,
data_format=aggregation_info.data_format,
format_column=aggregation_info.format_column_name,
key=key,
)
entry = intake.catalog.local.LocalCatalogEntry(
name=key, description='', driver='esm_group', args=args, metadata={}
)
return entry.get()


def _construct_agg_info(aggregations: List[Dict]) -> Tuple[List[Dict], Dict, List]:
"""
Helper function used to determine aggregation columns information and their
respective settings.
Examples
--------
>>> a = [{'type': 'union', 'attribute_name': 'variable_id'},
... {'type': 'join_new',
... 'attribute_name': 'member_id',
... 'options': {'coords': 'minimal', 'compat': 'override'}},
... {'type': 'join_new',
... 'attribute_name': 'dcpp_init_year',
... 'options': {'coords': 'minimal', 'compat': 'override'}}]
>>> aggregations, aggregation_dict, agg_columns = _construct_agg_info(a)
>>> agg_columns
['variable_id', 'member_id', 'dcpp_init_year']
>>> aggregation_dict
{'variable_id': {'type': 'union'},
'member_id': {'type': 'join_new',
'options': {'coords': 'minimal', 'compat': 'override'}},
'dcpp_init_year': {'type': 'join_new',
'options': {'coords': 'minimal', 'compat': 'override'}}}
"""
agg_columns = []
aggregation_dict = {}
if aggregations:
# Sort aggregations to make sure join_existing is always done before join_new
aggregations = sorted(aggregations, key=lambda i: i['type'], reverse=True)
for agg in aggregations:
key = agg['attribute_name']
rest = agg.copy()
del rest['attribute_name']
aggregation_dict[key] = rest
agg_columns = list(aggregation_dict.keys())
return aggregations, aggregation_dict, agg_columns
Loading

0 comments on commit d7e8e82

Please sign in to comment.