Skip to content

Commit

Permalink
Add toggling of emit_cvode flag for codegen (#1384)
Browse files Browse the repository at this point in the history
  • Loading branch information
JCGoran authored Aug 16, 2024
1 parent 14ca6a4 commit 6d3b830
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 5 deletions.
73 changes: 73 additions & 0 deletions src/codegen/codegen_helper_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "parser/c11_driver.hpp"
#include "visitors/visitor_utils.hpp"

#include "utils/logger.hpp"

namespace nmodl {
namespace codegen {
Expand All @@ -24,6 +25,24 @@ using namespace ast;
using symtab::syminfo::NmodlType;
using symtab::syminfo::Status;

/**
* Check whether a given SOLVE block solves a PROCEDURE with any of the CVode methods
*/
static bool check_procedure_has_cvode(const std::shared_ptr<const ast::Ast>& solve_node,
const std::shared_ptr<const ast::Ast>& procedure_node) {
const auto& solve_block = std::dynamic_pointer_cast<const ast::SolveBlock>(solve_node);
const auto& method = solve_block->get_method();
if (!method) {
return false;
}
const auto& method_name = method->get_node_name();

return procedure_node->get_node_name() == solve_block->get_block_name()->get_node_name() &&
(method_name == codegen::naming::AFTER_CVODE_METHOD ||
method_name == codegen::naming::CVODE_T_METHOD ||
method_name == codegen::naming::CVODE_T_V_METHOD);
}

/**
* How symbols are stored in NEURON? See notes written in markdown file.
*
Expand Down Expand Up @@ -152,6 +171,59 @@ void CodegenHelperVisitor::find_ion_variables(const ast::Program& node) {
}
}

/**
* Find whether or not we need to emit CVODE-related code for NEURON
* Notes: we generate CVODE-related code if and only if:
* - there is exactly ONE block being SOLVEd
* - the block is one of the following types:
* - DERIVATIVE
* - KINETIC
* - PROCEDURE being solved with the `after_cvode`, `cvode_t`, or `cvode_t_v` methods
*/
void CodegenHelperVisitor::check_cvode_codegen(const ast::Program& node) {
// find the breakpoint block
const auto& breakpoint_nodes = collect_nodes(node, {AstNodeType::BREAKPOINT_BLOCK});

// do nothing if there are no BREAKPOINT nodes
if (breakpoint_nodes.empty()) {
return;
}

// there can only be one BREAKPOINT block in the entire program
assert(breakpoint_nodes.size() == 1);

const auto& breakpoint_node = std::dynamic_pointer_cast<const ast::BreakpointBlock>(
breakpoint_nodes[0]);

// all (global) kinetic/derivative nodes
const auto& kinetic_or_derivative_nodes =
collect_nodes(node, {AstNodeType::KINETIC_BLOCK, AstNodeType::DERIVATIVE_BLOCK});

// all (global) procedure nodes
const auto& procedure_nodes = collect_nodes(node, {AstNodeType::PROCEDURE_BLOCK});

// find all SOLVE blocks in that BREAKPOINT block
const auto& solve_nodes = collect_nodes(*breakpoint_node, {AstNodeType::SOLVE_BLOCK});

// check whether any of the SOLVE blocks are solving any PROCEDURE with `after_cvode`,
// `cvode_t`, or `cvode_t_v` methods
const auto using_cvode = std::any_of(
solve_nodes.begin(), solve_nodes.end(), [&procedure_nodes](const auto& solve_node) {
return std::any_of(procedure_nodes.begin(),
procedure_nodes.end(),
[&solve_node](const auto& procedure_node) {
return check_procedure_has_cvode(solve_node, procedure_node);
});
});

// only case when we emit CVODE code is if we have exactly one block, and
// that block is either a KINETIC/DERIVATIVE with any method, or a
// PROCEDURE with `after_cvode` method
if (solve_nodes.size() == 1 && (kinetic_or_derivative_nodes.size() || using_cvode)) {
logger->debug("Will emit code for CVODE");
info.emit_cvode = true;
}
}

/**
* Find non-range variables i.e. ones that are not belong to per instance allocation
Expand Down Expand Up @@ -738,6 +810,7 @@ void CodegenHelperVisitor::visit_program(const ast::Program& node) {
find_non_range_variables();
find_table_variables();
find_neuron_global_variables();
check_cvode_codegen(node);
}


Expand Down
2 changes: 1 addition & 1 deletion src/codegen/codegen_helper_visitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class CodegenHelperVisitor: public visitor::ConstAstVisitor {
void find_non_range_variables();
void find_neuron_global_variables();
static void sort_with_mod2c_symbol_order(std::vector<SymbolType>& symbols);

void check_cvode_codegen(const ast::Program& node);
public:
CodegenHelperVisitor() = default;

Expand Down
6 changes: 6 additions & 0 deletions src/codegen/codegen_naming.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ static constexpr char CNEXP_METHOD[] = "cnexp";
/// cvode method in nmodl
static constexpr char AFTER_CVODE_METHOD[] = "after_cvode";

/// cvode_t method in nmodl
static constexpr char CVODE_T_METHOD[] = "cvode_t";

/// cvode_t_v method in nmodl
static constexpr char CVODE_T_V_METHOD[] = "cvode_t_v";

/// sparse method in nmodl
static constexpr char SPARSE_METHOD[] = "sparse";

Expand Down
4 changes: 0 additions & 4 deletions src/codegen/codegen_neuron_cpp_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1176,10 +1176,6 @@ void CodegenNeuronCppVisitor::print_mechanism_register() {
info.semantics[i].name,
i));
}
if (info.emit_cvode) {
mech_register_args.push_back(
"_nrn_mechanism_field<int>{\"_cvode_ieq\", \"cvodeieq\"} /* 0 */");
}

printer->add_multi_line(fmt::format("{}", fmt::join(mech_register_args, ",\n")));

Expand Down
130 changes: 130 additions & 0 deletions test/unit/codegen/codegen_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,24 @@ std::string run_codegen_helper_visitor(const std::string& text) {
return variables;
}

CodegenInfo run_codegen_helper_get_info(const std::string& text) {
const auto& ast = NmodlDriver().parse_string(text);
/// construct symbol table and run codegen helper visitor
SymtabVisitor{}.visit_program(*ast);
KineticBlockVisitor{}.visit_program(*ast);
SymtabVisitor{}.visit_program(*ast);
SteadystateVisitor{}.visit_program(*ast);
SymtabVisitor{}.visit_program(*ast);
NeuronSolveVisitor{}.visit_program(*ast);
SolveBlockVisitor{}.visit_program(*ast);
SymtabVisitor{true}.visit_program(*ast);

CodegenHelperVisitor v;
const auto info = v.analyze(*ast);

return info;
}

SCENARIO("unusual / failing mod files", "[codegen][var_order]") {
GIVEN("cal_mig.mod : USEION variables declared as RANGE") {
std::string nmodl_text = R"(
Expand Down Expand Up @@ -293,3 +311,115 @@ TEST_CASE("Check ion write/read checks") {
}
}
}

SCENARIO("CVODE codegen") {
GIVEN("a mod file with a single KINETIC block") {
std::string input_nmodl = R"(
STATE {
x
}
KINETIC states {
~ x << (a*c/3.2)
}
BREAKPOINT {
SOLVE states METHOD cnexp
})";

const auto& info = run_codegen_helper_get_info(input_nmodl);
THEN("Emit CVODE") {
REQUIRE(info.emit_cvode);
}
}
GIVEN("a mod file with a single DERIVATIVE block") {
std::string input_nmodl = R"(
STATE {
m
}
BREAKPOINT {
SOLVE state METHOD derivimplicit
}
DERIVATIVE state {
m' = 2 * m
}
)";
const auto& info = run_codegen_helper_get_info(input_nmodl);

THEN("Emit CVODE") {
REQUIRE(info.emit_cvode);
}
}
GIVEN("a mod file with a single PROCEDURE block solved with method `after_cvode`") {
std::string input_nmodl = R"(
BREAKPOINT {
SOLVE state METHOD after_cvode
}
PROCEDURE state() {}
)";

const auto& info = run_codegen_helper_get_info(input_nmodl);

THEN("Emit CVODE") {
REQUIRE(info.emit_cvode);
}
}
GIVEN("a mod file with a single PROCEDURE block NOT solved with method `after_cvode`") {
std::string input_nmodl = R"(
BREAKPOINT {
SOLVE state METHOD cnexp
}
PROCEDURE state() {}
)";

const auto& info = run_codegen_helper_get_info(input_nmodl);

THEN("Do not emit CVODE") {
REQUIRE(!info.emit_cvode);
}
}
GIVEN("a mod file with a DERIVATIVE and a KINETIC block") {
std::string input_nmodl = R"(
STATE {
m
x
}
BREAKPOINT {
SOLVE der METHOD derivimplicit
SOLVE kin METHOD cnexp
}
DERIVATIVE der {
m' = 2 * m
}
KINETIC kin {
~ x << (a*c/3.2)
}
)";

const auto& info = run_codegen_helper_get_info(input_nmodl);

THEN("Do not emit CVODE") {
REQUIRE(!info.emit_cvode);
}
}
GIVEN("a mod file with a PROCEDURE and a DERIVATIVE block") {
std::string input_nmodl = R"(
STATE {
m
}
BREAKPOINT {
SOLVE der METHOD derivimplicit
SOLVE func METHOD cnexp
}
DERIVATIVE der {
m' = 2 * m
}
PROCEDURE func() {
}
)";

const auto& info = run_codegen_helper_get_info(input_nmodl);

THEN("Do not emit CVODE") {
REQUIRE(!info.emit_cvode);
}
}
}

0 comments on commit 6d3b830

Please sign in to comment.