Skip to content

Commit

Permalink
update custom op v2 struct to be able to invoke from eager mode (micr…
Browse files Browse the repository at this point in the history
…osoft#700)

Co-authored-by: Cheng Tang <[email protected]>
Co-authored-by: Wenbing Li <[email protected]>
  • Loading branch information
3 people authored Apr 30, 2024
1 parent 0175f90 commit 3b889fc
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 21 deletions.
39 changes: 37 additions & 2 deletions include/op_def_struct.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,35 @@ struct ComputeArgsList<RType (C::*)(Args...) const> {
using ResultType = RType;
};

template<typename, typename T>
struct HasOnModelAttach {
static_assert(
std::integral_constant<T, false>::value,
"Second template parameter needs to be of function type.");
};

// specialization that does the checking

template<typename C, typename Ret, typename... Args>
struct HasOnModelAttach<C, Ret(Args...)> {
private:
template<typename T>
static constexpr auto check(T*)
-> typename
std::is_same<
decltype( std::declval<T>().OnModelAttach( std::declval<Args>()... ) ),
Ret
>::type; // attempt to call it and see if the return type is correct

template<typename>
static constexpr std::false_type check(...);

typedef decltype(check<C>(0)) type;

public:
static constexpr bool value = type::value;
};

template <typename T, typename = void>
struct CustomOp_defined_getInputMemoryType : std::false_type {};

Expand All @@ -96,8 +125,14 @@ struct OrtLiteCustomStructV2 : public OrtLiteCustomOp {
template <typename T>
static OrtStatusPtr InitKernel(KernelEx& kernel,
const OrtApi& api, const OrtKernelInfo& info, RegularComputeType fn, T t) {
auto status = kernel.OnModelAttach(api, info);
return ToApiStatus(status);
if constexpr (HasOnModelAttach<KernelEx, OrtStatusPtr(const OrtApi&, const OrtKernelInfo&)>::value){
auto status = kernel.OnModelAttach(api, info);
return ToApiStatus(status);
}
else {
auto status = kernel.OnModelAttach(OrtAttributeReader(api, info));
return ToApiStatus(status);
}
}

static OrtStatusPtr InitKernel(
Expand Down
4 changes: 2 additions & 2 deletions operators/contrib/cuda/fast_gelu.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ namespace contrib {

template <typename T>
struct FastGelu {
OrtStatusPtr OnModelAttach(const OrtApi& /*api*/,
const OrtKernelInfo& /*info*/) {
template<typename TDict>
OrtStatusPtr OnModelAttach(const TDict& /*dict*/) {
return nullptr;
}
OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx,
Expand Down
71 changes: 54 additions & 17 deletions test/static_test/test_cuda_eager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,45 +9,50 @@

#ifdef USE_CUDA
#include "math/cuda/negpos_def.h"
#include "contrib/cuda/fast_gelu.h"
#include <cuda.h>
#include <cuda_runtime.h>


class CudaAllocator : public Ort::Custom::IAllocator {
public:
void* Alloc(size_t size) override {
void* p = nullptr;
cudaMalloc((void**)&p, size);
return p;
}
void Free(void* p) override { cudaFree(p); }
};

class MockCudaKernelContext : public Ort::Custom::CUDAKernelContext {
public:
MockCudaKernelContext() { cudaStreamCreate(&stream); }
~MockCudaKernelContext() { cudaStreamDestroy(stream); }
void* AllocScratchBuffer(size_t size) override { return nullptr; }
void FreeScratchBuffer(void* p) override {}
void* AllocCudaScratchBuffer(size_t size) override { return nullptr; }
void FreeCudaScratchBuffer(void* p) override {}
void* AllocScratchBuffer(size_t size) override { return malloc(size); }
void FreeScratchBuffer(void* p) override { return free(p);}
void* AllocCudaScratchBuffer(size_t size) override { return cuda_alloc.Alloc(size); }
void FreeCudaScratchBuffer(void* p) override { return cuda_alloc.Free(p); }
void* GetCudaStream() const override { return static_cast<void*>(stream); }
void* GetCublasHandle() const override { return nullptr; }
int GetCudaDeviceId() const override { return 0; }

Ort::Custom::IAllocator* GetCudaAllocator() { return &cuda_alloc;};

private:
cudaStream_t stream;
};

class CudaAllocator : public Ort::Custom::IAllocator {
public:
void* Alloc(size_t size) override {
void* p = nullptr;
cudaMalloc((void**)&p, size);
return p;
}
void Free(void* p) override { cudaFree(p); }
CudaAllocator cuda_alloc;
};

TEST(CudaOp, test_eager_negpos) {
MockCudaKernelContext mock_cuda_kc;
std::vector<float> input_data = {0.0f, 0.2f, -1.3f, 1.5f};
std::unique_ptr<CudaAllocator> cuda_alloc = std::make_unique<CudaAllocator>();
auto cuda_alloc = mock_cuda_kc.GetCudaAllocator();
void* device_input = cuda_alloc->Alloc(sizeof(float) * input_data.size());
cudaMemcpyAsync(device_input, input_data.data(), sizeof(float)*input_data.size(), cudaMemcpyHostToDevice, static_cast<cudaStream_t>(mock_cuda_kc.GetCudaStream()));

ortc::Tensor<float> input(std::vector<int64_t>{2, 2}, device_input);
ortc::Tensor<float> output1(cuda_alloc.get());
ortc::Tensor<float> output2(cuda_alloc.get());
ortc::Tensor<float> output1(cuda_alloc);
ortc::Tensor<float> output2(cuda_alloc);
neg_pos_cuda(&mock_cuda_kc, input, output1, output2);

float* host_output1 = (float*)malloc(sizeof(float) * input_data.size());
Expand All @@ -63,4 +68,36 @@ TEST(CudaOp, test_eager_negpos) {
free(host_output2);
}

TEST(CudaOp, test_fastgelu_eager) {

MockCudaKernelContext mock_cuda_kc;
// inputs
std::vector<float> x_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
auto cuda_alloc = mock_cuda_kc.GetCudaAllocator();
void* x_gpu_input = cuda_alloc->Alloc(sizeof(float) * x_data.size());
cudaMemcpyAsync(x_gpu_input, x_data.data(), sizeof(float)*x_data.size(), cudaMemcpyHostToDevice, static_cast<cudaStream_t>(mock_cuda_kc.GetCudaStream()));

std::vector<float> bias_data = {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f};
void* bias_gpu_input = cuda_alloc->Alloc(sizeof(float) * bias_data.size());
cudaMemcpyAsync(bias_gpu_input, bias_data.data(), sizeof(float)*bias_data.size(), cudaMemcpyHostToDevice, static_cast<cudaStream_t>(mock_cuda_kc.GetCudaStream()));

ortc::NamedArgumentDict dict({"use_half_2_"},
std::make_tuple(false));
contrib::FastGelu<float> fast_gelu;
fast_gelu.OnModelAttach(dict);

ortc::Tensor<float> x(std::vector<int64_t>{6, }, x_gpu_input);
ortc::Tensor<float> bias(std::vector<int64_t>{6, }, bias_gpu_input);
ortc::Tensor<float> output(cuda_alloc);
fast_gelu.Compute(&mock_cuda_kc, x, &bias, output);

float* host_output = (float*)malloc(sizeof(float) * x_data.size());
cudaMemcpyAsync(host_output, output.DataRaw(), sizeof(float)*x_data.size(), cudaMemcpyDeviceToHost, static_cast<cudaStream_t>(mock_cuda_kc.GetCudaStream()));
ASSERT_NEAR(host_output[1], 0.9505811, 0.01f);
ASSERT_NEAR(host_output[2], 2.1696784, 0.01f);
ASSERT_NEAR(host_output[3], 3.298689, 0.01f);
ASSERT_NEAR(host_output[4], 4.399991, 0.01f);
ASSERT_NEAR(host_output[5], 5.5, 0.01f);
}

#endif

0 comments on commit 3b889fc

Please sign in to comment.