Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set static_argnames for relevant jax-jitted functions #118

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/adam_core/coordinates/transform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
from functools import partial
from typing import Literal, Optional, Union

import jax.numpy as jnp
Expand Down Expand Up @@ -563,7 +564,7 @@ def cartesian_to_keplerian(
return coords_keplerian


@jit
@partial(jit, static_argnames=("max_iter", "tol"))
def _keplerian_to_cartesian_p(
coords_keplerian: Union[np.ndarray, jnp.ndarray],
mu: float,
Expand Down Expand Up @@ -699,7 +700,7 @@ def _keplerian_to_cartesian_p(
)


@jit
@partial(jit, static_argnames=("max_iter", "tol"))
def _keplerian_to_cartesian_a(
coords_keplerian: Union[np.ndarray, jnp.ndarray],
mu: float,
Expand Down Expand Up @@ -782,7 +783,7 @@ def _keplerian_to_cartesian_a(
)


@jit
@partial(jit, static_argnames=("max_iter", "tol"))
def _keplerian_to_cartesian_q(
coords_keplerian: Union[np.ndarray, jnp.ndarray],
mu: float,
Expand Down Expand Up @@ -1039,7 +1040,7 @@ def cartesian_to_cometary(
return coords_cometary


@jit
@partial(jit, static_argnames=("max_iter", "tol"))
def _cometary_to_cartesian(
coords_cometary: Union[np.ndarray, jnp.ndarray],
t0: float,
Expand Down
5 changes: 3 additions & 2 deletions src/adam_core/dynamics/aberrations.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import Tuple

import jax.numpy as jnp
Expand All @@ -12,7 +13,7 @@
C = c.C


@jit
@partial(jit, static_argnames=("lt_tol", "mu", "tol", "max_iter"))
def _add_light_time(
orbit: jnp.ndarray,
t0: float,
Expand Down Expand Up @@ -108,7 +109,7 @@ def _while_condition(p):
)


@jit
@partial(jit, static_argnames=("lt_tol", "mu", "tol", "max_iter"))
def add_light_time(
orbits: jnp.ndarray,
t0: jnp.ndarray,
Expand Down
3 changes: 2 additions & 1 deletion src/adam_core/dynamics/chi.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import Tuple

import jax.numpy as jnp
Expand All @@ -12,7 +13,7 @@
MU = c.MU


@jit
@partial(jit, static_argnames=("mu", "max_iter", "tol"))
def calc_chi(
r: jnp.ndarray,
v: jnp.ndarray,
Expand Down
3 changes: 2 additions & 1 deletion src/adam_core/dynamics/ephemeris.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import Tuple

import jax.numpy as jnp
Expand All @@ -18,7 +19,7 @@
from .aberrations import _add_light_time, add_stellar_aberration


@jit
@partial(jit, static_argnames=("lt_tol", "max_iter", "tol", "stellar_aberration"))
def _generate_ephemeris_2body(
propagated_orbit: np.ndarray,
observation_time: float,
Expand Down
3 changes: 2 additions & 1 deletion src/adam_core/dynamics/kepler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import Tuple

import jax.numpy as jnp
Expand Down Expand Up @@ -194,7 +195,7 @@ def _calc_parabolic_anomalies(nu: float, e: float) -> Tuple[float, float]:
return D, M


@jit
@partial(jit, static_argnames=("max_iter", "tol"))
def solve_kepler(e: float, M: float, max_iter: int = 100, tol: float = 1e-15) -> float:
"""
Solve Kepler's equation for true anomaly (nu) given eccentricity
Expand Down
3 changes: 2 additions & 1 deletion src/adam_core/dynamics/lagrange.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import Tuple

import jax.numpy as jnp
Expand All @@ -14,7 +15,7 @@
LAGRANGE_TYPES = Tuple[jnp.float64, jnp.float64, jnp.float64, jnp.float64]


@jit
@partial(jit, static_argnames=("mu", "max_iter", "tol"))
def calc_lagrange_coefficients(
r: jnp.ndarray,
v: jnp.ndarray,
Expand Down
4 changes: 3 additions & 1 deletion src/adam_core/dynamics/propagation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial

import jax.numpy as jnp
import numpy as np
from jax import config, jit, vmap
Expand All @@ -15,7 +17,7 @@
config.update("jax_enable_x64", True)


@jit
@partial(jit, static_argnames=("max_iter", "tol"))
def _propagate_2body(
orbit: jnp.ndarray,
t0: float,
Expand Down
Loading