From 4834d3e767935c90d77c12772c5d57aea3b08bfd Mon Sep 17 00:00:00 2001 From: Dmitriy Sobolev Date: Fri, 7 Feb 2025 17:25:14 +0000 Subject: [PATCH] Refactor: __mem_adjusted_barrier --- .../parallel_backend_sycl_radix_sort_one_wg.h | 24 +++++++------------ 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_radix_sort_one_wg.h b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_radix_sort_one_wg.h index f4327d55439..d288ba9ac6a 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_radix_sort_one_wg.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_radix_sort_one_wg.h @@ -173,14 +173,8 @@ struct __subgroup_radix_sort uint16_t __begin_bit = 0; constexpr uint16_t __end_bit = sizeof(_KeyT) * ::std::numeric_limits::digits; - auto __val_mem_adjusted_barrier = [__it]() { - if constexpr (_SLM_tag_val::value) - __dpl_sycl::__group_barrier(__it); - else - __dpl_sycl::__group_barrier(__it, __dpl_sycl::__fence_space_global_and_local{}); - }; - auto __count_mem_adjusted_barrier = [__it]() { - if constexpr (_SLM_counter::value) + auto __mem_adjusted_barrier = [__it](auto __is_slm) { + if constexpr (decltype(__is_slm)::value) __dpl_sycl::__group_barrier(__it); else __dpl_sycl::__group_barrier(__it, __dpl_sycl::__fence_space_global_and_local{}); @@ -188,7 +182,7 @@ struct __subgroup_radix_sort //copy(move) values construction __block_load<_ValT>(__wi, __src, __values.__v, __n); - __val_mem_adjusted_barrier(); // TODO: check if the barrier can be removed + __mem_adjusted_barrier(_SLM_tag_val{}); // TODO: check if the barrier can be removed while (true) { @@ -218,7 +212,7 @@ struct __subgroup_radix_sort __indices[__i] = *__counters[__i]; *__counters[__i] = __indices[__i] + 1; } - __count_mem_adjusted_barrier(); + __mem_adjusted_barrier(_SLM_counter{}); //2. scan phase { @@ -231,7 +225,7 @@ struct __subgroup_radix_sort _ONEDPL_PRAGMA_UNROLL for (uint16_t __i = 1; __i < __bin_count; ++__i) __bin_sum[__i] = __bin_sum[__i - 1] + __counter_lacc[__wi * __bin_count + __i]; - __count_mem_adjusted_barrier(); + __mem_adjusted_barrier(_SLM_counter{}); //exclusive scan local sum uint16_t __sum_scan = __dpl_sycl::__exclusive_scan_over_group( @@ -243,7 +237,7 @@ struct __subgroup_radix_sort if (__wi == 0) __counter_lacc[0] = 0; - __count_mem_adjusted_barrier(); + __mem_adjusted_barrier(_SLM_counter{}); } _ONEDPL_PRAGMA_UNROLL @@ -257,7 +251,7 @@ struct __subgroup_radix_sort __begin_bit += __radix; //3. "re-order" phase - __dpl_sycl::__group_barrier(__it); + __mem_adjusted_barrier(_SLM_tag_val{}); if (__begin_bit >= __end_bit) { // the last iteration - writing out the result @@ -305,7 +299,7 @@ struct __subgroup_radix_sort __exchange_lacc[__r] = ::std::move(__values.__v[__i]); } } - __val_mem_adjusted_barrier(); + __mem_adjusted_barrier(_SLM_tag_val{}); _ONEDPL_PRAGMA_UNROLL for (uint16_t __i = 0; __i < __block_size; ++__i) @@ -314,7 +308,7 @@ struct __subgroup_radix_sort if (__idx < __n) __values.__v[__i] = ::std::move(__exchange_lacc[__idx]); } - __val_mem_adjusted_barrier(); + __mem_adjusted_barrier(_SLM_tag_val{}); } })); });