From 600230cbc4e7eff75b363204824d7c9ec3c8f1e1 Mon Sep 17 00:00:00 2001 From: ShanoToni Date: Mon, 16 Sep 2024 17:18:03 +0100 Subject: [PATCH] gpu: nvidia: Added support for native host task extension --- .../cudnn_batch_normalization_executor.hpp | 34 ++++++++----------- src/gpu/nvidia/sycl_cuda_compat.hpp | 8 +++-- src/gpu/nvidia/sycl_cuda_utils.hpp | 6 ++++ 3 files changed, 27 insertions(+), 21 deletions(-) diff --git a/src/gpu/nvidia/cudnn_batch_normalization_executor.hpp b/src/gpu/nvidia/cudnn_batch_normalization_executor.hpp index c9d9cc8df05..ca742dd4678 100644 --- a/src/gpu/nvidia/cudnn_batch_normalization_executor.hpp +++ b/src/gpu/nvidia/cudnn_batch_normalization_executor.hpp @@ -192,14 +192,12 @@ struct bnorm_exec_base_t { xpu::sycl::interop_memory_arg_t<::sycl::access::mode::write> arg_scale, float val, const size_t n) const { - cuda_stream->interop_task([&](::sycl::handler &cgh) { - T *scale_ptr = static_cast(arg_scale.get_native_pointer(ih)); - CUDA_EXECUTE_FUNC(cuMemsetD32Async, - reinterpret_cast(scale_ptr), - reinterpret_cast(val), n, - cuda_stream->get_underlying_stream()); - cudaDeviceSynchronize(); - }); + T *scale_ptr = static_cast(arg_scale.get_native_pointer(ih)); + CUDA_EXECUTE_FUNC(cuMemsetD32Async, + reinterpret_cast(scale_ptr), + reinterpret_cast(val), n, + cuda_stream->get_underlying_stream()); + sync_device(); } // Handle the cases when mean and var are read-only accessors or nullptr @@ -216,17 +214,15 @@ struct bnorm_exec_base_t { xpu::sycl::interop_memory_arg_t<::sycl::access_mode::write> arg_var, const size_t n) const { constexpr T mean_var_val = 0; - cuda_stream->interop_task([&](::sycl::handler &cgh) { - T *mean_ptr = static_cast(arg_mean.get_native_pointer(ih)); - T *var_ptr = static_cast(arg_var.get_native_pointer(ih)); - CUDA_EXECUTE_FUNC(cuMemsetD32Async, - reinterpret_cast(mean_ptr), mean_var_val, n, - cuda_stream->get_underlying_stream()); - CUDA_EXECUTE_FUNC(cuMemsetD32Async, - reinterpret_cast(var_ptr), mean_var_val, n, - cuda_stream->get_underlying_stream()); - cudaDeviceSynchronize(); - }); + T *mean_ptr = static_cast(arg_mean.get_native_pointer(ih)); + T *var_ptr = static_cast(arg_var.get_native_pointer(ih)); + CUDA_EXECUTE_FUNC(cuMemsetD32Async, + reinterpret_cast(mean_ptr), mean_var_val, n, + cuda_stream->get_underlying_stream()); + CUDA_EXECUTE_FUNC(cuMemsetD32Async, + reinterpret_cast(var_ptr), mean_var_val, n, + cuda_stream->get_underlying_stream()); + sync_device(); } }; diff --git a/src/gpu/nvidia/sycl_cuda_compat.hpp b/src/gpu/nvidia/sycl_cuda_compat.hpp index 197008123c0..5e6196f313d 100644 --- a/src/gpu/nvidia/sycl_cuda_compat.hpp +++ b/src/gpu/nvidia/sycl_cuda_compat.hpp @@ -35,9 +35,13 @@ T get_native_mem(const interop_handle &ih, U acc) { ih.get_native_mem<::sycl::backend::ext_oneapi_cuda>(acc)); } -template -void host_task(::sycl::handler &cgh, const T &task) { +template +void host_task(HandlerT &cgh, const FnT &task) { +#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND + cgh.ext_codeplay_enqueue_native_command(task); +#else cgh.host_task(task); +#endif } template