Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rocfft][cufft] DFT update host task to use native command #578

Merged
merged 11 commits into from
Oct 14, 2024
8 changes: 4 additions & 4 deletions src/dft/backends/cufft/backward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
auto inout_acc = inout.template get_access<sycl::access::mode::read_write>(cgh);
commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh);

cgh.host_task([=](sycl::interop_handle ih) {
detail::cufft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, plan);

auto inout_native = reinterpret_cast<fwd<descriptor_type> *>(
Expand Down Expand Up @@ -117,7 +117,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
auto out_acc = out.template get_access<sycl::access::mode::read_write>(cgh);
commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh);

cgh.host_task([=](sycl::interop_handle ih) {
detail::cufft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, plan);

auto in_native = reinterpret_cast<void *>(
Expand Down Expand Up @@ -171,7 +171,7 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, fwd<descriptor
cgh.depends_on(dependencies);
commit->depend_on_last_usm_workspace_event_if_rqd(cgh);

cgh.host_task([=](sycl::interop_handle ih) {
detail::cufft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, plan);

detail::cufft_execute<detail::Direction::Backward, fwd<descriptor_type>>(
Expand Down Expand Up @@ -217,7 +217,7 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, bwd<descriptor
cgh.depends_on(dependencies);
commit->depend_on_last_usm_workspace_event_if_rqd(cgh);

cgh.host_task([=](sycl::interop_handle ih) {
detail::cufft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, plan);

detail::cufft_execute<detail::Direction::Backward, fwd<descriptor_type>>(
Expand Down
27 changes: 26 additions & 1 deletion src/dft/backends/cufft/execute_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,16 @@ void cufft_execute(const std::string &func, CUstream stream, cufftHandle plan, v
}
}
}

#ifndef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
// If not using the enqueue native extension, the host task must wait on the
// asynchronous operation to complete. Otherwise it report the operation
// as complete early.
auto result = cuStreamSynchronize(stream);
if (result != CUDA_SUCCESS) {
throw oneapi::mkl::exception("dft/backends/cufft", func,
"cuStreamSynchronize returned " + std::to_string(result));
}
#endif
}

inline CUstream setup_stream(const std::string &func, sycl::interop_handle ih, cufftHandle plan) {
Expand All @@ -143,6 +147,27 @@ inline CUstream setup_stream(const std::string &func, sycl::interop_handle ih, c
return stream;
}


/** Wrap interop API to launch interop host task.
*
* @tparam HandlerT The command group handler type
* @tparam FnT The body of the enqueued task
*
* Either uses host task interop API, or enqueue native command extension.
* This extension avoids host synchronization after
* the CUDA call is complete.
*/
template <typename HandlerT, typename FnT>
static inline void cufft_enqueue_task(HandlerT&& cgh, FnT&& f) {
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
cgh.ext_codeplay_enqueue_native_command([=](sycl::interop_handle ih){
#else
cgh.host_task([=](sycl::interop_handle ih){
#endif
f(std::move(ih));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it needed to duplicate this in both cuFFT and rocFFT backends and in various domains (BLAS, LAPACK, FFT)? Can't we have one wrapper used across all domains and backends?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that the domains have always been separated on purpose. This makes review much easier as any change affecting common code would technically require an approval from every domain owners.
I agree this could be discussed in an issue. To my knowledge there is very little code that could be common across domains, other than the types and exceptions which are already common.

});
}

} // namespace oneapi::mkl::dft::cufft::detail

#endif
8 changes: 4 additions & 4 deletions src/dft/backends/cufft/forward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ ONEMKL_EXPORT void compute_forward(descriptor_type &desc,
auto inout_acc = inout.template get_access<sycl::access::mode::read_write>(cgh);
commit->add_buffer_workspace_dependency_if_rqd("compute_forward", cgh);

cgh.host_task([=](sycl::interop_handle ih) {
detail::cufft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, plan);

auto inout_native = reinterpret_cast<fwd<descriptor_type> *>(
Expand Down Expand Up @@ -119,7 +119,7 @@ ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer<fwd<descr
auto out_acc = out.template get_access<sycl::access::mode::read_write>(cgh);
commit->add_buffer_workspace_dependency_if_rqd("compute_forward", cgh);

cgh.host_task([=](sycl::interop_handle ih) {
detail::cufft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, plan);

auto in_native = reinterpret_cast<void *>(
Expand Down Expand Up @@ -173,7 +173,7 @@ ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, fwd<descriptor_
cgh.depends_on(dependencies);
commit->depend_on_last_usm_workspace_event_if_rqd(cgh);

cgh.host_task([=](sycl::interop_handle ih) {
detail::cufft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, plan);

detail::cufft_execute<detail::Direction::Forward, fwd<descriptor_type>>(
Expand Down Expand Up @@ -219,7 +219,7 @@ ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, fwd<descriptor_
cgh.depends_on(dependencies);
commit->depend_on_last_usm_workspace_event_if_rqd(cgh);

cgh.host_task([=](sycl::interop_handle ih) {
detail::cufft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, plan);

detail::cufft_execute<detail::Direction::Forward, fwd<descriptor_type>>(
Expand Down
41 changes: 17 additions & 24 deletions src/dft/backends/rocfft/backward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,13 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
auto inout_acc = inout.template get_access<sycl::access::mode::read_write>(cgh);
commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh);

cgh.host_task([=](sycl::interop_handle ih) {
detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, info);

auto inout_native = reinterpret_cast<void *>(
reinterpret_cast<fwd<descriptor_type> *>(detail::native_mem(ih, inout_acc)) +
offsets[0]);
detail::execute_checked(func_name, plan, &inout_native, nullptr, info);
detail::sync_checked(func_name, stream);
detail::execute_checked(func_name, stream, plan, &inout_native, nullptr, info);
});
});
}
Expand Down Expand Up @@ -113,7 +112,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
auto inout_im_acc = inout_im.template get_access<sycl::access::mode::read_write>(cgh);
commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh);

cgh.host_task([=](sycl::interop_handle ih) {
detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, info);

std::array<void *, 2> inout_native{
Expand All @@ -124,8 +123,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
detail::native_mem(ih, inout_im_acc)) +
offsets[0])
};
detail::execute_checked(func_name, plan, inout_native.data(), nullptr, info);
detail::sync_checked(func_name, stream);
detail::execute_checked(func_name, stream, plan, inout_native.data(), nullptr, info);
});
});
}
Expand All @@ -148,7 +146,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
auto out_acc = out.template get_access<sycl::access::mode::read_write>(cgh);
commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh);

cgh.host_task([=](sycl::interop_handle ih) {
detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
const std::string func_name = "compute_backward(desc, in, out)";
auto stream = detail::setup_stream(func_name, ih, info);

Expand All @@ -158,8 +156,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
auto out_native = reinterpret_cast<void *>(
reinterpret_cast<fwd<descriptor_type> *>(detail::native_mem(ih, out_acc)) +
offsets[1]);
detail::execute_checked(func_name, plan, &in_native, &out_native, info);
detail::sync_checked(func_name, stream);
detail::execute_checked(func_name, stream, plan, &in_native, &out_native, info);
});
});
}
Expand All @@ -184,7 +181,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
auto out_im_acc = out_im.template get_access<sycl::access::mode::read_write>(cgh);
commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh);

