From 3b889fc42f561888a4eccb85e39fe521774a017f Mon Sep 17 00:00:00 2001 From: "Tang, Cheng" Date: Tue, 30 Apr 2024 13:53:39 -0700 Subject: [PATCH] update custom op v2 struct to be able to invoke from eager mode (#700) Co-authored-by: Cheng Tang Co-authored-by: Wenbing Li <10278425+wenbingl@users.noreply.github.com> --- include/op_def_struct.h | 39 +++++++++++++++- operators/contrib/cuda/fast_gelu.h | 4 +- test/static_test/test_cuda_eager.cc | 71 ++++++++++++++++++++++------- 3 files changed, 93 insertions(+), 21 deletions(-) diff --git a/include/op_def_struct.h b/include/op_def_struct.h index 16f396aab..63d2418a5 100644 --- a/include/op_def_struct.h +++ b/include/op_def_struct.h @@ -71,6 +71,35 @@ struct ComputeArgsList { using ResultType = RType; }; +template +struct HasOnModelAttach { + static_assert( + std::integral_constant::value, + "Second template parameter needs to be of function type."); +}; + +// specialization that does the checking + +template +struct HasOnModelAttach { +private: + template + static constexpr auto check(T*) + -> typename + std::is_same< + decltype( std::declval().OnModelAttach( std::declval()... ) ), + Ret + >::type; // attempt to call it and see if the return type is correct + + template + static constexpr std::false_type check(...); + + typedef decltype(check(0)) type; + +public: + static constexpr bool value = type::value; +}; + template struct CustomOp_defined_getInputMemoryType : std::false_type {}; @@ -96,8 +125,14 @@ struct OrtLiteCustomStructV2 : public OrtLiteCustomOp { template static OrtStatusPtr InitKernel(KernelEx& kernel, const OrtApi& api, const OrtKernelInfo& info, RegularComputeType fn, T t) { - auto status = kernel.OnModelAttach(api, info); - return ToApiStatus(status); + if constexpr (HasOnModelAttach::value){ + auto status = kernel.OnModelAttach(api, info); + return ToApiStatus(status); + } + else { + auto status = kernel.OnModelAttach(OrtAttributeReader(api, info)); + return ToApiStatus(status); + } } static OrtStatusPtr InitKernel( diff --git a/operators/contrib/cuda/fast_gelu.h b/operators/contrib/cuda/fast_gelu.h index 6586e6573..804792529 100644 --- a/operators/contrib/cuda/fast_gelu.h +++ b/operators/contrib/cuda/fast_gelu.h @@ -10,8 +10,8 @@ namespace contrib { template struct FastGelu { - OrtStatusPtr OnModelAttach(const OrtApi& /*api*/, - const OrtKernelInfo& /*info*/) { + template + OrtStatusPtr OnModelAttach(const TDict& /*dict*/) { return nullptr; } OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx, diff --git a/test/static_test/test_cuda_eager.cc b/test/static_test/test_cuda_eager.cc index ba2109cef..65de140de 100644 --- a/test/static_test/test_cuda_eager.cc +++ b/test/static_test/test_cuda_eager.cc @@ -9,45 +9,50 @@ #ifdef USE_CUDA #include "math/cuda/negpos_def.h" +#include "contrib/cuda/fast_gelu.h" #include #include + +class CudaAllocator : public Ort::Custom::IAllocator { +public: + void* Alloc(size_t size) override { + void* p = nullptr; + cudaMalloc((void**)&p, size); + return p; + } + void Free(void* p) override { cudaFree(p); } +}; + class MockCudaKernelContext : public Ort::Custom::CUDAKernelContext { public: MockCudaKernelContext() { cudaStreamCreate(&stream); } ~MockCudaKernelContext() { cudaStreamDestroy(stream); } - void* AllocScratchBuffer(size_t size) override { return nullptr; } - void FreeScratchBuffer(void* p) override {} - void* AllocCudaScratchBuffer(size_t size) override { return nullptr; } - void FreeCudaScratchBuffer(void* p) override {} + void* AllocScratchBuffer(size_t size) override { return malloc(size); } + void FreeScratchBuffer(void* p) override { return free(p);} + void* AllocCudaScratchBuffer(size_t size) override { return cuda_alloc.Alloc(size); } + void FreeCudaScratchBuffer(void* p) override { return cuda_alloc.Free(p); } void* GetCudaStream() const override { return static_cast(stream); } void* GetCublasHandle() const override { return nullptr; } int GetCudaDeviceId() const override { return 0; } + Ort::Custom::IAllocator* GetCudaAllocator() { return &cuda_alloc;}; + private: cudaStream_t stream; -}; - -class CudaAllocator : public Ort::Custom::IAllocator { -public: - void* Alloc(size_t size) override { - void* p = nullptr; - cudaMalloc((void**)&p, size); - return p; - } - void Free(void* p) override { cudaFree(p); } + CudaAllocator cuda_alloc; }; TEST(CudaOp, test_eager_negpos) { MockCudaKernelContext mock_cuda_kc; std::vector input_data = {0.0f, 0.2f, -1.3f, 1.5f}; - std::unique_ptr cuda_alloc = std::make_unique(); + auto cuda_alloc = mock_cuda_kc.GetCudaAllocator(); void* device_input = cuda_alloc->Alloc(sizeof(float) * input_data.size()); cudaMemcpyAsync(device_input, input_data.data(), sizeof(float)*input_data.size(), cudaMemcpyHostToDevice, static_cast(mock_cuda_kc.GetCudaStream())); ortc::Tensor input(std::vector{2, 2}, device_input); - ortc::Tensor output1(cuda_alloc.get()); - ortc::Tensor output2(cuda_alloc.get()); + ortc::Tensor output1(cuda_alloc); + ortc::Tensor output2(cuda_alloc); neg_pos_cuda(&mock_cuda_kc, input, output1, output2); float* host_output1 = (float*)malloc(sizeof(float) * input_data.size()); @@ -63,4 +68,36 @@ TEST(CudaOp, test_eager_negpos) { free(host_output2); } +TEST(CudaOp, test_fastgelu_eager) { + + MockCudaKernelContext mock_cuda_kc; + // inputs + std::vector x_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; + auto cuda_alloc = mock_cuda_kc.GetCudaAllocator(); + void* x_gpu_input = cuda_alloc->Alloc(sizeof(float) * x_data.size()); + cudaMemcpyAsync(x_gpu_input, x_data.data(), sizeof(float)*x_data.size(), cudaMemcpyHostToDevice, static_cast(mock_cuda_kc.GetCudaStream())); + + std::vector bias_data = {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f}; + void* bias_gpu_input = cuda_alloc->Alloc(sizeof(float) * bias_data.size()); + cudaMemcpyAsync(bias_gpu_input, bias_data.data(), sizeof(float)*bias_data.size(), cudaMemcpyHostToDevice, static_cast(mock_cuda_kc.GetCudaStream())); + + ortc::NamedArgumentDict dict({"use_half_2_"}, + std::make_tuple(false)); + contrib::FastGelu fast_gelu; + fast_gelu.OnModelAttach(dict); + + ortc::Tensor x(std::vector{6, }, x_gpu_input); + ortc::Tensor bias(std::vector{6, }, bias_gpu_input); + ortc::Tensor output(cuda_alloc); + fast_gelu.Compute(&mock_cuda_kc, x, &bias, output); + + float* host_output = (float*)malloc(sizeof(float) * x_data.size()); + cudaMemcpyAsync(host_output, output.DataRaw(), sizeof(float)*x_data.size(), cudaMemcpyDeviceToHost, static_cast(mock_cuda_kc.GetCudaStream())); + ASSERT_NEAR(host_output[1], 0.9505811, 0.01f); + ASSERT_NEAR(host_output[2], 2.1696784, 0.01f); + ASSERT_NEAR(host_output[3], 3.298689, 0.01f); + ASSERT_NEAR(host_output[4], 4.399991, 0.01f); + ASSERT_NEAR(host_output[5], 5.5, 0.01f); +} + #endif \ No newline at end of file