diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp index 84edb1dd3b..de0011e2ab 100644 --- a/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp +++ b/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp @@ -1588,11 +1588,11 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, using CountT = std::uint32_t; // memory for storing count and offset values - CountT *count_ptr = - sycl::malloc_device(n_iters * n_counts, exec_q); - if (nullptr == count_ptr) { - throw std::runtime_error("Could not allocate USM-device memory"); - } + auto count_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + n_iters * n_counts, exec_q); + + CountT *count_ptr = count_owner.get(); constexpr std::uint32_t zero_radix_iter{0}; @@ -1605,25 +1605,17 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, n_counts, count_ptr, proj_op, is_ascending, depends); - sort_ev = exec_q.submit([=](sycl::handler &cgh) { - cgh.depends_on(sort_ev); - const sycl::context &ctx = exec_q.get_context(); - - using dpctl::tensor::alloc_utils::sycl_free_noexcept; - cgh.host_task( - [ctx, count_ptr]() { sycl_free_noexcept(count_ptr, ctx); }); - }); + sort_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {sort_ev}, count_owner); return sort_ev; } - ValueT *tmp_arr = - sycl::malloc_device(n_iters * n_to_sort, exec_q); - if (nullptr == tmp_arr) { - using dpctl::tensor::alloc_utils::sycl_free_noexcept; - sycl_free_noexcept(count_ptr, exec_q); - throw std::runtime_error("Could not allocate USM-device memory"); - } + auto tmp_arr_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + n_iters * n_to_sort, exec_q); + + ValueT *tmp_arr = tmp_arr_owner.get(); // iterations per each bucket assert("Number of iterations must be even" && radix_iters % 2 == 0); @@ -1657,17 +1649,8 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, } } - sort_ev = exec_q.submit([=](sycl::handler &cgh) { - cgh.depends_on(sort_ev); - - const sycl::context &ctx = exec_q.get_context(); - - using dpctl::tensor::alloc_utils::sycl_free_noexcept; - cgh.host_task([ctx, count_ptr, tmp_arr]() { - sycl_free_noexcept(tmp_arr, ctx); - sycl_free_noexcept(count_ptr, ctx); - }); - }); + sort_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {sort_ev}, tmp_arr_owner, count_owner); } return sort_ev; @@ -1769,13 +1752,12 @@ radix_argsort_axis1_contig_impl(sycl::queue &exec_q, reinterpret_cast(res_cp) + iter_res_offset + sort_res_offset; const std::size_t total_nelems = iter_nelems * sort_nelems; - const std::size_t padded_total_nelems = ((total_nelems + 63) / 64) * 64; - IndexTy *workspace = sycl::malloc_device( - padded_total_nelems + total_nelems, exec_q); + auto workspace_owner = + dpctl::tensor::alloc_utils::smart_malloc_device(total_nelems, + exec_q); - if (nullptr == workspace) { - throw std::runtime_error("Could not allocate workspace on device"); - } + // get raw USM pointer + IndexTy *workspace = workspace_owner.get(); using IdentityProjT = radix_sort_details::IdentityProj; using IndexedProjT = @@ -1820,14 +1802,8 @@ radix_argsort_axis1_contig_impl(sycl::queue &exec_q, }); }); - sycl::event cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(map_back_ev); - - const sycl::context &ctx = exec_q.get_context(); - - using dpctl::tensor::alloc_utils::sycl_free_noexcept; - cgh.host_task([ctx, workspace] { sycl_free_noexcept(workspace, ctx); }); - }); + sycl::event cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {map_back_ev}, workspace_owner); return cleanup_ev; } diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp index 6698aa50de..d1d686a2dd 100644 --- a/dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp +++ b/dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp @@ -30,9 +30,10 @@ #include #include #include -#include #include +#include + #include "kernels/dpctl_tensor_types.hpp" #include "kernels/sorting/merge_sort.hpp" #include "kernels/sorting/radix_sort.hpp" @@ -90,11 +91,11 @@ topk_full_merge_sort_impl(sycl::queue &exec_q, const CompT &comp, const std::vector &depends) { - IndexTy *index_data = - sycl::malloc_device(iter_nelems * axis_nelems, exec_q); - if (index_data == nullptr) { - throw std::runtime_error("Unable to allocate device_memory"); - } + auto index_data_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + iter_nelems * axis_nelems, exec_q); + // extract USM pointer + IndexTy *index_data = index_data_owner.get(); using IotaKernelName = topk_populate_index_data_krn; @@ -153,14 +154,8 @@ topk_full_merge_sort_impl(sycl::queue &exec_q, }); sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(write_out_ev); - const sycl::context &ctx = exec_q.get_context(); - - using dpctl::tensor::alloc_utils::sycl_free_noexcept; - cgh.host_task( - [ctx, index_data] { sycl_free_noexcept(index_data, ctx); }); - }); + dpctl::tensor::alloc_utils::async_smart_free(exec_q, {write_out_ev}, + index_data_owner); return cleanup_host_task_event; }; @@ -283,11 +278,11 @@ sycl::event topk_merge_impl( index_comp, depends); } - IndexTy *index_data = - sycl::malloc_device(iter_nelems * alloc_len, exec_q); - if (index_data == nullptr) { - throw std::runtime_error("Unable to allocate device_memory"); - } + auto index_data_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + iter_nelems * alloc_len, exec_q); + // get raw USM pointer + IndexTy *index_data = index_data_owner.get(); // no need to populate index data: SLM will be populated with default // values @@ -427,14 +422,8 @@ sycl::event topk_merge_impl( }); sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(write_topk_ev); - const sycl::context &ctx = exec_q.get_context(); - - using dpctl::tensor::alloc_utils::sycl_free_noexcept; - cgh.host_task( - [ctx, index_data] { sycl_free_noexcept(index_data, ctx); }); - }); + dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {write_topk_ev}, index_data_owner); return cleanup_host_task_event; } @@ -474,15 +463,13 @@ sycl::event topk_radix_impl(sycl::queue &exec_q, const std::size_t total_nelems = iter_nelems * axis_nelems; const std::size_t padded_total_nelems = ((total_nelems + 63) / 64) * 64; - IndexTy *workspace = sycl::malloc_device( - padded_total_nelems + total_nelems, exec_q); + auto workspace_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + padded_total_nelems + total_nelems, exec_q); - IndexTy *tmp_tp = sycl::malloc_device(total_nelems, exec_q); - - if (nullptr == workspace || nullptr == tmp_tp) { - throw std::runtime_error( - "Not enough device memory for radix sort topk"); - } + // get raw USM pointer + IndexTy *workspace = workspace_owner.get(); + IndexTy *tmp_tp = workspace + padded_total_nelems; using IdentityProjT = radix_sort_details::IdentityProj; using IndexedProjT = @@ -536,17 +523,8 @@ sycl::event topk_radix_impl(sycl::queue &exec_q, }); }); - sycl::event cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(write_topk_ev); - - const sycl::context &ctx = exec_q.get_context(); - - using dpctl::tensor::alloc_utils::sycl_free_noexcept; - cgh.host_task([ctx, workspace, tmp_tp] { - sycl_free_noexcept(workspace, ctx); - sycl_free_noexcept(tmp_tp, ctx); - }); - }); + sycl::event cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {write_topk_ev}, workspace_owner); return cleanup_ev; } diff --git a/dpctl/tensor/libtensor/include/utils/sycl_alloc_utils.hpp b/dpctl/tensor/libtensor/include/utils/sycl_alloc_utils.hpp index 3ad5f6f36a..87476221d8 100644 --- a/dpctl/tensor/libtensor/include/utils/sycl_alloc_utils.hpp +++ b/dpctl/tensor/libtensor/include/utils/sycl_alloc_utils.hpp @@ -28,6 +28,9 @@ #include #include +#include +#include +#include #include "sycl/sycl.hpp" @@ -73,11 +76,99 @@ void sycl_free_noexcept(T *ptr, const sycl::context &ctx) noexcept } } -template void sycl_free_noexcept(T *ptr, sycl::queue &q) noexcept +template +void sycl_free_noexcept(T *ptr, const sycl::queue &q) noexcept { sycl_free_noexcept(ptr, q.get_context()); } +class USMDeleter +{ +private: + sycl::context ctx_; + +public: + USMDeleter(const sycl::queue &q) : ctx_(q.get_context()) {} + USMDeleter(const sycl::context &ctx) : ctx_(ctx) {} + + template void operator()(T *ptr) const + { + sycl_free_noexcept(ptr, ctx_); + } +}; + +template +std::unique_ptr +smart_malloc(std::size_t count, + const sycl::queue &q, + sycl::usm::alloc kind, + const sycl::property_list &propList = {}) +{ + T *ptr = sycl::malloc(count, q, kind, propList); + if (nullptr == ptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + + auto usm_deleter = USMDeleter(q); + return std::unique_ptr(ptr, usm_deleter); +} + +template +std::unique_ptr +smart_malloc_device(std::size_t count, + const sycl::queue &q, + const sycl::property_list &propList = {}) +{ + return smart_malloc(count, q, sycl::usm::alloc::device, propList); +} + +template +std::unique_ptr +smart_malloc_shared(std::size_t count, + const sycl::queue &q, + const sycl::property_list &propList = {}) +{ + return smart_malloc(count, q, sycl::usm::alloc::shared, propList); +} + +template +std::unique_ptr +smart_malloc_jost(std::size_t count, + const sycl::queue &q, + const sycl::property_list &propList = {}) +{ + return smart_malloc(count, q, sycl::usm::alloc::host, propList); +} + +template +sycl::event async_smart_free(sycl::queue &exec_q, + const std::vector &depends, + Args &&...args) +{ + constexpr std::size_t n = sizeof...(Args); + + std::vector ptrs; + ptrs.reserve(n); + (ptrs.push_back(reinterpret_cast(args.get())), ...); + + std::vector dels; + dels.reserve(n); + (dels.push_back(args.get_deleter()), ...); + + sycl::event ht_e = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + cgh.host_task([ptrs, dels]() { + for (size_t i = 0; i < ptrs.size(); ++i) { + dels[i](ptrs[i]); + } + }); + }); + (args.release(), ...); + + return ht_e; +} + } // end of namespace alloc_utils } // end of namespace tensor } // end of namespace dpctl