Skip to content

Commit

Permalink
docs: experimental
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman committed Oct 21, 2024
1 parent 2b0c01c commit a394a4d
Showing 1 changed file with 102 additions and 26 deletions.
128 changes: 102 additions & 26 deletions src/unxt/_src/experimental.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
# pylint: disable=import-error
r"""Experimental features.
"""unxt: Quantities in JAX.
unxt.experimental provides experimental features that are not yet ready for
general use. These features may be removed or changed in the future without
notice.
THIS MODULE IS NOT GUARANTEED TO HAVE A STABLE API!
On some occasions JAX's automatic differentiation functions do not work well
with quantities. This is easily checked by enabling runtime type-checking (see
the docs), which will raise an error if a quantity's units do not match the
expected input / output units of a function. In these cases, you can use the
functions in this module to provide the units to the automatic differentiation
functions. Instead of directly propagating the units through the automatic
differentiation functions, the units are stripped and re-applied, while also
being provided within the function being AD'd.
Copyright (c) 2023 Galactic Dynamics. All rights reserved.
"""
# pylint: disable=import-error

__all__ = ["grad", "hessian", "jacfwd"]

Expand All @@ -21,22 +30,50 @@
from .quantity.core import Quantity
from .typing_ext import Unit
from unxt._src.quantity.api import ustrip
from unxt._src.units.core import units as parse_units

P = ParamSpec("P")
R = TypeVar("R", bound=Quantity)


def grad(
fun: Callable[P, R], argnums: int = 0, *, units: tuple[Unit, ...]
fun: Callable[P, R], argnums: int = 0, *, units: tuple[Unit | str, ...]
) -> Callable[P, R]:
# Gradient of a function with units
"""Gradient of a function with units.
In general, if you can use `quax.quaxify(jax.grad(func))` (or the syntactic
sugar `quax.grad(func)`), that's the better option! The difference from
those functions is how this units are supported. `quaxify` will directly
propagate the units through the automatic differentiation functions. But
sometimes that doesn't work and we need to strip the units and re-apply
them. This function does that, using the ``units`` kwarg.
See Also
--------
jax.grad : The original JAX gradient function.
Examples
--------
>>> import jax.numpy as jnp
>>> import unxt
>>> from unxt import Quantity
>>> def square_volume(x: Quantity["length"]) -> Quantity["volume"]:
... return x ** 3
>>> grad_square_volume = unxt.experimental.grad(square_volume, units=("m",))
>>> grad_square_volume(Quantity(2.0, "m"))
Quantity['area'](Array(12., dtype=float32, weak_type=True), unit='m2')
"""
units_ = tuple(map(parse_units, units))

# Gradient of function, stripping and adding units
@partial(jax.grad, argnums=argnums)
def gradfun_mag(*args: P.args) -> ArrayLike:
args_ = (
(a if unit is None else Quantity(a, unit))
for a, unit in zip(args, units, strict=True)
for a, unit in zip(args, units_, strict=True)
)
return fun(*args_).value

Expand All @@ -45,28 +82,46 @@ def gradfun(*args: P.args, **kw: P.kwargs) -> R:
# inside the function we are taking the grad of.
args_ = tuple(
(a if unit is None else ustrip(unit, a))
for a, unit in zip(args, units, strict=True)
for a, unit in zip(args, units_, strict=True)
)
# Call the grad, returning a Quantity
value = fun(*args)
grad_value = gradfun_mag(*args_)
# Adjust the Quantity by the units of the derivative
# TODO: get Quantity[unit] / unit2 -> Quantity[unit/unit2] working
return type_unparametrized(value)(grad_value, value.unit / units[argnums])
return type_unparametrized(value)(grad_value, value.unit / units_[argnums])

return gradfun


def jacfwd(
fun: Callable[P, R], argnums: int = 0, *, units: tuple[Unit, ...]
fun: Callable[P, R], argnums: int = 0, *, units: tuple[Unit | str, ...]
) -> Callable[P, R]:
"""Jacobian of ``fun`` evaluated column-by-column using forward-mode AD.
In general, if you can use ``quaxed.jacfwd``, that's the better option! The
difference from ``quaxed.jacfwd`` is how this function supports units.
``quaxed.jacfwd`` does `quax.quaxify(jax.jacfwd)`, which will 'strip' the
units when passing through. But sometimes that doesn't work and we need the
units
In general, if you can use `quax.quaxify(jax.jacfwd(func))` (or the
syntactic sugar `quax.jacfwd(func)`), that's the better option! The
difference from those functions is how this units are supported. `quaxify`
will directly propagate the units through the automatic differentiation
functions. But sometimes that doesn't work and we need to strip the units
and re-apply them. This function does that, using the ``units`` kwarg.
See Also
--------
jax.jacfwd : The original JAX jacfwd function.
Examples
--------
>>> import jax.numpy as jnp
>>> import unxt
>>> from unxt import Quantity
>>> def square_volume(x: Quantity["length"]) -> Quantity["volume"]:
... return x ** 3
>>> jacfwd_square_volume = unxt.experimental.jacfwd(square_volume, units=("m",))
>>> jacfwd_square_volume(Quantity(2.0, "m"))
Quantity['area'](Array(12., dtype=float32, weak_type=True), unit='m2')
"""
argnums = eqx.error_if(
Expand All @@ -75,11 +130,13 @@ def jacfwd(
"only int argnums are currently supported",
)

units_ = tuple(map(parse_units, units))

@partial(jax.jacfwd, argnums=argnums)
def jacfun_mag(*args: P.args) -> R:
args_ = (
(a if unit is None else Quantity(a, unit))
for a, unit in zip(args, units, strict=True)
for a, unit in zip(args, units_, strict=True)
)
return fun(*args_)

Expand All @@ -88,34 +145,53 @@ def jacfun(*args: P.args, **kw: P.kwargs) -> R:
# inside the function we are taking the Jacobian of.
args_ = tuple(
(a if unit is None else ustrip(unit, a))
for a, unit in zip(args, units, strict=True)
for a, unit in zip(args, units_, strict=True)
)
# Call the Jacobian, returning a Quantity
value = jacfun_mag(*args_)
# Adjust the Quantity by the units of the derivative
# TODO: check the unit correction
# TODO: get Quantity[unit] / unit2 -> Quantity[unit/unit2] working
return type_unparametrized(value)(value.value, value.unit / units[argnums])
return type_unparametrized(value)(value.value, value.unit / units_[argnums])

return jacfun


def hessian(fun: Callable[P, R], *, units: tuple[Unit, ...]) -> Callable[P, R]:
def hessian(fun: Callable[P, R], *, units: tuple[Unit | str, ...]) -> Callable[P, R]:
"""Hessian.
In general, if you can use ``quaxed.jacfwd``, that's the better option! The
difference from ``quaxed.jacfwd`` is how this function supports units.
``quaxed.jacfwd`` does `quax.quaxify(jax.jacfwd)`, which will 'strip' the
units when passing through. But sometimes that doesn't work and we need the
units
In general, if you can use `quax.quaxify(jax.hessian(func))` (or the
syntactic sugar `quax.hessian(func)`), that's the better option! The
difference from those functions is how this units are supported. `quaxify`
will directly propagate the units through the automatic differentiation
functions. But sometimes that doesn't work and we need to strip the units
and re-apply them. This function does that, using the ``units`` kwarg.
See Also
--------
jax.hessian : The original JAX hessian function.
Examples
--------
>>> import jax.numpy as jnp
>>> import unxt
>>> from unxt import Quantity
>>> def square_volume(x: Quantity["length"]) -> Quantity["volume"]:
... return x ** 3
>>> hessian_square_volume = unxt.experimental.hessian(square_volume, units=("m",))
>>> hessian_square_volume(Quantity(2.0, "m"))
Quantity['length'](Array(12., dtype=float32, weak_type=True), unit='m')
"""
units_ = tuple(map(parse_units, units))

@partial(jax.hessian)
def hessfun_mag(*args: P.args) -> R:
args_ = (
(a if unit is None else Quantity(a, unit))
for a, unit in zip(args, units, strict=True)
for a, unit in zip(args, units_, strict=True)
)
return fun(*args_)

Expand All @@ -131,6 +207,6 @@ def hessfun(*args: P.args, **kw: P.kwargs) -> R:
# Adjust the Quantity by the units of the derivative
# TODO: check the unit correction
# TODO: get Quantity[unit] / unit2 -> Quantity[unit/unit2] working
return type_unparametrized(value)(value.value, value.unit / units[0] ** 2)
return type_unparametrized(value)(value.value, value.unit / units_[0] ** 2)

return hessfun

0 comments on commit a394a4d

Please sign in to comment.