Skip to content

Commit

Permalink
Use unique_ptrs for temporary device allocations in choose
Browse files Browse the repository at this point in the history
Based on suggestions by @AlexanderKalistratov

Create unique_ptr wraps a device allocation, which still needs to be manually freed
after kernel run, but will be deallocated automatically during validation leading
to launch
  • Loading branch information
ndgrigorian committed Dec 9, 2024
1 parent 6b96da3 commit 8793b54
Showing 1 changed file with 47 additions and 39 deletions.
86 changes: 47 additions & 39 deletions dpnp/backend/extensions/indexing/choose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,22 @@ static kernels::choose_fn_ptr_t choose_wrap_dispatch_table[td_ns::num_types]

namespace py = pybind11;

/*
Returns an std::unique_ptr wrapping a USM allocation and deleter.
Must still be manually freed by host_task when allocation is needed
for duration of asynchronous kernel execution.
*/
template <typename T>
auto usm_unique_ptr(std::size_t sz, sycl::queue &q)
{
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
auto deleter = [&q](T *usm) { sycl_free_noexcept(usm, q); };

return std::unique_ptr<T, decltype(deleter)>(sycl::malloc_device<T>(sz, q),
deleter);
}

std::vector<sycl::event>
_populate_choose_kernel_params(sycl::queue &exec_q,
std::vector<sycl::event> &host_task_events,
Expand Down Expand Up @@ -279,9 +295,16 @@ std::pair<sycl::event, sycl::event>
chc_offsets.push_back(py::ssize_t(0));
}

char **packed_chc_ptrs = sycl::malloc_device<char *>(n_chcs, exec_q);
auto fn = mode ? choose_clip_dispatch_table[src_type_id][chc_type_id]
: choose_wrap_dispatch_table[src_type_id][chc_type_id];

if (packed_chc_ptrs == nullptr) {
if (fn == nullptr) {
throw std::runtime_error("Indices must be integer type, got " +
std::to_string(src_type_id));
}

auto packed_chc_ptrs = usm_unique_ptr<char *>(n_chcs, exec_q);
if (packed_chc_ptrs.get() == nullptr) {
throw std::runtime_error(
"Unable to allocate packed_chc_ptrs device memory");
}
Expand All @@ -292,23 +315,15 @@ std::pair<sycl::event, sycl::event>
// chcs[0].strides,
// ...,
// chcs[n_chcs].strides]
py::ssize_t *packed_shapes_strides =
sycl::malloc_device<py::ssize_t>((3 + n_chcs) * sh_nelems, exec_q);

if (packed_shapes_strides == nullptr) {
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
sycl_free_noexcept(packed_chc_ptrs, exec_q);
auto packed_shapes_strides =
usm_unique_ptr<py::ssize_t>((3 + n_chcs) * sh_nelems, exec_q);
if (packed_shapes_strides.get() == nullptr) {
throw std::runtime_error(
"Unable to allocate packed_shapes_strides device memory");
}

py::ssize_t *packed_chc_offsets =
sycl::malloc_device<py::ssize_t>(n_chcs, exec_q);

if (packed_chc_offsets == nullptr) {
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
sycl_free_noexcept(packed_chc_ptrs, exec_q);
sycl_free_noexcept(packed_shapes_strides, exec_q);
auto packed_chc_offsets = usm_unique_ptr<py::ssize_t>(n_chcs, exec_q);
if (packed_chc_offsets.get() == nullptr) {
throw std::runtime_error(
"Unable to allocate packed_chc_offsets device memory");
}
Expand All @@ -320,44 +335,37 @@ std::pair<sycl::event, sycl::event>
host_task_events.reserve(2);

std::vector<sycl::event> pack_deps = _populate_choose_kernel_params(
exec_q, host_task_events, packed_chc_ptrs, packed_shapes_strides,
packed_chc_offsets, src_shape, sh_nelems, src_strides, dst_strides,
chc_strides, chc_ptrs, chc_offsets, n_chcs);
exec_q, host_task_events, packed_chc_ptrs.get(),
packed_shapes_strides.get(), packed_chc_offsets.get(), src_shape,
sh_nelems, src_strides, dst_strides, chc_strides, chc_ptrs, chc_offsets,
n_chcs);

std::vector<sycl::event> all_deps;
all_deps.reserve(depends.size() + pack_deps.size());
all_deps.insert(std::end(all_deps), std::begin(pack_deps),
std::end(pack_deps));
all_deps.insert(std::end(all_deps), std::begin(depends), std::end(depends));

auto fn = mode ? choose_clip_dispatch_table[src_type_id][chc_type_id]
: choose_wrap_dispatch_table[src_type_id][chc_type_id];

if (fn == nullptr) {
sycl::event::wait(host_task_events);
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
sycl_free_noexcept(packed_chc_ptrs, exec_q);
sycl_free_noexcept(packed_shapes_strides, exec_q);
sycl_free_noexcept(packed_chc_offsets, exec_q);
throw std::runtime_error("Indices must be integer type, got " +
std::to_string(src_type_id));
}

sycl::event choose_generic_ev =
fn(exec_q, nelems, n_chcs, sh_nelems, packed_shapes_strides, src_data,
dst_data, packed_chc_ptrs, src_offset, dst_offset,
packed_chc_offsets, all_deps);
fn(exec_q, nelems, n_chcs, sh_nelems, packed_shapes_strides.get(),
src_data, dst_data, packed_chc_ptrs.get(), src_offset, dst_offset,
packed_chc_offsets.get(), all_deps);

// release usm_unique_ptrs
auto chc_ptrs_ = packed_chc_ptrs.release();
auto shapes_strides_ = packed_shapes_strides.release();
auto chc_offsets_ = packed_chc_offsets.release();

// free packed temporaries
sycl::event temporaries_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(choose_generic_ev);
const auto &ctx = exec_q.get_context();

using dpctl::tensor::alloc_utils::sycl_free_noexcept;
cgh.host_task([packed_shapes_strides, packed_chc_ptrs,
packed_chc_offsets, ctx]() {
sycl_free_noexcept(packed_shapes_strides, ctx);
sycl_free_noexcept(packed_chc_ptrs, ctx);
sycl_free_noexcept(packed_chc_offsets, ctx);
cgh.host_task([chc_ptrs_, shapes_strides_, chc_offsets_, ctx]() {
sycl_free_noexcept(chc_ptrs_, ctx);
sycl_free_noexcept(shapes_strides_, ctx);
sycl_free_noexcept(chc_offsets_, ctx);
});
});

Expand Down

0 comments on commit 8793b54

Please sign in to comment.