diff --git a/operators/cuda/cuda_ops.cc b/operators/cuda/cuda_ops.cc index 175758089..8fc3105b9 100644 --- a/operators/cuda/cuda_ops.cc +++ b/operators/cuda/cuda_ops.cc @@ -8,16 +8,12 @@ #include "cuda/fast_gelu.h" #include "cuda/mul_sigmoid.h" #include "cuda/negxplus1.h" -<<<<<<< HEAD #include "cuda/replace_zero.h" -======= #include "cuda/scatter_nd_of_shape.h" ->>>>>>> f5055466d5376059c2ea74e3cea46e16a537bc0d #include "cuda/transpose_cast.h" #endif FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { - using AddSharedInputFloat32Type = typename contrib::AddOrMulSharedInput; using MulSharedInputFloat32Type = typename contrib::AddOrMulSharedInput; @@ -28,7 +24,6 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { using Transpose2DCastFloat16ToFloat32Type = typename contrib::Transpose2DCast; #endif - static OrtOpLoader op_loader( []() { return nullptr; } #ifdef USE_CUDA diff --git a/operators/cuda/replace_zero.h b/operators/cuda/replace_zero.h index 72f760db2..e7974739d 100644 --- a/operators/cuda/replace_zero.h +++ b/operators/cuda/replace_zero.h @@ -13,6 +13,9 @@ namespace contrib { * * Y = X.copy() * X[X == 0] = c +* +* This operation usually appears when a tensor is updated with an operator Equal and Where. +* This kernel avoids the creation of one null tensor. */ template struct ReplaceZero { diff --git a/operators/cuda/replace_zero_impl.cu b/operators/cuda/replace_zero_impl.cu index f0c1414f2..43952c303 100644 --- a/operators/cuda/replace_zero_impl.cu +++ b/operators/cuda/replace_zero_impl.cu @@ -12,11 +12,13 @@ using namespace Ort::Custom; -template __device__ __inline__ T _replace_zero(const T x, const T by) { +template +__device__ __inline__ T _replace_zero(const T x, const T by) { return x == (T)0 ? by : x; } -template <> __device__ __inline__ half _replace_zero(const half x, const half by) { +template <> +__device__ __inline__ half _replace_zero(const half x, const half by) { #if __CUDA_ARCH__ < 700 return __half2float(x) == 0 ? by : x; #else @@ -25,21 +27,23 @@ template <> __device__ __inline__ half _replace_zero(const half x, const half by } template -__global__ void ReplaceZeroKernel(T *output_data, const T *input_data, CUDA_LONG N, const T by) { +__global__ void ReplaceZeroKernel(T* output_data, const T* input_data, CUDA_LONG N, const T by) { CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x; if (id >= N) return; output_data[id] = _replace_zero(input_data[id], by); } -template T _cvt(float value) { return (T)value; } +template +T _cast(float value) { return (T)value; } -template <> half _cvt(float value) { return __float2half(value); } +template <> +half _cast(float value) { return __float2half(value); } template cudaError_t _LaunchReplaceZeroKernel(cudaStream_t stream, int input_length, const T* input_data, T* output_data, float by) { if (input_length == 0) - return cudaGetLastError(); + return cudaGetLastError(); using TT = typename contrib::CudaT::MappedType; CUDA_LONG N = static_cast(input_length); @@ -47,9 +51,9 @@ cudaError_t _LaunchReplaceZeroKernel(cudaStream_t stream, int input_length, cons const int num_threads_per_block = 256; const int num_elements_per_thread = (N + num_threads_per_block - 1) / num_threads_per_block; - TT cby = _cvt(by); + TT cby = _cast(by); ReplaceZeroKernel<<>>( - reinterpret_cast(output_data), reinterpret_cast(input_data), N, cby); + reinterpret_cast(output_data), reinterpret_cast(input_data), N, cby); return cudaGetLastError(); } diff --git a/operators/cuda/replace_zero_impl.cuh b/operators/cuda/replace_zero_impl.cuh index 048a39363..7d975d4d5 100644 --- a/operators/cuda/replace_zero_impl.cuh +++ b/operators/cuda/replace_zero_impl.cuh @@ -6,4 +6,4 @@ #include template -cudaError_t LaunchReplaceZeroKernel(cudaStream_t stream, int input_length, const T* input_data, T* output_data, float by); +cudaError_t LaunchReplaceZeroKernel(cudaStream_t stream, int input_length, const T* input_data, T* output_data, float by);