Skip to content

Commit

Permalink
first draft for NegXPlus1
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed May 6, 2024
1 parent e645cda commit 6d2442d
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 60 deletions.
39 changes: 39 additions & 0 deletions operators/cuda/neg1plusx_impl.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "device_prop.cuh"
#include "utils.cuh"
#include "fast_gelu_impl.cuh"

using namespace Ort::Custom;

template <typename T>
__device__ __inline__ T _neg1plusx(const T x) {
return (T)1 - x;
}

template <>
__device__ __inline__ half _neg1plusx(const half x) {
#if __CUDA_ARCH__ < 700
return __float2half(1 - __half2float(x));
#else
return (half)1 - x;
#endif
}

template <typename T>
__global__ void _NegXplus1Kernel(T* output_data, const T* input_data, int N) {
int id = blockDim.x * blockIdx.x + threadIdx.x;
if (id >= N)
return;
output_data[id] = _neg1plusx(input_data[id]);
}

template <typename T>
cudaError_t LaunchNegXPlus1Kernel(cudaStream_t stream, int input_length, const T* input, T* output) {
constexpr int blockSize = 256;
const int gridSize = (input_length + blockSize - 1) / blockSize;
using TT = typename CudaT<T>::MappedType;
NegXPlus1Kernel<TT, blockSize><<<gridSize, blockSize, 0, stream>>>((TT*)input, (TT*)output, input_length);
return cudaGetLastError();
}
9 changes: 9 additions & 0 deletions operators/cuda/neg1plusx_impl.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include <cuda.h>
#include <cuda_runtime.h>

template <typename T>
cudaError_t LaunchFastGeluKernel(cudaStream_t stream, int input_length, const T* input, T* output);
35 changes: 35 additions & 0 deletions operators/cuda/negxplus1.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include "ocos.h"
#include "negxplus1_impl.cuh"
#include "cuda_type.h"

namespace contrib {

template <typename T>
struct FastGelu {
template <typename TDict>
OrtStatusPtr OnModelAttach(const TDict& /*dict*/) {
return nullptr;
}
OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx,
const ortc::Tensor<T>& input,
ortc::Tensor<T>& output) const {
const T* input_data = input.Data();
T* output_data = output.Allocate(input.Shape());
auto input_length = input.NumberOfElement();
if (0 == input_length) {
return nullptr;
}
using TT = typename CudaT<T>::MappedType;
LaunchNegXPlus1Kernel<TT>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
input_length,
reinterpret_cast<const TT*>(input_data),
reinterpret_cast<TT*>(output_data));
return nullptr;
}
};

} // namespace contrib
142 changes: 82 additions & 60 deletions test/cuda/test_cudaops.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
import numpy as np
from numpy.testing import assert_almost_equal
from onnx import helper, onnx_pb as onnx_proto
from onnx import helper, numpy_helper, onnx_pb as onnx_proto, TensorProto
from onnxruntime_extensions import make_onnx_model
from onnxruntime_extensions import get_library_path as _get_library_path

Expand All @@ -10,22 +10,17 @@

class TestCudaOps(unittest.TestCase):
@staticmethod
def _create_negpos_test_model(domain='ai.onnx.contrib'):
def _create_negpos_test_model(domain="ai.onnx.contrib"):
nodes = [
helper.make_node('Identity', ['x'], ['identity1']),
helper.make_node(
'NegPos', ['identity1'], ['neg', 'pos'],
domain=domain)
helper.make_node("Identity", ["x"], ["identity1"]),
helper.make_node("NegPos", ["identity1"], ["neg", "pos"], domain=domain),
]

input0 = helper.make_tensor_value_info(
'x', onnx_proto.TensorProto.FLOAT, [None, None])
output1 = helper.make_tensor_value_info(
'neg', onnx_proto.TensorProto.FLOAT, [None, None])
output2 = helper.make_tensor_value_info(
'pos', onnx_proto.TensorProto.FLOAT, [None, None])
input0 = helper.make_tensor_value_info("x", onnx_proto.TensorProto.FLOAT, [None, None])
output1 = helper.make_tensor_value_info("neg", onnx_proto.TensorProto.FLOAT, [None, None])
output2 = helper.make_tensor_value_info("pos", onnx_proto.TensorProto.FLOAT, [None, None])

