Skip to content

Commit

Permalink
Use of OrtxStatus in kernel NegXPlus1 (microsoft#732)
Browse files Browse the repository at this point in the history
* first draft for NegXPlus1

* complete

* fix unit test

* rename one test

* remove test if not cuda

* switch to OrtxStatus

---------

Co-authored-by: Wenbing Li <[email protected]>
  • Loading branch information
xadupre and wenbingl authored May 30, 2024
1 parent 95a49fa commit b60df02
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions operators/cuda/negxplus1.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,30 @@
#pragma once
#include "ocos.h"
#include "negxplus1_impl.cuh"
#include "ortx_common.h"

namespace contrib {

template <typename T>
struct NegXPlus1 {
template <typename TDict>
OrtStatusPtr OnModelAttach(const TDict& /*dict*/) {
return nullptr;
OrtxStatus OnModelAttach(const TDict& /*dict*/) {
return {};
}
OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx,
OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx,
const ortc::Tensor<T>& input,
ortc::Tensor<T>& output) const {
const T* input_data = input.Data();
T* output_data = output.Allocate(input.Shape());
auto input_length = input.NumberOfElement();
if (0 == input_length) {
return nullptr;
return {};
}
LaunchNegXPlus1Kernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
input_length,
input_data,
output_data);
return nullptr;
return {};
}
};

Expand Down

0 comments on commit b60df02

Please sign in to comment.