diff --git a/docs/source/api.rst b/docs/source/api.rst index 41ec70e7..d1a13328 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -8,14 +8,5 @@ For more details and examples, refer to the relevant chapters in the main part o ESM Datastore (intake.open_esm_datastore) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. module:: intake_esm.core - -.. autoclass:: esm_datastore - - .. automethod:: __getitem__ - .. automethod:: from_df - .. automethod:: to_dataset_dict - .. automethod:: nunique - .. automethod:: unique - .. automethod:: search - .. automethod:: serialize +.. automodule:: intake_esm.core + :members: diff --git a/intake_esm/core.py b/intake_esm/core.py index dd852ac7..acae1492 100644 --- a/intake_esm/core.py +++ b/intake_esm/core.py @@ -1,6 +1,7 @@ import json import logging -from typing import Any, Dict, List, Union +from collections import namedtuple +from typing import Any, Dict, List, Tuple, Union from warnings import warn import dask @@ -112,10 +113,52 @@ def __init__( super(esm_datastore, self).__init__(**kwargs) def _set_groups_and_keys(self): - self._grouped = self.df.groupby(self.aggregation_info['groupby_attrs']) + self._grouped = self.df.groupby(self.aggregation_info.groupby_attrs) self._keys = list(self._grouped.groups.keys()) + def _allnan_or_nonan(self, column: str) -> bool: + """ + Helper function used to filter groupby_attrs to ensure no columns with all nans + + Parameters + ---------- + column : str + Column name + + Returns + ------- + bool + Whether the dataframe column has all NaNs or no NaN valles + + Raises + ------ + ValueError + When the column has a mix of NaNs non NaN values + """ + if self.df[column].isnull().all(): + return False + if self.df[column].isnull().any(): + raise ValueError( + f'The data in the {column} column should either be all NaN or there should be no NaNs' + ) + return True + def _get_aggregation_info(self): + + AggregationInfo = namedtuple( + 'AggregationInfo', + [ + 'groupby_attrs', + 'variable_column_name', + 'aggregations', + 'agg_columns', + 'aggregation_dict', + 'path_column_name', + 'data_format', + 'format_column_name', + ], + ) + groupby_attrs = [] data_format = None format_column_name = None @@ -131,45 +174,27 @@ def _get_aggregation_info(self): format_column_name = self.esmcol_data['assets']['format_column_name'] if 'aggregation_control' in self.esmcol_data: - aggregation_dict = {} variable_column_name = self.esmcol_data['aggregation_control']['variable_column_name'] groupby_attrs = self.esmcol_data['aggregation_control'].get('groupby_attrs', []) aggregations = self.esmcol_data['aggregation_control'].get('aggregations', []) - # Sort aggregations to make sure join_existing is always done before join_new - aggregations = sorted(aggregations, key=lambda i: i['type'], reverse=True) - for agg in aggregations: - key = agg['attribute_name'] - rest = agg.copy() - del rest['attribute_name'] - aggregation_dict[key] = rest - agg_columns = list(aggregation_dict.keys()) + aggregations, aggregation_dict, agg_columns = _construct_agg_info(aggregations) if not groupby_attrs: groupby_attrs = self.df.columns.tolist() - # filter groupby_attrs to ensure no columns with all nans - def _allnan_or_nonan(column): - if self.df[column].isnull().all(): - return False - if self.df[column].isnull().any(): - raise ValueError( - f'The data in the {column} column should either be all NaN or there should be no NaNs' - ) - return True - - groupby_attrs = list(filter(_allnan_or_nonan, groupby_attrs)) - - info = { - 'groupby_attrs': groupby_attrs, - 'variable_column_name': variable_column_name, - 'aggregations': aggregations, - 'agg_columns': agg_columns, - 'aggregation_dict': aggregation_dict, - 'path_column_name': path_column_name, - 'data_format': data_format, - 'format_column_name': format_column_name, - } - return info + groupby_attrs = list(filter(self._allnan_or_nonan, groupby_attrs)) + + aggregation_info = AggregationInfo( + groupby_attrs, + variable_column_name, + aggregations, + agg_columns, + aggregation_dict, + path_column_name, + data_format, + format_column_name, + ) + return aggregation_info def keys(self) -> List: """ @@ -193,7 +218,98 @@ def key_template(self) -> str: str string template used to create catalog entry keys """ - return self.sep.join(self.aggregation_info['groupby_attrs']) + return self.sep.join(self.aggregation_info.groupby_attrs) + + @property + def df(self) -> pd.DataFrame: + """ + Return pandas :py:class:`~pandas.DataFrame`. + """ + return self._df + + @df.setter + def df(self, value: pd.DataFrame): + self._df = value + self._set_groups_and_keys() + + @property + def groupby_attrs(self) -> list: + """ + Dataframe columns used to determine groups of compatible datasets. + + Returns + ------- + list + Columns used to determine groups of compatible datasets. + """ + return self.aggregation_info.groupby_attrs + + @groupby_attrs.setter + def groupby_attrs(self, value: list) -> None: + groupby_attrs = list(filter(self._allnan_or_nonan, value)) + self.aggregation_info = self.aggregation_info._replace(groupby_attrs=groupby_attrs) + self._set_groups_and_keys() + + @property + def variable_column_name(self) -> str: + """ + Name of the column that contains the variable name. + """ + return self.aggregation_info.variable_column_name + + @variable_column_name.setter + def variable_column_name(self, value: str) -> None: + self.aggregation_info = self.aggregation_info._replace(variable_column_name=value) + + @property + def aggregations(self): + return self.aggregation_info.aggregations + + @property + def agg_columns(self) -> list: + """ + List of columns used to merge/concatenate compatible + multiple :py:class:`~xarray.Dataset` into a single :py:class:`~xarray.Dataset`. + """ + return self.aggregation_info.agg_columns + + @property + def aggregation_dict(self) -> dict: + return self.aggregation_info.aggregation_dict + + @property + def path_column_name(self) -> str: + """ + The name of the column containing the path to the asset. + """ + return self.aggregation_info.path_column_name + + @path_column_name.setter + def path_column_name(self, value: str) -> None: + self.aggregation_info = self.aggregation_info._replace(path_column_name=value) + + @property + def data_format(self) -> str: + """ + The data format. Valid values are netcdf and zarr. + If specified, it means that all data assets in the catalog use the same data format. + """ + return self.aggregation_info.data_format + + @data_format.setter + def data_format(self, value: str) -> None: + self.aggregation_info = self.aggregation_info._replace(data_format=value) + + @property + def format_column_name(self) -> str: + """ + Name of the column which contains the data format. + """ + return self.aggregation_info.format_column_name + + @format_column_name.setter + def format_column_name(self, value: str) -> None: + self.aggregation_info = self.aggregation_info._replace(format_column_name=value) def __len__(self): return len(self.keys()) @@ -288,6 +404,14 @@ def __dir__(self): 'unique', 'nunique', 'key_template', + 'groupby_attrs', + 'variable_column_name', + 'aggregations', + 'agg_columns', + 'aggregation_dict', + 'path_column_name', + 'data_format', + 'format_column_name', ] return sorted(list(self.__dict__.keys()) + rv) @@ -324,7 +448,7 @@ def from_df( Returns ------- - intake_esm.core.esm_datastore + :py:class:`~intake_esm.core.esm_datastore` Catalog object """ return cls( @@ -336,18 +460,6 @@ def from_df( **kwargs, ) - @property - def df(self) -> pd.DataFrame: - """ - Return pandas dataframe. - """ - return self._df - - @df.setter - def df(self, value: pd.DataFrame): - self._df = value - self._set_groups_and_keys() - def search(self, require_all_on: Union[str, List] = None, **query): """Search for entries in the catalog. @@ -386,7 +498,7 @@ def search(self, require_all_on: Union[str, List] = None, **query): 401 CMIP BCC BCC-CSM2-MR ... gn gs://cmip6/CMIP/BCC/BCC-CSM2-MR/historical/r3i... NaN The search method also accepts compiled regular expression objects - from :py:method:`~re.compile` as patterns. + from :py:func:`~re.compile` as patterns. >>> import re >>> # Let's search for variables containing "Frac" in their name @@ -568,9 +680,9 @@ def to_dataset_dict( Parameters ---------- zarr_kwargs : dict - Keyword arguments to pass to `xarray.open_zarr()` function + Keyword arguments to pass to :py:func:`~xarray.open_zarr` function cdf_kwargs : dict - Keyword arguments to pass to `xarray.open_dataset()` function + Keyword arguments to pass to :py:func:`~xarray.open_dataset` function preprocess : callable, optional If provided, call this function on each dataset prior to aggregation. aggregate : bool, optional @@ -585,7 +697,7 @@ def to_dataset_dict( Returns ------- dsets : dict - A dictionary of xarray :py:class:`~xarray.Dataset`s. + A dictionary of xarray :py:class:`~xarray.Dataset`. Examples -------- @@ -671,14 +783,53 @@ def _load_source(source): def _make_entry(key, df, aggregation_info): args = dict( df=df, - aggregation_dict=aggregation_info['aggregation_dict'], - path_column=aggregation_info['path_column_name'], - variable_column=aggregation_info['variable_column_name'], - data_format=aggregation_info['data_format'], - format_column=aggregation_info['format_column_name'], + aggregation_dict=aggregation_info.aggregation_dict, + path_column=aggregation_info.path_column_name, + variable_column=aggregation_info.variable_column_name, + data_format=aggregation_info.data_format, + format_column=aggregation_info.format_column_name, key=key, ) entry = intake.catalog.local.LocalCatalogEntry( name=key, description='', driver='esm_group', args=args, metadata={} ) return entry.get() + + +def _construct_agg_info(aggregations: List[Dict]) -> Tuple[List[Dict], Dict, List]: + """ + Helper function used to determine aggregation columns information and their + respective settings. + + Examples + -------- + + >>> a = [{'type': 'union', 'attribute_name': 'variable_id'}, + ... {'type': 'join_new', + ... 'attribute_name': 'member_id', + ... 'options': {'coords': 'minimal', 'compat': 'override'}}, + ... {'type': 'join_new', + ... 'attribute_name': 'dcpp_init_year', + ... 'options': {'coords': 'minimal', 'compat': 'override'}}] + >>> aggregations, aggregation_dict, agg_columns = _construct_agg_info(a) + >>> agg_columns + ['variable_id', 'member_id', 'dcpp_init_year'] + >>> aggregation_dict + {'variable_id': {'type': 'union'}, + 'member_id': {'type': 'join_new', + 'options': {'coords': 'minimal', 'compat': 'override'}}, + 'dcpp_init_year': {'type': 'join_new', + 'options': {'coords': 'minimal', 'compat': 'override'}}} + """ + agg_columns = [] + aggregation_dict = {} + if aggregations: + # Sort aggregations to make sure join_existing is always done before join_new + aggregations = sorted(aggregations, key=lambda i: i['type'], reverse=True) + for agg in aggregations: + key = agg['attribute_name'] + rest = agg.copy() + del rest['attribute_name'] + aggregation_dict[key] = rest + agg_columns = list(aggregation_dict.keys()) + return aggregations, aggregation_dict, agg_columns diff --git a/tests/test_core.py b/tests/test_core.py index 5b2b7b49..da04f638 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -124,6 +124,123 @@ def test_df_property(): assert len(col.df) == 2 +@pytest.mark.parametrize( + 'property, expected', + [ + ('groupby_attrs', ['component', 'experiment', 'frequency']), + ('variable_column_name', 'variable'), + ( + 'aggregations', + [{'type': 'union', 'attribute_name': 'variable', 'options': {'compat': 'override'}}], + ), + ('agg_columns', ['variable']), + ('aggregation_dict', {'variable': {'type': 'union', 'options': {'compat': 'override'}}}), + ('path_column_name', 'path'), + ('data_format', 'zarr'), + ('format_column_name', None), + ], +) +def test_aggregation_properties(property, expected): + col = intake.open_esm_datastore(catalog_dict_records) + value = getattr(col, property) + assert value == expected + + +@pytest.mark.parametrize( + 'aggregations, expected_aggregations, expected_aggregation_dict, expected_agg_columns', + [ + ( + [ + {'type': 'union', 'attribute_name': 'variable_id'}, + { + 'type': 'join_new', + 'attribute_name': 'member_id', + 'options': {'coords': 'minimal', 'compat': 'override'}, + }, + { + 'type': 'join_new', + 'attribute_name': 'dcpp_init_year', + 'options': {'coords': 'minimal', 'compat': 'override'}, + }, + ], + [ + {'type': 'union', 'attribute_name': 'variable_id'}, + { + 'type': 'join_new', + 'attribute_name': 'member_id', + 'options': {'coords': 'minimal', 'compat': 'override'}, + }, + { + 'type': 'join_new', + 'attribute_name': 'dcpp_init_year', + 'options': {'coords': 'minimal', 'compat': 'override'}, + }, + ], + { + 'variable_id': {'type': 'union'}, + 'member_id': { + 'type': 'join_new', + 'options': {'coords': 'minimal', 'compat': 'override'}, + }, + 'dcpp_init_year': { + 'type': 'join_new', + 'options': {'coords': 'minimal', 'compat': 'override'}, + }, + }, + ['variable_id', 'member_id', 'dcpp_init_year'], + ), + ( + [ + { + 'type': 'join_new', + 'attribute_name': 'member_id', + 'options': {'coords': 'minimal', 'compat': 'override'}, + }, + {'type': 'union', 'attribute_name': 'variable_id'}, + { + 'type': 'join_new', + 'attribute_name': 'dcpp_init_year', + 'options': {'coords': 'minimal', 'compat': 'override'}, + }, + ], + [ + {'type': 'union', 'attribute_name': 'variable_id'}, + { + 'type': 'join_new', + 'attribute_name': 'member_id', + 'options': {'coords': 'minimal', 'compat': 'override'}, + }, + { + 'type': 'join_new', + 'attribute_name': 'dcpp_init_year', + 'options': {'coords': 'minimal', 'compat': 'override'}, + }, + ], + { + 'variable_id': {'type': 'union'}, + 'member_id': { + 'type': 'join_new', + 'options': {'coords': 'minimal', 'compat': 'override'}, + }, + 'dcpp_init_year': { + 'type': 'join_new', + 'options': {'coords': 'minimal', 'compat': 'override'}, + }, + }, + ['variable_id', 'member_id', 'dcpp_init_year'], + ), + ([], [], {}, []), + ], +) +def test_construct_agg_info( + aggregations, expected_aggregations, expected_aggregation_dict, expected_agg_columns +): + r_agg, r_agg_dict, r_agg_colums = intake_esm.core._construct_agg_info(aggregations) + assert r_agg == expected_aggregations + assert r_agg_dict == expected_aggregation_dict + assert r_agg_colums == expected_agg_columns + + def test_serialize_to_json(): with TemporaryDirectory() as local_store: col = intake.open_esm_datastore(catalog_dict_records)