diff --git a/CMakeLists.txt b/CMakeLists.txt index 92c886dfc..c9580973d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -199,6 +199,7 @@ target_link_libraries(xtensor INTERFACE xtl) OPTION(XTENSOR_ENABLE_ASSERT "xtensor bound check" OFF) OPTION(XTENSOR_CHECK_DIMENSION "xtensor dimension check" OFF) +OPTION(XTENSOR_FORCE_TEMPORARY_MEMORY_IN_ASSIGNMENTS "xtensor force the use of temporary memory when assigning instead of an automatic overlap check" ON) OPTION(BUILD_TESTS "xtensor test suite" OFF) OPTION(BUILD_BENCHMARK "xtensor benchmark" OFF) OPTION(DOWNLOAD_GTEST "build gtest from downloaded sources" OFF) @@ -219,6 +220,10 @@ if(XTENSOR_CHECK_DIMENSION) add_definitions(-DXTENSOR_ENABLE_CHECK_DIMENSION) endif() +if(XTENSOR_FORCE_TEMPORARY_MEMORY_IN_ASSIGNMENTS) + add_definitions(-DXTENSOR_FORCE_TEMPORARY_MEMORY_IN_ASSIGNMENTS) +endif() + if(DEFAULT_COLUMN_MAJOR) add_definitions(-DXTENSOR_DEFAULT_LAYOUT=layout_type::column_major) endif() diff --git a/include/xtensor/xbroadcast.hpp b/include/xtensor/xbroadcast.hpp index 798b9cc9d..20b04edab 100644 --- a/include/xtensor/xbroadcast.hpp +++ b/include/xtensor/xbroadcast.hpp @@ -118,6 +118,29 @@ namespace xt return linear_end(c.expression()); } + /************************************* + * overlapping_memory_checker_traits * + *************************************/ + + template + struct overlapping_memory_checker_traits< + E, + std::enable_if_t::value && is_specialization_of::value>> + { + static bool check_overlap(const E& expr, const memory_range& dst_range) + { + if (expr.size() == 0) + { + return false; + } + else + { + using ChildE = std::decay_t; + return overlapping_memory_checker_traits::check_overlap(expr.expression(), dst_range); + } + } + }; + /** * @class xbroadcast * @brief Broadcasted xexpression to a specified shape. diff --git a/include/xtensor/xfunction.hpp b/include/xtensor/xfunction.hpp index 08a3dc1c1..f11362cdb 100644 --- a/include/xtensor/xfunction.hpp +++ b/include/xtensor/xfunction.hpp @@ -162,6 +162,42 @@ namespace xt { }; + /************************************* + * overlapping_memory_checker_traits * + *************************************/ + + template + struct overlapping_memory_checker_traits< + E, + std::enable_if_t::value && is_specialization_of::value>> + { + template = 0> + static bool check_tuple(const std::tuple&, const memory_range&) + { + return false; + } + + template = 0> + static bool check_tuple(const std::tuple& t, const memory_range& dst_range) + { + using ChildE = std::decay_t(t))>; + return overlapping_memory_checker_traits::check_overlap(std::get(t), dst_range) + || check_tuple(t, dst_range); + } + + static bool check_overlap(const E& expr, const memory_range& dst_range) + { + if (expr.size() == 0) + { + return false; + } + else + { + return check_tuple(expr.arguments(), dst_range); + } + } + }; + /************* * xfunction * *************/ diff --git a/include/xtensor/xgenerator.hpp b/include/xtensor/xgenerator.hpp index 551bb7e24..03433adca 100644 --- a/include/xtensor/xgenerator.hpp +++ b/include/xtensor/xgenerator.hpp @@ -76,6 +76,21 @@ namespace xt using size_type = std::size_t; }; + /************************************* + * overlapping_memory_checker_traits * + *************************************/ + + template + struct overlapping_memory_checker_traits< + E, + std::enable_if_t::value && is_specialization_of::value>> + { + static bool check_overlap(const E&, const memory_range&) + { + return false; + } + }; + /** * @class xgenerator * @brief Multidimensional function operating on indices. diff --git a/include/xtensor/xsemantic.hpp b/include/xtensor/xsemantic.hpp index 41f14951c..8aa76cfc9 100644 --- a/include/xtensor/xsemantic.hpp +++ b/include/xtensor/xsemantic.hpp @@ -217,6 +217,29 @@ namespace xt template using disable_xcontainer_semantics = typename std::enable_if::value, R>::type; + + template + class xview_semantic; + + template + struct overlapping_memory_checker_traits< + E, + std::enable_if_t::value && is_crtp_base_of::value>> + { + static bool check_overlap(const E& expr, const memory_range& dst_range) + { + if (expr.size() == 0) + { + return false; + } + else + { + using ChildE = std::decay_t; + return overlapping_memory_checker_traits::check_overlap(expr.expression(), dst_range); + } + } + }; + /** * @class xview_semantic * @brief Implementation of the xsemantic_base interface for @@ -598,8 +621,22 @@ namespace xt template inline auto xsemantic_base::operator=(const xexpression& e) -> derived_type& { +#ifdef XTENSOR_FORCE_TEMPORARY_MEMORY_IN_ASSIGNMENTS temporary_type tmp(e); return this->derived_cast().assign_temporary(std::move(tmp)); +#else + auto&& this_derived = this->derived_cast(); + auto memory_checker = make_overlapping_memory_checker(this_derived); + if (memory_checker.check_overlap(e.derived_cast())) + { + temporary_type tmp(e); + return this_derived.assign_temporary(std::move(tmp)); + } + else + { + return this->assign(e); + } +#endif } /************************************** diff --git a/include/xtensor/xutils.hpp b/include/xtensor/xutils.hpp index 137d0e70e..21c452489 100644 --- a/include/xtensor/xutils.hpp +++ b/include/xtensor/xutils.hpp @@ -119,6 +119,20 @@ namespace xt using type = T; }; + /*************************************** + * is_specialization_of implementation * + ***************************************/ + + template