Skip to content

Commit

Permalink
Improve kernel name generation for unnamed lambda (kernel templates) (#…
Browse files Browse the repository at this point in the history
…1524)

Signed-off-by: Dmitriy Sobolev <[email protected]>
  • Loading branch information
dmitriy-sobolev authored Nov 22, 2024
1 parent 3436938 commit 36492bf
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 39 deletions.
18 changes: 9 additions & 9 deletions test/kt/esimd_radix_sort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,13 +185,13 @@ template <typename T, bool IsAscending, std::uint8_t RadixBits, typename KernelP
void
test_general_cases(sycl::queue q, std::size_t size, KernelParam param)
{
test_usm<T, IsAscending, RadixBits, sycl::usm::alloc::shared>(q, size, TestUtils::get_new_kernel_params<0>(param));
test_usm<T, IsAscending, RadixBits, sycl::usm::alloc::device>(q, size, TestUtils::get_new_kernel_params<1>(param));
test_sycl_iterators<T, IsAscending, RadixBits>(q, size, TestUtils::get_new_kernel_params<2>(param));
test_sycl_buffer<T, IsAscending, RadixBits>(q, size, TestUtils::get_new_kernel_params<3>(param));
test_usm<T, IsAscending, RadixBits, sycl::usm::alloc::shared>(q, size, TestUtils::create_new_kernel_param_idx<0>(param));
test_usm<T, IsAscending, RadixBits, sycl::usm::alloc::device>(q, size, TestUtils::create_new_kernel_param_idx<1>(param));
test_sycl_iterators<T, IsAscending, RadixBits>(q, size, TestUtils::create_new_kernel_param_idx<2>(param));
test_sycl_buffer<T, IsAscending, RadixBits>(q, size, TestUtils::create_new_kernel_param_idx<3>(param));
#if _ENABLE_RANGES_TESTING
test_all_view<T, IsAscending, RadixBits>(q, size, TestUtils::get_new_kernel_params<4>(param));
test_subrange_view<T, IsAscending, RadixBits>(q, size, TestUtils::get_new_kernel_params<5>(param));
test_all_view<T, IsAscending, RadixBits>(q, size, TestUtils::create_new_kernel_param_idx<4>(param));
test_subrange_view<T, IsAscending, RadixBits>(q, size, TestUtils::create_new_kernel_param_idx<5>(param));
#endif // _ENABLE_RANGES_TESTING
}

