From acc086cbc48eec0899dc4bb0b3932a44d2575982 Mon Sep 17 00:00:00 2001 From: Artem Balyshev Date: Mon, 12 Aug 2024 20:00:03 +0300 Subject: [PATCH 1/2] [onert-micro] Support GRU This pr adds supporting for circle GRU op. ONE-DCO-1.0-Signed-off-by: Artem Balyshev --- onert-micro/CMakeLists.txt | 2 +- .../include/core/reader/OMCircleReader.h | 1 - .../include/execute/OMRuntimeKernel.h | 2 +- .../include/pal/common/PALGRUCommon.h | 132 +++++++++++++++ .../include/pal/mcu/KernelsToBuild.lst | 1 + .../onert-micro/include/pal/mcu/PALGRU.h | 23 +++ .../onert-micro/src/execute/kernels/GRU.cpp | 157 ++++++++++++++++++ .../src/execute/kernels/tests/GRU.test.cpp | 17 ++ .../onert-micro/src/import/kernels/GRU.cpp | 100 +++++++++++ 9 files changed, 432 insertions(+), 3 deletions(-) create mode 100644 onert-micro/onert-micro/include/pal/common/PALGRUCommon.h create mode 100644 onert-micro/onert-micro/include/pal/mcu/PALGRU.h create mode 100644 onert-micro/onert-micro/src/execute/kernels/GRU.cpp create mode 100644 onert-micro/onert-micro/src/execute/kernels/tests/GRU.test.cpp create mode 100644 onert-micro/onert-micro/src/import/kernels/GRU.cpp diff --git a/onert-micro/CMakeLists.txt b/onert-micro/CMakeLists.txt index d9388173cf4..a43f7e979e7 100644 --- a/onert-micro/CMakeLists.txt +++ b/onert-micro/CMakeLists.txt @@ -70,7 +70,7 @@ else () message(STATUS "FOUND FlatBuffers") - set(SCHEMA_FILE "${NNAS_PROJECT_SOURCE_DIR}/res/CircleSchema/0.6/circle_schema.fbs") + set(SCHEMA_FILE "${NNAS_PROJECT_SOURCE_DIR}/res/CircleSchema/0.8/circle_schema.fbs") # NOTE Copy circle_schema.fbs as schema.fbs to generate "schema_generated.fbs" instead of "circle_schema_generated.fbs" add_custom_command(OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/schema.fbs" diff --git a/onert-micro/onert-micro/include/core/reader/OMCircleReader.h b/onert-micro/onert-micro/include/core/reader/OMCircleReader.h index 90c1d8acc47..5d32d516c05 100644 --- a/onert-micro/onert-micro/include/core/reader/OMCircleReader.h +++ b/onert-micro/onert-micro/include/core/reader/OMCircleReader.h @@ -55,7 +55,6 @@ class OMCircleReader const CircleOperators *operators() const { return _current_subgraph->operators(); } const CircleValues *inputs() const { return _current_subgraph->inputs(); } const CircleValues *outputs() const { return _current_subgraph->outputs(); } - const circle::DataFormat data_format() const { return _current_subgraph->data_format(); } const CircleMetadataSet *metadata() const { return _model->metadata(); } uint32_t num_subgraph() const { return _model->subgraphs()->size(); } diff --git a/onert-micro/onert-micro/include/execute/OMRuntimeKernel.h b/onert-micro/onert-micro/include/execute/OMRuntimeKernel.h index b6f63cdaa48..e33239f7256 100644 --- a/onert-micro/onert-micro/include/execute/OMRuntimeKernel.h +++ b/onert-micro/onert-micro/include/execute/OMRuntimeKernel.h @@ -23,7 +23,7 @@ #include -constexpr static uint32_t maxInputSize = 5; +constexpr static uint32_t maxInputSize = 6; constexpr static uint32_t maxOutputSize = 5; namespace onert_micro diff --git a/onert-micro/onert-micro/include/pal/common/PALGRUCommon.h b/onert-micro/onert-micro/include/pal/common/PALGRUCommon.h new file mode 100644 index 00000000000..072a151fba2 --- /dev/null +++ b/onert-micro/onert-micro/include/pal/common/PALGRUCommon.h @@ -0,0 +1,132 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ONERT_MICRO_EXECUTE_PAL_GRU_COMMON_H +#define ONERT_MICRO_EXECUTE_PAL_GRU_COMMON_H + +#include "OMStatus.h" +#include "core/OMRuntimeShape.h" + +#include "PALUtils.h" +#include "ProcessBroadcastShapes.h" +#include "PALFullyConnected.h" +#include "PALLogistic.h" + +namespace onert_micro +{ +namespace execute +{ +namespace pal +{ +namespace +{ +void calculateGRU(const float *input_data, const float *weight_input_data, + const float *weight_hidden_data, const float *bias_input_data, + const float *bias_hidden_data, float *output_data, + const core::OMRuntimeShape &input_shape, const core::OMRuntimeShape &output_shape, + const core::OMRuntimeShape &weight_input_shape, + const core::OMRuntimeShape &weight_hidden_shape, float *output_input_data, + float *output_hidden_data, const core::OMRuntimeShape &output_shape_fc) +{ + core::FullyConnectedParams op_params{}; + // As FC nodes doesn't have any activations inside GRU, let' use just numeric limits + op_params.float_activation_min = std::numeric_limits::lowest(); + op_params.float_activation_max = std::numeric_limits::max(); + + // FC Input + FullyConnected(op_params, output_data, weight_input_shape, weight_input_data, bias_input_data, + output_shape_fc, output_input_data); + + // FC Hidden + FullyConnected(op_params, input_data, weight_hidden_shape, weight_hidden_data, bias_hidden_data, + output_shape_fc, output_hidden_data); + + int num_elements = output_shape_fc.dims(1) / 3; + + float *second_hidden_part = output_hidden_data + num_elements; + float *second_input_part = output_input_data + num_elements; + + float *third_hidden_part = second_hidden_part + num_elements; + float *third_input_part = second_input_part + num_elements; + + // Calculate Left part + for (int i = 0; i < num_elements; ++i) + { + output_input_data[i] += output_hidden_data[i]; + } + + Logistic(num_elements, output_input_data, output_input_data); + + // Calculate most left add + float *most_left_part_final = output_input_data; + float *first_part = output_input_data; + for (int i = 0; i < num_elements; ++i) + { + output_data[i] *= most_left_part_final[i]; + first_part[i] = 1.0f - first_part[i]; + } + + // Calc third part + for (int i = 0; i < num_elements; ++i) + { + second_hidden_part[i] += second_input_part[i]; + } + Logistic(num_elements, second_hidden_part, second_hidden_part); + + for (int i = 0; i < num_elements; ++i) + { + second_hidden_part[i] *= third_input_part[i]; + second_hidden_part[i] += third_hidden_part[i]; + second_hidden_part[i] = std::tanh(second_hidden_part[i]); + second_hidden_part[i] *= first_part[i]; + output_data[i] += second_hidden_part[i]; + } +} + +} // namespace + +OMStatus GRU(const float *input_data, const float *weight_input_data, + const float *weight_hidden_data, const float *bias_input_data, + const float *bias_hidden_data, const float *hidden_state_data, float *output_data, + float *output_input_data, float *output_hidden_data, + const core::OMRuntimeShape &input_shape, const core::OMRuntimeShape &output_shape, + const core::OMRuntimeShape &weight_input_shape, + const core::OMRuntimeShape &weight_hidden_shape) +{ + const int32_t time = input_shape.dims(0); + + core::OMRuntimeShape output_shape_fc(2); + output_shape_fc.setDim(0, 1); + output_shape_fc.setDim(1, weight_hidden_shape.dims(0)); + + std::memcpy(output_data, hidden_state_data, + output_shape.dims(output_shape.dimensionsCount() - 1) * sizeof(float)); + + for (int i = 0; i < time; ++i) + { + calculateGRU(input_data, weight_input_data, weight_hidden_data, bias_input_data, + bias_hidden_data, output_data, input_shape, output_shape, weight_input_shape, + weight_hidden_shape, output_input_data, output_hidden_data, output_shape_fc); + input_data += input_shape.dims(2); + } + return Ok; +} + +} // namespace pal +} // namespace execute +} // namespace onert_micro + +#endif // ONERT_MICRO_EXECUTE_PAL_GRU_COMMON_H diff --git a/onert-micro/onert-micro/include/pal/mcu/KernelsToBuild.lst b/onert-micro/onert-micro/include/pal/mcu/KernelsToBuild.lst index 83208f90cc3..94d1bff79a2 100644 --- a/onert-micro/onert-micro/include/pal/mcu/KernelsToBuild.lst +++ b/onert-micro/onert-micro/include/pal/mcu/KernelsToBuild.lst @@ -23,6 +23,7 @@ REGISTER_KERNEL(GATHER_ND, GatherND) REGISTER_KERNEL(EXP, Exp) REGISTER_KERNEL(GREATER, Greater) REGISTER_KERNEL(GREATER_EQUAL, GreaterEqual) +REGISTER_KERNEL(GRU, GRU) REGISTER_KERNEL(EXPAND_DIMS, ExpandDims) REGISTER_KERNEL(ELU, Elu) REGISTER_KERNEL(EQUAL, Equal) diff --git a/onert-micro/onert-micro/include/pal/mcu/PALGRU.h b/onert-micro/onert-micro/include/pal/mcu/PALGRU.h new file mode 100644 index 00000000000..75389fe5f7e --- /dev/null +++ b/onert-micro/onert-micro/include/pal/mcu/PALGRU.h @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright 2017 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ONERT_MICRO_EXECUTE_PAL_GRU_H +#define ONERT_MICRO_EXECUTE_PAL_GRU_H + +#include "PALGRUCommon.h" + +#endif // ONERT_MICRO_EXECUTE_PAL_GRU_H diff --git a/onert-micro/onert-micro/src/execute/kernels/GRU.cpp b/onert-micro/onert-micro/src/execute/kernels/GRU.cpp new file mode 100644 index 00000000000..e49e6c92ebe --- /dev/null +++ b/onert-micro/onert-micro/src/execute/kernels/GRU.cpp @@ -0,0 +1,157 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "OMStatus.h" + +#include "core/OMUtils.h" +#include "core/OMKernelData.h" +#include "core/memory/OMMemoryManager.h" + +#include "execute/OMKernelExecutionBuilder.h" +#include "execute/OMUtils.h" +#include "execute/OMRuntimeKernel.h" + +#include "PALGRU.h" + +using namespace onert_micro; +using namespace onert_micro::core; +using namespace onert_micro::execute; + +namespace +{ + +constexpr uint32_t inputTensorIdx = 0; +constexpr uint32_t hiddenHiddenTensorIdx = 1; +constexpr uint32_t hiddenHiddenBiasTensorIdx = 2; +constexpr uint32_t hiddenInputTensorIdx = 3; +constexpr uint32_t hiddenInputBiasTensorIdx = 4; +constexpr uint32_t stateTensorIdx = 5; + +constexpr uint32_t outputTensorIdx = 0; + +} // namespace + +// NOTE: doesnt currently support dynamic shapes +OMStatus onert_micro::execute::execute_kernel_CircleGRU(const OMExecuteArgs &execute_args) +{ + core::OMRuntimeContext &runtime_context = execute_args.runtime_context; + core::OMRuntimeStorage &runtime_storage = execute_args.runtime_storage; + uint16_t op_index = execute_args.kernel_index; + + const circle::Tensor *input; + const circle::Tensor *hidden_hidden; + const circle::Tensor *hidden_hidden_bias; + const circle::Tensor *hidden_input; + const circle::Tensor *hidden_input_bias; + const circle::Tensor *state; + + const circle::Tensor *output; + + uint8_t *input_data; + uint8_t *hidden_hidden_data; + uint8_t *hidden_hidden_bias_data; + uint8_t *hidden_input_data; + uint8_t *hidden_input_bias_data; + uint8_t *state_data; + uint8_t *output_data; + + // Read kernel + { + execute::OMRuntimeKernel runtime_kernel; + runtime_kernel.readKernel(op_index, runtime_context); + + input = runtime_kernel.inputs[inputTensorIdx]; + hidden_hidden = runtime_kernel.inputs[hiddenHiddenTensorIdx]; + hidden_hidden_bias = runtime_kernel.inputs[hiddenHiddenBiasTensorIdx]; + hidden_input = runtime_kernel.inputs[hiddenInputTensorIdx]; + hidden_input_bias = runtime_kernel.inputs[hiddenInputBiasTensorIdx]; + state = runtime_kernel.inputs[stateTensorIdx]; + + output = runtime_kernel.outputs[outputTensorIdx]; + assert(input != nullptr); + assert(hidden_hidden != nullptr); + assert(hidden_input != nullptr); + assert(state != nullptr); + // Biases can be nullptr + assert(output != nullptr); + + runtime_kernel.getDataFromStorage(op_index, runtime_storage, runtime_context); + + input_data = runtime_kernel.inputs_data[inputTensorIdx]; + hidden_hidden_data = runtime_kernel.inputs_data[hiddenHiddenTensorIdx]; + hidden_hidden_bias_data = runtime_kernel.inputs_data[hiddenHiddenBiasTensorIdx]; + hidden_input_data = runtime_kernel.inputs_data[hiddenInputTensorIdx]; + hidden_input_bias_data = runtime_kernel.inputs_data[hiddenInputBiasTensorIdx]; + state_data = runtime_kernel.inputs_data[stateTensorIdx]; + + output_data = runtime_kernel.outputs_data[outputTensorIdx]; + assert(input_data != nullptr); + assert(hidden_hidden_data != nullptr); + assert(hidden_input_data != nullptr); + assert(state_data != nullptr); + // Bias can be nullptr + assert(output_data != nullptr); + } + + OMStatus status; + + uint8_t *output_hidden_data; + uint8_t *output_input_data; + + status = + core::memory::OMMemoryManager::allocateMemory(core::OMRuntimeShape(hidden_hidden).flatSize() * + sizeof(core::OMDataType(hidden_hidden->type())), + &output_hidden_data); + if (status != Ok) + return status; + core::memory::OMMemoryManager::allocateMemory(core::OMRuntimeShape(hidden_input).flatSize() * + sizeof(core::OMDataType(hidden_input->type())), + &output_input_data); + if (status != Ok) + return status; + + switch (input->type()) + { +#ifndef DIS_FLOAT + case circle::TensorType_FLOAT32: + { + status = pal::GRU(core::utils::castInputData(input_data), + core::utils::castInputData(hidden_input_data), + core::utils::castInputData(hidden_hidden_data), + core::utils::castInputData(hidden_input_bias_data), + core::utils::castInputData(hidden_hidden_bias_data), + core::utils::castInputData(state_data), + core::utils::castOutputData(output_data), + core::utils::castOutputData(output_input_data), + core::utils::castOutputData(output_hidden_data), + core::OMRuntimeShape(input), core::OMRuntimeShape(output), + core::OMRuntimeShape(hidden_input), core::OMRuntimeShape(hidden_hidden)); + } + break; +#endif // DIS_FLOAT + default: + { + status = UnsupportedType; + assert(false && "Unsupported type."); + } + } + + core::memory::OMMemoryManager::deallocateMemory(output_input_data); + core::memory::OMMemoryManager::deallocateMemory(output_hidden_data); + + return status; +} diff --git a/onert-micro/onert-micro/src/execute/kernels/tests/GRU.test.cpp b/onert-micro/onert-micro/src/execute/kernels/tests/GRU.test.cpp new file mode 100644 index 00000000000..0c8a203a59d --- /dev/null +++ b/onert-micro/onert-micro/src/execute/kernels/tests/GRU.test.cpp @@ -0,0 +1,17 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// TODO add tests diff --git a/onert-micro/onert-micro/src/import/kernels/GRU.cpp b/onert-micro/onert-micro/src/import/kernels/GRU.cpp new file mode 100644 index 00000000000..1c76b2f9c5d --- /dev/null +++ b/onert-micro/onert-micro/src/import/kernels/GRU.cpp @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "OMStatus.h" + +#include "import/OMKernelConfigureBuilder.h" + +#include "core/OMUtils.h" +#include "core/OMKernelData.h" + +#include "execute/OMRuntimeKernel.h" + +using namespace onert_micro; +using namespace onert_micro::core; + +namespace +{ + +constexpr uint32_t inputTensorIdx = 0; +constexpr uint32_t hiddenHiddenTensorIdx = 1; +constexpr uint32_t hiddenHiddenBiasTensorIdx = 2; +constexpr uint32_t hiddenInputTensorIdx = 3; +constexpr uint32_t hiddenInputBiasTensorIdx = 4; +constexpr uint32_t stateTensorIdx = 5; + +constexpr uint32_t outputTensorIdx = 0; + +} // namespace + +OMStatus onert_micro::import::configure_kernel_CircleGRU(const OMConfigureArgs &config_args) +{ + core::OMRuntimeContext &runtime_context = config_args.runtime_context; + uint16_t op_index = config_args.kernel_index; + + const circle::Tensor *input; + const circle::Tensor *hidden_hidden; + const circle::Tensor *hidden_hidden_bias; + const circle::Tensor *hidden_input; + const circle::Tensor *hidden_input_bias; + const circle::Tensor *state; + + const circle::Tensor *output; + + // Read kernel + execute::OMRuntimeKernel runtime_kernel; + runtime_kernel.readKernel(op_index, runtime_context); + + input = runtime_kernel.inputs[inputTensorIdx]; + hidden_hidden = runtime_kernel.inputs[hiddenHiddenTensorIdx]; + hidden_hidden_bias = runtime_kernel.inputs[hiddenHiddenBiasTensorIdx]; + hidden_input = runtime_kernel.inputs[hiddenInputTensorIdx]; + hidden_input_bias = runtime_kernel.inputs[hiddenInputBiasTensorIdx]; + state = runtime_kernel.inputs[stateTensorIdx]; + + output = runtime_kernel.outputs[outputTensorIdx]; + assert(input != nullptr); + assert(hidden_hidden != nullptr); + assert(hidden_input != nullptr); + assert(state != nullptr); + // Biases can be nullptr + assert(output != nullptr); + + OMStatus status = Ok; + + OMRuntimeShape hidden_hidden_shape(hidden_hidden); + OMRuntimeShape hidden_input_shape(hidden_input); + OMRuntimeShape output_shape(output); + OMRuntimeShape state_shape(state); + + status = utils::checkCondition(hidden_hidden_shape.dims(0) == hidden_input_shape.dims(0)); + if (status != Ok) + return status; + + const int32_t div_factor = 3; + status = + utils::checkCondition(hidden_hidden_shape.dims(0) == + (div_factor * output_shape.dims(output_shape.dimensionsCount() - 1))); + if (status != Ok) + return status; + + status = utils::checkCondition(output_shape.dims(output_shape.dimensionsCount() - 1) == + state_shape.dims(state_shape.dimensionsCount() - 1)); + if (status != Ok) + return status; + + return status; +} From 6a061a9c4ad002a279693826405c5182b5961e0c Mon Sep 17 00:00:00 2001 From: Artem Balyshev Date: Tue, 13 Aug 2024 13:36:56 +0300 Subject: [PATCH 2/2] add test --- .../include/test_models/gru/FloatGRUKernel.h | 175 ++++++++++++++++++ .../include/test_models/gru/TestDataGRUBase.h | 60 ++++++ .../src/execute/kernels/tests/GRU.test.cpp | 30 ++- .../onert-micro/src/import/kernels/GRU.cpp | 4 + 4 files changed, 268 insertions(+), 1 deletion(-) create mode 100644 onert-micro/onert-micro/include/test_models/gru/FloatGRUKernel.h create mode 100644 onert-micro/onert-micro/include/test_models/gru/TestDataGRUBase.h diff --git a/onert-micro/onert-micro/include/test_models/gru/FloatGRUKernel.h b/onert-micro/onert-micro/include/test_models/gru/FloatGRUKernel.h new file mode 100644 index 00000000000..fa49b29c1d0 --- /dev/null +++ b/onert-micro/onert-micro/include/test_models/gru/FloatGRUKernel.h @@ -0,0 +1,175 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ONERT_MICRO_TEST_MODELS_FLOAT_GRU_KERNEL_H +#define ONERT_MICRO_TEST_MODELS_FLOAT_GRU_KERNEL_H + +#include "TestDataGRUBase.h" + +namespace onert_micro +{ +namespace test_model +{ + +namespace gru_float +{ +/* + * GRU Kernel: + * + * Input(1, 1, 6) + * | + * GRU + * | + * Output(1, 1, 5) + */ +unsigned char test_kernel_model_circle[] = { + 0x1c, 0x00, 0x00, 0x00, 0x43, 0x49, 0x52, 0x30, 0x00, 0x00, 0x12, 0x00, 0x18, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x00, 0x00, 0x14, 0x00, 0x12, 0x00, 0x00, 0x00, + 0x54, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x3c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x30, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0xb4, 0x06, 0x00, 0x00, 0x98, 0x06, 0x00, 0x00, + 0xe0, 0x04, 0x00, 0x00, 0x70, 0x03, 0x00, 0x00, 0x1c, 0x03, 0x00, 0x00, 0xd8, 0x02, 0x00, 0x00, + 0x90, 0x02, 0x00, 0x00, 0x38, 0x02, 0x00, 0x00, 0xe8, 0x01, 0x00, 0x00, 0xa8, 0x01, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, 0x6e, 0x6e, 0x70, 0x61, 0x63, 0x6b, 0x61, 0x67, + 0x65, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, + 0xf4, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0xfb, 0xfb, 0xff, 0xff, 0xff, 0x0c, 0x00, 0x0c, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x2d, + 0x2d, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, + 0x18, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x14, 0x00, 0x0e, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x6d, 0x61, 0x69, 0x6e, 0x00, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0xb4, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, + 0xac, 0x05, 0x00, 0x00, 0xf4, 0x03, 0x00, 0x00, 0x88, 0x02, 0x00, 0x00, 0x2c, 0x02, 0x00, 0x00, + 0xf0, 0x01, 0x00, 0x00, 0xa4, 0x01, 0x00, 0x00, 0x48, 0x01, 0x00, 0x00, 0xf8, 0x00, 0x00, 0x00, + 0xb4, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x1a, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00, + 0x07, 0x00, 0x14, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20, 0x01, 0x00, 0x00, 0x00, + 0x34, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, + 0x10, 0x00, 0x04, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x0e, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x0c, 0x00, 0x07, 0x00, 0x10, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xfb, + 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xc4, 0xfa, 0xff, 0xff, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, + 0x03, 0x00, 0x00, 0x00, 0x34, 0xfb, 0xff, 0xff, 0x30, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x19, 0x00, 0x00, 0x00, 0x53, 0x74, 0x61, 0x74, 0x65, 0x66, 0x75, 0x6c, + 0x50, 0x61, 0x72, 0x74, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x65, 0x64, 0x43, 0x61, 0x6c, 0x6c, 0x3a, + 0x30, 0x00, 0x00, 0x00, 0x1c, 0xfb, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, 0x68, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x02, 0x3c, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x73, 0x74, 0x72, 0x69, + 0x64, 0x65, 0x64, 0x5f, 0x73, 0x6c, 0x69, 0x63, 0x65, 0x5f, 0x32, 0x32, 0x00, 0x00, 0x00, 0x00, + 0x26, 0xfd, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0xb4, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x02, 0x3c, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x73, 0x74, 0x72, 0x69, 0x64, 0x65, 0x64, 0x5f, + 0x73, 0x6c, 0x69, 0x63, 0x65, 0x5f, 0x32, 0x31, 0x00, 0x00, 0x00, 0x00, 0x72, 0xfd, 0xff, 0xff, + 0x04, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x14, 0x00, + 0x08, 0x00, 0x07, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, + 0x38, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, + 0x73, 0x74, 0x72, 0x69, 0x64, 0x65, 0x64, 0x5f, 0x73, 0x6c, 0x69, 0x63, 0x65, 0x5f, 0x32, 0x00, + 0xc6, 0xfd, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x60, 0xfc, 0xff, 0xff, 0x24, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x0e, 0x00, 0x00, 0x00, 0x46, 0x75, 0x73, 0x65, 0x64, 0x43, 0x69, 0x72, 0x63, 0x6c, 0x65, 0x47, + 0x52, 0x55, 0x00, 0x00, 0x3c, 0xfc, 0xff, 0xff, 0x03, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x98, 0xfc, 0xff, 0xff, 0x48, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x73, 0x65, 0x71, 0x75, + 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x2f, 0x67, 0x72, 0x75, 0x2f, 0x7a, 0x65, 0x72, 0x6f, 0x73, + 0x00, 0x00, 0x00, 0x00, 0x4a, 0xfe, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, + 0xf0, 0xfc, 0xff, 0xff, 0x58, 0x01, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x0f, 0x00, 0x00, 0x00, 0x77, 0x68, 0x69, 0x6c, 0x65, 0x2f, 0x4d, 0x61, 0x74, 0x4d, 0x75, 0x6c, + 0x5f, 0x31, 0x31, 0x00, 0x9a, 0xfe, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0x2c, 0x01, 0x00, 0x00, + 0xc0, 0xfb, 0x12, 0xbd, 0x4b, 0xb1, 0x0c, 0x3f, 0x51, 0xbe, 0xa0, 0x3d, 0xdb, 0xcd, 0xca, 0xbe, + 0x77, 0xa7, 0x8d, 0x3e, 0xd8, 0x24, 0xe8, 0x3e, 0xc6, 0xe3, 0xfe, 0x3d, 0xa8, 0x41, 0xf0, 0xbd, + 0x9e, 0x70, 0xf3, 0xbd, 0x50, 0xfc, 0x4b, 0x3e, 0x7f, 0x8b, 0xf0, 0x3d, 0xae, 0xc0, 0x83, 0x3d, + 0xe4, 0xf0, 0x98, 0xbe, 0xd4, 0xd0, 0x7f, 0xbe, 0x80, 0xca, 0x98, 0x39, 0xe6, 0x2c, 0x08, 0xbe, + 0x61, 0x44, 0xdf, 0xbd, 0x67, 0x32, 0x32, 0xbe, 0x6a, 0x61, 0xdf, 0x3e, 0xc3, 0x0c, 0x55, 0x3e, + 0x6c, 0x28, 0x0e, 0xbf, 0xb6, 0x52, 0xf1, 0x3d, 0xb7, 0xd1, 0x3f, 0xbd, 0xa6, 0xf0, 0x9d, 0xbe, + 0xa0, 0xdd, 0xb1, 0x3e, 0xa3, 0x7d, 0x50, 0xbd, 0x3e, 0xd7, 0xe6, 0x3e, 0xe4, 0xb0, 0xe6, 0x3d, + 0x2a, 0xd6, 0xeb, 0x3e, 0xa8, 0xc8, 0x49, 0xbb, 0xdd, 0xdc, 0x6b, 0xbe, 0x66, 0x48, 0xc1, 0x3d, + 0x26, 0x6e, 0x52, 0x3e, 0xfc, 0xd6, 0x64, 0x3d, 0x4f, 0x1d, 0x1f, 0xbf, 0x5f, 0xf0, 0x9e, 0x3e, + 0xe0, 0x6e, 0xad, 0x3c, 0x48, 0x37, 0xe7, 0xbd, 0x36, 0xea, 0x0b, 0xbe, 0x3b, 0x81, 0xf2, 0xbd, + 0x52, 0xe1, 0x56, 0xbc, 0x75, 0x2e, 0xa3, 0xbd, 0x8c, 0x71, 0xc5, 0x3d, 0xf0, 0xaf, 0x0b, 0x3e, + 0x6b, 0x7d, 0xba, 0x3e, 0x4e, 0xbd, 0x93, 0xbe, 0xb3, 0x5c, 0x9c, 0xbe, 0x3c, 0xe2, 0xf3, 0x3c, + 0x39, 0xf1, 0xa0, 0x3d, 0xa0, 0x35, 0x50, 0x3e, 0xfa, 0x87, 0x0e, 0xbe, 0x76, 0xc2, 0x12, 0xbd, + 0x2a, 0xd6, 0x01, 0x3f, 0xa0, 0x77, 0xd0, 0x3c, 0x5a, 0x1f, 0x26, 0x3e, 0x02, 0x59, 0x0b, 0x3e, + 0xef, 0x6c, 0x41, 0xbe, 0x6e, 0x40, 0x4a, 0xbd, 0x2f, 0x89, 0x33, 0x3e, 0x50, 0x54, 0x8a, 0x3e, + 0x4d, 0xbb, 0x9f, 0xbe, 0xfd, 0x54, 0xb3, 0x3e, 0xc8, 0x5b, 0x66, 0xbe, 0xf0, 0xb0, 0x44, 0x3d, + 0x8a, 0x4d, 0x14, 0xbe, 0x9d, 0xf7, 0xd4, 0xbd, 0x38, 0xec, 0xc7, 0xbe, 0xb2, 0x79, 0x76, 0x3e, + 0xb2, 0xc2, 0xdd, 0xbe, 0x44, 0xd9, 0x05, 0xbe, 0x59, 0x34, 0x89, 0x3e, 0x71, 0xf8, 0x2b, 0x3e, + 0x1d, 0x62, 0x24, 0x3f, 0x40, 0xe6, 0x02, 0x3d, 0xef, 0x03, 0xb8, 0x3d, 0x02, 0x00, 0x00, 0x00, + 0x0f, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x58, 0xfe, 0xff, 0xff, 0x98, 0x01, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x77, 0x68, 0x69, 0x6c, + 0x65, 0x2f, 0x4d, 0x61, 0x74, 0x4d, 0x75, 0x6c, 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x68, 0x01, 0x00, 0x00, 0x9c, 0xb9, 0x96, 0xbe, + 0x30, 0x84, 0xdc, 0x3e, 0xb8, 0xe2, 0xc0, 0x3d, 0x00, 0xce, 0xf1, 0x3a, 0x28, 0xd6, 0xfb, 0x3d, + 0x49, 0x84, 0x95, 0xbe, 0xcc, 0x0d, 0x52, 0x3e, 0x7c, 0x4e, 0x6e, 0xbe, 0xde, 0xda, 0x4c, 0xbe, + 0x84, 0x5e, 0xda, 0x3e, 0x46, 0x2b, 0xd1, 0x3e, 0x78, 0xc8, 0x71, 0xbe, 0x00, 0xfd, 0x53, 0x3d, + 0x28, 0x4e, 0x91, 0x3e, 0x00, 0x46, 0xd6, 0xba, 0x20, 0x06, 0x97, 0xbe, 0xf4, 0x04, 0xdc, 0xbe, + 0xde, 0xf8, 0x05, 0xbf, 0x62, 0x20, 0x1d, 0xbe, 0x28, 0x28, 0xf9, 0x3d, 0xc6, 0xa0, 0x86, 0xbe, + 0xa2, 0x2f, 0x7f, 0xbe, 0xa0, 0xa1, 0x1d, 0xbd, 0x3c, 0x03, 0xb2, 0x3e, 0xe6, 0xe6, 0x7c, 0xbe, + 0x2e, 0x37, 0xbe, 0xbe, 0x84, 0xb2, 0x86, 0xbd, 0x10, 0x19, 0x56, 0x3e, 0x59, 0x86, 0x01, 0x3f, + 0xfc, 0x54, 0x15, 0x3e, 0xc3, 0xbd, 0x07, 0x3f, 0xa0, 0xcb, 0x5f, 0x3e, 0x6c, 0x19, 0xbb, 0x3e, + 0x9c, 0x98, 0x24, 0xbe, 0x40, 0x57, 0xd1, 0xbc, 0xb0, 0x9c, 0xec, 0xbd, 0x90, 0x19, 0xb4, 0x3d, + 0x59, 0x11, 0xe7, 0xbe, 0x04, 0x11, 0xd7, 0xbd, 0x6a, 0xd8, 0x46, 0xbe, 0xb9, 0xf2, 0x01, 0xbf, + 0x40, 0xe0, 0x2e, 0xbd, 0x9e, 0xe6, 0x9a, 0x3e, 0xa0, 0x27, 0xda, 0xbe, 0x39, 0xe9, 0x04, 0x3f, + 0x5c, 0x2f, 0x2d, 0x3e, 0x18, 0x35, 0x95, 0x3e, 0x5c, 0x67, 0x14, 0x3e, 0xd0, 0xb1, 0x92, 0xbd, + 0xa8, 0x99, 0xe2, 0xbd, 0x00, 0x1e, 0x0e, 0x3e, 0x80, 0x85, 0x7a, 0x3c, 0x88, 0xde, 0xde, 0x3e, + 0x0a, 0x10, 0xc9, 0x3e, 0x28, 0x29, 0x3c, 0xbd, 0xbe, 0x3a, 0xfd, 0x3e, 0x36, 0x76, 0xef, 0xbe, + 0x6e, 0x44, 0xb4, 0x3e, 0xdc, 0xd6, 0x9c, 0xbd, 0xf0, 0xed, 0x9a, 0x3e, 0x90, 0x9c, 0x6b, 0x3d, + 0x0c, 0xc3, 0x32, 0x3e, 0x8a, 0x27, 0x1f, 0xbe, 0x00, 0x64, 0x5f, 0x3a, 0x8e, 0x71, 0xcc, 0x3e, + 0xcf, 0xe7, 0xe1, 0xbe, 0xc6, 0x65, 0xb4, 0x3e, 0xa4, 0x65, 0x6d, 0x3e, 0x31, 0xd8, 0x03, 0x3f, + 0x2c, 0x2a, 0xa8, 0xbd, 0x38, 0x1b, 0xac, 0x3e, 0x60, 0xcc, 0x64, 0x3e, 0x18, 0x4c, 0x0e, 0xbd, + 0x82, 0x5e, 0xa2, 0x3e, 0xde, 0x70, 0xb0, 0xbe, 0x46, 0x07, 0xe6, 0xbe, 0xf6, 0x4a, 0xa8, 0xbe, + 0x90, 0xfa, 0x3f, 0x3e, 0x5c, 0x9a, 0xe9, 0xbe, 0x63, 0x1e, 0xd3, 0xbe, 0x20, 0x74, 0x7e, 0x3d, + 0x20, 0x9c, 0x02, 0xbf, 0xf7, 0x65, 0x02, 0x3f, 0xb6, 0x45, 0xdf, 0x3e, 0x4e, 0xc2, 0x48, 0xbe, + 0xe3, 0x90, 0xa9, 0xbe, 0xc8, 0x36, 0xab, 0x3d, 0xca, 0xc0, 0x22, 0xbe, 0xec, 0x99, 0x26, 0x3e, + 0xd0, 0x91, 0x35, 0x3e, 0x02, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x10, 0x00, 0x04, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x28, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x13, 0x00, 0x00, 0x00, + 0x73, 0x65, 0x72, 0x76, 0x69, 0x6e, 0x67, 0x5f, 0x64, 0x65, 0x66, 0x61, 0x75, 0x6c, 0x74, 0x5f, + 0x78, 0x3a, 0x30, 0x00, 0xec, 0xff, 0xff, 0xff, 0x03, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00}; + +const std::vector input_data = {7.899295, -4.584313, -2.9251342, + -2.1820352, -10.649105, 1.3530581}; + +const std::vector reference_output_data = {-0.9979859, -0.90550894, 0.025957875, -0.39570245, + -0.8868108}; + +} // namespace gru_float + +class TestDataFloatGRU : public TestDataGRUBase +{ +public: + TestDataFloatGRU() + { + _input_data = gru_float::input_data; + _reference_output_data = gru_float::reference_output_data; + _test_kernel_model_circle = gru_float::test_kernel_model_circle; + } + + ~TestDataFloatGRU() override = default; +}; + +} // namespace test_model +} // namespace onert_micro + +#endif // ONERT_MICRO_TEST_MODELS_FLOAT_GRU_KERNEL_H diff --git a/onert-micro/onert-micro/include/test_models/gru/TestDataGRUBase.h b/onert-micro/onert-micro/include/test_models/gru/TestDataGRUBase.h new file mode 100644 index 00000000000..5c3da425d4a --- /dev/null +++ b/onert-micro/onert-micro/include/test_models/gru/TestDataGRUBase.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ONERT_MICRO_TEST_MODELS_GRU_KERNEL_BASE_H +#define ONERT_MICRO_TEST_MODELS_GRU_KERNEL_BASE_H + +#include "test_models/TestDataBase.h" + +namespace onert_micro +{ +namespace test_model +{ + +template class TestDataGRUBase : public TestDataBase +{ +public: + TestDataGRUBase() = default; + + const unsigned char *get_model_ptr() override final { return _test_kernel_model_circle; } + + const std::vector &get_input_data_by_index(int i) override final + { + switch (i) + { + case 0: + return _input_data; + default: + assert(false && "Wrong input index"); + } + } + + const std::vector &get_output_data_by_index(int i) override final + { + assert(i == 0); + return _reference_output_data; + } + +protected: + std::vector _input_data; + std::vector _reference_output_data; + const unsigned char *_test_kernel_model_circle; +}; + +} // namespace test_model +} // namespace onert_micro + +#endif // ONERT_MICRO_TEST_MODELS_GRU_KERNEL_BASE_H diff --git a/onert-micro/onert-micro/src/execute/kernels/tests/GRU.test.cpp b/onert-micro/onert-micro/src/execute/kernels/tests/GRU.test.cpp index 0c8a203a59d..d9d49621947 100644 --- a/onert-micro/onert-micro/src/execute/kernels/tests/GRU.test.cpp +++ b/onert-micro/onert-micro/src/execute/kernels/tests/GRU.test.cpp @@ -14,4 +14,32 @@ * limitations under the License. */ -// TODO add tests +#include "execute/OMTestUtils.h" +#include "test_models/gru/FloatGRUKernel.h" + +namespace onert_micro +{ +namespace execute +{ +namespace testing +{ + +using namespace testing; + +class GRUTest : public ::testing::Test +{ + // Do nothing +}; + +TEST_F(GRUTest, Float_P) +{ + onert_micro::test_model::TestDataFloatGRU test_data_kernel; + std::vector output_data_vector = + onert_micro::execute::testing::checkKernel(1, &test_data_kernel); + EXPECT_THAT(output_data_vector, + FloatArrayNear(test_data_kernel.get_output_data_by_index(0), 0.0001f)); +} + +} // namespace testing +} // namespace execute +} // namespace onert_micro diff --git a/onert-micro/onert-micro/src/import/kernels/GRU.cpp b/onert-micro/onert-micro/src/import/kernels/GRU.cpp index 1c76b2f9c5d..2c1167c98b5 100644 --- a/onert-micro/onert-micro/src/import/kernels/GRU.cpp +++ b/onert-micro/onert-micro/src/import/kernels/GRU.cpp @@ -96,5 +96,9 @@ OMStatus onert_micro::import::configure_kernel_CircleGRU(const OMConfigureArgs & if (status != Ok) return status; + status = utils::checkCondition(input->type() == output->type()); + if (status != Ok) + return status; + return status; }