cgh.host_task([=](sycl::interop_handle ih) {
detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
const std::string func_name = "compute_backward(desc, in_re, in_im, out_re, out_im)";
auto stream = detail::setup_stream(func_name, ih, info);

Expand All @@ -204,8 +201,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
detail::native_mem(ih, out_im_acc)) +
offsets[1])
};
detail::execute_checked(func_name, plan, in_native.data(), out_native.data(), info);
detail::sync_checked(func_name, stream);
detail::execute_checked(func_name, stream, plan, in_native.data(), out_native.data(), info);
});
});
}
Expand Down Expand Up @@ -239,12 +235,11 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, fwd<descriptor
cgh.depends_on(deps);
commit->depend_on_last_usm_workspace_event_if_rqd(cgh);

cgh.host_task([=](sycl::interop_handle ih) {
detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, info);

void *inout_ptr = inout;
detail::execute_checked(func_name, plan, &inout_ptr, nullptr, info);
detail::sync_checked(func_name, stream);
detail::execute_checked(func_name, stream, plan, &inout_ptr, nullptr, info);
});
});
commit->set_last_usm_workspace_event_if_rqd(sycl_event);
Expand Down Expand Up @@ -273,12 +268,12 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, scalar<descrip
cgh.depends_on(deps);
commit->depend_on_last_usm_workspace_event_if_rqd(cgh);

