Skip to content

Commit

Permalink
complete
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed May 6, 2024
1 parent 6d2442d commit 56f979f
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 43 deletions.
5 changes: 4 additions & 1 deletion operators/cuda/cuda_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#ifdef USE_CUDA
#include "cuda/fast_gelu.h"
#include "cuda/negxplus1.h"
#endif

FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
Expand All @@ -13,10 +14,12 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
#ifdef USE_CUDA
,
CustomCudaStructV2("FastGelu", contrib::FastGelu<float>),
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<float>),
#if ORT_API_VERSION >= 16

CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::MFloat16>),
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::BFloat16>)
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::BFloat16>),
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<ortc::MFloat16>)
#endif
#endif
);
Expand Down
39 changes: 0 additions & 39 deletions operators/cuda/neg1plusx_impl.cu

This file was deleted.

3 changes: 1 addition & 2 deletions operators/cuda/negxplus1.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
#pragma once
#include "ocos.h"
#include "negxplus1_impl.cuh"
#include "cuda_type.h"

namespace contrib {

template <typename T>
struct FastGelu {
struct NegXPlus1 {
template <typename TDict>
OrtStatusPtr OnModelAttach(const TDict& /*dict*/) {
return nullptr;
Expand Down
50 changes: 50 additions & 0 deletions operators/cuda/negxplus1_impl.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "device_prop.cuh"
#include "utils.cuh"
#include "negxplus1_impl.cuh"
#include "cuda_type.h"

using namespace Ort::Custom;

template <typename T>
__device__ __inline__ T _negxplus1(const T x) {
return (T)1 - x;
}

template <>
__device__ __inline__ half _negxplus1(const half x) {
#if __CUDA_ARCH__ < 700
return __float2half(1 - __half2float(x));
#else
return (half)1 - x;
#endif
}

template <typename T>
__global__ void NegXPlus1Kernel(T* output_data, const T* input_data, int N) {
int id = blockDim.x * blockIdx.x + threadIdx.x;
if (id >= N)
return;
output_data[id] = _negxplus1(input_data[id]);
}

template <typename T>
cudaError_t _LaunchNegXPlus1Kernel(cudaStream_t stream, int input_length, const T* input, T* output) {
constexpr int blockSize = 256;
const int gridSize = (input_length + blockSize - 1) / blockSize;
using TT = typename contrib::CudaT<T>::MappedType;
NegXPlus1Kernel<TT><<<gridSize, blockSize, 0, stream>>>((TT*)output, (const TT*)input, input_length);
return cudaGetLastError();
}

template <>
cudaError_t LaunchNegXPlus1Kernel<float>(cudaStream_t stream, int input_length, const float* input, float* output) {
return _LaunchNegXPlus1Kernel(stream, input_length, input, output);
}

template <>
cudaError_t LaunchNegXPlus1Kernel<ortc::MFloat16>(cudaStream_t stream, int input_length, const ortc::MFloat16* input, ortc::MFloat16* output) {
return _LaunchNegXPlus1Kernel(stream, input_length, input, output);
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
#include <cuda_runtime.h>

template <typename T>
cudaError_t LaunchFastGeluKernel(cudaStream_t stream, int input_length, const T* input, T* output);
cudaError_t LaunchNegXPlus1Kernel(cudaStream_t stream, int input_length, const T* input, T* output);

0 comments on commit 56f979f

Please sign in to comment.