Skip to content

Commit

Permalink
Draft for ScatterNFOfShape
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed May 2, 2024
1 parent 27d4b5c commit 927522d
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 36 deletions.
4 changes: 2 additions & 2 deletions operators/contrib/contrib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {

CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::MFloat16>),
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::BFloat16>),
[]() { return std::shared_ptr<ortc::OrtCustomOps>(std::make_unique<ScatterNDOfShapeOp<float>>().release()) },
[]() { return std::shared_ptr<ortc::OrtCustomOps>(std::make_unique<ScatterNDOfShapeOp<ortc::MFloat16>>().release()) }
[]() { return std::shared_ptr<OrtCustomOp>(std::make_unique<contrib::ScatterNDOfShapeOp<float>>().release()); },
[]() { return std::shared_ptr<OrtCustomOp>(std::make_unique<contrib::ScatterNDOfShapeOp<ortc::MFloat16>>().release()); }
#endif
#endif
);
Expand Down
77 changes: 49 additions & 28 deletions operators/contrib/cuda/scatter_nd_of_shape.cu
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "scatter_nd_of_shape.h"
#include <cublas_v2.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>

namespace ortops {
namespace contrib {

template <class NTYPE>
NTYPE flattened_dimension(const std::vector<NTYPE>& values, int64_t first = 0) {
NTYPE r = 1;
auto end = values.begin() + first;
for (auto it = values.begin(); it != end; ++it)
r *= *it;
return r;
}

#define _ENFORCE(cond, msg) \
if (!(cond)) ORTX_CXX_API_THROW(msg, ORT_RUNTIME_EXCEPTION);
Expand Down Expand Up @@ -89,7 +101,7 @@ ONNXTensorElementDataType ScatterNDOfShapeOp<float>::GetInputType(std::size_t in
}

template <>
ONNXTensorElementDataType ScatterNDOfShapeOp<half>::GetInputType(std::size_t index) const {
ONNXTensorElementDataType ScatterNDOfShapeOp<ortc::MFloat16>::GetInputType(std::size_t index) const {
switch (index) {
case 0:
case 1:
Expand Down Expand Up @@ -142,7 +154,7 @@ ONNXTensorElementDataType ScatterNDOfShapeOp<float>::GetOutputType(std::size_t i
}

template <>
ONNXTensorElementDataType ScatterNDOfShapeOp<half>::GetOutputType(std::size_t index) const {
ONNXTensorElementDataType ScatterNDOfShapeOp<ortc::MFloat16>::GetOutputType(std::size_t index) const {
// D, scale D
switch (index) {
case 0:
Expand Down Expand Up @@ -172,7 +184,7 @@ ScatterNDOfShapeKernel<T>::ScatterNDOfShapeKernel(const OrtApi& api,
const OrtKernelInfo* info) {
char value_string[1000];
std::size_t size = 1000;
ThrowOnError(api, api.KernelInfoGetAttribute_string(info, "reduction", value_string, &size));
Ort::ThrowOnError(api.KernelInfoGetAttribute_string(info, "reduction", value_string, &size));
std::string value = value_string;
if (value == "add")
reduction_ = Reduction::Add;
Expand Down Expand Up @@ -204,13 +216,9 @@ void ScatterNDOfShapeKernel<T>::Compute(OrtKernelContext* context) {
cudaStream_t stream = (cudaStream_t)ctx.GetGPUComputeStream();

auto memi = updates.GetTensorMemoryInfo();
_ENFORCE(memi.GetDeviceType() == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU,
"Tensor updates is not on GPU.");

_ENFORCE(memi.GetDeviceType() == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU, "Tensor updates is not on GPU.");
auto mem = shape.GetTensorMemoryInfo();
_ENFORCE(
mem.GetDeviceType() == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU,
"Input shape is not on CPU.");
_ENFORCE(mem.GetDeviceType() == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU, "Input shape is not on CPU.");
const int64_t* X = shape.GetTensorData<int64_t>();
std::vector<int64_t> dims(X, X + dimensions[0]);
output = ctx.GetOutput(0, dims);
Expand All @@ -220,19 +228,36 @@ void ScatterNDOfShapeKernel<T>::Compute(OrtKernelContext* context) {
if (reduction_ == Reduction::Add &&
indices_shape[indices_shape.size() - 1] == 1 && input_shape.size() == 2 &&
input_shape[input_shape.size() - 1] >= maxThreadPerBlock_) {
size_t indice_size = static_cast<size_t>(onnx_c_ops::flattened_dimension(indices_shape));
size_t update_size = static_cast<size_t>(onnx_c_ops::flattened_dimension(updates_shape));

_ENFORCE(update_size == indice_size * input_shape[input_shape.size() - 1],
"Size mismatch.");

ComputeNoAtomic(stream, input_shape, indices_shape, output.GetTensorMutableData<T>(),
indices.GetTensorData<int64_t>(), updates.GetTensorData<T>());
size_t indice_size = static_cast<size_t>(flattened_dimension(indices_shape));
size_t update_size = static_cast<size_t>(flattened_dimension(updates_shape));
_ENFORCE(update_size == indice_size * input_shape[input_shape.size() - 1], "Size mismatch.");
ComputeNoAtomic(stream, input_shape, indices_shape, output.GetTensorMutableData<T>(), indices.GetTensorData<int64_t>(), updates.GetTensorData<T>());
} else {
_ENFORCE("This operator can only be used when the indices_shape[-1] == 1 and input_shape is a 2D matrix.");
ORTX_CXX_API_THROW("This operator can only be used when the indices_shape[-1] == 1 and input_shape is a 2D matrix.", ORT_RUNTIME_EXCEPTION);
}
}

template <typename T>
void _ComputeNoAtomic(cudaStream_t stream, const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& indices_shape, T* output_data,
const int64_t* indices_data, const T* updates_data,
int threads_per_block, int blocks_per_grid, size_t indice_size, size_t nrows, size_t stride) {
dim3 threads(threads_per_block);
dim3 blocks(blocks_per_grid);
addition_inplace_kernel<T><<<blocks, threads, 0, stream>>>(output_data, indices_data, updates_data, indice_size, nrows, stride);
}

template <>
void _ComputeNoAtomic<ortc::MFloat16>(cudaStream_t stream, const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& indices_shape, ortc::MFloat16* output_data,
const int64_t* indices_data, const ortc::MFloat16* updates_data,
int threads_per_block, int blocks_per_grid, size_t indice_size, size_t nrows, size_t stride) {

dim3 threads(threads_per_block);
dim3 blocks(blocks_per_grid);
addition_inplace_kernel<half><<<blocks, threads, 0, stream>>>((half*)output_data, indices_data, (const half*)updates_data, indice_size, nrows, stride);
}

template <typename T>
void ScatterNDOfShapeKernel<T>::ComputeNoAtomic(cudaStream_t& stream,
const std::vector<int64_t>& input_shape,
Expand All @@ -243,8 +268,8 @@ void ScatterNDOfShapeKernel<T>::ComputeNoAtomic(cudaStream_t& stream,
// reduction_ == Reduction::add
// indices_shape[indices_shape.size() - 1] == 1
// input_shape.size() == 2
size_t indice_size = static_cast<size_t>(onnx_c_ops::flattened_dimension(indices_shape));
size_t input_size = static_cast<size_t>(onnx_c_ops::flattened_dimension(input_shape));
size_t indice_size = static_cast<size_t>(flattened_dimension(indices_shape));
size_t input_size = static_cast<size_t>(flattened_dimension(input_shape));
size_t stride = input_shape[input_shape.size() - 1];
size_t nrows = input_size / stride;

Expand All @@ -253,15 +278,11 @@ void ScatterNDOfShapeKernel<T>::ComputeNoAtomic(cudaStream_t& stream,
std::vector<uint8_t> processed_once(input_shape[0], 0);

int threads_per_block = std::min(256, maxThreadPerBlock_ / 8);

int blocks_per_grid = (stride + threads_per_block - 1) / threads_per_block;
dim3 threads(threads_per_block);
dim3 blocks(blocks_per_grid);
addition_inplace_kernel<T><<<blocks, threads, 0, stream>>>(
output_data, indices_data, updates_data, indice_size, nrows, stride);
_ComputeNoAtomic(stream, input_shape, indices_shape, output_data, indices_data, updates_data, threads_per_block, blocks_per_grid, indice_size, nrows, stride);
}

static ScatterNDOfShapeOp<float> _op32;
static ScatterNDOfShapeOp<half> _op16;
static ScatterNDOfShapeOp<ortc::MFloat16> _op16;

} // namespace ortops
} // namespace contrib
18 changes: 13 additions & 5 deletions operators/contrib/cuda/scatter_nd_of_shape.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "ocos.h"
// #include "cublas_v2.h"
// #include <cuda_runtime.h>
#include <cuda_runtime.h>

#ifdef ORT_SWIFT_PACKAGE_MANAGER_BUILD
#include "onnxruntime/onnxruntime_cxx_api.h"
#else
#include "onnxruntime_cxx_api.h"
#endif

namespace ortops {
namespace contrib {

enum class Reduction : int {
None = 0,
Expand Down Expand Up @@ -34,8 +43,7 @@ struct ScatterNDOfShapeKernel {
};

template <typename T>
struct ScatterNDOfShapeOp
: Ort::CustomOpBase<ScatterNDOfShapeOp<T>, ScatterNDOfShapeKernel<T>> {
struct ScatterNDOfShapeOp : Ort::CustomOpBase<ScatterNDOfShapeOp<T>, ScatterNDOfShapeKernel<T>> {
typedef Ort::CustomOpBase<ScatterNDOfShapeOp<T>, ScatterNDOfShapeKernel<T>> parent_type;
ScatterNDOfShapeOp() : parent_type() {}
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
Expand All @@ -52,4 +60,4 @@ struct ScatterNDOfShapeOp
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(std::size_t index) const;
};

} // namespace ortops
} // namespace contrib
99 changes: 98 additions & 1 deletion test/cuda/test_cudaops.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,26 @@
import unittest
import numpy as np
from numpy.testing import assert_almost_equal
from onnx import helper, onnx_pb as onnx_proto
from onnx import helper, onnx_pb as onnx_proto, TensorProto
from onnx.reference import ReferenceEvaluator
from onnx.reference.op_run import OpRun
from onnx.reference.ops.op_scatternd import _scatter_nd_impl
from onnxruntime_extensions import make_onnx_model
from onnxruntime_extensions import get_library_path as _get_library_path

import onnxruntime as _ort


class ScatterNDOfShape(OpRun):
op_domain = "ai.onnx.contrib"

def _run(self, shape, indices, updates, reduction=None, strategy=None):
data = np.zeros(shape, dtype=updates.dtype)
y = _scatter_nd_impl(data, indices, updates, reduction=reduction)
return (y,)



class TestCudaOps(unittest.TestCase):
@staticmethod
def _create_negpos_test_model(domain='ai.onnx.contrib'):
Expand Down Expand Up @@ -117,5 +130,89 @@ def test_cuda_fastgelu_f16(self):
print ('CUDAExecutionProvider not available, test_cuda_fastgelu_f16 skipped.')


def _scatternd_of_shape_cuda(self, reduction, line, itype):
import onnxruntime

model1 = helper.make_model(
helper.make_graph(
[
helper.make_node(
"ScatterND",
inputs=["data", "indices", "updates"],
outputs=["y"],
reduction=reduction,
)
],
"nd",
[
helper.make_tensor_value_info("data", itype, [None, None, None]),
helper.make_tensor_value_info(
"indices", TensorProto.INT64, [None, None]
),
helper.make_tensor_value_info("updates", itype, [None, None, None]),
],
[helper.make_tensor_value_info("y", itype, [None, None, None])],
),
opset_imports=[helper.make_opsetid("", 18)],
ir_version=9,
)

dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
data = np.zeros((2, 2, 3), dtype=dtype)

model2 = helper.make_model(
helper.make_graph(
[
helper.make_node(
"ScatterNDOfShape",
inputs=["shape", "indices", "updates"],
outputs=["y"],
reduction=reduction,
domain="onnx_extended.ortops.optim.cuda",
)
],
"nd",
[
helper.make_tensor_value_info("shape", TensorProto.INT64, [None]),
helper.make_tensor_value_info(
"indices", TensorProto.INT64, [None, None]
),
helper.make_tensor_value_info("updates", itype, [None, None, None]),
],
[helper.make_tensor_value_info("y", itype, [None, None, None])],
),
opset_imports=[
helper.make_opsetid("", 18),
helper.make_opsetid("onnx_extended.ortops.optim.cuda", 1),
],
ir_version=9,
)

indices = np.array([[line], [1 - line], [line]], dtype=np.int64)
if itype == TensorProto.FLOAT:
updates = (2 ** np.arange(18).reshape((3, 2, 3))).astype(dtype)
else:
updates = np.arange(18).reshape((3, 2, 3)).astype(dtype)

feeds1 = dict(data=data, indices=indices, updates=updates)
feeds2 = dict(
shape=np.array([2, 2, 3], dtype=np.int64), indices=indices, updates=updates
)
ref = ReferenceEvaluator(model1, new_ops=[ScatterNDOfShape])
expected = ref.run(None, feeds1)[0]

opts = onnxruntime.SessionOptions()
opts.register_custom_ops_library(_get_library_path())
sess = onnxruntime.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"])
got = sess.run(None, feeds2)[0]
self.assertEqual(expected.tolist(), got.tolist())

def test_cuda_scatternd_of_shape(self):
self._scatternd_of_shape_cuda("add", 0, TensorProto.FLOAT)
self._scatternd_of_shape_cuda("add", 0, TensorProto.FLOAT16)
self._scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT)
self._scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT16)


if __name__ == "__main__":
unittest.main()

0 comments on commit 927522d

Please sign in to comment.