Skip to content

Commit

Permalink
Handle events occuring at fixed timepoints without root-finding (#2227)
Browse files Browse the repository at this point in the history
A first attempt towards #2185

For events that occur at known timepoints, we don't need sundials' root-finding. We can just stop the solver at the respective timepoints and handle the events.

Here, events are sorted such that the `ne_solver` events  that require root-finding by the solver come first and the other `ne - ne_solver` events come after that. The solver only tracks  `ne_solver` roots. 

To be extended to parameterized but state-independent trigger functions at some point.
  • Loading branch information
dweindl authored Dec 11, 2023
1 parent ad8eeb9 commit ecbae3f
Show file tree
Hide file tree
Showing 32 changed files with 363 additions and 71 deletions.
22 changes: 14 additions & 8 deletions include/amici/exception.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,29 +62,35 @@ class AmiException : public std::exception {
};

/**
* @brief cvode exception handler class
* @brief CVODE exception handler class
*/
class CvodeException : public AmiException {
public:
/**
* @brief Constructor
* @param error_code error code returned by cvode function
* @param function cvode function name
* @param error_code error code returned by CVODE function
* @param function CVODE function name
* @param extra Extra text to append to error message
*/
CvodeException(int error_code, char const* function);
CvodeException(
int error_code, char const* function, char const* extra = nullptr
);
};

/**
* @brief ida exception handler class
* @brief IDA exception handler class
*/
class IDAException : public AmiException {
public:
/**
* @brief Constructor
* @param error_code error code returned by ida function
* @param function ida function name
* @param error_code error code returned by IDA function
* @param function IDA function name
* @param extra Extra text to append to error message
*/
IDAException(int error_code, char const* function);
IDAException(
int error_code, char const* function, char const* extra = nullptr
);
};

/**
Expand Down
12 changes: 7 additions & 5 deletions include/amici/forwardproblem.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include "amici/vector.h"
#include <amici/amici.h>

#include <memory>
#include <sundials/sundials_direct.h>
#include <vector>

Expand Down Expand Up @@ -197,7 +196,9 @@ class ForwardProblem {
SimulationState const& getSimulationStateTimepoint(int it) const {
if (model->getTimepoint(it) == initial_state_.t)
return getInitialSimulationState();
return timepoint_states_.find(model->getTimepoint(it))->second;
auto map_iter = timepoint_states_.find(model->getTimepoint(it));
assert(map_iter != timepoint_states_.end());
return map_iter->second;
};

/**
Expand Down Expand Up @@ -273,9 +274,9 @@ class ForwardProblem {
/**
* @brief Execute everything necessary for the handling of data points
*
* @param it index of data point
* @param t measurement timepoint
*/
void handleDataPoint(int it);
void handleDataPoint(realtype t);

/**
* @brief Applies the event bolus to the current state
Expand Down Expand Up @@ -368,7 +369,8 @@ class ForwardProblem {
* @brief Array of flags indicating which root has been found.
*
* Array of length nr (ne) with the indices of the user functions gi found
* to have a root. For i = 0, . . . ,nr 1 if gi has a root, and = 0 if not.
* to have a root. For i = 0, . . . ,nr 1 or -1 if gi has a root, and = 0
* if not. See CVodeGetRootInfo for details.
*/
std::vector<int> roots_found_;

Expand Down
21 changes: 19 additions & 2 deletions include/amici/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include "amici/vector.h"

#include <map>
#include <memory>
#include <vector>

namespace amici {
Expand Down Expand Up @@ -117,14 +116,17 @@ class Model : public AbstractModel, public ModelDimensions {
* @param ndxdotdp_explicit Number of nonzero elements in `dxdotdp_explicit`
* @param ndxdotdx_explicit Number of nonzero elements in `dxdotdx_explicit`
* @param w_recursion_depth Recursion depth of fw
* @param state_independent_events Map of events with state-independent
* triggers functions, mapping trigger timepoints to event indices.
*/
Model(
ModelDimensions const& model_dimensions,
SimulationParameters simulation_parameters,
amici::SecondOrderMode o2mode, std::vector<amici::realtype> idlist,
std::vector<int> z2event, bool pythonGenerated = false,
int ndxdotdp_explicit = 0, int ndxdotdx_explicit = 0,
int w_recursion_depth = 0
int w_recursion_depth = 0,
std::map<realtype, std::vector<int>> state_independent_events = {}
);

/** Destructor. */
Expand Down Expand Up @@ -1449,6 +1451,15 @@ class Model : public AbstractModel, public ModelDimensions {
*/
SUNMatrixWrapper const& get_dxdotdp_full() const;

/**
* @brief Get trigger times for events that don't require root-finding.
*
* @return List of unique trigger points for events that don't require
* root-finding (i.e. that trigger at predetermined timepoints),
* in ascending order.
*/
virtual std::vector<double> get_trigger_timepoints() const;

/**
* Flag indicating whether for
* `amici::Solver::sensi_` == `amici::SensitivityOrder::second`
Expand All @@ -1462,6 +1473,12 @@ class Model : public AbstractModel, public ModelDimensions {
/** Logger */
Logger* logger = nullptr;

/**
* @brief Map of trigger timepoints to event indices for events that don't
* require root-finding.
*/
std::map<realtype, std::vector<int>> state_independent_events_ = {};

protected:
/**
* @brief Write part of a slice to a buffer according to indices specified
Expand Down
7 changes: 5 additions & 2 deletions include/amici/model_dae.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,22 @@ class Model_DAE : public Model {
* @param ndxdotdp_explicit number of nonzero elements dxdotdp_explicit
* @param ndxdotdx_explicit number of nonzero elements dxdotdx_explicit
* @param w_recursion_depth Recursion depth of fw
* @param state_independent_events Map of events with state-independent
* triggers functions, mapping trigger timepoints to event indices.
*/
Model_DAE(
ModelDimensions const& model_dimensions,
SimulationParameters simulation_parameters,
const SecondOrderMode o2mode, std::vector<realtype> const& idlist,
std::vector<int> const& z2event, bool const pythonGenerated = false,
int const ndxdotdp_explicit = 0, int const ndxdotdx_explicit = 0,
int const w_recursion_depth = 0
int const w_recursion_depth = 0,
std::map<realtype, std::vector<int>> state_independent_events = {}
)
: Model(
model_dimensions, simulation_parameters, o2mode, idlist, z2event,
pythonGenerated, ndxdotdp_explicit, ndxdotdx_explicit,
w_recursion_depth
w_recursion_depth, state_independent_events
) {
derived_state_.M_ = SUNMatrixWrapper(nx_solver, nx_solver);
auto M_nnz = static_cast<sunindextype>(
Expand Down
20 changes: 14 additions & 6 deletions include/amici/model_dimensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ struct ModelDimensions {
* @param nz Number of event observables
* @param nztrue Number of event observables of the non-augmented model
* @param ne Number of events
* @param ne_solver Number of events that require root-finding
* @param nspl Number of splines
* @param nJ Number of objective functions
* @param nw Number of repeating elements
Expand Down Expand Up @@ -58,11 +59,12 @@ struct ModelDimensions {
int const nx_rdata, int const nxtrue_rdata, int const nx_solver,
int const nxtrue_solver, int const nx_solver_reinit, int const np,
int const nk, int const ny, int const nytrue, int const nz,
int const nztrue, int const ne, int const nspl, int const nJ,
int const nw, int const ndwdx, int const ndwdp, int const ndwdw,
int const ndxdotdw, std::vector<int> ndJydy, int const ndxrdatadxsolver,
int const ndxrdatadtcl, int const ndtotal_cldx_rdata, int const nnz,
int const ubw, int const lbw
int const nztrue, int const ne, int const ne_solver, int const nspl,
int const nJ, int const nw, int const ndwdx, int const ndwdp,
int const ndwdw, int const ndxdotdw, std::vector<int> ndJydy,
int const ndxrdatadxsolver, int const ndxrdatadtcl,
int const ndtotal_cldx_rdata, int const nnz, int const ubw,
int const lbw
)
: nx_rdata(nx_rdata)
, nxtrue_rdata(nxtrue_rdata)
Expand All @@ -76,6 +78,7 @@ struct ModelDimensions {
, nz(nz)
, nztrue(nztrue)
, ne(ne)
, ne_solver(ne_solver)
, nspl(nspl)
, nw(nw)
, ndwdx(ndwdx)
Expand Down Expand Up @@ -104,6 +107,8 @@ struct ModelDimensions {
Expects(nztrue >= 0);
Expects(nztrue <= nz);
Expects(ne >= 0);
Expects(ne_solver >= 0);
Expects(ne >= ne_solver);
Expects(nspl >= 0);
Expects(nw >= 0);
Expects(ndwdx >= 0);
Expand Down Expand Up @@ -164,7 +169,10 @@ struct ModelDimensions {
/** Number of events */
int ne{0};

/** numer of spline functions in the model */
/** Number of events that require root-finding */
int ne_solver{0};

/** Number of spline functions in the model */
int nspl{0};

/** Number of common expressions */
Expand Down
7 changes: 5 additions & 2 deletions include/amici/model_ode.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,22 @@ class Model_ODE : public Model {
* @param ndxdotdp_explicit number of nonzero elements dxdotdp_explicit
* @param ndxdotdx_explicit number of nonzero elements dxdotdx_explicit
* @param w_recursion_depth Recursion depth of fw
* @param state_independent_events Map of events with state-independent
* triggers functions, mapping trigger timepoints to event indices.
*/
Model_ODE(
ModelDimensions const& model_dimensions,
SimulationParameters simulation_parameters,
const SecondOrderMode o2mode, std::vector<realtype> const& idlist,
std::vector<int> const& z2event, bool const pythonGenerated = false,
int const ndxdotdp_explicit = 0, int const ndxdotdx_explicit = 0,
int const w_recursion_depth = 0
int const w_recursion_depth = 0,
std::map<realtype, std::vector<int>> state_independent_events = {}
)
: Model(
model_dimensions, simulation_parameters, o2mode, idlist, z2event,
pythonGenerated, ndxdotdp_explicit, ndxdotdx_explicit,
w_recursion_depth
w_recursion_depth, state_independent_events
) {}

void
Expand Down
3 changes: 3 additions & 0 deletions include/amici/serialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <boost/iostreams/device/back_inserter.hpp>
#include <boost/iostreams/stream.hpp>
#include <boost/serialization/array.hpp>
#include <boost/serialization/map.hpp>
#include <boost/serialization/vector.hpp>

/** @file serialization.h Helper functions and forward declarations for
Expand Down Expand Up @@ -143,6 +144,7 @@ void serialize(Archive& ar, amici::Model& m, unsigned int const /*version*/) {
ar& m.sigma_res_;
ar& m.steadystate_computation_mode_;
ar& m.steadystate_sensitivity_mode_;
ar& m.state_independent_events_;
}

/**
Expand Down Expand Up @@ -260,6 +262,7 @@ void serialize(
ar& m.nz;
ar& m.nztrue;
ar& m.ne;
ar& m.ne_solver;
ar& m.nspl;
ar& m.nw;
ar& m.ndwdx;
Expand Down
1 change: 1 addition & 0 deletions matlab/@amimodel/generateC.m
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ function generateC(this)
fprintf(fid,[' ' num2str(this.nz) ',\n']);
fprintf(fid,[' ' num2str(this.nztrue) ',\n']);
fprintf(fid,[' ' num2str(this.nevent) ',\n']);
fprintf(fid,[' ' num2str(this.nevent) ',\n']);
fprintf(fid,[' 0,\n']);
fprintf(fid,[' ' num2str(this.ng) ',\n']);
fprintf(fid,[' ' num2str(this.nw) ',\n']);
Expand Down
3 changes: 2 additions & 1 deletion models/model_calvetti/model_calvetti.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class Model_model_calvetti : public amici::Model_DAE {
0,
0,
4,
4,
0,
1,
38,
Expand Down Expand Up @@ -207,6 +208,6 @@ class Model_model_calvetti : public amici::Model_DAE {

} // namespace model_model_calvetti

} // namespace amici
} // namespace amici

#endif /* _amici_model_calvetti_h */
3 changes: 2 additions & 1 deletion models/model_dirac/model_dirac.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class Model_model_dirac : public amici::Model_ODE {
0,
0,
2,
2,
0,
1,
0,
Expand Down Expand Up @@ -204,6 +205,6 @@ class Model_model_dirac : public amici::Model_ODE {

} // namespace model_model_dirac

} // namespace amici
} // namespace amici

#endif /* _amici_model_dirac_h */
3 changes: 2 additions & 1 deletion models/model_events/model_events.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class Model_model_events : public amici::Model_ODE {
2,
2,
6,
6,
0,
1,
0,
Expand Down Expand Up @@ -232,6 +233,6 @@ class Model_model_events : public amici::Model_ODE {

} // namespace model_model_events

} // namespace amici
} // namespace amici

#endif /* _amici_model_events_h */
3 changes: 2 additions & 1 deletion models/model_jakstat_adjoint/model_jakstat_adjoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class Model_model_jakstat_adjoint : public amici::Model_ODE {
0,
0,
0,
0,
1,
2,
1,
Expand Down Expand Up @@ -210,6 +211,6 @@ class Model_model_jakstat_adjoint : public amici::Model_ODE {

} // namespace model_model_jakstat_adjoint

} // namespace amici
} // namespace amici

#endif /* _amici_model_jakstat_adjoint_h */
3 changes: 2 additions & 1 deletion models/model_jakstat_adjoint_o2/model_jakstat_adjoint_o2.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class Model_model_jakstat_adjoint_o2 : public amici::Model_ODE {
0,
0,
0,
0,
18,
10,
2,
Expand Down Expand Up @@ -210,6 +211,6 @@ class Model_model_jakstat_adjoint_o2 : public amici::Model_ODE {

} // namespace model_model_jakstat_adjoint_o2

} // namespace amici
} // namespace amici

#endif /* _amici_model_jakstat_adjoint_o2_h */
3 changes: 2 additions & 1 deletion models/model_nested_events/model_nested_events.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class Model_model_nested_events : public amici::Model_ODE {
0,
0,
4,
4,
0,
1,
0,
Expand Down Expand Up @@ -210,6 +211,6 @@ class Model_model_nested_events : public amici::Model_ODE {

} // namespace model_model_nested_events

} // namespace amici
} // namespace amici

#endif /* _amici_model_nested_events_h */
3 changes: 2 additions & 1 deletion models/model_neuron/model_neuron.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class Model_model_neuron : public amici::Model_ODE {
1,
1,
1,
1,
0,
1,
0,
Expand Down Expand Up @@ -238,6 +239,6 @@ class Model_model_neuron : public amici::Model_ODE {

} // namespace model_model_neuron

} // namespace amici
} // namespace amici

#endif /* _amici_model_neuron_h */
Loading

0 comments on commit ecbae3f

Please sign in to comment.