diff --git a/src/codegen/codegen_helper_visitor.cpp b/src/codegen/codegen_helper_visitor.cpp index 1c24e2e5b..88d4e424e 100644 --- a/src/codegen/codegen_helper_visitor.cpp +++ b/src/codegen/codegen_helper_visitor.cpp @@ -15,6 +15,7 @@ #include "parser/c11_driver.hpp" #include "visitors/visitor_utils.hpp" +#include "utils/logger.hpp" namespace nmodl { namespace codegen { @@ -24,6 +25,24 @@ using namespace ast; using symtab::syminfo::NmodlType; using symtab::syminfo::Status; +/** + * Check whether a given SOLVE block solves a PROCEDURE with any of the CVode methods + */ +static bool check_procedure_has_cvode(const std::shared_ptr& solve_node, + const std::shared_ptr& procedure_node) { + const auto& solve_block = std::dynamic_pointer_cast(solve_node); + const auto& method = solve_block->get_method(); + if (!method) { + return false; + } + const auto& method_name = method->get_node_name(); + + return procedure_node->get_node_name() == solve_block->get_block_name()->get_node_name() && + (method_name == codegen::naming::AFTER_CVODE_METHOD || + method_name == codegen::naming::CVODE_T_METHOD || + method_name == codegen::naming::CVODE_T_V_METHOD); +} + /** * How symbols are stored in NEURON? See notes written in markdown file. * @@ -152,6 +171,59 @@ void CodegenHelperVisitor::find_ion_variables(const ast::Program& node) { } } +/** + * Find whether or not we need to emit CVODE-related code for NEURON + * Notes: we generate CVODE-related code if and only if: + * - there is exactly ONE block being SOLVEd + * - the block is one of the following types: + * - DERIVATIVE + * - KINETIC + * - PROCEDURE being solved with the `after_cvode`, `cvode_t`, or `cvode_t_v` methods + */ +void CodegenHelperVisitor::check_cvode_codegen(const ast::Program& node) { + // find the breakpoint block + const auto& breakpoint_nodes = collect_nodes(node, {AstNodeType::BREAKPOINT_BLOCK}); + + // do nothing if there are no BREAKPOINT nodes + if (breakpoint_nodes.empty()) { + return; + } + + // there can only be one BREAKPOINT block in the entire program + assert(breakpoint_nodes.size() == 1); + + const auto& breakpoint_node = std::dynamic_pointer_cast( + breakpoint_nodes[0]); + + // all (global) kinetic/derivative nodes + const auto& kinetic_or_derivative_nodes = + collect_nodes(node, {AstNodeType::KINETIC_BLOCK, AstNodeType::DERIVATIVE_BLOCK}); + + // all (global) procedure nodes + const auto& procedure_nodes = collect_nodes(node, {AstNodeType::PROCEDURE_BLOCK}); + + // find all SOLVE blocks in that BREAKPOINT block + const auto& solve_nodes = collect_nodes(*breakpoint_node, {AstNodeType::SOLVE_BLOCK}); + + // check whether any of the SOLVE blocks are solving any PROCEDURE with `after_cvode`, + // `cvode_t`, or `cvode_t_v` methods + const auto using_cvode = std::any_of( + solve_nodes.begin(), solve_nodes.end(), [&procedure_nodes](const auto& solve_node) { + return std::any_of(procedure_nodes.begin(), + procedure_nodes.end(), + [&solve_node](const auto& procedure_node) { + return check_procedure_has_cvode(solve_node, procedure_node); + }); + }); + + // only case when we emit CVODE code is if we have exactly one block, and + // that block is either a KINETIC/DERIVATIVE with any method, or a + // PROCEDURE with `after_cvode` method + if (solve_nodes.size() == 1 && (kinetic_or_derivative_nodes.size() || using_cvode)) { + logger->debug("Will emit code for CVODE"); + info.emit_cvode = true; + } +} /** * Find non-range variables i.e. ones that are not belong to per instance allocation @@ -738,6 +810,7 @@ void CodegenHelperVisitor::visit_program(const ast::Program& node) { find_non_range_variables(); find_table_variables(); find_neuron_global_variables(); + check_cvode_codegen(node); } diff --git a/src/codegen/codegen_helper_visitor.hpp b/src/codegen/codegen_helper_visitor.hpp index b960bdc06..9258f1cf7 100644 --- a/src/codegen/codegen_helper_visitor.hpp +++ b/src/codegen/codegen_helper_visitor.hpp @@ -75,7 +75,7 @@ class CodegenHelperVisitor: public visitor::ConstAstVisitor { void find_non_range_variables(); void find_neuron_global_variables(); static void sort_with_mod2c_symbol_order(std::vector& symbols); - + void check_cvode_codegen(const ast::Program& node); public: CodegenHelperVisitor() = default; diff --git a/src/codegen/codegen_naming.hpp b/src/codegen/codegen_naming.hpp index c7c10eb91..bc657e6f7 100644 --- a/src/codegen/codegen_naming.hpp +++ b/src/codegen/codegen_naming.hpp @@ -32,6 +32,12 @@ static constexpr char CNEXP_METHOD[] = "cnexp"; /// cvode method in nmodl static constexpr char AFTER_CVODE_METHOD[] = "after_cvode"; +/// cvode_t method in nmodl +static constexpr char CVODE_T_METHOD[] = "cvode_t"; + +/// cvode_t_v method in nmodl +static constexpr char CVODE_T_V_METHOD[] = "cvode_t_v"; + /// sparse method in nmodl static constexpr char SPARSE_METHOD[] = "sparse"; diff --git a/src/codegen/codegen_neuron_cpp_visitor.cpp b/src/codegen/codegen_neuron_cpp_visitor.cpp index 8e0f1de5f..b99b45101 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.cpp +++ b/src/codegen/codegen_neuron_cpp_visitor.cpp @@ -1176,10 +1176,6 @@ void CodegenNeuronCppVisitor::print_mechanism_register() { info.semantics[i].name, i)); } - if (info.emit_cvode) { - mech_register_args.push_back( - "_nrn_mechanism_field{\"_cvode_ieq\", \"cvodeieq\"} /* 0 */"); - } printer->add_multi_line(fmt::format("{}", fmt::join(mech_register_args, ",\n"))); diff --git a/test/unit/codegen/codegen_helper.cpp b/test/unit/codegen/codegen_helper.cpp index a8da8c9b8..5e5006f5d 100644 --- a/test/unit/codegen/codegen_helper.cpp +++ b/test/unit/codegen/codegen_helper.cpp @@ -57,6 +57,24 @@ std::string run_codegen_helper_visitor(const std::string& text) { return variables; } +CodegenInfo run_codegen_helper_get_info(const std::string& text) { + const auto& ast = NmodlDriver().parse_string(text); + /// construct symbol table and run codegen helper visitor + SymtabVisitor{}.visit_program(*ast); + KineticBlockVisitor{}.visit_program(*ast); + SymtabVisitor{}.visit_program(*ast); + SteadystateVisitor{}.visit_program(*ast); + SymtabVisitor{}.visit_program(*ast); + NeuronSolveVisitor{}.visit_program(*ast); + SolveBlockVisitor{}.visit_program(*ast); + SymtabVisitor{true}.visit_program(*ast); + + CodegenHelperVisitor v; + const auto info = v.analyze(*ast); + + return info; +} + SCENARIO("unusual / failing mod files", "[codegen][var_order]") { GIVEN("cal_mig.mod : USEION variables declared as RANGE") { std::string nmodl_text = R"( @@ -293,3 +311,115 @@ TEST_CASE("Check ion write/read checks") { } } } + +SCENARIO("CVODE codegen") { + GIVEN("a mod file with a single KINETIC block") { + std::string input_nmodl = R"( + STATE { + x + } + KINETIC states { + ~ x << (a*c/3.2) + } + BREAKPOINT { + SOLVE states METHOD cnexp + })"; + + const auto& info = run_codegen_helper_get_info(input_nmodl); + THEN("Emit CVODE") { + REQUIRE(info.emit_cvode); + } + } + GIVEN("a mod file with a single DERIVATIVE block") { + std::string input_nmodl = R"( + STATE { + m + } + BREAKPOINT { + SOLVE state METHOD derivimplicit + } + DERIVATIVE state { + m' = 2 * m + } + )"; + const auto& info = run_codegen_helper_get_info(input_nmodl); + + THEN("Emit CVODE") { + REQUIRE(info.emit_cvode); + } + } + GIVEN("a mod file with a single PROCEDURE block solved with method `after_cvode`") { + std::string input_nmodl = R"( + BREAKPOINT { + SOLVE state METHOD after_cvode + } + PROCEDURE state() {} + )"; + + const auto& info = run_codegen_helper_get_info(input_nmodl); + + THEN("Emit CVODE") { + REQUIRE(info.emit_cvode); + } + } + GIVEN("a mod file with a single PROCEDURE block NOT solved with method `after_cvode`") { + std::string input_nmodl = R"( + BREAKPOINT { + SOLVE state METHOD cnexp + } + PROCEDURE state() {} + )"; + + const auto& info = run_codegen_helper_get_info(input_nmodl); + + THEN("Do not emit CVODE") { + REQUIRE(!info.emit_cvode); + } + } + GIVEN("a mod file with a DERIVATIVE and a KINETIC block") { + std::string input_nmodl = R"( + STATE { + m + x + } + BREAKPOINT { + SOLVE der METHOD derivimplicit + SOLVE kin METHOD cnexp + } + DERIVATIVE der { + m' = 2 * m + } + KINETIC kin { + ~ x << (a*c/3.2) + } + )"; + + const auto& info = run_codegen_helper_get_info(input_nmodl); + + THEN("Do not emit CVODE") { + REQUIRE(!info.emit_cvode); + } + } + GIVEN("a mod file with a PROCEDURE and a DERIVATIVE block") { + std::string input_nmodl = R"( + STATE { + m + } + BREAKPOINT { + SOLVE der METHOD derivimplicit + SOLVE func METHOD cnexp + } + DERIVATIVE der { + m' = 2 * m + } + PROCEDURE func() { + } + )"; + + const auto& info = run_codegen_helper_get_info(input_nmodl); + + THEN("Do not emit CVODE") { + REQUIRE(!info.emit_cvode); + } + } +}