diff --git a/dpnp/backend/extensions/indexing/choose.cpp b/dpnp/backend/extensions/indexing/choose.cpp index 77e1de18c12..e73a9a3892c 100644 --- a/dpnp/backend/extensions/indexing/choose.cpp +++ b/dpnp/backend/extensions/indexing/choose.cpp @@ -54,6 +54,22 @@ static kernels::choose_fn_ptr_t choose_wrap_dispatch_table[td_ns::num_types] namespace py = pybind11; +/* + Returns an std::unique_ptr wrapping a USM allocation and deleter. + + Must still be manually freed by host_task when allocation is needed + for duration of asynchronous kernel execution. +*/ +template +auto usm_unique_ptr(std::size_t sz, sycl::queue &q) +{ + using dpctl::tensor::alloc_utils::sycl_free_noexcept; + auto deleter = [&q](T *usm) { sycl_free_noexcept(usm, q); }; + + return std::unique_ptr(sycl::malloc_device(sz, q), + deleter); +} + std::vector _populate_choose_kernel_params(sycl::queue &exec_q, std::vector &host_task_events, @@ -279,9 +295,16 @@ std::pair chc_offsets.push_back(py::ssize_t(0)); } - char **packed_chc_ptrs = sycl::malloc_device(n_chcs, exec_q); + auto fn = mode ? choose_clip_dispatch_table[src_type_id][chc_type_id] + : choose_wrap_dispatch_table[src_type_id][chc_type_id]; - if (packed_chc_ptrs == nullptr) { + if (fn == nullptr) { + throw std::runtime_error("Indices must be integer type, got " + + std::to_string(src_type_id)); + } + + auto packed_chc_ptrs = usm_unique_ptr(n_chcs, exec_q); + if (packed_chc_ptrs.get() == nullptr) { throw std::runtime_error( "Unable to allocate packed_chc_ptrs device memory"); } @@ -292,23 +315,15 @@ std::pair // chcs[0].strides, // ..., // chcs[n_chcs].strides] - py::ssize_t *packed_shapes_strides = - sycl::malloc_device((3 + n_chcs) * sh_nelems, exec_q); - - if (packed_shapes_strides == nullptr) { - using dpctl::tensor::alloc_utils::sycl_free_noexcept; - sycl_free_noexcept(packed_chc_ptrs, exec_q); + auto packed_shapes_strides = + usm_unique_ptr((3 + n_chcs) * sh_nelems, exec_q); + if (packed_shapes_strides.get() == nullptr) { throw std::runtime_error( "Unable to allocate packed_shapes_strides device memory"); } - py::ssize_t *packed_chc_offsets = - sycl::malloc_device(n_chcs, exec_q); - - if (packed_chc_offsets == nullptr) { - using dpctl::tensor::alloc_utils::sycl_free_noexcept; - sycl_free_noexcept(packed_chc_ptrs, exec_q); - sycl_free_noexcept(packed_shapes_strides, exec_q); + auto packed_chc_offsets = usm_unique_ptr(n_chcs, exec_q); + if (packed_chc_offsets.get() == nullptr) { throw std::runtime_error( "Unable to allocate packed_chc_offsets device memory"); } @@ -320,9 +335,10 @@ std::pair host_task_events.reserve(2); std::vector pack_deps = _populate_choose_kernel_params( - exec_q, host_task_events, packed_chc_ptrs, packed_shapes_strides, - packed_chc_offsets, src_shape, sh_nelems, src_strides, dst_strides, - chc_strides, chc_ptrs, chc_offsets, n_chcs); + exec_q, host_task_events, packed_chc_ptrs.get(), + packed_shapes_strides.get(), packed_chc_offsets.get(), src_shape, + sh_nelems, src_strides, dst_strides, chc_strides, chc_ptrs, chc_offsets, + n_chcs); std::vector all_deps; all_deps.reserve(depends.size() + pack_deps.size()); @@ -330,34 +346,26 @@ std::pair std::end(pack_deps)); all_deps.insert(std::end(all_deps), std::begin(depends), std::end(depends)); - auto fn = mode ? choose_clip_dispatch_table[src_type_id][chc_type_id] - : choose_wrap_dispatch_table[src_type_id][chc_type_id]; - - if (fn == nullptr) { - sycl::event::wait(host_task_events); - using dpctl::tensor::alloc_utils::sycl_free_noexcept; - sycl_free_noexcept(packed_chc_ptrs, exec_q); - sycl_free_noexcept(packed_shapes_strides, exec_q); - sycl_free_noexcept(packed_chc_offsets, exec_q); - throw std::runtime_error("Indices must be integer type, got " + - std::to_string(src_type_id)); - } - sycl::event choose_generic_ev = - fn(exec_q, nelems, n_chcs, sh_nelems, packed_shapes_strides, src_data, - dst_data, packed_chc_ptrs, src_offset, dst_offset, - packed_chc_offsets, all_deps); + fn(exec_q, nelems, n_chcs, sh_nelems, packed_shapes_strides.get(), + src_data, dst_data, packed_chc_ptrs.get(), src_offset, dst_offset, + packed_chc_offsets.get(), all_deps); + + // release usm_unique_ptrs + auto chc_ptrs_ = packed_chc_ptrs.release(); + auto shapes_strides_ = packed_shapes_strides.release(); + auto chc_offsets_ = packed_chc_offsets.release(); // free packed temporaries sycl::event temporaries_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(choose_generic_ev); const auto &ctx = exec_q.get_context(); + using dpctl::tensor::alloc_utils::sycl_free_noexcept; - cgh.host_task([packed_shapes_strides, packed_chc_ptrs, - packed_chc_offsets, ctx]() { - sycl_free_noexcept(packed_shapes_strides, ctx); - sycl_free_noexcept(packed_chc_ptrs, ctx); - sycl_free_noexcept(packed_chc_offsets, ctx); + cgh.host_task([chc_ptrs_, shapes_strides_, chc_offsets_, ctx]() { + sycl_free_noexcept(chc_ptrs_, ctx); + sycl_free_noexcept(shapes_strides_, ctx); + sycl_free_noexcept(chc_offsets_, ctx); }); });