Expand All @@ -217,11 +217,11 @@ main()
for (auto size : sort_sizes)
{
test_general_cases<TEST_KEY_TYPE, Ascending, TestRadixBits>(
q, size, TestUtils::get_new_kernel_params<0>(params));
q, size, TestUtils::create_new_kernel_param_idx<0>(params));
test_general_cases<TEST_KEY_TYPE, Descending, TestRadixBits>(
q, size, TestUtils::get_new_kernel_params<1>(params));
q, size, TestUtils::create_new_kernel_param_idx<1>(params));
}
test_small_sizes<TEST_KEY_TYPE, Ascending, TestRadixBits>(q, TestUtils::get_new_kernel_params<3>(params));
test_small_sizes<TEST_KEY_TYPE, Ascending, TestRadixBits>(q, TestUtils::create_new_kernel_param_idx<3>(params));
}
catch (const ::std::exception& exc)
{
Expand Down
8 changes: 4 additions & 4 deletions test/kt/esimd_radix_sort_by_key.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,13 @@ int main()
for (auto size : sort_sizes)
{
test_usm<TEST_KEY_TYPE, TEST_VALUE_TYPE, Ascending, TestRadixBits, sycl::usm::alloc::shared>(
q, size, TestUtils::get_new_kernel_params<0>(params));
q, size, TestUtils::create_new_kernel_param_idx<0>(params));
test_usm<TEST_KEY_TYPE, TEST_VALUE_TYPE, Descending, TestRadixBits, sycl::usm::alloc::shared>(
q, size, TestUtils::get_new_kernel_params<1>(params));
q, size, TestUtils::create_new_kernel_param_idx<1>(params));
test_sycl_buffer<TEST_KEY_TYPE, TEST_VALUE_TYPE, Ascending, TestRadixBits>(
q, size, TestUtils::get_new_kernel_params<2>(params));
q, size, TestUtils::create_new_kernel_param_idx<2>(params));
test_sycl_buffer<TEST_KEY_TYPE, TEST_VALUE_TYPE, Descending, TestRadixBits>(
q, size, TestUtils::get_new_kernel_params<3>(params));
q, size, TestUtils::create_new_kernel_param_idx<3>(params));
}
}
catch (const ::std::exception& exc)
Expand Down
8 changes: 4 additions & 4 deletions test/kt/esimd_radix_sort_by_key_out_of_place.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,13 @@ main()
for (auto size : sort_sizes)
{
test_usm<TEST_KEY_TYPE, TEST_VALUE_TYPE, Ascending, TestRadixBits, sycl::usm::alloc::shared>(
q, size, TestUtils::get_new_kernel_params<0>(params));
q, size, TestUtils::create_new_kernel_param_idx<0>(params));
test_usm<TEST_KEY_TYPE, TEST_VALUE_TYPE, Descending, TestRadixBits, sycl::usm::alloc::shared>(
q, size, TestUtils::get_new_kernel_params<1>(params));
q, size, TestUtils::create_new_kernel_param_idx<1>(params));
test_sycl_buffer<TEST_KEY_TYPE, TEST_VALUE_TYPE, Ascending, TestRadixBits>(
q, size, TestUtils::get_new_kernel_params<2>(params));
q, size, TestUtils::create_new_kernel_param_idx<2>(params));
test_sycl_buffer<TEST_KEY_TYPE, TEST_VALUE_TYPE, Descending, TestRadixBits>(
q, size, TestUtils::get_new_kernel_params<3>(params));
q, size, TestUtils::create_new_kernel_param_idx<3>(params));
}
}
catch (const ::std::exception& exc)
Expand Down
18 changes: 9 additions & 9 deletions test/kt/esimd_radix_sort_out_of_place.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,13 @@ template <typename T, bool IsAscending, std::uint8_t RadixBits, typename KernelP
void
test_general_cases(sycl::queue q, std::size_t size, KernelParam param)
{
test_usm<T, IsAscending, RadixBits, sycl::usm::alloc::shared>(q, size, TestUtils::get_new_kernel_params<0>(param));
test_usm<T, IsAscending, RadixBits, sycl::usm::alloc::device>(q, size, TestUtils::get_new_kernel_params<1>(param));
test_sycl_iterators<T, IsAscending, RadixBits>(q, size, TestUtils::get_new_kernel_params<2>(param));
test_sycl_buffer<T, IsAscending, RadixBits>(q, size, TestUtils::get_new_kernel_params<3>(param));
test_usm<T, IsAscending, RadixBits, sycl::usm::alloc::shared>(q, size, TestUtils::create_new_kernel_param_idx<0>(param));
test_usm<T, IsAscending, RadixBits, sycl::usm::alloc::device>(q, size, TestUtils::create_new_kernel_param_idx<1>(param));
test_sycl_iterators<T, IsAscending, RadixBits>(q, size, TestUtils::create_new_kernel_param_idx<2>(param));
test_sycl_buffer<T, IsAscending, RadixBits>(q, size, TestUtils::create_new_kernel_param_idx<3>(param));
#if _ENABLE_RANGES_TESTING
test_all_view<T, IsAscending, RadixBits>(q, size, TestUtils::get_new_kernel_params<4>(param));
test_subrange_view<T, IsAscending, RadixBits>(q, size, TestUtils::get_new_kernel_params<5>(param));
test_all_view<T, IsAscending, RadixBits>(q, size, TestUtils::create_new_kernel_param_idx<4>(param));
test_subrange_view<T, IsAscending, RadixBits>(q, size, TestUtils::create_new_kernel_param_idx<5>(param));
#endif // _ENABLE_RANGES_TESTING
}

