diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp index ab3185e13c..744517b2c6 100644 --- a/dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp +++ b/dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp @@ -149,8 +149,9 @@ topk_full_merge_sort_impl(sycl::queue &exec_q, std::size_t src_idx = iter_gid * axis_nelems + axis_gid; std::size_t dst_idx = iter_gid * k + axis_gid; - auto res_ind = index_data[src_idx]; - vals_tp[dst_idx] = arg_tp[res_ind]; + const IndexTy res_ind = index_data[src_idx]; + const argTy v = arg_tp[res_ind]; + vals_tp[dst_idx] = v; inds_tp[dst_idx] = res_ind % axis_nelems; }); }); @@ -425,8 +426,9 @@ sycl::event topk_merge_impl( const std::size_t src_idx = iter_gid * alloc_len + axis_gid; const std::size_t dst_idx = gid; - const auto res_ind = index_data[src_idx]; - vals_tp[dst_idx] = arg_tp[res_ind]; + const IndexTy res_ind = index_data[src_idx]; + const argTy v = arg_tp[res_ind]; + vals_tp[dst_idx] = v; inds_tp[dst_idx] = (res_ind % axis_nelems); }); }); @@ -538,11 +540,14 @@ sycl::event topk_radix_impl(sycl::queue &exec_q, const std::size_t dst_idx = gid; const IndexTy res_ind = tmp_tp[src_idx]; - vals_tp[dst_idx] = arg_tp[res_ind]; + const v = arg_tp[res_ind]; + vals_tp[dst_idx] = v; inds_tp[dst_idx] = (res_ind % axis_nelems); }); }); + write_topk_ev.wait(); + sycl::event cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(write_topk_ev);