diff --git a/CMakeLists.txt b/CMakeLists.txt index bf7ffdf12..8c72e45fd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -278,6 +278,7 @@ set(HEYOKA_SRC_FILES "${CMAKE_CURRENT_SOURCE_DIR}/src/detail/math_wrappers.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/src/detail/logging_impl.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/src/step_callback.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/src/llvm_state_mem_cache.cpp" # NOTE: sleef.cpp needs to be compiled even if we are not # building with sleef support on. "${CMAKE_CURRENT_SOURCE_DIR}/src/detail/sleef.cpp" diff --git a/config.hpp.in b/config.hpp.in index 12f098dd7..f528051cf 100644 --- a/config.hpp.in +++ b/config.hpp.in @@ -88,4 +88,17 @@ } // clang-format on +// C++20 constinit. +// NOTE: this seems to be buggy currently on MSVC: +// https://github.com/bluescarni/mppp/issues/291 +#if defined(__cpp_constinit) && (!defined(_MSC_VER) || defined(__clang__)) + +#define HEYOKA_CONSTINIT constinit + +#else + +#define HEYOKA_CONSTINIT + +#endif + #endif diff --git a/doc/changelog.rst b/doc/changelog.rst index ceb634266..4a4f1697c 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -7,6 +7,10 @@ Changelog New ~~~ +- Implement an in-memory cache for ``llvm_state``. The cache is used + to avoid re-optimising and re-compiling LLVM code which has + already been optimised and compiled during the program execution + (`#340 `__). - It is now possible to get the LLVM bitcode of an ``llvm_state`` (`#339 `__). @@ -14,6 +18,9 @@ New Changes ~~~~~~~ +- The optimisation level for an ``llvm_state`` is now clamped + within the ``[0, 3]`` range + (`#340 `__). - The LLVM bitcode is now used internally (instead of the textual representation of the IR) when copying and serialising an ``llvm_state`` @@ -25,6 +32,8 @@ Changes Fix ~~~ +- Fix compilation in C++20 mode + (`#340 `__). - Fix the object file of an ``llvm_state`` not being preserved during copy and deserialisation (`#339 `__). diff --git a/include/heyoka/detail/llvm_func_create.hpp b/include/heyoka/detail/llvm_func_create.hpp new file mode 100644 index 000000000..76b7b114c --- /dev/null +++ b/include/heyoka/detail/llvm_func_create.hpp @@ -0,0 +1,55 @@ +// Copyright 2020, 2021, 2022, 2023 Francesco Biscani (bluescarni@gmail.com), Dario Izzo (dario.izzo@gmail.com) +// +// This file is part of the heyoka library. +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef HEYOKA_DETAIL_LLVM_FUNC_CREATE_HPP +#define HEYOKA_DETAIL_LLVM_FUNC_CREATE_HPP + +#include +#include +#include +#include + +#include + +#include +#include + +#include + +HEYOKA_BEGIN_NAMESPACE + +namespace detail +{ + +// Helper to create an LLVM function. +// NOTE: the purpose of this helper is to check that the function was created with +// the requested name: LLVM will silently change the name of the created function +// if it already exists in a module, and in some cases we want to prevent +// this from happening. +template +llvm::Function *llvm_func_create(llvm::FunctionType *tp, llvm::Function::LinkageTypes linkage, const std::string &name, + Args &&...args) +{ + llvm::Function *ret = llvm::Function::Create(tp, linkage, name, std::forward(args)...); + assert(ret != nullptr); + + if (ret->getName() != name) { + // Remove function before throwing. + ret->eraseFromParent(); + + throw std::invalid_argument(fmt::format("Unable to create an LLVM function with name '{}'", name)); + } + + return ret; +} + +} // namespace detail + +HEYOKA_END_NAMESPACE + +#endif diff --git a/include/heyoka/llvm_state.hpp b/include/heyoka/llvm_state.hpp index fd05287d9..c05fe8b8f 100644 --- a/include/heyoka/llvm_state.hpp +++ b/include/heyoka/llvm_state.hpp @@ -155,6 +155,10 @@ class HEYOKA_DLL_PUBLIC llvm_state HEYOKA_DLL_LOCAL void check_uncompiled(const char *) const; HEYOKA_DLL_LOCAL void check_compiled(const char *) const; + // Helper to clamp the optimisation level to + // the [0, 3] range. + static unsigned clamp_opt_level(unsigned); + // Implementation details for the variadic constructor. template static auto kw_args_ctor_impl(KwArgs &&...kw_args) @@ -183,6 +187,7 @@ class HEYOKA_DLL_PUBLIC llvm_state return 3; } }(); + opt_level = clamp_opt_level(opt_level); // Fast math flag (defaults to false). auto fmath = [&p]() -> bool { @@ -211,6 +216,10 @@ class HEYOKA_DLL_PUBLIC llvm_state // end of a constructor. HEYOKA_DLL_LOCAL void ctor_setup_math_flags(); + // Low-level implementation details for compilation. + HEYOKA_DLL_LOCAL void compile_impl(); + HEYOKA_DLL_LOCAL void add_obj_trigger(); + // Meta-programming for the kwargs ctor. Enabled if: // - there is at least 1 argument (i.e., cannot act as a def ctor), // - if there is only 1 argument, it cannot be of type llvm_state @@ -237,15 +246,15 @@ class HEYOKA_DLL_PUBLIC llvm_state llvm::Module &module(); ir_builder &builder(); llvm::LLVMContext &context(); - unsigned &opt_level(); [[nodiscard]] const std::string &module_name() const; [[nodiscard]] const llvm::Module &module() const; [[nodiscard]] const ir_builder &builder() const; [[nodiscard]] const llvm::LLVMContext &context() const; - [[nodiscard]] const unsigned &opt_level() const; [[nodiscard]] bool fast_math() const; [[nodiscard]] bool force_avx512() const; + [[nodiscard]] unsigned get_opt_level() const; + void set_opt_level(unsigned); [[nodiscard]] std::string get_ir() const; [[nodiscard]] std::string get_bc() const; @@ -258,21 +267,41 @@ class HEYOKA_DLL_PUBLIC llvm_state void optimise(); [[nodiscard]] bool is_compiled() const; - [[nodiscard]] bool has_object_code() const; void compile(); std::uintptr_t jit_lookup(const std::string &); [[nodiscard]] llvm_state make_similar() const; + + // Cache management. + static std::size_t get_memcache_size(); + static std::size_t get_memcache_limit(); + static void set_memcache_limit(std::size_t); + static void clear_memcache(); }; +namespace detail +{ + +// The value contained in the in-memory cache. +struct llvm_mc_value { + std::string opt_bc, opt_ir, obj; +}; + +// Cache lookup and insertion. +std::optional llvm_state_mem_cache_lookup(const std::string &, unsigned); +void llvm_state_mem_cache_try_insert(std::string, unsigned, llvm_mc_value); + +} // namespace detail + HEYOKA_END_NAMESPACE // Archive version changelog: // - version 1: got rid of the inline_functions setting; // - version 2: added the force_avx512 setting; -// - version 3: added the bitcode snapshot. +// - version 3: added the bitcode snapshot, simplified +// compilation logic. BOOST_CLASS_VERSION(heyoka::llvm_state, 3) #endif diff --git a/src/detail/event_detection.cpp b/src/detail/event_detection.cpp index c6af77b14..ea738f747 100644 --- a/src/detail/event_detection.cpp +++ b/src/detail/event_detection.cpp @@ -27,6 +27,7 @@ #include #include #include +#include #if defined(HEYOKA_HAVE_REAL128) @@ -58,6 +59,7 @@ #include #include +#include #include #include #include @@ -259,7 +261,7 @@ std::tuple bracketed_root_find(const T *poly, std::uint32_t order, T lb, // NOTE: iter limit will be derived from the number of binary digits // in the significand. - const auto iter_limit = [&]() { + const boost::uintmax_t iter_limit = [&]() { #if defined(HEYOKA_HAVE_REAL) if constexpr (std::is_same_v) { // NOTE: we use lb here, but any of lb, ub or the poly @@ -267,15 +269,7 @@ std::tuple bracketed_root_find(const T *poly, std::uint32_t order, T lb, // working precision of the root finding scheme. // NOTE: since we use bisection for mppp::real, we need to allow // for more iterations than the number of digits. - const auto ret = boost::numeric_cast(lb.get_prec()); - - // LCOV_EXCL_START - if (ret > std::numeric_limits::max() / 2u) { - throw std::overflow_error("Overflow condition detected in bracketed_root_find()"); - } - // LCOV_EXCL_STOP - - return ret * 2u; + return boost::safe_numerics::safe(lb.get_prec()) * 2; } else { #endif return boost::numeric_cast(std::numeric_limits::digits); @@ -401,12 +395,7 @@ llvm::Function *add_poly_translator_1(llvm_state &s, llvm::Type *fp_t, std::uint auto *ft = llvm::FunctionType::get(builder.getVoidTy(), fargs, false); assert(ft != nullptr); // LCOV_EXCL_LINE // Now create the function. - auto *f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, "poly_translate_1", &s.module()); - // LCOV_EXCL_START - if (f == nullptr) { - throw std::invalid_argument("Unable to create a function for polynomial translation"); - } - // LCOV_EXCL_STOP + auto *f = llvm_func_create(ft, llvm::Function::ExternalLinkage, "poly_translate_1", &s.module()); // Set the names/attributes of the function arguments. auto *out_ptr = f->args().begin(); @@ -582,12 +571,7 @@ llvm::Function *llvm_add_poly_rtscc(llvm_state &s, llvm::Type *fp_t, std::uint32 auto *ft = llvm::FunctionType::get(builder.getVoidTy(), fargs, false); assert(ft != nullptr); // LCOV_EXCL_LINE // Now create the function. - auto *f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, "poly_rtscc", &md); - // LCOV_EXCL_START - if (f == nullptr) { - throw std::invalid_argument("Unable to create an rtscc function"); - } - // LCOV_EXCL_STOP + auto *f = llvm_func_create(ft, llvm::Function::ExternalLinkage, "poly_rtscc", &md); // Set the names/attributes of the function arguments. // NOTE: out_ptr1/2 are used both in read and write mode, @@ -694,12 +678,7 @@ llvm::Function *llvm_add_fex_check(llvm_state &s, llvm::Type *fp_t, std::uint32_ auto *ft = llvm::FunctionType::get(builder.getVoidTy(), fargs, false); assert(ft != nullptr); // LCOV_EXCL_LINE // Now create the function. - auto *f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, "fex_check", &md); - // LCOV_EXCL_START - if (f == nullptr) { - throw std::invalid_argument("Unable to create an fex_check function"); - } - // LCOV_EXCL_STOP + auto *f = llvm_func_create(ft, llvm::Function::ExternalLinkage, "fex_check", &md); // Set the names/attributes of the function arguments. auto *cf_ptr = f->args().begin(); diff --git a/src/detail/llvm_helpers.cpp b/src/detail/llvm_helpers.cpp index 27d08bdd0..965eec5c3 100644 --- a/src/detail/llvm_helpers.cpp +++ b/src/detail/llvm_helpers.cpp @@ -73,6 +73,7 @@ #endif +#include #include #include #include @@ -209,7 +210,7 @@ llvm::Type *to_llvm_type_impl(llvm::LLVMContext &c, const std::type_info &tp, bo { const auto it = type_map.find(tp); - const auto *err_msg = "Unable to associate the C++ type '{}' to an LLVM type"; + constexpr auto *err_msg = "Unable to associate the C++ type '{}' to an LLVM type"; if (it == type_map.end()) { // LCOV_EXCL_START @@ -790,11 +791,7 @@ llvm::CallInst *llvm_invoke_external(llvm_state &s, const std::string &name, llv arg_types.push_back(a->getType()); } auto *ft = llvm::FunctionType::get(ret_type, arg_types, false); - callee_f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, name, &s.module()); - if (callee_f == nullptr) { - throw std::invalid_argument( - fmt::format("Unable to create the prototype for the external function '{}'", name)); - } + callee_f = llvm_func_create(ft, llvm::Function::ExternalLinkage, name, &s.module()); // Add the function attributes. for (const auto &att : attrs) { @@ -3194,8 +3191,7 @@ void llvm_add_inv_kep_E_wrapper(llvm_state &s, llvm::Type *scal_t, std::uint32_t // The return type is void. auto *ft = llvm::FunctionType::get(builder.getVoidTy(), fargs, false); // Create the function - auto *f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, name, &md); - assert(f != nullptr); // LCOV_EXCL_LINE + auto *f = llvm_func_create(ft, llvm::Function::ExternalLinkage, name, &md); // Fetch the current insertion block. auto *orig_bb = builder.GetInsertBlock(); diff --git a/src/expression_cfunc.cpp b/src/expression_cfunc.cpp index f825a7431..02f0827a6 100644 --- a/src/expression_cfunc.cpp +++ b/src/expression_cfunc.cpp @@ -63,6 +63,7 @@ #include #include +#include #include #include #include @@ -1598,12 +1599,7 @@ auto add_cfunc_impl(llvm_state &s, const std::string &name, const F &fn, std::ui // Append ".strided" to the function name. const auto sname = name + ".strided"; // Now create the function. - auto *f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, sname, &md); - if (f == nullptr) { - // LCOV_EXCL_START - throw std::invalid_argument(fmt::format("Unable to create a compiled function with name '{}'", sname)); - // LCOV_EXCL_STOP - } + auto *f = llvm_func_create(ft, llvm::Function::ExternalLinkage, sname, &md); // NOTE: a cfunc cannot call itself recursively. f->addFnAttr(llvm::Attribute::NoRecurse); @@ -1661,12 +1657,7 @@ auto add_cfunc_impl(llvm_state &s, const std::string &name, const F &fn, std::ui fargs.pop_back(); ft = llvm::FunctionType::get(builder.getVoidTy(), fargs, false); assert(ft != nullptr); // LCOV_EXCL_LINE - f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, name, &md); - if (f == nullptr) { - // LCOV_EXCL_START - throw std::invalid_argument(fmt::format("Unable to create a compiled function with name '{}'", name)); - // LCOV_EXCL_STOP - } + f = llvm_func_create(ft, llvm::Function::ExternalLinkage, name, &md); // Set the names/attributes of the function arguments. out_ptr = f->args().begin(); diff --git a/src/llvm_state.cpp b/src/llvm_state.cpp index 0f68c1892..3971a73d5 100644 --- a/src/llvm_state.cpp +++ b/src/llvm_state.cpp @@ -53,6 +53,7 @@ #include #include #include +#include #include #include #include @@ -130,6 +131,7 @@ #endif +#include #include #include #include @@ -249,7 +251,7 @@ target_features get_target_features_impl() // Machinery to initialise the native target in // LLVM. This needs to be done only once. // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -std::once_flag nt_inited; +HEYOKA_CONSTINIT std::once_flag nt_inited; void init_native_target() { @@ -311,9 +313,6 @@ std::uint32_t recommended_simd_size() // Implementation of the jit class. struct llvm_state::jit { - // NOTE: this is the llvm_state containing - // the jit instance. - const llvm_state *m_state = nullptr; std::unique_ptr m_lljit; std::unique_ptr m_tm; std::unique_ptr m_ctx; @@ -322,7 +321,7 @@ struct llvm_state::jit { #endif std::optional m_object_file; - explicit jit(const llvm_state *state) : m_state(state) + jit() { // Ensure the native target is inited. detail::init_native_target(); @@ -353,8 +352,8 @@ struct llvm_state::jit { // LCOV_EXCL_STOP m_lljit = std::move(*lljit); - // Setup the machinery to cache the module's binary code - // when it is lazily generated. + // Setup the machinery to store the module's binary code + // when it is generated. m_lljit->getObjTransformLayer().setTransform([this](std::unique_ptr obj_buffer) { assert(obj_buffer); @@ -426,7 +425,6 @@ struct llvm_state::jit { #endif } - jit() = delete; jit(const jit &) = delete; jit(jit &&) = delete; jit &operator=(const jit &) = delete; @@ -512,7 +510,7 @@ namespace // Helper to load object code into a jit. template -void llvm_state_add_obj_to_jit(Jit &j, const std::string &obj) +void llvm_state_add_obj_to_jit(Jit &j, std::string obj) { llvm::SmallVector buffer(obj.begin(), obj.end()); auto err = j.m_lljit->addObjectFile(std::make_unique(std::move(buffer))); @@ -534,7 +532,7 @@ void llvm_state_add_obj_to_jit(Jit &j, const std::string &obj) // NOTE: this function at the moment is used when m_object_file // is supposed to be empty. assert(!j.m_object_file); - j.m_object_file.emplace(obj); + j.m_object_file.emplace(std::move(obj)); } // Helper to create an LLVM module from bitcode. @@ -575,7 +573,7 @@ auto llvm_state_bc_to_module(const std::string &module_name, const std::string & } // namespace detail llvm_state::llvm_state(std::tuple &&tup) - : m_jitter(std::make_unique(this)), m_opt_level(std::get<1>(tup)), m_fast_math(std::get<2>(tup)), + : m_jitter(std::make_unique()), m_opt_level(std::get<1>(tup)), m_fast_math(std::get<2>(tup)), m_force_avx512(std::get<3>(tup)), m_module_name(std::move(std::get<0>(tup))) { // Create the module. @@ -599,13 +597,13 @@ llvm_state::llvm_state(const llvm_state &other) // NOTE: start off by: // - creating a new jit, // - copying over the options from other. - : m_jitter(std::make_unique(this)), m_opt_level(other.m_opt_level), m_fast_math(other.m_fast_math), + : m_jitter(std::make_unique()), m_opt_level(other.m_opt_level), m_fast_math(other.m_fast_math), m_force_avx512(other.m_force_avx512), m_module_name(other.m_module_name) { - if (other.is_compiled() && other.m_jitter->m_object_file) { - // 'other' was compiled and code was generated. + if (other.is_compiled()) { + // 'other' was compiled. // We leave module and builder empty, copy over the - // IR/bitcode snapshots and add the cached compiled module + // IR/bitcode snapshots and add the compiled module // to the jit. m_ir_snapshot = other.m_ir_snapshot; m_bc_snapshot = other.m_bc_snapshot; @@ -613,68 +611,22 @@ llvm_state::llvm_state(const llvm_state &other) // NOLINTNEXTLINE(bugprone-unchecked-optional-access) detail::llvm_state_add_obj_to_jit(*m_jitter, *other.m_jitter->m_object_file); } else { - // 'other' has not been compiled yet, or - // it has been compiled but no code has been - // lazily generated yet. + // 'other' has not been compiled yet. // We will fetch its bitcode and reconstruct - // module and builder. - - // Is other compiled? - const auto other_cmp = other.is_compiled(); - - // Create the module from the bitcode. - // NOTE: branch to avoid expensive copy if other - // has been compiled. - if (other_cmp) { - m_module = detail::llvm_state_bc_to_module(m_module_name, other.m_bc_snapshot, context()); - } else { - m_module = detail::llvm_state_bc_to_module(m_module_name, other.get_bc(), context()); - } + // module and builder. The IR/bitcode snapshots + // are left in their default-constructed (empty) + // state. + m_module = detail::llvm_state_bc_to_module(m_module_name, other.get_bc(), context()); // Create a new builder for the module. m_builder = std::make_unique(context()); // Setup the math flags in the builder. ctor_setup_math_flags(); - - // Compile if needed. - // NOTE: compilation will take care of setting up m_ir_snapshot/m_bc_snapshot. - // If no compilation happens, m_ir_snapshot/m_bc_snapshot are left empty after init. - if (other_cmp) { - // NOTE: we need to temporarily disable optimisations - // before compilation, for the following reason. - // - // Recall that here we are in the case - // in which the 'other' llvm_state has been compiled, but - // no object code has been produced yet. This means the IR - // has already been optimised, and by running another optimisation - // pass now (indirectly, via compile()) we might end - // up modifying the already-optimised IR. - // By temporarily setting m_opt_level to zero, we are preventing - // any modification to the IR and ensuring that, after copying, - // we have exactly reproduced the original llvm_state object. - const auto orig_opt_level = m_opt_level; - m_opt_level = 0; - - compile(); - - // Restore the original optimisation level. - m_opt_level = orig_opt_level; - } } } -// NOTE: this needs to be implemented manually as we need -// to set up correctly the m_state member of the jit instance. -llvm_state::llvm_state(llvm_state &&other) noexcept - : m_jitter(std::move(other.m_jitter)), m_module(std::move(other.m_module)), m_builder(std::move(other.m_builder)), - m_opt_level(other.m_opt_level), m_ir_snapshot(std::move(other.m_ir_snapshot)), - m_bc_snapshot(std::move(other.m_bc_snapshot)), m_fast_math(other.m_fast_math), - m_force_avx512(other.m_force_avx512), m_module_name(std::move(other.m_module_name)) -{ - // Set up m_state. - m_jitter->m_state = this; -} +llvm_state::llvm_state(llvm_state &&) noexcept = default; llvm_state &llvm_state::operator=(const llvm_state &other) { @@ -689,7 +641,6 @@ llvm_state &llvm_state::operator=(const llvm_state &other) // needs to be done in a different order (specifically, we need to // ensure that the LLVM objects in this are destroyed in a specific // order). -// NOTE: we also need to set up correctly the m_state member of the jit instance. llvm_state &llvm_state::operator=(llvm_state &&other) noexcept { if (this != &other) { @@ -698,9 +649,6 @@ llvm_state &llvm_state::operator=(llvm_state &&other) noexcept m_module = std::move(other.m_module); m_jitter = std::move(other.m_jitter); - // Set up m_state. - m_jitter->m_state = this; - // The remaining bits. m_opt_level = other.m_opt_level; m_ir_snapshot = std::move(other.m_ir_snapshot); @@ -715,27 +663,29 @@ llvm_state &llvm_state::operator=(llvm_state &&other) noexcept llvm_state::~llvm_state() { - // NOTE: if this has not been moved-from, ensure - // the m_state member of the jit is pointing to this. + // Sanity checks in debug mode. if (m_jitter) { - assert(m_jitter->m_state == this); + if (is_compiled()) { + assert(m_jitter->m_object_file); + assert(!m_builder); + } else { + assert(!m_jitter->m_object_file); + assert(m_builder); + assert(m_ir_snapshot.empty()); + assert(m_bc_snapshot.empty()); + } } + + assert(m_opt_level <= 3u); } template void llvm_state::save_impl(Archive &ar, unsigned) const { - // Start by establishing if the state is compiled and binary - // code has been emitted. - // NOTE: we need both flags when deserializing. + // Start by establishing if the state is compiled. const auto cmp = is_compiled(); ar << cmp; - const auto with_obj = static_cast(m_jitter->m_object_file); - ar << with_obj; - - assert(!with_obj || cmp); - // Store the config options. ar << m_opt_level; ar << m_fast_math; @@ -752,22 +702,27 @@ void llvm_state::save_impl(Archive &ar, unsigned) const ar << get_bc(); } - if (with_obj) { - // Save the object file if available. + if (cmp) { + // Save the object file. // NOLINTNEXTLINE(bugprone-unchecked-optional-access) ar << *m_jitter->m_object_file; } // Save a copy of the IR snapshot if the state - // is compiled and binary code was emitted. + // is compiled. // NOTE: we want this because otherwise we would // need to re-parse the bitcode during des11n to // restore the IR snapshot. - if (cmp && with_obj) { + if (cmp) { ar << m_ir_snapshot; } } +// NOTE: currently loading from an archive won't interact with the +// memory cache - that is, if the archive contains a compiled module +// not in the cache *before* loading, it won't have been inserted in the cache +// *after* loading. I don't think this is an issue at the moment, but if needed +// we can always implement the feature at a later stage. template void llvm_state::load_impl(Archive &ar, unsigned version) { @@ -783,7 +738,7 @@ void llvm_state::load_impl(Archive &ar, unsigned version) // are primitive types, no need to reset the // addresses. - // Load the status flags from the archive. + // Load the compiled status flag from the archive. // NOTE: not sure why clang-tidy wants cmp to be // const here, as clearly ar >> cmp is going to // write something into it. Perhaps some const_cast @@ -792,12 +747,6 @@ void llvm_state::load_impl(Archive &ar, unsigned version) bool cmp{}; ar >> cmp; - // NOLINTNEXTLINE(misc-const-correctness) - bool with_obj{}; - ar >> with_obj; - - assert(!with_obj || cmp); - // Load the config options. // NOLINTNEXTLINE(misc-const-correctness) unsigned opt_level{}; @@ -821,14 +770,14 @@ void llvm_state::load_impl(Archive &ar, unsigned version) // Recover the object file, if available. std::optional obj_file; - if (with_obj) { + if (cmp) { obj_file.emplace(); ar >> *obj_file; } // Recover the IR snapshot, if available. std::string ir_snapshot; - if (cmp && with_obj) { + if (cmp) { ar >> ir_snapshot; } @@ -844,20 +793,18 @@ void llvm_state::load_impl(Archive &ar, unsigned version) m_builder.reset(); // Reset the jit with a new one. - m_jitter = std::make_unique(this); + m_jitter = std::make_unique(); - if (cmp && with_obj) { + if (cmp) { // Assign the snapshots. m_ir_snapshot = std::move(ir_snapshot); m_bc_snapshot = std::move(bc_snapshot); // Add the object code to the jit. // NOLINTNEXTLINE(bugprone-unchecked-optional-access) - detail::llvm_state_add_obj_to_jit(*m_jitter, *obj_file); + detail::llvm_state_add_obj_to_jit(*m_jitter, std::move(*obj_file)); } else { - // Clear the existing snapshots - // (they will be replaced with the - // actual ir/bitcode if compilation is needed). + // Clear the existing snapshots. m_ir_snapshot.clear(); m_bc_snapshot.clear(); @@ -869,32 +816,6 @@ void llvm_state::load_impl(Archive &ar, unsigned version) // Setup the math flags in the builder. ctor_setup_math_flags(); - - // Compile if needed. - // NOTE: compilation will take care of setting up m_ir_snapshot/m_bc_snapshot. - // If no compilation happens, m_ir_snapshot/m_bc_snapshot are left empty after - // clearing earlier. - if (cmp) { - // NOTE: we need to temporarily disable optimisations - // before compilation, for the following reason. - // - // Recall that here we are in the case - // in which the serialised llvm_state had been compiled, but - // no object code had been produced yet. This means the IR - // had already been optimised, and by running another optimisation - // pass (indirectly, via compile()) now we might end - // up modifying the already-optimised IR. - // By temporarily setting m_opt_level to zero, we are preventing - // any modification to the IR and ensuring that, after deserialisation, - // we have exactly reproduced the original llvm_state object. - const auto orig_opt_level = m_opt_level; - m_opt_level = 0; - - compile(); - - // Restore the original optimisation level. - m_opt_level = orig_opt_level; - } } // LCOV_EXCL_START } catch (...) { @@ -934,11 +855,6 @@ llvm::LLVMContext &llvm_state::context() return m_jitter->get_context(); } -unsigned &llvm_state::opt_level() -{ - return m_opt_level; -} - const llvm::Module &llvm_state::module() const { check_uncompiled(__func__); @@ -956,11 +872,16 @@ const llvm::LLVMContext &llvm_state::context() const return m_jitter->get_context(); } -const unsigned &llvm_state::opt_level() const +unsigned llvm_state::get_opt_level() const { return m_opt_level; } +void llvm_state::set_opt_level(unsigned opt_level) +{ + m_opt_level = clamp_opt_level(opt_level); +} + bool llvm_state::fast_math() const { return m_fast_math; @@ -971,6 +892,11 @@ bool llvm_state::force_avx512() const return m_force_avx512; } +unsigned llvm_state::clamp_opt_level(unsigned opt_level) +{ + return std::min(opt_level, 3u); +} + void llvm_state::check_uncompiled(const char *f) const { if (!m_module) { @@ -1239,6 +1165,59 @@ void llvm_state::optimise() } } +namespace detail +{ + +namespace +{ + +// The name of the function used to trigger the +// materialisation of object code. +constexpr auto obj_trigger_name = "heyoka.obj_trigger"; + +} // namespace + +} // namespace detail + +// NOTE: this adds a public no-op function to the state which is +// used to trigger the generation of object code after compilation. +void llvm_state::add_obj_trigger() +{ + auto &bld = builder(); + + auto *ft = llvm::FunctionType::get(bld.getVoidTy(), {}, false); + assert(ft != nullptr); + auto *f = detail::llvm_func_create(ft, llvm::Function::ExternalLinkage, detail::obj_trigger_name, &module()); + + bld.SetInsertPoint(llvm::BasicBlock::Create(context(), "entry", f)); + bld.CreateRetVoid(); +} + +// NOTE: this function is NOT exception-safe, proper cleanup +// needs to be done externally if needed. +void llvm_state::compile_impl() +{ + // Preconditions. + assert(m_module); + assert(m_builder); + assert(m_ir_snapshot.empty()); + assert(m_bc_snapshot.empty()); + + // Store a snapshot of the current IR and bitcode. + m_ir_snapshot = get_ir(); + m_bc_snapshot = get_bc(); + + // Add the module to the jit (this will clear out m_module). + m_jitter->add_module(std::move(m_module)); + + // Clear out the builder, which won't be usable any more. + m_builder.reset(); + + // Trigger object code materialisation via lookup. + jit_lookup(detail::obj_trigger_name); + assert(m_jitter->m_object_file); +} + // NOTE: we need to emphasise in the docs that compilation // triggers an optimisation pass. void llvm_state::compile() @@ -1251,24 +1230,56 @@ void llvm_state::compile() llvm::raw_string_ostream ostr(out); if (llvm::verifyModule(*m_module, &ostr)) { + // LCOV_EXCL_START throw std::runtime_error( fmt::format("The verification of the module '{}' produced an error:\n{}", m_module_name, ostr.str())); + // LCOV_EXCL_STOP } } + // Add the object materialisation trigger function. + // NOTE: do it **after** verification, on the assumption + // that add_obj_trigger() is implemented correctly. Like this, + // if module verification fails, the user still has the option + // to fix the module and re-attempt compilation without having + // altered the module and without having already added the trigger + // function. + add_obj_trigger(); + try { - // Run the optimisation pass. - optimise(); + // Fetch the bitcode *before* optimisation. + auto orig_bc = get_bc(); - // Store a snapshot of the optimised IR and bitcode before compiling. - m_ir_snapshot = get_ir(); - m_bc_snapshot = get_bc(); + // Combine m_opt_level and m_force_avx512 into a single value, + // as they both affect codegen. + assert(m_opt_level <= 3u); + const auto olevel = m_opt_level + (static_cast(m_force_avx512) << 2); - // Add the module (this will clear out m_module). - m_jitter->add_module(std::move(m_module)); + if (auto cached_data = detail::llvm_state_mem_cache_lookup(orig_bc, olevel)) { + // Cache hit. - // Clear out the builder, which won't be usable any more. - m_builder.reset(); + // Assign the snapshots. + m_ir_snapshot = std::move(cached_data->opt_ir); + m_bc_snapshot = std::move(cached_data->opt_bc); + + // Clear out module and builder. + m_module.reset(); + m_builder.reset(); + + // Assign the object file. + detail::llvm_state_add_obj_to_jit(*m_jitter, std::move(cached_data->obj)); + } else { + // Run the optimisation pass. + optimise(); + + // Run the compilation. + compile_impl(); + + // Try to insert orig_bc into the cache. + detail::llvm_state_mem_cache_try_insert(std::move(orig_bc), olevel, + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + {m_bc_snapshot, m_ir_snapshot, *m_jitter->m_object_file}); + } // LCOV_EXCL_START } catch (...) { // Reset to a def-cted state in case of error, @@ -1285,11 +1296,6 @@ bool llvm_state::is_compiled() const return !m_module; } -bool llvm_state::has_object_code() const -{ - return static_cast(m_jitter->m_object_file); -} - // NOTE: this function will lookup symbol names, // so it does not necessarily return a function // pointer (could be, e.g., a global variable). @@ -1371,10 +1377,7 @@ const std::string &llvm_state::get_object_code() const "Cannot extract the object code from an llvm_state which has not been compiled yet"); } - if (!m_jitter->m_object_file) { - throw std::invalid_argument( - "Cannot extract the object code from an llvm_state if the binary code has not been generated yet"); - } + assert(m_jitter->m_object_file); // NOLINTNEXTLINE(bugprone-unchecked-optional-access) return *m_jitter->m_object_file; @@ -1401,7 +1404,6 @@ std::ostream &operator<<(std::ostream &os, const llvm_state &s) oss << "Module name : " << s.m_module_name << '\n'; oss << "Compiled : " << s.is_compiled() << '\n'; - oss << "Has object code : " << s.has_object_code() << '\n'; oss << "Fast math : " << s.m_fast_math << '\n'; oss << "Force AVX512 : " << s.m_force_avx512 << '\n'; oss << "Optimisation level: " << s.m_opt_level << '\n'; diff --git a/src/llvm_state_mem_cache.cpp b/src/llvm_state_mem_cache.cpp new file mode 100644 index 000000000..821aa0c66 --- /dev/null +++ b/src/llvm_state_mem_cache.cpp @@ -0,0 +1,280 @@ +// Copyright 2020, 2021, 2022, 2023 Francesco Biscani (bluescarni@gmail.com), Dario Izzo (dario.izzo@gmail.com) +// +// This file is part of the heyoka library. +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include + +// This in-memory cache maps the bitcode +// of an LLVM module and an optimisation level to: +// +// - the optimised version of the bitcode, +// - the textual IR corresponding +// to the optimised bitcode, +// - the object code of the optimised bitcode. +// +// The cache invalidation policy is LRU, implemented +// by pairing a linked list to an unordered_map. + +HEYOKA_BEGIN_NAMESPACE + +namespace detail +{ + +namespace +{ + +// Global mutex for thread-safe operations. +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +HEYOKA_CONSTINIT std::mutex mem_cache_mutex; + +// Definition of the data structures for the cache. +using lru_queue_t = std::list>; + +using lru_key_t = lru_queue_t::iterator; + +struct lru_hasher { + std::size_t operator()(const lru_key_t &k) const noexcept + { + auto seed = std::hash{}(k->first); + boost::hash_combine(seed, k->second); + return seed; + } +}; + +struct lru_cmp { + bool operator()(const lru_key_t &k1, const lru_key_t &k2) const noexcept + { + return *k1 == *k2; + } +}; + +// NOTE: use boost::unordered_map because we need heterogeneous lookup. +using lru_map_t = boost::unordered_map; + +// Global variables for the implementation of the cache. +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +lru_queue_t lru_queue; + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,cert-err58-cpp) +lru_map_t lru_map; + +// Size of the cache. +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +HEYOKA_CONSTINIT std::size_t mem_cache_size = 0; + +// NOTE: default cache size limit is 2GB. +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +HEYOKA_CONSTINIT std::uint64_t mem_cache_limit = 2147483648ull; + +// Machinery for heterogeneous lookup into the cache. +// NOTE: this function MUST be invoked while holding the global lock. +auto llvm_state_mem_cache_hl(const std::string &bc, unsigned opt_level) +{ + using compat_key_t = std::pair; + + struct compat_hasher { + std::size_t operator()(const compat_key_t &k) const noexcept + { + auto seed = std::hash{}(k.first); + boost::hash_combine(seed, k.second); + return seed; + } + }; + + struct compat_cmp { + bool operator()(const lru_key_t &k1, const compat_key_t &k2) const noexcept + { + return k1->first == k2.first && k1->second == k2.second; + } + bool operator()(const compat_key_t &k1, const lru_key_t &k2) const noexcept + { + return operator()(k2, k1); + } + }; + + return lru_map.find(std::make_pair(std::cref(bc), opt_level), compat_hasher{}, compat_cmp{}); +} + +// Debug function to run sanity checks on the cache. +// NOTE: this function MUST be invoked while holding the global lock. +void llvm_state_mem_cache_sanity_checks() +{ + assert(lru_queue.size() == lru_map.size()); + + // Check that the computed size of the cache is consistent with mem_cache_size. + assert(std::accumulate(lru_map.begin(), lru_map.end(), boost::safe_numerics::safe(0), + [](const auto &a, const auto &p) { + return a + p.second.opt_bc.size() + p.second.opt_ir.size() + p.second.obj.size(); + }) + == mem_cache_size); +} + +} // namespace + +std::optional llvm_state_mem_cache_lookup(const std::string &bc, unsigned opt_level) +{ + // Lock down. + const std::lock_guard lock(mem_cache_mutex); + + // Sanity checks. + llvm_state_mem_cache_sanity_checks(); + + if (const auto it = llvm_state_mem_cache_hl(bc, opt_level); it == lru_map.end()) { + // Cache miss. + return {}; + } else { + // Cache hit. + + // Move the item to the front of the queue, if needed. + if (const auto queue_it = it->first; queue_it != lru_queue.begin()) { + // NOTE: splice() won't throw. + lru_queue.splice(lru_queue.begin(), lru_queue, queue_it, std::next(queue_it)); + } + + return it->second; + } +} + +void llvm_state_mem_cache_try_insert(std::string bc, unsigned opt_level, llvm_mc_value val) +{ + // Lock down. + const std::lock_guard lock(mem_cache_mutex); + + // Sanity checks. + llvm_state_mem_cache_sanity_checks(); + + // Do a first lookup to check if bc is already in the cache. + // This could happen, e.g., if two threads are compiling the same + // code concurrently. + if (const auto it = llvm_state_mem_cache_hl(bc, opt_level); it != lru_map.end()) { + assert(val.opt_bc == it->second.opt_bc); + assert(val.opt_ir == it->second.opt_ir); + assert(val.obj == it->second.obj); + + return; + } + + // Compute the new cache size. + auto new_cache_size = static_cast(boost::safe_numerics::safe(mem_cache_size) + + val.opt_bc.size() + val.opt_ir.size() + val.obj.size()); + + // Remove items from the cache if we are exceeding + // the limit. + while (new_cache_size > mem_cache_limit && !lru_queue.empty()) { + // Compute the size of the last item in the queue. + const auto cur_it = lru_map.find(std::prev(lru_queue.end())); + assert(cur_it != lru_map.end()); + const auto &cur_val = cur_it->second; + // NOTE: no possibility of overflow here, as cur_size is guaranteed + // not to be greater than mem_cache_size. + const auto cur_size + = static_cast(cur_val.opt_bc.size()) + cur_val.opt_ir.size() + cur_val.obj.size(); + + // NOTE: the next 4 lines cannot throw, which ensures that the + // cache cannot be left in an inconsistent state. + + // Remove the last item in the queue. + lru_map.erase(cur_it); + lru_queue.pop_back(); + + // Update new_cache_size and mem_cache_size. + new_cache_size -= cur_size; + mem_cache_size -= cur_size; + } + + if (new_cache_size > mem_cache_limit) { + // We cleared out the cache and yet insertion of + // bc would still exceed the limit. Exit. + assert(lru_queue.empty()); + assert(mem_cache_size == 0u); + + return; + } + + // Add the new item to the front of the queue. + // NOTE: if this throws, we have not modified lru_map yet, + // no cleanup needed. + lru_queue.emplace_front(std::move(bc), opt_level); + + // Add the new item to the map. + try { + const auto [new_it, ins_flag] = lru_map.emplace(lru_queue.begin(), std::move(val)); + assert(ins_flag); + + // Update mem_cache_size. + mem_cache_size = new_cache_size; + + // LCOV_EXCL_START + } catch (...) { + // Emplacement in lru_map failed, make sure to remove + // the item we just added to lru_queue before re-throwing. + lru_queue.pop_front(); + + throw; + } + // LCOV_EXCL_STOP +} + +} // namespace detail + +std::size_t llvm_state::get_memcache_size() +{ + // Lock down. + const std::lock_guard lock(detail::mem_cache_mutex); + + return detail::mem_cache_size; +} + +std::size_t llvm_state::get_memcache_limit() +{ + // Lock down. + const std::lock_guard lock(detail::mem_cache_mutex); + + return boost::numeric_cast(detail::mem_cache_limit); +} + +void llvm_state::set_memcache_limit(std::size_t new_limit) +{ + // Lock down. + const std::lock_guard lock(detail::mem_cache_mutex); + + detail::mem_cache_limit = boost::numeric_cast(new_limit); +} + +void llvm_state::clear_memcache() +{ + // Lock down. + const std::lock_guard lock(detail::mem_cache_mutex); + + // Sanity checks. + detail::llvm_state_mem_cache_sanity_checks(); + + detail::lru_map.clear(); + detail::lru_queue.clear(); + detail::mem_cache_size = 0; +} + +HEYOKA_END_NAMESPACE diff --git a/src/math/constants.cpp b/src/math/constants.cpp index 34e1563f1..406775b6c 100644 --- a/src/math/constants.cpp +++ b/src/math/constants.cpp @@ -345,7 +345,7 @@ llvm::Function *constant::taylor_c_diff_func(llvm_state &s, llvm::Type *fp_t, st } // NOLINTNEXTLINE(cert-err58-cpp) -const expression pi{func{constant{"pi", detail::pi_constant_func{}, u8"π"}}}; +const expression pi(func(constant("pi", detail::pi_constant_func{}, "π"))); HEYOKA_END_NAMESPACE diff --git a/src/taylor_00.cpp b/src/taylor_00.cpp index d5a92dc17..8e198f0cf 100644 --- a/src/taylor_00.cpp +++ b/src/taylor_00.cpp @@ -61,6 +61,7 @@ #endif #include +#include #include #include #include @@ -189,13 +190,7 @@ auto taylor_add_adaptive_step_with_events(llvm_state &s, const std::string &name auto *ft = llvm::FunctionType::get(builder.getVoidTy(), fargs, false); assert(ft != nullptr); // Now create the function. - auto *f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, name, &md); - // LCOV_EXCL_START - if (f == nullptr) { - throw std::invalid_argument( - fmt::format("Unable to create a function for an adaptive Taylor stepper with name '{}'", name)); - } - // LCOV_EXCL_STOP + auto *f = llvm_func_create(ft, llvm::Function::ExternalLinkage, name, &md); // NOTE: a step function cannot call itself recursively. f->addFnAttr(llvm::Attribute::NoRecurse); @@ -330,11 +325,7 @@ auto taylor_add_adaptive_step(llvm_state &s, const std::string &name, const U &s auto *ft = llvm::FunctionType::get(builder.getVoidTy(), fargs, false); assert(ft != nullptr); // Now create the function. - auto *f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, name, &md); - if (f == nullptr) { - throw std::invalid_argument( - fmt::format("Unable to create a function for an adaptive Taylor stepper with name '{}'", name)); - } + auto *f = llvm_func_create(ft, llvm::Function::ExternalLinkage, name, &md); // NOTE: a step function cannot call itself recursively. f->addFnAttr(llvm::Attribute::NoRecurse); diff --git a/src/taylor_01.cpp b/src/taylor_01.cpp index f3c4590f9..588b19efe 100644 --- a/src/taylor_01.cpp +++ b/src/taylor_01.cpp @@ -71,6 +71,7 @@ #endif #include +#include #include #include #include @@ -1760,15 +1761,8 @@ void taylor_add_d_out_function(llvm_state &s, llvm::Type *fp_scal_t, std::uint32 auto *ft = llvm::FunctionType::get(builder.getVoidTy(), fargs, false); assert(ft != nullptr); // LCOV_EXCL_LINE // Now create the function. - auto *f = llvm::Function::Create( - ft, external_linkage ? llvm::Function::ExternalLinkage : llvm::Function::InternalLinkage, "d_out_f", - &s.module()); - // LCOV_EXCL_START - if (f == nullptr) { - throw std::invalid_argument( - "Unable to create a function for the dense output in an adaptive Taylor integrator"); - } - // LCOV_EXCL_STOP + auto *f = llvm_func_create(ft, external_linkage ? llvm::Function::ExternalLinkage : llvm::Function::InternalLinkage, + "d_out_f", &s.module()); // Set the names/attributes of the function arguments. auto *out_ptr = f->args().begin(); @@ -1990,12 +1984,7 @@ void continuous_output::add_c_out_function(std::uint32_t order, std::uint32_t auto *ft = llvm::FunctionType::get(builder.getVoidTy(), fargs, false); assert(ft != nullptr); // LCOV_EXCL_LINE // Now create the function. - auto *f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, "c_out", &md); - // LCOV_EXCL_START - if (f == nullptr) { - throw std::invalid_argument("Unable to create a function for continuous output in a Taylor integrator"); - } - // LCOV_EXCL_STOP + auto *f = detail::llvm_func_create(ft, llvm::Function::ExternalLinkage, "c_out", &md); // Set the names/attributes of the function arguments. auto *out_ptr = f->args().begin(); @@ -2475,12 +2464,7 @@ void continuous_output_batch::add_c_out_function(std::uint32_t order, std::ui auto *ft = llvm::FunctionType::get(builder.getVoidTy(), fargs, false); assert(ft != nullptr); // LCOV_EXCL_LINE // Now create the function. - auto *f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, "c_out", &md); - // LCOV_EXCL_START - if (f == nullptr) { - throw std::invalid_argument("Unable to create a function for continuous output in a Taylor integrator"); - } - // LCOV_EXCL_STOP + auto *f = detail::llvm_func_create(ft, llvm::Function::ExternalLinkage, "c_out", &md); // Set the names/attributes of the function arguments. auto *out_ptr = f->args().begin(); diff --git a/src/taylor_02.cpp b/src/taylor_02.cpp index 2bb8231c4..4c0f9c217 100644 --- a/src/taylor_02.cpp +++ b/src/taylor_02.cpp @@ -59,6 +59,7 @@ #include #include +#include #include #include #include @@ -1845,11 +1846,7 @@ auto taylor_add_jet_impl(llvm_state &s, const std::string &name, const U &sys, s auto *ft = llvm::FunctionType::get(builder.getVoidTy(), fargs, false); assert(ft != nullptr); // LCOV_EXCL_LINE // Now create the function. - auto *f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, name, &md); - if (f == nullptr) { - throw std::invalid_argument(fmt::format( - "Unable to create a function for the computation of the jet of Taylor derivatives with name '{}'", name)); - } + auto *f = llvm_func_create(ft, llvm::Function::ExternalLinkage, name, &md); // NOTE: a jet function cannot call itself recursively. f->addFnAttr(llvm::Attribute::NoRecurse); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 3909efc90..1c512f18a 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -140,6 +140,7 @@ ADD_HEYOKA_TESTCASE(model_fixed_centres) ADD_HEYOKA_TESTCASE(model_rotating) ADD_HEYOKA_TESTCASE(model_mascon) ADD_HEYOKA_TESTCASE(step_callback) +ADD_HEYOKA_TESTCASE(llvm_state_mem_cache) if(HEYOKA_WITH_MPPP AND mp++_WITH_MPFR) ADD_HEYOKA_TESTCASE(event_detection_mp) diff --git a/test/constants.cpp b/test/constants.cpp index 100e94c38..3577bfe7b 100644 --- a/test/constants.cpp +++ b/test/constants.cpp @@ -130,7 +130,7 @@ TEST_CASE("pi stream") oss << heyoka::pi; - REQUIRE(oss.str() == u8"π"); + REQUIRE(oss.str() == "π"); } TEST_CASE("pi diff") diff --git a/test/llvm_state.cpp b/test/llvm_state.cpp index 2e369b675..b32b9e1ee 100644 --- a/test/llvm_state.cpp +++ b/test/llvm_state.cpp @@ -10,12 +10,20 @@ #include #include +#include #include #include #include #include #include +#include +#include +#include +#include +#include +#include + #if defined(HEYOKA_HAVE_REAL128) #include @@ -69,12 +77,27 @@ TEST_CASE("empty state") REQUIRE(!s.get_bc().empty()); REQUIRE(!s.get_ir().empty()); + REQUIRE(s.get_opt_level() == 3u); // Print also some info on the FP types. std::cout << "Double digits : " << std::numeric_limits::digits << '\n'; std::cout << "Long double digits: " << std::numeric_limits::digits << '\n'; } +TEST_CASE("opt level clamping") +{ + // Opt level clamping on construction. + auto s = llvm_state(kw::opt_level = 4u, kw::fast_math = true); + REQUIRE(s.get_opt_level() == 3u); + + // Opt level clamping on setter. + s = llvm_state{kw::mname = "foobarizer"}; + s.set_opt_level(0u); + REQUIRE(s.get_opt_level() == 0u); + s.set_opt_level(42u); + REQUIRE(s.get_opt_level() == 3u); +} + TEST_CASE("copy semantics") { auto [x, y] = make_vars("x", "y"); @@ -88,10 +111,9 @@ TEST_CASE("copy semantics") taylor_add_jet(s, "jet", {x * y, y * x}, 1, 1, true, false); REQUIRE(s.module_name() == "sample state"); - REQUIRE(s.opt_level() == 2u); + REQUIRE(s.get_opt_level() == 2u); REQUIRE(s.fast_math()); REQUIRE(!s.is_compiled()); - REQUIRE(!s.has_object_code()); const auto orig_ir = s.get_ir(); const auto orig_bc = s.get_bc(); @@ -99,10 +121,9 @@ TEST_CASE("copy semantics") auto s2 = s; REQUIRE(s2.module_name() == "sample state"); - REQUIRE(s2.opt_level() == 2u); + REQUIRE(s2.get_opt_level() == 2u); REQUIRE(s2.fast_math()); REQUIRE(!s2.is_compiled()); - REQUIRE(!s2.has_object_code()); REQUIRE(s2.get_ir() == orig_ir); REQUIRE(s2.get_bc() == orig_bc); @@ -119,52 +140,7 @@ TEST_CASE("copy semantics") REQUIRE(jet[3] == 6); } - // Compile, but don't generate code, and copy. - { - std::vector jet{2, 3, 0, 0}; - - llvm_state s{kw::mname = "sample state", kw::opt_level = 2u, kw::fast_math = true}; - - taylor_add_jet(s, "jet", {x * y, y * x}, 1, 1, true, false); - - // On-the-fly testing for string repr. - std::ostringstream oss; - oss << s; - const auto orig_repr = oss.str(); - - s.compile(); - - oss.str(""); - oss << s; - const auto compiled_repr = oss.str(); - - REQUIRE(orig_repr != compiled_repr); - - const auto orig_ir = s.get_ir(); - const auto orig_bc = s.get_bc(); - - auto s2 = s; - - REQUIRE(s2.module_name() == "sample state"); - REQUIRE(s2.opt_level() == 2u); - REQUIRE(s2.fast_math()); - REQUIRE(s2.is_compiled()); - REQUIRE(!s2.has_object_code()); - - REQUIRE(s2.get_ir() == orig_ir); - REQUIRE(s2.get_bc() == orig_bc); - - auto jptr = reinterpret_cast(s2.jit_lookup("jet")); - - jptr(jet.data(), nullptr, nullptr); - - REQUIRE(jet[0] == 2); - REQUIRE(jet[1] == 3); - REQUIRE(jet[2] == 6); - REQUIRE(jet[3] == 6); - } - - // Compile, generate code, and copy. + // Compile and copy. { std::vector jet{2, 3, 0, 0}; @@ -182,10 +158,9 @@ TEST_CASE("copy semantics") auto s2 = s; REQUIRE(s2.module_name() == "sample state"); - REQUIRE(s2.opt_level() == 2u); + REQUIRE(s2.get_opt_level() == 2u); REQUIRE(s2.fast_math()); REQUIRE(s2.is_compiled()); - REQUIRE(s2.has_object_code()); REQUIRE(s2.get_ir() == orig_ir); REQUIRE(s2.get_bc() == orig_bc); @@ -218,12 +193,6 @@ TEST_CASE("get object code") s.compile(); - REQUIRE_THROWS_MATCHES( - s.get_object_code(), std::invalid_argument, - Message("Cannot extract the object code from an llvm_state if the binary code has not been generated yet")); - - s.jit_lookup("jet"); - REQUIRE(!s.get_object_code().empty()); } } @@ -232,7 +201,7 @@ TEST_CASE("s11n") { auto [x, y] = make_vars("x", "y"); - // Def-cted state, no compilation, no object file. + // Def-cted state, no compilation. { std::stringstream ss; @@ -256,53 +225,15 @@ TEST_CASE("s11n") } REQUIRE(!s.is_compiled()); - REQUIRE(!s.has_object_code()); REQUIRE(s.get_ir() == orig_ir); REQUIRE(s.get_bc() == orig_bc); REQUIRE(s.module_name() == "foo"); - REQUIRE(s.opt_level() == 3u); + REQUIRE(s.get_opt_level() == 3u); REQUIRE(s.fast_math() == false); REQUIRE(s.force_avx512() == false); } - // Compiled state but without object file. - { - std::stringstream ss; - - llvm_state s{kw::force_avx512 = true, kw::mname = "foo"}; - - taylor_add_jet(s, "jet", {x * y, y * x}, 1, 1, true, false); - - s.compile(); - - const auto orig_ir = s.get_ir(); - const auto orig_bc = s.get_bc(); - - { - boost::archive::binary_oarchive oa(ss); - - oa << s; - } - - s = llvm_state{kw::mname = "sample state", kw::opt_level = 2u, kw::fast_math = true}; - - { - boost::archive::binary_iarchive ia(ss); - - ia >> s; - } - - REQUIRE(s.is_compiled()); - REQUIRE(!s.has_object_code()); - REQUIRE(s.module_name() == "foo"); - REQUIRE(s.get_ir() == orig_ir); - REQUIRE(s.get_bc() == orig_bc); - REQUIRE(s.opt_level() == 3u); - REQUIRE(s.fast_math() == false); - REQUIRE(s.force_avx512() == true); - } - - // Compiled state with object file. + // Compiled state. { std::stringstream ss; @@ -315,8 +246,6 @@ TEST_CASE("s11n") const auto orig_ir = s.get_ir(); const auto orig_bc = s.get_bc(); - s.jit_lookup("jet"); - { boost::archive::binary_oarchive oa(ss); @@ -332,11 +261,10 @@ TEST_CASE("s11n") } REQUIRE(s.is_compiled()); - REQUIRE(s.has_object_code()); REQUIRE(s.get_ir() == orig_ir); REQUIRE(s.get_bc() == orig_bc); REQUIRE(s.module_name() == "foo"); - REQUIRE(s.opt_level() == 3u); + REQUIRE(s.get_opt_level() == 3u); REQUIRE(s.fast_math() == false); auto jptr = reinterpret_cast(s.jit_lookup("jet")); @@ -363,7 +291,7 @@ TEST_CASE("make_similar") s.compile(); REQUIRE(s.module_name() == "sample state"); - REQUIRE(s.opt_level() == 2u); + REQUIRE(s.get_opt_level() == 2u); REQUIRE(s.fast_math()); REQUIRE(s.is_compiled()); REQUIRE(s.force_avx512()); @@ -371,7 +299,7 @@ TEST_CASE("make_similar") auto s2 = s.make_similar(); REQUIRE(s2.module_name() == "sample state"); - REQUIRE(s2.opt_level() == 2u); + REQUIRE(s2.get_opt_level() == 2u); REQUIRE(s2.fast_math()); REQUIRE(s2.force_avx512()); REQUIRE(!s2.is_compiled()); @@ -418,3 +346,24 @@ TEST_CASE("force_avx512") REQUIRE(s5.force_avx512()); } } + +TEST_CASE("existing trigger") +{ + using Catch::Matchers::Message; + + llvm_state s; + + auto &bld = s.builder(); + + auto *ft = llvm::FunctionType::get(bld.getVoidTy(), {}, false); + auto *f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, "heyoka.obj_trigger", &s.module()); + + bld.SetInsertPoint(llvm::BasicBlock::Create(s.context(), "entry", f)); + bld.CreateRetVoid(); + + REQUIRE_THROWS_MATCHES(s.compile(), std::invalid_argument, + Message("Unable to create an LLVM function with name 'heyoka.obj_trigger'")); + + // Check that the second function was properly cleaned up. + REQUIRE(std::distance(s.module().begin(), s.module().end()) == 1); +} diff --git a/test/llvm_state_mem_cache.cpp b/test/llvm_state_mem_cache.cpp new file mode 100644 index 000000000..02d8525a6 --- /dev/null +++ b/test/llvm_state_mem_cache.cpp @@ -0,0 +1,123 @@ +// Copyright 2020, 2021, 2022, 2023 Francesco Biscani (bluescarni@gmail.com), Dario Izzo (dario.izzo@gmail.com) +// +// This file is part of the heyoka library. +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include +#include +#include + +#include "catch.hpp" + +using namespace heyoka; + +TEST_CASE("basic") +{ + REQUIRE(llvm_state::get_memcache_size() == 0u); + REQUIRE(llvm_state::get_memcache_limit() > 0u); + + auto ta = taylor_adaptive{model::pendulum(), {1., 0.}}; + + auto cache_size = llvm_state::get_memcache_size(); + REQUIRE(cache_size > 0u); + + ta = taylor_adaptive{model::pendulum(), {1., 0.}}; + + REQUIRE(llvm_state::get_memcache_size() == cache_size); + + ta = taylor_adaptive{model::pendulum(), {1., 0.}, kw::opt_level = 2u}; + + REQUIRE(llvm_state::get_memcache_size() > cache_size); + cache_size = llvm_state::get_memcache_size(); + + ta = taylor_adaptive{model::pendulum(), {1., 0.}, kw::opt_level = 2u, kw::tol = 1e-12}; + REQUIRE(llvm_state::get_memcache_size() > cache_size); + + llvm_state::clear_memcache(); + REQUIRE(llvm_state::get_memcache_size() == 0u); + + ta = taylor_adaptive{model::pendulum(), {1., 0.}}; + cache_size = llvm_state::get_memcache_size(); + + llvm_state::set_memcache_limit(cache_size); + ta = taylor_adaptive{model::pendulum(), {1., 0.}, kw::tol = 1e-12}; + + REQUIRE(llvm_state::get_memcache_size() < cache_size); + + llvm_state::set_memcache_limit(0); + + ta = taylor_adaptive{model::pendulum(), {1., 0.}, kw::tol = 1e-11}; + + REQUIRE(llvm_state::get_memcache_size() == 0u); +} + +TEST_CASE("priority") +{ + // Check that the least recently used items are evicted first. + llvm_state::clear_memcache(); + llvm_state::set_memcache_limit(2048ull * 1024u * 1024u); + + auto ta = taylor_adaptive{model::pendulum(), {1., 0.}, kw::tol = 1e-11}; + const auto size11 = llvm_state::get_memcache_size(); + + llvm_state::clear_memcache(); + ta = taylor_adaptive{model::pendulum(), {1., 0.}, kw::tol = 1e-15}; + const auto size15 = llvm_state::get_memcache_size(); + + ta = taylor_adaptive{model::pendulum(), {1., 0.}, kw::tol = 1e-12}; + const auto size12 = llvm_state::get_memcache_size() - size15; + + llvm_state::set_memcache_limit(llvm_state::get_memcache_size()); + ta = taylor_adaptive{model::pendulum(), {1., 0.}, kw::tol = 1e-11}; + REQUIRE(llvm_state::get_memcache_size() == size12 + size11); + + // Check that cache hit moves element to the front. + llvm_state::clear_memcache(); + llvm_state::set_memcache_limit(2048ull * 1024u * 1024u); + + ta = taylor_adaptive{model::pendulum(), {1., 0.}, kw::tol = 1e-15}; + ta = taylor_adaptive{model::pendulum(), {1., 0.}, kw::tol = 1e-12}; + ta = taylor_adaptive{model::pendulum(), {1., 0.}, kw::tol = 1e-15}; + + llvm_state::set_memcache_limit(llvm_state::get_memcache_size()); + ta = taylor_adaptive{model::pendulum(), {1., 0.}, kw::tol = 1e-11}; + REQUIRE(llvm_state::get_memcache_size() == size15 + size11); +} + +// A test to check that the cache shrinks at the first +// insertion attempt after set_memcache_limit(). +TEST_CASE("shrink test") +{ + llvm_state::clear_memcache(); + llvm_state::set_memcache_limit(2048ull * 1024u * 1024u); + + auto ta = taylor_adaptive{model::pendulum(), {1., 0.}, kw::tol = 1e-11}; + const auto size11 = llvm_state::get_memcache_size(); + + llvm_state::clear_memcache(); + ta = taylor_adaptive{model::pendulum(), {1., 0.}, kw::tol = 1e-15}; + ta = taylor_adaptive{model::pendulum(), {1., 0.}, kw::tol = 1e-12}; + const auto cache_size = llvm_state::get_memcache_size(); + + llvm_state::set_memcache_limit(size11); + REQUIRE(llvm_state::get_memcache_size() == cache_size); + ta = taylor_adaptive{model::pendulum(), {1., 0.}, kw::tol = 1e-11}; + REQUIRE(llvm_state::get_memcache_size() == size11); +} + +// A test to check that the force_avx512 flag is taken +// into account when interacting with the cache. +TEST_CASE("force_avx512 test") +{ + llvm_state::clear_memcache(); + llvm_state::set_memcache_limit(2048ull * 1024u * 1024u); + + auto ta = taylor_adaptive{model::pendulum(), {1., 0.}, kw::tol = 1e-11}; + const auto size11 = llvm_state::get_memcache_size(); + + ta = taylor_adaptive{model::pendulum(), {1., 0.}, kw::tol = 1e-11, kw::force_avx512 = true}; + REQUIRE(llvm_state::get_memcache_size() > size11); +} diff --git a/test/opt_checks.cpp b/test/opt_checks.cpp index 6adbf366c..1ec73f143 100644 --- a/test/opt_checks.cpp +++ b/test/opt_checks.cpp @@ -34,5 +34,5 @@ TEST_CASE("function inlining") ++count; } - REQUIRE(count == 2u); + REQUIRE(count == 3u); }