Expand All @@ -242,11 +242,11 @@ main()
for (auto size : sort_sizes)
{
test_general_cases<TEST_KEY_TYPE, Ascending, TestRadixBits>(
q, size, TestUtils::get_new_kernel_params<0>(params));
q, size, TestUtils::create_new_kernel_param_idx<0>(params));
test_general_cases<TEST_KEY_TYPE, Descending, TestRadixBits>(
q, size, TestUtils::get_new_kernel_params<1>(params));
q, size, TestUtils::create_new_kernel_param_idx<1>(params));
}
test_small_sizes<TEST_KEY_TYPE, Ascending, TestRadixBits>(q, TestUtils::get_new_kernel_params<3>(params));
test_small_sizes<TEST_KEY_TYPE, Ascending, TestRadixBits>(q, TestUtils::create_new_kernel_param_idx<3>(params));
}
catch (const ::std::exception& exc)
{
Expand Down
14 changes: 7 additions & 7 deletions test/kt/single_pass_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,28 +175,28 @@ template <typename T, typename BinOp, typename KernelParam>
void
test_general_cases(sycl::queue q, std::size_t size, BinOp bin_op, KernelParam param)
{
test_usm<T, sycl::usm::alloc::shared>(q, size, bin_op, TestUtils::get_new_kernel_params<0>(param));
test_usm<T, sycl::usm::alloc::device>(q, size, bin_op, TestUtils::get_new_kernel_params<1>(param));
test_sycl_iterators<T>(q, size, bin_op, TestUtils::get_new_kernel_params<2>(param));
test_usm<T, sycl::usm::alloc::shared>(q, size, bin_op, TestUtils::create_new_kernel_param_idx<0>(param));
test_usm<T, sycl::usm::alloc::device>(q, size, bin_op, TestUtils::create_new_kernel_param_idx<1>(param));
test_sycl_iterators<T>(q, size, bin_op, TestUtils::create_new_kernel_param_idx<2>(param));
#if _ENABLE_RANGES_TESTING
test_all_view<T>(q, size, bin_op, TestUtils::get_new_kernel_params<3>(param));
test_buffer<T>(q, size, bin_op, TestUtils::get_new_kernel_params<4>(param));
test_all_view<T>(q, size, bin_op, TestUtils::create_new_kernel_param_idx<3>(param));
test_buffer<T>(q, size, bin_op, TestUtils::create_new_kernel_param_idx<4>(param));
#endif
}

template <typename T, typename KernelParam>
void
test_all_cases(sycl::queue q, std::size_t size, KernelParam param)
{
test_general_cases<T>(q, size, std::plus<T>{}, TestUtils::get_new_kernel_params<0>(param));
test_general_cases<T>(q, size, std::plus<T>{}, TestUtils::create_new_kernel_param_idx<0>(param));
#if _PSTL_GROUP_REDUCTION_MULT_INT64_BROKEN
static constexpr bool int64_mult_broken = std::is_integral_v<T> && (sizeof(T) == 8);
#else
static constexpr bool int64_mult_broken = 0;
#endif
if constexpr (!int64_mult_broken)
{
test_general_cases<T>(q, size, std::multiplies<T>{}, TestUtils::get_new_kernel_params<1>(param));
test_general_cases<T>(q, size, std::multiplies<T>{}, TestUtils::create_new_kernel_param_idx<1>(param));
}
}

Expand Down
16 changes: 10 additions & 6 deletions test/support/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -961,17 +961,21 @@ create_new_policy_idx(Policy&& policy)

#if TEST_DPCPP_BACKEND_PRESENT
template <typename KernelName, int idx>
struct __kernel_name_with_idx
struct kernel_name_with_idx
{
};

template <int idx, typename KernelParams>
template <int idx, typename KernelParam>
constexpr auto
get_new_kernel_params(KernelParams)
create_new_kernel_param_idx(KernelParam)
{
return oneapi::dpl::experimental::kt::kernel_param<
KernelParams::data_per_workitem, KernelParams::workgroup_size,
__kernel_name_with_idx<typename KernelParams::kernel_name, idx>>{};
#if TEST_EXPLICIT_KERNEL_NAMES
return oneapi::dpl::experimental::kt::kernel_param<KernelParam::data_per_workitem,
KernelParam::workgroup_size,
kernel_name_with_idx<typename KernelParam::kernel_name, idx>>{};
#else
return KernelParam{};
#endif // TEST_EXPLICIT_KERNEL_NAMES
}
#endif //TEST_DPCPP_BACKEND_PRESENT

Expand Down

0 comments on commit 36492bf

Please sign in to comment.