Skip to content

Commit

Permalink
simply tests + add support for non-dynamic simulation in jax
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Oct 22, 2024
1 parent 907acb7 commit 82a01ba
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 57 deletions.
4 changes: 1 addition & 3 deletions .github/workflows/test_benchmark_collection_models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ jobs:
# retrieve test models
- name: Download and test benchmark collection
run: |
git clone --depth 1 https://github.com/benchmarking-initiative/Benchmark-Models-PEtab.git \
&& export BENCHMARK_COLLECTION="$(pwd)/Benchmark-Models-PEtab/Benchmark-Models/" \
&& pip3 install -e $BENCHMARK_COLLECTION/../src/python \
pip install git+https://github.com/Benchmarking-Initiative/Benchmark-Models-PEtab.git@master#subdirectory=src/python \
&& AMICI_PARALLEL_COMPILE="" tests/benchmark-models/test_benchmark_collection.sh
# run gradient checks
Expand Down
53 changes: 29 additions & 24 deletions python/sdist/amici/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,48 +49,39 @@ def __init__(self):

@staticmethod
@abstractmethod
def xdot(t, x, args):
...
def xdot(t, x, args): ...

@staticmethod
@abstractmethod
def _w(t, x, p, k, tcl):
...
def _w(t, x, p, k, tcl): ...

@staticmethod
@abstractmethod
def x0(p, k):
...
def x0(p, k): ...

@staticmethod
@abstractmethod
def x_solver(x):
...
def x_solver(x): ...

@staticmethod
@abstractmethod
def x_rdata(x, tcl):
...
def x_rdata(x, tcl): ...

@staticmethod
@abstractmethod
def tcl(x, p, k):
...
def tcl(x, p, k): ...

@staticmethod
@abstractmethod
def y(t, x, p, k, tcl):
...
def y(t, x, p, k, tcl): ...

@staticmethod
@abstractmethod
def sigmay(y, p, k):
...
def sigmay(y, p, k): ...

@staticmethod
@abstractmethod
def Jy(y, my, sigmay):
...
def Jy(y, my, sigmay): ...

def unscale_p(self, p, pscale):
return jax.vmap(
Expand Down Expand Up @@ -136,6 +127,7 @@ def _solve(self, ts, p, k, x0, checkpointed):
saveat=diffrax.SaveAt(ts=ts),
throw=False,
)

return sol.ys, tcl, sol.stats

def _obs(self, ts, x, p, k, tcl):
Expand All @@ -162,21 +154,30 @@ def _run(
my: jnp.ndarray,
pscale: np.ndarray,
checkpointed=True,
dynamic=True,
):
ps = self.unscale_p(p, pscale)
if k_preeq.shape[0] > 0:
x0 = self._preeq(ps, k_preeq)

Check warning on line 161 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L161

Added line #L161 was not covered by tests
else:
x0 = self.x0(ps, k)
x, tcl, stats = self._solve(ts, ps, k, x0, checkpointed=checkpointed)

if dynamic:
x, tcl, stats = self._solve(
ts, ps, k, x0, checkpointed=checkpointed
)
else:
x = tuple(jnp.array([x0_i] * len(ts)) for x0_i in x0)
tcl = self.tcl(x0, ps, k)
stats = None

Check warning on line 172 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L170-L172

Added lines #L170 - L172 were not covered by tests
obs = self._obs(ts, x, ps, k, tcl)
my_r = my.reshape((len(ts), -1))
sigmay = self._sigmay(obs, ps, k)
llh = self._loss(obs, sigmay, my_r)
x_rdata = self._x_rdata(x, tcl)
return llh, (x_rdata, obs, stats)

