From baa13aff5191a9c76ceeb06a6c160439b31ab899 Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Thu, 30 Nov 2023 17:11:13 +0100 Subject: [PATCH] Handle events occuring at fixed timepoints without root-finding A first attempt towards #2185 For events that occur at known timepoints, we don't need sundials' root-finding. We can just stop the solver at the respective timepoints and handle the events. To be extended to parameterized but state-independent trigger functions. --- include/amici/forwardproblem.h | 4 +- include/amici/model.h | 10 +++- include/amici/model_dimensions.h | 20 ++++--- include/amici/model_ode.h | 8 ++- include/amici/serialization.h | 1 + models/model_calvetti/model_calvetti.h | 3 +- models/model_dirac/model_dirac.h | 3 +- models/model_events/model_events.h | 3 +- .../model_jakstat_adjoint.h | 3 +- .../model_jakstat_adjoint_o2.h | 3 +- .../model_nested_events/model_nested_events.h | 3 +- models/model_neuron/model_neuron.h | 3 +- models/model_neuron_o2/model_neuron_o2.h | 3 +- models/model_robertson/model_robertson.h | 3 +- models/model_steadystate/model_steadystate.h | 3 +- python/sdist/amici/de_export.py | 44 +++++++++++++++- python/sdist/amici/de_model.py | 18 +++++++ python/tests/test_events.py | 46 ++++++++++++++++ src/forwardproblem.cpp | 52 ++++++++++++++----- src/model.cpp | 19 ++++++- src/model_header.template.h | 4 +- src/solver.cpp | 2 +- src/solver_cvodes.cpp | 12 ++++- tests/cpp/unittests/testExpData.cpp | 3 +- tests/cpp/unittests/testMisc.cpp | 2 + tests/cpp/unittests/testSerialization.cpp | 2 + 26 files changed, 234 insertions(+), 43 deletions(-) diff --git a/include/amici/forwardproblem.h b/include/amici/forwardproblem.h index 91e5d50791..3d3742b794 100644 --- a/include/amici/forwardproblem.h +++ b/include/amici/forwardproblem.h @@ -197,7 +197,9 @@ class ForwardProblem { SimulationState const& getSimulationStateTimepoint(int it) const { if (model->getTimepoint(it) == initial_state_.t) return getInitialSimulationState(); - return timepoint_states_.find(model->getTimepoint(it))->second; + auto map_iter = timepoint_states_.find(model->getTimepoint(it)); + assert(map_iter != timepoint_states_.end()); + return map_iter->second; }; /** diff --git a/include/amici/model.h b/include/amici/model.h index 72f733e6cf..aaa6b1e47d 100644 --- a/include/amici/model.h +++ b/include/amici/model.h @@ -12,7 +12,6 @@ #include "amici/vector.h" #include -#include #include namespace amici { @@ -117,6 +116,8 @@ class Model : public AbstractModel, public ModelDimensions { * @param ndxdotdp_explicit Number of nonzero elements in `dxdotdp_explicit` * @param ndxdotdx_explicit Number of nonzero elements in `dxdotdx_explicit` * @param w_recursion_depth Recursion depth of fw + * @param state_independent_events Map of events with state-independent + * triggers functions, mapping trigger timepoints to event indices. */ Model( ModelDimensions const& model_dimensions, @@ -124,7 +125,8 @@ class Model : public AbstractModel, public ModelDimensions { amici::SecondOrderMode o2mode, std::vector idlist, std::vector z2event, bool pythonGenerated = false, int ndxdotdp_explicit = 0, int ndxdotdx_explicit = 0, - int w_recursion_depth = 0 + int w_recursion_depth = 0, + std::map> state_independent_events = {} ); /** Destructor. */ @@ -1449,6 +1451,8 @@ class Model : public AbstractModel, public ModelDimensions { */ SUNMatrixWrapper const& get_dxdotdp_full() const; + virtual std::vector get_trigger_timepoints() const; + /** * Flag indicating whether for * `amici::Solver::sensi_` == `amici::SensitivityOrder::second` @@ -1462,6 +1466,8 @@ class Model : public AbstractModel, public ModelDimensions { /** Logger */ Logger* logger = nullptr; + std::map> state_independent_events_ = {}; + protected: /** * @brief Write part of a slice to a buffer according to indices specified diff --git a/include/amici/model_dimensions.h b/include/amici/model_dimensions.h index f0679dbe36..b5aa1ba21e 100644 --- a/include/amici/model_dimensions.h +++ b/include/amici/model_dimensions.h @@ -31,6 +31,7 @@ struct ModelDimensions { * @param nz Number of event observables * @param nztrue Number of event observables of the non-augmented model * @param ne Number of events + * @param ne_solver Number of events that require root-finding * @param nspl Number of splines * @param nJ Number of objective functions * @param nw Number of repeating elements @@ -58,11 +59,12 @@ struct ModelDimensions { int const nx_rdata, int const nxtrue_rdata, int const nx_solver, int const nxtrue_solver, int const nx_solver_reinit, int const np, int const nk, int const ny, int const nytrue, int const nz, - int const nztrue, int const ne, int const nspl, int const nJ, - int const nw, int const ndwdx, int const ndwdp, int const ndwdw, - int const ndxdotdw, std::vector ndJydy, int const ndxrdatadxsolver, - int const ndxrdatadtcl, int const ndtotal_cldx_rdata, int const nnz, - int const ubw, int const lbw + int const nztrue, int const ne, int const ne_solver, int const nspl, + int const nJ, int const nw, int const ndwdx, int const ndwdp, + int const ndwdw, int const ndxdotdw, std::vector ndJydy, + int const ndxrdatadxsolver, int const ndxrdatadtcl, + int const ndtotal_cldx_rdata, int const nnz, int const ubw, + int const lbw ) : nx_rdata(nx_rdata) , nxtrue_rdata(nxtrue_rdata) @@ -76,6 +78,7 @@ struct ModelDimensions { , nz(nz) , nztrue(nztrue) , ne(ne) + , ne_solver(ne_solver) , nspl(nspl) , nw(nw) , ndwdx(ndwdx) @@ -104,6 +107,8 @@ struct ModelDimensions { Expects(nztrue >= 0); Expects(nztrue <= nz); Expects(ne >= 0); + Expects(ne_solver >= 0); + Expects(ne >= ne_solver); Expects(nspl >= 0); Expects(nw >= 0); Expects(ndwdx >= 0); @@ -164,7 +169,10 @@ struct ModelDimensions { /** Number of events */ int ne{0}; - /** numer of spline functions in the model */ + /** Number of events that require root-finding */ + int ne_solver{0}; + + /** Number of spline functions in the model */ int nspl{0}; /** Number of common expressions */ diff --git a/include/amici/model_ode.h b/include/amici/model_ode.h index 6567ee5c95..85ef35ac87 100644 --- a/include/amici/model_ode.h +++ b/include/amici/model_ode.h @@ -39,6 +39,8 @@ class Model_ODE : public Model { * @param ndxdotdp_explicit number of nonzero elements dxdotdp_explicit * @param ndxdotdx_explicit number of nonzero elements dxdotdx_explicit * @param w_recursion_depth Recursion depth of fw + * @param state_independent_events Map of events with state-independent + * triggers functions, mapping trigger timepoints to event indices. */ Model_ODE( ModelDimensions const& model_dimensions, @@ -46,12 +48,14 @@ class Model_ODE : public Model { const SecondOrderMode o2mode, std::vector const& idlist, std::vector const& z2event, bool const pythonGenerated = false, int const ndxdotdp_explicit = 0, int const ndxdotdx_explicit = 0, - int const w_recursion_depth = 0 + int const w_recursion_depth = 0, + std::map> state_independent_events + = {} ) : Model( model_dimensions, simulation_parameters, o2mode, idlist, z2event, pythonGenerated, ndxdotdp_explicit, ndxdotdx_explicit, - w_recursion_depth + w_recursion_depth, state_independent_events ) {} void diff --git a/include/amici/serialization.h b/include/amici/serialization.h index 501b56618a..1f29961bcc 100644 --- a/include/amici/serialization.h +++ b/include/amici/serialization.h @@ -260,6 +260,7 @@ void serialize( ar& m.nz; ar& m.nztrue; ar& m.ne; + ar& m.ne_solver; ar& m.nspl; ar& m.nw; ar& m.ndwdx; diff --git a/models/model_calvetti/model_calvetti.h b/models/model_calvetti/model_calvetti.h index 828be82728..2d0569a890 100644 --- a/models/model_calvetti/model_calvetti.h +++ b/models/model_calvetti/model_calvetti.h @@ -46,6 +46,7 @@ class Model_model_calvetti : public amici::Model_DAE { 0, 4, 0, + 0, 1, 38, 53, @@ -207,6 +208,6 @@ class Model_model_calvetti : public amici::Model_DAE { } // namespace model_model_calvetti -} // namespace amici +} // namespace amici #endif /* _amici_model_calvetti_h */ diff --git a/models/model_dirac/model_dirac.h b/models/model_dirac/model_dirac.h index 7a762479b5..b35fbda575 100644 --- a/models/model_dirac/model_dirac.h +++ b/models/model_dirac/model_dirac.h @@ -46,6 +46,7 @@ class Model_model_dirac : public amici::Model_ODE { 0, 2, 0, + 0, 1, 0, 0, @@ -204,6 +205,6 @@ class Model_model_dirac : public amici::Model_ODE { } // namespace model_model_dirac -} // namespace amici +} // namespace amici #endif /* _amici_model_dirac_h */ diff --git a/models/model_events/model_events.h b/models/model_events/model_events.h index df4bb68ae7..648090d990 100644 --- a/models/model_events/model_events.h +++ b/models/model_events/model_events.h @@ -60,6 +60,7 @@ class Model_model_events : public amici::Model_ODE { 2, 6, 0, + 0, 1, 0, 0, @@ -232,6 +233,6 @@ class Model_model_events : public amici::Model_ODE { } // namespace model_model_events -} // namespace amici +} // namespace amici #endif /* _amici_model_events_h */ diff --git a/models/model_jakstat_adjoint/model_jakstat_adjoint.h b/models/model_jakstat_adjoint/model_jakstat_adjoint.h index fdac2a9f94..6d7601947a 100644 --- a/models/model_jakstat_adjoint/model_jakstat_adjoint.h +++ b/models/model_jakstat_adjoint/model_jakstat_adjoint.h @@ -49,6 +49,7 @@ class Model_model_jakstat_adjoint : public amici::Model_ODE { 0, 0, 0, + 0, 1, 2, 1, @@ -210,6 +211,6 @@ class Model_model_jakstat_adjoint : public amici::Model_ODE { } // namespace model_model_jakstat_adjoint -} // namespace amici +} // namespace amici #endif /* _amici_model_jakstat_adjoint_h */ diff --git a/models/model_jakstat_adjoint_o2/model_jakstat_adjoint_o2.h b/models/model_jakstat_adjoint_o2/model_jakstat_adjoint_o2.h index 22ca276067..bfac0b3267 100644 --- a/models/model_jakstat_adjoint_o2/model_jakstat_adjoint_o2.h +++ b/models/model_jakstat_adjoint_o2/model_jakstat_adjoint_o2.h @@ -49,6 +49,7 @@ class Model_model_jakstat_adjoint_o2 : public amici::Model_ODE { 0, 0, 0, + 0, 18, 10, 2, @@ -210,6 +211,6 @@ class Model_model_jakstat_adjoint_o2 : public amici::Model_ODE { } // namespace model_model_jakstat_adjoint_o2 -} // namespace amici +} // namespace amici #endif /* _amici_model_jakstat_adjoint_o2_h */ diff --git a/models/model_nested_events/model_nested_events.h b/models/model_nested_events/model_nested_events.h index 9ff8f519fe..69f9b00b5f 100644 --- a/models/model_nested_events/model_nested_events.h +++ b/models/model_nested_events/model_nested_events.h @@ -49,6 +49,7 @@ class Model_model_nested_events : public amici::Model_ODE { 0, 4, 0, + 0, 1, 0, 0, @@ -210,6 +211,6 @@ class Model_model_nested_events : public amici::Model_ODE { } // namespace model_model_nested_events -} // namespace amici +} // namespace amici #endif /* _amici_model_nested_events_h */ diff --git a/models/model_neuron/model_neuron.h b/models/model_neuron/model_neuron.h index e8f6f5c21f..7f945997b4 100644 --- a/models/model_neuron/model_neuron.h +++ b/models/model_neuron/model_neuron.h @@ -63,6 +63,7 @@ class Model_model_neuron : public amici::Model_ODE { 1, 1, 0, + 0, 1, 0, 0, @@ -238,6 +239,6 @@ class Model_model_neuron : public amici::Model_ODE { } // namespace model_model_neuron -} // namespace amici +} // namespace amici #endif /* _amici_model_neuron_h */ diff --git a/models/model_neuron_o2/model_neuron_o2.h b/models/model_neuron_o2/model_neuron_o2.h index 23df2b9b33..9c1c4b59ee 100644 --- a/models/model_neuron_o2/model_neuron_o2.h +++ b/models/model_neuron_o2/model_neuron_o2.h @@ -65,6 +65,7 @@ class Model_model_neuron_o2 : public amici::Model_ODE { 1, 1, 0, + 0, 5, 2, 2, @@ -242,6 +243,6 @@ class Model_model_neuron_o2 : public amici::Model_ODE { } // namespace model_model_neuron_o2 -} // namespace amici +} // namespace amici #endif /* _amici_model_neuron_o2_h */ diff --git a/models/model_robertson/model_robertson.h b/models/model_robertson/model_robertson.h index 7f4377d785..816dd2db32 100644 --- a/models/model_robertson/model_robertson.h +++ b/models/model_robertson/model_robertson.h @@ -47,6 +47,7 @@ class Model_model_robertson : public amici::Model_DAE { 0, 0, 0, + 0, 1, 1, 2, @@ -209,6 +210,6 @@ class Model_model_robertson : public amici::Model_DAE { } // namespace model_model_robertson -} // namespace amici +} // namespace amici #endif /* _amici_model_robertson_h */ diff --git a/models/model_steadystate/model_steadystate.h b/models/model_steadystate/model_steadystate.h index b61649f9c8..776b754b08 100644 --- a/models/model_steadystate/model_steadystate.h +++ b/models/model_steadystate/model_steadystate.h @@ -46,6 +46,7 @@ class Model_model_steadystate : public amici::Model_ODE { 0, 0, 0, + 0, 1, 2, 2, @@ -204,6 +205,6 @@ class Model_model_steadystate : public amici::Model_ODE { } // namespace model_model_steadystate -} // namespace amici +} // namespace amici #endif /* _amici_model_steadystate_h */ diff --git a/python/sdist/amici/de_export.py b/python/sdist/amici/de_export.py index b1fa02c421..85271b6a6e 100644 --- a/python/sdist/amici/de_export.py +++ b/python/sdist/amici/de_export.py @@ -1425,13 +1425,24 @@ def num_expr(self) -> int: return len(self.sym("w")) def num_events(self) -> int: + """ + Total number of Events (those for which root-functions are added and those without). + + :return: + number of events + """ + return len(self.sym("h")) + + def num_events_solver(self) -> int: """ Number of Events. :return: number of event symbols (length of the root vector in AMICI) """ - return len(self.sym("h")) + return sum( + not event.triggers_at_fixed_timepoint() for event in self.events() + ) def sym(self, name: str) -> sp.Matrix: """ @@ -1750,6 +1761,16 @@ def parse_events(self) -> None: # add roots of heaviside functions self.add_component(root) + # re-order events - first those that require root tracking, then the others + self._events = list( + chain( + itertools.filterfalse( + Event.triggers_at_fixed_timepoint, self._events + ), + filter(Event.triggers_at_fixed_timepoint, self._events), + ) + ) + def get_appearance_counts(self, idxs: List[int]) -> List[int]: """ Counts how often a state appears in the time derivative of @@ -3642,6 +3663,7 @@ def _write_model_header_cpp(self) -> None: "NZ": self.model.num_eventobs(), "NZTRUE": self.model.num_eventobs(), "NEVENT": self.model.num_events(), + "NEVENT_SOLVER": self.model.num_events_solver(), "NOBJECTIVE": "1", "NSPL": len(self.model.splines), "NW": len(self.model.sym("w")), @@ -3736,6 +3758,7 @@ def _write_model_header_cpp(self) -> None: ) ), "Z2EVENT": ", ".join(map(str, self.model._z2event)), + "STATE_INDEPENDENT_EVENTS": self._get_state_independent_event_intializer(), "ID": ", ".join( ( str(float(isinstance(s, DifferentialState))) @@ -3871,6 +3894,25 @@ def _get_symbol_id_initializer_list(self, name: str) -> str: for idx, symbol in enumerate(self.model.sym(name)) ) + def _get_state_independent_event_intializer(self) -> str: + tmp_map = {} + for event_idx, event in enumerate(self.model.events()): + if not event.triggers_at_fixed_timepoint(): + continue + trigger_time = float(event.get_trigger_time()) + try: + tmp_map[trigger_time].append(event_idx) + except KeyError: + tmp_map[trigger_time] = [event_idx] + + def vector_initializer(v): + return f"{{{', '.join(map(str, v))}}}" + + return ", ".join( + f"{{{trigger_time}, {vector_initializer(event_idxs)}}}" + for trigger_time, event_idxs in tmp_map.items() + ) + def _write_c_make_file(self): """Write CMake ``CMakeLists.txt`` file for this model.""" sources = "\n".join( diff --git a/python/sdist/amici/de_model.py b/python/sdist/amici/de_model.py index 77d9013ad2..92d3fd9536 100644 --- a/python/sdist/amici/de_model.py +++ b/python/sdist/amici/de_model.py @@ -8,6 +8,7 @@ from .import_utils import ( RESERVED_SYMBOLS, ObservableTransformation, + amici_time_symbol, cast_to_sym, generate_measurement_symbol, generate_regularization_symbol, @@ -713,3 +714,20 @@ def __eq__(self, other): return self.get_val() == other.get_val() and ( self.get_initial_value() == other.get_initial_value() ) + + def triggers_at_fixed_timepoint(self) -> bool: + """Check whether the event triggers at a (single) fixed time-point.""" + return (amici_time_symbol - self.get_val()).is_Number + + def get_trigger_time(self) -> sp.Float: + """Get the time at which the event triggers. + + Only for events that trigger at a single fixed time-point. + """ + + if not self.triggers_at_fixed_timepoint(): + raise NotImplementedError( + "This event does not trigger at a fixed timepoint." + ) + + return amici_time_symbol - self.get_val() diff --git a/python/tests/test_events.py b/python/tests/test_events.py index d2a177bded..2e5b337a65 100644 --- a/python/tests/test_events.py +++ b/python/tests/test_events.py @@ -704,3 +704,49 @@ def expm(x): from mpmath import expm return np.array(expm(x).tolist()).astype(float) + + +import amici +from amici.antimony_import import antimony2amici +from amici.gradient_check import check_derivatives +from amici.testing import TemporaryDirectoryWinSafe as TemporaryDirectory + + +def test_handling_of_fixed_time_point_event_triggers(): + """If this example requires changes, please also update documentation/python_interface.rst.""" + import os + + os.environ["ENABLE_AMICI_DEBUGGING"] = "TRUE" + ant_model = """ + model test_events_time_based + event_target = 0 + bolus = 1 + at (time > 1): event_target = 1 + at (time > 2): event_target = event_target + bolus + at (time > 3): event_target = 3 + + end + """ + module_name = "test_events_time_based" + with TemporaryDirectory(prefix=module_name, delete=False) as outdir: + antimony2amici( + ant_model, + model_name=module_name, + output_dir=outdir, + verbose=True, + ) + model_module = amici.import_model_module( + module_name=module_name, module_path=outdir + ) + amici_model = model_module.getModel() + amici_model.setTimepoints(np.linspace(0, 4, 200)) + amici_solver = amici_model.getSolver() + rdata = amici.runAmiciSimulation(amici_model, amici_solver) + assert rdata.status == amici.AMICI_SUCCESS + assert (rdata.x[rdata.ts < 1] == 0).all() + assert (rdata.x[(rdata.ts >= 1) & (rdata.ts < 2)] == 1).all() + assert (rdata.x[(rdata.ts >= 2) & (rdata.ts < 3)] == 2).all() + assert (rdata.x[(rdata.ts >= 3)] == 3).all() + + # TODO sensitivities + check_derivatives(amici_model, amici_solver, edata=None) diff --git a/src/forwardproblem.cpp b/src/forwardproblem.cpp index 13946547ef..828826eeba 100644 --- a/src/forwardproblem.cpp +++ b/src/forwardproblem.cpp @@ -9,6 +9,7 @@ #include #include +#include namespace amici { @@ -110,24 +111,51 @@ void ForwardProblem::workForwardProblem() { /* store initial state and sensitivity*/ initial_state_ = getSimulationState(); + // get list of trigger timepoints for fixed-time triggered events + auto trigger_timepoints = model->get_trigger_timepoints(); + auto it_trigger_timepoints = trigger_timepoints.begin(); + /* loop over timepoints */ for (it_ = 0; it_ < model->nt(); it_++) { - auto nextTimepoint = model->getTimepoint(it_); + // next output time-point + auto next_t_out = model->getTimepoint(it_); - if (std::isinf(nextTimepoint)) + if (std::isinf(next_t_out)) break; - if (nextTimepoint > model->t0()) { - // Solve for nextTimepoint - while (t_ < nextTimepoint) { - int status = solver->run(nextTimepoint); - solver->writeSolution(&t_, x_, dx_, sx_, dx_); + if (next_t_out > model->t0()) { + // Solve for next output timepoint + while (t_ < next_t_out) { + // next stop time is either next output timepoint or next + // time-triggered event + auto next_t_event = it_trigger_timepoints != trigger_timepoints.end() + ? *it_trigger_timepoints + : std::numeric_limits::infinity(); + auto next_t_stop = std::min(next_t_out, next_t_event); + int status = solver->run(next_t_stop); + /* sx will be copied from solver on demand if sensitivities are computed */ + solver->writeSolution(&t_, x_, dx_, sx_, dx_); + if (status == AMICI_ILL_INPUT) { - /* clustering of roots => turn off rootfinding */ + /* clustering of roots => turn off root-finding */ solver->turnOffRootFinding(); - } else if (status == AMICI_ROOT_RETURN) { + } else if (status == AMICI_ROOT_RETURN || t_ == next_t_event) { + // solver-tracked or time-triggered event + solver->getRootInfo(roots_found_.data()); + + // check if we are at a trigger timepoint. + // if so, set the root-found flag + if (t_ == next_t_event) { + std::cout << "..." << std::endl; + for (auto ie : model->state_independent_events_[t_]) { + roots_found_[ie] = 1; + std::cout << "ie: " << ie << std::endl; + } + ++it_trigger_timepoints; + } + handleEvent(&tlastroot_, false, false); } } @@ -156,13 +184,9 @@ void ForwardProblem::handleEvent( /* store Heaviside information at event occurrence */ model->froot(t_, x_, dx_, rootvals_); - /* store timepoint at which the event occurred*/ + /* store timepoint at which the event occurred */ discs_.push_back(t_); - /* extract and store which events occurred */ - if (!seflag && !initial_event) { - solver->getRootInfo(roots_found_.data()); - } root_idx_.push_back(roots_found_); rval_tmp_ = rootvals_; diff --git a/src/model.cpp b/src/model.cpp index c50fcc60bf..218870cafe 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -178,12 +178,14 @@ Model::Model( SimulationParameters simulation_parameters, SecondOrderMode o2mode, std::vector idlist, std::vector z2event, bool const pythonGenerated, int const ndxdotdp_explicit, - int const ndxdotdx_explicit, int const w_recursion_depth + int const ndxdotdx_explicit, int const w_recursion_depth, + std::map> state_independent_events ) : ModelDimensions(model_dimensions) , pythonGenerated(pythonGenerated) , o2mode(o2mode) , idlist(std::move(idlist)) + , state_independent_events_(std::move(state_independent_events)) , derived_state_(model_dimensions) , z2event_(std::move(z2event)) , state_is_non_negative_(nx_solver, false) @@ -297,6 +299,7 @@ bool operator==(ModelDimensions const& a, ModelDimensions const& b) { && (a.nx_solver_reinit == b.nx_solver_reinit) && (a.np == b.np) && (a.nk == b.nk) && (a.ny == b.ny) && (a.nytrue == b.nytrue) && (a.nz == b.nz) && (a.nztrue == b.nztrue) && (a.ne == b.ne) + && (a.ne_solver == b.ne_solver) && (a.nspl == b.nspl) && (a.nw == b.nw) && (a.ndwdx == b.ndwdx) && (a.ndwdp == b.ndwdp) && (a.ndwdw == b.ndwdw) && (a.ndxdotdw == b.ndxdotdw) && (a.ndJydy == b.ndJydy) && (a.nnz == b.nnz) && (a.nJ == b.nJ) @@ -3071,6 +3074,20 @@ void Model::fstotal_cl( ); } +std::vector Model::get_trigger_timepoints() const { + std::vector trigger_timepoints( + state_independent_events_.size(), 0.0 + ); + // collect keys from state_independent_events_ which are the trigger + // timepoints + auto it = trigger_timepoints.begin(); + for (auto const& kv : state_independent_events_) { + *(it++) = kv.first; + } + std::sort(trigger_timepoints.begin(), trigger_timepoints.end()); + return trigger_timepoints; +} + const_N_Vector Model::computeX_pos(const_N_Vector x) { if (any_state_non_negative_) { for (int ix = 0; ix < derived_state_.x_pos_tmp_.getLength(); ++ix) { diff --git a/src/model_header.template.h b/src/model_header.template.h index af05c8ccc5..932fdeb1a0 100644 --- a/src/model_header.template.h +++ b/src/model_header.template.h @@ -121,6 +121,7 @@ class Model_TPL_MODELNAME : public amici::Model_TPL_MODEL_TYPE_UPPER { TPL_NZ, // nz TPL_NZTRUE, // nztrue TPL_NEVENT, // nevent + TPL_NEVENT_SOLVER, // nevent_solver TPL_NSPL, // nspl TPL_NOBJECTIVE, // nobjective TPL_NW, // nw @@ -146,7 +147,8 @@ class Model_TPL_MODELNAME : public amici::Model_TPL_MODEL_TYPE_UPPER { true, // pythonGenerated TPL_NDXDOTDP_EXPLICIT, // ndxdotdp_explicit TPL_NDXDOTDX_EXPLICIT, // ndxdotdx_explicit - TPL_W_RECURSION_DEPTH // w_recursion_depth + TPL_W_RECURSION_DEPTH, // w_recursion_depth + {TPL_STATE_INDEPENDENT_EVENTS} // state-independent events ) { root_initial_values_ = std::vector( rootInitialValues.begin(), rootInitialValues.end() diff --git a/src/solver.cpp b/src/solver.cpp index c114623050..22e1723640 100644 --- a/src/solver.cpp +++ b/src/solver.cpp @@ -159,7 +159,7 @@ void Solver::setup( /* activates stability limit detection */ setStabLimDet(stldet_); - rootInit(model->ne); + rootInit(model->ne_solver); if (nx() == 0) return; diff --git a/src/solver_cvodes.cpp b/src/solver_cvodes.cpp index 7157302c9e..3efbda3f77 100644 --- a/src/solver_cvodes.cpp +++ b/src/solver_cvodes.cpp @@ -1066,9 +1066,17 @@ static int froot(realtype t, N_Vector x, realtype* root, void* user_data) { auto model = dynamic_cast(typed_udata->first); Expects(model); - model->froot(t, x, gsl::make_span(root, model->ne)); + if (model->ne != model->ne_solver) { + // temporary buffer to store all root function values, not only the ones + // tracked by the solver + static std::vector root_buffer(model->ne, 0.0); + model->froot(t, x, root_buffer); + std::copy_n(root_buffer.begin(), model->ne_solver, root); + } else { + model->froot(t, x, gsl::make_span(root, model->ne_solver)); + } return model->checkFinite( - gsl::make_span(root, model->ne), ModelQuantity::root + gsl::make_span(root, model->ne_solver), ModelQuantity::root ); } diff --git a/tests/cpp/unittests/testExpData.cpp b/tests/cpp/unittests/testExpData.cpp index 416a41227b..d6e1a6fff2 100644 --- a/tests/cpp/unittests/testExpData.cpp +++ b/tests/cpp/unittests/testExpData.cpp @@ -4,8 +4,6 @@ #include #include -#include -#include #include #include @@ -49,6 +47,7 @@ class ExpDataTest : public ::testing::Test { nz, // nz nz, // nztrue nmaxevent, // ne + 0, // ne_solver 0, // nspl 0, // nJ 0, // nw diff --git a/tests/cpp/unittests/testMisc.cpp b/tests/cpp/unittests/testMisc.cpp index 80d2c3bc36..14af3a7a82 100644 --- a/tests/cpp/unittests/testMisc.cpp +++ b/tests/cpp/unittests/testMisc.cpp @@ -65,6 +65,7 @@ class ModelTest : public ::testing::Test { nz, // nz nz, // nztrue nmaxevent, // ne + 0, // ne_solver 0, // nspl 0, // nJ 0, // nw @@ -303,6 +304,7 @@ class SolverTest : public ::testing::Test { nz, // nz nz, // nztrue ne, // ne + 0, // ne_solver 0, // nspl 0, // nJ 0, // nw diff --git a/tests/cpp/unittests/testSerialization.cpp b/tests/cpp/unittests/testSerialization.cpp index a516de0880..f59f04d9c7 100644 --- a/tests/cpp/unittests/testSerialization.cpp +++ b/tests/cpp/unittests/testSerialization.cpp @@ -142,6 +142,7 @@ TEST(ModelSerializationTest, ToFile) nz, // nz nz, // nztrue ne, // ne + 0, // ne_solver 0, // nspl 0, // nJ 9, // nw @@ -207,6 +208,7 @@ TEST(ReturnDataSerializationTest, ToString) nz, // nz nz, // nztrue ne, // ne + 0, // ne_solver 0, // nspl 0, // nJ 9, // nw