Skip to content

Commit

Permalink
fix unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed May 6, 2024
1 parent 56f979f commit e499e77
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 9 deletions.
9 changes: 4 additions & 5 deletions operators/cuda/negxplus1.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@ struct NegXPlus1 {
if (0 == input_length) {
return nullptr;
}
using TT = typename CudaT<T>::MappedType;
LaunchNegXPlus1Kernel<TT>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
input_length,
reinterpret_cast<const TT*>(input_data),
reinterpret_cast<TT*>(output_data));
LaunchNegXPlus1Kernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
input_length,
input_data,
output_data);
return nullptr;
}
};
Expand Down
2 changes: 1 addition & 1 deletion operators/cuda/negxplus1_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ cudaError_t _LaunchNegXPlus1Kernel(cudaStream_t stream, int input_length, const
constexpr int blockSize = 256;
const int gridSize = (input_length + blockSize - 1) / blockSize;
using TT = typename contrib::CudaT<T>::MappedType;
NegXPlus1Kernel<TT><<<gridSize, blockSize, 0, stream>>>((TT*)output, (const TT*)input, input_length);
NegXPlus1Kernel<TT><<<gridSize, blockSize, 0, stream>>>(reinterpret_cast<TT*>(output), reinterpret_cast<const TT*>(input), input_length);
return cudaGetLastError();
}

Expand Down
15 changes: 12 additions & 3 deletions test/cuda/test_cudaops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,21 @@
import numpy as np
from numpy.testing import assert_almost_equal
from onnx import helper, numpy_helper, onnx_pb as onnx_proto, TensorProto
from onnx.reference import ReferenceEvaluator
from onnx.reference.op_run import OpRun
from onnxruntime_extensions import make_onnx_model
from onnxruntime_extensions import get_library_path as _get_library_path

import onnxruntime as _ort


class NegXPlus1(OpRun):
op_domain = "ai.onnx.contrib"

def _run(self, X):
return (1 - X,)


class TestCudaOps(unittest.TestCase):
@staticmethod
def _create_negpos_test_model(domain="ai.onnx.contrib"):
Expand Down Expand Up @@ -109,7 +118,7 @@ def _negxplus1_cuda(self, itype):

model2 = helper.make_model(
helper.make_graph(
[helper.make_node("NegXplus1", ["X"], ["Y"], domain="ai.onnx.contrib")],
[helper.make_node("NegXPlus1", ["X"], ["Y"], domain="ai.onnx.contrib")],
"nd",
[helper.make_tensor_value_info("X", itype, [None, None, None])],
[helper.make_tensor_value_info("Y", itype, [None, None, None])],
Expand All @@ -125,14 +134,14 @@ def _negxplus1_cuda(self, itype):
x = (np.arange(18) - 4).reshape((3, 2, 3)).astype(dtype)

feeds1 = dict(X=x)
ref = CReferenceEvaluator(model1)
ref = ReferenceEvaluator(model1, new_ops=[NegXPlus1])
expected = ref.run(None, feeds1)[0]

opts = onnxruntime.SessionOptions()
opts.register_custom_ops_library(_get_library_path())
sess = onnxruntime.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"])
got = sess.run(None, feeds1)[0]
self.assertEqualArray(expected, got, atol=1e-5)
assert_almost_equal(expected, got, decimal=5)

def test_negxplus1_cuda(self):
self._negxplus1_cuda(TensorProto.FLOAT)
Expand Down

0 comments on commit e499e77

Please sign in to comment.