diff --git a/CITATION.cff b/CITATION.cff index 80e1830..fe78f22 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -3,10 +3,10 @@ message: "If you use this software, please cite it as below." title: "xarray-units" abstract: "xarray extension for handling units" -version: 0.1.0 -date-released: 2023-12-11 +version: 0.2.0 +date-released: 2023-12-18 license: "MIT" -doi: "" +doi: "10.5281/zenodo.10354517" url: "https://github.com/astropenguin/xarray-units/" authors: - given-names: "Akio" diff --git a/README.md b/README.md index 922194b..4bca70c 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,15 @@ # xarray-units + +[![Release](https://img.shields.io/pypi/v/xarray-units?label=Release&color=cornflowerblue&style=flat-square)](https://pypi.org/project/xarray-units/) +[![Python](https://img.shields.io/pypi/pyversions/xarray-units?label=Python&color=cornflowerblue&style=flat-square)](https://pypi.org/project/xarray-units/) +[![Downloads](https://img.shields.io/pypi/dm/xarray-units?label=Downloads&color=cornflowerblue&style=flat-square)](https://pepy.tech/project/xarray-units) +[![DOI](https://img.shields.io/badge/DOI-10.5281/zenodo.10354517-cornflowerblue?style=flat-square)](https://doi.org/10.5281/zenodo.10354517) +[![Tests](https://img.shields.io/github/actions/workflow/status/astropenguin/xarray-units/tests.yaml?label=Tests&style=flat-square)](https://github.com/astropenguin/xarray-units/actions) + xarray extension for handling units + +## Installation + +```shell +pip install xarray-units==0.2.0 +``` diff --git a/pyproject.toml b/pyproject.toml index 3ebd823..9f38c81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "xarray-units" -version = "0.1.0" +version = "0.2.0" description = "xarray extension for handling units" authors = ["Akio Taniguchi "] documentation = "https://astropenguin.github.io/xarray-units/" diff --git a/tests/test_operator.py b/tests/test_operator.py new file mode 100644 index 0000000..1163148 --- /dev/null +++ b/tests/test_operator.py @@ -0,0 +1,121 @@ +# standard library +from typing import Any + + +# dependencies +from astropy.units import Quantity +from pytest import mark, raises +from xarray import DataArray +from xarray.testing import assert_identical # type: ignore +from xarray_units import operator as opr +from xarray_units.operator import Operator, take +from xarray_units.quantity import set +from xarray_units.utils import UnitsApplicationError + + +# test data +km = set(DataArray([1, 2, 3]), "km") +mm = set(DataArray([1, 2, 3]) * 1e6, "mm") +sc_1 = 2 +sc_2 = Quantity(2000, "m") + + +data_take: list[tuple[DataArray, Operator, Any, Any]] = [ + (km, "mul", sc_1, set(DataArray([2, 4, 6]), "km")), + (km, "mul", sc_2, set(DataArray([2, 4, 6]) * 1e3, "km m")), + (km, "mul", mm, set(DataArray([1, 4, 9]) * 1e6, "km mm")), + (mm, "mul", km, set(DataArray([1, 4, 9]) * 1e6, "km mm")), + # + (km, "pow", sc_1, set(DataArray([1, 4, 9]), "km2")), + (km, "pow", sc_2, UnitsApplicationError), + (km, "pow", mm, UnitsApplicationError), + (mm, "pow", km, UnitsApplicationError), + # + (km, "matmul", sc_1, UnitsApplicationError), + (km, "matmul", sc_2, UnitsApplicationError), + (km, "matmul", mm, set(DataArray(14) * 1e6, "km mm")), + (mm, "matmul", km, set(DataArray(14) * 1e6, "km mm")), + # + (km, "truediv", sc_1, set(DataArray([0.5, 1, 1.5]), "km")), + (km, "truediv", sc_2, set(DataArray([0.5, 1.0, 1.5]) * 1e-3, "km m-1")), + (km, "truediv", mm, set(DataArray([1, 1, 1]) * 1e-6, "km mm-1")), + (mm, "truediv", km, set(DataArray([1, 1, 1]) * 1e6, "mm km-1")), + # + (km, "add", sc_1, UnitsApplicationError), + (km, "add", sc_2, set(DataArray([3, 4, 5]), "km")), + (km, "add", mm, set(DataArray([2, 4, 6]), "km")), + (mm, "add", km, set(DataArray([2, 4, 6]) * 1e6, "mm")), + # + (km, "sub", sc_1, UnitsApplicationError), + (km, "sub", sc_2, set(DataArray([-1, 0, 1]), "km")), + (km, "sub", mm, set(DataArray([0, 0, 0]), "km")), + (mm, "sub", km, set(DataArray([0, 0, 0]), "mm")), + # + (km, "floordiv", sc_1, UnitsApplicationError), + (km, "floordiv", sc_2, set(DataArray([0, 1, 1]), "1")), + (km, "floordiv", mm, set(DataArray([1, 1, 1]), "1")), + (mm, "floordiv", km, set(DataArray([1, 1, 1]), "1")), + # + (km, "mod", sc_1, UnitsApplicationError), + (km, "mod", sc_2, set(DataArray([1, 0, 1]), "km")), + (km, "mod", mm, set(DataArray([0, 0, 0]), "km")), + (mm, "mod", km, set(DataArray([0, 0, 0]), "mm")), + # + (km, "lt", sc_1, UnitsApplicationError), + (km, "lt", sc_2, DataArray([True, False, False])), + (km, "lt", mm, DataArray([False, False, False])), + (mm, "lt", km, DataArray([False, False, False])), + # + (km, "le", sc_1, UnitsApplicationError), + (km, "le", sc_2, DataArray([True, True, False])), + (km, "le", mm, DataArray([True, True, True])), + (mm, "le", km, DataArray([True, True, True])), + # + (km, "eq", sc_1, UnitsApplicationError), + (km, "eq", sc_2, DataArray([False, True, False])), + (km, "eq", mm, DataArray([True, True, True])), + (mm, "eq", km, DataArray([True, True, True])), + # + (km, "ne", sc_1, UnitsApplicationError), + (km, "ne", sc_2, DataArray([True, False, True])), + (km, "ne", mm, DataArray([False, False, False])), + (mm, "ne", km, DataArray([False, False, False])), + # + (km, "ge", sc_1, UnitsApplicationError), + (km, "ge", sc_2, DataArray([False, True, True])), + (km, "ge", mm, DataArray([True, True, True])), + (mm, "ge", km, DataArray([True, True, True])), + # + (km, "gt", sc_1, UnitsApplicationError), + (km, "gt", sc_2, DataArray([False, False, True])), + (km, "gt", mm, DataArray([False, False, False])), + (mm, "gt", km, DataArray([False, False, False])), +] + + +@mark.parametrize("left, operator, right, expected", data_take) +def test_take( + left: DataArray, + operator: Operator, + right: Any, + expected: Any, +) -> None: + if expected is UnitsApplicationError: + with raises(expected): + take(left, operator, right) + else: + assert_identical(take(left, operator, right), expected) + + +@mark.parametrize("left, operator, right, expected", data_take) +def test_take_alias( + left: DataArray, + operator: Operator, + right: Any, + expected: DataArray, +) -> None: + if expected is UnitsApplicationError: + with raises(expected): + getattr(opr, operator)(left, right) + else: + assert_identical(getattr(opr, operator)(left, right), expected) diff --git a/tests/test_methods.py b/tests/test_quantity.py similarity index 94% rename from tests/test_methods.py rename to tests/test_quantity.py index e1b39c8..305c73c 100644 --- a/tests/test_methods.py +++ b/tests/test_quantity.py @@ -3,7 +3,7 @@ from astropy.constants import c # type: ignore from astropy.units import spectral # type: ignore from xarray.testing import assert_identical # type: ignore -from xarray_units.methods import apply, decompose, like, set, to +from xarray_units.quantity import apply, decompose, like, set, to # test datasets diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..4f6e22e --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,30 @@ +# standard library +from typing import Any + + +# dependencies +from astropy.units import Unit +from pytest import mark, raises +from xarray import DataArray +from xarray_units.utils import UnitsNotValidError, units_of + + +# test data +data_units_of: list[tuple[Any, Any]] = [ + (1, None), + (Unit("m"), None), + (1 * Unit("m"), Unit("m")), + (DataArray(1), None), + (DataArray(1, attrs={"units": "m"}), Unit("m")), + (DataArray(1, attrs={"units": "m, s"}), UnitsNotValidError), + (DataArray(1, attrs={"units": "spam"}), UnitsNotValidError), +] + + +@mark.parametrize("obj, expected", data_units_of) +def test_units_of(obj: Any, expected: Any) -> None: + if expected is UnitsNotValidError: + with raises(expected): + assert units_of(obj) + else: + assert units_of(obj) == expected diff --git a/xarray_units/__init__.py b/xarray_units/__init__.py index a389847..fc4d437 100644 --- a/xarray_units/__init__.py +++ b/xarray_units/__init__.py @@ -1,7 +1,8 @@ -__all__ = ["exceptions", "methods"] -__version__ = "0.1.0" +__all__ = ["operator", "quantity", "utils"] +__version__ = "0.2.0" # submodules -from . import exceptions -from . import methods +from . import operator +from . import quantity +from . import utils diff --git a/xarray_units/exceptions.py b/xarray_units/exceptions.py deleted file mode 100644 index 4878674..0000000 --- a/xarray_units/exceptions.py +++ /dev/null @@ -1,37 +0,0 @@ -__all__ = [ - "UnitsError", - "UnitsApplicationError", - "UnitsExistError", - "UnitsNotFoundError", - "UnitsNotValidError", -] - - -class UnitsError(Exception): - """Base exception for handling units.""" - - pass - - -class UnitsApplicationError(UnitsError): - """Units application is not successful.""" - - pass - - -class UnitsExistError(UnitsError): - """Units already exist in a DataArray.""" - - pass - - -class UnitsNotFoundError(UnitsError): - """Units do not exist in a DataArray.""" - - pass - - -class UnitsNotValidError(UnitsError): - """Units are not valid for a DataArray.""" - - pass diff --git a/xarray_units/operator.py b/xarray_units/operator.py new file mode 100644 index 0000000..4689ffd --- /dev/null +++ b/xarray_units/operator.py @@ -0,0 +1,181 @@ +__all__ = [ + "take", + # any-units operators + "mul", # * + "pow", # ** + "matmul", # @ + "truediv", # / + # same-units operators + "add", # + + "sub", # - + "floordiv", # // + "mod", # % + "lt", # < + "le", # <= + "eq", # == + "ne", # != + "ge", # >= + "gt", # > +] + + +# standard library +import operator as opr +from typing import Any, Literal, Union, get_args + + +# dependencies +from astropy.units import Quantity +from xarray import DataArray +from xarray_units.quantity import apply_any, set, to, unset +from .utils import TESTER, TDataArray, UnitsApplicationError, units_of + + +# type hints +AnyUnitsOperator = Literal[ + "mul", # * + "pow", # ** + "matmul", # @ + "truediv", # / +] +SameUnitsOperator = Literal[ + "add", # + + "sub", # - + "floordiv", # // + "mod", # % + "lt", # < + "le", # <= + "eq", # == + "ne", # != + "ge", # >= + "gt", # > +] +Operator = Union[AnyUnitsOperator, SameUnitsOperator] + + +def take(left: TDataArray, operator: Operator, right: Any, /) -> TDataArray: + """Perform an operation between a DataArray and any data with units. + + Args: + left: DataArray with units on the left side of the operator. + operator: Name of the operator (e.g. ``"add"``, ``"gt"``). + right: Any data on the right side of the operator. + + Returns: + DataArray of the result of the operation. Units are + the same as ``left`` in a numerical operation (e.g. ``"add"``) + or nothing in a relational operation (e.g. ``"gt"``). + + Raises: + UnitsApplicationError: Raised if the application fails. + UnitsNotFoundError: Raised if units are not found. + UnitsNotValidError: Raised if units are not valid. + + """ + left_units = units_of(left, strict=True) + right_units = units_of(right) + + if operator == "pow": + method = f"__{operator}__" + args = (Quantity(right, right_units),) + elif operator == "matmul": + method = "__mul__" + args = (Quantity(TESTER, right_units),) + elif operator == "eq" or operator == "ne": + method = "__lt__" + args = (Quantity(TESTER, right_units),) + else: + method = f"__{operator}__" + args = (Quantity(TESTER, right_units),) + + try: + test = apply_any(TESTER, left_units, method, *args) + except Exception as error: + raise UnitsApplicationError(error) + + if operator in get_args(SameUnitsOperator): + if isinstance(right, Quantity): + right = right.to(left_units).value # type: ignore + + if isinstance(right, DataArray): + right = to(right, left_units) + + try: + result = getattr(opr, operator)(left, right) + except Exception as error: + raise UnitsApplicationError(error) + + if (units := units_of(test)) is None: + return unset(result) + else: + return set(result, units, overwrite=True) + + +def mul(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) * (right)`` with units.""" + return take(left, "mul", right) + + +def pow(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) ** (right)`` with units.""" + return take(left, "pow", right) + + +def matmul(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) @ (right)`` with units.""" + return take(left, "matmul", right) + + +def truediv(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) / (right)`` with units.""" + return take(left, "truediv", right) + + +def add(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) + (right)`` with units.""" + return take(left, "add", right) + + +def sub(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) - (right)`` with units.""" + return take(left, "sub", right) + + +def floordiv(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) // (right)`` with units.""" + return take(left, "floordiv", right) + + +def mod(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) % (right)`` with units.""" + return take(left, "mod", right) + + +def lt(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) < (right)`` with units.""" + return take(left, "lt", right) + + +def le(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) <= (right)`` with units.""" + return take(left, "le", right) + + +def eq(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) == (right)`` with units.""" + return take(left, "eq", right) + + +def ne(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) != (right)`` with units.""" + return take(left, "ne", right) + + +def ge(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) >= (right)`` with units.""" + return take(left, "ge", right) + + +def gt(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) > (right)`` with units.""" + return take(left, "gt", right) diff --git a/xarray_units/py.typed b/xarray_units/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/xarray_units/methods.py b/xarray_units/quantity.py similarity index 57% rename from xarray_units/methods.py rename to xarray_units/quantity.py index 886e33e..cc3979d 100644 --- a/xarray_units/methods.py +++ b/xarray_units/quantity.py @@ -1,38 +1,38 @@ -__all__ = ["apply", "decompose", "like", "set", "to"] +__all__ = ["apply", "decompose", "like", "set", "to", "unset"] # standard library -from types import MethodType -from typing import Any, Optional, TypeVar, Union, overload +from types import MethodType, MethodWrapperType +from typing import Any # dependencies -from astropy.units import Equivalency, Quantity, Unit, UnitBase -from xarray import DataArray, map_blocks -from .exceptions import ( +from astropy.units import Quantity +from xarray import DataArray +from .utils import ( + TESTER, + UNITS_ATTR, + Equivalencies, + TDataArray, UnitsApplicationError, UnitsExistError, - UnitsNotFoundError, - UnitsNotValidError, + UnitsLike, + units_of, ) -# type hints -TDataArray = TypeVar("TDataArray", bound=DataArray) -Equivalencies = Optional[Equivalency] -UnitsLike = Union[UnitBase, str] - - -# constants -UNITS_ATTR = "units" - - -def apply(da: TDataArray, name: str, /, *args: Any, **kwargs: Any) -> TDataArray: +def apply( + da: TDataArray, + method: str, + /, + *args: Any, + **kwargs: Any, +) -> TDataArray: """Apply a method of Astropy Quantity to a DataArray. Args: da: Input DataArray with units. - name: Method (or property) name of Astropy Quantity. + method: Method (or property) name of Astropy Quantity. *args: Positional arguments of the method. *kwargs: Keyword arguments of the method. @@ -45,20 +45,40 @@ def apply(da: TDataArray, name: str, /, *args: Any, **kwargs: Any) -> TDataArray UnitsNotValidError: Raised if units are not valid. """ - if (da_units := units_of(da)) is None: - raise UnitsNotFoundError(repr(da)) + units = units_of(da, strict=True) + + def per_block(block: TDataArray) -> TDataArray: + data = apply_any(block, units, method, *args, **kwargs) + return block.copy(data=data) - # test application try: - test = apply_any(1, da_units, name, *args, **kwargs) + test = apply_any(TESTER, units, method, *args, **kwargs) except Exception as error: raise UnitsApplicationError(error) - def per_block(block: TDataArray) -> TDataArray: - data = apply_any(block, da_units, name, *args, **kwargs) - return block.copy(data=data) + try: + result = da.map_blocks(per_block) + except Exception as error: + raise UnitsApplicationError(error) + + return set(result, units_of(test, strict=True), overwrite=True) + + +def apply_any( + data: Any, + units: UnitsLike, + method: str, + /, + *args: Any, + **kwargs: Any, +) -> Any: + """Apply a method of Astropy Quantity to any data.""" + attr = getattr(Quantity(data, units), method) - return set(map_blocks(per_block, da), units_of(test), True) + if isinstance(attr, (MethodType, MethodWrapperType)): + return attr(*args, **kwargs) + else: + return attr def decompose(da: TDataArray, /) -> TDataArray: @@ -101,16 +121,14 @@ def like( UnitsNotValidError: Raised if units are not valid. """ - if (units := units_of(other)) is None: - raise UnitsNotFoundError(repr(other)) - - return apply(da, "to", units, equivalencies) + return apply(da, "to", units_of(other, strict=True), equivalencies) def set( da: TDataArray, units: UnitsLike, /, + *, overwrite: bool = False, ) -> TDataArray: """Set units to a DataArray. @@ -118,6 +136,8 @@ def set( Args: da: Input DataArray. units: Units to be set to the input. + + Keyword Args: overwrite: Whether to overwrite existing units. Returns: @@ -132,7 +152,7 @@ def set( if not overwrite and units_of(da) is not None: raise UnitsExistError(repr(da)) - return da.assign_attrs(units=units) + return da.assign_attrs({UNITS_ATTR: units}) def to( @@ -160,59 +180,16 @@ def to( return apply(da, "to", units, equivalencies) -# helper functions -def apply_any( - data: Any, - units: UnitsLike, - name: str, - /, - *args: Any, - **kwargs: Any, -) -> Quantity: - """Apply a method of Astropy Quantity to any data.""" - data = Quantity(data, units) - - if isinstance(attr := getattr(data, name), MethodType): - return ensure_consistency(data, attr(*args, **kwargs)) - else: - return ensure_consistency(data, attr) - - -def ensure_consistency(data_in: Any, data_out: Any, /) -> Quantity: - """Ensure consistency between input and output data.""" - if not isinstance(data_in, Quantity): - raise TypeError("Input must be Astropy Quantity.") +def unset(da: TDataArray, /) -> TDataArray: + """Remove units from a DataArray. - if not isinstance(data_out, Quantity): - raise TypeError("Output must be Astropy Quantity.") - - if data_out.shape != data_in.shape: - raise ValueError("Input and output shapes must be same.") - - return data_out - - -@overload -def units_of(obj: Quantity) -> UnitBase: - ... - - -@overload -def units_of(obj: DataArray) -> Optional[UnitBase]: - ... - - -def units_of(obj: Any) -> Any: - """Return units of an object if they exist and are valid.""" - if isinstance(obj, Quantity): - if isinstance(units := obj.unit, UnitBase): - return units - - if isinstance(obj, DataArray): - if (units := obj.attrs.get(UNITS_ATTR)) is None: - return None + Args: + da: Input DataArray. - if isinstance(units := Unit(units), UnitBase): # type: ignore - return units + Returns: + DataArray with units removed. - raise UnitsNotValidError(repr(obj)) + """ + da = da.copy(data=da.data) + da.attrs.pop(UNITS_ATTR, None) + return da diff --git a/xarray_units/utils.py b/xarray_units/utils.py new file mode 100644 index 0000000..7803979 --- /dev/null +++ b/xarray_units/utils.py @@ -0,0 +1,106 @@ +__all__ = [ + "UnitsError", + "UnitsApplicationError", + "UnitsExistError", + "UnitsNotFoundError", + "UnitsNotValidError", + "units_of", +] + + +# standard library +from typing import Any, Literal, Optional, TypeVar, Union, overload + + +# dependencies +from astropy.units import Equivalency, Quantity, Unit, UnitBase +from xarray import DataArray + + +# type hints +TDataArray = TypeVar("TDataArray", bound=DataArray) +Equivalencies = Optional[Equivalency] +UnitsLike = Union[UnitBase, str] + + +# constants +TESTER = 1 +UNITS_ATTR = "units" + + +class UnitsError(Exception): + """Base exception for handling units.""" + + pass + + +class UnitsApplicationError(UnitsError): + """Units application is not successful.""" + + pass + + +class UnitsExistError(UnitsError): + """Units already exist in a DataArray.""" + + pass + + +class UnitsNotFoundError(UnitsError): + """Units do not exist in a DataArray.""" + + pass + + +class UnitsNotValidError(UnitsError): + """Units are not valid for a DataArray.""" + + pass + + +@overload +def units_of(obj: Any, /, *, strict: Literal[False] = False) -> Optional[UnitBase]: + ... + + +@overload +def units_of(obj: Any, /, *, strict: Literal[True] = True) -> UnitBase: + ... + + +def units_of(obj: Any, /, *, strict: bool = False) -> Optional[UnitBase]: + """Return units of an object if they exist and are valid. + + Args: + obj: Any object from which units are extracted. + + Keyword Args: + strict: Whether to allow None as the return value. + + Raises: + UnitsNotFoundError: Raised if ``strict`` is ``True`` + but units do not exist in the object. + UnitsNotValidError: Raised if units exist in the object + but they cannot be parsed into ``UnitBase``. + + """ + if isinstance(obj, Quantity): + if isinstance(units := obj.unit, UnitBase): + return units + else: + raise UnitsNotValidError(repr(obj)) + + if isinstance(obj, DataArray): + if (units := obj.attrs.get(UNITS_ATTR)) is not None: + try: + units = Unit(units) # type: ignore + except Exception: + raise UnitsNotValidError(repr(obj)) + + if isinstance(units, UnitBase): + return units + else: + raise UnitsNotValidError(repr(obj)) + + if strict: + raise UnitsNotFoundError(repr(obj))