@eqx.filter_jit
# @eqx.filter_jit
def run(
self,
ts: np.ndarray,
Expand All @@ -185,8 +186,9 @@ def run(
k_preeq: np.ndarray,
my: np.ndarray,
pscale: np.ndarray,
dynamic=True,
):
return self._run(ts, p, k, k_preeq, my, pscale)
return self._run(ts, p, k, k_preeq, my, pscale, dynamic=dynamic)

@eqx.filter_jit
def srun(
Expand All @@ -197,6 +199,7 @@ def srun(
k_preeq: np.ndarray,
my: np.ndarray,
pscale: np.ndarray,
dynamic=True,
):
(llh, (x, obs, stats)), sllh = (
jax.value_and_grad(self._run, 1, True)
Expand All @@ -212,6 +215,7 @@ def s2run(
k_preeq: np.ndarray,
my: np.ndarray,
pscale: np.ndarray,
dynamic=True,
):
(llh, (x, obs, stats)), sllh = (
jax.value_and_grad(self._run, 1, True)
Expand All @@ -232,27 +236,28 @@ def run_simulation(
k_preeq = np.asarray(edata.fixedParametersPreequilibration)
my = np.asarray(edata.getObservedData())
pscale = np.asarray(edata.pscale)
dynamic = np.max(ts) > 0

rdata_kwargs = dict()

if sensitivity_order == amici.SensitivityOrder.none:
(
rdata_kwargs["llh"],
(rdata_kwargs["x"], rdata_kwargs["y"], rdata_kwargs["stats"]),
) = self.run(ts, p, k, k_preeq, my, pscale)
) = self.run(ts, p, k, k_preeq, my, pscale, dynamic)
elif sensitivity_order == amici.SensitivityOrder.first:
(
rdata_kwargs["llh"],
rdata_kwargs["sllh"],
(rdata_kwargs["x"], rdata_kwargs["y"], rdata_kwargs["stats"]),
) = self.srun(ts, p, k, k_preeq, my, pscale)
) = self.srun(ts, p, k, k_preeq, my, pscale, dynamic)
elif sensitivity_order == amici.SensitivityOrder.second:
(
rdata_kwargs["llh"],
rdata_kwargs["sllh"],
rdata_kwargs["s2llh"],
(rdata_kwargs["x"], rdata_kwargs["y"], rdata_kwargs["stats"]),
) = self.s2run(ts, p, k, k_preeq, my, pscale)
) = self.s2run(ts, p, k, k_preeq, my, pscale, dynamic)

for field in rdata_kwargs.keys():
if field == "llh":
Expand Down
16 changes: 15 additions & 1 deletion python/sdist/amici/petab/petab_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ def import_petab_problem(
model_name: str = None,
compile_: bool = None,
non_estimated_parameters_as_constants=True,
jax=False,
**kwargs,
) -> "amici.Model":
) -> "amici.Model | amici.JAXModel":
"""
Create an AMICI model for a PEtab problem.
Expand All @@ -64,6 +65,9 @@ def import_petab_problem(
model size and simulation times. If sensitivities with respect to those
parameters are required, this should be set to ``False``.
:param jax:
Whether to load the jax version of the model.
:param kwargs:
Additional keyword arguments to be passed to
:meth:`amici.sbml_import.SbmlImporter.sbml2amici` or
Expand Down Expand Up @@ -154,6 +158,16 @@ def import_petab_problem(

# import model
model_module = amici.import_model_module(model_name, model_output_dir)

if jax:
model = model_module.get_jax_model()

logger.info(
f"Successfully loaded jax model {model_name} "
f"from {model_output_dir}."
)
return model

model = model_module.getModel()
check_model(amici_model=model, petab_problem=petab_problem)

Expand Down
12 changes: 2 additions & 10 deletions tests/benchmark-models/test_benchmark_collection.sh
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,9 @@ script_path=$(dirname "$BASH_SOURCE")
script_path=$(cd "$script_path" && pwd)

for model in $models; do
yaml="${model_dir}"/"${model}"/"${model}".yaml

# different naming scheme
if [[ "$model" == "Bertozzi_PNAS2020" ]]; then
yaml="${model_dir}"/"${model}"/problem.yaml
fi

amici_model_dir=test_bmc/"${model}"
amici_model_dir=test_bmc
mkdir -p "$amici_model_dir"
cmd_import="amici_import_petab ${yaml} -o ${amici_model_dir} -n ${model} --flatten"
cmd_run="$script_path/test_petab_model.py -y ${yaml} -d ${amici_model_dir} -m ${model} -c"
cmd_run="$script_path/test_petab_model.py -d ${amici_model_dir} -m ${model} -c"

printf '=%.0s' {1..40}
printf " %s " "${model}"
Expand Down
34 changes: 15 additions & 19 deletions tests/benchmark-models/test_petab_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import argparse
import contextlib
import importlib
import logging
import os
import sys
Expand All @@ -29,6 +28,7 @@
)
from timeit import default_timer as timer
from petab.v1.visualize import plot_problem
import benchmark_models_petab

logger = get_logger(f"amici.{__name__}", logging.WARNING)

Expand Down Expand Up @@ -67,15 +67,6 @@ def parse_cli_args():
help="Plot measurement and simulation results",
)

# PEtab problem
parser.add_argument(
"-y",
"--yaml",
dest="yaml_file_name",
required=True,
help="PEtab YAML problem filename",
)

# Corresponding AMICI model
parser.add_argument(
"-m",
Expand All @@ -88,7 +79,7 @@ def parse_cli_args():
"-d",
"--model-dir",
dest="model_directory",
help="Directory containing the AMICI module of the "
help="Parent directory containing the AMICI module of the "
"model to simulate. Required if model is not "
"in python path.",
)
Expand All @@ -113,19 +104,20 @@ def main():

logger.info(
f"Simulating '{args.model_name}' "
f"({args.model_directory}) using PEtab data from "
f"{args.yaml_file_name}"
f"({args.model_directory}) with AMICI"
)

# load PEtab files
problem = petab.Problem.from_yaml(args.yaml_file_name)
problem = benchmark_models_petab.get_problem(args.model_name)
petab.flatten_timepoint_specific_output_overrides(problem)

# load model
if args.model_directory:
sys.path.insert(0, args.model_directory)
model_module = importlib.import_module(args.model_name)
amici_model = model_module.getModel()
from amici.petab.petab_import import import_petab_problem

amici_model = import_petab_problem(
problem,
model_output_dir=Path(args.model_directory) / args.model_name,
)
amici_solver = amici_model.getSolver()

amici_solver.setAbsoluteTolerance(1e-8)
Expand All @@ -145,7 +137,11 @@ def main():
rdatas = res[RDATAS]
llh = res[LLH]

jax_model = model_module.get_jax_model()
jax_model = import_petab_problem(
problem,
model_output_dir=Path(args.model_directory) / args.model_name,
jax=True,
)
simulation_conditions = (
problem.get_simulation_conditions_from_measurement_df()
)
Expand Down

0 comments on commit 82a01ba

Please sign in to comment.