From a394a4d4fb9f2b8a7c15d13a96e98fff52ff5582 Mon Sep 17 00:00:00 2001 From: nstarman Date: Mon, 21 Oct 2024 12:28:42 -0400 Subject: [PATCH] docs: experimental Signed-off-by: nstarman --- src/unxt/_src/experimental.py | 128 +++++++++++++++++++++++++++------- 1 file changed, 102 insertions(+), 26 deletions(-) diff --git a/src/unxt/_src/experimental.py b/src/unxt/_src/experimental.py index ff20aaa..b724db0 100644 --- a/src/unxt/_src/experimental.py +++ b/src/unxt/_src/experimental.py @@ -1,11 +1,20 @@ -# pylint: disable=import-error +r"""Experimental features. -"""unxt: Quantities in JAX. +unxt.experimental provides experimental features that are not yet ready for +general use. These features may be removed or changed in the future without +notice. -THIS MODULE IS NOT GUARANTEED TO HAVE A STABLE API! +On some occasions JAX's automatic differentiation functions do not work well +with quantities. This is easily checked by enabling runtime type-checking (see +the docs), which will raise an error if a quantity's units do not match the +expected input / output units of a function. In these cases, you can use the +functions in this module to provide the units to the automatic differentiation +functions. Instead of directly propagating the units through the automatic +differentiation functions, the units are stripped and re-applied, while also +being provided within the function being AD'd. -Copyright (c) 2023 Galactic Dynamics. All rights reserved. """ +# pylint: disable=import-error __all__ = ["grad", "hessian", "jacfwd"] @@ -21,22 +30,50 @@ from .quantity.core import Quantity from .typing_ext import Unit from unxt._src.quantity.api import ustrip +from unxt._src.units.core import units as parse_units P = ParamSpec("P") R = TypeVar("R", bound=Quantity) def grad( - fun: Callable[P, R], argnums: int = 0, *, units: tuple[Unit, ...] + fun: Callable[P, R], argnums: int = 0, *, units: tuple[Unit | str, ...] ) -> Callable[P, R]: - # Gradient of a function with units + """Gradient of a function with units. + + In general, if you can use `quax.quaxify(jax.grad(func))` (or the syntactic + sugar `quax.grad(func)`), that's the better option! The difference from + those functions is how this units are supported. `quaxify` will directly + propagate the units through the automatic differentiation functions. But + sometimes that doesn't work and we need to strip the units and re-apply + them. This function does that, using the ``units`` kwarg. + + See Also + -------- + jax.grad : The original JAX gradient function. + + Examples + -------- + >>> import jax.numpy as jnp + >>> import unxt + >>> from unxt import Quantity + + >>> def square_volume(x: Quantity["length"]) -> Quantity["volume"]: + ... return x ** 3 + + >>> grad_square_volume = unxt.experimental.grad(square_volume, units=("m",)) + >>> grad_square_volume(Quantity(2.0, "m")) + Quantity['area'](Array(12., dtype=float32, weak_type=True), unit='m2') + + """ + units_ = tuple(map(parse_units, units)) # Gradient of function, stripping and adding units @partial(jax.grad, argnums=argnums) def gradfun_mag(*args: P.args) -> ArrayLike: args_ = ( (a if unit is None else Quantity(a, unit)) - for a, unit in zip(args, units, strict=True) + for a, unit in zip(args, units_, strict=True) ) return fun(*args_).value @@ -45,28 +82,46 @@ def gradfun(*args: P.args, **kw: P.kwargs) -> R: # inside the function we are taking the grad of. args_ = tuple( (a if unit is None else ustrip(unit, a)) - for a, unit in zip(args, units, strict=True) + for a, unit in zip(args, units_, strict=True) ) # Call the grad, returning a Quantity value = fun(*args) grad_value = gradfun_mag(*args_) # Adjust the Quantity by the units of the derivative # TODO: get Quantity[unit] / unit2 -> Quantity[unit/unit2] working - return type_unparametrized(value)(grad_value, value.unit / units[argnums]) + return type_unparametrized(value)(grad_value, value.unit / units_[argnums]) return gradfun def jacfwd( - fun: Callable[P, R], argnums: int = 0, *, units: tuple[Unit, ...] + fun: Callable[P, R], argnums: int = 0, *, units: tuple[Unit | str, ...] ) -> Callable[P, R]: """Jacobian of ``fun`` evaluated column-by-column using forward-mode AD. - In general, if you can use ``quaxed.jacfwd``, that's the better option! The - difference from ``quaxed.jacfwd`` is how this function supports units. - ``quaxed.jacfwd`` does `quax.quaxify(jax.jacfwd)`, which will 'strip' the - units when passing through. But sometimes that doesn't work and we need the - units + In general, if you can use `quax.quaxify(jax.jacfwd(func))` (or the + syntactic sugar `quax.jacfwd(func)`), that's the better option! The + difference from those functions is how this units are supported. `quaxify` + will directly propagate the units through the automatic differentiation + functions. But sometimes that doesn't work and we need to strip the units + and re-apply them. This function does that, using the ``units`` kwarg. + + See Also + -------- + jax.jacfwd : The original JAX jacfwd function. + + Examples + -------- + >>> import jax.numpy as jnp + >>> import unxt + >>> from unxt import Quantity + + >>> def square_volume(x: Quantity["length"]) -> Quantity["volume"]: + ... return x ** 3 + + >>> jacfwd_square_volume = unxt.experimental.jacfwd(square_volume, units=("m",)) + >>> jacfwd_square_volume(Quantity(2.0, "m")) + Quantity['area'](Array(12., dtype=float32, weak_type=True), unit='m2') """ argnums = eqx.error_if( @@ -75,11 +130,13 @@ def jacfwd( "only int argnums are currently supported", ) + units_ = tuple(map(parse_units, units)) + @partial(jax.jacfwd, argnums=argnums) def jacfun_mag(*args: P.args) -> R: args_ = ( (a if unit is None else Quantity(a, unit)) - for a, unit in zip(args, units, strict=True) + for a, unit in zip(args, units_, strict=True) ) return fun(*args_) @@ -88,34 +145,53 @@ def jacfun(*args: P.args, **kw: P.kwargs) -> R: # inside the function we are taking the Jacobian of. args_ = tuple( (a if unit is None else ustrip(unit, a)) - for a, unit in zip(args, units, strict=True) + for a, unit in zip(args, units_, strict=True) ) # Call the Jacobian, returning a Quantity value = jacfun_mag(*args_) # Adjust the Quantity by the units of the derivative # TODO: check the unit correction # TODO: get Quantity[unit] / unit2 -> Quantity[unit/unit2] working - return type_unparametrized(value)(value.value, value.unit / units[argnums]) + return type_unparametrized(value)(value.value, value.unit / units_[argnums]) return jacfun -def hessian(fun: Callable[P, R], *, units: tuple[Unit, ...]) -> Callable[P, R]: +def hessian(fun: Callable[P, R], *, units: tuple[Unit | str, ...]) -> Callable[P, R]: """Hessian. - In general, if you can use ``quaxed.jacfwd``, that's the better option! The - difference from ``quaxed.jacfwd`` is how this function supports units. - ``quaxed.jacfwd`` does `quax.quaxify(jax.jacfwd)`, which will 'strip' the - units when passing through. But sometimes that doesn't work and we need the - units + In general, if you can use `quax.quaxify(jax.hessian(func))` (or the + syntactic sugar `quax.hessian(func)`), that's the better option! The + difference from those functions is how this units are supported. `quaxify` + will directly propagate the units through the automatic differentiation + functions. But sometimes that doesn't work and we need to strip the units + and re-apply them. This function does that, using the ``units`` kwarg. + + See Also + -------- + jax.hessian : The original JAX hessian function. + + Examples + -------- + >>> import jax.numpy as jnp + >>> import unxt + >>> from unxt import Quantity + + >>> def square_volume(x: Quantity["length"]) -> Quantity["volume"]: + ... return x ** 3 + + >>> hessian_square_volume = unxt.experimental.hessian(square_volume, units=("m",)) + >>> hessian_square_volume(Quantity(2.0, "m")) + Quantity['length'](Array(12., dtype=float32, weak_type=True), unit='m') """ + units_ = tuple(map(parse_units, units)) @partial(jax.hessian) def hessfun_mag(*args: P.args) -> R: args_ = ( (a if unit is None else Quantity(a, unit)) - for a, unit in zip(args, units, strict=True) + for a, unit in zip(args, units_, strict=True) ) return fun(*args_) @@ -131,6 +207,6 @@ def hessfun(*args: P.args, **kw: P.kwargs) -> R: # Adjust the Quantity by the units of the derivative # TODO: check the unit correction # TODO: get Quantity[unit] / unit2 -> Quantity[unit/unit2] working - return type_unparametrized(value)(value.value, value.unit / units[0] ** 2) + return type_unparametrized(value)(value.value, value.unit / units_[0] ** 2) return hessfun