graph = helper.make_graph(nodes, 'test0', [input0], [output1, output2])
graph = helper.make_graph(nodes, "test0", [input0], [output1, output2])
model = make_onnx_model(graph)
return model

Expand All @@ -34,87 +29,114 @@ def test_cuda_negpos(self):
so.register_custom_ops_library(_get_library_path())
onnx_model = self._create_negpos_test_model()
self.assertIn('op_type: "NegPos"', str(onnx_model))
sess = _ort.InferenceSession(onnx_model.SerializeToString(),
so,
providers=['CUDAExecutionProvider'])
x = np.array([[0., 1., 1.5], [7., 8., -5.5]]).astype(np.float32)
neg, pos = sess.run(None, {'x': x})
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=["CUDAExecutionProvider"])
x = np.array([[0.0, 1.0, 1.5], [7.0, 8.0, -5.5]]).astype(np.float32)
neg, pos = sess.run(None, {"x": x})
diff = x - (neg + pos)
assert_almost_equal(diff, np.zeros(diff.shape))

@staticmethod
def _create_fastgelu_test_model(domain='ai.onnx.contrib'):
nodes = [
helper.make_node(
'FastGelu', ['x', 'bias'], ['y'],
domain=domain)
]
def _create_fastgelu_test_model(domain="ai.onnx.contrib"):
nodes = [helper.make_node("FastGelu", ["x", "bias"], ["y"], domain=domain)]

input0 = helper.make_tensor_value_info(
'x', onnx_proto.TensorProto.FLOAT, [])
input1 = helper.make_tensor_value_info(
'bias', onnx_proto.TensorProto.FLOAT, [])
output0 = helper.make_tensor_value_info(
'y', onnx_proto.TensorProto.FLOAT, [])
input0 = helper.make_tensor_value_info("x", onnx_proto.TensorProto.FLOAT, [])
input1 = helper.make_tensor_value_info("bias", onnx_proto.TensorProto.FLOAT, [])
output0 = helper.make_tensor_value_info("y", onnx_proto.TensorProto.FLOAT, [])

graph = helper.make_graph(nodes, 'test1', [input0, input1], [output0])
graph = helper.make_graph(nodes, "test1", [input0, input1], [output0])
model = make_onnx_model(graph)
return model

@staticmethod
def _create_fastgelu_test_model_f16(domain='ai.onnx.contrib'):
nodes = [
helper.make_node(
'FastGelu', ['x', 'bias'], ['y'],
domain=domain)
]
def _create_fastgelu_test_model_f16(domain="ai.onnx.contrib"):
nodes = [helper.make_node("FastGelu", ["x", "bias"], ["y"], domain=domain)]

input0 = helper.make_tensor_value_info(
'x', onnx_proto.TensorProto.FLOAT16, [])
input1 = helper.make_tensor_value_info(
'bias', onnx_proto.TensorProto.FLOAT16, [])
output0 = helper.make_tensor_value_info(
'y', onnx_proto.TensorProto.FLOAT16, [])
input0 = helper.make_tensor_value_info("x", onnx_proto.TensorProto.FLOAT16, [])
input1 = helper.make_tensor_value_info("bias", onnx_proto.TensorProto.FLOAT16, [])
output0 = helper.make_tensor_value_info("y", onnx_proto.TensorProto.FLOAT16, [])

graph = helper.make_graph(nodes, 'test1', [input0, input1], [output0])
graph = helper.make_graph(nodes, "test1", [input0, input1], [output0])
model = make_onnx_model(graph)
return model