cgh.host_task([=](sycl::interop_handle ih) {
detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, info);

std::array<void *, 2> inout_native{ inout_re + offsets[0], inout_im + offsets[0] };
detail::execute_checked(func_name, plan, inout_native.data(), nullptr, info);
detail::sync_checked(func_name, stream);
detail::execute_checked(func_name, stream, plan, inout_native.data(), nullptr, info);

});
});
commit->set_last_usm_workspace_event_if_rqd(sycl_event);
Expand All @@ -305,14 +300,13 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, bwd<descriptor
cgh.depends_on(deps);
commit->depend_on_last_usm_workspace_event_if_rqd(cgh);

cgh.host_task([=](sycl::interop_handle ih) {
detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
const std::string func_name = "compute_backward(desc, in, out, deps)";
auto stream = detail::setup_stream(func_name, ih, info);

void *in_ptr = in;
void *out_ptr = out;
detail::execute_checked(func_name, plan, &in_ptr, &out_ptr, info);
detail::sync_checked(func_name, stream);
detail::execute_checked(func_name, stream, plan, &in_ptr, &out_ptr, info);
});
});
commit->set_last_usm_workspace_event_if_rqd(sycl_event);
Expand All @@ -336,15 +330,14 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, scalar<descrip
cgh.depends_on(deps);
commit->depend_on_last_usm_workspace_event_if_rqd(cgh);

cgh.host_task([=](sycl::interop_handle ih) {
detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
const std::string func_name =
"compute_backward(desc, in_re, in_im, out_re, out_im, deps)";
auto stream = detail::setup_stream(func_name, ih, info);

std::array<void *, 2> in_native{ in_re + offsets[0], in_im + offsets[0] };
std::array<void *, 2> out_native{ out_re + offsets[1], out_im + offsets[1] };
detail::execute_checked(func_name, plan, in_native.data(), out_native.data(), info);
detail::sync_checked(func_name, stream);
detail::execute_checked(func_name, stream, plan, in_native.data(), out_native.data(), info);
});
});
commit->set_last_usm_workspace_event_if_rqd(sycl_event);
Expand Down
38 changes: 32 additions & 6 deletions src/dft/backends/rocfft/execute_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,20 +76,46 @@ inline hipStream_t setup_stream(const std::string &func, sycl::interop_handle &i
}

inline void sync_checked(const std::string &func, hipStream_t stream) {
auto result = hipStreamSynchronize(stream);
if (result != hipSuccess) {
throw oneapi::mkl::exception("dft/backends/rocfft", func,
"hipStreamSynchronize returned " + std::to_string(result));
}
auto result = hipStreamSynchronize(stream);
if (result != hipSuccess) {
throw oneapi::mkl::exception("dft/backends/rocfft", func,
"hipStreamSynchronize returned " + std::to_string(result));
}
}

inline void execute_checked(const std::string &func, const rocfft_plan plan, void *in_buffer[],
inline void execute_checked(const std::string &func, hipStream_t stream, const rocfft_plan plan, void *in_buffer[],
void *out_buffer[], rocfft_execution_info info) {
auto result = rocfft_execute(plan, in_buffer, out_buffer, info);
if (result != rocfft_status_success) {
throw oneapi::mkl::exception("dft/backends/rocfft", func,
"rocfft_execute returned " + std::to_string(result));
}
#ifndef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
// If not using equeue native extension, the host task must wait on the
// asynchronous operation to complete. Otherwise it report the operation
// as complete early.
sync_checked(func, stream);
#endif
}

/** Wrap interop API to launch interop host task.
*
* @tparam HandlerT The command group handler type
* @tparam FnT The body of the enqueued task
*
* Either uses host task interop API, or enqueue native command extension.
* This extension avoids host synchronization after
* the CUDA call is complete.
*/
template <typename HandlerT, typename FnT>
static inline void rocfft_enqueue_task(HandlerT&& cgh, FnT&& f) {
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
cgh.ext_codeplay_enqueue_native_command([=](sycl::interop_handle ih){
#else
cgh.host_task([=](sycl::interop_handle ih){
#endif
f(std::move(ih));
});
}

} // namespace oneapi::mkl::dft::rocfft::detail
Expand Down
Loading