Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xarray parameters #404

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions param/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1658,6 +1658,125 @@ def _validate(self, val):
self._length_bounds_check(self.rows, len(val), 'Row')


class _XrBase:
"""
Provides mixin methods for DataArray and Dataset. Not useful on its own.
"""
__slots__ = []

def _validate_dims_coords(self, val):
if self.dims is not None:
self._validate_property(self.dims, val.dims, 'dimensions')

if self.coords is not None:
self._validate_property(self.coords, list(val.coords),
'coordinates')
if isinstance(self.coords, dict):
self._coord_check(self.coords, val.coords)


def _validate_property(self, expected, property, name):
difference = set(expected) - set([str(el) for el in property])
if difference:
msg = ('Provided DataArray {name} {found} do not contain '
jbednar marked this conversation as resolved.
Show resolved Hide resolved
'required {name} {expected}')
raise ValueError(msg.format(
found=list(property),
expected=sorted(expected),
name=name,
))

if isinstance(expected, (list, tuple)):
if not set(expected) == set(property):
msg = ('Provided DataArray {name} {found} must '
'exactly match {expected}')
raise ValueError(msg.format(
found=list(property),
expected=sorted(expected),
name=name,
))


def _coord_check(self, expected, coords):
if not all (coords[k].values.tolist() == list(v)
for k, v in expected.items()):
msg = 'Provided DataArray does not have expected coordinates'
raise ValueError(msg)


class DataArray(ClassSelector, _XrBase):
"""
Parameter whose value is an xarray DataArray.

dims: If specified, may be a tuple, list, or set. If a set is used, the supplied
DataArray must contain the specified dims and if a list or tuple is used, the
supplied DataArray must contain exactly the same dimensions and no other
dimensions.

coords: If specified, may be a set, tuple, list, or dict. For a set, tuple, or
list, the same validation is conducted as for dims (see above). For a dict, keys
must be coordinate names present in the supplied DataArray, and values must
match the respective coordinates exactly.
"""
__slots__ = ['dims', 'coords']

def __init__(self, default=None, dims=None, coords=None, **params):
from xarray import DataArray as xrArray
self.dims = dims
self.coords = coords
super(DataArray, self).__init__(xrArray, allow_None=True, default=default, **params)
self._validate(self.default)


def _validate(self, val):
if self.allow_None and val is None:
return

super(DataArray, self)._validate(val)
self._validate_dims_coords(val)


class Dataset(ClassSelector, _XrBase):
"""
Parameter whose value is an xarray Dataset.

dims: If specified, may be a tuple, list, or set. If a set is used, the supplied
Dataset must contain the specified dims and if a list or tuple is used, the
supplied Dataset must contain exactly the same dimensions and no other
dimensions.

coords: If specified, may be a set, tuple, list, or dict. For a set, tuple, or
list, the same validation is conducted as for dims (see above). For a dict, keys
must be coordinate names present in the supplied Dataset, and values must
match the respective coordinates exactly.

data_vars: Analogous to dims (see above).
"""
__slots__ = ['dims', 'coords', 'data_vars']

def __init__(self, default=None, dims=None, coords=None, data_vars=None,
**params):
from xarray import Dataset as xrDataset
self.dims = dims
self.coords = coords
self.data_vars = data_vars
super(Dataset, self).__init__(xrDataset, allow_None=True, default=default,
**params)
self._validate(self.default)


def _validate(self, val):
if self.allow_None and val is None:
return

super(Dataset, self)._validate(val)

self._validate_dims_coords(val)

if self.data_vars is not None:
self._validate_property(self.data_vars, val.data_vars,
'data variables')


# For portable code:
# - specify paths in unix (rather than Windows) style;
Expand Down