diff --git a/onert-micro/luci-interpreter/include/luci_interpreter/test_models/gather_nd/FloatGatherNDKernel.h b/onert-micro/luci-interpreter/include/luci_interpreter/test_models/gather_nd/FloatGatherNDKernel.h new file mode 100644 index 00000000000..4a789c76c98 --- /dev/null +++ b/onert-micro/luci-interpreter/include/luci_interpreter/test_models/gather_nd/FloatGatherNDKernel.h @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LUCI_INTERPRETER_TEST_MODELS_FLOAT_GATHER_ND_KERNEL_H +#define LUCI_INTERPRETER_TEST_MODELS_FLOAT_GATHER_ND_KERNEL_H + +#include "TestDataGatherNDBase.h" + +namespace luci_interpreter +{ +namespace test_kernel +{ +namespace gather_nd_float +{ +/* + * GatherND Kernel: + * + * Input(1, 4, 4, 3) + * | + * GatherND + * | + * Output(1, 2, 4, 4, 3) + */ +const unsigned char test_kernel_model_circle[] = { + 0x18, 0x00, 0x00, 0x00, 0x43, 0x49, 0x52, 0x30, 0x00, 0x00, 0x0e, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x08, 0x00, 0x10, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x50, 0x00, 0x00, 0x00, 0x84, 0x01, 0x00, 0x00, 0xa0, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, 0x2c, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x8c, 0xff, 0xff, 0xff, 0x90, 0xff, 0xff, 0xff, 0x94, 0xff, 0xff, 0xff, + 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x18, 0x00, 0x14, 0x00, + 0x10, 0x00, 0x0c, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x1c, 0x00, 0x00, 0x00, 0x60, 0x00, 0x00, 0x00, 0x64, 0x00, 0x00, 0x00, 0x68, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x6d, 0x61, 0x69, 0x6e, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x14, 0x00, 0x00, 0x00, 0x10, 0x00, 0x0c, 0x00, + 0x07, 0x00, 0x08, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x53, 0x10, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x74, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xa4, 0xff, 0xff, 0xff, 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x6f, 0x66, 0x6d, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x14, 0x00, 0x10, 0x00, 0x0f, 0x00, + 0x08, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x02, 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x69, 0x6e, 0x64, 0x69, + 0x63, 0x65, 0x73, 0x00, 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, + 0x70, 0x61, 0x72, 0x61, 0x6d, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x0c, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x6b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x6b, 0x11, 0x00, 0x00, 0x00, 0x4f, 0x4e, 0x45, 0x2d, + 0x74, 0x66, 0x6c, 0x69, 0x74, 0x65, 0x32, 0x63, 0x69, 0x72, 0x63, 0x6c, 0x65, 0x00, 0x00, 0x00}; + +const std::vector input_data = {-63.555645, 32.902435, -61.76536, 28.280264, + 66.80893, -13.163652, -66.06793, -15.827837}; +const std::vector reference_output_data = {-61.76536, 28.280264, 66.80893, -13.163652}; + +} // namespace gather_nd_float + +class TestDataFloatGatherND : public TestDataGatherNDBase +{ +public: + TestDataFloatGatherND() + { + _input_data = gather_nd_float::input_data; + _reference_output_data = gather_nd_float::reference_output_data; + _test_kernel_model_circle = gather_nd_float::test_kernel_model_circle; + } + + ~TestDataFloatGatherND() override = default; +}; + +} // namespace test_kernel +} // namespace luci_interpreter + +#endif // LUCI_INTERPRETER_TEST_MODELS_FLOAT_GATHER_ND_KERNEL_H diff --git a/onert-micro/luci-interpreter/include/luci_interpreter/test_models/gather_nd/NegGatherNDKernel.h b/onert-micro/luci-interpreter/include/luci_interpreter/test_models/gather_nd/NegGatherNDKernel.h new file mode 100644 index 00000000000..71cba6bb877 --- /dev/null +++ b/onert-micro/luci-interpreter/include/luci_interpreter/test_models/gather_nd/NegGatherNDKernel.h @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LUCI_INTERPRETER_TEST_MODELS_NEG_GATHER_KERNEL_H +#define LUCI_INTERPRETER_TEST_MODELS_NEG_GATHER_KERNEL_H + +#include "TestDataGatherNDBase.h" + +namespace luci_interpreter +{ +namespace test_kernel +{ +namespace neg_gather_nd_mismatch_input_output_type +{ +/* + * GatherND Kernel with input output type mismatch (should be equal): + * + * Input(1, 4, 4, 3) - FLOAT32 Indices(1, 2, 1) + * \ / + * GatherND + * | + * Output(1, 2, 2, 4) - INT32 + */ +const unsigned char test_kernel_model_circle[] = { + 0x18, 0x00, 0x00, 0x00, 0x43, 0x49, 0x52, 0x30, 0x00, 0x00, 0x0e, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x08, 0x00, 0x10, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x50, 0x00, 0x00, 0x00, 0x88, 0x01, 0x00, 0x00, 0xa4, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, 0x2c, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x8c, 0xff, 0xff, 0xff, 0x90, 0xff, 0xff, 0xff, 0x94, 0xff, 0xff, 0xff, + 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x18, 0x00, 0x14, 0x00, + 0x10, 0x00, 0x0c, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x1c, 0x00, 0x00, 0x00, 0x60, 0x00, 0x00, 0x00, 0x64, 0x00, 0x00, 0x00, 0x68, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x6d, 0x61, 0x69, 0x6e, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x14, 0x00, 0x00, 0x00, 0x10, 0x00, 0x0c, 0x00, + 0x07, 0x00, 0x08, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x53, 0x10, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x78, 0x00, 0x00, 0x00, 0x3c, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xd8, 0xff, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x02, 0x0c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x6f, 0x66, 0x6d, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x14, 0x00, + 0x10, 0x00, 0x0f, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x69, 0x6e, 0x64, 0x69, 0x63, 0x65, 0x73, 0x00, 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x04, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x0c, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x6b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x6b, 0x11, 0x00, 0x00, 0x00, + 0x4f, 0x4e, 0x45, 0x2d, 0x74, 0x66, 0x6c, 0x69, 0x74, 0x65, 0x32, 0x63, 0x69, 0x72, 0x63, 0x6c, + 0x65, 0x00, 0x00, 0x00}; + +} // namespace neg_gather_nd_mismatch_input_output_type + +class NegTestDataInputOutputTypeMismatchGatherNDKernel : public NegTestDataBase +{ +public: + NegTestDataInputOutputTypeMismatchGatherNDKernel() + { + _test_kernel_model_circle = neg_gather_nd_mismatch_input_output_type::test_kernel_model_circle; + } + + ~NegTestDataInputOutputTypeMismatchGatherNDKernel() override = default; + + const unsigned char *get_model_ptr() override final { return _test_kernel_model_circle; } + +protected: + const unsigned char *_test_kernel_model_circle; +}; + +} // namespace test_kernel +} // namespace luci_interpreter + +#endif // LUCI_INTERPRETER_TEST_MODELS_NEG_GATHER_ND_KERNEL_H diff --git a/onert-micro/luci-interpreter/include/luci_interpreter/test_models/gather_nd/TestDataGatherNDBase.h b/onert-micro/luci-interpreter/include/luci_interpreter/test_models/gather_nd/TestDataGatherNDBase.h new file mode 100644 index 00000000000..a9485ad2c8d --- /dev/null +++ b/onert-micro/luci-interpreter/include/luci_interpreter/test_models/gather_nd/TestDataGatherNDBase.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LUCI_INTERPRETER_TEST_MODELS_GATHER_ND_KERNEL_BASE_H +#define LUCI_INTERPRETER_TEST_MODELS_GATHER_ND_KERNEL_BASE_H + +#include "luci_interpreter/test_models/TestDataBase.h" + +namespace luci_interpreter +{ +namespace test_kernel +{ + +template class TestDataGatherNDBase : public TestDataBase +{ +public: + TestDataGatherNDBase() = default; + + const unsigned char *get_model_ptr() override final { return _test_kernel_model_circle; } + + const std::vector &get_input_data_by_index(int i) override final + { + switch (i) + { + case 0: + return _input_data; + default: + assert(false && "Wrong input index"); + } + } + + const std::vector &get_output_data_by_index(int i) override final + { + assert(i == 0); + return _reference_output_data; + } + +protected: + std::vector _input_data; + std::vector _reference_output_data; + const unsigned char *_test_kernel_model_circle; +}; + +} // namespace test_kernel +} // namespace luci_interpreter + +#endif // LUCI_INTERPRETER_TEST_MODELS_GATHER_ND_KERNEL_BASE_H diff --git a/onert-micro/luci-interpreter/pal/cmsisnn/KernelsToBuild.lst b/onert-micro/luci-interpreter/pal/cmsisnn/KernelsToBuild.lst index e92ce5e85f1..c02895a35ba 100644 --- a/onert-micro/luci-interpreter/pal/cmsisnn/KernelsToBuild.lst +++ b/onert-micro/luci-interpreter/pal/cmsisnn/KernelsToBuild.lst @@ -15,6 +15,7 @@ REGISTER_KERNEL(CONV_2D, Conv2D) REGISTER_KERNEL(LOGISTIC, Logistic) REGISTER_KERNEL(LOG, Log) REGISTER_KERNEL(GATHER, Gather) +REGISTER_KERNEL(GATHER_ND, GatherND) REGISTER_KERNEL(EXP, Exp) REGISTER_KERNEL(FULLY_CONNECTED, FullyConnected) REGISTER_KERNEL(GREATER, Greater) diff --git a/onert-micro/luci-interpreter/pal/common/PALGatherND.h b/onert-micro/luci-interpreter/pal/common/PALGatherND.h new file mode 100644 index 00000000000..63aced347c9 --- /dev/null +++ b/onert-micro/luci-interpreter/pal/common/PALGatherND.h @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LUCI_INTERPRETER_PAL_GATHER_ND_COMMON_H +#define LUCI_INTERPRETER_PAL_GATHER_ND_COMMON_H + +#include "PALUtils.h" +#include + +namespace luci_interpreter_pal +{ + +constexpr int MAX_INDICES_ND = 5; + +template +inline void GatherND(luci_interpreter::RuntimeShape params_shape, const ParamsT *param_data, + luci_interpreter::RuntimeShape indices_shape, const IndicesT *index_data, + ParamsT *output_data) +{ + const int indices_dims = indices_shape.dimensionsCount(); + const int indices_nd = indices_shape.dims(indices_dims - 1); + const int params_dims = params_shape.dimensionsCount(); + + int n_slices = 1; + for (int i = 0; i < indices_dims - 1; ++i) + { + n_slices *= indices_shape.dims(i); + } + + // If indices[-1] == params.rank, fetch single elements. + // If indices[-1] < params.rank, fetch slices. + int slice_size = 1; + for (int i = indices_nd; i < params_dims; ++i) + { + slice_size *= params_shape.dims(i); + } + + int params_flat_size = params_shape.flatSize(); + int remain_flat_size = params_flat_size; + + // Number of elements per dimension + int dims_to_count[MAX_INDICES_ND]; + for (int i = 0; i < indices_nd; ++i) + { + dims_to_count[i] = remain_flat_size / params_shape.dims(i); + remain_flat_size = dims_to_count[i]; + } + + for (int i = 0; i < n_slices; ++i) + { + int from_pos = 0; + for (int j = 0; j < indices_nd; ++j) + { + int offset = i * indices_nd + j; + IndicesT index = index_data[offset]; + from_pos += index * dims_to_count[j]; + } + if (from_pos < 0 || from_pos + slice_size > params_flat_size) + { + assert(false && "GatherND error"); + return; + } + std::memcpy(output_data + i * slice_size, param_data + from_pos, sizeof(ParamsT) * slice_size); + } +} + +} // namespace luci_interpreter_pal + +#endif // LUCI_INTERPRETER_PAL_GATHER_ND_COMMON_H diff --git a/onert-micro/luci-interpreter/pal/mcu/KernelsToBuild.lst b/onert-micro/luci-interpreter/pal/mcu/KernelsToBuild.lst index 04e28bb89c4..bd4d2a8b854 100644 --- a/onert-micro/luci-interpreter/pal/mcu/KernelsToBuild.lst +++ b/onert-micro/luci-interpreter/pal/mcu/KernelsToBuild.lst @@ -17,6 +17,7 @@ REGISTER_KERNEL(CONV_2D, Conv2D) REGISTER_KERNEL(LOGISTIC, Logistic) REGISTER_KERNEL(LOG, Log) REGISTER_KERNEL(GATHER, Gather) +REGISTER_KERNEL(GATHER_ND, GatherND) REGISTER_KERNEL(EXP, Exp) REGISTER_KERNEL(GREATER, Greater) REGISTER_KERNEL(GREATER_EQUAL, GreaterEqual) diff --git a/onert-micro/luci-interpreter/src/kernels/GatherND.cpp b/onert-micro/luci-interpreter/src/kernels/GatherND.cpp new file mode 100644 index 00000000000..c9c41393490 --- /dev/null +++ b/onert-micro/luci-interpreter/src/kernels/GatherND.cpp @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright 2021 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Builders.h" +#include "kernels/Utils.h" +#include "TISOKernel.h" +#include "PALGatherND.h" + +#include + +namespace luci_interpreter +{ + +void configure_kernel_CircleGatherND(const circle::Operator *cur_op, + BaseRuntimeGraph *runtime_graph) +{ + kernels::TISOKernel kernel(cur_op, runtime_graph); + + LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input2()) == DataType::S32); + LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input1()) == + Tensor::element_type(kernel.output())); + + const int params_rank = Tensor::num_dims(kernel.input1()); + const int indices_rank = Tensor::num_dims(kernel.input2()); + const int indices_nd = Tensor::dim(kernel.input2(), indices_rank - 1); + + LUCI_INTERPRETER_CHECK(params_rank >= 1); + LUCI_INTERPRETER_CHECK(indices_rank >= 1); + LUCI_INTERPRETER_CHECK(indices_nd <= params_rank); + LUCI_INTERPRETER_CHECK(indices_nd <= luci_interpreter_pal::MAX_INDICES_ND); +} + +void execute_kernel_CircleGatherND(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph) +{ + kernels::TISOKernel kernel(cur_op, runtime_graph); + + const uint8_t *params_data = runtime_graph->getDataByTensor(kernel.input1()); + const uint8_t *indies_data = runtime_graph->getConstDataByTensor(kernel.input2()); + uint8_t *output_data = runtime_graph->getDataByTensor(kernel.output()); + + switch (Tensor::element_type(kernel.input1())) + { +#ifndef DIS_FLOAT + case DataType::FLOAT32: + return luci_interpreter_pal::GatherND( + kernels::getTensorRuntimeShape(kernel.input1(), runtime_graph), + kernels::getTensorData(params_data), + kernels::getTensorRuntimeShape(kernel.input2(), runtime_graph), + kernels::getTensorData(indies_data), kernels::getTensorData(output_data)); +#endif // DIS_FLOAT + default: + assert(false && "Unsupported type"); + } +} + +} // namespace luci_interpreter diff --git a/onert-micro/luci-interpreter/src/kernels/GatherND.test.cpp b/onert-micro/luci-interpreter/src/kernels/GatherND.test.cpp new file mode 100644 index 00000000000..27080c0b251 --- /dev/null +++ b/onert-micro/luci-interpreter/src/kernels/GatherND.test.cpp @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernels/TestUtils.h" +#include "luci_interpreter/test_models/gather_nd/FloatGatherNDKernel.h" +#include "luci_interpreter/test_models/gather_nd/NegGatherNDKernel.h" + +#include "loader/ModuleLoader.h" + +namespace luci_interpreter +{ +namespace +{ + +using namespace testing; + +class GatherNDTest : public ::testing::Test +{ + // Do nothing +}; + +template +std::vector checkGatherNDKernel(test_kernel::TestDataBase *test_data_base) +{ + MemoryManager memory_manager{}; + RuntimeModule runtime_module{}; + bool dealloc_input = true; + + // Load model with single op + auto *model_data_raw = reinterpret_cast(test_data_base->get_model_ptr()); + ModuleLoader::load(&runtime_module, &memory_manager, model_data_raw, dealloc_input); + + auto *main_runtime_graph = runtime_module.getMainGraph(); + assert(main_runtime_graph->getNumOfInputTensors() == 1); + + // Set input data + { + auto *input_tensor_data = reinterpret_cast(main_runtime_graph->configureGraphInput(0)); + std::copy(test_data_base->get_input_data_by_index(0).begin(), + test_data_base->get_input_data_by_index(0).end(), input_tensor_data); + } + + runtime_module.execute(); + + assert(main_runtime_graph->getNumOfOutputTensors() == 1); + + T *output_data = reinterpret_cast(main_runtime_graph->getOutputDataByIndex(0)); + const size_t num_elements = (main_runtime_graph->getOutputDataSizeByIndex(0) / sizeof(T)); + std::vector output_data_vector(output_data, output_data + num_elements); + return output_data_vector; +} + +TEST_F(GatherNDTest, Gather_Float_P) +{ + test_kernel::TestDataFloatGatherND test_data_float_gather; + std::vector output_data_vector = checkGatherNDKernel(&test_data_float_gather); + EXPECT_THAT(output_data_vector, kernels::testing::FloatArrayNear( + test_data_float_gather.get_output_data_by_index(0), 0.0001f)); +} + +TEST_F(GatherNDTest, Input_output_type_mismatch_NEG) +{ + test_kernel::NegTestDataInputOutputTypeMismatchGatherNDKernel test_data_kernel; + + MemoryManager memory_manager{}; + RuntimeModule runtime_module{}; + bool dealloc_input = true; + // Load model with single op + auto *model_data_raw = reinterpret_cast(test_data_kernel.get_model_ptr()); + EXPECT_DEATH(ModuleLoader::load(&runtime_module, &memory_manager, model_data_raw, dealloc_input), + ""); +} + +} // namespace +} // namespace luci_interpreter