diff --git a/onert-micro/onert-micro/include/execute/kernels/ArgCommon.h b/onert-micro/onert-micro/include/execute/kernels/ArgCommon.h new file mode 100644 index 00000000000..7dd3b26be2a --- /dev/null +++ b/onert-micro/onert-micro/include/execute/kernels/ArgCommon.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ONERT_MICRO_EXECUTE_KERNELS_ARG_COMMON_H +#define ONERT_MICRO_EXECUTE_KERNELS_ARG_COMMON_H + +#include "OMStatus.h" + +#include "core/OMUtils.h" +#include "core/OMKernelData.h" + +#include "execute/OMKernelExecutionBuilder.h" +#include "execute/OMUtils.h" +#include "execute/OMRuntimeKernel.h" +#include + +namespace onert_micro +{ +namespace execute +{ + +OMStatus execute_arg_common( + const OMExecuteArgs &execute_args, + const std::function &f_float); + +} // namespace execute +} // namespace onert_micro + +#endif // ONERT_MICRO_EXECUTE_KERNELS_ARG_COMMON_H diff --git a/onert-micro/onert-micro/include/import/helpers/OMArgCommon.h b/onert-micro/onert-micro/include/import/helpers/OMArgCommon.h new file mode 100644 index 00000000000..eff7eb7ade3 --- /dev/null +++ b/onert-micro/onert-micro/include/import/helpers/OMArgCommon.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ONERT_MICRO_IMPORT_HELPERS_CONFIGURE_ARG_KERNEL_COMMON_H +#define ONERT_MICRO_IMPORT_HELPERS_CONFIGURE_ARG_KERNEL_COMMON_H + +#include "import/OMKernelConfigureBuilder.h" +#include "core/OMUtils.h" +#include "OMStatus.h" +#include "execute/OMRuntimeKernel.h" + +namespace onert_micro +{ +namespace import +{ +namespace helpers +{ + +OMStatus configure_arg_kernel_common(const OMConfigureArgs &config_args); + +} // namespace helpers +} // namespace import +} // namespace onert_micro + +#endif // ONERT_MICRO_IMPORT_HELPERS_CONFIGURE_ARG_KERNEL_COMMON_H diff --git a/onert-micro/onert-micro/src/execute/CMakeLists.txt b/onert-micro/onert-micro/src/execute/CMakeLists.txt index 912af3bd885..cdb8cb2e9db 100644 --- a/onert-micro/onert-micro/src/execute/CMakeLists.txt +++ b/onert-micro/onert-micro/src/execute/CMakeLists.txt @@ -16,6 +16,7 @@ set(SOURCES OMUtils.cpp kernels/ConvolutionCommon.cpp kernels/PoolingCommon.cpp + kernels/ArgCommon.cpp kernels/ReshapeCommon.cpp ) diff --git a/onert-micro/onert-micro/src/execute/kernels/ArgCommon.cpp b/onert-micro/onert-micro/src/execute/kernels/ArgCommon.cpp new file mode 100644 index 00000000000..eb82073a0b2 --- /dev/null +++ b/onert-micro/onert-micro/src/execute/kernels/ArgCommon.cpp @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "execute/kernels/ArgCommon.h" +#include "PALArgMax.h" + +using namespace onert_micro; +using namespace onert_micro::execute; + +namespace +{ + +constexpr uint32_t input1TensorIdx = 0; +constexpr uint32_t input2TensorIdx = 1; +constexpr uint32_t outputTensorIdx = 0; + +} // namespace + +OMStatus onert_micro::execute::execute_arg_common( + const OMExecuteArgs &execute_args, + const std::function &f_float) +{ + core::OMRuntimeContext &runtime_context = execute_args.runtime_context; + core::OMRuntimeStorage &runtime_storage = execute_args.runtime_storage; + uint16_t op_index = execute_args.kernel_index; + const circle::Tensor *output; + const circle::Tensor *input1; + const circle::Tensor *input2; + + uint8_t *output_data; + uint8_t *input_data; + uint8_t *axis_data; + + // Read kernel + execute::OMRuntimeKernel runtime_kernel; + runtime_kernel.readKernel(op_index, runtime_context); + + output = runtime_kernel.outputs[outputTensorIdx]; + assert(output != nullptr); + + input1 = runtime_kernel.inputs[input1TensorIdx]; + assert(input1 != nullptr); + + input2 = runtime_kernel.inputs[input2TensorIdx]; + assert(input2 != nullptr); + + runtime_kernel.getDataFromStorage(op_index, runtime_storage, runtime_context); + + output_data = runtime_kernel.outputs_data[outputTensorIdx]; + assert(output_data != nullptr); + + input_data = runtime_kernel.inputs_data[input1TensorIdx]; + assert(input_data != nullptr); + + axis_data = runtime_kernel.inputs_data[input2TensorIdx]; + assert(axis_data != nullptr); + + OMStatus status; + const core::OMRuntimeShape input1_shape(input1); + const core::OMRuntimeShape output_shape(output); + switch (input1->type()) + { +#ifndef DIS_FLOAT + case circle::TensorType_FLOAT32: + { + status = f_float(input1_shape, reinterpret_cast(input_data), + reinterpret_cast(axis_data), output_shape, + reinterpret_cast(output_data)); + } + break; +#endif // DIS_FLOAT + default: + { + status = UnsupportedType; + assert(false && "Unsupported type."); + } + } + return status; +} diff --git a/onert-micro/onert-micro/src/execute/kernels/ArgMax.cpp b/onert-micro/onert-micro/src/execute/kernels/ArgMax.cpp index 12967d44435..dcf08bed5ef 100644 --- a/onert-micro/onert-micro/src/execute/kernels/ArgMax.cpp +++ b/onert-micro/onert-micro/src/execute/kernels/ArgMax.cpp @@ -14,80 +14,20 @@ * limitations under the License. */ -#include "execute/OMKernelExecutionBuilder.h" -#include "OMStatus.h" -#include "execute/OMRuntimeKernel.h" - -#include "core/OMRuntimeShape.h" +#include "execute/kernels/ArgCommon.h" #include "PALArgMax.h" using namespace onert_micro; using namespace onert_micro::execute; -namespace -{ -constexpr uint32_t input1TensorIdx = 0; -constexpr uint32_t input2TensorIdx = 1; -constexpr uint32_t outputTensorIdx = 0; -} // namespace - OMStatus onert_micro::execute::execute_kernel_CircleArgMax(const OMExecuteArgs &execute_args) { - core::OMRuntimeContext &runtime_context = execute_args.runtime_context; - core::OMRuntimeStorage &runtime_storage = execute_args.runtime_storage; - uint16_t op_index = execute_args.kernel_index; - const circle::Tensor *output; - const circle::Tensor *input1; - const circle::Tensor *input2; - - uint8_t *output_data; - uint8_t *input_data; - uint8_t *axis_data; - - // Read kernel - execute::OMRuntimeKernel runtime_kernel; - runtime_kernel.readKernel(op_index, runtime_context); - - output = runtime_kernel.outputs[outputTensorIdx]; - assert(output != nullptr); - - input1 = runtime_kernel.inputs[input1TensorIdx]; - assert(input1 != nullptr); - - input2 = runtime_kernel.inputs[input2TensorIdx]; - assert(input2 != nullptr); - - runtime_kernel.getDataFromStorage(op_index, runtime_storage, runtime_context); - - output_data = runtime_kernel.outputs_data[outputTensorIdx]; - assert(output_data != nullptr); - - input_data = runtime_kernel.inputs_data[input1TensorIdx]; - assert(input_data != nullptr); - - axis_data = runtime_kernel.inputs_data[input2TensorIdx]; - assert(axis_data != nullptr); - - OMStatus status; - const core::OMRuntimeShape input1_shape(input1); - const core::OMRuntimeShape output_shape(output); - switch (input1->type()) - { -#ifndef DIS_FLOAT - case circle::TensorType_FLOAT32: - { - status = - onert_micro::execute::pal::ArgMax(input1_shape, reinterpret_cast(input_data), - reinterpret_cast(axis_data), output_shape, - reinterpret_cast(output_data)); - } - break; -#endif // DIS_FLOAT - default: - { - status = UnsupportedType; - assert(false && "Unsupported type."); - } - } - return status; + auto arg_max_float_lambda = [](const core::OMRuntimeShape &input1_shape, const float *input1_data, + const int *input2_data, const core::OMRuntimeShape &output_shape, + int *output_data) { + return onert_micro::execute::pal::ArgMax(input1_shape, input1_data, input2_data, output_shape, + output_data); + }; + + return execute_arg_common(execute_args, arg_max_float_lambda); } diff --git a/onert-micro/onert-micro/src/execute/kernels/ArgMin.cpp b/onert-micro/onert-micro/src/execute/kernels/ArgMin.cpp index b87872d3a30..a4afc72ee99 100644 --- a/onert-micro/onert-micro/src/execute/kernels/ArgMin.cpp +++ b/onert-micro/onert-micro/src/execute/kernels/ArgMin.cpp @@ -14,80 +14,20 @@ * limitations under the License. */ -#include "execute/OMKernelExecutionBuilder.h" -#include "OMStatus.h" -#include "execute/OMRuntimeKernel.h" - -#include "core/OMRuntimeShape.h" +#include "execute/kernels/ArgCommon.h" #include "PALArgMin.h" using namespace onert_micro; using namespace onert_micro::execute; -namespace -{ -constexpr uint32_t input1TensorIdx = 0; -constexpr uint32_t input2TensorIdx = 1; -constexpr uint32_t outputTensorIdx = 0; -} // namespace - OMStatus onert_micro::execute::execute_kernel_CircleArgMin(const OMExecuteArgs &execute_args) { - core::OMRuntimeContext &runtime_context = execute_args.runtime_context; - core::OMRuntimeStorage &runtime_storage = execute_args.runtime_storage; - uint16_t op_index = execute_args.kernel_index; - const circle::Tensor *output; - const circle::Tensor *input1; - const circle::Tensor *input2; - - uint8_t *output_data; - uint8_t *input_data; - uint8_t *axis_data; - - // Read kernel - execute::OMRuntimeKernel runtime_kernel; - runtime_kernel.readKernel(op_index, runtime_context); - - output = runtime_kernel.outputs[outputTensorIdx]; - assert(output != nullptr); - - input1 = runtime_kernel.inputs[input1TensorIdx]; - assert(input1 != nullptr); - - input2 = runtime_kernel.inputs[input2TensorIdx]; - assert(input2 != nullptr); - - runtime_kernel.getDataFromStorage(op_index, runtime_storage, runtime_context); - - output_data = runtime_kernel.outputs_data[outputTensorIdx]; - assert(output_data != nullptr); - - input_data = runtime_kernel.inputs_data[input1TensorIdx]; - assert(input_data != nullptr); - - axis_data = runtime_kernel.inputs_data[input2TensorIdx]; - assert(axis_data != nullptr); - - OMStatus status; - const core::OMRuntimeShape input1_shape(input1); - const core::OMRuntimeShape output_shape(output); - switch (input1->type()) - { -#ifndef DIS_FLOAT - case circle::TensorType_FLOAT32: - { - status = - onert_micro::execute::pal::ArgMin(input1_shape, reinterpret_cast(input_data), - reinterpret_cast(axis_data), output_shape, - reinterpret_cast(output_data)); - } - break; -#endif // DIS_FLOAT - default: - { - status = UnsupportedType; - assert(false && "Unsupported type."); - } - } - return status; + auto arg_max_float_lambda = [](const core::OMRuntimeShape &input1_shape, const float *input1_data, + const int *input2_data, const core::OMRuntimeShape &output_shape, + int *output_data) { + return onert_micro::execute::pal::ArgMin(input1_shape, input1_data, input2_data, output_shape, + output_data); + }; + + return execute_arg_common(execute_args, arg_max_float_lambda); } diff --git a/onert-micro/onert-micro/src/import/CMakeLists.txt b/onert-micro/onert-micro/src/import/CMakeLists.txt index 6297f8feed1..0a792f3d59e 100644 --- a/onert-micro/onert-micro/src/import/CMakeLists.txt +++ b/onert-micro/onert-micro/src/import/CMakeLists.txt @@ -8,6 +8,7 @@ set(SOURCES helpers/OMPadCommon.cpp helpers/OMConfigureTISOKernel.cpp helpers/OMPoolingCommon.cpp + helpers/OMArgCommon.cpp ) # Add configure kernels diff --git a/onert-micro/onert-micro/src/import/helpers/OMArgCommon.cpp b/onert-micro/onert-micro/src/import/helpers/OMArgCommon.cpp new file mode 100644 index 00000000000..d0c04534b24 --- /dev/null +++ b/onert-micro/onert-micro/src/import/helpers/OMArgCommon.cpp @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "import/helpers/OMArgCommon.h" + +using namespace onert_micro; +using namespace onert_micro::core; + +namespace +{ + +constexpr uint32_t input1TensorIdx = 0; +constexpr uint32_t input2TensorIdx = 1; +constexpr uint32_t outputTensorIdx = 0; + +} // namespace + +OMStatus +onert_micro::import::helpers::configure_arg_kernel_common(const OMConfigureArgs &config_args) +{ + OMRuntimeContext &runtime_context = config_args.runtime_context; + uint16_t op_index = config_args.kernel_index; + + onert_micro::execute::OMRuntimeKernel runtime_kernel; + + OMStatus status = runtime_kernel.readKernel(op_index, runtime_context); + if (status != Ok) + return status; + + const circle::Tensor *output = runtime_kernel.outputs[outputTensorIdx]; + assert(output != nullptr); + const circle::Tensor *input1 = runtime_kernel.inputs[input1TensorIdx]; + assert(input1 != nullptr); + const circle::Tensor *input2 = runtime_kernel.inputs[input2TensorIdx]; + assert(input2 != nullptr); + + const OMRuntimeShape input2_shape(input2); + + // dim tensor must be a scalar or has one element + status = + utils::checkCondition(input2_shape.dimensionsCount() == 0 or input2_shape.flatSize() == 1); + if (status != Ok) + return status; + + // value and output type must match + status = utils::checkCondition(output->type() == circle::TensorType_INT32); + if (status != Ok) + return status; + status = utils::checkCondition(input2->type() == circle::TensorType_INT32); + if (status != Ok) + return status; + + status = utils::checkCondition(input1->type() != circle::TensorType_INT8 and + input1->type() != circle::TensorType_INT16); + if (status != Ok) + return status; + + return status; +} diff --git a/onert-micro/onert-micro/src/import/kernels/ArgMax.cpp b/onert-micro/onert-micro/src/import/kernels/ArgMax.cpp index f15f14e7814..b38beaf31df 100644 --- a/onert-micro/onert-micro/src/import/kernels/ArgMax.cpp +++ b/onert-micro/onert-micro/src/import/kernels/ArgMax.cpp @@ -14,61 +14,12 @@ * limitations under the License. */ -#include "import/OMKernelConfigureBuilder.h" -#include "core/OMUtils.h" -#include "OMStatus.h" -#include "execute/OMRuntimeKernel.h" +#include "import/helpers/OMArgCommon.h" using namespace onert_micro; using namespace onert_micro::core; -namespace -{ - -constexpr uint32_t input1TensorIdx = 0; -constexpr uint32_t input2TensorIdx = 1; -constexpr uint32_t outputTensorIdx = 0; - -} // namespace - OMStatus onert_micro::import::configure_kernel_CircleArgMax(const OMConfigureArgs &config_args) { - OMRuntimeContext &runtime_context = config_args.runtime_context; - uint16_t op_index = config_args.kernel_index; - - onert_micro::execute::OMRuntimeKernel runtime_kernel; - - OMStatus status = runtime_kernel.readKernel(op_index, runtime_context); - if (status != Ok) - return status; - - const circle::Tensor *output = runtime_kernel.outputs[outputTensorIdx]; - assert(output != nullptr); - const circle::Tensor *input1 = runtime_kernel.inputs[input1TensorIdx]; - assert(input1 != nullptr); - const circle::Tensor *input2 = runtime_kernel.inputs[input2TensorIdx]; - assert(input2 != nullptr); - - const OMRuntimeShape input2_shape(input2); - - // dim tensor must be a scalar or has one element - status = - utils::checkCondition(input2_shape.dimensionsCount() == 0 or input2_shape.flatSize() == 1); - if (status != Ok) - return status; - - // value and output type must match - status = utils::checkCondition(output->type() == circle::TensorType_INT32); - if (status != Ok) - return status; - status = utils::checkCondition(input2->type() == circle::TensorType_INT32); - if (status != Ok) - return status; - - status = utils::checkCondition(input1->type() != circle::TensorType_INT8 and - input1->type() != circle::TensorType_INT16); - if (status != Ok) - return status; - - return status; + return helpers::configure_arg_kernel_common(config_args); } diff --git a/onert-micro/onert-micro/src/import/kernels/ArgMin.cpp b/onert-micro/onert-micro/src/import/kernels/ArgMin.cpp index 25d6cc71cbc..1f9b9c5e4b1 100644 --- a/onert-micro/onert-micro/src/import/kernels/ArgMin.cpp +++ b/onert-micro/onert-micro/src/import/kernels/ArgMin.cpp @@ -14,61 +14,12 @@ * limitations under the License. */ -#include "import/OMKernelConfigureBuilder.h" -#include "core/OMUtils.h" -#include "OMStatus.h" -#include "execute/OMRuntimeKernel.h" +#include "import/helpers/OMArgCommon.h" using namespace onert_micro; using namespace onert_micro::core; -namespace -{ - -constexpr uint32_t input1TensorIdx = 0; -constexpr uint32_t input2TensorIdx = 1; -constexpr uint32_t outputTensorIdx = 0; - -} // namespace - OMStatus onert_micro::import::configure_kernel_CircleArgMin(const OMConfigureArgs &config_args) { - OMRuntimeContext &runtime_context = config_args.runtime_context; - uint16_t op_index = config_args.kernel_index; - - onert_micro::execute::OMRuntimeKernel runtime_kernel; - - OMStatus status = runtime_kernel.readKernel(op_index, runtime_context); - if (status != Ok) - return status; - - const circle::Tensor *output = runtime_kernel.outputs[outputTensorIdx]; - assert(output != nullptr); - const circle::Tensor *input1 = runtime_kernel.inputs[input1TensorIdx]; - assert(input1 != nullptr); - const circle::Tensor *input2 = runtime_kernel.inputs[input2TensorIdx]; - assert(input2 != nullptr); - - const OMRuntimeShape input2_shape(input2); - - // dim tensor must be a scalar or has one element - status = - utils::checkCondition(input2_shape.dimensionsCount() == 0 or input2_shape.flatSize() == 1); - if (status != Ok) - return status; - - // value and output type must match - status = utils::checkCondition(output->type() == circle::TensorType_INT32); - if (status != Ok) - return status; - status = utils::checkCondition(input2->type() == circle::TensorType_INT32); - if (status != Ok) - return status; - - status = utils::checkCondition(input1->type() != circle::TensorType_INT8 and - input1->type() != circle::TensorType_INT16); - if (status != Ok) - return status; - - return status; + return helpers::configure_arg_kernel_common(config_args); }