Skip to content

Commit

Permalink
gpu: nvidia: Added support for native host task extension
Browse files Browse the repository at this point in the history
  • Loading branch information
ShanoToni committed Sep 26, 2024
1 parent da6d11b commit 856db1d
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 21 deletions.
34 changes: 15 additions & 19 deletions src/gpu/nvidia/cudnn_batch_normalization_executor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,14 +192,12 @@ struct bnorm_exec_base_t {
xpu::sycl::interop_memory_arg_t<::sycl::access::mode::write>
arg_scale,
float val, const size_t n) const {
cuda_stream->interop_task([&](::sycl::handler &cgh) {
T *scale_ptr = static_cast<T *>(arg_scale.get_native_pointer(ih));
CUDA_EXECUTE_FUNC(cuMemsetD32Async,
reinterpret_cast<CUdeviceptr>(scale_ptr),
reinterpret_cast<int &>(val), n,
cuda_stream->get_underlying_stream());
cudaDeviceSynchronize();
});
T *scale_ptr = static_cast<T *>(arg_scale.get_native_pointer(ih));
CUDA_EXECUTE_FUNC(cuMemsetD32Async,
reinterpret_cast<CUdeviceptr>(scale_ptr),
reinterpret_cast<int &>(val), n,
cuda_stream->get_underlying_stream());
sync_device();
}

// Handle the cases when mean and var are read-only accessors or nullptr
Expand All @@ -216,17 +214,15 @@ struct bnorm_exec_base_t {
xpu::sycl::interop_memory_arg_t<::sycl::access_mode::write> arg_var,
const size_t n) const {
constexpr T mean_var_val = 0;
cuda_stream->interop_task([&](::sycl::handler &cgh) {
T *mean_ptr = static_cast<T *>(arg_mean.get_native_pointer(ih));
T *var_ptr = static_cast<T *>(arg_var.get_native_pointer(ih));
CUDA_EXECUTE_FUNC(cuMemsetD32Async,
reinterpret_cast<CUdeviceptr>(mean_ptr), mean_var_val, n,
cuda_stream->get_underlying_stream());
CUDA_EXECUTE_FUNC(cuMemsetD32Async,
reinterpret_cast<CUdeviceptr>(var_ptr), mean_var_val, n,
cuda_stream->get_underlying_stream());
cudaDeviceSynchronize();
});
T *mean_ptr = static_cast<T *>(arg_mean.get_native_pointer(ih));
T *var_ptr = static_cast<T *>(arg_var.get_native_pointer(ih));
CUDA_EXECUTE_FUNC(cuMemsetD32Async,
reinterpret_cast<CUdeviceptr>(mean_ptr), mean_var_val, n,
cuda_stream->get_underlying_stream());
CUDA_EXECUTE_FUNC(cuMemsetD32Async,
reinterpret_cast<CUdeviceptr>(var_ptr), mean_var_val, n,
cuda_stream->get_underlying_stream());
sync_device();
}
};

Expand Down
8 changes: 6 additions & 2 deletions src/gpu/nvidia/sycl_cuda_compat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,13 @@ T get_native_mem(const interop_handle &ih, U acc) {
ih.get_native_mem<::sycl::backend::ext_oneapi_cuda>(acc));
}

template <typename T>
void host_task(::sycl::handler &cgh, const T &task) {
template <typename HandlerT, typename FnT>
void host_task(HandlerT &cgh, const FnT &task) {
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
cgh.ext_codeplay_enqueue_native_command(task);
#else
cgh.host_task(task);
#endif
}

template <typename native_object_t, typename sycl_object_t,
Expand Down
6 changes: 6 additions & 0 deletions src/gpu/nvidia/sycl_cuda_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ inline status_t check_device(dnnl::impl::engine_kind_t eng_kind) {
: status::invalid_arguments);
}

static void sync_device() {
#ifndef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
cudaDeviceSynchronize();
#endif
}

static void convert_dnnl_dims_array(
const dnnl_dim_t *dims, int *new_dims, int n_dims) {
for (size_t i = 0; i < n_dims; i++) {
Expand Down

0 comments on commit 856db1d

Please sign in to comment.