Skip to content

Commit

Permalink
Break up _populate_choose_kernel_params
Browse files Browse the repository at this point in the history
py_choose now uses multiple functions to move host data for choose
kernel parameters to the device
  • Loading branch information
ndgrigorian committed Jan 29, 2025
1 parent ca14052 commit 6b81516
Showing 1 changed file with 114 additions and 71 deletions.
185 changes: 114 additions & 71 deletions dpnp/backend/extensions/indexing/choose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,77 +54,112 @@ static kernels::choose_fn_ptr_t choose_wrap_dispatch_table[td_ns::num_types]

namespace py = pybind11;

std::vector<sycl::event>
_populate_choose_kernel_params(sycl::queue &exec_q,
std::vector<sycl::event> &host_task_events,
char **device_chc_ptrs,
py::ssize_t *device_shape_strides,
py::ssize_t *device_chc_offsets,
const py::ssize_t *shape,
int shape_len,
std::vector<py::ssize_t> &inp_strides,
std::vector<py::ssize_t> &dst_strides,
std::vector<py::ssize_t> &chc_strides,
std::vector<char *> &chc_ptrs,
std::vector<py::ssize_t> &chc_offsets,
py::ssize_t n_chcs)
namespace detail
{
using ptr_host_allocator_T =
dpctl::tensor::alloc_utils::usm_host_allocator<char *>;
using ptrT = std::vector<char *, ptr_host_allocator_T>;

ptr_host_allocator_T ptr_allocator(exec_q);
std::shared_ptr<ptrT> host_chc_ptrs_shp =
std::make_shared<ptrT>(n_chcs, ptr_allocator);
using host_ptrs_allocator_t =
dpctl::tensor::alloc_utils::usm_host_allocator<char *>;
using ptrs_t = std::vector<char *, host_ptrs_allocator_t>;
using host_ptrs_shp_t = std::shared_ptr<ptrs_t>;

using usm_host_allocatorT =
dpctl::tensor::alloc_utils::usm_host_allocator<py::ssize_t>;
using shT = std::vector<py::ssize_t, usm_host_allocatorT>;
host_ptrs_shp_t make_host_ptrs(sycl::queue &exec_q,
const std::vector<char *> &ptrs)
{
host_ptrs_allocator_t ptrs_allocator(exec_q);
host_ptrs_shp_t host_ptrs_shp =
std::make_shared<ptrs_t>(ptrs.size(), ptrs_allocator);

std::copy(ptrs.begin(), ptrs.end(), host_ptrs_shp->begin());

return host_ptrs_shp;
}

using host_sz_allocator_t =
dpctl::tensor::alloc_utils::usm_host_allocator<py::ssize_t>;
using sz_t = std::vector<py::ssize_t, host_sz_allocator_t>;
using host_sz_shp_t = std::shared_ptr<sz_t>;

usm_host_allocatorT sz_allocator(exec_q);
std::shared_ptr<shT> host_shape_strides_shp =
std::make_shared<shT>(shape_len * (3 + n_chcs), sz_allocator);
host_sz_shp_t make_host_offsets(sycl::queue &exec_q,
const std::vector<py::ssize_t> &offsets)
{
host_sz_allocator_t offsets_allocator(exec_q);
host_sz_shp_t host_offsets_shp =
std::make_shared<sz_t>(offsets.size(), offsets_allocator);

std::copy(offsets.begin(), offsets.end(), host_offsets_shp->begin());

std::shared_ptr<shT> host_chc_offsets_shp =
std::make_shared<shT>(n_chcs, sz_allocator);
return host_offsets_shp;
}

host_sz_shp_t make_host_shape_strides(sycl::queue &exec_q,
py::ssize_t n_chcs,
std::vector<py::ssize_t> &shape,
std::vector<py::ssize_t> &inp_strides,
std::vector<py::ssize_t> &dst_strides,
std::vector<py::ssize_t> &chc_strides)
{
auto nelems = shape.size();
host_sz_allocator_t shape_strides_allocator(exec_q);
host_sz_shp_t host_shape_strides_shp =
std::make_shared<sz_t>(nelems * (3 + n_chcs), shape_strides_allocator);

std::copy(shape, shape + shape_len, host_shape_strides_shp->begin());
std::copy(shape.begin(), shape.end(), host_shape_strides_shp->begin());
std::copy(inp_strides.begin(), inp_strides.end(),
host_shape_strides_shp->begin() + shape_len);
host_shape_strides_shp->begin() + nelems);
std::copy(dst_strides.begin(), dst_strides.end(),
host_shape_strides_shp->begin() + 2 * shape_len);
host_shape_strides_shp->begin() + 2 * nelems);
std::copy(chc_strides.begin(), chc_strides.end(),
host_shape_strides_shp->begin() + 3 * shape_len);
host_shape_strides_shp->begin() + 3 * nelems);

std::copy(chc_ptrs.begin(), chc_ptrs.end(), host_chc_ptrs_shp->begin());
std::copy(chc_offsets.begin(), chc_offsets.end(),
host_chc_offsets_shp->begin());
return host_shape_strides_shp;
}

