Skip to content

Commit

Permalink
Merge branch 'main' into scatter
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre authored May 30, 2024
2 parents e244b1a + b60df02 commit 17db995
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions operators/cuda/negxplus1.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,30 @@
#pragma once
#include "ocos.h"
#include "negxplus1_impl.cuh"
#include "ortx_common.h"

namespace contrib {

template <typename T>
struct NegXPlus1 {
template <typename TDict>
OrtStatusPtr OnModelAttach(const TDict& /*dict*/) {
return nullptr;
OrtxStatus OnModelAttach(const TDict& /*dict*/) {
return {};
}
OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx,
OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx,
const ortc::Tensor<T>& input,
ortc::Tensor<T>& output) const {
const T* input_data = input.Data();
T* output_data = output.Allocate(input.Shape());
auto input_length = input.NumberOfElement();
if (0 == input_length) {
return nullptr;
return {};
}
LaunchNegXPlus1Kernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
input_length,
input_data,
output_data);
return nullptr;
return {};
}
};

Expand Down

0 comments on commit 17db995

Please sign in to comment.