diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/merge_sort.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/merge_sort.hpp index f3b5030c48..dbf40c10fe 100644 --- a/dpctl/tensor/libtensor/include/kernels/sorting/merge_sort.hpp +++ b/dpctl/tensor/libtensor/include/kernels/sorting/merge_sort.hpp @@ -33,6 +33,7 @@ #include "kernels/dpctl_tensor_types.hpp" #include "kernels/sorting/search_sorted_detail.hpp" +#include "kernels/sorting/sort_utils.hpp" namespace dpctl { @@ -811,20 +812,12 @@ sycl::event stable_argsort_axis1_contig_impl( const size_t total_nelems = iter_nelems * sort_nelems; - sycl::event populate_indexed_data_ev = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); + using dpctl::tensor::kernels::sort_utils_detail::iota_impl; - const sycl::range<1> range{total_nelems}; + using IotaKernelName = populate_index_data_krn; - using KernelName = - populate_index_data_krn; - - cgh.parallel_for(range, [=](sycl::id<1> id) { - size_t i = id[0]; - res_tp[i] = static_cast(i); - }); - }); + sycl::event populate_indexed_data_ev = iota_impl( + exec_q, res_tp, total_nelems, depends); // Sort segments of the array sycl::event base_sort_ev = @@ -839,21 +832,11 @@ sycl::event stable_argsort_axis1_contig_impl( exec_q, iter_nelems, sort_nelems, res_tp, index_comp, sorted_block_size, {base_sort_ev}); - sycl::event write_out_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(merges_ev); - - auto temp_acc = - merge_sort_detail::GetReadOnlyAccess{}(res_tp, - cgh); - - using KernelName = index_map_to_rows_krn; + using MapBackKernelName = index_map_to_rows_krn; + using dpctl::tensor::kernels::sort_utils_detail::map_back_impl; - const sycl::range<1> range{total_nelems}; - - cgh.parallel_for(range, [=](sycl::id<1> id) { - res_tp[id] = (temp_acc[id] % sort_nelems); - }); - }); + sycl::event write_out_ev = map_back_impl( + exec_q, total_nelems, res_tp, res_tp, sort_nelems, {merges_ev}); return write_out_ev; } diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp index dc3da24315..335b285bbf 100644 --- a/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp +++ b/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp @@ -38,6 +38,7 @@ #include #include "kernels/dpctl_tensor_types.hpp" +#include "kernels/sorting/sort_utils.hpp" #include "utils/sycl_alloc_utils.hpp" namespace dpctl @@ -62,6 +63,47 @@ class radix_sort_reorder_peer_kernel; template class radix_sort_reorder_kernel; +/*! @brief Computes smallest exponent such that `n <= (1 << exponent)` */ +template && + sizeof(SizeT) == sizeof(std::uint64_t), + int> = 0> +std::uint32_t ceil_log2(SizeT n) +{ + if (n <= 1) + return std::uint32_t{1}; + + std::uint32_t exp{1}; + --n; + // if n > 2^b, n = q * 2^b + r for q > 0 and 0 <= r < 2^b + // ceil_log2(q * 2^b + r) == ceil_log2(q * 2^b) == q + ceil_log2(n1) + if (n >= (SizeT{1} << 32)) { + n >>= 32; + exp += 32; + } + if (n >= (SizeT{1} << 16)) { + n >>= 16; + exp += 16; + } + if (n >= (SizeT{1} << 8)) { + n >>= 8; + exp += 8; + } + if (n >= (SizeT{1} << 4)) { + n >>= 4; + exp += 4; + } + if (n >= (SizeT{1} << 2)) { + n >>= 2; + exp += 2; + } + if (n >= (SizeT{1} << 1)) { + n >>= 1; + ++exp; + } + return exp; +} + //---------------------------------------------------------- // bitwise order-preserving conversions to unsigned integers //---------------------------------------------------------- @@ -1144,7 +1186,7 @@ struct subgroup_radix_sort const std::size_t max_slm_size = dev.template get_info() / 2; - const auto n_uniform = 1 << (std::uint32_t(std::log2(n - 1)) + 1); + const auto n_uniform = 1 << ceil_log2(n); const auto req_slm_size_val = sizeof(T) * n_uniform; return ((req_slm_size_val + req_slm_size_counters) <= max_slm_size) @@ -1256,9 +1298,7 @@ struct subgroup_radix_sort const uint16_t id = wi * block_size + i; if (id < n) values[i] = std::move( - this_input_arr[iter_val_offset + - static_cast( - id)]); + this_input_arr[iter_val_offset + id]); } while (true) { @@ -1272,8 +1312,7 @@ struct subgroup_radix_sort // counting phase auto pcounter = get_accessor_pointer(counter_acc) + - static_cast(wi) + - iter_counter_offset; + (wi + iter_counter_offset); // initialize counters #pragma unroll @@ -1348,19 +1387,15 @@ struct subgroup_radix_sort // scan contiguous numbers uint16_t bin_sum[bin_count]; - bin_sum[0] = - counter_acc[iter_counter_offset + - static_cast( - wi * bin_count)]; + const std::size_t counter_offset0 = + iter_counter_offset + wi * bin_count; + bin_sum[0] = counter_acc[counter_offset0]; #pragma unroll for (uint16_t i = 1; i < bin_count; ++i) bin_sum[i] = bin_sum[i - 1] + - counter_acc - [iter_counter_offset + - static_cast( - wi * bin_count + i)]; + counter_acc[counter_offset0 + i]; sycl::group_barrier(ndit.get_group()); @@ -1374,10 +1409,7 @@ struct subgroup_radix_sort // add to local sum, generate exclusive scan result #pragma unroll for (uint16_t i = 0; i < bin_count; ++i) - counter_acc[iter_counter_offset + - static_cast( - wi * bin_count + i + - 1)] = + counter_acc[counter_offset0 + i + 1] = sum_scan + bin_sum[i]; if (wi == 0) @@ -1407,10 +1439,8 @@ struct subgroup_radix_sort if (r < n) { // move the values to source range and // destroy the values - this_output_arr - [iter_val_offset + - static_cast(r)] = - std::move(values[i]); + this_output_arr[iter_val_offset + r] = + std::move(values[i]); } } @@ -1422,8 +1452,7 @@ struct subgroup_radix_sort for (uint16_t i = 0; i < block_size; ++i) { const uint16_t r = indices[i]; if (r < n) - exchange_acc[iter_exchange_offset + - static_cast(r)] = + exchange_acc[iter_exchange_offset + r] = std::move(values[i]); } @@ -1435,8 +1464,7 @@ struct subgroup_radix_sort if (id < n) values[i] = std::move( exchange_acc[iter_exchange_offset + - static_cast( - id)]); + id]); } sycl::group_barrier(ndit.get_group()); @@ -1601,11 +1629,11 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, using CountT = std::uint32_t; // memory for storing count and offset values - CountT *count_ptr = - sycl::malloc_device(n_iters * n_counts, exec_q); - if (nullptr == count_ptr) { - throw std::runtime_error("Could not allocate USM-device memory"); - } + auto count_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + n_iters * n_counts, exec_q); + + CountT *count_ptr = count_owner.get(); constexpr std::uint32_t zero_radix_iter{0}; @@ -1618,25 +1646,17 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, n_counts, count_ptr, proj_op, is_ascending, depends); - sort_ev = exec_q.submit([=](sycl::handler &cgh) { - cgh.depends_on(sort_ev); - const sycl::context &ctx = exec_q.get_context(); - - using dpctl::tensor::alloc_utils::sycl_free_noexcept; - cgh.host_task( - [ctx, count_ptr]() { sycl_free_noexcept(count_ptr, ctx); }); - }); + sort_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {sort_ev}, count_owner); return sort_ev; } - ValueT *tmp_arr = - sycl::malloc_device(n_iters * n_to_sort, exec_q); - if (nullptr == tmp_arr) { - using dpctl::tensor::alloc_utils::sycl_free_noexcept; - sycl_free_noexcept(count_ptr, exec_q); - throw std::runtime_error("Could not allocate USM-device memory"); - } + auto tmp_arr_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + n_iters * n_to_sort, exec_q); + + ValueT *tmp_arr = tmp_arr_owner.get(); // iterations per each bucket assert("Number of iterations must be even" && radix_iters % 2 == 0); @@ -1670,17 +1690,8 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, } } - sort_ev = exec_q.submit([=](sycl::handler &cgh) { - cgh.depends_on(sort_ev); - - const sycl::context &ctx = exec_q.get_context(); - - using dpctl::tensor::alloc_utils::sycl_free_noexcept; - cgh.host_task([ctx, count_ptr, tmp_arr]() { - sycl_free_noexcept(tmp_arr, ctx); - sycl_free_noexcept(count_ptr, ctx); - }); - }); + sort_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {sort_ev}, tmp_arr_owner, count_owner); } return sort_ev; @@ -1782,57 +1793,38 @@ radix_argsort_axis1_contig_impl(sycl::queue &exec_q, reinterpret_cast(res_cp) + iter_res_offset + sort_res_offset; const std::size_t total_nelems = iter_nelems * sort_nelems; - const std::size_t padded_total_nelems = ((total_nelems + 63) / 64) * 64; - IndexTy *workspace = sycl::malloc_device( - padded_total_nelems + total_nelems, exec_q); + auto workspace_owner = + dpctl::tensor::alloc_utils::smart_malloc_device(total_nelems, + exec_q); - if (nullptr == workspace) { - throw std::runtime_error("Could not allocate workspace on device"); - } + // get raw USM pointer + IndexTy *workspace = workspace_owner.get(); using IdentityProjT = radix_sort_details::IdentityProj; using IndexedProjT = radix_sort_details::IndexedProj; const IndexedProjT proj_op{arg_tp, IdentityProjT{}}; - sycl::event iota_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); + using IotaKernelName = radix_argsort_iota_krn; - using KernelName = radix_argsort_iota_krn; + using dpctl::tensor::kernels::sort_utils_detail::iota_impl; - cgh.parallel_for( - sycl::range<1>(total_nelems), [=](sycl::id<1> id) { - size_t i = id[0]; - IndexTy sort_id = static_cast(i); - workspace[i] = sort_id; - }); - }); + sycl::event iota_ev = iota_impl( + exec_q, workspace, total_nelems, depends); sycl::event radix_sort_ev = radix_sort_details::parallel_radix_sort_impl( exec_q, iter_nelems, sort_nelems, workspace, res_tp, proj_op, sort_ascending, {iota_ev}); - sycl::event map_back_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(radix_sort_ev); - - using KernelName = radix_argsort_index_write_out_krn; - - cgh.parallel_for( - sycl::range<1>(total_nelems), [=](sycl::id<1> id) { - IndexTy linear_index = res_tp[id]; - res_tp[id] = (linear_index % sort_nelems); - }); - }); - - sycl::event cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(map_back_ev); + using MapBackKernelName = radix_argsort_index_write_out_krn; + using dpctl::tensor::kernels::sort_utils_detail::map_back_impl; - const sycl::context &ctx = exec_q.get_context(); + sycl::event map_back_ev = map_back_impl( + exec_q, total_nelems, res_tp, res_tp, sort_nelems, {radix_sort_ev}); - using dpctl::tensor::alloc_utils::sycl_free_noexcept; - cgh.host_task([ctx, workspace] { sycl_free_noexcept(workspace, ctx); }); - }); + sycl::event cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {map_back_ev}, workspace_owner); return cleanup_ev; } diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/sort_utils.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/sort_utils.hpp new file mode 100644 index 0000000000..f62a6c3fa0 --- /dev/null +++ b/dpctl/tensor/libtensor/include/kernels/sorting/sort_utils.hpp @@ -0,0 +1,124 @@ +//=== sorting.hpp - Implementation of sorting kernels ---*-C++-*--/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for tensor sort/argsort operations. +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include + +#include + +namespace dpctl +{ +namespace tensor +{ +namespace kernels +{ +namespace sort_utils_detail +{ + +namespace syclexp = sycl::ext::oneapi::experimental; + +template +sycl::event iota_impl(sycl::queue &exec_q, + T *data, + std::size_t nelems, + const std::vector &dependent_events) +{ + constexpr std::uint32_t lws = 256; + constexpr std::uint32_t n_wi = 4; + const std::size_t n_groups = (nelems + n_wi * lws - 1) / (n_wi * lws); + + sycl::range<1> gRange{n_groups * lws}; + sycl::range<1> lRange{lws}; + sycl::nd_range<1> ndRange{gRange, lRange}; + + sycl::event e = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependent_events); + cgh.parallel_for(ndRange, [=](sycl::nd_item<1> it) { + const std::size_t gid = it.get_global_id(); + const auto &sg = it.get_sub_group(); + const std::uint32_t lane_id = sg.get_local_id()[0]; + + const std::size_t offset = (gid - lane_id) * n_wi; + const std::uint32_t max_sgSize = sg.get_max_local_range()[0]; + + std::array stripe{}; +#pragma unroll + for (std::uint32_t i = 0; i < n_wi; ++i) { + stripe[i] = T(offset + lane_id + i * max_sgSize); + } + + if (offset + n_wi * max_sgSize < nelems) { + constexpr auto group_ls_props = syclexp::properties{ + syclexp::data_placement_striped + // , syclexp::full_group + }; + + auto out_multi_ptr = sycl::address_space_cast< + sycl::access::address_space::global_space, + sycl::access::decorated::yes>(&data[offset]); + + syclexp::group_store(sg, sycl::span{&stripe[0], n_wi}, + out_multi_ptr, group_ls_props); + } + else { + for (std::size_t idx = offset + lane_id; idx < nelems; + idx += max_sgSize) + { + data[idx] = T(idx); + } + } + }); + }); + + return e; +} + +template +sycl::event map_back_impl(sycl::queue &exec_q, + std::size_t nelems, + const IndexTy *flat_index_data, + IndexTy *reduced_index_data, + std::size_t row_size, + const std::vector &dependent_events) +{ + sycl::event map_back_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependent_events); + + cgh.parallel_for( + sycl::range<1>(nelems), [=](sycl::id<1> id) { + const IndexTy linear_index = flat_index_data[id]; + reduced_index_data[id] = (linear_index % row_size); + }); + }); + + return map_back_ev; +} + +} // end of namespace sort_utils_detail +} // end of namespace kernels +} // end of namespace tensor +} // end of namespace dpctl diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp index 2674f877c9..0ff3fc4723 100644 --- a/dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp +++ b/dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp @@ -30,13 +30,15 @@ #include #include #include -#include #include +#include + #include "kernels/dpctl_tensor_types.hpp" -#include "merge_sort.hpp" -#include "radix_sort.hpp" -#include "search_sorted_detail.hpp" +#include "kernels/sorting/merge_sort.hpp" +#include "kernels/sorting/radix_sort.hpp" +#include "kernels/sorting/search_sorted_detail.hpp" +#include "kernels/sorting/sort_utils.hpp" #include "utils/sycl_alloc_utils.hpp" #include @@ -69,6 +71,68 @@ void scale_topk_params(const std::uint64_t nelems_per_slm, throw std::runtime_error("Could not construct top k kernel parameters"); } +template +sycl::event write_out_impl(sycl::queue &exec_q, + std::size_t iter_nelems, + std::size_t k, + const argTy *arg_tp, + const IndexTy *index_data, + std::size_t iter_index_stride, + std::size_t axis_nelems, + argTy *vals_tp, + IndexTy *inds_tp, + const std::vector &depends) +{ + constexpr std::uint32_t lws = 64; + constexpr std::uint32_t n_wi = 4; + const std::size_t nelems = iter_nelems * k; + const std::size_t n_groups = (nelems + lws * n_wi - 1) / (n_wi * lws); + + sycl::range<1> lRange{lws}; + sycl::range<1> gRange{n_groups * lws}; + sycl::nd_range<1> ndRange{gRange, lRange}; + + sycl::event write_out_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + cgh.parallel_for(ndRange, [=](sycl::nd_item<1> it) { + const std::size_t gid = it.get_global_linear_id(); + const auto &sg = it.get_sub_group(); + const std::uint32_t lane_id = sg.get_local_id()[0]; + const std::uint32_t sg_size = sg.get_max_local_range()[0]; + + const std::size_t start_id = + (gid - lane_id) * sg_size * n_wi + lane_id; + +#pragma unroll + for (std::uint32_t i = 0; i < n_wi; ++i) { + const std::size_t data_id = start_id + i * sg_size; + + if (data_id < nelems) { + const std::size_t iter_id = data_id / k; + + /* + const std::size_t axis_gid = data_id - (iter_gid * k); + const std::size_t src_idx = iter_gid * iter_index_stride + + axis_gid; + */ + const std::size_t src_idx = + data_id + iter_id * (iter_index_stride - k); + + const IndexTy res_ind = index_data[src_idx]; + const argTy v = arg_tp[res_ind]; + + const std::size_t dst_idx = data_id; + vals_tp[dst_idx] = v; + inds_tp[dst_idx] = (res_ind % axis_nelems); + } + } + }); + }); + + return write_out_ev; +} + } // namespace topk_detail template @@ -89,26 +153,18 @@ topk_full_merge_sort_impl(sycl::queue &exec_q, const CompT &comp, const std::vector &depends) { - IndexTy *index_data = - sycl::malloc_device(iter_nelems * axis_nelems, exec_q); - if (index_data == nullptr) { - throw std::runtime_error("Unable to allocate device_memory"); - } + auto index_data_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + iter_nelems * axis_nelems, exec_q); + // extract USM pointer + IndexTy *index_data = index_data_owner.get(); - sycl::event populate_indexed_data_ev = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - auto const &range = sycl::range<1>(iter_nelems * axis_nelems); + using IotaKernelName = topk_populate_index_data_krn; - using KernelName = - topk_populate_index_data_krn; + using dpctl::tensor::kernels::sort_utils_detail::iota_impl; - cgh.parallel_for(range, [=](sycl::id<1> id) { - std::size_t i = id[0]; - index_data[i] = static_cast(i); - }); - }); + sycl::event populate_indexed_data_ev = iota_impl( + exec_q, index_data, iter_nelems * axis_nelems, depends); std::size_t sorted_block_size; // Sort segments of the array @@ -124,35 +180,17 @@ topk_full_merge_sort_impl(sycl::queue &exec_q, exec_q, iter_nelems, axis_nelems, index_data, comp, sorted_block_size, {base_sort_ev}); - sycl::event write_out_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(merges_ev); - - using KernelName = topk_full_merge_map_back_krn; - - cgh.parallel_for(iter_nelems * k, [=](sycl::id<1> id) { - std::size_t gid = id[0]; + using WriteOutKernelName = + topk_full_merge_map_back_krn; - std::size_t iter_gid = gid / k; - std::size_t axis_gid = gid - (iter_gid * k); - - std::size_t src_idx = iter_gid * axis_nelems + axis_gid; - std::size_t dst_idx = iter_gid * k + axis_gid; - - auto res_ind = index_data[src_idx]; - vals_tp[dst_idx] = arg_tp[res_ind]; - inds_tp[dst_idx] = res_ind % axis_nelems; - }); - }); + sycl::event write_out_ev = + topk_detail::write_out_impl( + exec_q, iter_nelems, k, arg_tp, index_data, axis_nelems, + axis_nelems, vals_tp, inds_tp, {merges_ev}); sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(write_out_ev); - const sycl::context &ctx = exec_q.get_context(); - - using dpctl::tensor::alloc_utils::sycl_free_noexcept; - cgh.host_task( - [ctx, index_data] { sycl_free_noexcept(index_data, ctx); }); - }); + dpctl::tensor::alloc_utils::async_smart_free(exec_q, {write_out_ev}, + index_data_owner); return cleanup_host_task_event; }; @@ -275,11 +313,11 @@ sycl::event topk_merge_impl( index_comp, depends); } - IndexTy *index_data = - sycl::malloc_device(iter_nelems * alloc_len, exec_q); - if (index_data == nullptr) { - throw std::runtime_error("Unable to allocate device_memory"); - } + auto index_data_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + iter_nelems * alloc_len, exec_q); + // get raw USM pointer + IndexTy *index_data = index_data_owner.get(); // no need to populate index data: SLM will be populated with default // values @@ -396,36 +434,17 @@ sycl::event topk_merge_impl( k_rounded, {base_sort_ev}); // Write out top k of the merge-sorted memory - sycl::event write_topk_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(merges_ev); - - using KernelName = - topk_partial_merge_map_back_krn; - - cgh.parallel_for(iter_nelems * k, [=](sycl::id<1> id) { - std::size_t gid = id[0]; + using WriteOutKernelName = + topk_partial_merge_map_back_krn; - std::size_t iter_gid = gid / k; - std::size_t axis_gid = gid - (iter_gid * k); - - std::size_t src_idx = iter_gid * alloc_len + axis_gid; - std::size_t dst_idx = iter_gid * k + axis_gid; - - auto res_ind = index_data[src_idx]; - vals_tp[dst_idx] = arg_tp[res_ind]; - inds_tp[dst_idx] = res_ind % axis_nelems; - }); - }); + sycl::event write_topk_ev = + topk_detail::write_out_impl( + exec_q, iter_nelems, k, arg_tp, index_data, alloc_len, + axis_nelems, vals_tp, inds_tp, {merges_ev}); sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(write_topk_ev); - const sycl::context &ctx = exec_q.get_context(); - - using dpctl::tensor::alloc_utils::sycl_free_noexcept; - cgh.host_task( - [ctx, index_data] { sycl_free_noexcept(index_data, ctx); }); - }); + dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {write_topk_ev}, index_data_owner); return cleanup_host_task_event; } @@ -465,33 +484,25 @@ sycl::event topk_radix_impl(sycl::queue &exec_q, const std::size_t total_nelems = iter_nelems * axis_nelems; const std::size_t padded_total_nelems = ((total_nelems + 63) / 64) * 64; - IndexTy *workspace = sycl::malloc_device( - padded_total_nelems + total_nelems, exec_q); - - IndexTy *tmp_tp = sycl::malloc_device(total_nelems, exec_q); + auto workspace_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + padded_total_nelems + total_nelems, exec_q); - if (nullptr == workspace || nullptr == tmp_tp) { - throw std::runtime_error( - "Not enough device memory for radix sort topk"); - } + // get raw USM pointer + IndexTy *workspace = workspace_owner.get(); + IndexTy *tmp_tp = workspace + padded_total_nelems; using IdentityProjT = radix_sort_details::IdentityProj; using IndexedProjT = radix_sort_details::IndexedProj; const IndexedProjT proj_op{arg_tp, IdentityProjT{}}; - sycl::event iota_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); + using IotaKernelName = topk_iota_krn; - using KernelName = topk_iota_krn; + using dpctl::tensor::kernels::sort_utils_detail::iota_impl; - cgh.parallel_for( - sycl::range<1>(total_nelems), [=](sycl::id<1> id) { - size_t i = id[0]; - IndexTy sort_id = static_cast(i); - workspace[i] = sort_id; - }); - }); + sycl::event iota_ev = iota_impl( + exec_q, workspace, total_nelems, depends); sycl::event radix_sort_ev = radix_sort_details::parallel_radix_sort_impl( @@ -499,37 +510,15 @@ sycl::event topk_radix_impl(sycl::queue &exec_q, ascending, {iota_ev}); // Write out top k of the temporary - sycl::event write_topk_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(radix_sort_ev); - - using KernelName = topk_radix_map_back_krn; - - cgh.parallel_for(iter_nelems * k, [=](sycl::id<1> id) { - std::size_t gid = id[0]; - - std::size_t iter_gid = gid / k; - std::size_t axis_gid = gid - (iter_gid * k); - - std::size_t src_idx = iter_gid * axis_nelems + axis_gid; - std::size_t dst_idx = iter_gid * k + axis_gid; + using WriteOutKernelName = topk_radix_map_back_krn; - IndexTy res_ind = tmp_tp[src_idx]; - vals_tp[dst_idx] = arg_tp[res_ind]; - inds_tp[dst_idx] = res_ind % axis_nelems; - }); - }); - - sycl::event cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(write_topk_ev); + sycl::event write_topk_ev = + topk_detail::write_out_impl( + exec_q, iter_nelems, k, arg_tp, tmp_tp, axis_nelems, axis_nelems, + vals_tp, inds_tp, {radix_sort_ev}); - const sycl::context &ctx = exec_q.get_context(); - - using dpctl::tensor::alloc_utils::sycl_free_noexcept; - cgh.host_task([ctx, workspace, tmp_tp] { - sycl_free_noexcept(workspace, ctx); - sycl_free_noexcept(tmp_tp, ctx); - }); - }); + sycl::event cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {write_topk_ev}, workspace_owner); return cleanup_ev; } diff --git a/dpctl/tensor/libtensor/include/utils/sycl_alloc_utils.hpp b/dpctl/tensor/libtensor/include/utils/sycl_alloc_utils.hpp index 3ad5f6f36a..f67e1bba1f 100644 --- a/dpctl/tensor/libtensor/include/utils/sycl_alloc_utils.hpp +++ b/dpctl/tensor/libtensor/include/utils/sycl_alloc_utils.hpp @@ -28,6 +28,10 @@ #include #include +#include +#include +#include +#include #include "sycl/sycl.hpp" @@ -73,11 +77,137 @@ void sycl_free_noexcept(T *ptr, const sycl::context &ctx) noexcept } } -template void sycl_free_noexcept(T *ptr, sycl::queue &q) noexcept +template +void sycl_free_noexcept(T *ptr, const sycl::queue &q) noexcept { sycl_free_noexcept(ptr, q.get_context()); } +class USMDeleter +{ +private: + sycl::context ctx_; + +public: + USMDeleter(const sycl::queue &q) : ctx_(q.get_context()) {} + USMDeleter(const sycl::context &ctx) : ctx_(ctx) {} + + template void operator()(T *ptr) const + { + sycl_free_noexcept(ptr, ctx_); + } +}; + +template +std::unique_ptr +smart_malloc(std::size_t count, + const sycl::queue &q, + sycl::usm::alloc kind, + const sycl::property_list &propList = {}) +{ + T *ptr = sycl::malloc(count, q, kind, propList); + if (nullptr == ptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + + auto usm_deleter = USMDeleter(q); + return std::unique_ptr(ptr, usm_deleter); +} + +template +std::unique_ptr +smart_malloc_device(std::size_t count, + const sycl::queue &q, + const sycl::property_list &propList = {}) +{ + return smart_malloc(count, q, sycl::usm::alloc::device, propList); +} + +template +std::unique_ptr +smart_malloc_shared(std::size_t count, + const sycl::queue &q, + const sycl::property_list &propList = {}) +{ + return smart_malloc(count, q, sycl::usm::alloc::shared, propList); +} + +template +std::unique_ptr +smart_malloc_host(std::size_t count, + const sycl::queue &q, + const sycl::property_list &propList = {}) +{ + return smart_malloc(count, q, sycl::usm::alloc::host, propList); +} + +namespace +{ +template struct valid_smart_ptr : public std::false_type +{ +}; + +template +struct valid_smart_ptr &> + : public std::is_same +{ +}; + +template +struct valid_smart_ptr> + : public std::is_same +{ +}; + +// base case +template struct all_valid_smart_ptrs +{ + static constexpr bool value = true; +}; + +template +struct all_valid_smart_ptrs +{ + static constexpr bool value = valid_smart_ptr::value && + (all_valid_smart_ptrs::value); +}; +} // namespace + +template +sycl::event async_smart_free(sycl::queue &exec_q, + const std::vector &depends, + Args &&...args) +{ + constexpr std::size_t n = sizeof...(Args); + static_assert( + n > 0, "async_smart_free requires at least one smart pointer argument"); + + static_assert( + all_valid_smart_ptrs::value, + "async_smart_free requires unique_ptr created with smart_malloc"); + + std::vector ptrs; + ptrs.reserve(n); + (ptrs.push_back(reinterpret_cast(args.get())), ...); + + std::vector dels; + dels.reserve(n); + (dels.push_back(args.get_deleter()), ...); + + sycl::event ht_e = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + cgh.host_task([ptrs, dels]() { + for (size_t i = 0; i < ptrs.size(); ++i) { + dels[i](ptrs[i]); + } + }); + }); + (args.release(), ...); + + return ht_e; +} + } // end of namespace alloc_utils } // end of namespace tensor } // end of namespace dpctl diff --git a/dpctl/tests/test_usm_ndarray_top_k.py b/dpctl/tests/test_usm_ndarray_top_k.py new file mode 100644 index 0000000000..a029db005c --- /dev/null +++ b/dpctl/tests/test_usm_ndarray_top_k.py @@ -0,0 +1,220 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import dpctl.tensor as dpt +from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported + + +def _expected_largest_inds(inp, n, shift, k): + "Computed expected top_k indices for mode='largest'" + assert k < n + ones_start_id = shift % (2 * n) + + alloc_dev = inp.device + + if ones_start_id < n: + expected_inds = dpt.arange( + ones_start_id, ones_start_id + k, dtype="i8", device=alloc_dev + ) + else: + # wrap-around + ones_end_id = (ones_start_id + n) % (2 * n) + if ones_end_id >= k: + expected_inds = dpt.arange(k, dtype="i8", device=alloc_dev) + else: + expected_inds = dpt.concat( + ( + dpt.arange(ones_end_id, dtype="i8", device=alloc_dev), + dpt.arange( + ones_start_id, + ones_start_id + k - ones_end_id, + dtype="i8", + device=alloc_dev, + ), + ) + ) + + return expected_inds + + +@pytest.mark.parametrize( + "dtype", + [ + "i1", + "u1", + "i2", + "u2", + "i4", + "u4", + "i8", + "u8", + "f2", + "f4", + "f8", + "c8", + "c16", + ], +) +@pytest.mark.parametrize("n", [33, 43, 255, 511, 1021, 8193]) +def test_top_k_1d_largest(dtype, n): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + shift, k = 734, 5 + o = dpt.ones(n, dtype=dtype) + z = dpt.zeros(n, dtype=dtype) + oz = dpt.concat((o, z)) + inp = dpt.roll(oz, shift) + + expected_inds = _expected_largest_inds(oz, n, shift, k) + + s = dpt.top_k(inp, k, mode="largest") + assert s.values.shape == (k,) + assert s.values.dtype == inp.dtype + assert s.indices.shape == (k,) + assert dpt.all(s.indices == expected_inds) + assert dpt.all(s.values == dpt.ones(k, dtype=dtype)), s.values + assert dpt.all(s.values == inp[s.indices]), s.indices + + +def _expected_smallest_inds(inp, n, shift, k): + "Computed expected top_k indices for mode='smallest'" + assert k < n + zeros_start_id = (n + shift) % (2 * n) + zeros_end_id = (shift) % (2 * n) + + alloc_dev = inp.device + + if zeros_start_id < zeros_end_id: + expected_inds = dpt.arange( + zeros_start_id, zeros_start_id + k, dtype="i8", device=alloc_dev + ) + else: + if zeros_end_id >= k: + expected_inds = dpt.arange(k, dtype="i8", device=alloc_dev) + else: + expected_inds = dpt.concat( + ( + dpt.arange(zeros_end_id, dtype="i8", device=alloc_dev), + dpt.arange( + zeros_start_id, + zeros_start_id + k - zeros_end_id, + dtype="i8", + device=alloc_dev, + ), + ) + ) + + return expected_inds + + +@pytest.mark.parametrize( + "dtype", + [ + "i1", + "u1", + "i2", + "u2", + "i4", + "u4", + "i8", + "u8", + "f2", + "f4", + "f8", + "c8", + "c16", + ], +) +@pytest.mark.parametrize("n", [37, 39, 61, 255, 257, 513, 1021, 8193]) +def test_top_k_1d_smallest(dtype, n): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + shift, k = 734, 5 + o = dpt.ones(n, dtype=dtype) + z = dpt.zeros(n, dtype=dtype) + oz = dpt.concat((o, z)) + inp = dpt.roll(oz, shift) + + expected_inds = _expected_smallest_inds(oz, n, shift, k) + + s = dpt.top_k(inp, k, mode="smallest") + assert s.values.shape == (k,) + assert s.values.dtype == inp.dtype + assert s.indices.shape == (k,) + assert dpt.all(s.indices == expected_inds) + assert dpt.all(s.values == dpt.zeros(k, dtype=dtype)), s.values + assert dpt.all(s.values == inp[s.indices]), s.indices + + +# triage failing top k radix implementation on CPU +# replicates from Python behavior of radix sort topk implementation +@pytest.mark.parametrize( + "n", + [ + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 61, + 137, + 255, + 511, + 1021, + 8193, + ], +) +def test_top_k_largest_1d_radix_i1(n): + get_queue_or_skip() + dt = "i1" + + shift, k = 734, 5 + o = dpt.ones(n, dtype=dt) + z = dpt.zeros(n, dtype=dt) + oz = dpt.concat((o, z)) + inp = dpt.roll(oz, shift) + + expected_inds = _expected_largest_inds(oz, n, shift, k) + + sorted_v = dpt.sort(inp, descending=True, kind="radixsort") + argsorted = dpt.argsort(inp, descending=True, kind="radixsort") + + assert dpt.all(sorted_v == inp[argsorted]) + + topk_vals = dpt.copy(sorted_v[:k]) + topk_inds = dpt.copy(argsorted[:k]) + + assert dpt.all(topk_vals == dpt.ones(k, dtype=dt)) + assert dpt.all(topk_inds == expected_inds) + + assert dpt.all(topk_vals == inp[topk_inds]), topk_inds