def test_cuda_fastgelu(self):
eps = _ort.get_available_providers()
if 'CUDAExecutionProvider' in eps:
if "CUDAExecutionProvider" in eps:
so = _ort.SessionOptions()
so.register_custom_ops_library(_get_library_path())
onnx_model = self._create_fastgelu_test_model()
self.assertIn('op_type: "FastGelu"', str(onnx_model))
sess = _ort.InferenceSession(onnx_model.SerializeToString(),
so,
providers=['CUDAExecutionProvider'])
x = np.array([0., 1., 2., 3., 4., 5.]).astype(np.float32)
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=["CUDAExecutionProvider"])
x = np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]).astype(np.float32)
bias = np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5]).astype(np.float32)
expected_y = np.array([0., 0.9505811, 2.1696784, 3.298689, 4.399991, 5.5]).astype(np.float32)
y = sess.run(None, {'x': x, 'bias':bias})[0]
expected_y = np.array([0.0, 0.9505811, 2.1696784, 3.298689, 4.399991, 5.5]).astype(np.float32)
y = sess.run(None, {"x": x, "bias": bias})[0]
assert_almost_equal(y, expected_y)
else:
print ('CUDAExecutionProvider not available, test_cuda_fastgelu skipped.')
print("CUDAExecutionProvider not available, test_cuda_fastgelu skipped.")

def test_cuda_fastgelu_f16(self):
eps = _ort.get_available_providers()
if 'CUDAExecutionProvider' in eps:
if "CUDAExecutionProvider" in eps:
so = _ort.SessionOptions()
so.register_custom_ops_library(_get_library_path())
onnx_model = self._create_fastgelu_test_model_f16()
self.assertIn('op_type: "FastGelu"', str(onnx_model))
sess = _ort.InferenceSession(onnx_model.SerializeToString(),
so,
providers=['CUDAExecutionProvider'])
x = np.array([0., 1., 2., 3., 4., 5.]).astype(np.float16)
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=["CUDAExecutionProvider"])
x = np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]).astype(np.float16)
bias = np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5]).astype(np.float16)
expected_y = np.array([0., 0.95, 2.17, 3.299, 4.4, 5.5]).astype(np.float16)
y = sess.run(None, {'x': x, 'bias':bias})[0]
expected_y = np.array([0.0, 0.95, 2.17, 3.299, 4.4, 5.5]).astype(np.float16)
y = sess.run(None, {"x": x, "bias": bias})[0]
assert_almost_equal(y, expected_y)
else:
print ('CUDAExecutionProvider not available, test_cuda_fastgelu_f16 skipped.')
print("CUDAExecutionProvider not available, test_cuda_fastgelu_f16 skipped.")

def _negxplus1_cuda(self, itype):
import onnxruntime

dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
model1 = helper.make_model(
helper.make_graph(
[helper.make_node("Sub", ["one", "X"], ["Y"])],
"nd",
[helper.make_tensor_value_info("X", itype, [None, None, None])],
[helper.make_tensor_value_info("Y", itype, [None, None, None])],
[numpy_helper.from_array(np.array([1], dtype=dtype), name="one")],
),
opset_imports=[helper.make_opsetid("", 18)],
ir_version=9,
)

model2 = helper.make_model(
helper.make_graph(
[helper.make_node("NegXplus1", ["X"], ["Y"], domain="ai.onnx.contrib")],
"nd",
[helper.make_tensor_value_info("X", itype, [None, None, None])],
[helper.make_tensor_value_info("Y", itype, [None, None, None])],
),
opset_imports=[
helper.make_opsetid("", 18),
helper.make_opsetid("ai.onnx.contrib", 1),
],
ir_version=9,
)

dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
x = (np.arange(18) - 4).reshape((3, 2, 3)).astype(dtype)

feeds1 = dict(X=x)
ref = CReferenceEvaluator(model1)
expected = ref.run(None, feeds1)[0]

opts = onnxruntime.SessionOptions()
opts.register_custom_ops_library(_get_library_path())
sess = onnxruntime.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"])
got = sess.run(None, feeds1)[0]
self.assertEqualArray(expected, got, atol=1e-5)

def test_negxplus1_cuda(self):
self._negxplus1_cuda(TensorProto.FLOAT)
self._negxplus1_cuda(TensorProto.FLOAT16)


if __name__ == "__main__":
Expand Down

0 comments on commit 6d2442d

Please sign in to comment.