Skip to content

Commit

Permalink
Refactor: __mem_adjusted_barrier
Browse files Browse the repository at this point in the history
  • Loading branch information
dmitriy-sobolev committed Feb 7, 2025
1 parent 8a5d554 commit 4834d3e
Showing 1 changed file with 9 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -173,22 +173,16 @@ struct __subgroup_radix_sort
uint16_t __begin_bit = 0;
constexpr uint16_t __end_bit = sizeof(_KeyT) * ::std::numeric_limits<unsigned char>::digits;

auto __val_mem_adjusted_barrier = [__it]() {
if constexpr (_SLM_tag_val::value)
__dpl_sycl::__group_barrier(__it);
else
__dpl_sycl::__group_barrier(__it, __dpl_sycl::__fence_space_global_and_local{});
};
auto __count_mem_adjusted_barrier = [__it]() {
if constexpr (_SLM_counter::value)
auto __mem_adjusted_barrier = [__it](auto __is_slm) {
if constexpr (decltype(__is_slm)::value)
__dpl_sycl::__group_barrier(__it);
else
__dpl_sycl::__group_barrier(__it, __dpl_sycl::__fence_space_global_and_local{});
};

//copy(move) values construction
__block_load<_ValT>(__wi, __src, __values.__v, __n);
__val_mem_adjusted_barrier(); // TODO: check if the barrier can be removed
__mem_adjusted_barrier(_SLM_tag_val{}); // TODO: check if the barrier can be removed

while (true)
{
Expand Down Expand Up @@ -218,7 +212,7 @@ struct __subgroup_radix_sort
__indices[__i] = *__counters[__i];
*__counters[__i] = __indices[__i] + 1;
}
__count_mem_adjusted_barrier();
__mem_adjusted_barrier(_SLM_counter{});

//2. scan phase
{
Expand All @@ -231,7 +225,7 @@ struct __subgroup_radix_sort
_ONEDPL_PRAGMA_UNROLL
for (uint16_t __i = 1; __i < __bin_count; ++__i)
__bin_sum[__i] = __bin_sum[__i - 1] + __counter_lacc[__wi * __bin_count + __i];
__count_mem_adjusted_barrier();
__mem_adjusted_barrier(_SLM_counter{});

//exclusive scan local sum
uint16_t __sum_scan = __dpl_sycl::__exclusive_scan_over_group(
Expand All @@ -243,7 +237,7 @@ struct __subgroup_radix_sort

if (__wi == 0)
__counter_lacc[0] = 0;
__count_mem_adjusted_barrier();
__mem_adjusted_barrier(_SLM_counter{});
}

_ONEDPL_PRAGMA_UNROLL
Expand All @@ -257,7 +251,7 @@ struct __subgroup_radix_sort
__begin_bit += __radix;

//3. "re-order" phase
__dpl_sycl::__group_barrier(__it);
__mem_adjusted_barrier(_SLM_tag_val{});
if (__begin_bit >= __end_bit)
{
// the last iteration - writing out the result
Expand Down Expand Up @@ -305,7 +299,7 @@ struct __subgroup_radix_sort
__exchange_lacc[__r] = ::std::move(__values.__v[__i]);
}
}
__val_mem_adjusted_barrier();
__mem_adjusted_barrier(_SLM_tag_val{});

_ONEDPL_PRAGMA_UNROLL
for (uint16_t __i = 0; __i < __block_size; ++__i)
Expand All @@ -314,7 +308,7 @@ struct __subgroup_radix_sort
if (__idx < __n)
__values.__v[__i] = ::std::move(__exchange_lacc[__idx]);
}
__val_mem_adjusted_barrier();
__mem_adjusted_barrier(_SLM_tag_val{});
}
}));
});
Expand Down

0 comments on commit 4834d3e

Please sign in to comment.