diff --git a/operators/cuda/negxplus1.h b/operators/cuda/negxplus1.h index 832da6047..5460c37a2 100644 --- a/operators/cuda/negxplus1.h +++ b/operators/cuda/negxplus1.h @@ -4,29 +4,30 @@ #pragma once #include "ocos.h" #include "negxplus1_impl.cuh" +#include "ortx_common.h" namespace contrib { template struct NegXPlus1 { template - OrtStatusPtr OnModelAttach(const TDict& /*dict*/) { - return nullptr; + OrtxStatus OnModelAttach(const TDict& /*dict*/) { + return {}; } - OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx, + OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx, const ortc::Tensor& input, ortc::Tensor& output) const { const T* input_data = input.Data(); T* output_data = output.Allocate(input.Shape()); auto input_length = input.NumberOfElement(); if (0 == input_length) { - return nullptr; + return {}; } LaunchNegXPlus1Kernel(reinterpret_cast(ctx->GetCudaStream()), input_length, input_data, output_data); - return nullptr; + return {}; } };