diff --git a/src/adam_core/coordinates/transform.py b/src/adam_core/coordinates/transform.py index f531d5c..91dc71e 100644 --- a/src/adam_core/coordinates/transform.py +++ b/src/adam_core/coordinates/transform.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +from functools import partial from typing import Literal, Optional, Union import jax.numpy as jnp @@ -563,7 +564,7 @@ def cartesian_to_keplerian( return coords_keplerian -@jit +@partial(jit, static_argnames=("max_iter", "tol")) def _keplerian_to_cartesian_p( coords_keplerian: Union[np.ndarray, jnp.ndarray], mu: float, @@ -699,7 +700,7 @@ def _keplerian_to_cartesian_p( ) -@jit +@partial(jit, static_argnames=("max_iter", "tol")) def _keplerian_to_cartesian_a( coords_keplerian: Union[np.ndarray, jnp.ndarray], mu: float, @@ -782,7 +783,7 @@ def _keplerian_to_cartesian_a( ) -@jit +@partial(jit, static_argnames=("max_iter", "tol")) def _keplerian_to_cartesian_q( coords_keplerian: Union[np.ndarray, jnp.ndarray], mu: float, @@ -1039,7 +1040,7 @@ def cartesian_to_cometary( return coords_cometary -@jit +@partial(jit, static_argnames=("max_iter", "tol")) def _cometary_to_cartesian( coords_cometary: Union[np.ndarray, jnp.ndarray], t0: float, diff --git a/src/adam_core/dynamics/aberrations.py b/src/adam_core/dynamics/aberrations.py index d6b271b..6dd619a 100644 --- a/src/adam_core/dynamics/aberrations.py +++ b/src/adam_core/dynamics/aberrations.py @@ -1,3 +1,4 @@ +from functools import partial from typing import Tuple import jax.numpy as jnp @@ -12,7 +13,7 @@ C = c.C -@jit +@partial(jit, static_argnames=("lt_tol", "mu", "tol", "max_iter")) def _add_light_time( orbit: jnp.ndarray, t0: float, @@ -108,7 +109,7 @@ def _while_condition(p): ) -@jit +@partial(jit, static_argnames=("lt_tol", "mu", "tol", "max_iter")) def add_light_time( orbits: jnp.ndarray, t0: jnp.ndarray, diff --git a/src/adam_core/dynamics/chi.py b/src/adam_core/dynamics/chi.py index 90788c0..ed75fb2 100644 --- a/src/adam_core/dynamics/chi.py +++ b/src/adam_core/dynamics/chi.py @@ -1,3 +1,4 @@ +from functools import partial from typing import Tuple import jax.numpy as jnp @@ -12,7 +13,7 @@ MU = c.MU -@jit +@partial(jit, static_argnames=("mu", "max_iter", "tol")) def calc_chi( r: jnp.ndarray, v: jnp.ndarray, diff --git a/src/adam_core/dynamics/ephemeris.py b/src/adam_core/dynamics/ephemeris.py index 710e35f..2dd6bd0 100644 --- a/src/adam_core/dynamics/ephemeris.py +++ b/src/adam_core/dynamics/ephemeris.py @@ -1,3 +1,4 @@ +from functools import partial from typing import Tuple import jax.numpy as jnp @@ -18,7 +19,7 @@ from .aberrations import _add_light_time, add_stellar_aberration -@jit +@partial(jit, static_argnames=("lt_tol", "max_iter", "tol", "stellar_aberration")) def _generate_ephemeris_2body( propagated_orbit: np.ndarray, observation_time: float, diff --git a/src/adam_core/dynamics/kepler.py b/src/adam_core/dynamics/kepler.py index 47f0cba..f533aed 100644 --- a/src/adam_core/dynamics/kepler.py +++ b/src/adam_core/dynamics/kepler.py @@ -1,3 +1,4 @@ +from functools import partial from typing import Tuple import jax.numpy as jnp @@ -194,7 +195,7 @@ def _calc_parabolic_anomalies(nu: float, e: float) -> Tuple[float, float]: return D, M -@jit +@partial(jit, static_argnames=("max_iter", "tol")) def solve_kepler(e: float, M: float, max_iter: int = 100, tol: float = 1e-15) -> float: """ Solve Kepler's equation for true anomaly (nu) given eccentricity diff --git a/src/adam_core/dynamics/lagrange.py b/src/adam_core/dynamics/lagrange.py index 0114a4b..8a281da 100644 --- a/src/adam_core/dynamics/lagrange.py +++ b/src/adam_core/dynamics/lagrange.py @@ -1,3 +1,4 @@ +from functools import partial from typing import Tuple import jax.numpy as jnp @@ -14,7 +15,7 @@ LAGRANGE_TYPES = Tuple[jnp.float64, jnp.float64, jnp.float64, jnp.float64] -@jit +@partial(jit, static_argnames=("mu", "max_iter", "tol")) def calc_lagrange_coefficients( r: jnp.ndarray, v: jnp.ndarray, diff --git a/src/adam_core/dynamics/propagation.py b/src/adam_core/dynamics/propagation.py index af665b2..61d70a1 100644 --- a/src/adam_core/dynamics/propagation.py +++ b/src/adam_core/dynamics/propagation.py @@ -1,3 +1,5 @@ +from functools import partial + import jax.numpy as jnp import numpy as np from jax import config, jit, vmap @@ -15,7 +17,7 @@ config.update("jax_enable_x64", True) -@jit +@partial(jit, static_argnames=("max_iter", "tol")) def _propagate_2body( orbit: jnp.ndarray, t0: float,