From 06f62175a4991e99deb66b0e16b8ca0f26ba8c44 Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Fri, 1 Dec 2023 23:22:33 +0100 Subject: [PATCH] Refactor event handling (#2216) Just extracting 2 functions. No change in functionality. --- include/amici/forwardproblem.h | 15 ++++ src/forwardproblem.cpp | 138 +++++++++++++++++---------------- 2 files changed, 88 insertions(+), 65 deletions(-) diff --git a/include/amici/forwardproblem.h b/include/amici/forwardproblem.h index dfe3bd8f22..91e5d50791 100644 --- a/include/amici/forwardproblem.h +++ b/include/amici/forwardproblem.h @@ -250,6 +250,21 @@ class ForwardProblem { void handleEvent(realtype* tlastroot, bool seflag, bool initial_event); + /** + * @brief Store pre-event model state + * + * @param seflag Secondary event flag + * @param initial_event initial event flag + */ + void store_pre_event_state(bool seflag, bool initial_event); + + /** + * @brief Check for, and if applicable, handle any secondary events + * + * @param tlastroot pointer to the timepoint of the last event + */ + void handle_secondary_event(realtype* tlastroot); + /** * @brief Extract output information for events */ diff --git a/src/forwardproblem.cpp b/src/forwardproblem.cpp index 72fd87c979..13946547ef 100644 --- a/src/forwardproblem.cpp +++ b/src/forwardproblem.cpp @@ -182,6 +182,77 @@ void ForwardProblem::handleEvent( if (model->nz > 0) storeEvent(); + store_pre_event_state(seflag, initial_event); + + if (!initial_event) + model->updateHeaviside(roots_found_); + + applyEventBolus(); + + if (solver->computingFSA()) { + /* compute the new xdot */ + model->fxdot(t_, x_, dx_, xdot_); + applyEventSensiBolusFSA(); + } + + handle_secondary_event(tlastroot); + + /* only reinitialise in the first event fired */ + if (!seflag) { + solver->reInit(t_, x_, dx_); + if (solver->computingFSA()) { + solver->sensReInit(sx_, sdx_); + } + } +} + +void ForwardProblem::storeEvent() { + if (t_ == model->getTimepoint(model->nt() - 1)) { + // call from fillEvent at last timepoint + model->froot(t_, x_, dx_, rootvals_); + for (int ie = 0; ie < model->ne; ie++) { + roots_found_.at(ie) = (nroots_.at(ie) < model->nMaxEvent()) ? 1 : 0; + } + root_idx_.push_back(roots_found_); + } + + if (getRootCounter() < getEventCounter()) { + /* update stored state (sensi) */ + event_states_.at(getRootCounter()) = getSimulationState(); + } else { + /* add stored state (sensi) */ + event_states_.push_back(getSimulationState()); + } + + /* EVENT OUTPUT */ + for (int ie = 0; ie < model->ne; ie++) { + /* only look for roots of the rootfunction not discontinuities */ + if (nroots_.at(ie) >= model->nMaxEvent()) + continue; + + /* only consider transitions false -> true or event filling */ + if (roots_found_.at(ie) != 1 + && t_ != model->getTimepoint(model->nt() - 1)) { + continue; + } + + if (edata && solver->computingASA()) + model->getAdjointStateEventUpdate( + slice(dJzdx_, nroots_.at(ie), model->nx_solver * model->nJ), ie, + nroots_.at(ie), t_, x_, *edata + ); + + nroots_.at(ie)++; + } + + if (t_ == model->getTimepoint(model->nt() - 1)) { + // call from fillEvent at last timepoint + // loop until all events are filled + fillEvents(model->nMaxEvent()); + } +} + +void ForwardProblem::store_pre_event_state(bool seflag, bool initial_event) { /* if we need to do forward sensitivities later on we need to store the old * x and the old xdot */ if (solver->getSensitivityOrder() >= SensitivityOrder::first) { @@ -212,18 +283,9 @@ void ForwardProblem::handleEvent( xdot_disc_.push_back(xdot_); xdot_old_disc_.push_back(xdot_old_); } +} - if (!initial_event) - model->updateHeaviside(roots_found_); - - applyEventBolus(); - - if (solver->computingFSA()) { - /* compute the new xdot */ - model->fxdot(t_, x_, dx_, xdot_); - applyEventSensiBolusFSA(); - } - +void ForwardProblem::handle_secondary_event(realtype* tlastroot) { int secondevent = 0; /* check whether we need to fire a secondary event */ @@ -260,60 +322,6 @@ void ForwardProblem::handleEvent( ); handleEvent(tlastroot, true, false); } - - /* only reinitialise in the first event fired */ - if (!seflag) { - solver->reInit(t_, x_, dx_); - if (solver->computingFSA()) { - solver->sensReInit(sx_, sdx_); - } - } -} - -void ForwardProblem::storeEvent() { - if (t_ == model->getTimepoint(model->nt() - 1)) { - // call from fillEvent at last timepoint - model->froot(t_, x_, dx_, rootvals_); - for (int ie = 0; ie < model->ne; ie++) { - roots_found_.at(ie) = (nroots_.at(ie) < model->nMaxEvent()) ? 1 : 0; - } - root_idx_.push_back(roots_found_); - } - - if (getRootCounter() < getEventCounter()) { - /* update stored state (sensi) */ - event_states_.at(getRootCounter()) = getSimulationState(); - } else { - /* add stored state (sensi) */ - event_states_.push_back(getSimulationState()); - } - - /* EVENT OUTPUT */ - for (int ie = 0; ie < model->ne; ie++) { - /* only look for roots of the rootfunction not discontinuities */ - if (nroots_.at(ie) >= model->nMaxEvent()) - continue; - - /* only consider transitions false -> true or event filling */ - if (roots_found_.at(ie) != 1 - && t_ != model->getTimepoint(model->nt() - 1)) { - continue; - } - - if (edata && solver->computingASA()) - model->getAdjointStateEventUpdate( - slice(dJzdx_, nroots_.at(ie), model->nx_solver * model->nJ), ie, - nroots_.at(ie), t_, x_, *edata - ); - - nroots_.at(ie)++; - } - - if (t_ == model->getTimepoint(model->nt() - 1)) { - // call from fillEvent at last timepoint - // loop until all events are filled - fillEvents(model->nMaxEvent()); - } } void ForwardProblem::handleDataPoint(int /*it*/) {