const sycl::event &device_chc_ptrs_copy_ev = exec_q.copy<char *>(
host_chc_ptrs_shp->data(), device_chc_ptrs, host_chc_ptrs_shp->size());
/* This function expects a queue and a non-trivial number of
std::pairs of raw device pointers and host shared pointers
(structured as <device_ptr, shared_ptr>),
then enqueues a copy of the host shared pointer data into
the device pointer.
Assumes the device pointer addresses sufficient memory for
the size of the host memory.
*/
template <typename... DevHostPairs>
std::vector<sycl::event> batched_copy(sycl::queue &exec_q,
DevHostPairs &&...dev_host_pairs)
{
constexpr std::size_t n = sizeof...(DevHostPairs);
static_assert(n > 0, "batched_copy requires at least one argument");

const sycl::event &device_shape_strides_copy_ev = exec_q.copy<py::ssize_t>(
host_shape_strides_shp->data(), device_shape_strides,
host_shape_strides_shp->size());
std::vector<sycl::event> copy_evs;
copy_evs.reserve(n);
(copy_evs.emplace_back(exec_q.copy(dev_host_pairs.second->data(),
dev_host_pairs.first,
dev_host_pairs.second->size())),
...);

const sycl::event &device_chc_offsets_copy_ev = exec_q.copy<py::ssize_t>(
host_chc_offsets_shp->data(), device_chc_offsets,
host_chc_offsets_shp->size());
return copy_evs;
}

/* This function takes as input a queue, sycl::event dependencies,
and a non-trivial number of shared_ptrs and moves them into
a host_task lambda capture, ensuring their lifetime until the
host_task executes.
*/
template <typename... Shps>
sycl::event async_shp_free(sycl::queue &exec_q,
const std::vector<sycl::event> &depends,
Shps &&...shps)
{
constexpr std::size_t n = sizeof...(Shps);
static_assert(n > 0, "async_shp_free requires at least one argument");

const sycl::event &shared_ptr_cleanup_ev =
exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on({device_chc_offsets_copy_ev,
device_shape_strides_copy_ev,
device_chc_ptrs_copy_ev});
cgh.host_task([host_chc_offsets_shp, host_shape_strides_shp,
host_chc_ptrs_shp]() {});
cgh.depends_on(depends);
cgh.host_task([capture = std::tuple(std::move(shps)...)]() {});
});
host_task_events.push_back(shared_ptr_cleanup_ev);

std::vector<sycl::event> param_pack_deps{device_chc_ptrs_copy_ev,
device_shape_strides_copy_ev,
device_chc_offsets_copy_ev};
return param_pack_deps;
return shared_ptr_cleanup_ev;
}

// copied from dpctl, remove if a similar utility is ever exposed
Expand All @@ -149,6 +184,8 @@ std::vector<dpctl::tensor::usm_ndarray> parse_py_chcs(const sycl::queue &q,
return res;
}

} // namespace detail

std::pair<sycl::event, sycl::event>
py_choose(const dpctl::tensor::usm_ndarray &src,
const py::object &py_chcs,
Expand All @@ -158,7 +195,7 @@ std::pair<sycl::event, sycl::event>
const std::vector<sycl::event> &depends)
{
std::vector<dpctl::tensor::usm_ndarray> chcs =
parse_py_chcs(exec_q, py_chcs);
detail::parse_py_chcs(exec_q, py_chcs);

// Python list max size must fit into py_ssize_t
py::ssize_t n_chcs = chcs.size();
Expand Down Expand Up @@ -310,31 +347,37 @@ std::pair<sycl::event, sycl::event>
host_task_events.reserve(2);

std::vector<sycl::event> pack_deps;
std::vector<py::ssize_t> common_shape;
std::vector<py::ssize_t> src_strides;
std::vector<py::ssize_t> dst_strides;
if (nd == 0) {
// special case where all inputs are scalars
// need to pass src, dst shape=1 and strides=0
// chc_strides already initialized to 0 so ignore
std::array<py::ssize_t, 1> scalar_sh{1};
std::vector<py::ssize_t> src_strides{0};
std::vector<py::ssize_t> dst_strides{0};

pack_deps = _populate_choose_kernel_params(
exec_q, host_task_events, packed_chc_ptrs.get(),
packed_shapes_strides.get(), packed_chc_offsets.get(),
scalar_sh.data(), sh_nelems, src_strides, dst_strides, chc_strides,
chc_ptrs, chc_offsets, n_chcs);
common_shape = {1};
src_strides = {0};
dst_strides = {0};
}
else {
auto src_strides = src.get_strides_vector();
auto dst_strides = dst.get_strides_vector();

pack_deps = _populate_choose_kernel_params(
exec_q, host_task_events, packed_chc_ptrs.get(),
packed_shapes_strides.get(), packed_chc_offsets.get(), src_shape,
sh_nelems, src_strides, dst_strides, chc_strides, chc_ptrs,
chc_offsets, n_chcs);
common_shape = src.get_shape_vector();
src_strides = src.get_strides_vector();
dst_strides = dst.get_strides_vector();
}

auto host_chc_ptrs = detail::make_host_ptrs(exec_q, chc_ptrs);
auto host_chc_offsets = detail::make_host_offsets(exec_q, chc_offsets);
auto host_shape_strides = detail::make_host_shape_strides(
exec_q, n_chcs, common_shape, src_strides, dst_strides, chc_strides);

pack_deps = detail::batched_copy(
exec_q, std::make_pair(packed_chc_ptrs.get(), host_chc_ptrs),
std::make_pair(packed_chc_offsets.get(), host_chc_offsets),
std::make_pair(packed_shapes_strides.get(), host_shape_strides));

host_task_events.push_back(
detail::async_shp_free(exec_q, pack_deps, host_chc_ptrs,
host_chc_offsets, host_shape_strides));

std::vector<sycl::event> all_deps;
all_deps.reserve(depends.size() + pack_deps.size());
all_deps.insert(std::end(all_deps), std::begin(pack_deps),
Expand Down

0 comments on commit 6b81516

Please sign in to comment.