From 27d4b5c586c802172cbba0bb7c3c70cc2384cf6b Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Thu, 2 May 2024 16:45:48 +0200 Subject: [PATCH] clang --- operators/contrib/cuda/scatter_nd_of_shape.cu | 135 +++++++++--------- operators/contrib/cuda/scatter_nd_of_shape.h | 33 ++--- 2 files changed, 88 insertions(+), 80 deletions(-) diff --git a/operators/contrib/cuda/scatter_nd_of_shape.cu b/operators/contrib/cuda/scatter_nd_of_shape.cu index b9461189b..35c6aab34 100644 --- a/operators/contrib/cuda/scatter_nd_of_shape.cu +++ b/operators/contrib/cuda/scatter_nd_of_shape.cu @@ -5,7 +5,8 @@ namespace ortops { -#define _ENFORCE(cond, msg) if (!(cond)) ORTX_CXX_API_THROW(msg, ORT_RUNTIME_EXCEPTION); +#define _ENFORCE(cond, msg) \ + if (!(cond)) ORTX_CXX_API_THROW(msg, ORT_RUNTIME_EXCEPTION); #ifndef HIP_LONG #define HIP_LONG int32_t @@ -17,14 +18,16 @@ namespace ortops { struct GridDim { enum : CUDA_LONG { - maxThreadsPerBlock = 256, // max threads per block - maxElementsPerThread = 4, // max element processed per thread + maxThreadsPerBlock = 256, // max threads per block + maxElementsPerThread = 4, // max element processed per thread }; }; -template __device__ __forceinline__ void _add_inplace(T &x, const T a) { x += a; } +template +__device__ __forceinline__ void _add_inplace(T& x, const T a) { x += a; } -template<> __device__ __forceinline__ void _add_inplace(half &x, const half a) { +template <> +__device__ __forceinline__ void _add_inplace(half& x, const half a) { #if __CUDA_ARCH__ < 700 x = __float2half(__half2float(x) + __half2float(a)); #else @@ -34,8 +37,8 @@ template<> __device__ __forceinline__ void _add_inplace(half &x, const half a) { template __global__ void -addition_inplace_kernel(T *__restrict__ output_data, const int64_t *__restrict__ indices_data, - const T *__restrict__ updates_data, const CUDA_LONG indice_size, +addition_inplace_kernel(T* __restrict__ output_data, const int64_t* __restrict__ indices_data, + const T* __restrict__ updates_data, const CUDA_LONG indice_size, const CUDA_LONG nrows, const CUDA_LONG stride) { HIP_LONG id = blockDim.x * blockIdx.x + threadIdx.x; if (id >= stride) @@ -55,56 +58,59 @@ addition_inplace_kernel(T *__restrict__ output_data, const int64_t *__restrict__ ////////////////// template -void *ScatterNDOfShapeOp::CreateKernel(const OrtApi &api, const OrtKernelInfo *info) const { +void* ScatterNDOfShapeOp::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const { return std::make_unique>(api, info).release(); } -template const char *ScatterNDOfShapeOp::GetName() const { +template +const char* ScatterNDOfShapeOp::GetName() const { return "ScatterNDOfShape"; } -template const char *ScatterNDOfShapeOp::GetExecutionProviderType() const { +template +const char* ScatterNDOfShapeOp::GetExecutionProviderType() const { return "CUDAExecutionProvider"; } -template size_t ScatterNDOfShapeOp::GetInputTypeCount() const { return 3; }; +template +size_t ScatterNDOfShapeOp::GetInputTypeCount() const { return 3; }; template <> ONNXTensorElementDataType ScatterNDOfShapeOp::GetInputType(std::size_t index) const { switch (index) { - case 0: - case 1: - return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; - case 2: - return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - default: - ORTX_CXX_API_THROW("Wrong input index.", ORT_RUNTIME_EXCEPTION); + case 0: + case 1: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; + case 2: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + default: + ORTX_CXX_API_THROW("Wrong input index.", ORT_RUNTIME_EXCEPTION); } } template <> ONNXTensorElementDataType ScatterNDOfShapeOp::GetInputType(std::size_t index) const { switch (index) { - case 0: - case 1: - return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; - case 2: - return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; - default: - ORTX_CXX_API_THROW("Wrong input index.", ORT_RUNTIME_EXCEPTION); + case 0: + case 1: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; + case 2: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; + default: + ORTX_CXX_API_THROW("Wrong input index.", ORT_RUNTIME_EXCEPTION); } } template OrtMemType ScatterNDOfShapeOp::GetInputMemoryType(std::size_t index) const { switch (index) { - case 0: - return OrtMemTypeCPUInput; - case 1: - case 2: - return OrtMemTypeDefault; - default: - ORTX_CXX_API_THROW("Wrong input index.", ORT_RUNTIME_EXCEPTION); + case 0: + return OrtMemTypeCPUInput; + case 1: + case 2: + return OrtMemTypeDefault; + default: + ORTX_CXX_API_THROW("Wrong input index.", ORT_RUNTIME_EXCEPTION); } } @@ -112,25 +118,26 @@ template OrtCustomOpInputOutputCharacteristic ScatterNDOfShapeOp::GetInputCharacteristic(std::size_t index) const { switch (index) { - case 0: - case 1: - case 2: - return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; - default: - ORTX_CXX_API_THROW("Wrong output index.", ORT_RUNTIME_EXCEPTION); + case 0: + case 1: + case 2: + return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; + default: + ORTX_CXX_API_THROW("Wrong output index.", ORT_RUNTIME_EXCEPTION); } } -template size_t ScatterNDOfShapeOp::GetOutputTypeCount() const { return 1; } +template +size_t ScatterNDOfShapeOp::GetOutputTypeCount() const { return 1; } template <> ONNXTensorElementDataType ScatterNDOfShapeOp::GetOutputType(std::size_t index) const { // D, scale D switch (index) { - case 0: - return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - default: - ORTX_CXX_API_THROW("Wrong output index.", ORT_RUNTIME_EXCEPTION); + case 0: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + default: + ORTX_CXX_API_THROW("Wrong output index.", ORT_RUNTIME_EXCEPTION); } } @@ -138,10 +145,10 @@ template <> ONNXTensorElementDataType ScatterNDOfShapeOp::GetOutputType(std::size_t index) const { // D, scale D switch (index) { - case 0: - return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; - default: - ORTX_CXX_API_THROW("Wrong output index.", ORT_RUNTIME_EXCEPTION); + case 0: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; + default: + ORTX_CXX_API_THROW("Wrong output index.", ORT_RUNTIME_EXCEPTION); } } @@ -149,10 +156,10 @@ template OrtCustomOpInputOutputCharacteristic ScatterNDOfShapeOp::GetOutputCharacteristic(std::size_t index) const { switch (index) { - case 0: - return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; - default: - ORTX_CXX_API_THROW("Wrong output index", ORT_RUNTIME_EXCEPTION); + case 0: + return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; + default: + ORTX_CXX_API_THROW("Wrong output index", ORT_RUNTIME_EXCEPTION); } } @@ -161,8 +168,8 @@ ScatterNDOfShapeOp::GetOutputCharacteristic(std::size_t index) const { /////////////////// template -ScatterNDOfShapeKernel::ScatterNDOfShapeKernel(const OrtApi &api, - const OrtKernelInfo *info) { +ScatterNDOfShapeKernel::ScatterNDOfShapeKernel(const OrtApi& api, + const OrtKernelInfo* info) { char value_string[1000]; std::size_t size = 1000; ThrowOnError(api, api.KernelInfoGetAttribute_string(info, "reduction", value_string, &size)); @@ -178,7 +185,8 @@ ScatterNDOfShapeKernel::ScatterNDOfShapeKernel(const OrtApi &api, maxThreadPerBlock_ = prop.maxThreadsPerBlock; } -template void ScatterNDOfShapeKernel::Compute(OrtKernelContext *context) { +template +void ScatterNDOfShapeKernel::Compute(OrtKernelContext* context) { Ort::KernelContext ctx(context); int n_inputs = ctx.GetInputCount(); @@ -197,13 +205,13 @@ template void ScatterNDOfShapeKernel::Compute(OrtKernelContext * auto memi = updates.GetTensorMemoryInfo(); _ENFORCE(memi.GetDeviceType() == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU, - "Tensor updates is not on GPU."); + "Tensor updates is not on GPU."); auto mem = shape.GetTensorMemoryInfo(); _ENFORCE( mem.GetDeviceType() == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU, "Input shape is not on CPU."); - const int64_t *X = shape.GetTensorData(); + const int64_t* X = shape.GetTensorData(); std::vector dims(X, X + dimensions[0]); output = ctx.GetOutput(0, dims); @@ -212,12 +220,11 @@ template void ScatterNDOfShapeKernel::Compute(OrtKernelContext * if (reduction_ == Reduction::Add && indices_shape[indices_shape.size() - 1] == 1 && input_shape.size() == 2 && input_shape[input_shape.size() - 1] >= maxThreadPerBlock_) { - size_t indice_size = static_cast(onnx_c_ops::flattened_dimension(indices_shape)); size_t update_size = static_cast(onnx_c_ops::flattened_dimension(updates_shape)); _ENFORCE(update_size == indice_size * input_shape[input_shape.size() - 1], - "Size mismatch."); + "Size mismatch."); ComputeNoAtomic(stream, input_shape, indices_shape, output.GetTensorMutableData(), indices.GetTensorData(), updates.GetTensorData()); @@ -227,11 +234,11 @@ template void ScatterNDOfShapeKernel::Compute(OrtKernelContext * } template -void ScatterNDOfShapeKernel::ComputeNoAtomic(cudaStream_t &stream, - const std::vector &input_shape, - const std::vector &indices_shape, - T *output_data, const int64_t *indices_data, - const T *updates_data) const { +void ScatterNDOfShapeKernel::ComputeNoAtomic(cudaStream_t& stream, + const std::vector& input_shape, + const std::vector& indices_shape, + T* output_data, const int64_t* indices_data, + const T* updates_data) const { // The kernel is slow if there are a lot of duplicates. // reduction_ == Reduction::add // indices_shape[indices_shape.size() - 1] == 1 @@ -257,4 +264,4 @@ void ScatterNDOfShapeKernel::ComputeNoAtomic(cudaStream_t &stream, static ScatterNDOfShapeOp _op32; static ScatterNDOfShapeOp _op16; -} // namespace ortops +} // namespace ortops diff --git a/operators/contrib/cuda/scatter_nd_of_shape.h b/operators/contrib/cuda/scatter_nd_of_shape.h index 72d725cca..b5a924427 100644 --- a/operators/contrib/cuda/scatter_nd_of_shape.h +++ b/operators/contrib/cuda/scatter_nd_of_shape.h @@ -15,18 +15,19 @@ enum class Reduction : int { }; /** -* This kernel implementation the fusion of ConstantOfShape and ScatterND. -* The implementation does not use OrtLiteCustom as the input shape (first input) -* is expected to be on CPU wheeras the other outputs are expected to be on CUDA. -*/ -template struct ScatterNDOfShapeKernel { - ScatterNDOfShapeKernel(const OrtApi &api, const OrtKernelInfo *info); - void Compute(OrtKernelContext *context); - -private: - void ComputeNoAtomic(cudaStream_t &stream, const std::vector &input_shape, - const std::vector &indices_shape, T *output_data, - const int64_t *indices_data, const T *updates_data) const; + * This kernel implementation the fusion of ConstantOfShape and ScatterND. + * The implementation does not use OrtLiteCustom as the input shape (first input) + * is expected to be on CPU wheeras the other outputs are expected to be on CUDA. + */ +template +struct ScatterNDOfShapeKernel { + ScatterNDOfShapeKernel(const OrtApi& api, const OrtKernelInfo* info); + void Compute(OrtKernelContext* context); + + private: + void ComputeNoAtomic(cudaStream_t& stream, const std::vector& input_shape, + const std::vector& indices_shape, T* output_data, + const int64_t* indices_data, const T* updates_data) const; Reduction reduction_; int maxThreadPerBlock_; @@ -37,9 +38,9 @@ struct ScatterNDOfShapeOp : Ort::CustomOpBase, ScatterNDOfShapeKernel> { typedef Ort::CustomOpBase, ScatterNDOfShapeKernel> parent_type; ScatterNDOfShapeOp() : parent_type() {} - void *CreateKernel(const OrtApi &api, const OrtKernelInfo *info) const; - const char *GetName() const; - const char *GetExecutionProviderType() const; + void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const; + const char* GetName() const; + const char* GetExecutionProviderType() const; std::size_t GetInputTypeCount() const; ONNXTensorElementDataType GetInputType(std::size_t index) const; @@ -51,4 +52,4 @@ struct ScatterNDOfShapeOp OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(std::size_t index) const; }; -} // namespace ortops +} // namespace ortops