From 6d2442d276a020a8a1318678bd3415a641cc9fbc Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 6 May 2024 11:37:06 +0000 Subject: [PATCH] first draft for NegXPlus1 --- operators/cuda/neg1plusx_impl.cu | 39 ++++++++ operators/cuda/neg1plusx_impl.cuh | 9 ++ operators/cuda/negxplus1.h | 35 ++++++++ test/cuda/test_cudaops.py | 142 +++++++++++++++++------------- 4 files changed, 165 insertions(+), 60 deletions(-) create mode 100644 operators/cuda/neg1plusx_impl.cu create mode 100644 operators/cuda/neg1plusx_impl.cuh create mode 100644 operators/cuda/negxplus1.h diff --git a/operators/cuda/neg1plusx_impl.cu b/operators/cuda/neg1plusx_impl.cu new file mode 100644 index 000000000..fdccaf85c --- /dev/null +++ b/operators/cuda/neg1plusx_impl.cu @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "device_prop.cuh" +#include "utils.cuh" +#include "fast_gelu_impl.cuh" + +using namespace Ort::Custom; + +template +__device__ __inline__ T _neg1plusx(const T x) { + return (T)1 - x; +} + +template <> +__device__ __inline__ half _neg1plusx(const half x) { +#if __CUDA_ARCH__ < 700 + return __float2half(1 - __half2float(x)); +#else + return (half)1 - x; +#endif +} + +template +__global__ void _NegXplus1Kernel(T* output_data, const T* input_data, int N) { + int id = blockDim.x * blockIdx.x + threadIdx.x; + if (id >= N) + return; + output_data[id] = _neg1plusx(input_data[id]); +} + +template +cudaError_t LaunchNegXPlus1Kernel(cudaStream_t stream, int input_length, const T* input, T* output) { + constexpr int blockSize = 256; + const int gridSize = (input_length + blockSize - 1) / blockSize; + using TT = typename CudaT::MappedType; + NegXPlus1Kernel<<>>((TT*)input, (TT*)output, input_length); + return cudaGetLastError(); +} diff --git a/operators/cuda/neg1plusx_impl.cuh b/operators/cuda/neg1plusx_impl.cuh new file mode 100644 index 000000000..d6fec1cc8 --- /dev/null +++ b/operators/cuda/neg1plusx_impl.cuh @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include + +template +cudaError_t LaunchFastGeluKernel(cudaStream_t stream, int input_length, const T* input, T* output); \ No newline at end of file diff --git a/operators/cuda/negxplus1.h b/operators/cuda/negxplus1.h new file mode 100644 index 000000000..d31885827 --- /dev/null +++ b/operators/cuda/negxplus1.h @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "ocos.h" +#include "negxplus1_impl.cuh" +#include "cuda_type.h" + +namespace contrib { + +template +struct FastGelu { + template + OrtStatusPtr OnModelAttach(const TDict& /*dict*/) { + return nullptr; + } + OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx, + const ortc::Tensor& input, + ortc::Tensor& output) const { + const T* input_data = input.Data(); + T* output_data = output.Allocate(input.Shape()); + auto input_length = input.NumberOfElement(); + if (0 == input_length) { + return nullptr; + } + using TT = typename CudaT::MappedType; + LaunchNegXPlus1Kernel(reinterpret_cast(ctx->GetCudaStream()), + input_length, + reinterpret_cast(input_data), + reinterpret_cast(output_data)); + return nullptr; + } +}; + +} // namespace contrib \ No newline at end of file diff --git a/test/cuda/test_cudaops.py b/test/cuda/test_cudaops.py index d868fe675..be4bedd4a 100644 --- a/test/cuda/test_cudaops.py +++ b/test/cuda/test_cudaops.py @@ -1,7 +1,7 @@ import unittest import numpy as np from numpy.testing import assert_almost_equal -from onnx import helper, onnx_pb as onnx_proto +from onnx import helper, numpy_helper, onnx_pb as onnx_proto, TensorProto from onnxruntime_extensions import make_onnx_model from onnxruntime_extensions import get_library_path as _get_library_path @@ -10,22 +10,17 @@ class TestCudaOps(unittest.TestCase): @staticmethod - def _create_negpos_test_model(domain='ai.onnx.contrib'): + def _create_negpos_test_model(domain="ai.onnx.contrib"): nodes = [ - helper.make_node('Identity', ['x'], ['identity1']), - helper.make_node( - 'NegPos', ['identity1'], ['neg', 'pos'], - domain=domain) + helper.make_node("Identity", ["x"], ["identity1"]), + helper.make_node("NegPos", ["identity1"], ["neg", "pos"], domain=domain), ] - input0 = helper.make_tensor_value_info( - 'x', onnx_proto.TensorProto.FLOAT, [None, None]) - output1 = helper.make_tensor_value_info( - 'neg', onnx_proto.TensorProto.FLOAT, [None, None]) - output2 = helper.make_tensor_value_info( - 'pos', onnx_proto.TensorProto.FLOAT, [None, None]) + input0 = helper.make_tensor_value_info("x", onnx_proto.TensorProto.FLOAT, [None, None]) + output1 = helper.make_tensor_value_info("neg", onnx_proto.TensorProto.FLOAT, [None, None]) + output2 = helper.make_tensor_value_info("pos", onnx_proto.TensorProto.FLOAT, [None, None]) - graph = helper.make_graph(nodes, 'test0', [input0], [output1, output2]) + graph = helper.make_graph(nodes, "test0", [input0], [output1, output2]) model = make_onnx_model(graph) return model @@ -34,87 +29,114 @@ def test_cuda_negpos(self): so.register_custom_ops_library(_get_library_path()) onnx_model = self._create_negpos_test_model() self.assertIn('op_type: "NegPos"', str(onnx_model)) - sess = _ort.InferenceSession(onnx_model.SerializeToString(), - so, - providers=['CUDAExecutionProvider']) - x = np.array([[0., 1., 1.5], [7., 8., -5.5]]).astype(np.float32) - neg, pos = sess.run(None, {'x': x}) + sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=["CUDAExecutionProvider"]) + x = np.array([[0.0, 1.0, 1.5], [7.0, 8.0, -5.5]]).astype(np.float32) + neg, pos = sess.run(None, {"x": x}) diff = x - (neg + pos) assert_almost_equal(diff, np.zeros(diff.shape)) @staticmethod - def _create_fastgelu_test_model(domain='ai.onnx.contrib'): - nodes = [ - helper.make_node( - 'FastGelu', ['x', 'bias'], ['y'], - domain=domain) - ] + def _create_fastgelu_test_model(domain="ai.onnx.contrib"): + nodes = [helper.make_node("FastGelu", ["x", "bias"], ["y"], domain=domain)] - input0 = helper.make_tensor_value_info( - 'x', onnx_proto.TensorProto.FLOAT, []) - input1 = helper.make_tensor_value_info( - 'bias', onnx_proto.TensorProto.FLOAT, []) - output0 = helper.make_tensor_value_info( - 'y', onnx_proto.TensorProto.FLOAT, []) + input0 = helper.make_tensor_value_info("x", onnx_proto.TensorProto.FLOAT, []) + input1 = helper.make_tensor_value_info("bias", onnx_proto.TensorProto.FLOAT, []) + output0 = helper.make_tensor_value_info("y", onnx_proto.TensorProto.FLOAT, []) - graph = helper.make_graph(nodes, 'test1', [input0, input1], [output0]) + graph = helper.make_graph(nodes, "test1", [input0, input1], [output0]) model = make_onnx_model(graph) return model @staticmethod - def _create_fastgelu_test_model_f16(domain='ai.onnx.contrib'): - nodes = [ - helper.make_node( - 'FastGelu', ['x', 'bias'], ['y'], - domain=domain) - ] + def _create_fastgelu_test_model_f16(domain="ai.onnx.contrib"): + nodes = [helper.make_node("FastGelu", ["x", "bias"], ["y"], domain=domain)] - input0 = helper.make_tensor_value_info( - 'x', onnx_proto.TensorProto.FLOAT16, []) - input1 = helper.make_tensor_value_info( - 'bias', onnx_proto.TensorProto.FLOAT16, []) - output0 = helper.make_tensor_value_info( - 'y', onnx_proto.TensorProto.FLOAT16, []) + input0 = helper.make_tensor_value_info("x", onnx_proto.TensorProto.FLOAT16, []) + input1 = helper.make_tensor_value_info("bias", onnx_proto.TensorProto.FLOAT16, []) + output0 = helper.make_tensor_value_info("y", onnx_proto.TensorProto.FLOAT16, []) - graph = helper.make_graph(nodes, 'test1', [input0, input1], [output0]) + graph = helper.make_graph(nodes, "test1", [input0, input1], [output0]) model = make_onnx_model(graph) return model def test_cuda_fastgelu(self): eps = _ort.get_available_providers() - if 'CUDAExecutionProvider' in eps: + if "CUDAExecutionProvider" in eps: so = _ort.SessionOptions() so.register_custom_ops_library(_get_library_path()) onnx_model = self._create_fastgelu_test_model() self.assertIn('op_type: "FastGelu"', str(onnx_model)) - sess = _ort.InferenceSession(onnx_model.SerializeToString(), - so, - providers=['CUDAExecutionProvider']) - x = np.array([0., 1., 2., 3., 4., 5.]).astype(np.float32) + sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=["CUDAExecutionProvider"]) + x = np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]).astype(np.float32) bias = np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5]).astype(np.float32) - expected_y = np.array([0., 0.9505811, 2.1696784, 3.298689, 4.399991, 5.5]).astype(np.float32) - y = sess.run(None, {'x': x, 'bias':bias})[0] + expected_y = np.array([0.0, 0.9505811, 2.1696784, 3.298689, 4.399991, 5.5]).astype(np.float32) + y = sess.run(None, {"x": x, "bias": bias})[0] assert_almost_equal(y, expected_y) else: - print ('CUDAExecutionProvider not available, test_cuda_fastgelu skipped.') + print("CUDAExecutionProvider not available, test_cuda_fastgelu skipped.") def test_cuda_fastgelu_f16(self): eps = _ort.get_available_providers() - if 'CUDAExecutionProvider' in eps: + if "CUDAExecutionProvider" in eps: so = _ort.SessionOptions() so.register_custom_ops_library(_get_library_path()) onnx_model = self._create_fastgelu_test_model_f16() self.assertIn('op_type: "FastGelu"', str(onnx_model)) - sess = _ort.InferenceSession(onnx_model.SerializeToString(), - so, - providers=['CUDAExecutionProvider']) - x = np.array([0., 1., 2., 3., 4., 5.]).astype(np.float16) + sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=["CUDAExecutionProvider"]) + x = np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]).astype(np.float16) bias = np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5]).astype(np.float16) - expected_y = np.array([0., 0.95, 2.17, 3.299, 4.4, 5.5]).astype(np.float16) - y = sess.run(None, {'x': x, 'bias':bias})[0] + expected_y = np.array([0.0, 0.95, 2.17, 3.299, 4.4, 5.5]).astype(np.float16) + y = sess.run(None, {"x": x, "bias": bias})[0] assert_almost_equal(y, expected_y) else: - print ('CUDAExecutionProvider not available, test_cuda_fastgelu_f16 skipped.') + print("CUDAExecutionProvider not available, test_cuda_fastgelu_f16 skipped.") + + def _negxplus1_cuda(self, itype): + import onnxruntime + + dtype = np.float32 if itype == TensorProto.FLOAT else np.float16 + model1 = helper.make_model( + helper.make_graph( + [helper.make_node("Sub", ["one", "X"], ["Y"])], + "nd", + [helper.make_tensor_value_info("X", itype, [None, None, None])], + [helper.make_tensor_value_info("Y", itype, [None, None, None])], + [numpy_helper.from_array(np.array([1], dtype=dtype), name="one")], + ), + opset_imports=[helper.make_opsetid("", 18)], + ir_version=9, + ) + + model2 = helper.make_model( + helper.make_graph( + [helper.make_node("NegXplus1", ["X"], ["Y"], domain="ai.onnx.contrib")], + "nd", + [helper.make_tensor_value_info("X", itype, [None, None, None])], + [helper.make_tensor_value_info("Y", itype, [None, None, None])], + ), + opset_imports=[ + helper.make_opsetid("", 18), + helper.make_opsetid("ai.onnx.contrib", 1), + ], + ir_version=9, + ) + + dtype = np.float32 if itype == TensorProto.FLOAT else np.float16 + x = (np.arange(18) - 4).reshape((3, 2, 3)).astype(dtype) + + feeds1 = dict(X=x) + ref = CReferenceEvaluator(model1) + expected = ref.run(None, feeds1)[0] + + opts = onnxruntime.SessionOptions() + opts.register_custom_ops_library(_get_library_path()) + sess = onnxruntime.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"]) + got = sess.run(None, feeds1)[0] + self.assertEqualArray(expected, got, atol=1e-5) + + def test_negxplus1_cuda(self): + self._negxplus1_cuda(TensorProto.FLOAT) + self._negxplus1_cuda(TensorProto.FLOAT16) if __name__ == "__main__":