From 45afbc6b5b5aa03bd092b0ba8b30db46524c69a0 Mon Sep 17 00:00:00 2001 From: Akio Taniguchi Date: Mon, 11 Dec 2023 13:41:39 +0000 Subject: [PATCH 01/31] #5 Add initial operators module --- xarray_units/__init__.py | 3 ++- xarray_units/operators.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) create mode 100644 xarray_units/operators.py diff --git a/xarray_units/__init__.py b/xarray_units/__init__.py index a389847..45a7249 100644 --- a/xarray_units/__init__.py +++ b/xarray_units/__init__.py @@ -1,7 +1,8 @@ -__all__ = ["exceptions", "methods"] +__all__ = ["exceptions", "methods", "operators"] __version__ = "0.1.0" # submodules from . import exceptions from . import methods +from . import operators diff --git a/xarray_units/operators.py b/xarray_units/operators.py new file mode 100644 index 0000000..a9a2c5b --- /dev/null +++ b/xarray_units/operators.py @@ -0,0 +1 @@ +__all__ = [] From ae80ad7eba9e903122a7e661acc2bc229a1398f1 Mon Sep 17 00:00:00 2001 From: Akio Taniguchi Date: Mon, 11 Dec 2023 14:00:14 +0000 Subject: [PATCH 02/31] #5 Add type hints for supported operators --- xarray_units/operators.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/xarray_units/operators.py b/xarray_units/operators.py index a9a2c5b..6face50 100644 --- a/xarray_units/operators.py +++ b/xarray_units/operators.py @@ -1 +1,24 @@ __all__ = [] + + +# standard library +from typing import Literal + + +# type hints +Operator = Literal[ + "lt", # < + "le", # <= + "eq", # == + "ne", # != + "ge", # >= + "gt", # > + "add", # + + "sub", # - + "mul", # * + "pow", # ** + "matmul", # @ + "truediv", # / + "floordiv", # // + "mod", # % +] From 957134c65699c2a753fba75c3997f64a529911b0 Mon Sep 17 00:00:00 2001 From: Akio Taniguchi Date: Mon, 11 Dec 2023 16:05:46 +0000 Subject: [PATCH 03/31] =?UTF-8?q?#5=20Rename=20module=20(exceptions=20?= =?UTF-8?q?=E2=86=92=20utils)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- xarray_units/__init__.py | 4 ++-- xarray_units/methods.py | 2 +- xarray_units/{exceptions.py => utils.py} | 0 3 files changed, 3 insertions(+), 3 deletions(-) rename xarray_units/{exceptions.py => utils.py} (100%) diff --git a/xarray_units/__init__.py b/xarray_units/__init__.py index 45a7249..70ac860 100644 --- a/xarray_units/__init__.py +++ b/xarray_units/__init__.py @@ -1,8 +1,8 @@ -__all__ = ["exceptions", "methods", "operators"] +__all__ = ["methods", "operators", "utils"] __version__ = "0.1.0" # submodules -from . import exceptions from . import methods from . import operators +from . import utils diff --git a/xarray_units/methods.py b/xarray_units/methods.py index 886e33e..885cbfe 100644 --- a/xarray_units/methods.py +++ b/xarray_units/methods.py @@ -9,7 +9,7 @@ # dependencies from astropy.units import Equivalency, Quantity, Unit, UnitBase from xarray import DataArray, map_blocks -from .exceptions import ( +from .utils import ( UnitsApplicationError, UnitsExistError, UnitsNotFoundError, diff --git a/xarray_units/exceptions.py b/xarray_units/utils.py similarity index 100% rename from xarray_units/exceptions.py rename to xarray_units/utils.py From a0507b326e03a50a94b3ab2722a16fd5ae4b4bd8 Mon Sep 17 00:00:00 2001 From: Akio Taniguchi Date: Mon, 11 Dec 2023 16:20:06 +0000 Subject: [PATCH 04/31] #5 Move common items to utils module --- xarray_units/methods.py | 53 ++++++++++------------------------------- xarray_units/utils.py | 48 +++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 40 deletions(-) diff --git a/xarray_units/methods.py b/xarray_units/methods.py index 885cbfe..6086ece 100644 --- a/xarray_units/methods.py +++ b/xarray_units/methods.py @@ -3,31 +3,30 @@ # standard library from types import MethodType -from typing import Any, Optional, TypeVar, Union, overload +from typing import Any # dependencies -from astropy.units import Equivalency, Quantity, Unit, UnitBase +from astropy.units import Quantity from xarray import DataArray, map_blocks from .utils import ( + Equivalencies, + TDataArray, UnitsApplicationError, UnitsExistError, + UnitsLike, UnitsNotFoundError, - UnitsNotValidError, + units_of, ) -# type hints -TDataArray = TypeVar("TDataArray", bound=DataArray) -Equivalencies = Optional[Equivalency] -UnitsLike = Union[UnitBase, str] - - -# constants -UNITS_ATTR = "units" - - -def apply(da: TDataArray, name: str, /, *args: Any, **kwargs: Any) -> TDataArray: +def apply( + da: TDataArray, + name: str, + /, + *args: Any, + **kwargs: Any, +) -> TDataArray: """Apply a method of Astropy Quantity to a DataArray. Args: @@ -190,29 +189,3 @@ def ensure_consistency(data_in: Any, data_out: Any, /) -> Quantity: raise ValueError("Input and output shapes must be same.") return data_out - - -@overload -def units_of(obj: Quantity) -> UnitBase: - ... - - -@overload -def units_of(obj: DataArray) -> Optional[UnitBase]: - ... - - -def units_of(obj: Any) -> Any: - """Return units of an object if they exist and are valid.""" - if isinstance(obj, Quantity): - if isinstance(units := obj.unit, UnitBase): - return units - - if isinstance(obj, DataArray): - if (units := obj.attrs.get(UNITS_ATTR)) is None: - return None - - if isinstance(units := Unit(units), UnitBase): # type: ignore - return units - - raise UnitsNotValidError(repr(obj)) diff --git a/xarray_units/utils.py b/xarray_units/utils.py index 4878674..1e579e3 100644 --- a/xarray_units/utils.py +++ b/xarray_units/utils.py @@ -4,9 +4,29 @@ "UnitsExistError", "UnitsNotFoundError", "UnitsNotValidError", + "units_of", ] +# standard library +from typing import Any, Optional, TypeVar, Union, overload + + +# dependencies +from astropy.units import Equivalency, Quantity, Unit, UnitBase +from xarray import DataArray + + +# type hints +TDataArray = TypeVar("TDataArray", bound=DataArray) +Equivalencies = Optional[Equivalency] +UnitsLike = Union[UnitBase, str] + + +# constants +UNITS_ATTR = "units" + + class UnitsError(Exception): """Base exception for handling units.""" @@ -35,3 +55,31 @@ class UnitsNotValidError(UnitsError): """Units are not valid for a DataArray.""" pass + + +@overload +def units_of(obj: Quantity) -> UnitBase: + ... + + +@overload +def units_of(obj: Any) -> Optional[UnitBase]: + ... + + +def units_of(obj: Any) -> Optional[UnitBase]: + """Return units of an object if they exist and are valid.""" + if isinstance(obj, Quantity): + if isinstance(units := obj.unit, UnitBase): + return units + + raise UnitsNotValidError(repr(obj)) + + if isinstance(obj, DataArray): + if (units := obj.attrs.get(UNITS_ATTR)) is None: + return None + + if isinstance(units := Unit(units), UnitBase): # type: ignore + return units + + raise UnitsNotValidError(repr(obj)) From eb49052f364e706dcd03641bbe9f7f56b644d5c9 Mon Sep 17 00:00:00 2001 From: Akio Taniguchi Date: Mon, 11 Dec 2023 17:37:43 +0000 Subject: [PATCH 05/31] #5 Add function for operation between two datasets --- xarray_units/methods.py | 1 - xarray_units/operators.py | 79 ++++++++++++++++++++++++++++++++++++++- xarray_units/utils.py | 1 + 3 files changed, 78 insertions(+), 3 deletions(-) diff --git a/xarray_units/methods.py b/xarray_units/methods.py index 6086ece..0bfe560 100644 --- a/xarray_units/methods.py +++ b/xarray_units/methods.py @@ -47,7 +47,6 @@ def apply( if (da_units := units_of(da)) is None: raise UnitsNotFoundError(repr(da)) - # test application try: test = apply_any(1, da_units, name, *args, **kwargs) except Exception as error: diff --git a/xarray_units/operators.py b/xarray_units/operators.py index 6face50..ac8dd0a 100644 --- a/xarray_units/operators.py +++ b/xarray_units/operators.py @@ -1,8 +1,25 @@ -__all__ = [] +__all__ = ["take"] # standard library -from typing import Literal +import operator as opr +from typing import Any, Literal, Union + + +# dependencies +from astropy.units import Quantity +from numpy import bool_ +from numpy.typing import NDArray +from xarray import map_blocks +from .methods import set +from .utils import ( + UNITS_ONE, + TDataArray, + UnitsApplicationError, + UnitsLike, + UnitsNotFoundError, + units_of, +) # type hints @@ -22,3 +39,61 @@ "floordiv", # // "mod", # % ] + + +def take(left: TDataArray, operator: Operator, right: Any, /) -> TDataArray: + """Perform an operation between a DataArray and any data with units. + + Args: + left: DataArray with units on the left side of the operator.. + operator: Name of the operator (e.g. ``"add"``, ``"gt"``). + right: Any data on the right side of the operator. + + Returns: + DataArray of the result of the operation. Units are + the same as ``left`` in a numerical operation (e.g. ``"add"``) + or nothing in a relational operation (e.g. ``"gt"``). + + Raises: + UnitsApplicationError: Raised if the application fails. + UnitsNotFoundError: Raised if units are not found. + UnitsNotValidError: Raised if units are not valid. + + """ + if (left_units := units_of(left)) is None: + raise UnitsNotFoundError(repr(left)) + + if (right_units := units_of(right)) is None: + right_units = UNITS_ONE + + try: + if operator == "pow": + test = take_any(1, left_units, operator, right, right_units) + else: + test = take_any(1, left_units, operator, 1, right_units) + except Exception as error: + raise UnitsApplicationError(error) + + def per_block(left_: TDataArray, right_: Any) -> TDataArray: + data = take_any(left_, left_units, operator, right_, right_units) + return left.copy(data=data) + + if (units := units_of(test)) is None: + return map_blocks(per_block, left, (right,)) + else: + return set(map_blocks(per_block, left, (right,)), units, True) + + +# helper functions +def take_any( + left: Any, + left_units: UnitsLike, + operator: Operator, + right: Any, + right_units: UnitsLike, + /, +) -> Union[Quantity, NDArray[bool_], bool_]: + """Perform an operation between two any datasets.""" + left = Quantity(left, left_units) + right = Quantity(right, right_units) + return getattr(opr, operator)(left, right) diff --git a/xarray_units/utils.py b/xarray_units/utils.py index 1e579e3..1daf5df 100644 --- a/xarray_units/utils.py +++ b/xarray_units/utils.py @@ -25,6 +25,7 @@ # constants UNITS_ATTR = "units" +UNITS_ONE = "1" class UnitsError(Exception): From b4ad7d477e71eccd04a9042308858518981a913f Mon Sep 17 00:00:00 2001 From: Akio Taniguchi Date: Mon, 11 Dec 2023 17:44:16 +0000 Subject: [PATCH 06/31] #5 Add operator functions --- xarray_units/operators.py | 89 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 87 insertions(+), 2 deletions(-) diff --git a/xarray_units/operators.py b/xarray_units/operators.py index ac8dd0a..b76cca8 100644 --- a/xarray_units/operators.py +++ b/xarray_units/operators.py @@ -1,4 +1,20 @@ -__all__ = ["take"] +__all__ = [ + "take", + "lt", # < + "le", # <= + "eq", # == + "ne", # != + "ge", # >= + "gt", # > + "add", # + + "sub", # - + "mul", # * + "pow", # ** + "matmul", # @ + "truediv", # / + "floordiv", # // + "mod", # % +] # standard library @@ -84,7 +100,6 @@ def per_block(left_: TDataArray, right_: Any) -> TDataArray: return set(map_blocks(per_block, left, (right,)), units, True) -# helper functions def take_any( left: Any, left_units: UnitsLike, @@ -97,3 +112,73 @@ def take_any( left = Quantity(left, left_units) right = Quantity(right, right_units) return getattr(opr, operator)(left, right) + + +def lt(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) < (right)`` with units.""" + return take(left, "lt", right) + + +def le(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) <= (right)`` with units.""" + return take(left, "le", right) + + +def eq(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) == (right)`` with units.""" + return take(left, "eq", right) + + +def ne(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) != (right)`` with units.""" + return take(left, "ne", right) + + +def ge(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) >= (right)`` with units.""" + return take(left, "ge", right) + + +def gt(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) > (right)`` with units.""" + return take(left, "gt", right) + + +def add(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) + (right)`` with units.""" + return take(left, "add", right) + + +def sub(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) - (right)`` with units.""" + return take(left, "sub", right) + + +def mul(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) * (right)`` with units.""" + return take(left, "mul", right) + + +def pow(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) ** (right)`` with units.""" + return take(left, "pow", right) + + +def matmul(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) @ (right)`` with units.""" + return take(left, "matmul", right) + + +def truediv(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) / (right)`` with units.""" + return take(left, "truediv", right) + + +def floordiv(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) // (right)`` with units.""" + return take(left, "floordiv", right) + + +def mod(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) % (right)`` with units.""" + return take(left, "mod", right) From 6970793e732aca116b302660db0042f5d862d0fc Mon Sep 17 00:00:00 2001 From: Akio Taniguchi Date: Mon, 11 Dec 2023 17:45:16 +0000 Subject: [PATCH 07/31] #5 Reorder functions in method module --- xarray_units/methods.py | 63 ++++++++++++++++++++--------------------- 1 file changed, 31 insertions(+), 32 deletions(-) diff --git a/xarray_units/methods.py b/xarray_units/methods.py index 0bfe560..104c66f 100644 --- a/xarray_units/methods.py +++ b/xarray_units/methods.py @@ -59,6 +59,37 @@ def per_block(block: TDataArray) -> TDataArray: return set(map_blocks(per_block, da), units_of(test), True) +def apply_any( + data: Any, + units: UnitsLike, + name: str, + /, + *args: Any, + **kwargs: Any, +) -> Quantity: + """Apply a method of Astropy Quantity to any data.""" + data = Quantity(data, units) + + if isinstance(attr := getattr(data, name), MethodType): + return ensure_consistency(data, attr(*args, **kwargs)) + else: + return ensure_consistency(data, attr) + + +def ensure_consistency(data_in: Any, data_out: Any, /) -> Quantity: + """Ensure consistency between input and output data.""" + if not isinstance(data_in, Quantity): + raise TypeError("Input must be Astropy Quantity.") + + if not isinstance(data_out, Quantity): + raise TypeError("Output must be Astropy Quantity.") + + if data_out.shape != data_in.shape: + raise ValueError("Input and output shapes must be same.") + + return data_out + + def decompose(da: TDataArray, /) -> TDataArray: """Convert a DataArray with units to decomposed ones. @@ -156,35 +187,3 @@ def to( """ return apply(da, "to", units, equivalencies) - - -# helper functions -def apply_any( - data: Any, - units: UnitsLike, - name: str, - /, - *args: Any, - **kwargs: Any, -) -> Quantity: - """Apply a method of Astropy Quantity to any data.""" - data = Quantity(data, units) - - if isinstance(attr := getattr(data, name), MethodType): - return ensure_consistency(data, attr(*args, **kwargs)) - else: - return ensure_consistency(data, attr) - - -def ensure_consistency(data_in: Any, data_out: Any, /) -> Quantity: - """Ensure consistency between input and output data.""" - if not isinstance(data_in, Quantity): - raise TypeError("Input must be Astropy Quantity.") - - if not isinstance(data_out, Quantity): - raise TypeError("Output must be Astropy Quantity.") - - if data_out.shape != data_in.shape: - raise ValueError("Input and output shapes must be same.") - - return data_out From bce830903adb9bd616f313fb78db55555f06dde8 Mon Sep 17 00:00:00 2001 From: Akio Taniguchi Date: Mon, 11 Dec 2023 18:01:24 +0000 Subject: [PATCH 08/31] #5 Add py.typed --- xarray_units/py.typed | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 xarray_units/py.typed diff --git a/xarray_units/py.typed b/xarray_units/py.typed new file mode 100644 index 0000000..e69de29 From 29035ca7f18c7fd379cc71f29d9a671e5119a6b6 Mon Sep 17 00:00:00 2001 From: Akio Taniguchi Date: Mon, 11 Dec 2023 23:45:52 +0000 Subject: [PATCH 09/31] #5 Update units_of to catch ValueError --- xarray_units/utils.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/xarray_units/utils.py b/xarray_units/utils.py index 1daf5df..afd81a0 100644 --- a/xarray_units/utils.py +++ b/xarray_units/utils.py @@ -73,14 +73,19 @@ def units_of(obj: Any) -> Optional[UnitBase]: if isinstance(obj, Quantity): if isinstance(units := obj.unit, UnitBase): return units - - raise UnitsNotValidError(repr(obj)) + else: + raise UnitsNotValidError(repr(obj)) if isinstance(obj, DataArray): if (units := obj.attrs.get(UNITS_ATTR)) is None: return None - if isinstance(units := Unit(units), UnitBase): # type: ignore - return units + try: + units = Unit(units) # type: ignore + except Exception: + raise UnitsNotValidError(repr(obj)) - raise UnitsNotValidError(repr(obj)) + if isinstance(units, UnitBase): + return units + else: + raise UnitsNotValidError(repr(obj)) From a9973102b2b5e761487bccc1231b99d772e54604 Mon Sep 17 00:00:00 2001 From: Akio Taniguchi Date: Mon, 11 Dec 2023 23:46:10 +0000 Subject: [PATCH 10/31] #5 Add test for utils module --- tests/test_utils.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 tests/test_utils.py diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..4f6e22e --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,30 @@ +# standard library +from typing import Any + + +# dependencies +from astropy.units import Unit +from pytest import mark, raises +from xarray import DataArray +from xarray_units.utils import UnitsNotValidError, units_of + + +# test data +data_units_of: list[tuple[Any, Any]] = [ + (1, None), + (Unit("m"), None), + (1 * Unit("m"), Unit("m")), + (DataArray(1), None), + (DataArray(1, attrs={"units": "m"}), Unit("m")), + (DataArray(1, attrs={"units": "m, s"}), UnitsNotValidError), + (DataArray(1, attrs={"units": "spam"}), UnitsNotValidError), +] + + +@mark.parametrize("obj, expected", data_units_of) +def test_units_of(obj: Any, expected: Any) -> None: + if expected is UnitsNotValidError: + with raises(expected): + assert units_of(obj) + else: + assert units_of(obj) == expected From 8ebf362650e64540f076ed2a7a206577222a7a49 Mon Sep 17 00:00:00 2001 From: Akio Taniguchi Date: Wed, 13 Dec 2023 15:57:12 +0000 Subject: [PATCH 11/31] #5 Add function to unset units of DataArray --- xarray_units/methods.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/xarray_units/methods.py b/xarray_units/methods.py index 104c66f..80361da 100644 --- a/xarray_units/methods.py +++ b/xarray_units/methods.py @@ -1,4 +1,4 @@ -__all__ = ["apply", "decompose", "like", "set", "to"] +__all__ = ["apply", "decompose", "like", "set", "to", "unset"] # standard library @@ -10,6 +10,7 @@ from astropy.units import Quantity from xarray import DataArray, map_blocks from .utils import ( + UNITS_ATTR, Equivalencies, TDataArray, UnitsApplicationError, @@ -161,7 +162,7 @@ def set( if not overwrite and units_of(da) is not None: raise UnitsExistError(repr(da)) - return da.assign_attrs(units=units) + return da.assign_attrs({UNITS_ATTR: units}) def to( @@ -187,3 +188,18 @@ def to( """ return apply(da, "to", units, equivalencies) + + +def unset(da: TDataArray, /) -> TDataArray: + """Remove units from a DataArray. + + Args: + da: Input DataArray. + + Returns: + DataArray with units removed. + + """ + da = da.copy(data=da.data) + da.attrs.pop(UNITS_ATTR) + return da From 963664bd50d4540a24b591e828275faa00c03459 Mon Sep 17 00:00:00 2001 From: Akio Taniguchi Date: Wed, 13 Dec 2023 16:01:26 +0000 Subject: [PATCH 12/31] #5 Add strict option to units_of --- xarray_units/methods.py | 13 ++++--------- xarray_units/utils.py | 43 ++++++++++++++++++++++++++--------------- 2 files changed, 31 insertions(+), 25 deletions(-) diff --git a/xarray_units/methods.py b/xarray_units/methods.py index 80361da..ecc6c35 100644 --- a/xarray_units/methods.py +++ b/xarray_units/methods.py @@ -16,7 +16,6 @@ UnitsApplicationError, UnitsExistError, UnitsLike, - UnitsNotFoundError, units_of, ) @@ -45,16 +44,15 @@ def apply( UnitsNotValidError: Raised if units are not valid. """ - if (da_units := units_of(da)) is None: - raise UnitsNotFoundError(repr(da)) + units = units_of(da, True) try: - test = apply_any(1, da_units, name, *args, **kwargs) + test = apply_any(1, units, name, *args, **kwargs) except Exception as error: raise UnitsApplicationError(error) def per_block(block: TDataArray) -> TDataArray: - data = apply_any(block, da_units, name, *args, **kwargs) + data = apply_any(block, units, name, *args, **kwargs) return block.copy(data=data) return set(map_blocks(per_block, da), units_of(test), True) @@ -131,10 +129,7 @@ def like( UnitsNotValidError: Raised if units are not valid. """ - if (units := units_of(other)) is None: - raise UnitsNotFoundError(repr(other)) - - return apply(da, "to", units, equivalencies) + return apply(da, "to", units_of(other, True), equivalencies) def set( diff --git a/xarray_units/utils.py b/xarray_units/utils.py index afd81a0..bda0624 100644 --- a/xarray_units/utils.py +++ b/xarray_units/utils.py @@ -9,7 +9,7 @@ # standard library -from typing import Any, Optional, TypeVar, Union, overload +from typing import Any, Literal, Optional, TypeVar, Union, overload # dependencies @@ -59,16 +59,26 @@ class UnitsNotValidError(UnitsError): @overload -def units_of(obj: Quantity) -> UnitBase: +def units_of(obj: Quantity, /, strict: Literal[False] = False) -> UnitBase: ... @overload -def units_of(obj: Any) -> Optional[UnitBase]: +def units_of(obj: Quantity, /, strict: Literal[True] = True) -> UnitBase: ... -def units_of(obj: Any) -> Optional[UnitBase]: +@overload +def units_of(obj: Any, /, strict: Literal[False] = False) -> Optional[UnitBase]: + ... + + +@overload +def units_of(obj: Any, /, strict: Literal[True] = True) -> UnitBase: + ... + + +def units_of(obj: Any, /, strict: bool = False) -> Optional[UnitBase]: """Return units of an object if they exist and are valid.""" if isinstance(obj, Quantity): if isinstance(units := obj.unit, UnitBase): @@ -77,15 +87,16 @@ def units_of(obj: Any) -> Optional[UnitBase]: raise UnitsNotValidError(repr(obj)) if isinstance(obj, DataArray): - if (units := obj.attrs.get(UNITS_ATTR)) is None: - return None - - try: - units = Unit(units) # type: ignore - except Exception: - raise UnitsNotValidError(repr(obj)) - - if isinstance(units, UnitBase): - return units - else: - raise UnitsNotValidError(repr(obj)) + if (units := obj.attrs.get(UNITS_ATTR)) is not None: + try: + units = Unit(units) # type: ignore + except Exception: + raise UnitsNotValidError(repr(obj)) + + if isinstance(units, UnitBase): + return units + else: + raise UnitsNotValidError(repr(obj)) + + if strict: + raise UnitsNotFoundError(repr(obj)) From 7437f99b5fc8940a548091ec846b7fb8eefb94d3 Mon Sep 17 00:00:00 2001 From: Akio Taniguchi Date: Sun, 17 Dec 2023 05:50:46 +0000 Subject: [PATCH 13/31] #5 Update apply and apply_any --- xarray_units/methods.py | 43 ++++++++++++++++------------------------- 1 file changed, 17 insertions(+), 26 deletions(-) diff --git a/xarray_units/methods.py b/xarray_units/methods.py index ecc6c35..41032f3 100644 --- a/xarray_units/methods.py +++ b/xarray_units/methods.py @@ -2,13 +2,13 @@ # standard library -from types import MethodType +from types import MethodType, MethodWrapperType from typing import Any # dependencies from astropy.units import Quantity -from xarray import DataArray, map_blocks +from xarray import DataArray from .utils import ( UNITS_ATTR, Equivalencies, @@ -46,16 +46,21 @@ def apply( """ units = units_of(da, True) + def per_block(block: TDataArray) -> TDataArray: + data = apply_any(block, units, name, *args, **kwargs) + return block.copy(data=data) + try: - test = apply_any(1, units, name, *args, **kwargs) + tested = apply_any(1, units, name, *args, **kwargs) except Exception as error: raise UnitsApplicationError(error) - def per_block(block: TDataArray) -> TDataArray: - data = apply_any(block, units, name, *args, **kwargs) - return block.copy(data=data) + try: + applied = da.map_blocks(per_block) + except Exception as error: + raise UnitsApplicationError(error) - return set(map_blocks(per_block, da), units_of(test), True) + return set(applied, units_of(tested, True), True) def apply_any( @@ -65,28 +70,14 @@ def apply_any( /, *args: Any, **kwargs: Any, -) -> Quantity: +) -> Any: """Apply a method of Astropy Quantity to any data.""" - data = Quantity(data, units) + attr = getattr(Quantity(data, units), name) - if isinstance(attr := getattr(data, name), MethodType): - return ensure_consistency(data, attr(*args, **kwargs)) + if isinstance(attr, (MethodType, MethodWrapperType)): + return attr(*args, **kwargs) else: - return ensure_consistency(data, attr) - - -def ensure_consistency(data_in: Any, data_out: Any, /) -> Quantity: - """Ensure consistency between input and output data.""" - if not isinstance(data_in, Quantity): - raise TypeError("Input must be Astropy Quantity.") - - if not isinstance(data_out, Quantity): - raise TypeError("Output must be Astropy Quantity.") - - if data_out.shape != data_in.shape: - raise ValueError("Input and output shapes must be same.") - - return data_out + return attr def decompose(da: TDataArray, /) -> TDataArray: From 0391b938a25b35b94842dc470b26db2951044a3c Mon Sep 17 00:00:00 2001 From: Akio Taniguchi Date: Sun, 17 Dec 2023 05:57:54 +0000 Subject: [PATCH 14/31] =?UTF-8?q?#5=20Rename=20module=20(methods=20?= =?UTF-8?q?=E2=86=92=20quantity)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/{test_methods.py => test_quantity.py} | 2 +- xarray_units/__init__.py | 4 ++-- xarray_units/{methods.py => quantity.py} | 0 3 files changed, 3 insertions(+), 3 deletions(-) rename tests/{test_methods.py => test_quantity.py} (94%) rename xarray_units/{methods.py => quantity.py} (100%) diff --git a/tests/test_methods.py b/tests/test_quantity.py similarity index 94% rename from tests/test_methods.py rename to tests/test_quantity.py index e1b39c8..305c73c 100644 --- a/tests/test_methods.py +++ b/tests/test_quantity.py @@ -3,7 +3,7 @@ from astropy.constants import c # type: ignore from astropy.units import spectral # type: ignore from xarray.testing import assert_identical # type: ignore -from xarray_units.methods import apply, decompose, like, set, to +from xarray_units.quantity import apply, decompose, like, set, to # test datasets diff --git a/xarray_units/__init__.py b/xarray_units/__init__.py index 70ac860..477778f 100644 --- a/xarray_units/__init__.py +++ b/xarray_units/__init__.py @@ -1,8 +1,8 @@ -__all__ = ["methods", "operators", "utils"] +__all__ = ["operators", "quantity", "utils"] __version__ = "0.1.0" # submodules -from . import methods from . import operators +from . import quantity from . import utils diff --git a/xarray_units/methods.py b/xarray_units/quantity.py similarity index 100% rename from xarray_units/methods.py rename to xarray_units/quantity.py From f09f109d28c731d7be1e0462bfec8624b3cf8fb2 Mon Sep 17 00:00:00 2001 From: Akio Taniguchi Date: Sun, 17 Dec 2023 12:13:44 +0000 Subject: [PATCH 15/31] #5 Update type hints of units_of --- xarray_units/quantity.py | 7 ++++--- xarray_units/utils.py | 16 +++------------- 2 files changed, 7 insertions(+), 16 deletions(-) diff --git a/xarray_units/quantity.py b/xarray_units/quantity.py index 41032f3..258c2d8 100644 --- a/xarray_units/quantity.py +++ b/xarray_units/quantity.py @@ -44,7 +44,7 @@ def apply( UnitsNotValidError: Raised if units are not valid. """ - units = units_of(da, True) + units = units_of(da, strict=True) def per_block(block: TDataArray) -> TDataArray: data = apply_any(block, units, name, *args, **kwargs) @@ -60,7 +60,7 @@ def per_block(block: TDataArray) -> TDataArray: except Exception as error: raise UnitsApplicationError(error) - return set(applied, units_of(tested, True), True) + return set(applied, units_of(tested, strict=True), True) def apply_any( @@ -120,7 +120,8 @@ def like( UnitsNotValidError: Raised if units are not valid. """ - return apply(da, "to", units_of(other, True), equivalencies) + units = units_of(other, strict=True) + return apply(da, "to", units, equivalencies) def set( diff --git a/xarray_units/utils.py b/xarray_units/utils.py index bda0624..ce7d32c 100644 --- a/xarray_units/utils.py +++ b/xarray_units/utils.py @@ -59,26 +59,16 @@ class UnitsNotValidError(UnitsError): @overload -def units_of(obj: Quantity, /, strict: Literal[False] = False) -> UnitBase: +def units_of(obj: Any, /, *, strict: Literal[False] = False) -> Optional[UnitBase]: ... @overload -def units_of(obj: Quantity, /, strict: Literal[True] = True) -> UnitBase: +def units_of(obj: Any, /, *, strict: Literal[True] = True) -> UnitBase: ... -@overload -def units_of(obj: Any, /, strict: Literal[False] = False) -> Optional[UnitBase]: - ... - - -@overload -def units_of(obj: Any, /, strict: Literal[True] = True) -> UnitBase: - ... - - -def units_of(obj: Any, /, strict: bool = False) -> Optional[UnitBase]: +def units_of(obj: Any, /, *, strict: bool = False) -> Optional[UnitBase]: """Return units of an object if they exist and are valid.""" if isinstance(obj, Quantity): if isinstance(units := obj.unit, UnitBase): From c03edf5870efdf3ab39f90d5d0f84f76c1c1c80b Mon Sep 17 00:00:00 2001 From: Akio Taniguchi Date: Sun, 17 Dec 2023 12:14:48 +0000 Subject: [PATCH 16/31] #5 Fix unset so as not to raise KeyError --- xarray_units/quantity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray_units/quantity.py b/xarray_units/quantity.py index 258c2d8..325cdd3 100644 --- a/xarray_units/quantity.py +++ b/xarray_units/quantity.py @@ -188,5 +188,5 @@ def unset(da: TDataArray, /) -> TDataArray: """ da = da.copy(data=da.data) - da.attrs.pop(UNITS_ATTR) + da.attrs.pop(UNITS_ATTR, None) return da From 297246c4f3af5f757c3b33fb8c870bc6a77dc0f3 Mon Sep 17 00:00:00 2001 From: Akio Taniguchi Date: Sun, 17 Dec 2023 12:39:12 +0000 Subject: [PATCH 17/31] #5 Update take and take_any --- xarray_units/operators.py | 66 ++++++++++++++++++--------------------- 1 file changed, 31 insertions(+), 35 deletions(-) diff --git a/xarray_units/operators.py b/xarray_units/operators.py index b76cca8..5770b67 100644 --- a/xarray_units/operators.py +++ b/xarray_units/operators.py @@ -19,27 +19,25 @@ # standard library import operator as opr -from typing import Any, Literal, Union +from typing import Any, Literal, Optional, Union, get_args # dependencies from astropy.units import Quantity -from numpy import bool_ -from numpy.typing import NDArray -from xarray import map_blocks -from .methods import set -from .utils import ( - UNITS_ONE, - TDataArray, - UnitsApplicationError, - UnitsLike, - UnitsNotFoundError, - units_of, -) +from xarray import DataArray +from xarray_units.quantity import set, to, unset +from .utils import TDataArray, UnitsApplicationError, UnitsLike, units_of # type hints -Operator = Literal[ +AnyUnitsOperator = Literal[ + "mul", # * + "pow", # ** + "matmul", # @ + "truediv", # / + "mod", # % +] +SameUnitsOperator = Literal[ "lt", # < "le", # <= "eq", # == @@ -48,20 +46,16 @@ "gt", # > "add", # + "sub", # - - "mul", # * - "pow", # ** - "matmul", # @ - "truediv", # / "floordiv", # // - "mod", # % ] +Operator = Union[AnyUnitsOperator, SameUnitsOperator] def take(left: TDataArray, operator: Operator, right: Any, /) -> TDataArray: """Perform an operation between a DataArray and any data with units. Args: - left: DataArray with units on the left side of the operator.. + left: DataArray with units on the left side of the operator. operator: Name of the operator (e.g. ``"add"``, ``"gt"``). right: Any data on the right side of the operator. @@ -76,28 +70,29 @@ def take(left: TDataArray, operator: Operator, right: Any, /) -> TDataArray: UnitsNotValidError: Raised if units are not valid. """ - if (left_units := units_of(left)) is None: - raise UnitsNotFoundError(repr(left)) - - if (right_units := units_of(right)) is None: - right_units = UNITS_ONE + left_units = units_of(left, strict=True) + right_units = units_of(right) try: if operator == "pow": - test = take_any(1, left_units, operator, right, right_units) + tested = take_any(1, left_units, operator, right, right_units) else: - test = take_any(1, left_units, operator, 1, right_units) + tested = take_any(1, left_units, operator, 1, right_units) except Exception as error: raise UnitsApplicationError(error) - def per_block(left_: TDataArray, right_: Any) -> TDataArray: - data = take_any(left_, left_units, operator, right_, right_units) - return left.copy(data=data) + if operator in get_args(SameUnitsOperator): + if isinstance(right, DataArray): + right = to(right, left_units) + elif isinstance(right, Quantity): + right = right.to(left_units) # type: ignore + + result = getattr(opr, operator)(left, right) - if (units := units_of(test)) is None: - return map_blocks(per_block, left, (right,)) + if (units := units_of(tested)) is None: + return unset(result) else: - return set(map_blocks(per_block, left, (right,)), units, True) + return set(result, units, True) def take_any( @@ -105,12 +100,13 @@ def take_any( left_units: UnitsLike, operator: Operator, right: Any, - right_units: UnitsLike, + right_units: Optional[UnitsLike], /, -) -> Union[Quantity, NDArray[bool_], bool_]: +) -> Any: """Perform an operation between two any datasets.""" left = Quantity(left, left_units) right = Quantity(right, right_units) + return getattr(opr, operator)(left, right) From 523b36bdd535a1d2a9d7fec4a6ef38c6f2bc003f Mon Sep 17 00:00:00 2001 From: Akio Taniguchi Date: Sun, 17 Dec 2023 12:49:12 +0000 Subject: [PATCH 18/31] #5 Reorder functions --- xarray_units/operators.py | 100 +++++++++++++++++++------------------- 1 file changed, 51 insertions(+), 49 deletions(-) diff --git a/xarray_units/operators.py b/xarray_units/operators.py index 5770b67..b269d78 100644 --- a/xarray_units/operators.py +++ b/xarray_units/operators.py @@ -1,19 +1,21 @@ __all__ = [ "take", + # any-units operators + "mul", # * + "pow", # ** + "matmul", # @ + "truediv", # / + "mod", # % + # same-units operators + "add", # + + "sub", # - + "floordiv", # // "lt", # < "le", # <= "eq", # == "ne", # != "ge", # >= "gt", # > - "add", # + - "sub", # - - "mul", # * - "pow", # ** - "matmul", # @ - "truediv", # / - "floordiv", # // - "mod", # % ] @@ -38,15 +40,15 @@ "mod", # % ] SameUnitsOperator = Literal[ + "add", # + + "sub", # - + "floordiv", # // "lt", # < "le", # <= "eq", # == "ne", # != "ge", # >= "gt", # > - "add", # + - "sub", # - - "floordiv", # // ] Operator = Union[AnyUnitsOperator, SameUnitsOperator] @@ -110,34 +112,29 @@ def take_any( return getattr(opr, operator)(left, right) -def lt(left: TDataArray, right: Any) -> TDataArray: - """Perform ``(left) < (right)`` with units.""" - return take(left, "lt", right) - - -def le(left: TDataArray, right: Any) -> TDataArray: - """Perform ``(left) <= (right)`` with units.""" - return take(left, "le", right) +def mul(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) * (right)`` with units.""" + return take(left, "mul", right) -def eq(left: TDataArray, right: Any) -> TDataArray: - """Perform ``(left) == (right)`` with units.""" - return take(left, "eq", right) +def pow(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) ** (right)`` with units.""" + return take(left, "pow", right) -def ne(left: TDataArray, right: Any) -> TDataArray: - """Perform ``(left) != (right)`` with units.""" - return take(left, "ne", right) +def matmul(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) @ (right)`` with units.""" + return take(left, "matmul", right) -def ge(left: TDataArray, right: Any) -> TDataArray: - """Perform ``(left) >= (right)`` with units.""" - return take(left, "ge", right) +def truediv(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) / (right)`` with units.""" + return take(left, "truediv", right) -def gt(left: TDataArray, right: Any) -> TDataArray: - """Perform ``(left) > (right)`` with units.""" - return take(left, "gt", right) +def mod(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) % (right)`` with units.""" + return take(left, "mod", right) def add(left: TDataArray, right: Any) -> TDataArray: @@ -150,31 +147,36 @@ def sub(left: TDataArray, right: Any) -> TDataArray: return take(left, "sub", right) -def mul(left: TDataArray, right: Any) -> TDataArray: - """Perform ``(left) * (right)`` with units.""" - return take(left, "mul", right) +def floordiv(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) // (right)`` with units.""" + return take(left, "floordiv", right) -def pow(left: TDataArray, right: Any) -> TDataArray: - """Perform ``(left) ** (right)`` with units.""" - return take(left, "pow", right) +def lt(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) < (right)`` with units.""" + return take(left, "lt", right) -def matmul(left: TDataArray, right: Any) -> TDataArray: - """Perform ``(left) @ (right)`` with units.""" - return take(left, "matmul", right) +def le(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) <= (right)`` with units.""" + return take(left, "le", right) -def truediv(left: TDataArray, right: Any) -> TDataArray: - """Perform ``(left) / (right)`` with units.""" - return take(left, "truediv", right) +def eq(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) == (right)`` with units.""" + return take(left, "eq", right) -def floordiv(left: TDataArray, right: Any) -> TDataArray: - """Perform ``(left) // (right)`` with units.""" - return take(left, "floordiv", right) +def ne(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) != (right)`` with units.""" + return take(left, "ne", right) -def mod(left: TDataArray, right: Any) -> TDataArray: - """Perform ``(left) % (right)`` with units.""" - return take(left, "mod", right) +def ge(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) >= (right)`` with units.""" + return take(left, "ge", right) + + +def gt(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) > (right)`` with units.""" + return take(left, "gt", right) From 8e5f929bb9f4cf4a6cf0d7bbba0ac198693632eb Mon Sep 17 00:00:00 2001 From: Akio Taniguchi Date: Sun, 17 Dec 2023 13:34:53 +0000 Subject: [PATCH 19/31] #5 Update apply and apply_any --- xarray_units/quantity.py | 17 +++++++++-------- xarray_units/utils.py | 2 +- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/xarray_units/quantity.py b/xarray_units/quantity.py index 325cdd3..c23f52f 100644 --- a/xarray_units/quantity.py +++ b/xarray_units/quantity.py @@ -10,6 +10,7 @@ from astropy.units import Quantity from xarray import DataArray from .utils import ( + TEST_DATA, UNITS_ATTR, Equivalencies, TDataArray, @@ -22,7 +23,7 @@ def apply( da: TDataArray, - name: str, + method: str, /, *args: Any, **kwargs: Any, @@ -31,7 +32,7 @@ def apply( Args: da: Input DataArray with units. - name: Method (or property) name of Astropy Quantity. + method: Method (or property) name of Astropy Quantity. *args: Positional arguments of the method. *kwargs: Keyword arguments of the method. @@ -47,32 +48,32 @@ def apply( units = units_of(da, strict=True) def per_block(block: TDataArray) -> TDataArray: - data = apply_any(block, units, name, *args, **kwargs) + data = apply_any(block, units, method, *args, **kwargs) return block.copy(data=data) try: - tested = apply_any(1, units, name, *args, **kwargs) + test = apply_any(TEST_DATA, units, method, *args, **kwargs) except Exception as error: raise UnitsApplicationError(error) try: - applied = da.map_blocks(per_block) + result = da.map_blocks(per_block) except Exception as error: raise UnitsApplicationError(error) - return set(applied, units_of(tested, strict=True), True) + return set(result, units_of(test, strict=True), True) def apply_any( data: Any, units: UnitsLike, - name: str, + method: str, /, *args: Any, **kwargs: Any, ) -> Any: """Apply a method of Astropy Quantity to any data.""" - attr = getattr(Quantity(data, units), name) + attr = getattr(Quantity(data, units), method) if isinstance(attr, (MethodType, MethodWrapperType)): return attr(*args, **kwargs) diff --git a/xarray_units/utils.py b/xarray_units/utils.py index ce7d32c..5f56d56 100644 --- a/xarray_units/utils.py +++ b/xarray_units/utils.py @@ -24,8 +24,8 @@ # constants +TEST_DATA = 1 UNITS_ATTR = "units" -UNITS_ONE = "1" class UnitsError(Exception): From d5419d1951ba4cd3c765695dc3cf90b19b327cd9 Mon Sep 17 00:00:00 2001 From: Akio Taniguchi Date: Sun, 17 Dec 2023 13:43:28 +0000 Subject: [PATCH 20/31] #5 Update take --- xarray_units/operators.py | 34 ++++++++++++---------------------- xarray_units/quantity.py | 4 ++-- xarray_units/utils.py | 2 +- 3 files changed, 15 insertions(+), 25 deletions(-) diff --git a/xarray_units/operators.py b/xarray_units/operators.py index b269d78..226ff3f 100644 --- a/xarray_units/operators.py +++ b/xarray_units/operators.py @@ -21,14 +21,14 @@ # standard library import operator as opr -from typing import Any, Literal, Optional, Union, get_args +from typing import Any, Literal, Union, get_args # dependencies from astropy.units import Quantity from xarray import DataArray -from xarray_units.quantity import set, to, unset -from .utils import TDataArray, UnitsApplicationError, UnitsLike, units_of +from xarray_units.quantity import apply_any, set, to, unset +from .utils import TESTER, TDataArray, UnitsApplicationError, units_of # type hints @@ -77,9 +77,11 @@ def take(left: TDataArray, operator: Operator, right: Any, /) -> TDataArray: try: if operator == "pow": - tested = take_any(1, left_units, operator, right, right_units) + args = (Quantity(right, right_units),) else: - tested = take_any(1, left_units, operator, 1, right_units) + args = (Quantity(TESTER, right_units),) + + test = apply_any(TESTER, left_units, f"__{operator}__", *args) except Exception as error: raise UnitsApplicationError(error) @@ -89,29 +91,17 @@ def take(left: TDataArray, operator: Operator, right: Any, /) -> TDataArray: elif isinstance(right, Quantity): right = right.to(left_units) # type: ignore - result = getattr(opr, operator)(left, right) + try: + result = getattr(opr, operator)(left, right) + except Exception as error: + raise UnitsApplicationError(error) - if (units := units_of(tested)) is None: + if (units := units_of(test)) is None: return unset(result) else: return set(result, units, True) -def take_any( - left: Any, - left_units: UnitsLike, - operator: Operator, - right: Any, - right_units: Optional[UnitsLike], - /, -) -> Any: - """Perform an operation between two any datasets.""" - left = Quantity(left, left_units) - right = Quantity(right, right_units) - - return getattr(opr, operator)(left, right) - - def mul(left: TDataArray, right: Any) -> TDataArray: """Perform ``(left) * (right)`` with units.""" return take(left, "mul", right) diff --git a/xarray_units/quantity.py b/xarray_units/quantity.py index c23f52f..021a10c 100644 --- a/xarray_units/quantity.py +++ b/xarray_units/quantity.py @@ -10,7 +10,7 @@ from astropy.units import Quantity from xarray import DataArray from .utils import ( - TEST_DATA, + TESTER, UNITS_ATTR, Equivalencies, TDataArray, @@ -52,7 +52,7 @@ def per_block(block: TDataArray) -> TDataArray: return block.copy(data=data) try: - test = apply_any(TEST_DATA, units, method, *args, **kwargs) + test = apply_any(TESTER, units, method, *args, **kwargs) except Exception as error: raise UnitsApplicationError(error) diff --git a/xarray_units/utils.py b/xarray_units/utils.py index 5f56d56..f046790 100644 --- a/xarray_units/utils.py +++ b/xarray_units/utils.py @@ -24,7 +24,7 @@ # constants -TEST_DATA = 1 +TESTER = 1 UNITS_ATTR = "units" From 92b905281afc1d5bb1a6e49ede8ec3959c823660 Mon Sep 17 00:00:00 2001 From: Akio Taniguchi Date: Sun, 17 Dec 2023 13:57:17 +0000 Subject: [PATCH 21/31] #5 Make overwrite keyword argument --- xarray_units/operators.py | 2 +- xarray_units/quantity.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/xarray_units/operators.py b/xarray_units/operators.py index 226ff3f..ffabad2 100644 --- a/xarray_units/operators.py +++ b/xarray_units/operators.py @@ -99,7 +99,7 @@ def take(left: TDataArray, operator: Operator, right: Any, /) -> TDataArray: if (units := units_of(test)) is None: return unset(result) else: - return set(result, units, True) + return set(result, units, overwrite=True) def mul(left: TDataArray, right: Any) -> TDataArray: diff --git a/xarray_units/quantity.py b/xarray_units/quantity.py index 021a10c..cc3979d 100644 --- a/xarray_units/quantity.py +++ b/xarray_units/quantity.py @@ -61,7 +61,7 @@ def per_block(block: TDataArray) -> TDataArray: except Exception as error: raise UnitsApplicationError(error) - return set(result, units_of(test, strict=True), True) + return set(result, units_of(test, strict=True), overwrite=True) def apply_any( @@ -121,14 +121,14 @@ def like( UnitsNotValidError: Raised if units are not valid. """ - units = units_of(other, strict=True) - return apply(da, "to", units, equivalencies) + return apply(da, "to", units_of(other, strict=True), equivalencies) def set( da: TDataArray, units: UnitsLike, /, + *, overwrite: bool = False, ) -> TDataArray: """Set units to a DataArray. @@ -136,6 +136,8 @@ def set( Args: da: Input DataArray. units: Units to be set to the input. + + Keyword Args: overwrite: Whether to overwrite existing units. Returns: From 9e6d1709e828b3394d05140b1ec3e07ef1101e26 Mon Sep 17 00:00:00 2001 From: Akio Taniguchi Date: Sun, 17 Dec 2023 13:58:10 +0000 Subject: [PATCH 22/31] =?UTF-8?q?#5=20Rename=20module=20(operators=20?= =?UTF-8?q?=E2=86=92=20operator)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- xarray_units/__init__.py | 4 ++-- xarray_units/{operators.py => operator.py} | 0 2 files changed, 2 insertions(+), 2 deletions(-) rename xarray_units/{operators.py => operator.py} (100%) diff --git a/xarray_units/__init__.py b/xarray_units/__init__.py index 477778f..245fa95 100644 --- a/xarray_units/__init__.py +++ b/xarray_units/__init__.py @@ -1,8 +1,8 @@ -__all__ = ["operators", "quantity", "utils"] +__all__ = ["operator", "quantity", "utils"] __version__ = "0.1.0" # submodules -from . import operators +from . import operator from . import quantity from . import utils diff --git a/xarray_units/operators.py b/xarray_units/operator.py similarity index 100% rename from xarray_units/operators.py rename to xarray_units/operator.py From 16756e668d26daa559237c5b533331b5de86299a Mon Sep 17 00:00:00 2001 From: Akio Taniguchi Date: Sun, 17 Dec 2023 14:07:19 +0000 Subject: [PATCH 23/31] #5 Update docstrings of units_of --- xarray_units/utils.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/xarray_units/utils.py b/xarray_units/utils.py index f046790..7803979 100644 --- a/xarray_units/utils.py +++ b/xarray_units/utils.py @@ -69,7 +69,21 @@ def units_of(obj: Any, /, *, strict: Literal[True] = True) -> UnitBase: def units_of(obj: Any, /, *, strict: bool = False) -> Optional[UnitBase]: - """Return units of an object if they exist and are valid.""" + """Return units of an object if they exist and are valid. + + Args: + obj: Any object from which units are extracted. + + Keyword Args: + strict: Whether to allow None as the return value. + + Raises: + UnitsNotFoundError: Raised if ``strict`` is ``True`` + but units do not exist in the object. + UnitsNotValidError: Raised if units exist in the object + but they cannot be parsed into ``UnitBase``. + + """ if isinstance(obj, Quantity): if isinstance(units := obj.unit, UnitBase): return units From a976ffb3f3e6986c8a676d9f433ed5539b925465 Mon Sep 17 00:00:00 2001 From: Akio Taniguchi Date: Sun, 17 Dec 2023 14:11:29 +0000 Subject: [PATCH 24/31] #5 Update README --- README.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/README.md b/README.md index 922194b..83f404b 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,15 @@ # xarray-units + +[![Release](https://img.shields.io/pypi/v/xarray-units?label=Release&color=cornflowerblue&style=flat-square)](https://pypi.org/project/xarray-units/) +[![Python](https://img.shields.io/pypi/pyversions/xarray-units?label=Python&color=cornflowerblue&style=flat-square)](https://pypi.org/project/xarray-units/) +[![Downloads](https://img.shields.io/pypi/dm/xarray-units?label=Downloads&color=cornflowerblue&style=flat-square)](https://pepy.tech/project/xarray-units) +[![DOI](https://img.shields.io/badge/DOI-10.5281/zenodo.10354517-cornflowerblue?style=flat-square)](https://doi.org/10.5281/zenodo.10354517) +[![Tests](https://img.shields.io/github/actions/workflow/status/astropenguin/xarray-units/tests.yaml?label=Tests&style=flat-square)](https://github.com/astropenguin/xarray-units/actions) + xarray extension for handling units + +## Installation + +```shell +pip install xarray-units==0.1.0 +``` From eb0b14b4566a425d03ee72bdb98a2e094fba30ef Mon Sep 17 00:00:00 2001 From: Akio Taniguchi Date: Sun, 17 Dec 2023 14:12:48 +0000 Subject: [PATCH 25/31] #5 Add DOI to citation file --- CITATION.cff | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CITATION.cff b/CITATION.cff index 80e1830..b3ad7f2 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -6,7 +6,7 @@ abstract: "xarray extension for handling units" version: 0.1.0 date-released: 2023-12-11 license: "MIT" -doi: "" +doi: "10.5281/zenodo.10354517" url: "https://github.com/astropenguin/xarray-units/" authors: - given-names: "Akio" From e34c33c5eb7e73ddc9075532b7eb9cc34105f849 Mon Sep 17 00:00:00 2001 From: Akio Taniguchi Date: Sun, 17 Dec 2023 15:06:15 +0000 Subject: [PATCH 26/31] #5 Fix test operation in take --- xarray_units/operator.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/xarray_units/operator.py b/xarray_units/operator.py index ffabad2..81070b0 100644 --- a/xarray_units/operator.py +++ b/xarray_units/operator.py @@ -75,21 +75,27 @@ def take(left: TDataArray, operator: Operator, right: Any, /) -> TDataArray: left_units = units_of(left, strict=True) right_units = units_of(right) - try: - if operator == "pow": - args = (Quantity(right, right_units),) - else: - args = (Quantity(TESTER, right_units),) + if operator == "pow": + method = f"__{operator}__" + args = (Quantity(right, right_units),) + elif operator == "matmul": + method = "__mul__" + args = (Quantity(TESTER, right_units),) + else: + method = f"__{operator}__" + args = (Quantity(TESTER, right_units),) - test = apply_any(TESTER, left_units, f"__{operator}__", *args) + try: + test = apply_any(TESTER, left_units, method, *args) except Exception as error: raise UnitsApplicationError(error) if operator in get_args(SameUnitsOperator): + if isinstance(right, Quantity): + right = right.to(left_units).value # type: ignore + if isinstance(right, DataArray): right = to(right, left_units) - elif isinstance(right, Quantity): - right = right.to(left_units) # type: ignore try: result = getattr(opr, operator)(left, right) From 46e3ab91041148fde65e83a6e93598c78aefb15a Mon Sep 17 00:00:00 2001 From: Akio Taniguchi Date: Sun, 17 Dec 2023 15:41:18 +0000 Subject: [PATCH 27/31] #5 Fix wrong type of mod operator --- xarray_units/operator.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/xarray_units/operator.py b/xarray_units/operator.py index 81070b0..7ca5165 100644 --- a/xarray_units/operator.py +++ b/xarray_units/operator.py @@ -5,11 +5,11 @@ "pow", # ** "matmul", # @ "truediv", # / - "mod", # % # same-units operators "add", # + "sub", # - "floordiv", # // + "mod", # % "lt", # < "le", # <= "eq", # == @@ -37,12 +37,12 @@ "pow", # ** "matmul", # @ "truediv", # / - "mod", # % ] SameUnitsOperator = Literal[ "add", # + "sub", # - "floordiv", # // + "mod", # % "lt", # < "le", # <= "eq", # == @@ -128,11 +128,6 @@ def truediv(left: TDataArray, right: Any) -> TDataArray: return take(left, "truediv", right) -def mod(left: TDataArray, right: Any) -> TDataArray: - """Perform ``(left) % (right)`` with units.""" - return take(left, "mod", right) - - def add(left: TDataArray, right: Any) -> TDataArray: """Perform ``(left) + (right)`` with units.""" return take(left, "add", right) @@ -148,6 +143,11 @@ def floordiv(left: TDataArray, right: Any) -> TDataArray: return take(left, "floordiv", right) +def mod(left: TDataArray, right: Any) -> TDataArray: + """Perform ``(left) % (right)`` with units.""" + return take(left, "mod", right) + + def lt(left: TDataArray, right: Any) -> TDataArray: """Perform ``(left) < (right)`` with units.""" return take(left, "lt", right) From 7b13b93cde753c3375ccf19241903889e8444d67 Mon Sep 17 00:00:00 2001 From: Akio Taniguchi Date: Sun, 17 Dec 2023 16:16:22 +0000 Subject: [PATCH 28/31] #5 Add test for operator module --- tests/test_operator.py | 121 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 tests/test_operator.py diff --git a/tests/test_operator.py b/tests/test_operator.py new file mode 100644 index 0000000..d3b0ce7 --- /dev/null +++ b/tests/test_operator.py @@ -0,0 +1,121 @@ +# standard library +from typing import Any + + +# dependencies +from astropy.units import Quantity +from pytest import mark, raises +from xarray import DataArray +from xarray.testing import assert_identical # type: ignore +from xarray_units import operator as opr +from xarray_units.operator import Operator, take +from xarray_units.quantity import set +from xarray_units.utils import UnitsApplicationError + + +# test data +km = set(DataArray([1, 2, 3]), "km") +mm = set(DataArray([1, 2, 3]) * 1e6, "mm") +sc_1 = 2 +sc_2 = Quantity(2000, "m") + + +data_take: list[tuple[DataArray, Operator, Any, Any]] = [ + (km, "mul", sc_1, set(DataArray([2, 4, 6]), "km")), + (km, "mul", sc_2, set(DataArray([2, 4, 6]) * 1e3, "km m")), + (km, "mul", mm, set(DataArray([1, 4, 9]) * 1e6, "km mm")), + (mm, "mul", km, set(DataArray([1, 4, 9]) * 1e6, "km mm")), + # + (km, "pow", sc_1, set(DataArray([1, 4, 9]), "km2")), + (km, "pow", sc_2, UnitsApplicationError), + (km, "pow", mm, UnitsApplicationError), + (mm, "pow", km, UnitsApplicationError), + # + (km, "matmul", sc_1, UnitsApplicationError), + (km, "matmul", sc_2, UnitsApplicationError), + (km, "matmul", mm, set(DataArray(14) * 1e6, "km mm")), + (mm, "matmul", km, set(DataArray(14) * 1e6, "km mm")), + # + (km, "truediv", sc_1, set(DataArray([0.5, 1, 1.5]), "km")), + (km, "truediv", sc_2, set(DataArray([0.5, 1.0, 1.5]) * 1e-3, "km m-1")), + (km, "truediv", mm, set(DataArray([1, 1, 1]) * 1e-6, "km mm-1")), + (mm, "truediv", km, set(DataArray([1, 1, 1]) * 1e6, "mm km-1")), + # + (km, "add", sc_1, UnitsApplicationError), + (km, "add", sc_2, set(DataArray([3, 4, 5]), "km")), + (km, "add", mm, set(DataArray([2, 4, 6]), "km")), + (mm, "add", km, set(DataArray([2, 4, 6]) * 1e6, "mm")), + # + (km, "sub", sc_1, UnitsApplicationError), + (km, "sub", sc_2, set(DataArray([-1, 0, 1]), "km")), + (km, "sub", mm, set(DataArray([0, 0, 0]), "km")), + (mm, "sub", km, set(DataArray([0, 0, 0]), "mm")), + # + (km, "floordiv", sc_1, UnitsApplicationError), + (km, "floordiv", sc_2, set(DataArray([0, 1, 1]), "1")), + (km, "floordiv", mm, set(DataArray([1, 1, 1]), "1")), + (mm, "floordiv", km, set(DataArray([1, 1, 1]), "1")), + # + (km, "mod", sc_1, UnitsApplicationError), + (km, "mod", sc_2, set(DataArray([1, 0, 1]), "km")), + (km, "mod", mm, set(DataArray([0, 0, 0]), "km")), + (mm, "mod", km, set(DataArray([0, 0, 0]), "mm")), + # + (km, "lt", sc_1, UnitsApplicationError), + (km, "lt", sc_2, DataArray([True, False, False])), + (km, "lt", mm, DataArray([False, False, False])), + (mm, "lt", km, DataArray([False, False, False])), + # + (km, "le", sc_1, UnitsApplicationError), + (km, "le", sc_2, DataArray([True, True, False])), + (km, "le", mm, DataArray([True, True, True])), + (mm, "le", km, DataArray([True, True, True])), + # + # (km, "eq", sc_1, UnitsApplicationError), + (km, "eq", sc_2, DataArray([False, True, False])), + (km, "eq", mm, DataArray([True, True, True])), + (mm, "eq", km, DataArray([True, True, True])), + # + # (km, "ne", sc_1, UnitsApplicationError), + (km, "ne", sc_2, DataArray([True, False, True])), + (km, "ne", mm, DataArray([False, False, False])), + (mm, "ne", km, DataArray([False, False, False])), + # + (km, "ge", sc_1, UnitsApplicationError), + (km, "ge", sc_2, DataArray([False, True, True])), + (km, "ge", mm, DataArray([True, True, True])), + (mm, "ge", km, DataArray([True, True, True])), + # + (km, "gt", sc_1, UnitsApplicationError), + (km, "gt", sc_2, DataArray([False, False, True])), + (km, "gt", mm, DataArray([False, False, False])), + (mm, "gt", km, DataArray([False, False, False])), +] + + +@mark.parametrize("left, operator, right, expected", data_take) +def test_take( + left: DataArray, + operator: Operator, + right: Any, + expected: Any, +) -> None: + if expected is UnitsApplicationError: + with raises(expected): + take(left, operator, right) + else: + assert_identical(take(left, operator, right), expected) + + +@mark.parametrize("left, operator, right, expected", data_take) +def test_take_alias( + left: DataArray, + operator: Operator, + right: Any, + expected: DataArray, +) -> None: + if expected is UnitsApplicationError: + with raises(expected): + getattr(opr, operator)(left, right) + else: + assert_identical(getattr(opr, operator)(left, right), expected) From 627d8a9d377ecc9ee2290feefefcda05ee85c473 Mon Sep 17 00:00:00 2001 From: Akio Taniguchi Date: Sun, 17 Dec 2023 22:45:34 +0000 Subject: [PATCH 29/31] #5 Fix test operation in take --- tests/test_operator.py | 4 ++-- xarray_units/operator.py | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/test_operator.py b/tests/test_operator.py index d3b0ce7..1163148 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -71,12 +71,12 @@ (km, "le", mm, DataArray([True, True, True])), (mm, "le", km, DataArray([True, True, True])), # - # (km, "eq", sc_1, UnitsApplicationError), + (km, "eq", sc_1, UnitsApplicationError), (km, "eq", sc_2, DataArray([False, True, False])), (km, "eq", mm, DataArray([True, True, True])), (mm, "eq", km, DataArray([True, True, True])), # - # (km, "ne", sc_1, UnitsApplicationError), + (km, "ne", sc_1, UnitsApplicationError), (km, "ne", sc_2, DataArray([True, False, True])), (km, "ne", mm, DataArray([False, False, False])), (mm, "ne", km, DataArray([False, False, False])), diff --git a/xarray_units/operator.py b/xarray_units/operator.py index 7ca5165..4689ffd 100644 --- a/xarray_units/operator.py +++ b/xarray_units/operator.py @@ -81,6 +81,9 @@ def take(left: TDataArray, operator: Operator, right: Any, /) -> TDataArray: elif operator == "matmul": method = "__mul__" args = (Quantity(TESTER, right_units),) + elif operator == "eq" or operator == "ne": + method = "__lt__" + args = (Quantity(TESTER, right_units),) else: method = f"__{operator}__" args = (Quantity(TESTER, right_units),) From c174550050cc6847978e63e7c08e17d48e269525 Mon Sep 17 00:00:00 2001 From: Akio Taniguchi Date: Sun, 17 Dec 2023 22:48:05 +0000 Subject: [PATCH 30/31] =?UTF-8?q?#5=20Update=20package=20version=20(0.1.0?= =?UTF-8?q?=20=E2=86=92=200.2.0)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CITATION.cff | 2 +- README.md | 2 +- pyproject.toml | 2 +- xarray_units/__init__.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/CITATION.cff b/CITATION.cff index b3ad7f2..36b5d70 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -3,7 +3,7 @@ message: "If you use this software, please cite it as below." title: "xarray-units" abstract: "xarray extension for handling units" -version: 0.1.0 +version: 0.2.0 date-released: 2023-12-11 license: "MIT" doi: "10.5281/zenodo.10354517" diff --git a/README.md b/README.md index 83f404b..4bca70c 100644 --- a/README.md +++ b/README.md @@ -11,5 +11,5 @@ xarray extension for handling units ## Installation ```shell -pip install xarray-units==0.1.0 +pip install xarray-units==0.2.0 ``` diff --git a/pyproject.toml b/pyproject.toml index 3ebd823..9f38c81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "xarray-units" -version = "0.1.0" +version = "0.2.0" description = "xarray extension for handling units" authors = ["Akio Taniguchi "] documentation = "https://astropenguin.github.io/xarray-units/" diff --git a/xarray_units/__init__.py b/xarray_units/__init__.py index 245fa95..fc4d437 100644 --- a/xarray_units/__init__.py +++ b/xarray_units/__init__.py @@ -1,5 +1,5 @@ __all__ = ["operator", "quantity", "utils"] -__version__ = "0.1.0" +__version__ = "0.2.0" # submodules From 5376ff2604ab671c81089cdfec8bfabd6fe14792 Mon Sep 17 00:00:00 2001 From: Akio Taniguchi Date: Sun, 17 Dec 2023 22:48:26 +0000 Subject: [PATCH 31/31] #5 Update release date --- CITATION.cff | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CITATION.cff b/CITATION.cff index 36b5d70..fe78f22 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -4,7 +4,7 @@ message: "If you use this software, please cite it as below." title: "xarray-units" abstract: "xarray extension for handling units" version: 0.2.0 -date-released: 2023-12-11 +date-released: 2023-12-18 license: "MIT" doi: "10.5281/zenodo.10354517" url: "https://github.com/astropenguin/xarray-units/"