diff --git a/python/sdist/amici/de_export.py b/python/sdist/amici/de_export.py index 96e25861ab..6451f761a6 100644 --- a/python/sdist/amici/de_export.py +++ b/python/sdist/amici/de_export.py @@ -32,6 +32,7 @@ Set, Tuple, Union, + Literal ) import numpy as np @@ -2838,9 +2839,6 @@ def _generate_c_code(self) -> None: if func_info.generate_body: dec = log_execution_time(f"writing {func_name}.cpp", logger) dec(self._write_function_file)(func_name) - if func_name in sparse_functions and func_info.body: - self._write_function_index(func_name, "colptrs") - self._write_function_index(func_name, "rowvals") for name in self.model.sym_names(): # only generate for those that have nontrivial implementation, @@ -3132,21 +3130,29 @@ def _write_function_file(self, function: str) -> None: self._build_hints.add(fun["build_hint"]) lines.insert(0, fun["include"]) + if function in sparse_functions: + lines.extend(self._generate_function_index(function, "colptrs")) + lines.extend(self._generate_function_index(function, "rowvals")) + # if not body is None: filename = os.path.join(self.model_path, f"{function}.cpp") with open(filename, "w") as fileout: fileout.write("\n".join(lines)) - def _write_function_index(self, function: str, indextype: str) -> None: + def _generate_function_index( + self, function: str, indextype: Literal["colptrs", "rowvals"] + ) -> List[str]: """ - Generate equations and write the C++ code for the function - ``function``. + Generate equations and C++ code for the function ``function``. :param function: name of the function to be written (see ``self.functions``) :param indextype: type of index {'colptrs', 'rowvals'} + + :returns: + The code lines for the respective function index file """ if indextype == "colptrs": values = self.model.colptrs(function) @@ -3233,11 +3239,7 @@ def _write_function_index(self, function: str, indextype: str) -> None: ] ) - filename = f"{function}_{indextype}.cpp" - filename = os.path.join(self.model_path, filename) - - with open(filename, "w") as fileout: - fileout.write("\n".join(lines)) + return lines def _get_function_body(self, function: str, equations: sp.Matrix) -> List[str]: """